from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import json
import re
import math
import threading
from pathlib import Path
from rapidfuzz import fuzz
from typing import List, Dict, Any, Optional, Tuple
from airagagent.config import (
    EMBEDDINGS_DIR,
    HYBRID_SEARCH_OVERSAMPLE,
    HYBRID_VECTOR_WEIGHT,
    HYBRID_KEYWORD_WEIGHT,
    HYBRID_MIN_KEYWORD_SCORE,
    HYBRID_BM25_WEIGHT,
    MAX_CHUNKS_PER_SOURCE,
    TECHNICAL_TERMS,
    SYNONYMS,
    TRADING_KEYWORD_TAXONOMY,
    TRADING_SYNONYMS,
    DOMAIN_KEYWORD_MAPPING
)
from airagagent.keyword_manager import get_keyword_manager
from airagagent.exceptions import (
    VectorStoreError,
    IndexLoadError,
    IndexSaveError,
    EmbeddingError,
    ValidationError,
    ResourceError
)

# Stopword list for promoting query terms to keyword patterns
QUERY_STOPWORDS = {
    'that', 'this', 'with', 'have', 'from', 'they', 'were', 'their', 'there',
    'which', 'about', 'into', 'through', 'after', 'before', 'where', 'when',
    'because', 'between', 'while', 'within', 'upon', 'these', 'those', 'many',
    'much', 'other', 'such', 'also', 'very', 'than', 'them', 'some', 'only',
    'even', 'most', 'more', 'like', 'been', 'being', 'shall', 'will', 'could',
    'would', 'should', 'might', 'your', 'that', 'into', 'onto', 'both', 'each',
    'summarize', 'summary', 'summaries', 'steps', 'guide', 'guides', 'information',
    'detail', 'details', 'explain', 'describe', 'please', 'help', 'answer',
    'question', 'regarding', 'related', 'tell', 'show', 'based', 'according',
    'provide', 'gimme', 'give', 'need', 'looking', 'want', 'make'
}

GENERIC_DOMAIN_TERMS = {
    'cannabis', 'marijuana', 'weed', 'hemp', 'pot', 'ganja',
    'plant', 'plants', 'grow', 'growing', 'cultivation', 'cultivate',
    'harvest', 'buds', 'bud', 'flower', 'flowers', 'guide', 'guides',
    'book', 'books', 'document', 'documents', 'manual', 'manuals', 'g', 'gram', 'grams'
}

BM25_K1 = 1.5
BM25_B = 0.75

CANONICAL_TOPIC_KEYWORDS = {
    'hydroponics': [
        'hydroponic', 'hydroponics', 'hydro',
        'soil-less', 'soil less', 'soilless',
        'water culture', 'nutrient solution', 'nutrient tank',
        'ebb and flow', 'reservoir', 'deep water culture', 'dwc',
        'drip system', 'flood and drain', 'soilless mix', 'soilless mixture'
    ],
    'soilless growing': [
        'soil-less', 'soil less', 'soilless mix', 'soilless mixture',
        'perlite mix', 'coco coir', 'rockwool', 'expanded clay', 'grow cube'
    ],
    'medical': [
        'medical', 'medicine', 'medicinal', 'patient', 'doctor', 'physician',
        'therapeutic', 'therapy', 'treatment', 'dosage', 'dose', 'clinical',
        'symptom', 'condition', 'diagnosis', 'prescription', 'wellness',
        'health', 'healthcare', 'pain relief', 'anxiety relief', 'anti-inflammatory'
    ],
    'finance': [
        'finance', 'financial', 'funding', 'capital', 'loan', 'credit', 'debt',
        'investment', 'investor', 'roi', 'return on investment', 'cash flow',
        'profit', 'profitability', 'revenue', 'income', 'margin', 'budget'
    ],
    'legal': [
        'legal', 'law', 'laws', 'statute', 'regulation', 'regulatory',
        'compliance', 'license', 'licensing', 'permit', 'permitting',
        'decriminalize', 'legalization', 'court', 'policy', 'ordinance'
    ],
    'business': [
        'business', 'entrepreneur', 'entrepreneurship', 'startup', 'company',
        'corporation', 'market', 'marketing', 'sales', 'branding', 'franchise',
        'supply chain', 'wholesale', 'retail', 'customer', 'client',
        'strategy', 'scaling', 'operations', 'management'
    ],
    'psychology': [
        'mental health', 'psychology', 'therapy', 'counseling', 'wellbeing',
        'stress', 'anxiety', 'ptsd', 'depression', 'mindfulness', 'meditation'
    ],
    'history': [
        'history', 'historical', 'chronicle', 'timeline', 'era', 'century', 'ancient',
        'modern history', 'historian', 'civilization', 'heritage', 'legacy',
        'anthropology', 'archaeology', 'culture', 'cultural history', 'records'
    ]
}

