import datetime
from django.shortcuts import render
from collections import defaultdict
# Create your views here.
from rest_framework import status
from rest_framework.generics import ListAPIView, CreateAPIView
from rest_framework.response import Response
from rest_framework.views import APIView
from participant.api import ParticipantApi
from .utils import response_standard
from .api import QuestionSurveyApi, ObservationApi, QuestionApi, AllowedResponseApi
from .serializers import SurveyObservationSerializer, ObservationSerializer, AllowedResponseSerializer, SurveyQuestionSerializer

from dateutil.relativedelta import relativedelta
from .api import QuestionApi
import dateutil.parser
from collections import OrderedDict

from django.db.models import Count
from itertools import groupby
from models import Observation

import logging
logger = logging.getLogger(__name__)



class SurveyObservationView(CreateAPIView, ListAPIView):
    serializer_class = ObservationSerializer
    profile_surveys = [1330, 1331, 1332, 18, 20, 21, 22]
    editable_surveys = [15]

    def get(self, request, *args, **kwargs):
        def _get(survey, participant):
            
            django_queryset = self.get_queryset(
                survey,
                participant,
                *args, **kwargs)
            queryset = self.filter_queryset(django_queryset)
            page = self.paginate_queryset(queryset)
            if page is not None:
                serializer = self.get_serializer(page, many=True)
                return self.get_paginated_response(serializer.data)
            serializer = self.get_serializer(queryset, many=True)
            resp_data = serializer.data
            # Add children observations
            # TODO: this can be optimized by using LEFT JOIN on self/raw query
            childrens = defaultdict(list) #parent_id => list
            parents = {} #parent_id = > obs
            for obs in resp_data:
                if ("parent" in obs and obs["parent"] is not None):
                    childrens[obs["parent"]].append(obs)
                else:
                    parents[obs["id"]] = obs
            for parent_id, children in childrens.items():
                try:
                    parents[parent_id]["children"] = children
                except:
                    continue

            data_with_children = parents.values()

            #print "data with children: ", data_with_children
            return data_with_children

        if request.GET.getlist('survey_id') != None:
            observations = []
            for survey in request.GET.getlist('survey_id'):
                observations.append({
                    "survey": survey,
                    "observations": _get(survey,
                                         request.participant_token.participant.id)
                })
            return Response(response_standard(observations))
        else:
            return Response(response_standard({"observations": _get(
                kwargs["survey_id"],
                request.participant_token.participant.id)}))

    def post(self, request, *args, **kwargs):
        survey_id = kwargs["survey_id"]
       
        is_survey_profile = False
        data = request.data
        participant = request.participant_token.participant
        logger.debug("survey_id=" + str(survey_id) + " participant=" + str(participant))
        observations = data.get("observations", [])

        profile_observations = defaultdict(list)
        # keep all observations
        if int(survey_id) == 15000:
            for observation in observations:
                obs = ObservationApi()._filter(question_survey=observation["question_survey"],participant=participant.id)
                for ob in obs:
                    ob.is_edited = True
                    ob.save()

        # Remove all observations for the survey
        if int(survey_id) in self.profile_surveys:
            old_observations = ObservationApi()._filter(survey_id_list=[survey_id],participant=participant.id)
            for ob in old_observations:
                # ob.audit_log.disable_tracking()
                ob.delete()
            med_question_survey = None

            # Remove Allowed Responses
            if int(survey_id) == 1330:
                med_question_survey = QuestionSurveyApi()._get(15)
            if int(survey_id) == 1331:
                med_question_survey = QuestionSurveyApi()._get(401)
            if int(survey_id) == 1332:
                med_question_survey = QuestionSurveyApi()._get(400)
            if int(survey_id) == 20:
                med_question_survey = QuestionSurveyApi()._get(12)
            if int(survey_id) == 18:
                med_question_survey = QuestionSurveyApi()._get(34)
            if int(survey_id) == 21:
                med_question_survey = QuestionSurveyApi()._get(14)
            if int(survey_id) == 22:
                med_question_survey = QuestionSurveyApi()._get(35)
            if med_question_survey is not None:
                AllowedResponseApi()._delete_by_participant(
                    med_question_survey.question, participant)

        if 'entry_date' in data:
            old = ObservationApi()._filter(survey_id_list=[survey_id], entry_date=data['entry_date'], participant=participant.id)
            if 'survey_type' in data and data['survey_type'] == 'Day':
                # srlg still investigaing purpose of madness below. Doesn't appear to have desired behavior 
                #str_data = data['entry_date'].strip("+0000")
                str_data = data['entry_date'].replace("+0000", "")
                try:
                    date_string = datetime.datetime.strptime(str_data, "%Y-%m-%dT%H:%M:%S").date()

                except ValueError:
                    logger.debug("ValueError In SurveyObservationView!! - dump data next line:")
                    logger.debug("str_data: " + str_data)
                    logger.debug(data)

                logger.debug("SurveyObservationView ValueError fix:")
                old = Observation.objects.filter(question_survey__survey__id__in=[survey_id],
                                                 entry_date__startswith=date_string,participant_id=participant.id)
            for obj in old:
                obj.is_edited = True
                obj.save()

            # check each observation is valid
        for observ in observations:
            question_survey = observ.get("question_survey", None)
            sos = SurveyObservationSerializer(data=observ)
            if not sos.is_valid():
                error = sos.errors
                error["question_survey"] = ["question_survey is required"]
                return Response(error, status=status.HTTP_400_BAD_REQUEST)
            else:
                #TODO: note this will only work with ONE LEVEL of children. needs to to be
                # refactored for multi-level children
                validated_data = self.valid_observation(sos.data, question_survey, participant)
                if 'entry_date' in data:
                    validated_data['entry_date'] = data['entry_date']

                #create observation
                obs_obj = self.create_observation(validated_data, participant, is_survey_profile)
                children_observations = observ.get("children", [])
                for child_obs in children_observations:
                    child_sos = SurveyObservationSerializer(data=child_obs)
                    if not child_sos.is_valid():
                        error = sos.errors
                        error["question_survey"] = ["question_survey is required"]
                        return Response(error, status=status.HTTP_400_BAD_REQUEST)
                    child_question_survey = child_obs["question_survey"]
                    child_validated_data = self.valid_observation(survey_ques=child_sos.data,
                                                                  question_survey_id=child_question_survey,
                                                                  participant=participant)

                    if 'entry_date' in data:
                        child_validated_data['entry_date'] = data['entry_date']
                    child_obs_obj = self.create_observation(child_validated_data,
                                                            participant,
                                                            is_survey_profile)
                    #set parent observation
                    child_obs_obj.parent = obs_obj
                    child_obs_obj.save()

                # Profile questions
                if obs_obj.question_survey.survey.id in self.profile_surveys:
                    profile_observations[obs_obj.question_survey.survey.id].append(obs_obj)

        # Profile questions
        for id, items in profile_observations.iteritems():
            med_question_survey = None
            # Medications
            if id in self.profile_surveys:
                if id == 1330:
                    med_question_survey = QuestionSurveyApi()._get(15)
                if id == 1331:
                    med_question_survey = QuestionSurveyApi()._get(401)
                if id == 1332:
                    med_question_survey = QuestionSurveyApi()._get(400)
                if id == 20:
                    med_question_survey = QuestionSurveyApi()._get(12)
                if id == 18:
                    med_question_survey = QuestionSurveyApi()._get(34)
                if id == 21:
                    med_question_survey = QuestionSurveyApi()._get(14)
                if id == 22:
                    med_question_survey = QuestionSurveyApi()._get(35)
                if med_question_survey is None:
                    continue

                # ";;" Split values
                if id in [20, 18, 21, 22]:
                    class tmpObject:
                        pass
                    _items = items[0].str_value.split(';;')
                    items = []
                    for item in _items:
                        if item == "":
                            continue
                        stdOb = tmpObject()
                        stdOb.str_value = item
                        items.append(stdOb)
                i = 0
                for obs in items:
                    value = obs.str_value
                    if value == "" or value is None:
                        continue
                    AllowedResponseApi()._create(
                        med_question_survey.question,
                        value, i, None, value, None, False, participant)
                    i+=1

        return Response("ok", status=status.HTTP_201_CREATED)

    def create_observation(self, validated_data, participant, is_survey_profile):
        obs = ObservationApi()._create(**validated_data)
        return obs

    def check_if_profile(self, survey_id):
            """ check if survey is of type profile and then link profile with observations """
            return int(survey_id) == 11

    def valid_observation(self, survey_ques, question_survey_id, participant):
        question_survey_id = survey_ques["question_survey"]
        questionsurvey = QuestionSurveyApi()._get(question_survey_id)
        #parent = None
        #if hasattr(questionsurvey.question.parent, "id"):
        #    parent = questionsurvey.question.parent.id
        #survey_ques["parent"] = parent
        survey_ques["question_survey"] = questionsurvey
        survey_ques["participant"] = participant
        survey_ques["parent"] = None
        return survey_ques

    def get_queryset(self, survey, participant, *args, **kwargs):
        # survey = self.kwargs["survey_id"]
        # participant = self.request.query_params.get("participant")
        entry_date = self.request.query_params.get("entry_date")
        start_date = self.request.query_params.get("start_date")
        end_date = self.request.query_params.get("end_date")
        # entry_date_range = self.request.query_params.get("entry_date_range")
        question_survey = self.request.query_params.get("question_survey")
        return ObservationApi()._filter(survey=survey, participant=participant, entry_date=entry_date,
                                        question_survey=question_survey,
                                        start_date=start_date, end_date=end_date)



