"""
Modular RAG Orchestrator using LangGraph

This implements a multi-agent RAG system where specialized sub-agents handle
different aspects of the retrieval and generation pipeline. The orchestrator
coordinates these agents using LangGraph for complex workflows.

Architecture Overview:
- QueryAnalyzer: Classifies questions, extracts entities, determines strategy
- RetrieverAgent: Handles multi-modal retrieval (semantic, keyword, hybrid)
- KnowledgeSynthesizer: Combines retrieved info with existing knowledge
- AnswerGenerator: Uses LLM to generate coherent responses
- FactChecker: Validates claims against sources
- MemoryManager: Maintains conversation context
- EntityGraphAgent: Manages entity relationships and multi-hop queries

The orchestrator routes queries through appropriate agent combinations based on
question type and complexity.
"""

from typing import Dict, List, Any, Optional, TypedDict, Annotated
from pathlib import Path
import json
from datetime import datetime

# LangGraph imports (to be installed)
try:
    from langgraph.graph import StateGraph, END
    from langchain.schema import BaseMessage
    LANGGRAPH_AVAILABLE = True
except ImportError:
    LANGGRAPH_AVAILABLE = False
    print("Warning: LangGraph not available. Install with: pip install langgraph")

from mistral_integration import RAGAgent
from vector_store import VectorStore


class AgentState(TypedDict):
    """State passed between agents in the workflow"""
    question: str
    question_type: str
    entities: List[str]
    search_results: List[Dict[str, Any]]
    distilled_insights: List[Dict[str, Any]]
    context: str
    answer: str
    sources: List[Dict[str, Any]]
    memory: List[str]
    entity_hints: List[str]
    confidence: float
    needs_fact_check: bool
    final_answer: str


class QueryAnalyzer:
    """Analyzes questions to determine processing strategy"""

    def __init__(self, rag_agent: RAGAgent):
        self.rag_agent = rag_agent

    def analyze(self, question: str) -> Dict[str, Any]:
        """Analyze question and return processing recommendations"""
        question_type = self.rag_agent.analyze_question(question)

        # Extract key entities and topics
        entities = self._extract_entities(question)

        # Determine complexity and required agents
        complexity = self._assess_complexity(question, entities)

        return {
            'question_type': question_type,
            'entities': entities,
            'complexity': complexity,
            'required_agents': self._get_required_agents(question_type, complexity),
            'retrieval_strategy': self._choose_retrieval_strategy(question_type, entities)
        }

    def _extract_entities(self, question: str) -> List[str]:
        """Extract named entities from question"""
        # Use simple heuristics for now (could be enhanced with spaCy/NER)
        words = question.lower().split()
        entities = []

        # Look for capitalized words, historical figures, places
        for word in words:
            if word[0].isupper() and len(word) > 3:
                entities.append(word)

        return list(set(entities))

    def _assess_complexity(self, question: str, entities: List[str]) -> str:
        """Assess question complexity"""
        word_count = len(question.split())

        if word_count > 20 or len(entities) > 3:
            return 'high'
        elif word_count > 10 or len(entities) > 1:
            return 'medium'
        else:
            return 'low'

    def _get_required_agents(self, question_type: str, complexity: str) -> List[str]:
        """Determine which agents are needed"""
        base_agents = ['retriever', 'answer_generator']

        if complexity == 'high':
            base_agents.extend(['knowledge_synthesizer', 'fact_checker'])
        elif question_type == 'factual':
            base_agents.append('fact_checker')
        elif question_type == 'comparative':
            base_agents.append('knowledge_synthesizer')

        return base_agents

    def _choose_retrieval_strategy(self, question_type: str, entities: List[str]) -> str:
        """Choose optimal retrieval approach"""
        if question_type == 'source_lookup':
            return 'keyword'  # Precise matching
        elif len(entities) > 2:
            return 'hybrid'  # Balance precision and recall
        else:
            return 'semantic'  # Broad semantic matching


