"""
Reranker Module for RAG System
================================
Implements cross-encoder based reranking to improve retrieval precision.

Based on research from "When Retrieval Succeeds and Fails: Rethinking RAG for LLMs"
which recommends reranking as a key step after initial retrieval to filter noise
and improve relevance ordering.

The cross-encoder scores (query, document) pairs jointly, providing more accurate
relevance scores than bi-encoder (embedding) similarity alone.
"""

import os
import time
import logging
from typing import List, Dict, Any, Optional, Tuple
from functools import lru_cache

logger = logging.getLogger(__name__)

# Configuration
RERANKER_ENABLED = os.environ.get("RERANKER_ENABLED", "true").lower() == "true"
RERANKER_MODEL = os.environ.get("RERANKER_MODEL", "cross-encoder/ms-marco-MiniLM-L-6-v2")
RERANKER_TOP_K = int(os.environ.get("RERANKER_TOP_K", "5"))  # Final results after reranking
RERANKER_CANDIDATES = int(os.environ.get("RERANKER_CANDIDATES", "20"))  # Candidates to rerank
RERANKER_MIN_SCORE = float(os.environ.get("RERANKER_MIN_SCORE", "-2.0"))  # Cross-encoder can output negative scores  # Min score threshold
RERANKER_WEIGHT = float(os.environ.get("RERANKER_WEIGHT", "0.4"))  # Weight in final score

# Lazy-loaded model
_reranker_model = None
_model_load_attempted = False


def get_reranker_model():
    """Lazy load the cross-encoder model."""
    global _reranker_model, _model_load_attempted
    
    if _reranker_model is not None:
        return _reranker_model
    
    if _model_load_attempted:
        return None
    
    _model_load_attempted = True
    
    try:
        from sentence_transformers import CrossEncoder
        logger.info(f"Loading reranker model: {RERANKER_MODEL}")
        start = time.time()
        _reranker_model = CrossEncoder(RERANKER_MODEL, max_length=512)
        logger.info(f"Reranker model loaded in {time.time() - start:.2f}s")
        return _reranker_model
    except ImportError:
        logger.warning("sentence_transformers not installed. Reranking disabled.")
        return None
    except Exception as e:
        logger.error(f"Failed to load reranker model: {e}")
        return None


def rerank_candidates(
    query: str,
    candidates: List[Dict[str, Any]],
    top_k: int = RERANKER_TOP_K,
    min_score: float = RERANKER_MIN_SCORE,
    reranker_weight: float = RERANKER_WEIGHT
) -> List[Dict[str, Any]]:
    """
    Rerank retrieval candidates using a cross-encoder model.
    
    Args:
        query: The user's search query
        candidates: List of candidate documents from initial retrieval
        top_k: Number of top results to return after reranking
        min_score: Minimum reranker score to include a result
        reranker_weight: Weight of reranker score in final combined score
    
    Returns:
        Reranked list of candidates with updated scores
    """
    if not candidates:
        return []
    
    if not RERANKER_ENABLED:
        logger.debug("Reranker disabled, returning original candidates")
        return candidates[:top_k]
    
    model = get_reranker_model()
    if model is None:
        logger.debug("Reranker model not available, returning original candidates")
        return candidates[:top_k]
    
    try:
        start_time = time.time()
        
        # Prepare query-document pairs for cross-encoder
        pairs = []
        for candidate in candidates:
            # Use content for reranking, fall back to text
            doc_text = candidate.get('content') or candidate.get('text', '')
            # Truncate long documents for efficiency
            doc_text = doc_text[:1500] if len(doc_text) > 1500 else doc_text
            pairs.append([query, doc_text])
        
        # Get cross-encoder scores
        reranker_scores = model.predict(pairs, show_progress_bar=False)
        
        # Update candidates with reranker scores
        reranked = []
        for i, candidate in enumerate(candidates):
            reranker_score = float(reranker_scores[i])
            
            # Skip if below minimum threshold
            if reranker_score < min_score:
                continue
            
            # Normalize reranker score to [0, 1] range (cross-encoder outputs can be any range)
            # Using sigmoid-like normalization
            normalized_reranker = 1 / (1 + pow(2.71828, -reranker_score))
            
            # Combine with original score
            original_score = candidate.get('score', 0.5)
            combined_score = (
                (1 - reranker_weight) * original_score + 
                reranker_weight * normalized_reranker
            )
            
            reranked_candidate = candidate.copy()
            reranked_candidate['reranker_score'] = reranker_score
            reranked_candidate['reranker_normalized'] = normalized_reranker
            reranked_candidate['original_score'] = original_score
            reranked_candidate['score'] = combined_score
            reranked.append(reranked_candidate)
        
        # Sort by combined score
        reranked.sort(key=lambda x: x['score'], reverse=True)
        
        elapsed = time.time() - start_time
        logger.info(f"Reranked {len(candidates)} candidates to {min(len(reranked), top_k)} in {elapsed:.3f}s")
        
        return reranked[:top_k]
        
    except Exception as e:
        logger.error(f"Reranking failed: {e}")
        # Fallback to original ordering
        return candidates[:top_k]