class VectorStore:
    def __init__(self, model_name=None):
        from airagagent.config import EMBEDDING_MODEL

        # Initialize keyword manager for advanced keyword processing
        self.keyword_manager = get_keyword_manager()
        if model_name is None:
            model_name = EMBEDDING_MODEL
        self.model = SentenceTransformer(model_name)
        self.index = None
        self.documents = []
        self.searchable_texts_raw = []
        self.searchable_texts_lower = []
        self.doc_lengths = []
        self.avg_doc_length = 0.0
        self.index_file = EMBEDDINGS_DIR / "faiss_index.index"
        self.metadata_file = EMBEDDINGS_DIR / "documents_metadata.json"

        # Thread safety: Lock for all write operations
        self._lock = threading.Lock()

    def load_existing_index(self):
        """Load existing FAISS index and metadata"""
        if not self.index_file.exists():
            return False

        if self.metadata_file.exists():
            try:
                self.index = faiss.read_index(str(self.index_file))
                with open(self.metadata_file, 'r', encoding='utf-8') as f:
                    self.documents = json.load(f)
                self._prepare_searchable_texts()
                print(f"✓ Loaded existing index with {len(self.documents)} documents")
                return True
            except json.JSONDecodeError as e:
                raise IndexLoadError(
                    str(self.metadata_file),
                    {"error_type": "json_decode", "original_error": str(e)}
                )
            except Exception as e:
                raise IndexLoadError(
                    str(self.index_file),
                    {"error_type": "faiss_load", "original_error": str(e)}
                )

        # Backward compatibility: try loading old pickle file
        from config import EMBEDDINGS_DIR
        old_pickle_file = EMBEDDINGS_DIR / "documents_metadata.pkl"
        if old_pickle_file.exists():
            try:
                print("Migrating from pickle to JSON format...")
                self.index = faiss.read_index(str(self.index_file))
                import pickle
                with open(old_pickle_file, 'rb') as f:
                    self.documents = pickle.load(f)
                # Save in new format
                self.save_metadata()
                self._prepare_searchable_texts()
                # Remove old file
                old_pickle_file.unlink()
                print(f"✓ Migrated index with {len(self.documents)} documents to JSON format")
                return True
            except Exception as e:
                raise IndexLoadError(
                    str(old_pickle_file),
                    {"error_type": "pickle_migration", "original_error": str(e)}
                )

        return False

    def add_documents(self, chunks):
        """Add new documents to the vector store"""
        with self._lock:
            if not chunks:
                return

            if not isinstance(chunks, list):
                raise ValidationError("Chunks must be a list", {"received_type": type(chunks)})

            # Validate chunk structure
            for i, chunk in enumerate(chunks):
                if not isinstance(chunk, dict):
                    raise ValidationError(f"Chunk {i} must be a dictionary", {"chunk_index": i, "chunk_type": type(chunk)})
                if 'content' not in chunk:
                    raise ValidationError(f"Chunk {i} missing 'content' key", {"chunk_index": i})

            contents = [chunk['content'] for chunk in chunks]

            try:
                embeddings = self.model.encode(
                    contents,
                    batch_size=32,
                    show_progress_bar=True,
                    convert_to_numpy=True,
                    normalize_embeddings=False
                )
            except Exception as e:
                # Get a preview of the problematic content
                problematic_content = contents[0][:200] + "..." if contents else ""
                raise EmbeddingError(problematic_content, {"original_error": str(e), "batch_size": len(contents)})

            if self.index is None:
                # Create new index
                self.index = faiss.IndexFlatIP(embeddings.shape[1])
                self.documents = []

            # Add to index
            embeddings_np = np.array(embeddings).astype('float32')
            faiss.normalize_L2(embeddings_np)  # Normalize for cosine similarity
            self.index.add(embeddings_np)

            # Store document metadata
            self.documents.extend(chunks)
            # Update searchable texts cache with new documents
            for chunk in chunks:
                searchable_text = self._build_searchable_text(chunk)
                self.searchable_texts_raw.append(searchable_text)
                self.searchable_texts_lower.append(searchable_text.lower())

            print(f"✓ Added {len(chunks)} chunks to vector store")

    def save_index(self):
        """Save the FAISS index and metadata"""
        with self._lock:
            if self.index is None:
                raise IndexSaveError("No index to save - index is None", {"index_state": "none"})

            try:
                faiss.write_index(self.index, str(self.index_file))
                self.save_metadata()
                print("✓ Saved vector store index")
            except Exception as e:
                raise IndexSaveError(str(self.index_file), {"original_error": str(e)})

    def save_metadata(self):
        """Save document metadata to JSON file"""
        try:
            with open(self.metadata_file, 'w', encoding='utf-8') as f:
                json.dump(self.documents, f, ensure_ascii=False, indent=2)
        except Exception as e:
            raise IndexSaveError(str(self.metadata_file), {"error_type": "json_save", "original_error": str(e)})

    def replace_source_documents(self, source_name, new_chunks):
        """
        Safely replace all documents for a source with new chunks.
        This rebuilds the index to ensure consistency (prevents index corruption).

        Args:
            source_name: Name of the source file to replace
            new_chunks: List of new chunk dictionaries to replace old ones with
        """
        with self._lock:
            if not new_chunks:
                # If no new chunks, just remove old ones
                self._remove_source_from_index(source_name)
                return

            print(f"Replacing documents for source: {source_name} ({len(new_chunks)} new chunks)")

            # Step 1: Filter out old documents for this source
            kept_documents = [
                doc for doc in self.documents
                if doc.get('metadata', {}).get('source') != source_name
            ]

            print(f"  Keeping {len(kept_documents)} documents from other sources")

            # Step 2: Combine with new documents
            updated_documents = kept_documents + new_chunks

            # Step 3: Rebuild index from scratch to ensure consistency
            # This is the safest approach - prevents index/document mismatch
            print(f"  Rebuilding index with {len(updated_documents)} total documents...")

            # Clear existing index
            if self.index is not None:
                dimension = self.index.d
            else:
                # Get dimension from model
                sample_embedding = self.model.encode(["test"])
                dimension = sample_embedding.shape[1]

            # Create new index
            self.index = faiss.IndexFlatIP(dimension)

            # Clear searchable texts cache
            self.searchable_texts_raw = []
            self.searchable_texts_lower = []

            # Re-embed all documents in batches
            contents = [doc['content'] for doc in updated_documents]
            batch_size = 32

            for i in range(0, len(contents), batch_size):
                batch = contents[i : i + batch_size]
                try:
                    embeddings = self.model.encode(
                        batch,
                        batch_size=batch_size,
                        show_progress_bar=False,
                        convert_to_numpy=True,
                        normalize_embeddings=False
                    )
                    embeddings_np = np.array(embeddings).astype('float32')
                    faiss.normalize_L2(embeddings_np)
                    self.index.add(embeddings_np)

                    # Update searchable texts cache
                    for j, doc in enumerate(updated_documents[i : i + batch_size]):
                        searchable_text = self._build_searchable_text(doc)
                        self.searchable_texts_raw.append(searchable_text)
                        self.searchable_texts_lower.append(searchable_text.lower())

                except Exception as e:
                    raise EmbeddingError(
                        f"Error embedding batch starting at {i}",
                        {"original_error": str(e), "batch_start": i}
                    )

            # Update documents list
            self.documents = updated_documents

            print(f"✓ Successfully replaced source '{source_name}': {len(new_chunks)} chunks")
            print(f"  Total documents in index: {self.index.ntotal}")
    
    def _remove_source_from_index(self, source_name):
        """Remove all documents for a source by rebuilding index without them."""
        with self._lock:
            kept_documents = [
                doc for doc in self.documents
                if doc.get('metadata', {}).get('source') != source_name
            ]

            if len(kept_documents) == len(self.documents):
                print(f"No documents found for source: {source_name}")
                return

            print(f"Removing {len(self.documents) - len(kept_documents)} documents for source: {source_name}")

            # Rebuild index with kept documents only
            if self.index is not None:
                dimension = self.index.d
            else:
                sample_embedding = self.model.encode(["test"])
                dimension = sample_embedding.shape[1]

            self.index = faiss.IndexFlatIP(dimension)
            self.searchable_texts_raw = []
            self.searchable_texts_lower = []

            contents = [doc['content'] for doc in kept_documents]
            batch_size = 32

            for i in range(0, len(contents), batch_size):
                batch = contents[i : i + batch_size]
                embeddings = self.model.encode(
                    batch,
                    batch_size=batch_size,
                    show_progress_bar=False,
                    convert_to_numpy=True,
                    normalize_embeddings=False
                )
                embeddings_np = np.array(embeddings).astype('float32')
                faiss.normalize_L2(embeddings_np)
                self.index.add(embeddings_np)

                for j, doc in enumerate(kept_documents[i : i + batch_size]):
                    searchable_text = self._build_searchable_text(doc)
                    self.searchable_texts_raw.append(searchable_text)
                    self.searchable_texts_lower.append(searchable_text.lower())

            self.documents = kept_documents
            print(f"✓ Removed source '{source_name}'. Total documents: {len(self.documents)}")
    
    def mark_source_as_deleted(self, source_name):
        """
        DEPRECATED: Use replace_source_documents instead.
        This method only marks documents as deleted but doesn't remove vectors from index,
        which can cause index corruption. Kept for backward compatibility but not recommended.
        """
        count = 0
        for doc in self.documents:
            if doc.get('metadata', {}).get('source') == source_name:
                doc.setdefault('metadata', {})['deleted'] = True
                count += 1
        # Only save if changes were made
        if count > 0:
            self.save_metadata()
            print(f"⚠ Marked {count} documents from {source_name} as deleted (deprecated method - may cause index mismatch)")
        return count

    def search(self, query, k=5):
        """Search for similar documents using hybrid vector + keyword re-ranking."""
        if self.index is None:
            raise VectorStoreError("Index not loaded", {"index_state": "none"})
        if len(self.documents) == 0:
            return []

        if not query:
            raise ValidationError("Query cannot be empty", {"query_length": 0})
        if not isinstance(query, str):
            raise ValidationError("Query must be a string", {"query_type": type(query).__name__})
        if len(query.strip()) == 0:
            raise ValidationError("Query cannot be only whitespace", {"query_content": repr(query)})

        query_lower = query.lower()
        query_words = re.findall(r'\b\w+\b', query_lower)
        
        domain_terms = []
        primary_domain_terms = []
        general_terms = []
        
        expanded_terms = set()
        for word in query_words:
            if word in SYNONYMS:
                expanded_terms.update(SYNONYMS[word])
        
        for term in TECHNICAL_TERMS:
            if term in query_lower and term not in domain_terms:
                domain_terms.append(term)
                if term not in GENERIC_DOMAIN_TERMS:
                    primary_domain_terms.append(term)
        
        for term in expanded_terms:
            if term not in domain_terms:
                domain_terms.append(term)
                if term not in GENERIC_DOMAIN_TERMS:
                    primary_domain_terms.append(term)
        
        for word in query_words:
            if len(word) <= 3:
                continue
            if word in QUERY_STOPWORDS:
                continue
            if word in domain_terms or word in general_terms:
                continue
            general_terms.append(word)

        keyword_patterns = []
        added_terms = set()
        for term in domain_terms + general_terms:
            if term in added_terms:
                continue
            keyword_patterns.append({
                'term': term,
                'pattern': re.compile(r'\b' + re.escape(term) + r'\b', re.IGNORECASE),
                'is_domain': term in domain_terms,
                'is_primary': term in primary_domain_terms
            })
            added_terms.add(term)
        self._ensure_searchable_texts()
        idf_table = self._compute_idf_table(keyword_patterns) if keyword_patterns else {}

        try:
            query_embedding = self.model.encode([query])
            query_embedding = np.array(query_embedding).astype('float32')
            faiss.normalize_L2(query_embedding)

            oversample = max(k * HYBRID_SEARCH_OVERSAMPLE, k)
            scores, indices = self.index.search(query_embedding, oversample)
        except Exception as e:
            print(f"Error during search: {e}")
            return []

        candidates = []
        retrieved_indices = []
        for score, idx in zip(scores[0], indices[0]):
            if idx >= len(self.documents):
                continue
            doc = self.documents[idx]
            # Skip deleted documents
            if doc.get('metadata', {}).get('deleted'):
                continue

            retrieved_indices.append(idx)

            searchable_text_lower = self.searchable_texts_lower[idx]
            searchable_text_raw = self.searchable_texts_raw[idx]
            
            # Initial keyword score from fuzzy matching
            keyword_score = 0.0
            if searchable_text_lower:
                keyword_score = fuzz.partial_ratio(query_lower, searchable_text_lower) / 100.0
            
            # Check for exact keyword matches in searchable text (whole word matching)
            # This prevents false matches like "hydro" in "tetrahydrocannabinol"
            matched_terms = []
            has_keyword_match = False
            keyword_snippet = None
            matched_domain = False
            matched_primary = False
            require_domain_match = any(p['is_domain'] for p in keyword_patterns)
            require_primary_match = any(p['is_primary'] for p in keyword_patterns)

            if keyword_patterns:
                for pattern_info in keyword_patterns:
                    term_pattern = pattern_info['pattern']
                    if term_pattern.search(searchable_text_raw):
                        matched_terms.append(pattern_info['term'])
                        if pattern_info['is_domain']:
                            matched_domain = True
                            if pattern_info.get('is_primary'):
                                matched_primary = True
                        keyword_score = min(1.0, keyword_score + 0.5)
                        has_keyword_match = True

                if len(matched_terms) > 1:
                    keyword_score = min(1.0, keyword_score + 0.2 * min(len(matched_terms) - 1, 3))

                keyword_snippet = self._extract_keyword_snippet(searchable_text_raw, keyword_patterns)
                if not keyword_snippet and has_keyword_match:
                    keyword_snippet = searchable_text_raw[:240]

                if has_keyword_match and not matched_domain and not require_domain_match:
                    keyword_score = min(keyword_score, 0.4)
            
            # Also check for exact word matches from query (not just keywords)
            # This helps catch terms like "hydroponics" even if not in keyword list
            if not matched_terms:  # Only if no keyword matches found
                for word in query_words:
                    if len(word) > 3:  # Only check words longer than 3 chars
                        word_pattern = re.compile(r'\b' + re.escape(word) + r'\b', re.IGNORECASE)
                        if word_pattern.search(searchable_text_raw):
                            # Boost for exact word matches
                            keyword_score = min(1.0, keyword_score + 0.3)  # Increased from 0.15 to 0.3
                            break

            if keyword_score < HYBRID_MIN_KEYWORD_SCORE:
                keyword_score = 0.0

            if keyword_patterns:
                if require_primary_match and not matched_primary:
                    continue
                if not require_primary_match and require_domain_match and not matched_domain:
                    continue
                if not require_domain_match and not has_keyword_match:
                    continue

            bm25_score = self._compute_bm25_score(idx, keyword_patterns, idf_table)

            combined_score = (
                HYBRID_VECTOR_WEIGHT * float(score) +
                HYBRID_KEYWORD_WEIGHT * keyword_score +
                HYBRID_BM25_WEIGHT * bm25_score
            )

            candidate = {
                'content': doc['content'],
                'metadata': doc['metadata'],
                'score': combined_score,
                'vector_score': float(score),
                'keyword_score': keyword_score,
                'bm25_score': bm25_score,
                'has_keyword_match': has_keyword_match,
                'keyword_snippet': keyword_snippet,
                'doc_index': idx
            }
            candidates.append(candidate)

        # Count how many candidates currently have keyword matches
        keyword_hit_count = sum(
            1 for candidate in candidates
            if candidate.get('has_keyword_match')
        )

        # Fallback: if user query includes strong keywords but none of the vector results matched them,
        # scan the full document set to pull in direct keyword hits.
        fallback_patterns = [p for p in keyword_patterns if p.get('is_primary')]
        if not fallback_patterns:
            fallback_patterns = [p for p in keyword_patterns if p['is_domain']]
        if not fallback_patterns:
            fallback_patterns = keyword_patterns

        if fallback_patterns and keyword_hit_count < k:
            needed = k - keyword_hit_count
            fallback_candidates = self._keyword_fallback_search(
                fallback_patterns,
                needed,
                exclude_indices=set(retrieved_indices),
                idf_table=idf_table
            )
            candidates.extend(fallback_candidates)

        candidates.sort(key=lambda item: item['score'], reverse=True)
        balanced_candidates = []
        source_counts = {}
        used_indices = set()
        for candidate in candidates:
            source = candidate.get('metadata', {}).get('source', 'Unknown')
            count = source_counts.get(source, 0)
            if count >= MAX_CHUNKS_PER_SOURCE:
                continue
            balanced_candidates.append(candidate)
            source_counts[source] = count + 1
            doc_idx = candidate.get('doc_index')
            if doc_idx is not None:
                used_indices.add(doc_idx)
            if len(balanced_candidates) >= k:
                break

        if fallback_patterns:
            existing_sources = {c.get('metadata', {}).get('source', 'Unknown') for c in balanced_candidates}
            needed_sources = max(0, k - len(existing_sources))
            if needed_sources > 0:
                extra_candidates = self._keyword_fallback_search(
                    fallback_patterns,
                    needed_sources,
                    exclude_indices=used_indices,
                    exclude_sources=existing_sources,
                    idf_table=idf_table
                )
                for extra_candidate in extra_candidates:
                    source = extra_candidate.get('metadata', {}).get('source', 'Unknown')
                    if source in existing_sources:
                        continue
                    balanced_candidates.append(extra_candidate)
                    existing_sources.add(source)
                    source_counts[source] = source_counts.get(source, 0) + 1
                    doc_idx = extra_candidate.get('doc_index')
                    if doc_idx is not None:
                        used_indices.add(doc_idx)
                    if len(balanced_candidates) >= k:
                        break

        if len(balanced_candidates) < k:
            # If we still need more, fill with remaining candidates regardless of source limit
            for candidate in candidates:
                if candidate in balanced_candidates:
                    continue
                balanced_candidates.append(candidate)
                if len(balanced_candidates) >= k:
                    break

        return balanced_candidates

    def enhanced_search(self, query: str, k: int = 5, domain: Optional[str] = None) -> List[Dict[str, Any]]:
        """
        Enhanced search using advanced keyword management and domain filtering.

        Args:
            query: Search query
            k: Number of results to return
            domain: Optional domain filter (sports, crypto, stocks, forex)

        Returns:
            List of search results with enhanced keyword analysis
        """
        # Extract keywords from query using advanced keyword manager
        extracted_keywords = self.keyword_manager.extract_keywords(query, max_keywords=15)
        self.logger.info(f"Extracted {len(extracted_keywords)} keywords from query: {[kw['keyword'] for kw in extracted_keywords[:5]]}")

        # Expand query with synonyms and related terms
        expanded_terms = self.keyword_manager.expand_query(query, domain)
        self.logger.info(f"Expanded query to {len(expanded_terms)} terms")

        # Filter by domain if specified
        if domain:
            expanded_terms = self.keyword_manager.filter_by_domain(expanded_terms, domain)
            self.logger.info(f"Filtered to {len(expanded_terms)} domain-specific terms for {domain}")

        # Perform base search
        base_results = self.search(query, k * 2)  # Get more candidates for re-ranking

        # Enhance results with keyword analysis
        enhanced_results = []
        for result in base_results:
            content = result.get('content', '')
            metadata = result.get('metadata', {})

            # Extract keywords from this document
            doc_keywords = self.keyword_manager.extract_keywords(content, max_keywords=10)

            # Calculate keyword relevance scores
            keyword_scores = []
            for kw in extracted_keywords:
                # Check if keyword appears in document
                kw_lower = kw['keyword'].lower()
                if kw_lower in content.lower():
                    # Calculate position and context scores
                    position_score = self._calculate_keyword_position_score(kw_lower, content)
                    context_score = self._calculate_keyword_context_score(kw_lower, content, expanded_terms)

                    total_kw_score = (kw['score'] * 0.4 + position_score * 0.3 + context_score * 0.3)
                    keyword_scores.append({
                        'keyword': kw['keyword'],
                        'score': total_kw_score,
                        'domain': kw['domain'],
                        'frequency': content.lower().count(kw_lower)
                    })

            # Calculate document-level keyword score
            doc_keyword_score = sum(ks['score'] for ks in keyword_scores) / len(keyword_scores) if keyword_scores else 0

            # Boost original score with keyword relevance
            original_score = result.get('score', 0)
            enhanced_score = original_score * 0.7 + doc_keyword_score * 0.3

            # Add keyword analysis to result
            enhanced_result = result.copy()
            enhanced_result.update({
                'score': enhanced_score,
                'keyword_analysis': {
                    'matched_keywords': keyword_scores,
                    'doc_keywords': doc_keywords,
                    'keyword_score': doc_keyword_score,
                    'total_matches': len(keyword_scores)
                },
                'query_expansion': expanded_terms[:10],  # Top 10 expanded terms
                'domain_relevance': self._calculate_domain_relevance(content, domain) if domain else 1.0
            })

            enhanced_results.append(enhanced_result)

        # Re-sort by enhanced score
        enhanced_results.sort(key=lambda x: x['score'], reverse=True)

        # Apply domain filtering if specified
        if domain:
            enhanced_results = [r for r in enhanced_results if r.get('domain_relevance', 0) > 0.3]

        return enhanced_results[:k]

    def keyword_focused_search(self, keywords: List[str], k: int = 5, domain: Optional[str] = None) -> List[Dict[str, Any]]:
        """
        Search focused specifically on keyword matching with advanced ranking.

        Args:
            keywords: List of keywords to search for
            k: Number of results to return
            domain: Optional domain filter

        Returns:
            Keyword-focused search results
        """
        # Rank keywords by relevance
        ranked_keywords = self.keyword_manager.rank_keywords_by_relevance(keywords, ' '.join(keywords))
        self.logger.info(f"Ranked keywords: {[kw for kw, score in ranked_keywords[:5]]}")

        # Find similar keywords for expansion
        expanded_keywords = set(keywords)
        for keyword, _ in ranked_keywords[:5]:  # Top 5 keywords
            similar = self.keyword_manager.find_similar_keywords(keyword, top_k=3)
            expanded_keywords.update([kw for kw, score in similar if score > 0.7])

        # Build keyword patterns for search
        keyword_patterns = []
        for keyword in expanded_keywords:
            pattern = {
                'term': keyword,
                'pattern': re.compile(r'\b' + re.escape(keyword) + r'\b', re.IGNORECASE),
                'is_domain': domain in self.keyword_manager.get_domain_keywords(domain) if domain else False,
                'is_primary': keyword in [kw for kw, _ in ranked_keywords[:3]]
            }
            keyword_patterns.append(pattern)

        # Use keyword fallback search as primary method
        results = self._keyword_fallback_search(keyword_patterns, k * 2, idf_table={})

        # Enhance with keyword analysis
        enhanced_results = []
        for result in results:
            content = result.get('content', '')

            # Detailed keyword matching analysis
            keyword_matches = []
            for keyword in expanded_keywords:
                if keyword.lower() in content.lower():
                    matches = len(re.findall(r'\b' + re.escape(keyword) + r'\b', content, re.IGNORECASE))
                    if matches > 0:
                        relevance_score = self.keyword_manager.rank_keywords_by_relevance([keyword], content)[0][1]
                        keyword_matches.append({
                            'keyword': keyword,
                            'matches': matches,
                            'relevance': relevance_score
                        })

            # Calculate comprehensive keyword score
            keyword_score = sum(km['relevance'] * km['matches'] for km in keyword_matches) / len(keyword_matches) if keyword_matches else 0

            enhanced_result = result.copy()
            enhanced_result.update({
                'keyword_score': keyword_score,
                'keyword_matches': keyword_matches,
                'total_keyword_matches': len(keyword_matches),
                'expanded_keywords': list(expanded_keywords)
            })

            enhanced_results.append(enhanced_result)

        # Sort by keyword relevance
        enhanced_results.sort(key=lambda x: x.get('keyword_score', 0), reverse=True)

        return enhanced_results[:k]

    def _calculate_keyword_position_score(self, keyword: str, content: str) -> float:
        """Calculate score based on keyword position in document."""
        content_lower = content.lower()
        keyword_pos = content_lower.find(keyword)

        if keyword_pos < 0:
            return 0

        # Earlier keywords get higher scores
        position_score = 1.0 - (keyword_pos / len(content))

        # Bonus for keywords in first 20% of document
        if keyword_pos < len(content) * 0.2:
            position_score *= 1.5

        return min(position_score, 1.0)

    def _calculate_keyword_context_score(self, keyword: str, content: str, context_terms: List[str]) -> float:
        """Calculate score based on surrounding context terms."""
        content_lower = content.lower()
        keyword_pos = content_lower.find(keyword)

        if keyword_pos < 0:
            return 0

        # Extract context window around keyword
        window_size = 200
        start = max(0, keyword_pos - window_size)
        end = min(len(content_lower), keyword_pos + len(keyword) + window_size)
        context = content_lower[start:end]

        # Count context terms in the window
        context_matches = sum(1 for term in context_terms if term in context and term != keyword)

        # Normalize score
        context_score = min(context_matches / len(context_terms), 1.0) if context_terms else 0

        return context_score

    def _calculate_domain_relevance(self, content: str, domain: str) -> float:
        """Calculate how relevant a document is to a specific domain."""
        if not domain:
            return 1.0

        domain_keywords = self.keyword_manager.get_domain_keywords(domain)
        if not domain_keywords:
            return 0.5

        content_lower = content.lower()
        matches = sum(1 for kw in domain_keywords if kw.lower() in content_lower)

        # Calculate relevance score
        relevance = matches / len(domain_keywords)

        # Boost for multiple matches
        if matches >= 3:
            relevance *= 1.2

        return min(relevance, 1.0)

    def _build_searchable_text(self, doc):
        """Compose searchable text using enriched metadata + raw content."""
        metadata = doc.get('metadata', {})

        searchable_text_parts = []
        if metadata.get('summary'):
            searchable_text_parts.append(str(metadata.get('summary')))
        if metadata.get('key_points'):
            searchable_text_parts.extend([str(point) for point in metadata.get('key_points', [])])
        if metadata.get('themes'):
            searchable_text_parts.extend([str(theme) for theme in metadata.get('themes', [])])
        if metadata.get('clean_excerpt'):
            searchable_text_parts.append(str(metadata.get('clean_excerpt')))

        searchable_text_parts.append(str(doc.get('content', '')))
        base_text = ' '.join(searchable_text_parts)
        lower_text = base_text.lower()

        topics = metadata.setdefault('topics', [])
        for canonical, variants in CANONICAL_TOPIC_KEYWORDS.items():
            if canonical in topics:
                continue
            if any(variant in lower_text for variant in variants):
                topics.append(canonical)
                searchable_text_parts.append(canonical)

        if topics:
            distinct_topics = sorted(set(topics))
            metadata['topic_label'] = ', '.join(distinct_topics)
            if not metadata.get('_topic_tagged'):
                summary = metadata.get('summary', '')
                topic_prefix = f"[Topics: {metadata['topic_label']}] "
                if not summary.startswith(topic_prefix):
                    metadata['summary'] = f"{topic_prefix}{summary}".strip()
                metadata['_topic_tagged'] = True

        return ' '.join(searchable_text_parts)

    def _prepare_searchable_texts(self):
        """Build cached searchable text for all documents."""
        self.searchable_texts_raw = [
            self._build_searchable_text(doc)
            for doc in self.documents
        ]
        self.searchable_texts_lower = [
            text.lower() for text in self.searchable_texts_raw
        ]
        self.doc_lengths = [
            len(text.split()) for text in self.searchable_texts_raw
        ]
        if self.doc_lengths:
            self.avg_doc_length = sum(self.doc_lengths) / len(self.doc_lengths)
        else:
            self.avg_doc_length = 0.0

    def _ensure_searchable_texts(self):
        """Ensure searchable texts cache is in sync with documents."""
        if (
            len(self.searchable_texts_raw) != len(self.documents) or
            len(self.searchable_texts_lower) != len(self.documents)
        ):
            self._prepare_searchable_texts()

    def _compute_idf_table(self, keyword_patterns):
        """Compute IDF values for query terms using BM25 formula."""
        idf_table = {}
        total_docs = len(self.searchable_texts_lower)
        if total_docs == 0:
            return idf_table
        for info in keyword_patterns:
            term = info['term']
            if term in idf_table:
                continue
            pattern = info['pattern']
            df = 0
            for text in self.searchable_texts_lower:
                if pattern.search(text):
                    df += 1
            idf = math.log((total_docs - df + 0.5) / (df + 0.5) + 1.0)
            idf_table[term] = max(idf, 0.0)
        return idf_table

    def _compute_bm25_score(self, doc_index, keyword_patterns, idf_table):
        """Compute BM25 score for a document given keyword patterns."""
        if not keyword_patterns or not idf_table:
            return 0.0
        if doc_index >= len(self.searchable_texts_lower):
            return 0.0
        doc_text = self.searchable_texts_lower[doc_index]
        doc_len = self.doc_lengths[doc_index] if doc_index < len(self.doc_lengths) else len(doc_text.split())
        avg_len = self.avg_doc_length if self.avg_doc_length > 0 else doc_len or 1
        bm25 = 0.0
        for info in keyword_patterns:
            term = info['term']
            pattern = info['pattern']
            tf = len(pattern.findall(doc_text))
            if tf == 0:
                continue
            idf = idf_table.get(term, 0.0)
            numerator = tf * (BM25_K1 + 1.0)
            denominator = tf + BM25_K1 * (1.0 - BM25_B + BM25_B * (doc_len / avg_len))
            bm25 += idf * (numerator / denominator)
        return bm25

    def _keyword_fallback_search(self, keyword_patterns, needed, exclude_indices=None, exclude_sources=None, idf_table=None):
        """Scan all documents for keyword matches when vector search misses them."""
        fallback_candidates = []
        exclude_indices = exclude_indices or set()
        local_exclude_sources = set(exclude_sources) if exclude_sources else set()

        for idx, searchable_text_raw in enumerate(self.searchable_texts_raw):
            if idx in exclude_indices:
                continue

            doc = self.documents[idx]
            if doc.get('metadata', {}).get('deleted'):
                continue

            if any(info['pattern'].search(searchable_text_raw) for info in keyword_patterns):
                source = doc.get('metadata', {}).get('source', 'Unknown')
                if source in local_exclude_sources:
                    continue
                snippet = self._extract_keyword_snippet(searchable_text_raw, keyword_patterns)
                if not snippet:
                    continue
                keyword_score = 1.0
                bm25 = self._compute_bm25_score(idx, keyword_patterns, idf_table or {})
                combined = (
                    HYBRID_KEYWORD_WEIGHT * keyword_score +
                    HYBRID_BM25_WEIGHT * bm25
                )
                candidate = {
                    'content': doc['content'],
                    'metadata': doc['metadata'],
                    'score': combined,
                    'vector_score': 0.0,
                    'keyword_score': keyword_score,
                    'bm25_score': bm25,
                    'has_keyword_match': True,
                    'keyword_snippet': snippet,
                    'doc_index': idx
                }
                fallback_candidates.append(candidate)
                exclude_indices.add(idx)
                local_exclude_sources.add(source)

            if len(fallback_candidates) >= needed:
                break

        return fallback_candidates

    def _extract_keyword_snippet(self, raw_text, keyword_patterns, window: int = 240):
        """Extract a concise snippet around the first keyword match."""
        if not raw_text or not keyword_patterns:
            return None

        for info in keyword_patterns:
            match = info['pattern'].search(raw_text)
            if match:
                start = max(0, match.start() - window // 2)
                end = min(len(raw_text), match.end() + window // 2)
                snippet = raw_text[start:end]
                snippet = re.sub(r'\s+', ' ', snippet).strip()
                return snippet
        return None
