"""
Evaluation Framework for RAG Chunking System
Phase 3.2: Implement evaluation framework (coherence, precision, size distribution)
"""

import json
from pathlib import Path
from typing import List, Dict, Tuple
import re


class ChunkEvaluator:
    """Evaluate chunk quality and system performance"""
    
    def __init__(self):
        self.metrics = {
            'coherence_scores': [],
            'size_distribution': [],
            'boilerplate_rate': 0,
            'quality_scores': [],
        }
    
    def evaluate_chunk_coherence(self, chunk: str) -> float:
        """
        Evaluate semantic coherence of a chunk.
        Returns score 0-1 (higher = more coherent)
        """
        if not chunk:
            return 0.0
        
        # Check sentence boundary integrity
        sentences = re.split(r'[.!?]+', chunk)
        sentences = [s.strip() for s in sentences if len(s.strip()) > 10]
        
        if len(sentences) < 2:
            return 0.5  # Single sentence chunks are somewhat coherent
        
        # Check for topic consistency (simple heuristic: word overlap between sentences)
        words_per_sentence = [set(re.findall(r'\b\w+\b', s.lower())) for s in sentences]
        
        # Calculate average word overlap between adjacent sentences
        overlaps = []
        for i in range(len(words_per_sentence) - 1):
            set1 = words_per_sentence[i]
            set2 = words_per_sentence[i + 1]
            if len(set1) > 0 and len(set2) > 0:
                overlap = len(set1 & set2) / max(len(set1), len(set2))
                overlaps.append(overlap)
        
        if overlaps:
            avg_overlap = sum(overlaps) / len(overlaps)
            # Normalize to 0-1 scale (0.2 overlap = 0.5 score, 0.4 overlap = 1.0 score)
            coherence = min(1.0, max(0.0, (avg_overlap - 0.1) / 0.3))
        else:
            coherence = 0.5
        
        return coherence
    
    def evaluate_size_distribution(self, chunks: List[Dict]) -> Dict:
        """
        Evaluate chunk size distribution.
        Returns statistics about chunk sizes.
        """
        if not chunks:
            return {}
        
        sizes = [len(chunk.get('content', '')) for chunk in chunks]
        
        if not sizes:
            return {}
        
        total = len(sizes)
        avg_size = sum(sizes) / total
        min_size = min(sizes)
        max_size = max(sizes)
        
        # Calculate distribution percentages
        target_range_min = 600
        target_range_max = 1400
        
        in_target_range = sum(1 for s in sizes if target_range_min <= s <= target_range_max)
        below_min = sum(1 for s in sizes if s < 400)
        above_max = sum(1 for s in sizes if s > 1600)
        
        return {
            'total_chunks': total,
            'avg_size': round(avg_size, 1),
            'min_size': min_size,
            'max_size': max_size,
            'std_dev': round((sum((s - avg_size)**2 for s in sizes) / total)**0.5, 1),
            'in_target_range_600_1400': in_target_range,
            'in_target_range_pct': round(in_target_range / total * 100, 1),
            'below_400': below_min,
            'below_400_pct': round(below_min / total * 100, 1),
            'above_1600': above_max,
            'above_1600_pct': round(above_max / total * 100, 1),
        }
    
    def evaluate_boilerplate_rate(self, chunks: List[Dict]) -> float:
        """
        Calculate percentage of chunks containing boilerplate content.
        """
        if not chunks:
            return 0.0
        
        boilerplate_patterns = [
            r'copyright\s+©?\s*\d{4}',
            r'all rights reserved',
            r'page\s+\d+\s+of\s+\d+',
        ]
        
        boilerplate_count = 0
        for chunk in chunks:
            content = chunk.get('content', '').lower()
            for pattern in boilerplate_patterns:
                if re.search(pattern, content, re.IGNORECASE):
                    boilerplate_count += 1
                    break
        
        return round(boilerplate_count / len(chunks) * 100, 2)
    
    def evaluate_chunks(self, chunks: List[Dict]) -> Dict:
        """
        Comprehensive evaluation of chunks.
        Returns evaluation report.
        """
        if not chunks:
            return {'error': 'No chunks provided'}
        
        # Evaluate coherence
        coherence_scores = []
        for chunk in chunks:
            coherence = self.evaluate_chunk_coherence(chunk.get('content', ''))
            coherence_scores.append(coherence)
        
        avg_coherence = sum(coherence_scores) / len(coherence_scores) if coherence_scores else 0
        
        # Evaluate size distribution
        size_stats = self.evaluate_size_distribution(chunks)
        
        # Evaluate boilerplate rate
        boilerplate_rate = self.evaluate_boilerplate_rate(chunks)
        
        # Overall quality score (weighted combination)
        quality_score = (
            avg_coherence * 0.4 +  # Coherence weight
            (1.0 - boilerplate_rate / 100) * 0.3 +  # Low boilerplate weight
            min(1.0, size_stats.get('in_target_range_pct', 0) / 80) * 0.3  # Size distribution weight
        )
        
        return {
            'total_chunks': len(chunks),
            'coherence': {
                'avg_score': round(avg_coherence, 3),
                'min_score': round(min(coherence_scores), 3) if coherence_scores else 0,
                'max_score': round(max(coherence_scores), 3) if coherence_scores else 0,
            },
            'size_distribution': size_stats,
            'boilerplate_rate_pct': boilerplate_rate,
            'overall_quality_score': round(quality_score, 3),
            'recommendations': self._generate_recommendations(avg_coherence, boilerplate_rate, size_stats),
        }
    
    def _generate_recommendations(self, coherence: float, boilerplate_rate: float, size_stats: Dict) -> List[str]:
        """Generate recommendations based on evaluation results"""
        recommendations = []
        
        if coherence < 0.7:
            recommendations.append("Consider adjusting chunk boundaries to improve coherence")
        
        if boilerplate_rate > 2:
            recommendations.append(f"Boilerplate rate ({boilerplate_rate}%) is above target (2%). Improve filtering.")
        
        target_pct = size_stats.get('in_target_range_pct', 0)
        if target_pct < 80:
            recommendations.append(f"Only {target_pct}% of chunks in target range (600-1400 chars). Adjust chunk sizes.")
        
        below_min = size_stats.get('below_400', 0)
        total_chunks = size_stats.get('total_chunks', 0)
        if below_min > total_chunks * 0.1:
            recommendations.append(f"{below_min} chunks below minimum size. Consider merging small chunks.")
        
        if not recommendations:
            recommendations.append("Chunk quality metrics are within acceptable ranges")
        
        return recommendations
    
    def save_evaluation_report(self, evaluation: Dict, output_path: Path):
        """Save evaluation report to JSON file"""
        with open(output_path, 'w') as f:
            json.dump(evaluation, f, indent=2)
    
    def load_evaluation_report(self, input_path: Path) -> Dict:
        """Load evaluation report from JSON file"""
        with open(input_path, 'r') as f:
            return json.load(f)