class RetrieverAgent:
    """Handles multi-modal retrieval operations"""

    def __init__(self, vector_store: VectorStore, rag_agent: RAGAgent):
        self.vector_store = vector_store
        self.rag_agent = rag_agent

    def retrieve(self, query: str, strategy: str = 'semantic', k: int = 5) -> List[Dict[str, Any]]:
        """Execute retrieval with specified strategy"""
        if strategy == 'semantic':
            return self.vector_store.search(query, k=k)
        elif strategy == 'keyword':
            return self.vector_store.keyword_search(query, k=k)
        elif strategy == 'hybrid':
            return self.vector_store.hybrid_search(query, k=k)
        else:
            return self.vector_store.search(query, k=k)


class KnowledgeSynthesizer:
    """Combines retrieved information with existing knowledge"""

    def __init__(self, rag_agent: RAGAgent):
        self.rag_agent = rag_agent

    def synthesize(self, search_results: List[Dict[str, Any]], question: str) -> Dict[str, Any]:
        """Synthesize knowledge from multiple sources"""
        # Collect insights from digests and cards
        distilled_insights = self.rag_agent.collect_distilled_insights(search_results)

        # Generate entity hints
        entity_hints = self.rag_agent._format_entity_graph_hints(question)

        return {
            'distilled_insights': distilled_insights,
            'entity_hints': entity_hints,
            'synthesis_summary': self._create_synthesis_summary(distilled_insights, entity_hints)
        }

    def _create_synthesis_summary(self, insights: List[Dict[str, Any]], entity_hints: str) -> str:
        """Create a concise synthesis of available knowledge"""
        if not insights:
            return "No distilled insights available."

        summary_parts = []
        for insight in insights[:3]:
            title = insight.get('title', 'Unknown source')
            overview = insight.get('overview', '')[:100]
            if overview:
                summary_parts.append(f"{title}: {overview}")

        if entity_hints:
            summary_parts.append(f"Entity connections: {entity_hints[:200]}")

        return ' | '.join(summary_parts)


class AnswerGenerator:
    """Generates final answers using LLM with enhanced prompting"""

    def __init__(self, rag_agent: RAGAgent):
        self.rag_agent = rag_agent

    def generate(self, context: str, question: str, insights: Dict[str, Any],
                question_type: str, search_results: List[Dict[str, Any]]) -> str:
        """Generate answer using the enhanced reasoning pipeline"""
        # Use the existing RAGAgent's generate_extraction_answer method
        # but with additional context from insights
        return self.rag_agent.generate_extraction_answer(
            context=context,
            question=question,
            search_results=search_results
        )


class FactChecker:
    """Validates claims against source material"""

    def __init__(self, rag_agent: RAGAgent):
        self.rag_agent = rag_agent

    def check_claims(self, answer: str, sources: List[Dict[str, Any]]) -> Dict[str, Any]:
        """Verify factual claims in the answer"""
        claims = self._extract_claims(answer)

        verification_results = []
        for claim in claims:
            verification = self._verify_claim(claim, sources)
            verification_results.append({
                'claim': claim,
                'verified': verification['verified'],
                'confidence': verification['confidence'],
                'supporting_sources': verification['sources']
            })

        overall_confidence = sum(r['confidence'] for r in verification_results) / len(verification_results) if verification_results else 0.5

        return {
            'verification_results': verification_results,
            'overall_confidence': overall_confidence,
            'recommendations': self._generate_recommendations(verification_results)
        }

    def _extract_claims(self, answer: str) -> List[str]:
        """Extract factual claims from answer"""
        # Simple sentence splitting for claims
        sentences = [s.strip() for s in answer.split('.') if s.strip()]
        return sentences

    def _verify_claim(self, claim: str, sources: List[Dict[str, Any]]) -> Dict[str, Any]:
        """Verify a single claim against sources"""
        # Use fuzzy matching to find supporting evidence
        best_match = {'score': 0, 'source': None}

        for source in sources:
            content = source.get('content', '').lower()
            claim_lower = claim.lower()

            # Simple containment check (could be enhanced with semantic similarity)
            if claim_lower in content:
                best_match = {'score': 0.9, 'source': source['source']}
                break

            # Fuzzy matching as fallback
            from rapidfuzz import fuzz
            score = fuzz.partial_ratio(claim_lower, content) / 100.0
            if score > best_match['score']:
                best_match = {'score': score, 'source': source['source']}

        return {
            'verified': best_match['score'] > 0.7,
            'confidence': best_match['score'],
            'sources': [best_match['source']] if best_match['source'] else []
        }

    def _generate_recommendations(self, verification_results: List[Dict[str, Any]]) -> List[str]:
        """Generate recommendations based on verification results"""
        recommendations = []

        low_confidence_claims = [r for r in verification_results if r['confidence'] < 0.6]
        if low_confidence_claims:
            recommendations.append(f"Consider additional research for {len(low_confidence_claims)} claims with low confidence")

        unverified_claims = [r for r in verification_results if not r['verified']]
        if unverified_claims:
            recommendations.append(f"Unable to verify {len(unverified_claims)} claims against available sources")

        return recommendations


