from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import json
import re
import math
from pathlib import Path
from rapidfuzz import fuzz
from config import (
    EMBEDDINGS_DIR,
    HYBRID_SEARCH_OVERSAMPLE,
    HYBRID_VECTOR_WEIGHT,
    HYBRID_KEYWORD_WEIGHT,
    HYBRID_MIN_KEYWORD_SCORE,
    HYBRID_BM25_WEIGHT,
    MAX_CHUNKS_PER_SOURCE,
    TECHNICAL_TERMS,
    SYNONYMS
)
from exceptions import (
    VectorStoreError,
    IndexLoadError,
    IndexSaveError,
    EmbeddingError,
    ValidationError,
    ResourceError
)

# Stopword list for promoting query terms to keyword patterns
QUERY_STOPWORDS = {
    'that', 'this', 'with', 'have', 'from', 'they', 'were', 'their', 'there',
    'which', 'about', 'into', 'through', 'after', 'before', 'where', 'when',
    'because', 'between', 'while', 'within', 'upon', 'these', 'those', 'many',
    'much', 'other', 'such', 'also', 'very', 'than', 'them', 'some', 'only',
    'even', 'most', 'more', 'like', 'been', 'being', 'shall', 'will', 'could',
    'would', 'should', 'might', 'your', 'that', 'into', 'onto', 'both', 'each',
    'summarize', 'summary', 'summaries', 'steps', 'guide', 'guides', 'information',
    'detail', 'details', 'explain', 'describe', 'please', 'help', 'answer',
    'question', 'regarding', 'related', 'tell', 'show', 'based', 'according',
    'provide', 'gimme', 'give', 'need', 'looking', 'want', 'make'
}

GENERIC_DOMAIN_TERMS = {
    'cannabis', 'marijuana', 'weed', 'hemp', 'pot', 'ganja',
    'plant', 'plants', 'grow', 'growing', 'cultivation', 'cultivate',
    'harvest', 'buds', 'bud', 'flower', 'flowers', 'guide', 'guides',
    'book', 'books', 'document', 'documents', 'manual', 'manuals', 'g', 'gram', 'grams'
}

BM25_K1 = 1.5
BM25_B = 0.75

CANONICAL_TOPIC_KEYWORDS = {
    'hydroponics': [
        'hydroponic', 'hydroponics', 'hydro',
        'soil-less', 'soil less', 'soilless',
        'water culture', 'nutrient solution', 'nutrient tank',
        'ebb and flow', 'reservoir', 'deep water culture', 'dwc',
        'drip system', 'flood and drain', 'soilless mix', 'soilless mixture'
    ],
    'soilless growing': [
        'soil-less', 'soil less', 'soilless mix', 'soilless mixture',
        'perlite mix', 'coco coir', 'rockwool', 'expanded clay', 'grow cube'
    ],
    'medical': [
        'medical', 'medicine', 'medicinal', 'patient', 'doctor', 'physician',
        'therapeutic', 'therapy', 'treatment', 'dosage', 'dose', 'clinical',
        'symptom', 'condition', 'diagnosis', 'prescription', 'wellness',
        'health', 'healthcare', 'pain relief', 'anxiety relief', 'anti-inflammatory'
    ],
    'finance': [
        'finance', 'financial', 'funding', 'capital', 'loan', 'credit', 'debt',
        'investment', 'investor', 'roi', 'return on investment', 'cash flow',
        'profit', 'profitability', 'revenue', 'income', 'margin', 'budget'
    ],
    'legal': [
        'legal', 'law', 'laws', 'statute', 'regulation', 'regulatory',
        'compliance', 'license', 'licensing', 'permit', 'permitting',
        'decriminalize', 'legalization', 'court', 'policy', 'ordinance'
    ],
    'business': [
        'business', 'entrepreneur', 'entrepreneurship', 'startup', 'company',
        'corporation', 'market', 'marketing', 'sales', 'branding', 'franchise',
        'supply chain', 'wholesale', 'retail', 'customer', 'client',
        'strategy', 'scaling', 'operations', 'management'
    ],
    'psychology': [
        'mental health', 'psychology', 'therapy', 'counseling', 'wellbeing',
        'stress', 'anxiety', 'ptsd', 'depression', 'mindfulness', 'meditation'
    ],
    'history': [
        'history', 'historical', 'chronicle', 'timeline', 'era', 'century', 'ancient',
        'modern history', 'historian', 'civilization', 'heritage', 'legacy',
        'anthropology', 'archaeology', 'culture', 'cultural history', 'records'
    ]
}