class InsightsTypeView(APIView):

    def get(self, request, *args, **kwargs):
        data = request.GET
        plot_type = data.get("var_x")  # day/month/year
        question_id = data.get("var_y")  # question_id
        x_axis_summary = data.get("x_axis_summary", '')
        y_axis_summary = data.get("y_axis_summary", '')
        start_date = data.get("start_date")
        end_date = data.get("end_date")

        participant_id = request.participant_token.participant.id
        observation = ObservationApi()._filter(participant=participant_id,
                                               question_id=question_id,
                                               start_date=start_date,
                                               end_date=end_date)

        results = self.format_results(observations=observation,
                                      plot_type=plot_type,
                                      start_date=start_date,
                                      end_date=end_date,
                                      question_id=question_id)
        return Response(response_standard(results), status=status.HTTP_200_OK)

    def format_results(self, observations, plot_type, start_date, end_date, question_id):

        """
        add values to corresponding keys of result dict
        """
        result_dict = self.make_dictionary(plot_type, start_date, end_date)
        question = QuestionApi()._get(id=question_id)
        result_list = [] # will be used to maintain order when sending data
        if plot_type == "day":
            if question.data_type.lower() == "str":
                for obs in observations:
                    match_key = str(obs.entry_date.day) + " " + str(obs.entry_date.month) + \
                                     " " + str(obs.entry_date.year)
                    result_dict[match_key].append(obs.str_value)
            elif question.data_type.lower() == "int":
                for obs in observations:
                    match_key = str(obs.entry_date.day) + " " + str(obs.entry_date.month) + \
                                 " " + str(obs.entry_date.year)
                    result_dict[match_key].append(obs.int_value)
            elif question.data_type.lower() == "double":
                for obs in observations:
                    match_key = str(obs.entry_date.day) + " " + str(obs.entry_date.month) + \
                                 " " + str(obs.entry_date.year)
                    result_dict[match_key].append(obs.double_value)
            for res in result_dict.items():
                result_list.append(res)
            return result_list

        if plot_type == "month":
            if question.data_type.lower() == "str":
                for obs in observations:
                    match_key = str(obs.entry_date.month) + " " + str(obs.entry_date.year)
                    result_dict[match_key].append(obs.str_value)
            elif question.data_type.lower() == "int":
                for obs in observations:
                    match_key = str(obs.entry_date.month) + " " + str(obs.entry_date.year)
                    result_dict[match_key].append(obs.int_value)
            elif question.data_type.lower() == "double":
                for obs in observations:
                    match_key = str(obs.entry_date.month) + " " + str(obs.entry_date.year)
                    result_dict[match_key].append(obs.double_value)
            for res in result_dict.items():
                result_list.append(res)
            return result_list

    def make_dictionary(self, plot_type, start_date, end_date):

        """
         makes ordered dictionary with keys as day month and year
        """
        first_date = dateutil.parser.parse(start_date)
        final_date = dateutil.parser.parse(end_date)
        initial_dict = OrderedDict()
        if plot_type == "day":
            date_diff = final_date - first_date
            days = date_diff.days
            for i in range(0, days):
                start_date = first_date + relativedelta(days=i)
                keyy = str(start_date.day) + " " + str(start_date.month) + " " + str(start_date.year)
                initial_dict.update({keyy: []})
            return initial_dict
        if plot_type == "month":
            start_month = first_date.month
            end_month = final_date.month
            start_year = first_date.year
            end_year = final_date.year
            end_day = final_date.day
            if start_year == end_year:
                for i in range(start_month, end_month + 1):
                    keyy = str(start_month) + " " + str(start_year)
                    start_month += 1
                    initial_dict.update({keyy: []})
                return initial_dict
            while start_year <= end_year:
                keyy = str(start_month) + " " + str(start_year)
                initial_dict.update({keyy: []})
                start_month += 1
                if start_month > 12:
                    start_month = 1
                    start_year += 1
                if start_month > end_month and start_year == end_year:
                    break
                if start_month == end_month and end_day == 1:
                    break
            return initial_dict