def rerank_with_diversity(
    query: str,
    candidates: List[Dict[str, Any]],
    top_k: int = RERANKER_TOP_K,
    max_per_source: int = 2,
    diversity_weight: float = 0.1
) -> List[Dict[str, Any]]:
    """
    Rerank candidates with diversity constraints to avoid over-representation
    from single sources.
    
    This implements the MMR (Maximal Marginal Relevance) concept to balance
    relevance with diversity.
    """
    if not candidates:
        return []
    
    # First, get reranker scores
    reranked = rerank_candidates(query, candidates, top_k=len(candidates))
    
    if not reranked:
        return candidates[:top_k]
    
    # Apply diversity selection
    selected = []
    source_counts = {}
    
    for candidate in reranked:
        source = candidate.get('metadata', {}).get('source', 'unknown')
        current_count = source_counts.get(source, 0)
        
        if current_count >= max_per_source:
            # Apply diversity penalty but still consider
            candidate['score'] *= (1 - diversity_weight)
        
        source_counts[source] = current_count + 1
        selected.append(candidate)
        
        if len([c for c in selected if source_counts.get(c.get('metadata', {}).get('source', 'unknown'), 0) <= max_per_source]) >= top_k:
            break
    
    # Re-sort after diversity adjustments
    selected.sort(key=lambda x: x['score'], reverse=True)
    
    return selected[:top_k]


def get_reranker_status() -> Dict[str, Any]:
    """Get status information about the reranker."""
    model = get_reranker_model() if RERANKER_ENABLED else None
    
    return {
        'enabled': RERANKER_ENABLED,
        'model': RERANKER_MODEL if model else None,
        'model_loaded': model is not None,
        'top_k': RERANKER_TOP_K,
        'candidates': RERANKER_CANDIDATES,
        'min_score': RERANKER_MIN_SCORE,
        'weight': RERANKER_WEIGHT
    }


# Lightweight reranking alternatives (no model required)

def keyword_boost_rerank(
    query: str,
    candidates: List[Dict[str, Any]],
    top_k: int = 5
) -> List[Dict[str, Any]]:
    """
    Simple keyword-based reranking boost.
    Useful as a fallback when cross-encoder is not available.
    """
    import re
    
    query_lower = query.lower()
    query_words = set(re.findall(r'\b\w{4,}\b', query_lower))  # Words 4+ chars
    
    reranked = []
    for candidate in candidates:
        content = (candidate.get('content') or candidate.get('text', '')).lower()
        
        # Count matching query words
        matches = sum(1 for word in query_words if word in content)
        match_ratio = matches / len(query_words) if query_words else 0
        
        # Boost score based on keyword matches
        original_score = candidate.get('score', 0.5)
        boosted_score = original_score + (0.2 * match_ratio)
        
        reranked_candidate = candidate.copy()
        reranked_candidate['keyword_match_ratio'] = match_ratio
        reranked_candidate['original_score'] = original_score
        reranked_candidate['score'] = min(boosted_score, 1.0)
        reranked.append(reranked_candidate)
    
    reranked.sort(key=lambda x: x['score'], reverse=True)
    return reranked[:top_k]


def length_penalty_rerank(
    candidates: List[Dict[str, Any]],
    ideal_length: int = 800,
    penalty_factor: float = 0.1,
    top_k: int = 5
) -> List[Dict[str, Any]]:
    """
    Apply length penalty to favor chunks of optimal size.
    Too short = less context, too long = diluted relevance.
    """
    reranked = []
    for candidate in candidates:
        content = candidate.get('content') or candidate.get('text', '')
        length = len(content)
        
        # Calculate length penalty (0 = ideal, increases as length deviates)
        length_diff = abs(length - ideal_length) / ideal_length
        penalty = penalty_factor * min(length_diff, 1.0)
        
        original_score = candidate.get('score', 0.5)
        penalized_score = original_score * (1 - penalty)
        
        reranked_candidate = candidate.copy()
        reranked_candidate['length'] = length
        reranked_candidate['length_penalty'] = penalty
        reranked_candidate['original_score'] = original_score
        reranked_candidate['score'] = penalized_score
        reranked.append(reranked_candidate)
    
    reranked.sort(key=lambda x: x['score'], reverse=True)
    return reranked[:top_k]


# Combined reranking pipeline

def full_rerank_pipeline(
    query: str,
    candidates: List[Dict[str, Any]],
    top_k: int = RERANKER_TOP_K,
    use_cross_encoder: bool = True,
    use_diversity: bool = True,
    max_per_source: int = 2
) -> List[Dict[str, Any]]:
    """
    Full reranking pipeline combining multiple techniques.
    
    Pipeline:
    1. Cross-encoder reranking (if available and enabled)
    2. Diversity selection (if enabled)
    3. Final top-k selection
    """
    if not candidates:
        return []
    
    result = candidates
    
    # Step 1: Cross-encoder reranking
    if use_cross_encoder and RERANKER_ENABLED:
        model = get_reranker_model()
        if model:
            result = rerank_candidates(query, result, top_k=min(len(result), RERANKER_CANDIDATES))
        else:
            # Fallback to keyword boost
            result = keyword_boost_rerank(query, result, top_k=len(result))
    
    # Step 2: Diversity selection
    if use_diversity:
        result = rerank_with_diversity(
            query, result, 
            top_k=top_k, 
            max_per_source=max_per_source
        )
    else:
        result = result[:top_k]
    
    return result
