"""
Adaptive Vector Processing System
Supports CPU/GPU environments with optimized backends
"""

import numpy as np
import torch
from typing import List, Optional, Union
from pathlib import Path
import time
import gc
from dataclasses import dataclass

from config import ProcessingConfig


class AdaptiveEmbeddingEngine:
    """Smart embedding engine that adapts to available hardware and optimizes performance"""

    def __init__(self, config: ProcessingConfig):
        self.config = config
        self.model = None
        self.tokenizer = None
        self._load_optimized_model()

    def _load_optimized_model(self):
        """Load model with optimal backend for current environment"""
        try:
            if self.config.backend == "onnx":
                self._load_onnx_model()
            elif self.config.backend == "openvino":
                self._load_openvino_model()
            else:  # pytorch
                self._load_pytorch_model()
        except Exception as e:
            print(f"Failed to load optimized model ({self.config.backend}), falling back to PyTorch: {e}")
            self._load_pytorch_model()

    def _load_onnx_model(self):
        """Optimized ONNX backend with quantization"""
        try:
            from optimum.onnxruntime import ORTModelForFeatureExtraction
            from transformers import AutoTokenizer

            model_name = "sentence-transformers/all-MiniLM-L6-v2"

            # Quantization config for CPU
            if self.config.precision == "int8":
                from optimum.onnxruntime.configuration import AutoQuantizationConfig
                qconfig = AutoQuantizationConfig.avx512_vnni(is_static=False)
                self.model = ORTModelForFeatureExtraction.from_pretrained(
                    model_name,
                    file_name="model_quantized.onnx",
                    quantization_config=qconfig
                )
            else:
                self.model = ORTModelForFeatureExtraction.from_pretrained(model_name)

            self.tokenizer = AutoTokenizer.from_pretrained(model_name)
            print(f"Loaded ONNX model with {self.config.precision} precision")

        except ImportError:
            raise Exception("optimum[onnxruntime] not installed. Install with: pip install optimum[onnxruntime]")

    def _load_openvino_model(self):
        """OpenVINO backend for Intel CPUs"""
        try:
            from optimum.intel import OVModelForFeatureExtraction
            from transformers import AutoTokenizer

            model_name = "sentence-transformers/all-MiniLM-L6-v2"

            self.model = OVModelForFeatureExtraction.from_pretrained(
                model_name,
                export=True,
                compile=False,
                quantization_config=None if self.config.precision == "fp32" else {"mode": "int8"}
            )
            self.tokenizer = AutoTokenizer.from_pretrained(model_name)
            print(f"Loaded OpenVINO model with {self.config.precision} precision")

        except ImportError:
            raise Exception("optimum[intel] not installed. Install with: pip install optimum[intel]")

    def _load_pytorch_model(self):
        """PyTorch backend with GPU optimizations"""
        try:
            from sentence_transformers import SentenceTransformer

            model_name = "sentence-transformers/all-MiniLM-L6-v2"
            self.model = SentenceTransformer(model_name)

            # GPU optimizations
            if self.config.device == "cuda":
                if self.config.precision == "fp16":
                    self.model = self.model.half()
                self.model = self.model.to(self.config.device)

                # Enable TensorFloat-32 for Ampere GPUs
                if torch.cuda.is_available():
                    torch.backends.cuda.matmul.allow_tf32 = True
                    torch.backends.cudnn.allow_tf32 = True

            print(f"Loaded PyTorch model on {self.config.device} with {self.config.precision} precision")

        except ImportError:
            raise Exception("sentence-transformers not installed. Install with: pip install sentence-transformers")

    def encode_parallel(self, chunks: List[str], progress_bar: bool = True) -> np.ndarray:
        """Smart parallel encoding based on available resources"""
        if self.config.device == "cuda":
            return self._encode_gpu_batch(chunks, progress_bar)
        else:
            return self._encode_cpu_parallel(chunks, progress_bar)

    def _encode_gpu_batch(self, chunks: List[str], progress_bar: bool) -> np.ndarray:
        """Large batch processing for GPU"""
        start_time = time.time()

        # Use model's built-in batching for GPU efficiency
        embeddings = self.model.encode(
            chunks,
            batch_size=self.config.batch_size,
            show_progress_bar=progress_bar,
            convert_to_tensor=True,
            device=self.config.device,
            normalize_embeddings=True
        )

        # Return to CPU for FAISS compatibility
        embeddings = embeddings.cpu().numpy()

        print(".2f")
        return embeddings

    def _encode_cpu_parallel(self, chunks: List[str], progress_bar: bool) -> np.ndarray:
        """Parallel CPU processing with process pool"""
        from concurrent.futures import ProcessPoolExecutor
        import multiprocessing

        start_time = time.time()

        def process_batch(batch):
            # Create model instance in subprocess
            if hasattr(self, '_load_pytorch_model'):
                temp_model = self._load_pytorch_model()
                return temp_model.encode(batch, show_progress_bar=False)
            else:
                # Fallback to main model (less efficient)
                return self.model.encode(batch, show_progress_bar=False)

        # Split into batches
        batches = [chunks[i:i + self.config.batch_size]
                  for i in range(0, len(chunks), self.config.batch_size)]

        # Use available CPU cores
        max_workers = min(self.config.max_workers, multiprocessing.cpu_count())

        with ProcessPoolExecutor(max_workers=max_workers) as executor:
            results = list(executor.map(process_batch, batches))

        embeddings = np.vstack(results)

        print(".2f")
        return embeddings

    def encode_single(self, text: str) -> np.ndarray:
        """Encode a single text"""
        return self.model.encode([text], show_progress_bar=False)[0]


