"""
Performance Benchmark Tool
Measures query performance before/after optimization
"""

import time
import faiss
import numpy as np
import json
from pathlib import Path
from typing import List, Dict
import statistics


class PerformanceBenchmark:
    """Benchmark FAISS index performance"""
    
    def __init__(self, index_path: Path, metadata_path: Path):
        self.index_path = Path(index_path)
        self.metadata_path = Path(metadata_path)
        self.index = None
        self.metadata = None
        
    def load_index(self):
        """Load index and metadata"""
        self.index = faiss.read_index(str(self.index_path))
        with open(self.metadata_path, 'r') as f:
            self.metadata = json.load(f)
        print(f"✓ Loaded index: {self.index.ntotal} vectors, dim={self.index.d}")
        
    def generate_test_queries(self, num_queries: int = 20) -> np.ndarray:
        """Generate random query vectors for testing"""
        # Random queries with same dimension as index
        queries = np.random.randn(num_queries, self.index.d).astype('float32')
        # Normalize for IP (inner product) search
        faiss.normalize_L2(queries)
        return queries
    
    def benchmark_queries(self, queries: np.ndarray, k: int = 5) -> Dict:
        """Benchmark query performance"""
        times = []
        
        print(f"\nRunning {len(queries)} test queries (k={k})...")
        
        for i, query in enumerate(queries):
            start = time.perf_counter()
            distances, indices = self.index.search(query.reshape(1, -1), k)
            elapsed = time.perf_counter() - start
            times.append(elapsed * 1000)  # Convert to milliseconds
            
            if (i + 1) % 10 == 0:
                print(f"  {i + 1}/{len(queries)} queries completed...")
        
        results = {
            'total_queries': len(queries),
            'k': k,
            'times_ms': times,
            'mean_ms': statistics.mean(times),
            'median_ms': statistics.median(times),
            'min_ms': min(times),
            'max_ms': max(times),
            'stdev_ms': statistics.stdev(times) if len(times) > 1 else 0,
            'total_time_sec': sum(times) / 1000
        }
        
        return results
    
    def get_index_stats(self) -> Dict:
        """Get index statistics"""
        stats = {
            'type': type(self.index).__name__,
            'total_vectors': self.index.ntotal,
            'dimension': self.index.d,
            'trained': self.index.is_trained
        }
        
        if hasattr(self.index, 'nlist'):
            stats['nlist'] = self.index.nlist
        if hasattr(self.index, 'nprobe'):
            stats['nprobe'] = self.index.nprobe
        
        # Estimate memory
        import sys
        stats['memory_mb'] = sys.getsizeof(self.index) / (1024 * 1024)
        
        return stats
    
    def print_results(self, results: Dict, stats: Dict):
        """Print benchmark results"""
        print("\n" + "=" * 60)
        print("BENCHMARK RESULTS")
        print("=" * 60)
        
        print(f"\nIndex Configuration:")
        print(f"  Type: {stats['type']}")
        print(f"  Vectors: {stats['total_vectors']:,}")
        print(f"  Dimension: {stats['dimension']}")
        if 'nlist' in stats:
            print(f"  Clusters (nlist): {stats['nlist']}")
        if 'nprobe' in stats:
            print(f"  Search Clusters (nprobe): {stats['nprobe']}")
        
        print(f"\nQuery Performance ({results['total_queries']} queries, k={results['k']}):")
        print(f"  Mean:   {results['mean_ms']:.2f} ms")
        print(f"  Median: {results['median_ms']:.2f} ms")
        print(f"  Min:    {results['min_ms']:.2f} ms")
        print(f"  Max:    {results['max_ms']:.2f} ms")
        print(f"  StdDev: {results['stdev_ms']:.2f} ms")
        print(f"  Total:  {results['total_time_sec']:.2f} sec")
        
        # Performance rating
        mean_ms = results['mean_ms']
        if mean_ms < 1:
            rating = "⚡ EXCELLENT"
        elif mean_ms < 5:
            rating = "✓ GOOD"
        elif mean_ms < 20:
            rating = "~ ACCEPTABLE"
        else:
            rating = "⚠️  SLOW"
        
        print(f"\nPerformance Rating: {rating}")
        
    def run_benchmark(self, num_queries: int = 20, k: int = 5):
        """Run full benchmark"""
        print("=" * 60)
        print("FAISS Performance Benchmark")
        print("=" * 60)
        
        self.load_index()
        stats = self.get_index_stats()
        
        queries = self.generate_test_queries(num_queries)
        results = self.benchmark_queries(queries, k)
        
        self.print_results(results, stats)
        
        return results, stats


def compare_benchmarks(before: Dict, after: Dict):
    """Compare two benchmark results"""
    print("\n" + "=" * 60)
    print("PERFORMANCE COMPARISON")
    print("=" * 60)
    
    before_mean = before['mean_ms']
    after_mean = after['mean_ms']
    
    speedup = before_mean / after_mean
    improvement_pct = ((before_mean - after_mean) / before_mean) * 100
    
    print(f"\nQuery Speed:")
    print(f"  Before: {before_mean:.2f} ms")
    print(f"  After:  {after_mean:.2f} ms")
    print(f"  Speedup: {speedup:.2f}x")
    print(f"  Improvement: {improvement_pct:+.1f}%")
    
    if speedup > 1.5:
        print(f"\n✓ Significant improvement!")
    elif speedup > 1.1:
        print(f"\n✓ Noticeable improvement")
    elif speedup > 0.9:
        print(f"\n~ Similar performance")
    else:
        print(f"\n⚠️  Performance degraded")


if __name__ == "__main__":
    import sys
    from pathlib import Path
    
    # Parse arguments
    if len(sys.argv) < 2:
        print("Usage: python benchmark.py <processed_dir> [--baseline] [--compare baseline.json]")
        sys.exit(1)
    
    processed_dir = Path(sys.argv[1])
    index_path = processed_dir / "embeddings" / "faiss_index.index"
    metadata_path = processed_dir / "embeddings" / "documents_metadata.json"
    
    benchmark = PerformanceBenchmark(index_path, metadata_path)
    results, stats = benchmark.run_benchmark(num_queries=50, k=5)
    
    # Save baseline
    if '--baseline' in sys.argv:
        baseline_file = processed_dir / "embeddings" / "benchmark_baseline.json"
        with open(baseline_file, 'w') as f:
            json.dump({'results': results, 'stats': stats}, f, indent=2)
        print(f"\n✓ Saved baseline to {baseline_file}")
    
    # Compare with baseline
    if '--compare' in sys.argv:
        compare_idx = sys.argv.index('--compare') + 1
        if compare_idx < len(sys.argv):
            baseline_file = Path(sys.argv[compare_idx])
            with open(baseline_file, 'r') as f:
                baseline_data = json.load(f)
            compare_benchmarks(baseline_data['results'], results)