class AllowedResponseView(APIView):
    def post(self, request, **kwargs):
        question_id = kwargs.get("question_id")
        question_obj = QuestionApi()._get(question_id)
        serializer = AllowedResponseSerializer(data=request.data)
        if serializer.is_valid():
            lst = []
            for index, option in enumerate(serializer.data):
                value = option.get("value")
                display_name = option.get("display_name")
                is_expected = option.get("is_expected")
                ar = AllowedResponseApi()._create(question_obj, value, index, display_name=display_name,
                                                  is_expected_value=is_expected)
                lst.append(str(ar.id))
            return Response({"ids": ",".join(lst)}, status=status.HTTP_201_CREATED)
        return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)

    def get(self, request, **kwargs):
        responses = AllowedResponseApi()._filter(None, participant=request.participant_token.participant)
        return Response([{"value": i.value, "icon_name": i.icon_name, "display_name": i.display_name, "order": i.order,
             "is_expected_value": i.is_expected_value, "id": i.id, "question_id": i.question.id}
            for i in responses], status=status.HTTP_201_CREATED)


class TotalSurveyCountView(APIView):
    def get(self, request, *args, **kwargs):
        data = request.GET
        survey_id_list = data.getlist("surveys[]")
        start_date = data.get("start_date")
        end_date = data.get("end_date")
        participant_id = request.participant_token.participant.id

        if 'date' in request.GET:
            all_observations = ObservationApi()._filter(survey_id_list=survey_id_list, parent=False).filter(is_edited=False).exclude(question_survey__survey__id__in=[130,1331,1330,1332])
            participant_observations = all_observations.filter(participant_id=participant_id)
            try:
                all_observations = all_observations.extra({'date_created' : "DATE(entry_date)"}).values('date_created').annotate(created_count=Count('id'))
                uniquekeys = []
                total_count = 0
                for row in all_observations:
                # for key, values in groupby(all_observations, key=lambda row: row['date_created'].strftime("%m-%d-%Y")):
                    key = row['date_created']
                    total_count = total_count + row['created_count']
                    # total_count = total_count + value
                    if key not in uniquekeys:
                        uniquekeys.append(key)
                all_count = len(uniquekeys)
            except:
                all_count = 0
            try:
                participant_observations = participant_observations.extra({'date_created' : "date(entry_date)"}).values('date_created').annotate(created_count=Count('id'))
                uniquekeys = []
                for key, values in groupby(participant_observations, key=lambda row: row['date_created'].strftime("%m-%d-%Y")):
                    if key not in uniquekeys:
                        uniquekeys.append(key)
                p_count = len(uniquekeys)
            except:
                p_count = 0
            return Response(response_standard({"all_count": all_count, "participant_count": p_count}), status=status.HTTP_200_OK)
        else:
            all_observations = ObservationApi()._filter(survey_id_list=survey_id_list,start_date=start_date,
                                                    end_date=end_date,
                                                    parent=False).filter(is_edited=False).exclude(question_survey__survey__id__in=[130,1330,1331,1332])
            participant_observations = all_observations.filter(participant_id=participant_id)
            return Response(response_standard({"all_count": all_observations.count(),
                         "participant_count": participant_observations.count()}),
                        status=status.HTTP_200_OK)