class VectorStore:
    def __init__(self, model_name=None):
        from config import EMBEDDING_MODEL
        if model_name is None:
            model_name = EMBEDDING_MODEL
        self.model = SentenceTransformer(model_name)
        self.index = None
        self.documents = []
        self.searchable_texts_raw = []
        self.searchable_texts_lower = []
        self.doc_lengths = []
        self.avg_doc_length = 0.0
        self.index_file = EMBEDDINGS_DIR / "faiss_index.index"
        self.metadata_file = EMBEDDINGS_DIR / "documents_metadata.json"

    def load_existing_index(self):
        """Load existing FAISS index and metadata"""
        if not self.index_file.exists():
            return False

        if self.metadata_file.exists():
            try:
                self.index = faiss.read_index(str(self.index_file))
                with open(self.metadata_file, 'r', encoding='utf-8') as f:
                    self.documents = json.load(f)
                self._prepare_searchable_texts()
                print(f"✓ Loaded existing index with {len(self.documents)} documents")
                return True
            except json.JSONDecodeError as e:
                raise IndexLoadError(
                    str(self.metadata_file),
                    {"error_type": "json_decode", "original_error": str(e)}
                )
            except Exception as e:
                raise IndexLoadError(
                    str(self.index_file),
                    {"error_type": "faiss_load", "original_error": str(e)}
                )

        # Backward compatibility: try loading old pickle file
        from config import EMBEDDINGS_DIR
        old_pickle_file = EMBEDDINGS_DIR / "documents_metadata.pkl"
        if old_pickle_file.exists():
            try:
                print("Migrating from pickle to JSON format...")
                self.index = faiss.read_index(str(self.index_file))
                import pickle
                with open(old_pickle_file, 'rb') as f:
                    self.documents = pickle.load(f)
                # Save in new format
                self.save_metadata()
                self._prepare_searchable_texts()
                # Remove old file
                old_pickle_file.unlink()
                print(f"✓ Migrated index with {len(self.documents)} documents to JSON format")
                return True
            except Exception as e:
                raise IndexLoadError(
                    str(old_pickle_file),
                    {"error_type": "pickle_migration", "original_error": str(e)}
                )

        return False

    def add_documents(self, chunks):
        """Add new documents to the vector store"""
        if not chunks:
            return

        if not isinstance(chunks, list):
            raise ValidationError("Chunks must be a list", {"received_type": type(chunks)})

        # Validate chunk structure
        for i, chunk in enumerate(chunks):
            if not isinstance(chunk, dict):
                raise ValidationError(f"Chunk {i} must be a dictionary", {"chunk_index": i, "chunk_type": type(chunk)})
            if 'content' not in chunk:
                raise ValidationError(f"Chunk {i} missing 'content' key", {"chunk_index": i})

        contents = [chunk['content'] for chunk in chunks]

        try:
            embeddings = self.model.encode(
                contents,
                batch_size=32,
                show_progress_bar=True,
                convert_to_numpy=True,
                normalize_embeddings=False
            )
        except Exception as e:
            # Get a preview of the problematic content
            problematic_content = contents[0][:200] + "..." if contents else ""
            raise EmbeddingError(problematic_content, {"original_error": str(e), "batch_size": len(contents)})

        if self.index is None:
            # Create new index
            self.index = faiss.IndexFlatIP(embeddings.shape[1])
            self.documents = []

        # Add to index
        embeddings_np = np.array(embeddings).astype('float32')
        faiss.normalize_L2(embeddings_np)  # Normalize for cosine similarity
        self.index.add(embeddings_np)

        # Store document metadata
        self.documents.extend(chunks)
        # Update searchable texts cache with new documents
        for chunk in chunks:
            searchable_text = self._build_searchable_text(chunk)
            self.searchable_texts_raw.append(searchable_text)
            self.searchable_texts_lower.append(searchable_text.lower())

        print(f"✓ Added {len(chunks)} chunks to vector store")

    def save_index(self):
        """Save the FAISS index and metadata"""
        if self.index is None:
            raise IndexSaveError("No index to save - index is None", {"index_state": "none"})

        try:
            faiss.write_index(self.index, str(self.index_file))
            self.save_metadata()
            print("✓ Saved vector store index")
        except Exception as e:
            raise IndexSaveError(str(self.index_file), {"original_error": str(e)})

    def save_metadata(self):
        """Save document metadata to JSON file"""
        try:
            with open(self.metadata_file, 'w', encoding='utf-8') as f:
                json.dump(self.documents, f, ensure_ascii=False, indent=2)
        except Exception as e:
            raise IndexSaveError(str(self.metadata_file), {"error_type": "json_save", "original_error": str(e)})

    def replace_source_documents(self, source_name, new_chunks):
        """
        Safely replace all documents for a source with new chunks.
        This rebuilds the index to ensure consistency (prevents index corruption).
        
        Args:
            source_name: Name of the source file to replace
            new_chunks: List of new chunk dictionaries to replace old ones with
        """
        if not new_chunks:
            # If no new chunks, just remove old ones
            self._remove_source_from_index(source_name)
            return
        
        print(f"Replacing documents for source: {source_name} ({len(new_chunks)} new chunks)")
        
        # Step 1: Filter out old documents for this source
        kept_documents = [
            doc for doc in self.documents 
            if doc.get('metadata', {}).get('source') != source_name
        ]
        
        print(f"  Keeping {len(kept_documents)} documents from other sources")
        
        # Step 2: Combine with new documents
        updated_documents = kept_documents + new_chunks
        
        # Step 3: Rebuild index from scratch to ensure consistency
        # This is the safest approach - prevents index/document mismatch
        print(f"  Rebuilding index with {len(updated_documents)} total documents...")
        
        # Clear existing index
        if self.index is not None:
            dimension = self.index.d
        else:
            # Get dimension from model
            sample_embedding = self.model.encode(["test"])
            dimension = sample_embedding.shape[1]
        
        # Create new index
        self.index = faiss.IndexFlatIP(dimension)
        
        # Clear searchable texts cache
        self.searchable_texts_raw = []
        self.searchable_texts_lower = []
        
        # Re-embed all documents in batches
        contents = [doc['content'] for doc in updated_documents]
        batch_size = 32
        
        for i in range(0, len(contents), batch_size):
            batch = contents[i : i + batch_size]
            try:
                embeddings = self.model.encode(
                    batch,
                    batch_size=batch_size,
                    show_progress_bar=False,
                    convert_to_numpy=True,
                    normalize_embeddings=False
                )
                embeddings_np = np.array(embeddings).astype('float32')
                faiss.normalize_L2(embeddings_np)
                self.index.add(embeddings_np)
                
                # Update searchable texts cache
                for j, doc in enumerate(updated_documents[i : i + batch_size]):
                    searchable_text = self._build_searchable_text(doc)
                    self.searchable_texts_raw.append(searchable_text)
                    self.searchable_texts_lower.append(searchable_text.lower())
                    
            except Exception as e:
                raise EmbeddingError(
                    f"Error embedding batch starting at {i}",
                    {"original_error": str(e), "batch_start": i}
                )
        
        # Update documents list
        self.documents = updated_documents
        
        print(f"✓ Successfully replaced source '{source_name}': {len(new_chunks)} chunks")
        print(f"  Total documents in index: {self.index.ntotal}")
    
    def _remove_source_from_index(self, source_name):
        """Remove all documents for a source by rebuilding index without them."""
        kept_documents = [
            doc for doc in self.documents 
            if doc.get('metadata', {}).get('source') != source_name
        ]
        
        if len(kept_documents) == len(self.documents):
            print(f"No documents found for source: {source_name}")
            return
        
        print(f"Removing {len(self.documents) - len(kept_documents)} documents for source: {source_name}")
        
        # Rebuild index with kept documents only
        if self.index is not None:
            dimension = self.index.d
        else:
            sample_embedding = self.model.encode(["test"])
            dimension = sample_embedding.shape[1]
        
        self.index = faiss.IndexFlatIP(dimension)
        self.searchable_texts_raw = []
        self.searchable_texts_lower = []
        
        contents = [doc['content'] for doc in kept_documents]
        batch_size = 32
        
        for i in range(0, len(contents), batch_size):
            batch = contents[i : i + batch_size]
            embeddings = self.model.encode(
                batch,
                batch_size=batch_size,
                show_progress_bar=False,
                convert_to_numpy=True,
                normalize_embeddings=False
            )
            embeddings_np = np.array(embeddings).astype('float32')
            faiss.normalize_L2(embeddings_np)
            self.index.add(embeddings_np)
            
            for j, doc in enumerate(kept_documents[i : i + batch_size]):
                searchable_text = self._build_searchable_text(doc)
                self.searchable_texts_raw.append(searchable_text)
                self.searchable_texts_lower.append(searchable_text.lower())
        
        self.documents = kept_documents
        print(f"✓ Removed source '{source_name}'. Total documents: {len(self.documents)}")
    
    def mark_source_as_deleted(self, source_name):
        """
        DEPRECATED: Use replace_source_documents instead.
        This method only marks documents as deleted but doesn't remove vectors from index,
        which can cause index corruption. Kept for backward compatibility but not recommended.
        """
        count = 0
        for doc in self.documents:
            if doc.get('metadata', {}).get('source') == source_name:
                doc.setdefault('metadata', {})['deleted'] = True
                count += 1
        # Only save if changes were made
        if count > 0:
            self.save_metadata()
            print(f"⚠ Marked {count} documents from {source_name} as deleted (deprecated method - may cause index mismatch)")
        return count

    def search(self, query, k=5):
        """Search for similar documents using hybrid vector + keyword re-ranking."""
        if self.index is None:
            raise VectorStoreError("Index not loaded", {"index_state": "none"})
        if len(self.documents) == 0:
            return []

        if not query:
            raise ValidationError("Query cannot be empty", {"query_length": 0})
        if not isinstance(query, str):
            raise ValidationError("Query must be a string", {"query_type": type(query).__name__})
        if len(query.strip()) == 0:
            raise ValidationError("Query cannot be only whitespace", {"query_content": repr(query)})

        query_lower = query.lower()
        query_words = re.findall(r'\b\w+\b', query_lower)
        
        domain_terms = []
        primary_domain_terms = []
        general_terms = []
        
        expanded_terms = set()
        for word in query_words:
            if word in SYNONYMS:
                expanded_terms.update(SYNONYMS[word])
        
        for term in TECHNICAL_TERMS:
            if term in query_lower and term not in domain_terms:
                domain_terms.append(term)
                if term not in GENERIC_DOMAIN_TERMS:
                    primary_domain_terms.append(term)
        
        for term in expanded_terms:
            if term not in domain_terms:
                domain_terms.append(term)
                if term not in GENERIC_DOMAIN_TERMS:
                    primary_domain_terms.append(term)
        
        for word in query_words:
            if len(word) <= 3:
                continue
            if word in QUERY_STOPWORDS:
                continue
            if word in domain_terms or word in general_terms:
                continue
            general_terms.append(word)

        keyword_patterns = []
        added_terms = set()
        for term in domain_terms + general_terms:
            if term in added_terms:
                continue
            keyword_patterns.append({
                'term': term,
                'pattern': re.compile(r'\b' + re.escape(term) + r'\b', re.IGNORECASE),
                'is_domain': term in domain_terms,
                'is_primary': term in primary_domain_terms
            })
            added_terms.add(term)
        self._ensure_searchable_texts()
        idf_table = self._compute_idf_table(keyword_patterns) if keyword_patterns else {}

        try:
            query_embedding = self.model.encode([query])
            query_embedding = np.array(query_embedding).astype('float32')
            faiss.normalize_L2(query_embedding)

            oversample = max(k * HYBRID_SEARCH_OVERSAMPLE, k)
            scores, indices = self.index.search(query_embedding, oversample)
        except Exception as e:
            print(f"Error during search: {e}")
            return []

        candidates = []
        retrieved_indices = []
        for score, idx in zip(scores[0], indices[0]):
            if idx >= len(self.documents):
                continue
            doc = self.documents[idx]
            # Skip deleted documents
            if doc.get('metadata', {}).get('deleted'):
                continue

            retrieved_indices.append(idx)

            searchable_text_lower = self.searchable_texts_lower[idx]
            searchable_text_raw = self.searchable_texts_raw[idx]
            
            # Initial keyword score from fuzzy matching
            keyword_score = 0.0
            if searchable_text_lower:
                keyword_score = fuzz.partial_ratio(query_lower, searchable_text_lower) / 100.0
            
            # Check for exact keyword matches in searchable text (whole word matching)
            # This prevents false matches like "hydro" in "tetrahydrocannabinol"
            matched_terms = []
            has_keyword_match = False
            keyword_snippet = None
            matched_domain = False
            matched_primary = False
            require_domain_match = any(p['is_domain'] for p in keyword_patterns)
            require_primary_match = any(p['is_primary'] for p in keyword_patterns)

            if keyword_patterns:
                for pattern_info in keyword_patterns:
                    term_pattern = pattern_info['pattern']
                    if term_pattern.search(searchable_text_raw):
                        matched_terms.append(pattern_info['term'])
                        if pattern_info['is_domain']:
                            matched_domain = True
                            if pattern_info.get('is_primary'):
                                matched_primary = True
                        keyword_score = min(1.0, keyword_score + 0.5)
                        has_keyword_match = True

                if len(matched_terms) > 1:
                    keyword_score = min(1.0, keyword_score + 0.2 * min(len(matched_terms) - 1, 3))

                keyword_snippet = self._extract_keyword_snippet(searchable_text_raw, keyword_patterns)
                if not keyword_snippet and has_keyword_match:
                    keyword_snippet = searchable_text_raw[:240]

                if has_keyword_match and not matched_domain and not require_domain_match:
                    keyword_score = min(keyword_score, 0.4)
            
            # Also check for exact word matches from query (not just keywords)
            # This helps catch terms like "hydroponics" even if not in keyword list
            if not matched_terms:  # Only if no keyword matches found
                for word in query_words:
                    if len(word) > 3:  # Only check words longer than 3 chars
                        word_pattern = re.compile(r'\b' + re.escape(word) + r'\b', re.IGNORECASE)
                        if word_pattern.search(searchable_text_raw):
                            # Boost for exact word matches
                            keyword_score = min(1.0, keyword_score + 0.3)  # Increased from 0.15 to 0.3
                            break

            if keyword_score < HYBRID_MIN_KEYWORD_SCORE:
                keyword_score = 0.0

            if keyword_patterns:
                if require_primary_match and not matched_primary:
                    continue
                if not require_primary_match and require_domain_match and not matched_domain:
                    continue
                if not require_domain_match and not has_keyword_match:
                    continue

            bm25_score = self._compute_bm25_score(idx, keyword_patterns, idf_table)

            combined_score = (
                HYBRID_VECTOR_WEIGHT * float(score) +
                HYBRID_KEYWORD_WEIGHT * keyword_score +
                HYBRID_BM25_WEIGHT * bm25_score
            )

            candidate = {
                'content': doc['content'],
                'metadata': doc['metadata'],
                'score': combined_score,
                'vector_score': float(score),
                'keyword_score': keyword_score,
                'bm25_score': bm25_score,
                'has_keyword_match': has_keyword_match,
                'keyword_snippet': keyword_snippet,
                'doc_index': idx
            }
            candidates.append(candidate)

        # Count how many candidates currently have keyword matches
        keyword_hit_count = sum(
            1 for candidate in candidates
            if candidate.get('has_keyword_match')
        )

        # Fallback: if user query includes strong keywords but none of the vector results matched them,
        # scan the full document set to pull in direct keyword hits.
        fallback_patterns = [p for p in keyword_patterns if p.get('is_primary')]
        if not fallback_patterns:
            fallback_patterns = [p for p in keyword_patterns if p['is_domain']]
        if not fallback_patterns:
            fallback_patterns = keyword_patterns

        if fallback_patterns and keyword_hit_count < k:
            needed = k - keyword_hit_count
            fallback_candidates = self._keyword_fallback_search(
                fallback_patterns,
                needed,
                exclude_indices=set(retrieved_indices),
                idf_table=idf_table
            )
            candidates.extend(fallback_candidates)

        candidates.sort(key=lambda item: item['score'], reverse=True)
        balanced_candidates = []
        source_counts = {}
        used_indices = set()
        for candidate in candidates:
            source = candidate.get('metadata', {}).get('source', 'Unknown')
            count = source_counts.get(source, 0)
            if count >= MAX_CHUNKS_PER_SOURCE:
                continue
            balanced_candidates.append(candidate)
            source_counts[source] = count + 1
            doc_idx = candidate.get('doc_index')
            if doc_idx is not None:
                used_indices.add(doc_idx)
            if len(balanced_candidates) >= k:
                break

        if fallback_patterns:
            existing_sources = {c.get('metadata', {}).get('source', 'Unknown') for c in balanced_candidates}
            needed_sources = max(0, k - len(existing_sources))
            if needed_sources > 0:
                extra_candidates = self._keyword_fallback_search(
                    fallback_patterns,
                    needed_sources,
                    exclude_indices=used_indices,
                    exclude_sources=existing_sources,
                    idf_table=idf_table
                )
                for extra_candidate in extra_candidates:
                    source = extra_candidate.get('metadata', {}).get('source', 'Unknown')
                    if source in existing_sources:
                        continue
                    balanced_candidates.append(extra_candidate)
                    existing_sources.add(source)
                    source_counts[source] = source_counts.get(source, 0) + 1
                    doc_idx = extra_candidate.get('doc_index')
                    if doc_idx is not None:
                        used_indices.add(doc_idx)
                    if len(balanced_candidates) >= k:
                        break

        if len(balanced_candidates) < k:
            # If we still need more, fill with remaining candidates regardless of source limit
            for candidate in candidates:
                if candidate in balanced_candidates:
                    continue
                balanced_candidates.append(candidate)
                if len(balanced_candidates) >= k:
                    break

        return balanced_candidates

    def _build_searchable_text(self, doc):
        """Compose searchable text using enriched metadata + raw content."""
        metadata = doc.get('metadata', {})

        searchable_text_parts = []
        if metadata.get('summary'):
            searchable_text_parts.append(str(metadata.get('summary')))
        if metadata.get('key_points'):
            searchable_text_parts.extend([str(point) for point in metadata.get('key_points', [])])
        if metadata.get('themes'):
            searchable_text_parts.extend([str(theme) for theme in metadata.get('themes', [])])
        if metadata.get('clean_excerpt'):
            searchable_text_parts.append(str(metadata.get('clean_excerpt')))

        searchable_text_parts.append(str(doc.get('content', '')))
        base_text = ' '.join(searchable_text_parts)
        lower_text = base_text.lower()

        topics = metadata.setdefault('topics', [])
        for canonical, variants in CANONICAL_TOPIC_KEYWORDS.items():
            if canonical in topics:
                continue
            if any(variant in lower_text for variant in variants):
                topics.append(canonical)
                searchable_text_parts.append(canonical)

        if topics:
            distinct_topics = sorted(set(topics))
            metadata['topic_label'] = ', '.join(distinct_topics)
            if not metadata.get('_topic_tagged'):
                summary = metadata.get('summary', '')
                topic_prefix = f"[Topics: {metadata['topic_label']}] "
                if not summary.startswith(topic_prefix):
                    metadata['summary'] = f"{topic_prefix}{summary}".strip()
                metadata['_topic_tagged'] = True

        return ' '.join(searchable_text_parts)

    def _prepare_searchable_texts(self):
        """Build cached searchable text for all documents."""
        self.searchable_texts_raw = [
            self._build_searchable_text(doc)
            for doc in self.documents
        ]
        self.searchable_texts_lower = [
            text.lower() for text in self.searchable_texts_raw
        ]
        self.doc_lengths = [
            len(text.split()) for text in self.searchable_texts_raw
        ]
        if self.doc_lengths:
            self.avg_doc_length = sum(self.doc_lengths) / len(self.doc_lengths)
        else:
            self.avg_doc_length = 0.0

    def _ensure_searchable_texts(self):
        """Ensure searchable texts cache is in sync with documents."""
        if (
            len(self.searchable_texts_raw) != len(self.documents) or
            len(self.searchable_texts_lower) != len(self.documents)
        ):
            self._prepare_searchable_texts()

    def _compute_idf_table(self, keyword_patterns):
        """Compute IDF values for query terms using BM25 formula."""
        idf_table = {}
        total_docs = len(self.searchable_texts_lower)
        if total_docs == 0:
            return idf_table
        for info in keyword_patterns:
            term = info['term']
            if term in idf_table:
                continue
            pattern = info['pattern']
            df = 0
            for text in self.searchable_texts_lower:
                if pattern.search(text):
                    df += 1
            idf = math.log((total_docs - df + 0.5) / (df + 0.5) + 1.0)
            idf_table[term] = max(idf, 0.0)
        return idf_table

    def _compute_bm25_score(self, doc_index, keyword_patterns, idf_table):
        """Compute BM25 score for a document given keyword patterns."""
        if not keyword_patterns or not idf_table:
            return 0.0
        if doc_index >= len(self.searchable_texts_lower):
            return 0.0
        doc_text = self.searchable_texts_lower[doc_index]
        doc_len = self.doc_lengths[doc_index] if doc_index < len(self.doc_lengths) else len(doc_text.split())
        avg_len = self.avg_doc_length if self.avg_doc_length > 0 else doc_len or 1
        bm25 = 0.0
        for info in keyword_patterns:
            term = info['term']
            pattern = info['pattern']
            tf = len(pattern.findall(doc_text))
            if tf == 0:
                continue
            idf = idf_table.get(term, 0.0)
            numerator = tf * (BM25_K1 + 1.0)
            denominator = tf + BM25_K1 * (1.0 - BM25_B + BM25_B * (doc_len / avg_len))
            bm25 += idf * (numerator / denominator)
        return bm25

    def _keyword_fallback_search(self, keyword_patterns, needed, exclude_indices=None, exclude_sources=None, idf_table=None):
        """Scan all documents for keyword matches when vector search misses them."""
        fallback_candidates = []
        exclude_indices = exclude_indices or set()
        local_exclude_sources = set(exclude_sources) if exclude_sources else set()

        for idx, searchable_text_raw in enumerate(self.searchable_texts_raw):
            if idx in exclude_indices:
                continue

            doc = self.documents[idx]
            if doc.get('metadata', {}).get('deleted'):
                continue

            if any(info['pattern'].search(searchable_text_raw) for info in keyword_patterns):
                source = doc.get('metadata', {}).get('source', 'Unknown')
                if source in local_exclude_sources:
                    continue
                snippet = self._extract_keyword_snippet(searchable_text_raw, keyword_patterns)
                if not snippet:
                    continue
                keyword_score = 1.0
                bm25 = self._compute_bm25_score(idx, keyword_patterns, idf_table or {})
                combined = (
                    HYBRID_KEYWORD_WEIGHT * keyword_score +
                    HYBRID_BM25_WEIGHT * bm25
                )
                candidate = {
                    'content': doc['content'],
                    'metadata': doc['metadata'],
                    'score': combined,
                    'vector_score': 0.0,
                    'keyword_score': keyword_score,
                    'bm25_score': bm25,
                    'has_keyword_match': True,
                    'keyword_snippet': snippet,
                    'doc_index': idx
                }
                fallback_candidates.append(candidate)
                exclude_indices.add(idx)
                local_exclude_sources.add(source)

            if len(fallback_candidates) >= needed:
                break

        return fallback_candidates

    def _extract_keyword_snippet(self, raw_text, keyword_patterns, window: int = 240):
        """Extract a concise snippet around the first keyword match."""
        if not raw_text or not keyword_patterns:
            return None

        for info in keyword_patterns:
            match = info['pattern'].search(raw_text)
            if match:
                start = max(0, match.start() - window // 2)
                end = min(len(raw_text), match.end() + window // 2)
                snippet = raw_text[start:end]
                snippet = re.sub(r'\s+', ' ', snippet).strip()
                return snippet
        return None
