import json
import time
from pathlib import Path
from pdf_processor import PDFProcessor
from vector_store import VectorStore
from mistral_integration import RAGAgent
from config import PDF_DIR, DOCUMENTS_DIR, DEFAULT_RETRIEVAL_K, MIN_CHUNK_SCORE
from answer_synthesis import enhance_retrieved_context, synthesize_answer, strip_reasoning
from exceptions import (
    RAGSystemError,
    VectorStoreError,
    ModelError,
    DocumentProcessingError,
    ValidationError
)
from logging_config import get_logger, log_performance, log_error_with_context

class RAGSystem:
    def __init__(self):
        self.logger = get_logger(__name__)
        self.pdf_processor = PDFProcessor()
        self.vector_store = VectorStore()
        self.rag_agent = None  # Initialize later to avoid loading model unnecessarily
        self._initialized = False  # Track initialization state

    def initialize_model(self):
        """Initialize RAG agent (called on demand to save memory)"""
        # No longer loading local model - using Grok API instead
        if self.rag_agent is None:
            self.rag_agent = RAGAgent()  # RAGAgent now uses Grok API, no local model loading

    def setup_system(self):
        """Initialize the RAG system"""
        self.logger.info("RAG System initialization started")

        start_time = time.time()

        try:
            # Load existing vector store
            self.vector_store.load_existing_index()
            self.ensure_enriched_index()

            # Process new PDFs
            new_chunks = self.pdf_processor.process_all_new()

            # Update vector store with new chunks (RAW chunks only - enrichment happens via button)
            if new_chunks:
                self.logger.info(f"Processing {len(new_chunks)} new document chunks")
                self.initialize_model()
                # Use raw chunks directly - enrichment happens manually via UI button
                self.vector_store.add_documents(new_chunks)
                self.vector_store.save_index()
                self.logger.info(f"Successfully processed {len(new_chunks)} raw chunks (enrichment available via UI button)")

            total_docs = len(self.vector_store.documents) if self.vector_store.documents else 0

            setup_duration = time.time() - start_time
            log_performance(self.logger, "system_setup", setup_duration,
                          total_documents=total_docs,
                          new_chunks_processed=len(new_chunks) if new_chunks else 0)

            self.logger.info(f"RAG System initialized successfully with {total_docs} documents")
            return total_docs

        except Exception as e:
            log_error_with_context(self.logger, e, {"operation": "setup_system"}, "Failed to initialize RAG system")
            raise

    def search_and_answer(self, question, k=None):
        """Search for relevant documents and generate answer"""
        start_time = time.time()
        question_preview = str(question)[:100] if question is not None else "None"
        self.logger.info(f"Processing question: {question_preview}{'...' if question and len(str(question)) > 100 else ''}")

        # Use default k if not specified
        if k is None:
            k = DEFAULT_RETRIEVAL_K

        # Validate query early
        if not question:
            self.logger.warning("Empty question provided")
            return {
                'answer': "Please provide a question.",
                'sources': []
            }
        if not isinstance(question, str):
            self.logger.warning(f"Invalid question type: {type(question)}")
            return {
                'answer': "Question must be a string.",
                'sources': []
            }
        question = question.strip()
        if not question:
            self.logger.warning("Question is only whitespace")
            return {
                'answer': "Please provide a meaningful question.",
                'sources': []
            }

        if not self.vector_store.documents:
            self.logger.warning("Search attempted but no documents available")
            return {
                'answer': "No documents available in the system. Please add some PDFs first.",
                'sources': []
            }

        # Initialize model if needed
        try:
            self.initialize_model()
        except Exception as e:
            log_error_with_context(self.logger, e, {"operation": "model_initialization", "question": question[:100]})
            return {
                'answer': f"Error initializing model: {str(e)}",
                'sources': []
            }

        # Search for relevant documents
        try:
            search_results = self.vector_store.search(question, k=k)
            self.logger.debug(f"Found {len(search_results)} search results")
            
            # Log all retrieved chunks with scores for debugging
            self.logger.info(f"Retrieved {len(search_results)} chunks with scores:")
            for i, result in enumerate(search_results[:10], 1):
                score = result.get('score', 0.0)
                vector_score = result.get('vector_score', 0.0)
                keyword_score = result.get('keyword_score', 0.0)
                source = result.get('metadata', {}).get('source', 'Unknown')
                self.logger.info(f"  Chunk {i}: {source[:50]} - Combined: {score:.3f} (Vector: {vector_score:.3f}, Keyword: {keyword_score:.3f})")
            
            # Filter out low-quality chunks below minimum score threshold
            filtered_results = [r for r in search_results if r.get('score', 0.0) >= MIN_CHUNK_SCORE]
            
            if filtered_results:
                self.logger.info(f"Filtered to {len(filtered_results)} high-quality chunks (score >= {MIN_CHUNK_SCORE})")
                score_list = [f"{r.get('score', 0.0):.3f}" for r in filtered_results[:5]]
                self.logger.info(f"Filtered chunk scores: {score_list}")
                search_results = filtered_results
            elif search_results:
                # If all chunks are below threshold, use top 3 anyway but warn
                self.logger.warning(f"All chunks below threshold {MIN_CHUNK_SCORE}, using top 3 anyway")
                top3_scores = [f"{r.get('score', 0.0):.3f}" for r in search_results[:3]]
                self.logger.warning(f"Top 3 scores: {top3_scores}")
                search_results = search_results[:3]
        except ValidationError as e:
            self.logger.warning(f"Query validation failed: {e.message}")
            return {
                'answer': f"Invalid search query: {e.message}",
                'sources': []
            }
        except VectorStoreError as e:
            log_error_with_context(self.logger, e, {"operation": "vector_search", "question": question[:100]})
            return {
                'answer': f"Vector store error: {e.message}",
                'sources': []
            }
        except Exception as e:
            log_error_with_context(self.logger, e, {"operation": "search", "question": question[:100]})
            return {
                'answer': f"Unexpected search error: {str(e)}",
                'sources': []
            }

        if not search_results:
            self.logger.info(f"No relevant documents found for question: {question[:100]}")
            return {
                'answer': "No relevant documents found for your question.",
                'sources': []
            }

        # Format context and generate answer
        try:
            self.logger.debug(f"Formatting context from {len(search_results)} search results")
            context = self.rag_agent.format_context(search_results, question=question)
            self.logger.debug(f"Context length: {len(context)} characters")
            answer = self.rag_agent.generate_answer(context, question, search_results)
            sources = self.rag_agent.prepare_source_summaries(search_results)
            self.logger.info(f"Generated answer with {len(sources)} sources")
            for i, source in enumerate(sources[:5], 1):
                self.logger.info(f"  Source {i}: {source.get('source', 'Unknown')[:50]} - Score: {source.get('score', 0.0):.3f}")

            duration = time.time() - start_time
            log_performance(self.logger, "question_answer", duration,
                          question_length=len(question),
                          results_count=len(search_results),
                          answer_length=len(answer) if answer else 0)

            # Apply 3-layer synthesis for better answers
            try:
                self.logger.info("Applying 3-layer answer synthesis...")
                enhanced_context = enhance_retrieved_context(question, search_results)
                synthesized_answer = synthesize_answer(question, enhanced_context)
                # Strip any remaining reasoning from the answer
                final_answer = strip_reasoning(synthesized_answer)
                self.logger.info(f"Synthesis complete. Original: {len(answer)} chars, Enhanced: {len(final_answer)} chars")
            except Exception as e:
                self.logger.warning(f"Synthesis failed, using original answer: {e}")
                final_answer = strip_reasoning(answer) if answer else answer

            self.logger.info(f"Successfully answered question in {duration:.2f}s")
            return {
                'answer': final_answer,
                'sources': sources,
                'enhanced': True
            }

        except Exception as e:
            log_error_with_context(self.logger, e, {"operation": "answer_generation", "question": question[:100]})
            return {
                'answer': f"Error generating answer: {str(e)}",
                'sources': []
            }

    def ensure_enriched_index(self):
        """Ensure the vector store uses enriched knowledge cards rather than raw chunks."""
        # OPTIMIZATION: Skip enrichment if vector store already has documents
        # This makes first question much faster - enrichment can happen later if needed
        if self.vector_store.documents and len(self.vector_store.documents) > 0:
            # Check if documents have enriched metadata
            # Safe access with bounds checking
            sample_metadata = {}
            try:
                if len(self.vector_store.documents) > 0:
                    sample_metadata = self.vector_store.documents[0].get('metadata', {})
            except (IndexError, AttributeError) as e:
                self.logger.debug(f"Could not access sample metadata: {e}")
                sample_metadata = {}
            if sample_metadata.get('summary') or sample_metadata.get('card_id'):
                # Already enriched, nothing to do
                return
            else:
                # Has documents but not enriched - skip enrichment for now to speed up initialization
                # The system works fine with raw chunks, enrichment is optional
                self.logger.info(f"Vector store has {len(self.vector_store.documents)} documents (not enriched). Skipping enrichment for faster startup.")
                return
        
        # Only rebuild if vector store is empty
        chunk_files = list(Path(DOCUMENTS_DIR).glob("*_chunks.json"))
        if not chunk_files:
            return

        print("Rebuilding knowledge index from stored chunks...")
        try:
            self._rebuild_vector_store_from_chunks(chunk_files)
        except VectorStoreError as e:
            print(f"Vector store rebuild failed: {e.message}")
            raise
        except Exception as e:
            print(f"Unexpected error during rebuild: {e}")
            raise

    def _rebuild_vector_store_from_chunks(self, chunk_files):
        """Rebuild the vector store using knowledge cards generated from stored chunks."""
        all_chunks = []
        for chunk_file in chunk_files:
            try:
                with open(chunk_file, 'r') as f:
                    data = json.load(f)
                    if isinstance(data, list):
                        all_chunks.extend(data)
            except Exception as e:
                print(f"Warning: Failed to load chunk file {chunk_file}: {e}")

        if not all_chunks:
            print("No chunk data available for rebuilding.")
            return

        print(f"   • Loaded {len(all_chunks)} stored chunks")

        # Reset existing vector store (RAW chunks only - enrichment happens via button)
        try:
            self.vector_store = VectorStore()
            self.initialize_model()
            # Use raw chunks directly - enrichment happens manually via UI button
            if all_chunks:
                self.vector_store.add_documents(all_chunks)
                self.vector_store.save_index()
                print(f"   • Rebuilt vector store with {len(all_chunks)} raw chunks (enrichment available via UI button)")
        except VectorStoreError as e:
            print(f"   • Vector store error: {e.message}")
            raise
        except ModelError as e:
            print(f"   • Model error: {e.message}")
            raise
        except DocumentProcessingError as e:
            print(f"   • Document processing error: {e.message}")
            raise
        except Exception as e:
            print(f"   • Unexpected error rebuilding vector store: {e}")
            raise

def main():
    """Command line interface for testing"""
    rag_system = RAGSystem()

    # Setup system
    doc_count = rag_system.setup_system()

    if doc_count == 0:
        print("No documents available. Please add PDFs to the pdf_directory folder.")
        return

    print(f"\nTotal documents in vector store: {doc_count}")

    # Interactive Q&A loop
    while True:
        try:
            question = input("\nEnter your question (or 'quit' to exit): ").strip()

            if question.lower() in ['quit', 'exit', 'q']:
                break

            if not question:
                continue

            print("Searching and generating answer...")
            result = rag_system.search_and_answer(question)

            print(f"\nAnswer: {result['answer']}\n")
            if result['sources']:
                print("Sources:")
                for i, source in enumerate(result['sources'], 1):
                    print(f"{i}. {source['source']} (Score: {source['score']:.3f})")
                    print(f"   Preview: {source['preview']}\n")

        except KeyboardInterrupt:
            print("\nExiting...")
            break
        except Exception as e:
            print(f"Error: {e}")

if __name__ == "__main__":
    main()
