"""
Hardware Detection and Configuration
Auto-detects optimal settings for current environment
"""

import os
import psutil
import platform
from dataclasses import dataclass
from typing import Literal


@dataclass
class HardwareConfig:
    """Optimal configuration based on detected hardware"""
    device: str
    cpu_cores: int
    total_ram_gb: float
    available_ram_gb: float
    batch_size: int
    num_workers: int
    use_quantization: bool
    faiss_index_type: str
    faiss_nlist: int
    faiss_nprobe: int


class HardwareDetector:
    """Detect hardware and generate optimal configuration"""
    
    def __init__(self):
        self.config = self._detect_hardware()
    
    def _detect_hardware(self) -> HardwareConfig:
        """Detect system hardware and return optimal config"""
        
        # CPU information
        cpu_count = os.cpu_count() or 4
        
        # Memory information
        mem = psutil.virtual_memory()
        total_ram_gb = mem.total / (1024**3)
        available_ram_gb = mem.available / (1024**3)
        
        # Check for GPU (torch/cuda)
        has_gpu = False
        try:
            import torch
            has_gpu = torch.cuda.is_available()
        except ImportError:
            has_gpu = False
        
        # Determine optimal settings
        device = 'cuda' if has_gpu else 'cpu'
        
        # Batch size based on available RAM
        if available_ram_gb > 10:
            batch_size = 64
        elif available_ram_gb > 6:
            batch_size = 32
        else:
            batch_size = 16
        
        # Workers for parallel processing (leave 2 cores for system)
        num_workers = max(1, cpu_count - 2)
        
        # Use quantization on CPU to save memory
        use_quantization = not has_gpu
        
        # FAISS index settings for ~6K documents
        # nlist = sqrt(num_docs) is a good heuristic
        faiss_nlist = 100  # sqrt(6000) ≈ 77, round to 100
        faiss_nprobe = 10  # Search 10% of clusters (speed/accuracy tradeoff)
        faiss_index_type = 'IVF_SQ8' if not has_gpu else 'IVF_FLAT'
        
        return HardwareConfig(
            device=device,
            cpu_cores=cpu_count,
            total_ram_gb=round(total_ram_gb, 2),
            available_ram_gb=round(available_ram_gb, 2),
            batch_size=batch_size,
            num_workers=num_workers,
            use_quantization=use_quantization,
            faiss_index_type=faiss_index_type,
            faiss_nlist=faiss_nlist,
            faiss_nprobe=faiss_nprobe
        )
    
    def print_config(self):
        """Print detected configuration"""
        print("=== Hardware Configuration ===")
        print(f"Device: {self.config.device.upper()}")
        print(f"CPU Cores: {self.config.cpu_cores}")
        print(f"Total RAM: {self.config.total_ram_gb} GB")
        print(f"Available RAM: {self.config.available_ram_gb} GB")
        print(f"")
        print("=== Optimization Settings ===")
        print(f"Batch Size: {self.config.batch_size}")
        print(f"Workers: {self.config.num_workers}")
        print(f"Quantization: {self.config.use_quantization}")
        print(f"")
        print("=== FAISS Index Settings ===")
        print(f"Index Type: {self.config.faiss_index_type}")
        print(f"nlist (clusters): {self.config.faiss_nlist}")
        print(f"nprobe (search): {self.config.faiss_nprobe}")
        print(f"")
        print(f"Platform: {platform.system()} {platform.release()}")
        print(f"Python: {platform.python_version()}")
        return self.config


if __name__ == "__main__":
    detector = HardwareDetector()
    config = detector.print_config()
    
    print("\n=== Recommendations ===")
    if config.available_ram_gb < 4:
        print("⚠️  Low RAM detected. Consider increasing batch_size if OOM errors occur.")
    if config.cpu_cores <= 2:
        print("⚠️  Limited CPU cores. Parallel processing may not provide benefits.")
    if config.device == 'cpu':
        print("✓ CPU-only mode. Using scalar quantization to reduce memory.")
    else:
        print("✓ GPU detected! Using full precision for maximum quality.")
