#!/usr/bin/env python3
"""
Test script to compare embedding models before full rebuild.
Tests BAAI/bge-base-en-v1.5 vs all-MiniLM-L6-v2 on sample queries.
"""

import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))

from sentence_transformers import SentenceTransformer
import numpy as np
import time

def test_embedding_model(model_name, test_queries, test_chunks):
    """Test an embedding model on sample queries"""
    print(f"\n{'='*60}")
    print(f"Testing: {model_name}")
    print(f"{'='*60}")
    
    try:
        start_time = time.time()
        model = SentenceTransformer(model_name)
        load_time = time.time() - start_time
        print(f"✓ Model loaded in {load_time:.2f}s")
        print(f"  Max sequence length: {model.max_seq_length}")
        print(f"  Embedding dimension: {model.get_sentence_embedding_dimension()}")
        
        # Embed queries
        query_start = time.time()
        query_embeddings = model.encode(test_queries, show_progress_bar=False)
        query_time = time.time() - query_start
        
        # Embed chunks
        chunk_start = time.time()
        chunk_embeddings = model.encode(test_chunks, show_progress_bar=False)
        chunk_time = time.time() - chunk_start
        
        # Calculate similarities
        similarities = []
        for i, query_emb in enumerate(query_embeddings):
            # Normalize for cosine similarity
            query_emb_norm = query_emb / np.linalg.norm(query_emb)
            chunk_embs_norm = chunk_embeddings / np.linalg.norm(chunk_embeddings, axis=1, keepdims=True)
            
            # Cosine similarity
            sims = np.dot(chunk_embs_norm, query_emb_norm)
            similarities.append(sims)
            
            # Find best match
            best_idx = np.argmax(sims)
            print(f"\nQuery {i+1}: '{test_queries[i]}'")
            print(f"  Best match score: {sims[best_idx]:.4f}")
            print(f"  Best chunk: {test_chunks[best_idx][:100]}...")
        
        total_time = time.time() - start_time
        print(f"\n✓ Total test time: {total_time:.2f}s")
        print(f"  Query encoding: {query_time:.2f}s")
        print(f"  Chunk encoding: {chunk_time:.2f}s")
        
        return {
            'model': model_name,
            'load_time': load_time,
            'query_time': query_time,
            'chunk_time': chunk_time,
            'total_time': total_time,
            'dimension': model.get_sentence_embedding_dimension(),
            'similarities': similarities
        }
        
    except Exception as e:
        print(f"❌ Error testing {model_name}: {e}")
        return None

def main():
    print("="*60)
    print("Embedding Model Comparison Test")
    print("="*60)
    
    # Test queries
    test_queries = [
        "ph level cannabis",
        "ideal temperature for growing",
        "nutrient requirements"
    ]
    
    # Test chunks (sample from actual data)
    test_chunks = [
        "Cannabis grows best in a 6.5 to 8 pH range. pH tester: electronic instrument or chemical used to measure the acid or alkaline balance.",
        "The optimal temperature for cannabis cultivation ranges from 70-85°F during the day and 60-70°F at night.",
        "Cannabis requires nitrogen, phosphorus, and potassium (NPK) in varying ratios throughout its growth cycle.",
        "This book is written for the purpose of supplying information to the public. The publisher and the author do not advocate breaking the law.",
        "Introduction: Cannabis, commonly known in the United States as marijuana, is a wondrous plant and an ancient plant."
    ]
    
    # Test old model
    print("\n📊 Testing OLD model...")
    old_result = test_embedding_model("sentence-transformers/all-MiniLM-L6-v2", test_queries, test_chunks)
    
    # Test new model
    print("\n📊 Testing NEW model...")
    new_result = test_embedding_model("BAAI/bge-base-en-v1.5", test_queries, test_chunks)
    
    # Compare results
    if old_result and new_result:
        print("\n" + "="*60)
        print("COMPARISON RESULTS")
        print("="*60)
        
        print(f"\nModel Dimensions:")
        print(f"  Old: {old_result['dimension']}D")
        print(f"  New: {new_result['dimension']}D")
        
        print(f"\nPerformance:")
        print(f"  Old load time: {old_result['load_time']:.2f}s")
        print(f"  New load time: {new_result['load_time']:.2f}s")
        print(f"  Old query time: {old_result['query_time']:.2f}s")
        print(f"  New query time: {new_result['query_time']:.2f}s")
        
        print(f"\nSimilarity Scores (higher is better):")
        for i, query in enumerate(test_queries):
            old_best = np.max(old_result['similarities'][i])
            new_best = np.max(new_result['similarities'][i])
            improvement = ((new_best - old_best) / old_best) * 100
            print(f"\n  Query: '{query}'")
            print(f"    Old model: {old_best:.4f}")
            print(f"    New model: {new_best:.4f}")
            print(f"    Improvement: {improvement:+.1f}%")
        
        # Overall recommendation
        avg_improvement = np.mean([
            ((np.max(new_result['similarities'][i]) - np.max(old_result['similarities'][i])) / 
             np.max(old_result['similarities'][i])) * 100
            for i in range(len(test_queries))
        ])
        
        print(f"\n{'='*60}")
        print(f"Average Improvement: {avg_improvement:+.1f}%")
        if avg_improvement > 5:
            print("✅ RECOMMENDATION: Upgrade to new model")
        elif avg_improvement > 0:
            print("⚠️  RECOMMENDATION: Small improvement, upgrade optional")
        else:
            print("❌ RECOMMENDATION: Keep old model")
        print("="*60)

if __name__ == "__main__":
    main()