class AdaptiveFAISSManager:
    """Adaptive FAISS manager with GPU/CPU optimizations"""

    def __init__(self, config: ProcessingConfig):
        self.config = config
        self.index = None

    def create_optimized_index(self, dimension: int = 768):
        """Create optimal FAISS index for current environment"""
        import faiss

        if self.config.device == "cuda":
            return self._create_gpu_index(dimension)
        else:
            return self._create_cpu_index(dimension)

    def _create_gpu_index(self, dimension: int):
        """GPU-optimized FAISS index"""
        import faiss

        try:
            # Try GPU index first
            res = faiss.StandardGpuResources()

            if self.config.faiss_index_type == "IVF_FLAT":
                quantizer = faiss.IndexFlatIP(dimension)
                index = faiss.IndexIVFFlat(quantizer, dimension, self.config.n_list)
                index.nprobe = self.config.n_probe
                return faiss.index_cpu_to_gpu(res, 0, index)
            else:
                # Fallback to FlatIP for maximum speed
                return faiss.GpuIndexFlatIP(res, dimension)

        except Exception as e:
            print(f"GPU FAISS failed, falling back to CPU: {e}")
            return self._create_cpu_index(dimension)

    def _create_cpu_index(self, dimension: int):
        """CPU-optimized FAISS index with memory optimizations"""
        import faiss

        # Enable memory optimizations
        faiss.omp_set_num_threads(self.config.max_workers)

        if self.config.faiss_quantization == "PQ":
            # Product Quantization for memory efficiency
            index = faiss.IndexIVFPQ(
                faiss.IndexFlatIP(dimension),
                dimension,
                self.config.n_list,
                8,  # bytes per vector
                8   # number of subquantizers
            )
        else:
            # Scalar Quantization for speed
            index = faiss.IndexIVFScalarQuantizer(
                faiss.IndexFlatIP(dimension),
                dimension,
                self.config.n_list,
                faiss.ScalarQuantizer.QT_8bit
            )

        index.nprobe = self.config.n_probe
        return index

    def batch_add_vectors(self, vectors: np.ndarray, ids: Optional[np.ndarray] = None):
        """Efficient batch addition with memory management"""
        import faiss

        batch_size = 10000  # Adjust based on available memory

        for i in range(0, len(vectors), batch_size):
            batch = vectors[i:i + batch_size].astype('float32')

            if ids is not None:
                batch_ids = ids[i:i + batch_size]
                self.index.add_with_ids(batch, batch_ids)
            else:
                self.index.add(batch)

            # Clear memory for large datasets
            if i % 100000 == 0 and i > 0:
                gc.collect()

    def search(self, query_vectors: np.ndarray, k: int = 10):
        """Optimized search with configured parameters"""
        return self.index.search(query_vectors.astype('float32'), k)

    def save_index(self, path: Union[str, Path]):
        """Save FAISS index"""
        import faiss
        faiss.write_index(self.index, str(path))

    def load_index(self, path: Union[str, Path]):
        """Load FAISS index"""
        import faiss
        self.index = faiss.read_index(str(path))


class VectorProcessingPipeline:
    """Main pipeline orchestrator integrating all components"""

    def __init__(self, config: ProcessingConfig = None):
        self.config = config or ProcessingConfig()
        self.embedding_engine = AdaptiveEmbeddingEngine(self.config)
        self.faiss_manager = AdaptiveFAISSManager(self.config)

    async def process_chunks(self, chunks: List[str], chunk_metadata: List[dict], output_dir: Union[str, Path]):
        """Main processing pipeline"""
        import time
        from datetime import datetime

        output_dir = Path(output_dir)
        output_dir.mkdir(parents=True, exist_ok=True)

        print(f"Starting pipeline on {self.config.device.upper()}")
        start_time = time.time()

        # Step 1: Generate embeddings
        print("Generating embeddings...")
        embeddings = self.embedding_engine.encode_parallel(chunks)
        embedding_time = time.time() - start_time

        # Step 2: Create and populate FAISS index
        print("Building FAISS index...")
        self.faiss_manager.index = self.faiss_manager.create_optimized_index(
            dimension=embeddings.shape[1]
        )

        # Train if necessary
        if hasattr(self.faiss_manager.index, 'train') and self.faiss_manager.index.ntotal == 0:
            self.faiss_manager.index.train(embeddings[:min(1000, len(embeddings))].astype('float32'))

        self.faiss_manager.batch_add_vectors(embeddings)
        faiss_time = time.time() - start_time - embedding_time

        # Step 3: Prepare metadata
        metadata = {
            "chunk_count": len(chunks),
            "embedding_dim": embeddings.shape[1],
            "processing_time": {
                "total": time.time() - start_time,
                "embedding": embedding_time,
                "faiss": faiss_time
            },
            "config": self.config.__dict__,
            "timestamp": datetime.now().isoformat(),
            "chunk_metadata": chunk_metadata
        }

        # Step 4: Save everything
        print("Saving vectors and metadata...")
        await self._save_vectors_and_metadata(embeddings, metadata, output_dir)

        # Save FAISS index
        import faiss
        faiss.write_index(
            self.faiss_manager.index,
            str(output_dir / "faiss_index.index")
        )

        total_time = time.time() - start_time
        print(".2f")

        return {
            "embeddings": embeddings,
            "index": self.faiss_manager.index,
            "metadata": metadata
        }

    async def _save_vectors_and_metadata(self, embeddings: np.ndarray, metadata: dict, output_dir: Path):
        """Save vectors and metadata efficiently"""
        import json

        # Save metadata
        with open(output_dir / "metadata.json", 'w') as f:
            json.dump(metadata, f, indent=2)

        # Save embeddings as numpy (most compatible)
        np.save(output_dir / "embeddings.npy", embeddings.astype('float32'))
