"""
FAISS Index Optimizer
Rebuilds existing index with optimized structure for better performance
"""

import faiss
import numpy as np
import json
from pathlib import Path
import time
import shutil
from typing import Tuple, Optional

from optimizations.hardware_detector import HardwareDetector


class FAISSOptimizer:
    """Optimize existing FAISS index for better performance"""
    
    def __init__(self, processed_dir: Path):
        self.processed_dir = Path(processed_dir)
        self.embeddings_dir = self.processed_dir / "embeddings"
        self.index_path = self.embeddings_dir / "faiss_index.index"
        self.metadata_path = self.embeddings_dir / "documents_metadata.json"
        self.backup_dir = self.embeddings_dir / "backup"
        
        # Get optimal hardware config
        detector = HardwareDetector()
        self.config = detector.config
        
    def backup_current_index(self) -> bool:
        """Backup current index before optimization"""
        try:
            self.backup_dir.mkdir(parents=True, exist_ok=True)
            timestamp = int(time.time())
            
            if self.index_path.exists():
                backup_index = self.backup_dir / f"faiss_index_{timestamp}.index"
                shutil.copy2(self.index_path, backup_index)
                print(f"✓ Backed up index to {backup_index}")
            
            if self.metadata_path.exists():
                backup_meta = self.backup_dir / f"documents_metadata_{timestamp}.json"
                shutil.copy2(self.metadata_path, backup_meta)
                print(f"✓ Backed up metadata to {backup_meta}")
            
            return True
        except Exception as e:
            print(f"✗ Backup failed: {e}")
            return False
    
    def load_current_index(self) -> Tuple[Optional[faiss.Index], Optional[np.ndarray], Optional[list]]:
        """Load current index and metadata"""
        try:
            # Load index
            if not self.index_path.exists():
                print(f"✗ Index not found at {self.index_path}")
                return None, None, None
            
            index = faiss.read_index(str(self.index_path))
            print(f"✓ Loaded existing index: {index.ntotal} vectors, dim={index.d}")
            
            # Load metadata
            if not self.metadata_path.exists():
                print("✗ Metadata not found")
                return index, None, None
            
            with open(self.metadata_path, 'r') as f:
                metadata = json.load(f)
            print(f"✓ Loaded metadata for {len(metadata)} documents")
            
            # Extract vectors from index
            # For FlatIP index, we can reconstruct vectors
            if hasattr(index, 'reconstruct_n'):
                vectors = np.zeros((index.ntotal, index.d), dtype='float32')
                for i in range(index.ntotal):
                    vectors[i] = index.reconstruct(i)
                print(f"✓ Extracted {len(vectors)} vectors from index")
            else:
                print("⚠️  Cannot extract vectors from this index type")
                vectors = None
            
            return index, vectors, metadata
            
        except Exception as e:
            print(f"✗ Failed to load index: {e}")
            return None, None, None
    
    def create_optimized_index(self, dimension: int, num_vectors: int) -> faiss.Index:
        """Create optimized FAISS index based on configuration"""
        print(f"\n=== Creating Optimized Index ===")
        print(f"Vectors: {num_vectors}, Dimension: {dimension}")
        print(f"Config: {self.config.faiss_index_type}, nlist={self.config.faiss_nlist}, nprobe={self.config.faiss_nprobe}")
        
        if self.config.faiss_index_type == 'IVF_SQ8':
            # IVF with Scalar Quantization (best for CPU, saves memory)
            quantizer = faiss.IndexFlatIP(dimension)
            index = faiss.IndexIVFScalarQuantizer(
                quantizer,
                dimension,
                self.config.faiss_nlist,
                faiss.ScalarQuantizer.QT_8bit
            )
            print("✓ Created IVF with Scalar Quantization (CPU-optimized)")
            
        elif self.config.faiss_index_type == 'IVF_FLAT':
            # IVF with full vectors (best for GPU or high accuracy)
            quantizer = faiss.IndexFlatIP(dimension)
            index = faiss.IndexIVFFlat(
                quantizer,
                dimension,
                self.config.faiss_nlist
            )
            print("✓ Created IVF with full precision")
            
        else:
            # Fallback to simple flat index
            index = faiss.IndexFlatIP(dimension)
            print("✓ Created simple flat index (fallback)")
        
        # Set search parameters
        if hasattr(index, 'nprobe'):
            index.nprobe = self.config.faiss_nprobe
            print(f"✓ Set nprobe = {index.nprobe}")
        
        return index
    
    def train_and_populate(self, index: faiss.Index, vectors: np.ndarray) -> faiss.Index:
        """Train index and add vectors"""
        print(f"\n=== Training and Populating Index ===")
        
        # Train if needed
        if not index.is_trained:
            print(f"Training index on {len(vectors)} vectors...")
            start = time.time()
            index.train(vectors)
            print(f"✓ Training complete in {time.time() - start:.2f}s")
        else:
            print("✓ Index doesn't need training")
        
        # Add vectors
        print(f"Adding {len(vectors)} vectors to index...")
        start = time.time()
        index.add(vectors)
        print(f"✓ Added vectors in {time.time() - start:.2f}s")
        
        return index
    
    def save_optimized_index(self, index: faiss.Index) -> bool:
        """Save optimized index"""
        try:
            print(f"\n=== Saving Optimized Index ===")
            faiss.write_index(index, str(self.index_path))
            print(f"✓ Saved to {self.index_path}")
            return True
        except Exception as e:
            print(f"✗ Save failed: {e}")
            return False
    
    def get_index_stats(self, index: faiss.Index) -> dict:
        """Get statistics about the index"""
        stats = {
            'type': type(index).__name__,
            'total_vectors': index.ntotal,
            'dimension': index.d,
            'trained': index.is_trained
        }
        
        if hasattr(index, 'nlist'):
            stats['nlist'] = index.nlist
        if hasattr(index, 'nprobe'):
            stats['nprobe'] = index.nprobe
        
        # Estimate memory usage
        if hasattr(index, 'code_size'):
            bytes_per_vector = index.code_size + 8  # code + id
        else:
            bytes_per_vector = index.d * 4 + 8  # float32 * dim + id
        
        stats['memory_mb'] = (index.ntotal * bytes_per_vector) / (1024 * 1024)
        
        return stats
    
    def optimize(self, dry_run: bool = False) -> bool:
        """Main optimization workflow"""
        print("=" * 60)
        print("FAISS Index Optimization")
        print("=" * 60)
        
        # Step 1: Backup
        print("\n[1/5] Backing up current index...")
        if not self.backup_current_index():
            print("✗ Backup failed! Aborting.")
            return False
        
        # Step 2: Load current
        print("\n[2/5] Loading current index...")
        old_index, vectors, metadata = self.load_current_index()
        
        if old_index is None or vectors is None:
            print("✗ Cannot load current index. Aborting.")
            return False
        
        if dry_run:
            print("\n🔍 DRY RUN MODE - Not making changes")
            print(f"Would optimize index with {len(vectors)} vectors")
            return True
        
        # Step 3: Create optimized index
        print("\n[3/5] Creating optimized index...")
        new_index = self.create_optimized_index(old_index.d, len(vectors))
        
        # Step 4: Train and populate
        print("\n[4/5] Training and populating...")
        new_index = self.train_and_populate(new_index, vectors)
        
        # Step 5: Save
        print("\n[5/5] Saving optimized index...")
        if not self.save_optimized_index(new_index):
            print("✗ Save failed! Original index is backed up.")
            return False
        
        # Show stats
        print("\n" + "=" * 60)
        print("OPTIMIZATION COMPLETE")
        print("=" * 60)
        
        old_stats = self.get_index_stats(old_index)
        new_stats = self.get_index_stats(new_index)
        
        print(f"\nOLD INDEX:")
        print(f"  Type: {old_stats['type']}")
        print(f"  Memory: {old_stats['memory_mb']:.2f} MB")
        
        print(f"\nNEW INDEX:")
        print(f"  Type: {new_stats['type']}")
        print(f"  Clusters (nlist): {new_stats.get('nlist', 'N/A')}")
        print(f"  Search (nprobe): {new_stats.get('nprobe', 'N/A')}")
        print(f"  Memory: {new_stats['memory_mb']:.2f} MB")
        
        memory_savings = (1 - new_stats['memory_mb'] / old_stats['memory_mb']) * 100
        print(f"\nMEMORY SAVINGS: {memory_savings:.1f}%")
        
        print(f"\n✓ Optimization successful!")
        print(f"✓ Backup saved in: {self.backup_dir}")
        
        return True


if __name__ == "__main__":
    import sys
    from pathlib import Path
    
    # Get processed dir from command line or use default
    if len(sys.argv) > 1:
        processed_dir = Path(sys.argv[1])
    else:
        processed_dir = Path(__file__).parent.parent / "processed"
    
    dry_run = '--dry-run' in sys.argv
    
    optimizer = FAISSOptimizer(processed_dir)
    success = optimizer.optimize(dry_run=dry_run)
    
    sys.exit(0 if success else 1)