def evaluate_processed_documents(documents_dir: Path, output_file: Path = None):
    """
    Evaluate all processed documents in the documents directory.
    """
    evaluator = ChunkEvaluator()
    all_evaluations = {}
    
    # Find all chunk JSON files
    chunk_files = list(documents_dir.glob("*_chunks.json"))
    
    print(f"Evaluating {len(chunk_files)} documents...")
    
    for chunk_file in chunk_files:
        try:
            with open(chunk_file, 'r') as f:
                chunks = json.load(f)
            
            # Extract chunks (handle both old and new formats)
            if isinstance(chunks, list):
                chunk_list = chunks if isinstance(chunks[0], dict) else [{'content': c} for c in chunks]
            else:
                chunk_list = []
            
            evaluation = evaluator.evaluate_chunks(chunk_list)
            all_evaluations[chunk_file.stem] = evaluation
            
            print(f"  {chunk_file.stem}: Quality={evaluation['overall_quality_score']:.3f}, "
                  f"Coherence={evaluation['coherence']['avg_score']:.3f}, "
                  f"Boilerplate={evaluation['boilerplate_rate_pct']:.1f}%")
        
        except Exception as e:
            print(f"  Error evaluating {chunk_file.name}: {e}")
    
    # Calculate aggregate statistics
    if all_evaluations:
        avg_quality = sum(e['overall_quality_score'] for e in all_evaluations.values()) / len(all_evaluations)
        avg_coherence = sum(e['coherence']['avg_score'] for e in all_evaluations.values()) / len(all_evaluations)
        avg_boilerplate = sum(e['boilerplate_rate_pct'] for e in all_evaluations.values()) / len(all_evaluations)
        
        summary = {
            'total_documents': len(all_evaluations),
            'aggregate_metrics': {
                'avg_quality_score': round(avg_quality, 3),
                'avg_coherence': round(avg_coherence, 3),
                'avg_boilerplate_rate_pct': round(avg_boilerplate, 2),
            },
            'document_evaluations': all_evaluations,
        }
        
        if output_file:
            evaluator.save_evaluation_report(summary, output_file)
            print(f"\nEvaluation report saved to {output_file}")
        
        return summary
    
    return {}


if __name__ == "__main__":
    from config import DOCUMENTS_DIR
    
    output_path = Path("processed/evaluation_report.json")
    evaluate_processed_documents(DOCUMENTS_DIR, output_path)

