import json
import time
import numpy as np
import faiss
import threading
from pathlib import Path
from sentence_transformers import SentenceTransformer
from config import (
    EMBEDDINGS_DIR,
    DOCUMENTS_DIR,
    NOTES_DIR,
    EMBEDDING_MODEL,
    PROCESSED_DIR
)

# Setup logging/printing
def log(msg):
    print(f"[{time.strftime('%H:%M:%S')}] {msg}")

def compose_card_text(card, metadata):
    """Compose a readable knowledge card string."""
    lines = [
        f"[Source: {metadata.get('source', 'Unknown')} | Chunk {metadata.get('chunk_id', '?')}]",
        f"Summary: {card.get('summary', '')}"
    ]

    if card.get('key_points'):
        lines.append("Key Points:")
        for point in card['key_points']:
            lines.append(f"- {point}")

    if card.get('themes'):
        lines.append("Themes: " + ', '.join(card['themes']))

    if card.get('clean_excerpt'):
        lines.append("Excerpt: " + card['clean_excerpt'])

    return '\n'.join(lines)

def main():
    target_file = "OceanofPDF.com_The_Invisible_Master_Leo_LyonZagami.pdf"
    log(f"Starting vector store repair. Target file to enrich: {target_file}")

    # 1. Load current metadata
    metadata_file = EMBEDDINGS_DIR / "documents_metadata.json"
    if not metadata_file.exists():
        log("Error: Metadata file not found!")
        return
    
    with open(metadata_file, 'r') as f:
        all_docs = json.load(f)
    log(f"Loaded {len(all_docs)} existing documents from metadata.")

    # 2. Load Raw Chunks for Target
    chunk_file = DOCUMENTS_DIR / f"{Path(target_file).stem}_chunks.json"
    if not chunk_file.exists():
        log("Error: Raw chunks file not found!")
        return
    
    with open(chunk_file, 'r') as f:
        raw_chunks = json.load(f)
    log(f"Loaded {len(raw_chunks)} raw chunks for target.")

    # 3. Load Cards for Target
    card_file = NOTES_DIR / f"{Path(target_file).stem}_cards.json"
    if not card_file.exists():
        log("Error: Cards file not found!")
        return
    
    with open(card_file, 'r') as f:
        cards_list = json.load(f)
    log(f"Loaded {len(cards_list)} cards for target.")

    # Map cards by chunk_id for easy lookup
    cards_map = {c['chunk_id']: c for c in cards_list}

    # 4. Prepare New Enriched Chunks
    new_enriched_chunks = []
    for chunk in raw_chunks:
        chunk_id = chunk.get('metadata', {}).get('chunk_id')
        if chunk_id in cards_map:
            card = cards_map[chunk_id]
            metadata = chunk.get('metadata', {}).copy()
            
            # Update metadata with card info
            metadata.update({
                'summary': card.get('summary'),
                'key_points': card.get('key_points'),
                'themes': card.get('themes'),
                'card_id': card.get('card_id'),
                'clean_excerpt': card.get('clean_excerpt')
            })
            
            # Compose new content
            new_content = compose_card_text(card, metadata)
            
            new_enriched_chunks.append({
                'content': new_content,
                'metadata': metadata
            })
        else:
            log(f"Warning: No card found for chunk {chunk_id}, using raw chunk.")
            new_enriched_chunks.append(chunk)
    
    log(f"Prepared {len(new_enriched_chunks)} enriched chunks.")

    # 5. Build Final Document List
    # Filter out OLD chunks for target file
    final_docs = [d for d in all_docs if d.get('metadata', {}).get('source') != target_file]
    log(f"Retained {len(final_docs)} documents from other sources.")
    
    # Add NEW enriched chunks
    final_docs.extend(new_enriched_chunks)
    log(f"Final document count: {len(final_docs)}")

    # 6. Re-Embed Everything (The Heavy Lifting)
    log("Loading embedding model...")
    model = SentenceTransformer(EMBEDDING_MODEL)
    
    # Create new FAISS index
    sample_embedding = model.encode(["test"])
    dimension = sample_embedding.shape[1]
    index = faiss.IndexFlatIP(dimension)
    
    log(f"Created new FAISS index (dim={dimension}). Starting embedding process...")

    batch_size = 32
    total_docs = len(final_docs)
    contents = [doc['content'] for doc in final_docs]
    
    # Embed in batches to save memory and show progress
    all_embeddings = []
    
    for i in range(0, total_docs, batch_size):
        batch = contents[i : i + batch_size]
        try:
            embeddings = model.encode(batch, convert_to_numpy=True, normalize_embeddings=False)
            # Normalize for Cosine Similarity (Inner Product)
            faiss.normalize_L2(embeddings)
            all_embeddings.append(embeddings)
            
            if i % 1000 < batch_size:
                 log(f"Processed {min(i + batch_size, total_docs)}/{total_docs} chunks...")
        except Exception as e:
            log(f"Error embedding batch starting at {i}: {e}")
            return

    # Concatenate all embeddings
    final_embeddings = np.vstack(all_embeddings)
    
    # Add to index
    log("Adding vectors to FAISS index...")
    index.add(final_embeddings)
    
    # 7. Save
    log("Saving new index and metadata...")
    
    # Save Index
    index_file = EMBEDDINGS_DIR / "faiss_index.index"
    faiss.write_index(index, str(index_file))
    
    # Save Metadata
    with open(metadata_file, 'w', encoding='utf-8') as f:
        json.dump(final_docs, f, ensure_ascii=False, indent=2)
        
    log("SUCCESS: Vector store repaired and enriched chunks applied!")
    log(f"Final Index Total: {index.ntotal}")
    log(f"Final Metadata Total: {len(final_docs)}")

if __name__ == "__main__":
    main()