class MemoryManager:
    """Manages conversation context and short-term memory"""

    def __init__(self, max_memory: int = 4):
        self.short_term_memory: List[str] = []
        self.max_memory = max_memory

    def add_to_memory(self, item: str):
        """Add item to short-term memory"""
        self.short_term_memory.append(item)
        if len(self.short_term_memory) > self.max_memory:
            self.short_term_memory = self.short_term_memory[-self.max_memory:]

    def get_memory_context(self) -> str:
        """Get formatted memory context"""
        if not self.short_term_memory:
            return "None recorded."
        return '\n'.join(f"- {note}" for note in self.short_term_memory)

    def clear_memory(self):
        """Clear all memory"""
        self.short_term_memory = []


class ModularRAGOrchestrator:
    """Main orchestrator that coordinates specialized agents"""

    def __init__(self, rag_agent: RAGAgent, vector_store: VectorStore):
        self.rag_agent = rag_agent
        self.vector_store = vector_store

        # Initialize specialized agents
        self.query_analyzer = QueryAnalyzer(rag_agent)
        self.retriever = RetrieverAgent(vector_store, rag_agent)
        self.knowledge_synthesizer = KnowledgeSynthesizer(rag_agent)
        self.answer_generator = AnswerGenerator(rag_agent)
        self.fact_checker = FactChecker(rag_agent)
        self.memory_manager = MemoryManager()

        # Build the LangGraph workflow if available
        self.workflow = self._build_workflow() if LANGGRAPH_AVAILABLE else None

    def _build_workflow(self) -> StateGraph:
        """Build the agent workflow using LangGraph"""
        workflow = StateGraph(AgentState)

        # Add nodes for each agent
        workflow.add_node("analyze_query", self._analyze_query_node)
        workflow.add_node("retrieve_info", self._retrieve_info_node)
        workflow.add_node("synthesize_knowledge", self._synthesize_knowledge_node)
        workflow.add_node("generate_answer", self._generate_answer_node)
        workflow.add_node("fact_check", self._fact_check_node)
        workflow.add_node("finalize", self._finalize_node)

        # Define workflow edges
        workflow.set_entry_point("analyze_query")

        # Conditional routing based on analysis
        workflow.add_conditional_edges(
            "analyze_query",
            self._route_after_analysis,
            {
                "simple": "retrieve_info",
                "complex": "synthesize_knowledge",
                "factual": "fact_check"
            }
        )

        workflow.add_edge("retrieve_info", "generate_answer")
        workflow.add_edge("synthesize_knowledge", "generate_answer")
        workflow.add_edge("generate_answer", "fact_check")
        workflow.add_edge("fact_check", "finalize")

        workflow.set_finish_point("finalize")

        return workflow.compile()

    def _analyze_query_node(self, state: AgentState) -> AgentState:
        """Analyze the incoming query"""
        analysis = self.query_analyzer.analyze(state['question'])
        return {
            **state,
            'question_type': analysis['question_type'],
            'entities': analysis['entities']
        }

    def _retrieve_info_node(self, state: AgentState) -> AgentState:
        """Retrieve relevant information"""
        analysis = self.query_analyzer.analyze(state['question'])
        search_results = self.retriever.retrieve(
            state['question'],
            strategy=analysis['retrieval_strategy']
        )

        context = self.rag_agent.format_context(search_results, question=state['question'])

        return {
            **state,
            'search_results': search_results,
            'context': context
        }

    def _synthesize_knowledge_node(self, state: AgentState) -> AgentState:
        """Synthesize knowledge from multiple sources"""
        synthesis = self.knowledge_synthesizer.synthesize(
            state['search_results'],
            state['question']
        )

        return {
            **state,
            'distilled_insights': synthesis['distilled_insights'],
            'entity_hints': synthesis['entity_hints']
        }

    def _generate_answer_node(self, state: AgentState) -> AgentState:
        """Generate the final answer"""
        answer = self.answer_generator.generate(
            context=state['context'],
            question=state['question'],
            insights={
                'distilled_insights': state.get('distilled_insights', []),
                'entity_hints': state.get('entity_hints', [])
            },
            question_type=state['question_type'],
            search_results=state['search_results']
        )

        return {
            **state,
            'answer': answer
        }

    def _fact_check_node(self, state: AgentState) -> AgentState:
        """Verify the generated answer"""
        verification = self.fact_checker.check_claims(
            state['answer'],
            state['search_results']
        )

        return {
            **state,
            'confidence': verification['overall_confidence'],
            'needs_fact_check': verification['overall_confidence'] < 0.7
        }

    def _finalize_node(self, state: AgentState) -> AgentState:
        """Finalize the response"""
        sources = self.rag_agent.prepare_source_summaries(state['search_results'])

        # Add to memory for future context
        self.memory_manager.add_to_memory(f"Q: {state['question'][:50]}... A: {state['answer'][:50]}...")

        return {
            **state,
            'sources': sources,
            'final_answer': state['answer']
        }

    def _route_after_analysis(self, state: AgentState) -> str:
        """Route to appropriate next step based on analysis"""
        complexity = self.query_analyzer._assess_complexity(
            state['question'],
            state.get('entities', [])
        )

        if complexity == 'high':
            return "complex"
        elif state['question_type'] == 'factual':
            return "factual"
        else:
            return "simple"

    def query(self, question: str) -> Dict[str, Any]:
        """Main entry point for queries"""
        if self.workflow:
            # Use LangGraph workflow
            initial_state = AgentState(
                question=question,
                question_type="",
                entities=[],
                search_results=[],
                distilled_insights=[],
                context="",
                answer="",
                sources=[],
                memory=[],
                entity_hints=[],
                confidence=0.0,
                needs_fact_check=False,
                final_answer=""
            )

            result = self.workflow.invoke(initial_state)
            return {
                'answer': result['final_answer'],
                'sources': result['sources'],
                'confidence': result['confidence']
            }
        else:
            # Fallback to simplified pipeline
            return self._simple_query(question)

    def _simple_query(self, question: str) -> Dict[str, Any]:
        """Simplified query processing without LangGraph"""
        # Analyze query
        analysis = self.query_analyzer.analyze(question)

        # Retrieve information
        search_results = self.retriever.retrieve(question, strategy=analysis['retrieval_strategy'])
        context = self.rag_agent.format_context(search_results, question=question)

        # Generate answer
        answer = self.rag_agent.generate_extraction_answer(context, question, search_results)
        sources = self.rag_agent.prepare_source_summaries(search_results)

        return {
            'answer': answer,
            'sources': sources,
            'confidence': 0.8  # Placeholder
        }


# Integration with existing RAGSystem
def create_modular_orchestrator(rag_agent: RAGAgent, vector_store: VectorStore) -> ModularRAGOrchestrator:
    """Factory function to create modular orchestrator"""
    return ModularRAGOrchestrator(rag_agent, vector_store)


if __name__ == "__main__":
    # Example usage
    print("Modular RAG Orchestrator Design")
    print("=" * 40)
    print("This module provides a blueprint for breaking the monolithic RAGAgent")
    print("into specialized sub-agents that can be orchestrated using LangGraph.")
    print("\nKey Components:")
    print("- QueryAnalyzer: Classifies questions and determines processing strategy")
    print("- RetrieverAgent: Handles semantic/keyword/hybrid retrieval")
    print("- KnowledgeSynthesizer: Combines insights from multiple sources")
    print("- AnswerGenerator: Uses LLM with enhanced prompting")
    print("- FactChecker: Validates claims against sources")
    print("- MemoryManager: Maintains conversation context")
    print("\nTo use: Install langgraph, then integrate with existing RAGSystem")
