"""
Memory Service - Central coordinator for all memory operations
Manages user profiles, conversations, and intelligent memory retrieval
"""

import json
import pickle
import numpy as np
from typing import Dict, List, Any, Optional, Tuple
from datetime import datetime, timedelta
from pathlib import Path
import logging

logger = logging.getLogger(__name__)


class MemoryService:
    """Central service for managing user memory and personalization"""

    def __init__(self, db_connection=None, vector_store=None, grok_client=None):
        self.db = db_connection
        self.vector_store = vector_store
        self.grok_client = grok_client

        # Initialize components
        from .entity_extractor import EntityExtractor
        from .conversation_summarizer import ConversationSummarizer
        from .memory_rag import MemoryEnhancedRAG

        self.entity_extractor = EntityExtractor(grok_client)
        self.summarizer = ConversationSummarizer(grok_client)
        self.memory_rag = MemoryEnhancedRAG(vector_store, self, self.entity_extractor) if vector_store else None

    def get_user_profile(self, user_id: str) -> Optional[Dict[str, Any]]:
        """Get comprehensive user profile"""
        try:
            if self.db:
                return self._get_user_profile_from_db(user_id)
            else:
                return self._get_user_profile_from_file(user_id)
        except Exception as e:
            logger.error(f"Error getting user profile: {e}")
            return None

    def _get_user_profile_from_db(self, user_id: str) -> Optional[Dict[str, Any]]:
        """Get user profile from database"""
        query = """
        SELECT user_id, demographics, interests, survey_preferences,
               expertise_level, monetization_profile, personality_traits,
               created_at, updated_at
        FROM user_profiles
        WHERE user_id = %s
        """

        result = self.db.fetch_one(query, (user_id,))
        if result:
            profile = {
                'user_id': result[0],
                'demographics': json.loads(result[1]) if result[1] else {},
                'interests': json.loads(result[2]) if result[2] else [],
                'survey_preferences': json.loads(result[3]) if result[3] else {},
                'expertise_level': result[4],
                'monetization_profile': json.loads(result[5]) if result[5] else {},
                'personality_traits': json.loads(result[6]) if result[6] else {},
                'created_at': result[7],
                'updated_at': result[8]
            }
            return profile
        return None

    def _get_user_profile_from_file(self, user_id: str) -> Optional[Dict[str, Any]]:
        """Fallback file-based storage for development"""
        profile_path = Path(f"memory/profiles/{user_id}.json")
        if profile_path.exists():
            with open(profile_path, 'r') as f:
                return json.load(f)
        return None

    def update_user_profile(self, user_id: str, profile_data: Dict[str, Any]) -> bool:
        """Update or create user profile"""
        try:
            if self.db:
                return self._update_user_profile_db(user_id, profile_data)
            else:
                return self._update_user_profile_file(user_id, profile_data)
        except Exception as e:
            logger.error(f"Error updating user profile: {e}")
            return False

    def _update_user_profile_db(self, user_id: str, profile_data: Dict[str, Any]) -> bool:
        """Update user profile in database"""
        # Convert data to JSON
        demographics = json.dumps(profile_data.get('demographics', {}))
        interests = json.dumps(profile_data.get('interests', []))
        survey_preferences = json.dumps(profile_data.get('survey_preferences', {}))
        expertise_level = profile_data.get('expertise_level')
        monetization_profile = json.dumps(profile_data.get('monetization_profile', {}))
        personality_traits = json.dumps(profile_data.get('personality_traits', {}))

        query = """
        INSERT INTO user_profiles
        (user_id, demographics, interests, survey_preferences, expertise_level,
         monetization_profile, personality_traits, updated_at)
        VALUES (%s, %s, %s, %s, %s, %s, %s, NOW())
        ON DUPLICATE KEY UPDATE
        demographics = VALUES(demographics),
        interests = VALUES(interests),
        survey_preferences = VALUES(survey_preferences),
        expertise_level = VALUES(expertise_level),
        monetization_profile = VALUES(monetization_profile),
        personality_traits = VALUES(personality_traits),
        updated_at = NOW()
        """

        return self.db.execute(query, (user_id, demographics, interests, survey_preferences,
                                     expertise_level, monetization_profile, personality_traits))

    def _update_user_profile_file(self, user_id: str, profile_data: Dict[str, Any]) -> bool:
        """Update user profile in file system"""
        profile_path = Path(f"memory/profiles/{user_id}.json")
        profile_path.parent.mkdir(parents=True, exist_ok=True)

        try:
            with open(profile_path, 'w') as f:
                json.dump(profile_data, f, indent=2)
            return True
        except Exception as e:
            logger.error(f"Error saving profile to file: {e}")
            return False

    def process_conversation_turn(self, user_id: str, session_id: str,
                                user_message: str, assistant_response: str) -> Dict[str, Any]:
        """Process a conversation turn and update memory"""

        # Extract entities from the conversation
        entities = self.entity_extractor.extract_entities(user_message, assistant_response)

        # Store the conversation turn
        self._store_conversation_turn(user_id, session_id, user_message, assistant_response, entities)

        # Update user profile with new entities
        if entities:
            current_profile = self.get_user_profile(user_id) or {}
            updated_profile = self._merge_profiles(current_profile, entities)
            self.update_user_profile(user_id, updated_profile)

        # Check if we need to summarize conversation
        conversation_history = self.get_recent_conversation(user_id, session_id, limit=50)
        if self.summarizer.should_summarize(len(conversation_history)):
            self._create_conversation_summary(user_id, session_id, conversation_history)

        return {
            'entities_extracted': entities,
            'profile_updated': bool(entities),
            'summary_created': self.summarizer.should_summarize(len(conversation_history))
        }

    def _merge_profiles(self, current_profile: Dict[str, Any], new_entities: Dict[str, Any]) -> Dict[str, Any]:
        """Merge new entities into existing profile"""

        merged = current_profile.copy()

        # Merge demographics
        current_demo = merged.get('demographics', {})
        new_demo = new_entities.get('demographics', {})
        merged['demographics'] = {**current_demo, **new_demo}

        # Merge interests (unique)
        current_interests = set(merged.get('interests', []))
        new_interests = set(new_entities.get('interests', []))
        merged['interests'] = list(current_interests | new_interests)

        # Merge survey preferences
        current_prefs = merged.get('survey_preferences', {})
        new_prefs = new_entities.get('survey_preferences', {})
        merged['survey_preferences'] = {**current_prefs, **new_prefs}

        # Update expertise level (prefer newer)
        if new_entities.get('expertise_level'):
            merged['expertise_level'] = new_entities['expertise_level']

        # Merge personality traits
        current_traits = merged.get('personality_traits', {})
        new_traits = new_entities.get('personality_traits', {})
        merged['personality_traits'] = {**current_traits, **new_traits}

        return merged

    def _store_conversation_turn(self, user_id: str, session_id: str,
                               user_message: str, assistant_response: str,
                               entities: Dict[str, Any]):
        """Store a conversation turn in memory"""

        try:
            if self.db:
                self._store_conversation_turn_db(user_id, session_id, user_message, assistant_response, entities)
            else:
                self._store_conversation_turn_file(user_id, session_id, user_message, assistant_response, entities)
        except Exception as e:
            logger.error(f"Error storing conversation turn: {e}")

    def _store_conversation_turn_db(self, user_id: str, session_id: str,
                                  user_message: str, assistant_response: str,
                                  entities: Dict[str, Any]):
        """Store conversation turn in database"""

        # Store user message
        user_query = """
        INSERT INTO conversation_memory
        (user_id, session_id, memory_type, content, metadata, created_at)
        VALUES (%s, %s, 'message', %s, %s, NOW())
        """
        user_metadata = json.dumps({
            'role': 'user',
            'entities': entities,
            'type': 'user_input'
        })
        self.db.execute(user_query, (user_id, session_id, user_message, user_metadata))

        # Store assistant response
        assistant_query = """
        INSERT INTO conversation_memory
        (user_id, session_id, memory_type, content, metadata, created_at)
        VALUES (%s, %s, 'message', %s, %s, NOW())
        """
        assistant_metadata = json.dumps({
            'role': 'assistant',
            'type': 'response'
        })
        self.db.execute(assistant_query, (user_id, session_id, assistant_response, assistant_metadata))

    def _store_conversation_turn_file(self, user_id: str, session_id: str,
                                    user_message: str, assistant_response: str,
                                    entities: Dict[str, Any]):
        """Store conversation turn in file system"""
        memory_path = Path(f"memory/conversations/{user_id}/{session_id}.jsonl")
        memory_path.parent.mkdir(parents=True, exist_ok=True)

        messages = [
            {
                'timestamp': datetime.now().isoformat(),
                'role': 'user',
                'content': user_message,
                'entities': entities
            },
            {
                'timestamp': datetime.now().isoformat(),
                'role': 'assistant',
                'content': assistant_response
            }
        ]

        with open(memory_path, 'a') as f:
            for message in messages:
                f.write(json.dumps(message) + '\n')

    def get_recent_conversation(self, user_id: str, session_id: str = None, limit: int = 10) -> List[Dict[str, Any]]:
        """Get recent conversation history"""
        try:
            if self.db:
                return self._get_recent_conversation_db(user_id, session_id, limit)
            else:
                return self._get_recent_conversation_file(user_id, session_id, limit)
        except Exception as e:
            logger.error(f"Error getting recent conversation: {e}")
            return []

    def _get_recent_conversation_db(self, user_id: str, session_id: str = None, limit: int = 10) -> List[Dict[str, Any]]:
        """Get recent conversation from database"""
        if session_id:
            query = """
            SELECT content, metadata, created_at
            FROM conversation_memory
            WHERE user_id = %s AND session_id = %s AND memory_type = 'message'
            ORDER BY created_at DESC
            LIMIT %s
            """
            results = self.db.fetch_all(query, (user_id, session_id, limit))
        else:
            query = """
            SELECT content, metadata, created_at
            FROM conversation_memory
            WHERE user_id = %s AND memory_type = 'message'
            ORDER BY created_at DESC
            LIMIT %s
            """
            results = self.db.fetch_all(query, (user_id, limit))

        conversation = []
        for result in reversed(results):  # Reverse to get chronological order
            metadata = json.loads(result[1]) if result[1] else {}
            conversation.append({
                'content': result[0],
                'role': metadata.get('role', 'unknown'),
                'timestamp': result[2].isoformat() if hasattr(result[2], 'isoformat') else str(result[2]),
                'metadata': metadata
            })

        return conversation

    def _get_recent_conversation_file(self, user_id: str, session_id: str = None, limit: int = 10) -> List[Dict[str, Any]]:
        """Get recent conversation from file system"""
        if session_id:
            memory_path = Path(f"memory/conversations/{user_id}/{session_id}.jsonl")
        else:
            # Get most recent session
            conversations_dir = Path(f"memory/conversations/{user_id}")
            if conversations_dir.exists():
                sessions = list(conversations_dir.glob("*.jsonl"))
                if sessions:
                    memory_path = max(sessions, key=lambda x: x.stat().st_mtime)
                else:
                    return []
            else:
                return []

        if not memory_path.exists():
            return []

        conversation = []
        with open(memory_path, 'r') as f:
            for line in f:
                if line.strip():
                    message = json.loads(line.strip())
                    conversation.append(message)

        # Return most recent messages
        return conversation[-limit:] if conversation else []

    def _create_conversation_summary(self, user_id: str, session_id: str, conversation_history: List[Dict[str, Any]]):
        """Create and store conversation summary"""
        try:
            summary = self.summarizer.create_summary(conversation_history)

            if self.db:
                self._store_summary_db(user_id, session_id, summary, len(conversation_history))
            else:
                self._store_summary_file(user_id, session_id, summary, len(conversation_history))

        except Exception as e:
            logger.error(f"Error creating conversation summary: {e}")

    def _store_summary_db(self, user_id: str, session_id: str, summary: str, message_count: int):
        """Store summary in database"""
        query = """
        INSERT INTO conversation_memory
        (user_id, session_id, memory_type, content, metadata, created_at)
        VALUES (%s, %s, 'summary', %s, %s, NOW())
        """

        metadata = json.dumps({
            'type': 'conversation_summary',
            'message_count': message_count,
            'summary_type': 'periodic'
        })

        self.db.execute(query, (user_id, session_id, summary, metadata))

    def _store_summary_file(self, user_id: str, session_id: str, summary: str, message_count: int):
        """Store summary in file system"""
        summary_path = Path(f"memory/summaries/{user_id}/{session_id}.json")
        summary_path.parent.mkdir(parents=True, exist_ok=True)

        summary_data = {
            'timestamp': datetime.now().isoformat(),
            'summary': summary,
            'message_count': message_count,
            'session_id': session_id
        }

        with open(summary_path, 'w') as f:
            json.dump(summary_data, f, indent=2)

    def get_memory_context(self, user_id: str, session_id: str = None) -> Dict[str, Any]:
        """Get comprehensive memory context for a user"""

        user_profile = self.get_user_profile(user_id)
        conversation_history = self.get_recent_conversation(user_id, session_id, limit=20)

        # Get survey progress if available
        survey_progress = self._get_survey_progress(user_id)

        return {
            'user_profile': user_profile,
            'conversation_buffer': conversation_history,
            'survey_progress': survey_progress,
            'session_id': session_id
        }

    def _get_survey_progress(self, user_id: str) -> Optional[Dict[str, Any]]:
        """Get user's survey progress"""
        try:
            if self.db:
                query = """
                SELECT survey_id, current_section, completed_questions, last_activity
                FROM survey_progress
                WHERE user_id = %s
                ORDER BY last_activity DESC
                LIMIT 1
                """
                result = self.db.fetch_one(query, (user_id,))
                if result:
                    return {
                        'survey_id': result[0],
                        'current_section': result[1],
                        'completed_questions': json.loads(result[2]) if result[2] else [],
                        'last_activity': result[3]
                    }
        except Exception as e:
            logger.error(f"Error getting survey progress: {e}")

        return None

    def enhanced_query(self, query: str, user_id: str, session_id: str = None) -> Dict[str, Any]:
        """Enhanced query with memory context"""
        if self.memory_rag:
            return self.memory_rag.query_with_memory(query, user_id, session_id,
                                                    self.get_user_profile(user_id),
                                                    self.get_recent_conversation(user_id, session_id, limit=10))
        else:
            # Fallback to basic query
            return {'enhanced_query': query, 'rag_context': []}
