#!/usr/bin/env python3
"""
Advanced Legal Document RAG Service
Incorporates sophisticated PDF processing, FAISS vector search, BM25 hybrid scoring,
and DeepSeek/Tesseract OCR fallback.

Port: 5007 (configurable via RAG_PORT)
"""

import os
import sys
import json
import time
import hashlib
import uuid
import re
import math
import threading
import queue
from datetime import datetime, timezone
from pathlib import Path
from functools import wraps
from dataclasses import dataclass
from typing import List, Dict, Any, Optional, Tuple

from flask import Flask, request, jsonify, send_from_directory, render_template
from werkzeug.utils import secure_filename

# ============================================================================
# Configuration
# ============================================================================
@dataclass
class RAGConfig:
    """RAG Service Configuration"""
    # Paths - use the ai-lawyer-rag directory, NOT the parent eventheodds directory
    base_dir: Path = Path(__file__).parent
    data_dir: Path = None
    cache_dir: Path = None
    pdf_dir: Path = None
    
    # Chunking parameters
    chunk_size: int = 1000
    chunk_min_size: int = 400
    chunk_max_size: int = 1600
    sentence_overlap: int = 2
    min_sentences: int = 3
    
    # Search parameters
    hybrid_vector_weight: float = 0.50
    hybrid_keyword_weight: float = 0.50
    hybrid_bm25_weight: float = 0.25
    max_chunks_per_source: int = 2
    default_k: int = 10
    
    # OCR fallback URLs
    deepseek_ocr_url: str = "http://127.0.0.1:5003/ocr"
    tesseract_ocr_url: str = "http://127.0.0.1:5002/ocr"
    
    def __post_init__(self):
        if self.data_dir is None:
            self.data_dir = Path(os.environ.get('DATA_DIR', self.base_dir / 'data'))
        if self.cache_dir is None:
            self.cache_dir = self.data_dir / 'rag_cache'
        if self.pdf_dir is None:
            self.pdf_dir = self.data_dir / 'csv'  # Default upload dir
        
        # Ensure directories exist
        self.cache_dir.mkdir(parents=True, exist_ok=True)
        self.pdf_dir.mkdir(parents=True, exist_ok=True)

# Document type configurations
DOCUMENT_TYPE_CONFIGS = {
    'technical': {'chunk_size': 800, 'chunk_min_size': 400, 'chunk_max_size': 1200, 'sentence_overlap': 1, 'min_sentences': 2},
    'research': {'chunk_size': 1000, 'chunk_min_size': 500, 'chunk_max_size': 1500, 'sentence_overlap': 2, 'min_sentences': 3},
    'legal': {'chunk_size': 900, 'chunk_min_size': 500, 'chunk_max_size': 1400, 'sentence_overlap': 2, 'min_sentences': 3},
    'sports': {'chunk_size': 1000, 'chunk_min_size': 400, 'chunk_max_size': 1500, 'sentence_overlap': 2, 'min_sentences': 3},
    'default': {'chunk_size': 1000, 'chunk_min_size': 400, 'chunk_max_size': 1600, 'sentence_overlap': 2, 'min_sentences': 3}
}

# ============================================================================
# Tiered Legal Keywords for RAG Optimization
# Higher tiers get stronger boost in hybrid search
# ============================================================================

# Tier 1: Core Legal Reasoning & Concepts (highest boost: 1.0)
LEGAL_TIER_1 = [
    'legal doctrine', 'holding', 'dicta', 'precedent', 'stare decisis', 'analogy',
    'distinguishing', 'syllogism', 'rule application', 'fact pattern', 'element',
    'prong', 'test', 'daubert test', 'balancing test', 'multi-factor test',
    'bright-line rule', 'standard of review', 'de novo', 'abuse of discretion',
    'clear error', 'burden of production', 'burden of persuasion', 'presumption',
    'inference', 'prima facie case', 'affirmative defense', 'genuine issue of material fact',
    'moving party', 'non-moving party', 'celotex trilogy', 'reasonable doubt',
    'preponderance of evidence', 'clear and convincing', 'strict scrutiny',
    'intermediate scrutiny', 'rational basis', 'mens rea', 'actus reus',
]

# Tier 2: Legal Actions & Processes (boost: 0.85)
LEGAL_TIER_2 = [
    'cause of action', 'claim', 'motion to dismiss', '12(b)(6)', 'rule 12',
    'motion for summary judgment', 'discovery', 'interrogatory', 'deposition',
    'request for production', 'request for admission', 'in limine', 'voir dire',
    'direct examination', 'cross-examination', 'closing argument', 'opening statement',
    'jury instruction', 'appeal', 'writ of certiorari', 'remand', 'settlement',
    'negotiation', 'mediation', 'arbitration', 'injunction', 'temporary restraining order',
    'preliminary injunction', 'permanent injunction', 'specific performance',
    'compensatory damages', 'punitive damages', 'liquidated damages', 'nominal damages',
    'treble damages', 'class action', 'joinder', 'intervention', 'interpleader',
    'declaratory judgment', 'mandamus', 'prohibition', 'quo warranto',
]

# Tier 3: Legal Document & Discourse Types (boost: 0.75)
LEGAL_TIER_3 = [
    'complaint', 'answer', 'counterclaim', 'cross-claim', 'third-party complaint',
    'appellate brief', 'trial brief', 'memorandum of law', 'memorandum in support',
    'memorandum in opposition', 'reply brief', 'amicus brief', 'majority opinion',
    'concurring opinion', 'dissenting opinion', 'plurality opinion', 'per curiam',
    'bench memorandum', 'demand letter', 'cease and desist', 'engagement letter',
    'retainer agreement', 'unilateral contract', 'bilateral contract', 'indemnity clause',
    'hold harmless', 'force majeure', 'statute', 'regulation', 'ordinance',
    'legislative history', 'committee report', 'floor debate', 'regulatory comment',
    'notice and comment', 'proposed rule', 'final rule', 'consent decree',
]

# Tier 4: Key Entities & Roles (boost: 0.65)
LEGAL_TIER_4 = [
    'plaintiff', 'defendant', 'petitioner', 'respondent', 'appellant', 'appellee',
    'movant', 'non-movant', 'counsel of record', 'amicus curiae', 'fact witness',
    'expert witness', 'trier of fact', 'jury', 'bench trial', 'trial court',
    'appellate court', 'en banc', 'supreme court', 'circuit court', 'district court',
    'magistrate judge', 'administrative law judge', 'arbitrator', 'mediator',
    'special master', 'receiver', 'trustee', 'guardian ad litem', 'curator',
    'sec', 'epa', 'ftc', 'doj', 'nlrb', 'eeoc', 'cfpb', 'osha',
    'prosecutor', 'public defender', 'solicitor general', 'attorney general',
]

# Tier 5: Foundational Philosophical & Ethical Concepts (boost: 0.70)
LEGAL_TIER_5 = [
    'justice', 'procedural justice', 'distributive justice', 'equity', 'fairness',
    'due process', 'substantive due process', 'procedural due process', 'equal protection',
    'liberty', 'autonomy', 'privacy', 'efficiency', 'kaldor-hicks', 'pareto efficiency',
    'deterrence', 'retribution', 'rehabilitation', 'incapacitation', 'restitution',
    'good faith', 'bad faith', 'fiduciary duty', 'duty of care', 'duty of loyalty',
    'confidentiality', 'attorney-client privilege', 'work product doctrine',
    'conflict of interest', 'zealous advocacy', 'candor to the tribunal',
    'jurisdiction', 'personal jurisdiction', 'subject matter jurisdiction',
    'sovereignty', 'legitimacy', 'rule of law', 'separation of powers',
    'federalism', 'preemption', 'dormant commerce clause',
]

# Tier 6: Meta-Cognitive & Strategic Tags (boost: 0.80)
LEGAL_TIER_6 = [
    'strategic consideration', 'risk assessment', 'client counseling', 'ethical dilemma',
    'cost-benefit analysis', 'probabilistic outcome', 'worst-case scenario', 'leverage',
    'bargaining position', 'settlement value', 'litigation risk', 'exposure',
    'precedent strength', 'circuit split', 'emerging trend', 'novel issue',
    'matter of first impression', 'policy argument', 'textualist argument',
    'originalist argument', 'purposive interpretation', 'legislative intent',
    'slippery slope', 'parade of horribles', 'floodgates argument', 'chilling effect',
    'overruling', 'distinguishing precedent', 'limiting holding', 'extending doctrine',
    'case strategy', 'litigation posture', 'forum selection', 'choice of law',
]

# Additional common legal terms (boost: 0.50)
LEGAL_GENERAL = [
    'contract', 'agreement', 'breach', 'damages', 'liability', 'negligence', 'tort',
    'jurisdiction', 'venue', 'statute', 'constitution', 'amendment', 'civil rights',
    'criminal', 'felony', 'misdemeanor', 'arraignment', 'indictment', 'verdict', 'sentence',
    'judge', 'witness', 'testimony', 'evidence', 'motion', 'brief', 'pleading',
    'subpoena', 'summons', 'writ', 'habeas corpus', 'attorney', 'counsel', 'lawyer',
    'privilege', 'custody', 'alimony', 'divorce', 'probate', 'will', 'trust', 'estate',
    'corporation', 'llc', 'partnership', 'shareholder', 'fiduciary', 'patent', 'trademark',
    'copyright', 'employment', 'discrimination', 'harassment', 'real estate', 'lease',
    'mortgage', 'deed', 'easement', 'bankruptcy', 'creditor', 'lien', 'collateral',
    'immigration', 'asylum', 'deportation', 'naturalization',
]

# Combined flat list for backward compatibility
LEGAL_TERMS = (
    LEGAL_TIER_1 + LEGAL_TIER_2 + LEGAL_TIER_3 +
    LEGAL_TIER_4 + LEGAL_TIER_5 + LEGAL_TIER_6 + LEGAL_GENERAL
)

# Tier weights for hybrid search boosting
LEGAL_TIER_WEIGHTS = {
    1: 1.0,   # Core Legal Reasoning - highest priority
    2: 0.85,  # Legal Actions & Processes
    3: 0.75,  # Legal Documents
    4: 0.65,  # Entities & Roles
    5: 0.70,  # Philosophical & Ethical
    6: 0.80,  # Strategic & Meta-Cognitive
    0: 0.50,  # General terms
}

def get_term_tier(term: str) -> int:
    """Get the tier level for a legal term (1-6, or 0 for general)"""
    term_lower = term.lower()
    if term_lower in [t.lower() for t in LEGAL_TIER_1]:
        return 1
    elif term_lower in [t.lower() for t in LEGAL_TIER_2]:
        return 2
    elif term_lower in [t.lower() for t in LEGAL_TIER_3]:
        return 3
    elif term_lower in [t.lower() for t in LEGAL_TIER_4]:
        return 4
    elif term_lower in [t.lower() for t in LEGAL_TIER_5]:
        return 5
    elif term_lower in [t.lower() for t in LEGAL_TIER_6]:
        return 6
    elif term_lower in [t.lower() for t in LEGAL_GENERAL]:
        return 0
    return -1  # Not a legal term

def get_term_weight(term: str) -> float:
    """Get the boost weight for a legal term based on its tier"""
    tier = get_term_tier(term)
    if tier >= 0:
        return LEGAL_TIER_WEIGHTS.get(tier, 0.3)
    return 0.3  # Default non-legal term weight

# ============================================================================
# Strategic Additions (Quick Wins)
# ============================================================================
STRATEGIC_ADDITIONS = [
    'issue spotting', 'elements analysis', 'counterargument anticipation',
    'rebuttal strategy', 'evidentiary foundation', 'preservation of error',
    'deferential review', 'harmless error', 'plain error', 'structural error',
    'waiver', 'forfeiture', 'estoppel', 'laches', 'clean hands', 'unclean hands',
    'parol evidence rule', 'statute of limitations', 'statute of frauds',
    'best evidence rule', 'hearsay', 'hearsay exception', 'business records exception',
]
# Extend LEGAL_TIER_6 with strategic additions
LEGAL_TIER_6.extend(STRATEGIC_ADDITIONS)

# ============================================================================
# Context-Aware Term Expansion (Concept Clusters)
# ============================================================================
LEGAL_CONCEPT_CLUSTERS = {
    'summary_judgment': [
        'summary judgment', 'motion for summary judgment', 'genuine issue of material fact',
        'moving party', 'non-moving party', 'Rule 56', 'Celotex trilogy',
        'no genuine dispute', 'material fact', 'summary adjudication',
        'Celotex', 'Anderson v. Liberty Lobby', 'Matsushita', 'burden of production'
    ],
    'attorney_client': [
        'attorney-client privilege', 'privilege', 'confidentiality', 'work product doctrine',
        'duty of confidentiality', 'ethics rule 1.6', 'client confidence',
        'privileged communication', 'common interest doctrine', 'joint defense',
        'crime-fraud exception', 'waiver of privilege', 'attorney client'
    ],
    'discovery': [
        'discovery', 'interrogatory', 'deposition', 'request for production',
        'request for admission', 'subpoena', 'mandatory disclosure',
        'protective order', 'privilege log', 'document production',
        'Rule 26', 'spoliation', 'ESI', 'electronically stored information'
    ],
    'jurisdiction': [
        'personal jurisdiction', 'subject matter jurisdiction', 'venue',
        'forum non conveniens', 'minimum contacts', 'specific jurisdiction',
        'general jurisdiction', 'long-arm statute', 'diversity jurisdiction',
        'federal question', 'supplemental jurisdiction', 'removal'
    ],
    'due_process': [
        'due process', 'procedural due process', 'substantive due process',
        'notice', 'opportunity to be heard', 'fundamental rights',
        'Mathews v. Eldridge', 'rational basis', 'strict scrutiny'
    ],
    'contract_formation': [
        'offer', 'acceptance', 'consideration', 'mutual assent',
        'meeting of the minds', 'contract formation', 'bargain',
        'promissory estoppel', 'detrimental reliance', 'statute of frauds'
    ],
    'negligence': [
        'negligence', 'duty of care', 'breach of duty', 'causation',
        'proximate cause', 'but-for causation', 'damages', 'foreseeability',
        'reasonable person', 'standard of care', 'res ipsa loquitur'
    ],
    'constitutional': [
        'first amendment', 'free speech', 'establishment clause', 'free exercise',
        'fourth amendment', 'search and seizure', 'probable cause', 'warrant',
        'fifth amendment', 'self-incrimination', 'double jeopardy', 'takings',
        'sixth amendment', 'right to counsel', 'confrontation clause',
        'fourteenth amendment', 'equal protection', 'due process', 'incorporation'
    ],
    'evidence': [
        'hearsay', 'relevance', 'probative value', 'prejudicial effect',
        'Rule 403', 'character evidence', 'prior bad acts', 'Rule 404',
        'expert testimony', 'Daubert', 'foundation', 'authentication'
    ],
    'appellate': [
        'appeal', 'standard of review', 'de novo', 'abuse of discretion',
        'clearly erroneous', 'harmless error', 'plain error', 'preservation',
        'brief', 'oral argument', 'remand', 'affirm', 'reverse', 'vacate'
    ]
}

# ============================================================================
# Procedural Stage Awareness
# ============================================================================
LITIGATION_PHASES = {
    'pre_filing': [
        'demand letter', 'cease and desist', 'pre-litigation', 'settlement demand',
        'tolling agreement', 'notice of claim', 'administrative exhaustion'
    ],
    'pleadings': [
        'complaint', 'answer', 'motion to dismiss', 'amended complaint',
        '12(b)(6)', 'Rule 8', 'affirmative defense', 'counterclaim', 'cross-claim',
        'third-party complaint', 'default judgment', 'responsive pleading'
    ],
    'discovery': [
        'interrogatory', 'deposition', 'production request', 'expert disclosure',
        'Rule 26', 'initial disclosure', 'discovery dispute', 'motion to compel',
        'protective order', 'privilege log', 'subpoena duces tecum'
    ],
    'pre_trial': [
        'motion in limine', 'summary judgment', 'pretrial conference', 'trial brief',
        'Rule 56', 'pretrial order', 'witness list', 'exhibit list', 'Daubert motion'
    ],
    'trial': [
        'voir dire', 'opening statement', 'direct examination', 'jury instruction',
        'closing argument', 'cross-examination', 'redirect', 'motion for directed verdict',
        'judgment as a matter of law', 'jury deliberation', 'verdict'
    ],
    'post_trial': [
        'post-trial motion', 'motion for new trial', 'remittitur', 'additur',
        'judgment notwithstanding verdict', 'JNOV', 'Rule 59', 'Rule 60'
    ],
    'appeal': [
        'notice of appeal', 'appellate brief', 'oral argument', 'writ of certiorari',
        'petition for review', 'interlocutory appeal', 'mandamus', 'en banc'
    ],
    'enforcement': [
        'judgment enforcement', 'execution', 'garnishment', 'judgment lien',
        'contempt', 'collection', 'satisfaction of judgment', 'judgment debtor'
    ]
}

def detect_litigation_phase(text: str) -> str:
    """Detect which litigation phase a text relates to"""
    text_lower = text.lower()
    phase_scores = {}
    for phase, terms in LITIGATION_PHASES.items():
        score = sum(1 for term in terms if term in text_lower)
        if score > 0:
            phase_scores[phase] = score
    if phase_scores:
        return max(phase_scores, key=phase_scores.get)
    return 'general'

# ============================================================================
# Jurisdictional Intelligence
# ============================================================================
JURISDICTION_HIERARCHY = {
    'us_supreme_court': 1.5,
    'federal_circuit': 1.3,
    'federal_district': 1.1,
    'state_supreme': 1.2,
    'state_appellate': 1.0,
    'state_trial': 0.8,
    'administrative': 0.7,
    'secondary': 0.6,  # Treatises, law reviews
}

# Patterns to detect court level from text
COURT_PATTERNS = {
    'us_supreme_court': [r'supreme court of the united states', r'\d+\s+u\.?s\.?\s+\d+', r'scotus'],
    'federal_circuit': [r'circuit court of appeals', r'\d+\s+f\.(?:2d|3d|4th)\s+\d+', r'court of appeals'],
    'federal_district': [r'district court', r'\d+\s+f\.\s*supp', r'u\.s\. district'],
    'state_supreme': [r'supreme court of \w+', r'court of last resort'],
    'state_appellate': [r'court of appeal', r'appellate division', r'intermediate appellate'],
}

def detect_court_level(text: str) -> str:
    """Detect court level from text content"""
    text_lower = text.lower()
    for level, patterns in COURT_PATTERNS.items():
        for pattern in patterns:
            if re.search(pattern, text_lower):
                return level
    return 'unknown'

# ============================================================================
# Case Shepardization Signals
# ============================================================================
CITATION_SIGNALS = {
    'positive': ['affirmed', 'followed', 'adopted', 'approved', 'cited with approval',
                 'reaffirmed', 'extended', 'relied upon', 'consistent with'],
    'negative': ['reversed', 'overruled', 'disapproved', 'criticized', 'limited',
                 'rejected', 'declined to follow', 'abrogated', 'superseded'],
    'neutral': ['cited', 'discussed', 'mentioned', 'referenced', 'see also', 'cf.'],
    'caution': ['distinguished', 'questioned', 'called into doubt', 'but see',
                'contra', 'modified', 'clarified', 'narrowed']
}

def analyze_citation_treatment(text: str) -> dict:
    """Analyze how citations are treated in text"""
    text_lower = text.lower()
    treatment = {'positive': 0, 'negative': 0, 'neutral': 0, 'caution': 0}
    for signal_type, signals in CITATION_SIGNALS.items():
        for signal in signals:
            if signal in text_lower:
                treatment[signal_type] += 1
    return treatment

# ============================================================================
# Doctrine-Specific Enhancement
# ============================================================================
LEGAL_DOCTRINES = {
    'chevron': [
        'Chevron deference', 'administrative deference', 'agency interpretation',
        'step one', 'step two', 'ambiguous statute', 'reasonable interpretation',
        'Chevron U.S.A. v. NRDC', 'Auer deference', 'Skidmore deference'
    ],
    'erie': [
        'Erie doctrine', 'substantive vs procedural', 'choice of law',
        'Erie Railroad v. Tompkins', 'federal common law', 'outcome determinative',
        'twin aims of Erie', 'Hanna v. Plumer', 'vertical choice of law'
    ],
    'justiciability': [
        'ripeness', 'mootness', 'standing', 'case or controversy',
        'injury in fact', 'causation', 'redressability', 'Lujan v. Defenders',
        'capable of repetition', 'political question', 'advisory opinion'
    ],
    'qualified_immunity': [
        'qualified immunity', 'clearly established law', 'government immunity',
        'reasonable officer', 'constitutional violation', 'Harlow v. Fitzgerald',
        'sequencing', 'Pearson v. Callahan', 'sovereign immunity'
    ],
    'class_action': [
        'class action', 'Rule 23', 'class certification', 'commonality', 'typicality',
        'adequacy', 'numerosity', 'predominance', 'superiority', 'opt-out',
        'settlement class', 'cy pres', 'class notice'
    ],
    'res_judicata': [
        'res judicata', 'claim preclusion', 'issue preclusion', 'collateral estoppel',
        'final judgment', 'same cause of action', 'same parties', 'mutuality',
        'offensive collateral estoppel', 'defensive collateral estoppel'
    ],
    'fourth_amendment': [
        'search and seizure', 'probable cause', 'warrant requirement', 'exclusionary rule',
        'fruit of the poisonous tree', 'good faith exception', 'exigent circumstances',
        'plain view', 'consent search', 'Terry stop', 'reasonable suspicion'
    ],
    'first_amendment': [
        'free speech', 'content-based', 'content-neutral', 'strict scrutiny',
        'intermediate scrutiny', 'time place manner', 'prior restraint',
        'public forum', 'commercial speech', 'symbolic speech', 'compelled speech'
    ]
}

def get_relevant_doctrines(query: str) -> list:
    """Get doctrines relevant to a query"""
    query_lower = query.lower()
    relevant = []
    for doctrine, terms in LEGAL_DOCTRINES.items():
        if any(term.lower() in query_lower for term in terms[:3]):
            relevant.append(doctrine)
    return relevant

# ============================================================================
# Legal Citation Pattern Recognition
# ============================================================================
LEGAL_CITATION_PATTERNS = [
    (r'\d+\s+U\.?S\.?\s+\d+', 'us_supreme_court'),           # Supreme Court: 410 U.S. 113
    (r'\d+\s+S\.?\s*Ct\.?\s+\d+', 'us_supreme_court'),       # S. Ct. citation
    (r'\d+\s+L\.?\s*Ed\.?\s*(?:2d)?\s+\d+', 'us_supreme_court'),  # L. Ed. citation
    (r'\d+\s+F\.(?:2d|3d|4th)?\s+\d+', 'federal_circuit'),   # Federal Reporter
    (r'\d+\s+F\.\s*Supp\.?(?:\s*2d|3d)?\s+\d+', 'federal_district'),  # Federal Supplement
    (r'\d+\s+S\.?\s*E\.?(?:\s*2d)?\s+\d+', 'state_appellate'),  # South Eastern Reporter
    (r'\d+\s+S\.?\s*W\.?(?:\s*2d|3d)?\s+\d+', 'state_appellate'),  # South Western Reporter
    (r'\d+\s+N\.?\s*E\.?(?:\s*2d|3d)?\s+\d+', 'state_appellate'),  # North Eastern Reporter
    (r'\d+\s+N\.?\s*W\.?(?:\s*2d)?\s+\d+', 'state_appellate'),  # North Western Reporter
    (r'\d+\s+P\.?(?:\s*2d|3d)?\s+\d+', 'state_appellate'),   # Pacific Reporter
    (r'\d+\s+A\.?(?:\s*2d|3d)?\s+\d+', 'state_appellate'),   # Atlantic Reporter
    (r'\d+\s+So\.?(?:\s*2d|3d)?\s+\d+', 'state_appellate'),  # Southern Reporter
    (r'\d+\s+U\.?S\.?C\.?(?:\.?A\.?)?\s*§?\s*\d+', 'statute'),  # U.S. Code
    (r'\d+\s+C\.?F\.?R\.?\s*§?\s*\d+', 'regulation'),        # Code of Federal Regulations
    (r'Fed\.?\s*R\.?\s*Civ\.?\s*P\.?\s*\d+', 'rule'),        # Federal Rules of Civil Procedure
    (r'Fed\.?\s*R\.?\s*Evid\.?\s*\d+', 'rule'),              # Federal Rules of Evidence
]

def extract_citations(text: str) -> list:
    """Extract and categorize legal citations from text"""
    citations = []
    for pattern, citation_type in LEGAL_CITATION_PATTERNS:
        matches = re.finditer(pattern, text, re.IGNORECASE)
        for match in matches:
            citations.append({
                'citation': match.group(),
                'type': citation_type,
                'position': match.start()
            })
    return citations

def count_authoritative_citations(text: str) -> dict:
    """Count citations by authority level"""
    citations = extract_citations(text)
    counts = {'supreme': 0, 'circuit': 0, 'district': 0, 'state': 0, 'other': 0}
    for cit in citations:
        if cit['type'] == 'us_supreme_court':
            counts['supreme'] += 1
        elif cit['type'] == 'federal_circuit':
            counts['circuit'] += 1
        elif cit['type'] == 'federal_district':
            counts['district'] += 1
        elif cit['type'] == 'state_appellate':
            counts['state'] += 1
        else:
            counts['other'] += 1
    return counts

# ============================================================================
# Query Expansion with Concept Clusters
# ============================================================================
def expand_query_with_concepts(query: str) -> list:
    """Expand query with related legal concepts"""
    query_lower = query.lower()
    expanded_terms = []

    for cluster_name, terms in LEGAL_CONCEPT_CLUSTERS.items():
        # Check if any cluster terms appear in query (check all terms)
        if any(term.lower() in query_lower for term in terms):
            expanded_terms.extend(terms)
        # Also check if cluster name matches (e.g., "summary_judgment" -> "summary judgment")
        elif cluster_name.replace('_', ' ') in query_lower:
            expanded_terms.extend(terms)

    # Also check doctrine terms
    for doctrine, terms in LEGAL_DOCTRINES.items():
        if any(term.lower() in query_lower for term in terms):
            expanded_terms.extend(terms)
        elif doctrine.replace('_', ' ') in query_lower:
            expanded_terms.extend(terms)

    # Remove duplicates while preserving order
    seen = set()
    unique_terms = []
    for term in expanded_terms:
        if term.lower() not in seen:
            seen.add(term.lower())
            unique_terms.append(term)

    return unique_terms

# ============================================================================
# Validation Queries for RAG Testing
# ============================================================================
VALIDATION_QUERIES = [
    ("What's the standard for summary judgment?",
     ["genuine issue", "material fact", "Celotex", "moving party"]),
    ("How does attorney-client privilege work?",
     ["confidential", "communication", "privilege", "waiver"]),
    ("What's required for personal jurisdiction?",
     ["minimum contacts", "purposeful", "jurisdiction"]),
    ("Explain Chevron deference",
     ["agency", "interpretation", "deference", "ambiguous"]),
    ("What are the elements of negligence?",
     ["duty", "breach", "causation", "damages"]),
    ("How does res judicata work?",
     ["claim preclusion", "final judgment", "same parties"]),
    ("What is qualified immunity?",
     ["clearly established", "reasonable", "immunity"]),
    ("Explain the Erie doctrine",
     ["substantive", "procedural", "state law", "federal"]),
]

# ============================================================================
# ADVANCED LEGAL REASONING SYSTEM
# ============================================================================

# Precedential Hierarchy - Authority Weighting
PRECEDENTIAL_HIERARCHY = {
    'gold_standard': {
        'weight': 2.0,
        'sources': ['U.S. Supreme Court', 'Supreme Court', 'SCOTUS', 'Restatement',
                    'Model Rules', 'Uniform Commercial Code', 'UCC'],
        'tags': ['landmark', 'seminal', 'watershed', 'controlling', 'binding']
    },
    'highly_persuasive': {
        'weight': 1.5,
        'sources': ['Circuit Court', 'Court of Appeals', 'Federal Appellate',
                    'Treatise', 'Wright & Miller', 'Prosser', 'Corbin', 'Williston'],
        'tags': ['well-reasoned', 'authoritative', 'influential', 'leading']
    },
    'persuasive': {
        'weight': 1.2,
        'sources': ['District Court', 'State Supreme Court', 'Law Review',
                    'American Law Reports', 'ALR'],
        'tags': ['instructive', 'informative', 'useful']
    },
    'declining': {
        'weight': 0.4,
        'sources': [],
        'tags': ['abrogated', 'superseded', 'overruled', 'questioned',
                 'criticized', 'outdated', 'distinguished away']
    }
}

def get_precedential_weight(text: str, metadata: dict = None) -> float:
    """Calculate precedential weight based on source authority"""
    text_lower = text.lower()
    source = (metadata or {}).get('source', '').lower()

    # Check gold standard
    for src in PRECEDENTIAL_HIERARCHY['gold_standard']['sources']:
        if src.lower() in text_lower or src.lower() in source:
            return PRECEDENTIAL_HIERARCHY['gold_standard']['weight']

    # Check for declining authority (negative signals)
    for tag in PRECEDENTIAL_HIERARCHY['declining']['tags']:
        if tag in text_lower:
            return PRECEDENTIAL_HIERARCHY['declining']['weight']

    # Check highly persuasive
    for src in PRECEDENTIAL_HIERARCHY['highly_persuasive']['sources']:
        if src.lower() in text_lower or src.lower() in source:
            return PRECEDENTIAL_HIERARCHY['highly_persuasive']['weight']

    # Check persuasive
    for src in PRECEDENTIAL_HIERARCHY['persuasive']['sources']:
        if src.lower() in text_lower or src.lower() in source:
            return PRECEDENTIAL_HIERARCHY['persuasive']['weight']

    return 1.0  # Default weight


# Temporal Reasoning - Precedent Evolution
def calculate_temporal_weight(chunk_metadata: dict, query_context: dict = None) -> float:
    """Calculate weight based on temporal relevance of legal authority"""
    current_year = datetime.now().year
    chunk_year = chunk_metadata.get('year')

    if not chunk_year:
        # Try to extract year from text patterns
        return 1.0

    year_diff = current_year - chunk_year
    doc_type = chunk_metadata.get('document_type', '').lower()
    tags = chunk_metadata.get('tags', [])

    # Statutes: Recent interpretations matter more
    if 'statute' in doc_type or 'regulation' in doc_type or any('statute' in t for t in tags):
        if year_diff < 5:
            return 1.2
        elif year_diff < 10:
            return 1.0
        elif year_diff < 20:
            return 0.8
        else:
            return 0.6

    # Case law: Stability matters, but landmark cases remain relevant
    if any(tag in ['landmark', 'seminal', 'watershed'] for tag in tags):
        return 1.0  # Landmark cases don't decay

    if year_diff < 10:
        return 1.0
    elif year_diff < 25:
        return 0.85
    elif year_diff < 50:
        return 0.7
    else:
        return 0.5


# Argument Structure Recognition
ARGUMENT_SCHEMAS = {
    'rule_synthesis': {
        'patterns': ['if', 'then', 'when', 'where', 'provided that', 'unless',
                     'therefore', 'thus', 'hence', 'accordingly', 'consequently'],
        'weight': 1.3,
        'type': 'deductive',
        'description': 'Rule-based deductive reasoning'
    },
    'policy_balancing': {
        'patterns': ['on one hand', 'on the other hand', 'balance', 'weigh',
                     'competing interests', 'countervailing', 'policy considerations',
                     'public interest', 'private rights'],
        'weight': 1.25,
        'type': 'dialectical',
        'description': 'Policy-based balancing analysis'
    },
    'analogical': {
        'patterns': ['similarly', 'like', 'analogous', 'distinguishable', 'akin to',
                     'comparable', 'in contrast', 'unlike', 'parallel to'],
        'weight': 1.4,
        'type': 'comparative',
        'description': 'Analogical reasoning from precedent'
    },
    'doctrinal_framework': {
        'patterns': ['three-part test', 'four-factor', 'elements', 'prongs',
                     'step one', 'step two', 'first', 'second', 'third',
                     'multi-factor', 'totality of circumstances'],
        'weight': 1.35,
        'type': 'framework',
        'description': 'Structured doctrinal analysis'
    },
    'textual_analysis': {
        'patterns': ['plain meaning', 'plain language', 'ordinary meaning',
                     'statutory text', 'unambiguous', 'literal interpretation',
                     'express terms', 'legislative intent'],
        'weight': 1.2,
        'type': 'textual',
        'description': 'Textual/statutory interpretation'
    },
    'historical_analysis': {
        'patterns': ['historically', 'at common law', 'traditional', 'longstanding',
                     'originally', 'founding era', 'framers', 'original understanding'],
        'weight': 1.15,
        'type': 'historical',
        'description': 'Historical/originalist analysis'
    }
}

def detect_argument_structure(text: str) -> list:
    """Identify legal reasoning patterns in text"""
    text_lower = text.lower()
    detected = []

    for schema_name, schema in ARGUMENT_SCHEMAS.items():
        matches = 0
        matched_patterns = []

        for pattern in schema['patterns']:
            if pattern in text_lower:
                matches += 1
                matched_patterns.append(pattern)

        if matches > 0:
            confidence = min(1.0, matches / (len(schema['patterns']) * 0.3))
            detected.append({
                'schema': schema_name,
                'confidence': round(confidence, 2),
                'weight': schema['weight'],
                'type': schema['type'],
                'description': schema['description'],
                'matched_patterns': matched_patterns[:5]
            })

    # Sort by confidence
    detected.sort(key=lambda x: x['confidence'], reverse=True)
    return detected


# Legal Cross-References - Doctrine Relationships
LEGAL_CROSS_REFERENCES = {
    'due_process': {
        'related': ['equal protection', 'fundamental rights', 'procedural fairness',
                    'substantive due process', 'notice and hearing'],
        'contrasting': ['police power', 'state interest', 'rational basis'],
        'tests': ['strict scrutiny', 'intermediate scrutiny', 'rational basis review'],
        'applications': ['criminal procedure', 'administrative law', 'family law', 'immigration']
    },
    'equal_protection': {
        'related': ['due process', 'suspect classification', 'fundamental rights'],
        'contrasting': ['legitimate government interest', 'police power'],
        'tests': ['strict scrutiny', 'intermediate scrutiny', 'rational basis'],
        'applications': ['voting rights', 'education', 'employment discrimination']
    },
    'first_amendment': {
        'related': ['free speech', 'free exercise', 'establishment clause', 'free press'],
        'contrasting': ['compelling interest', 'content-neutral regulation'],
        'tests': ['strict scrutiny', 'intermediate scrutiny', 'time place manner'],
        'applications': ['political speech', 'commercial speech', 'religious liberty']
    },
    'fourth_amendment': {
        'related': ['search and seizure', 'probable cause', 'warrant requirement'],
        'contrasting': ['exigent circumstances', 'consent', 'plain view'],
        'tests': ['reasonable expectation of privacy', 'totality of circumstances'],
        'applications': ['criminal procedure', 'administrative searches', 'border searches']
    },
    'contract_interpretation': {
        'related': ['plain meaning', 'parol evidence', 'course of dealing'],
        'canons': ['contra proferentem', 'ejusdem generis', 'noscitur a sociis'],
        'sources': ['text', 'context', 'trade usage', 'course of performance'],
        'remedies': ['damages', 'specific performance', 'rescission', 'reformation']
    },
    'negligence': {
        'related': ['duty of care', 'breach', 'causation', 'damages', 'proximate cause'],
        'contrasting': ['contributory negligence', 'comparative fault', 'assumption of risk'],
        'tests': ['reasonable person', 'foreseeability', 'but-for causation'],
        'applications': ['medical malpractice', 'products liability', 'premises liability']
    },
    'summary_judgment': {
        'related': ['genuine issue', 'material fact', 'moving party', 'non-moving party'],
        'standards': ['no genuine dispute', 'view evidence favorably', 'draw inferences'],
        'key_cases': ['Celotex', 'Anderson v. Liberty Lobby', 'Matsushita'],
        'procedures': ['Rule 56', 'statement of facts', 'response', 'reply']
    },
    'personal_jurisdiction': {
        'related': ['minimum contacts', 'purposeful availment', 'fair play'],
        'types': ['general jurisdiction', 'specific jurisdiction', 'consent'],
        'tests': ['International Shoe', 'stream of commerce', 'effects test'],
        'applications': ['internet jurisdiction', 'contract disputes', 'tort claims']
    }
}

def get_related_concepts(query: str) -> list:
    """Get related legal concepts for query expansion"""
    query_lower = query.lower()
    related = []

    for doctrine, refs in LEGAL_CROSS_REFERENCES.items():
        doctrine_clean = doctrine.replace('_', ' ')
        if doctrine_clean in query_lower:
            # Add related concepts
            related.extend(refs.get('related', []))
            related.extend(refs.get('tests', []))
            related.extend(refs.get('key_cases', []))

        # Also check if any related terms are in query
        for rel_term in refs.get('related', []):
            if rel_term in query_lower:
                related.extend(refs.get('related', []))
                break

    # Deduplicate
    return list(set(related))


def extract_primary_doctrine(text: str) -> str:
    """Extract the primary legal doctrine from text"""
    text_lower = text.lower()

    # Check each doctrine
    for doctrine in LEGAL_CROSS_REFERENCES.keys():
        doctrine_clean = doctrine.replace('_', ' ')
        if doctrine_clean in text_lower:
            return doctrine

        # Check related terms
        refs = LEGAL_CROSS_REFERENCES[doctrine]
        for term in refs.get('related', [])[:3]:
            if term in text_lower:
                return doctrine

    return None


# Enhanced Chunk Metadata Schema
ENHANCED_METADATA_FIELDS = {
    'primary_doctrine': None,
    'supporting_doctrines': [],
    'argument_structures': [],
    'precedential_weight': 1.0,
    'temporal_weight': 1.0,
    'authority_level': 'unknown',
    'jurisdictional_relevance': 1.0,
    'citation_count': 0,
    'year': None
}

def enhance_chunk_metadata(chunk: dict) -> dict:
    """Add advanced legal metadata to a chunk"""
    content = chunk.get('content', '')
    metadata = chunk.get('metadata', {})

    # Detect argument structures
    arg_structures = detect_argument_structure(content)
    if arg_structures:
        chunk['argument_structures'] = arg_structures[:3]
        chunk['primary_argument_type'] = arg_structures[0]['type']

    # Calculate precedential weight
    chunk['precedential_weight'] = get_precedential_weight(content, metadata)

    # Calculate temporal weight
    chunk['temporal_weight'] = calculate_temporal_weight(metadata)

    # Extract primary doctrine
    chunk['primary_doctrine'] = extract_primary_doctrine(content)

    # Count citations
    citations = extract_citations(content)
    chunk['citation_count'] = len(citations)
    chunk['has_supreme_court_citation'] = any(
        c['type'] == 'us_supreme_court' for c in citations
    )

    return chunk


# BM25 parameters
BM25_K1 = 1.5
BM25_B = 0.75

# ============================================================================
# PDF Processing
# ============================================================================
# Try to import PDF libraries
try:
    import fitz  # PyMuPDF
    HAS_PYMUPDF = True
except ImportError:
    HAS_PYMUPDF = False

try:
    from pdfminer.high_level import extract_text as pdfminer_extract
    HAS_PDFMINER = True
except ImportError:
    HAS_PDFMINER = False

try:
    from PyPDF2 import PdfReader
    HAS_PYPDF2 = True
except ImportError:
    HAS_PYPDF2 = False

# Try numpy for embeddings
try:
    import numpy as np
    HAS_NUMPY = True
except ImportError:
    HAS_NUMPY = False

# Try sentence-transformers for embeddings
try:
    from sentence_transformers import SentenceTransformer
    HAS_SENTENCE_TRANSFORMERS = True
except ImportError:
    HAS_SENTENCE_TRANSFORMERS = False

# Try FAISS for vector search
try:
    import faiss
    HAS_FAISS = True
except ImportError:
    HAS_FAISS = False

# Try requests for OCR fallback
try:
    import requests
    HAS_REQUESTS = True
except ImportError:
    HAS_REQUESTS = False


# ============================================================================
# LLM Integration for Enrichment
# ============================================================================
class LLMService:
    """LLM service using local DeepSeek for enrichment, Grok for chat fallback"""

    def __init__(self):
        # Local DeepSeek LLM (primary for enrichment - GPU accelerated)
        self.deepseek_url = os.environ.get('DEEPSEEK_LLM_URL', 'http://127.0.0.1:5004/v1/chat/completions')
        self.deepseek_model = os.environ.get('DEEPSEEK_MODEL', 'Qwen/QwQ-32B')
        self.deepseek_available = True

        # Grok API (fallback for chat if DeepSeek fails)
        self.grok_api_key = (os.environ.get('GROK_API_KEY') or os.environ.get('XAI_API_KEY', '')).strip().strip('"').strip("'")
        self.grok_url = "https://api.x.ai/v1/chat/completions"
        self.grok_model = "grok-4-fast-reasoning"

        # Check DeepSeek availability
        self._check_deepseek()

        if not self.deepseek_available and not self.grok_api_key:
            print("[RAG] WARNING: Neither DeepSeek nor Grok available. LLM features will fail.")

    def _check_deepseek(self):
        """Check if local DeepSeek is available"""
        if not HAS_REQUESTS:
            self.deepseek_available = False
            return

        try:
            # Quick health check
            base_url = self.deepseek_url.replace('/v1/chat/completions', '/v1/models')
            response = requests.get(base_url, timeout=5)
            if response.status_code == 200:
                print(f"[RAG] Local DeepSeek available: {self.deepseek_model}")
                self.deepseek_available = True
            else:
                print(f"[RAG] Local DeepSeek not responding, will use Grok fallback")
                self.deepseek_available = False
        except Exception as e:
            print(f"[RAG] Local DeepSeek check failed: {e}, will use Grok fallback")
            self.deepseek_available = False

    def generate(self, prompt: str, max_tokens: int = 800, temperature: float = 0.5) -> Optional[str]:
        """Generate text using DeepSeek (primary) or Grok (fallback)"""
        # Try DeepSeek first (local, faster for enrichment)
        if self.deepseek_available:
            result = self._call_deepseek(prompt, max_tokens, temperature)
            if result:
                return result

        # Fallback to Grok
        return self._call_grok(prompt, max_tokens, temperature)

    def _call_deepseek(self, prompt: str, max_tokens: int, temperature: float) -> Optional[str]:
        """Call local DeepSeek LLM"""
        if not HAS_REQUESTS:
            return None

        try:
            response = requests.post(
                self.deepseek_url,
                headers={'Content-Type': 'application/json'},
                json={
                    'model': self.deepseek_model,
                    'messages': [{'role': 'user', 'content': prompt}],
                    'max_tokens': max_tokens,
                    'temperature': temperature,
                    'stream': False
                },
                timeout=120  # Longer timeout for local model
            )

            if response.status_code == 200:
                data = response.json()
                if 'choices' in data and len(data['choices']) > 0:
                    content = data['choices'][0]['message']['content']
                    # DeepSeek R1 sometimes includes thinking - extract final answer
                    if '<think>' in content and '</think>' in content:
                        # Remove thinking tags and get the answer after
                        import re
                        content = re.sub(r'<think>.*?</think>', '', content, flags=re.DOTALL).strip()
                    return content
            else:
                print(f"[RAG] DeepSeek error: {response.status_code}")
                self.deepseek_available = False  # Mark as unavailable for this session
        except Exception as e:
            print(f"[RAG] DeepSeek exception: {e}")
            self.deepseek_available = False

        return None

    def _call_grok(self, prompt: str, max_tokens: int, temperature: float) -> Optional[str]:
        """Call Grok API (fallback)"""
        if not HAS_REQUESTS or not self.grok_api_key:
            print("[RAG] Grok API not configured (no API key)")
            return None

        try:
            print(f"[RAG] Calling Grok fallback ({self.grok_model})...")
            response = requests.post(
                self.grok_url,
                headers={
                    'Authorization': f'Bearer {self.grok_api_key}',
                    'Content-Type': 'application/json'
                },
                json={
                    'model': self.grok_model,
                    'messages': [{'role': 'user', 'content': prompt}],
                    'max_tokens': max_tokens,
                    'temperature': temperature,
                    'stream': False
                },
                timeout=45
            )

            if response.status_code == 200:
                data = response.json()
                if 'choices' in data and len(data['choices']) > 0:
                    return data['choices'][0]['message']['content']
            else:
                print(f"[RAG] Grok API error: {response.status_code} - {response.text[:500] if response.text else '(empty)'}")
        except Exception as e:
            print(f"[RAG] Grok API exception: {e}")

        return None


class PDFProcessor:
    """Advanced PDF processor with structure-aware chunking"""
    
    def __init__(self, config: RAGConfig):
        self.config = config
        self.boilerplate_patterns = [
            r'copyright\s+©?\s*\d{4}',
            r'all rights reserved',
            r'page\s+\d+\s+of\s+\d+',
            r'^\s*$',
            r'^\d+$',
            r'[a-f0-9]{32,}',
            r'^[\s\W]*$',
        ]
    
    def extract_text(self, file_path: Path) -> str:
        """Extract text using best available method"""
        if file_path.suffix.lower() == '.pdf':
            # Try PyMuPDF first (best quality)
            if HAS_PYMUPDF:
                try:
                    return self._extract_with_pymupdf(file_path)
                except Exception as e:
                    print(f"PyMuPDF failed: {e}")
            
            # Try pdfminer
            if HAS_PDFMINER:
                try:
                    return pdfminer_extract(str(file_path))
                except Exception as e:
                    print(f"pdfminer failed: {e}")
            
            # Try PyPDF2
            if HAS_PYPDF2:
                try:
                    reader = PdfReader(file_path)
                    return '\n'.join(page.extract_text() or '' for page in reader.pages)
                except Exception as e:
                    print(f"PyPDF2 failed: {e}")
            
            return ""
        
        elif file_path.suffix.lower() == '.txt':
            try:
                with open(file_path, 'r', encoding='utf-8') as f:
                    return f.read()
            except UnicodeDecodeError:
                with open(file_path, 'r', encoding='latin-1') as f:
                    return f.read()
        
        return ""
    
    def _extract_with_pymupdf(self, file_path: Path) -> str:
        """High-fidelity extraction using PyMuPDF"""
        doc = fitz.open(file_path)
        pages = []
        try:
            for page in doc:
                text = page.get_text("text", flags=fitz.TEXT_DEHYPHENATE | fitz.TEXT_PRESERVE_WHITESPACE)
                pages.append(text)
        finally:
            doc.close()
        return "\n".join(pages)
    
    def detect_document_type(self, text: str, filename: str) -> str:
        """Detect document type for optimal chunking"""
        text_lower = text.lower()
        filename_lower = filename.lower()
        
        # Sports betting indicators
        sports_indicators = ['betting', 'odds', 'spread', 'moneyline', 'parlay', 'sportsbook', 'handicapping']
        if sum(1 for ind in sports_indicators if ind in text_lower) >= 2:
            return 'sports'
        
        # Technical indicators
        tech_indicators = ['specification', 'technical', 'api', 'protocol', 'algorithm']
        if sum(1 for ind in tech_indicators if ind in text_lower) >= 2:
            return 'technical'
        
        # Research indicators
        research_indicators = ['abstract', 'methodology', 'results', 'conclusion', 'references']
        if sum(1 for ind in research_indicators if ind in text_lower) >= 2:
            return 'research'
        
        return 'default'
    
    def split_into_sentences(self, text: str) -> List[str]:
        """Split text into sentences"""
        if not text:
            return []
        
        # Sentence splitting regex
        sentence_endings = r'(?<=[.!?])\s+(?=[A-Z])|(?<=[.!?])\s*$|(?<=[.!?"\'])\s+(?=[A-Z])'
        sentences = re.split(sentence_endings, text)
        return [s.strip() for s in sentences if len(s.strip()) > 10]
    
    def detect_headings(self, text: str) -> List[Tuple[str, int, int]]:
        """Detect headings in document"""
        headings = []
        lines = text.split('\n')
        
        for i, line in enumerate(lines):
            line_stripped = line.strip()
            if not line_stripped or len(line_stripped) < 3:
                continue
            
            # ALL CAPS headings
            if line_stripped.isupper() and len(line_stripped) >= 4 and len(line_stripped) < 80:
                headings.append((line_stripped, i, 1))
                continue
            
            # Numbered headings
            numbered_match = re.match(r'^\s*(\d+)(?:\.(\d+))?\s+(.+)$', line_stripped)
            if numbered_match:
                depth = sum(1 for g in numbered_match.groups()[:2] if g)
                heading_text = numbered_match.group(3)
                if len(heading_text) > 3:
                    headings.append((heading_text, i, min(depth, 3)))
        
        return headings
    
    def chunk_text(self, text: str, doc_type: str = 'default') -> List[Dict[str, Any]]:
        """Structure-aware chunking"""
        config = DOCUMENT_TYPE_CONFIGS.get(doc_type, DOCUMENT_TYPE_CONFIGS['default'])
        chunk_size = config['chunk_size']
        chunk_min = config['chunk_min_size']
        chunk_max = config['chunk_max_size']
        overlap = config['sentence_overlap']
        min_sentences = config['min_sentences']
        
        if not text or len(text.strip()) < chunk_min:
            return []
        
        # Try structure-aware chunking first
        headings = self.detect_headings(text)
        
        if headings:
            chunks = self._chunk_by_sections(text, headings, config)
            if chunks:
                return chunks
        
        # Fall back to sentence-based chunking
        return self._sentence_based_chunking(text, config)
    
    def _chunk_by_sections(self, text: str, headings: List[Tuple], config: dict) -> List[Dict]:
        """Chunk text by detected sections"""
        chunks = []
        lines = text.split('\n')
        
        # Process sections between headings
        for i, (heading_text, heading_pos, level) in enumerate(headings):
            end_pos = headings[i + 1][1] if i + 1 < len(headings) else len(lines)
            section_content = '\n'.join(lines[heading_pos + 1:end_pos]).strip()
            
            if section_content and len(section_content) >= config['chunk_min_size']:
                section_chunks = self._sentence_based_chunking(section_content, config)
                for chunk in section_chunks:
                    chunk['section'] = heading_text
                    chunk['section_level'] = level
                chunks.extend(section_chunks)
        
        return chunks
    
    def _sentence_based_chunking(self, text: str, config: dict) -> List[Dict]:
        """Sentence-based chunking with overlap"""
        sentences = self.split_into_sentences(text)
        if len(sentences) < config['min_sentences']:
            return []
        
        chunks = []
        current_chunk = []
        current_length = 0
        
        for sentence in sentences:
            sentence_length = len(sentence) + 1
            
            if current_length + sentence_length > config['chunk_max_size'] and current_chunk:
                chunk_text = ' '.join(current_chunk).strip()
                if len(chunk_text) >= config['chunk_min_size']:
                    chunks.append({
                        'content': chunk_text,
                        'sentence_count': len(current_chunk),
                        'char_count': len(chunk_text)
                    })
                
                # Start new chunk with overlap
                overlap_count = min(config['sentence_overlap'], len(current_chunk))
                current_chunk = current_chunk[-overlap_count:] if overlap_count > 0 else []
                current_length = sum(len(s) + 1 for s in current_chunk)
            
            current_chunk.append(sentence)
            current_length += sentence_length
            
            if current_length >= config['chunk_size'] and len(current_chunk) >= config['min_sentences']:
                chunk_text = ' '.join(current_chunk).strip()
                if len(chunk_text) >= config['chunk_min_size']:
                    chunks.append({
                        'content': chunk_text,
                        'sentence_count': len(current_chunk),
                        'char_count': len(chunk_text)
                    })
                
                overlap_count = min(config['sentence_overlap'], len(current_chunk))
                current_chunk = current_chunk[-overlap_count:] if overlap_count > 0 else []
                current_length = sum(len(s) + 1 for s in current_chunk)
        
        # Final chunk
        if current_chunk:
            chunk_text = ' '.join(current_chunk).strip()
            if len(chunk_text) >= config['chunk_min_size']:
                chunks.append({
                    'content': chunk_text,
                    'sentence_count': len(current_chunk),
                    'char_count': len(chunk_text)
                })
        
        return chunks
    
    def is_valid_chunk(self, chunk: Dict, config: dict) -> bool:
        """Validate chunk quality"""
        content = chunk.get('content', '')
        if len(content) < config.get('chunk_min_size', 400):
            return False
        
        # Check for boilerplate
        content_lower = content.lower()
        for pattern in self.boilerplate_patterns:
            if re.search(pattern, content_lower, re.IGNORECASE):
                return False
        
        # Check for meaningful content
        words = re.findall(r'\b\w+\b', content_lower)
        if len(words) < 5:
            return False
        
        # Check uniqueness ratio
        unique_ratio = len(set(words)) / len(words) if words else 0
        if unique_ratio < 0.3:
            return False
        
        return True
    
    def process_file(self, file_path: Path, source_name: str = None) -> List[Dict]:
        """Process a file and return chunks with metadata"""
        if source_name is None:
            source_name = file_path.name
        
        # Extract text
        text = self.extract_text(file_path)
        if not text or len(text.strip()) < 100:
            return []
        
        # Detect document type
        doc_type = self.detect_document_type(text, source_name)
        config = DOCUMENT_TYPE_CONFIGS.get(doc_type, DOCUMENT_TYPE_CONFIGS['default'])
        
        # Chunk text
        raw_chunks = self.chunk_text(text, doc_type)
        
        # Filter and add metadata
        chunks = []
        for i, chunk in enumerate(raw_chunks):
            if self.is_valid_chunk(chunk, config):
                chunk['source'] = source_name
                chunk['chunk_id'] = i
                chunk['total_chunks'] = len(raw_chunks)
                chunk['document_type'] = doc_type
                chunks.append(chunk)
        
        return chunks


# ============================================================================
# Vector Store with Hybrid Search
# ============================================================================
class VectorStore:
    """FAISS-based vector store with BM25 hybrid search"""
    
    def __init__(self, config: RAGConfig):
        self.config = config
        self.documents = []
        self.embeddings = []
        self.index = None
        self.model = None
        self.searchable_texts = []
        self.doc_lengths = []
        self.avg_doc_length = 0
        
        # File paths
        self.index_file = config.cache_dir / 'faiss_index.index'
        self.metadata_file = config.cache_dir / 'documents_metadata.json'
        self.embeddings_file = config.cache_dir / 'embeddings.npy'
        
        # Initialize embedding model
        self._init_model()
    
    def _init_model(self):
        """Initialize embedding model"""
        if HAS_SENTENCE_TRANSFORMERS:
            try:
                self.model = SentenceTransformer('all-MiniLM-L6-v2')
                self.embedding_dim = 384
                print(f"[RAG] Loaded SentenceTransformer model")
            except Exception as e:
                print(f"[RAG] SentenceTransformer failed: {e}")
                self.model = None
        
        if self.model is None:
            self.embedding_dim = 384
            print("[RAG] Using hash-based embeddings (install sentence-transformers for better results)")
    
    def _get_embedding(self, text: str) -> np.ndarray:
        """Get embedding for text"""
        if self.model:
            return self.model.encode(text, normalize_embeddings=True)
        else:
            # Hash-based fallback
            h = hashlib.sha384(text.encode()).digest()
            arr = np.frombuffer(h, dtype=np.uint8).astype(np.float32)
            return arr / (np.linalg.norm(arr) + 1e-9)
    
    def _get_embeddings_batch(self, texts: List[str]) -> np.ndarray:
        """Get embeddings for multiple texts"""
        if self.model:
            return self.model.encode(texts, normalize_embeddings=True, show_progress_bar=True)
        else:
            return np.array([self._get_embedding(t) for t in texts])
    
    def load(self) -> bool:
        """Load existing index and documents"""
        if not self.metadata_file.exists():
            return False
        
        try:
            with open(self.metadata_file, 'r', encoding='utf-8') as f:
                self.documents = json.load(f)
            
            if HAS_FAISS and self.index_file.exists():
                self.index = faiss.read_index(str(self.index_file))
            elif self.embeddings_file.exists():
                self.embeddings = np.load(str(self.embeddings_file))
            
            self._prepare_searchable_texts()
            print(f"[RAG] Loaded {len(self.documents)} documents")
            return True
        except Exception as e:
            print(f"[RAG] Failed to load index: {e}")
            return False
    
    def save(self):
        """Save index and documents"""
        try:
            with open(self.metadata_file, 'w', encoding='utf-8') as f:
                json.dump(self.documents, f, ensure_ascii=False, indent=2)
            
            if HAS_FAISS and self.index is not None:
                faiss.write_index(self.index, str(self.index_file))
            elif len(self.embeddings) > 0:
                np.save(str(self.embeddings_file), np.array(self.embeddings))
            
            print(f"[RAG] Saved {len(self.documents)} documents")
        except Exception as e:
            print(f"[RAG] Failed to save index: {e}")
    
    def _prepare_searchable_texts(self):
        """Build searchable text cache"""
        self.searchable_texts = []
        self.doc_lengths = []
        
        for doc in self.documents:
            text = self._build_searchable_text(doc)
            self.searchable_texts.append(text.lower())
            self.doc_lengths.append(len(text.split()))
        
        self.avg_doc_length = sum(self.doc_lengths) / len(self.doc_lengths) if self.doc_lengths else 0
    
    def _build_searchable_text(self, doc: Dict) -> str:
        """Build searchable text from document"""
        parts = []
        
        # Add source
        source = doc.get('source', '')
        if source:
            clean_source = re.sub(r'\.(pdf|txt)$', '', source, flags=re.IGNORECASE)
            parts.append(clean_source.replace('_', ' ').replace('-', ' '))
        
        # Add section
        if doc.get('section'):
            parts.append(doc['section'])

        # Enriched metadata (improves keyword/BM25 search without changing embeddings)
        if doc.get('summary'):
            parts.append(str(doc.get('summary')))
        if doc.get('key_points'):
            try:
                parts.append(' '.join([str(p) for p in doc.get('key_points', [])]))
            except Exception:
                pass
        if doc.get('themes'):
            try:
                parts.append(' '.join([str(t) for t in doc.get('themes', [])]))
            except Exception:
                pass
        
        # Add content
        parts.append(doc.get('content', ''))
        
        return ' '.join(parts)

    def save_metadata_only(self):
        """Persist documents metadata safely without rewriting the FAISS index/embeddings."""
        try:
            tmp_path = self.metadata_file.with_suffix(self.metadata_file.suffix + '.tmp')
            with open(tmp_path, 'w', encoding='utf-8') as f:
                json.dump(self.documents, f, ensure_ascii=False, indent=2)
            os.replace(tmp_path, self.metadata_file)  # atomic on POSIX
        except Exception as e:
            print(f"[RAG] Failed to save metadata (checkpoint): {e}")
    
    def add_documents(self, chunks: List[Dict]):
        """Add documents to the store"""
        if not chunks:
            return
        
        contents = [c['content'] for c in chunks]
        new_embeddings = self._get_embeddings_batch(contents)
        
        if HAS_FAISS:
            if self.index is None:
                self.index = faiss.IndexFlatIP(self.embedding_dim)
            
            embeddings_np = np.array(new_embeddings).astype('float32')
            faiss.normalize_L2(embeddings_np)
            self.index.add(embeddings_np)
        else:
            if isinstance(self.embeddings, np.ndarray) and len(self.embeddings) > 0:
                self.embeddings = np.vstack([self.embeddings, new_embeddings])
            else:
                self.embeddings = new_embeddings
        
        self.documents.extend(chunks)
        
        # Update searchable texts
        for chunk in chunks:
            text = self._build_searchable_text(chunk)
            self.searchable_texts.append(text.lower())
            self.doc_lengths.append(len(text.split()))
        
        if self.doc_lengths:
            self.avg_doc_length = sum(self.doc_lengths) / len(self.doc_lengths)
        
        print(f"[RAG] Added {len(chunks)} documents. Total: {len(self.documents)}")
    
    def replace_source(self, source_name: str, new_chunks: List[Dict]):
        """Replace all documents for a source"""
        # Remove old documents
        keep_docs = []
        keep_embeddings = []
        
        for i, doc in enumerate(self.documents):
            if doc.get('source') != source_name:
                keep_docs.append(doc)
                if not HAS_FAISS and len(self.embeddings) > i:
                    keep_embeddings.append(self.embeddings[i])
        
        self.documents = keep_docs
        
        # Rebuild index
        if HAS_FAISS:
            self.index = faiss.IndexFlatIP(self.embedding_dim)
            if keep_docs:
                contents = [d['content'] for d in keep_docs]
                embeddings = self._get_embeddings_batch(contents).astype('float32')
                faiss.normalize_L2(embeddings)
                self.index.add(embeddings)
        else:
            self.embeddings = np.array(keep_embeddings) if keep_embeddings else np.array([])
        
        self._prepare_searchable_texts()
        
        # Add new documents
        self.add_documents(new_chunks)
    
    def search(self, query: str, k: int = 10) -> List[Dict]:
        """Enhanced hybrid vector + keyword search with legal intelligence"""
        if not self.documents:
            return []

        query_lower = query.lower()
        query_words = re.findall(r'\b\w+\b', query_lower)

        # Detect litigation phase for phase-aware boosting
        query_phase = detect_litigation_phase(query)

        # Get relevant legal doctrines
        relevant_doctrines = get_relevant_doctrines(query)

        # Expand query with related legal concepts
        expanded_terms = expand_query_with_concepts(query)

        # Build keyword patterns
        keyword_patterns = []
        added_terms = set()

        # Add original query words
        for word in query_words:
            if len(word) > 2 and word not in added_terms:
                term_tier = get_term_tier(word)
                keyword_patterns.append({
                    'term': word,
                    'pattern': re.compile(r'\b' + re.escape(word) + r'\b', re.IGNORECASE),
                    'is_legal': term_tier >= 0,
                    'tier': term_tier,
                    'weight': get_term_weight(word)
                })
                added_terms.add(word)

        # Add legal terms from query (including multi-word terms)
        for term in LEGAL_TERMS:
            if term in query_lower and term not in added_terms:
                term_tier = get_term_tier(term)
                keyword_patterns.append({
                    'term': term,
                    'pattern': re.compile(r'\b' + re.escape(term) + r'\b', re.IGNORECASE),
                    'is_legal': True,
                    'tier': term_tier,
                    'weight': get_term_weight(term)
                })
                added_terms.add(term)

        # Add expanded concept terms (with slightly reduced weight)
        for term in expanded_terms[:15]:  # Limit expansion to top 15 terms
            if term.lower() not in added_terms:
                term_tier = get_term_tier(term)
                keyword_patterns.append({
                    'term': term,
                    'pattern': re.compile(r'\b' + re.escape(term) + r'\b', re.IGNORECASE),
                    'is_legal': True,
                    'tier': term_tier,
                    'weight': get_term_weight(term) * 0.7,  # Reduced weight for expanded terms
                    'expanded': True
                })
                added_terms.add(term.lower())
        
        # Compute IDF
        idf_table = self._compute_idf(keyword_patterns)
        
        # Vector search
        query_embedding = self._get_embedding(query)
        
        if HAS_FAISS and self.index is not None:
            query_embedding = np.array([query_embedding]).astype('float32')
            faiss.normalize_L2(query_embedding)
            oversample = min(k * 4, len(self.documents))
            scores, indices = self.index.search(query_embedding, oversample)
            vector_results = list(zip(scores[0], indices[0]))
        else:
            # Cosine similarity fallback
            if isinstance(self.embeddings, np.ndarray) and len(self.embeddings) > 0:
                similarities = np.dot(self.embeddings, query_embedding)
                oversample = min(k * 4, len(self.documents))
                top_indices = np.argsort(similarities)[-oversample:][::-1]
                vector_results = [(similarities[i], i) for i in top_indices]
            else:
                vector_results = []
        
        # Score candidates
        candidates = []
        for vector_score, idx in vector_results:
            if idx >= len(self.documents):
                continue
            
            doc = self.documents[idx]
            searchable_text = self.searchable_texts[idx] if idx < len(self.searchable_texts) else ''
            doc_content = doc.get('content', searchable_text)

            # Keyword score with tiered weights
            keyword_score = 0.0
            matched_terms = []
            expanded_matches = []
            for p in keyword_patterns:
                if p['pattern'].search(searchable_text):
                    if p.get('expanded'):
                        expanded_matches.append(p['term'])
                    else:
                        matched_terms.append(p['term'])
                    keyword_score += p.get('weight', 0.3)

            keyword_score = min(1.0, keyword_score)

            # BM25 score
            bm25_score = self._compute_bm25(idx, keyword_patterns, idf_table)

            # Base combined score
            combined = (
                self.config.hybrid_vector_weight * float(vector_score) +
                self.config.hybrid_keyword_weight * keyword_score +
                self.config.hybrid_bm25_weight * bm25_score
            )

            # ============================================================
            # Advanced Legal Intelligence Boosting
            # ============================================================

            # 1. Phase-aware boosting: boost documents from same litigation phase
            doc_phase = detect_litigation_phase(doc_content)
            phase_boost = 1.0
            if query_phase != 'general' and doc_phase == query_phase:
                phase_boost = 1.15  # 15% boost for same phase

            # 2. Jurisdiction/court level boosting
            court_level = detect_court_level(doc_content)
            jurisdiction_boost = JURISDICTION_HIERARCHY.get(court_level, 1.0)

            # 3. Doctrine-specific boosting
            doctrine_boost = 1.0
            for doctrine in relevant_doctrines:
                doctrine_terms = LEGAL_DOCTRINES.get(doctrine, [])
                if any(term.lower() in searchable_text.lower() for term in doctrine_terms):
                    doctrine_boost *= 1.12  # 12% boost per matched doctrine

            # 4. Citation authority boosting
            citation_counts = count_authoritative_citations(doc_content)
            citation_boost = 1.0
            if citation_counts['supreme'] > 0:
                citation_boost += 0.08 * min(citation_counts['supreme'], 3)
            if citation_counts['circuit'] > 0:
                citation_boost += 0.05 * min(citation_counts['circuit'], 3)

            # 5. Citation treatment analysis (positive treatment preferred)
            treatment = analyze_citation_treatment(doc_content)
            treatment_boost = 1.0
            if treatment['positive'] > treatment['negative']:
                treatment_boost = 1.05
            elif treatment['negative'] > treatment['positive'] * 2:
                treatment_boost = 0.90  # Penalize heavily negative treatment

            # 6. Precedential authority weight (gold standard > persuasive > declining)
            precedential_boost = get_precedential_weight(doc_content, doc)

            # 7. Temporal relevance (recent for statutes, stability for case law)
            temporal_boost = calculate_temporal_weight(doc)

            # 8. Argument structure quality (well-structured arguments score higher)
            arg_structures = detect_argument_structure(doc_content)
            argument_boost = 1.0
            if arg_structures:
                best_arg = arg_structures[0]
                if best_arg['confidence'] > 0.3:
                    argument_boost = best_arg['weight']

            # 9. Cross-reference relevance (boost if related concepts present)
            related_concepts = get_related_concepts(query)
            cross_ref_boost = 1.0
            for concept in related_concepts[:5]:
                if concept.lower() in searchable_text.lower():
                    cross_ref_boost += 0.05  # 5% boost per related concept found
            cross_ref_boost = min(1.25, cross_ref_boost)  # Cap at 25% boost

            # Apply all boosts
            final_score = (combined * phase_boost * jurisdiction_boost * doctrine_boost *
                          citation_boost * treatment_boost * precedential_boost *
                          temporal_boost * argument_boost * cross_ref_boost)

            candidates.append({
                'content': doc_content,
                'metadata': {k: v for k, v in doc.items() if k != 'content'},
                'score': final_score,
                'base_score': combined,
                'vector_score': float(vector_score),
                'keyword_score': keyword_score,
                'bm25_score': bm25_score,
                'matched_terms': matched_terms,
                'expanded_matches': expanded_matches,
                'boosts': {
                    'phase': round(phase_boost, 2),
                    'jurisdiction': round(jurisdiction_boost, 2),
                    'doctrine': round(doctrine_boost, 2),
                    'citation': round(citation_boost, 2),
                    'treatment': round(treatment_boost, 2),
                    'precedential': round(precedential_boost, 2),
                    'temporal': round(temporal_boost, 2),
                    'argument': round(argument_boost, 2),
                    'cross_ref': round(cross_ref_boost, 2)
                },
                'phase': doc_phase,
                'court_level': court_level,
                'argument_structure': arg_structures[0] if arg_structures else None,
                'primary_doctrine': extract_primary_doctrine(doc_content)
            })
        
        # Sort and limit by source
        candidates.sort(key=lambda x: x['score'], reverse=True)
        
        results = []
        source_counts = {}
        for c in candidates:
            source = c['metadata'].get('source', 'unknown')
            if source_counts.get(source, 0) < self.config.max_chunks_per_source:
                results.append(c)
                source_counts[source] = source_counts.get(source, 0) + 1
                if len(results) >= k:
                    break
        
        return results
    
    def _compute_idf(self, patterns: List[Dict]) -> Dict[str, float]:
        """Compute IDF for terms"""
        idf_table = {}
        total_docs = len(self.documents)
        if total_docs == 0:
            return idf_table
        
        for p in patterns:
            term = p['term']
            if term in idf_table:
                continue
            
            df = sum(1 for text in self.searchable_texts if p['pattern'].search(text))
            idf = math.log((total_docs - df + 0.5) / (df + 0.5) + 1.0)
            idf_table[term] = max(idf, 0.0)
        
        return idf_table
    
    def _compute_bm25(self, idx: int, patterns: List[Dict], idf_table: Dict) -> float:
        """Compute BM25 score for a document"""
        if not patterns or idx >= len(self.searchable_texts):
            return 0.0
        
        doc_text = self.searchable_texts[idx]
        doc_len = self.doc_lengths[idx] if idx < len(self.doc_lengths) else len(doc_text.split())
        avg_len = self.avg_doc_length or doc_len or 1
        
        bm25 = 0.0
        for p in patterns:
            tf = len(p['pattern'].findall(doc_text))
            if tf == 0:
                continue
            
            idf = idf_table.get(p['term'], 0.0)
            numerator = tf * (BM25_K1 + 1.0)
            denominator = tf + BM25_K1 * (1.0 - BM25_B + BM25_B * (doc_len / avg_len))
            bm25 += idf * (numerator / denominator)
        
        return bm25


# ============================================================================
# OCR Fallback Service
# ============================================================================
class OCRService:
    """OCR fallback for scanned PDFs using DeepSeek OCR (port 5003) and Tesseract (port 5002)"""
    
    def __init__(self, config: RAGConfig):
        self.config = config
        # DeepSeek OCR on port 5003 (GPU accelerated)
        self.deepseek_ocr_url = os.environ.get('DEEPSEEK_OCR_URL', 'http://127.0.0.1:5003/ocr')
        # Tesseract OCR on port 5002 (CPU fallback)
        self.tesseract_ocr_url = os.environ.get('TESSERACT_OCR_URL', 'http://127.0.0.1:5002/ocr')
    
    def extract_text(self, file_path: Path) -> Optional[str]:
        """Try OCR extraction with DeepSeek (primary) and Tesseract (fallback)"""
        if not HAS_REQUESTS:
            return None
        
        # Try DeepSeek OCR first (GPU accelerated, faster)
        text = self._try_ocr(self.deepseek_ocr_url, file_path, "DeepSeek")
        if text:
            return text
        
        # Fall back to Tesseract (CPU)
        text = self._try_ocr(self.tesseract_ocr_url, file_path, "Tesseract")
        return text
    
    def _try_ocr(self, url: str, file_path: Path, service_name: str) -> Optional[str]:
        """Try OCR service"""
        try:
            print(f"[RAG] Trying {service_name} OCR at {url}...")
            with open(file_path, 'rb') as f:
                files = {'file': (file_path.name, f, 'application/pdf')}
                response = requests.post(url, files=files, timeout=600)
            
            if response.status_code == 200:
                data = response.json()
                if data.get('success') and data.get('text'):
                    print(f"[RAG] {service_name} OCR extracted {len(data['text'])} chars")
                    return data['text']
                elif data.get('text'):  # Some OCR services return text without success flag
                    print(f"[RAG] {service_name} OCR extracted {len(data['text'])} chars")
                    return data['text']
        except Exception as e:
            print(f"[RAG] {service_name} OCR failed: {e}")
        
        return None


# ============================================================================
# Enrichment Service
# ============================================================================
class EnrichmentService:
    """Service for LLM-based chunk enrichment"""
    
    def __init__(self, config: RAGConfig):
        self.config = config
        self.llm = LLMService()
        self.enrichment_budget = int(os.environ.get('ENRICHMENT_BUDGET', 999999))
        self.enrichment_used = 0
    
    def enrich_chunk(self, chunk: Dict) -> Dict:
        """Enrich a single chunk with LLM-generated metadata"""
        content = chunk.get('content', '')
        if not content or len(content) < 50:
            return chunk
        
        # Generate basic card first (no LLM)
        card = self._generate_basic_card(content, chunk)
        
        # Try LLM enhancement if within budget
        if self.enrichment_used < self.enrichment_budget:
            enhanced = self._enhance_with_llm(content, card)
            if enhanced:
                card.update(enhanced)
                self.enrichment_used += 1
        
        # Update chunk with enrichment
        enriched_chunk = chunk.copy()
        enriched_chunk['summary'] = card.get('summary', '')
        enriched_chunk['key_points'] = card.get('key_points', [])
        enriched_chunk['themes'] = card.get('themes', [])
        enriched_chunk['enriched'] = True

        # IMPORTANT:
        # Do NOT overwrite `content` during enrichment.
        # The vector embeddings/FAISS index are built from `content`; changing it would require
        # recomputing embeddings and rebuilding the index (expensive + fragile).
        # Instead, the VectorStore builds searchable text using summary/key_points/themes.
        return enriched_chunk
    
    def _generate_basic_card(self, content: str, chunk: Dict) -> Dict:
        """Generate basic knowledge card without LLM"""
        # Split into sentences
        sentences = re.split(r'(?<=[.!?])\s+', content)
        sentences = [s.strip() for s in sentences if len(s.strip()) > 10]
        
        # Basic summary (first 2 sentences)
        summary = ' '.join(sentences[:2])[:400] if sentences else content[:400]
        
        # Key points (first 4 sentences)
        key_points = [s[:220] for s in sentences[:4]]
        
        # Extract themes via frequency analysis
        themes = self._extract_themes(content)
        
        return {
            'summary': summary,
            'key_points': key_points,
            'themes': themes,
            'source': chunk.get('source', 'unknown'),
            'chunk_id': chunk.get('chunk_id', 0)
        }
    
    def _extract_themes(self, text: str, max_themes: int = 5) -> List[str]:
        """Extract themes using frequency analysis"""
        words = re.findall(r'\b[a-zA-Z]{4,}\b', text.lower())
        if not words:
            return []
        
        # Stop words to exclude
        stop_words = {
            'that', 'this', 'with', 'have', 'from', 'they', 'were', 'their',
            'there', 'which', 'about', 'into', 'through', 'also', 'very',
            'than', 'them', 'some', 'only', 'even', 'most', 'more', 'like',
            'been', 'being', 'would', 'could', 'should', 'will', 'just',
            'because', 'when', 'where', 'what', 'your', 'other', 'each'
        }
        
        # Count word frequency
        word_counts = {}
        for word in words:
            if word not in stop_words and len(word) >= 4:
                word_counts[word] = word_counts.get(word, 0) + 1
        
        # Get top themes
        sorted_words = sorted(word_counts.items(), key=lambda x: x[1], reverse=True)
        themes = [word for word, count in sorted_words[:max_themes] if count >= 2]

        # Add legal-specific themes, prioritizing higher tiers
        text_lower = text.lower()
        legal_themes_found = []
        for tier_terms, tier_num in [
            (LEGAL_TIER_1, 1), (LEGAL_TIER_6, 6), (LEGAL_TIER_2, 2),
            (LEGAL_TIER_5, 5), (LEGAL_TIER_3, 3), (LEGAL_TIER_4, 4)
        ]:
            for term in tier_terms:
                if term in text_lower and term not in themes and term not in legal_themes_found:
                    legal_themes_found.append(term)
                    if len(legal_themes_found) >= max_themes:
                        break
            if len(legal_themes_found) >= max_themes:
                break

        # Add legal themes to main themes list
        for term in legal_themes_found:
            if term not in themes:
                themes.append(term)
                if len(themes) >= max_themes * 2:  # Allow more themes for legal docs
                    break
        
        return themes[:max_themes]
    
    def _enhance_with_llm(self, content: str, base_card: Dict) -> Optional[Dict]:
        """Enhance card with LLM"""
        prompt = f"""You are refining a knowledge card for a sports betting document. Extract the most important information.

EXCERPT:
{content[:800]}

CURRENT SUMMARY:
{base_card['summary']}

Provide an improved knowledge card in this exact format:
Summary: <1-2 sentences capturing the main point about sports betting>
Key Points:
- <concise bullet 1>
- <concise bullet 2>
- <concise bullet 3>
Themes: theme1, theme2, theme3

Begin the card now:"""
        
        try:
            response = self.llm.generate(prompt, max_tokens=200, temperature=0.3)
            if response:
                return self._parse_llm_response(response)
        except Exception as e:
            print(f"[RAG] LLM enhancement error: {e}")
        
        return None
    
    def _parse_llm_response(self, text: str) -> Optional[Dict]:
        """Parse LLM output into structured card"""
        if not text:
            return None
        
        result = {}
        
        # Extract summary
        summary_match = re.search(r'Summary:\s*(.+?)(?=Key Points:|Themes:|$)', text, re.DOTALL)
        if summary_match:
            result['summary'] = summary_match.group(1).strip()[:400]
        
        # Extract key points
        key_points = []
        kp_match = re.search(r'Key Points:\s*(.*?)(?=Themes:|$)', text, re.DOTALL)
        if kp_match:
            lines = kp_match.group(1).strip().split('\n')
            for line in lines:
                line = line.strip()
                if line.startswith('-'):
                    line = line[1:].strip()
                if line and len(line) > 5:
                    key_points.append(line[:220])
        if key_points:
            result['key_points'] = key_points[:5]
        
        # Extract themes
        themes_match = re.search(r'Themes:\s*(.+)', text)
        if themes_match:
            themes = [t.strip() for t in themes_match.group(1).split(',')]
            result['themes'] = [t for t in themes if t][:5]
        
        return result if result else None
    
    def _compose_enriched_content(self, original: str, card: Dict) -> str:
        """Compose enriched content for storage"""
        lines = [
            f"[Source: {card.get('source', 'Unknown')} | Chunk {card.get('chunk_id', '?')}]",
            f"Summary: {card.get('summary', '')}"
        ]
        
        if card.get('key_points'):
            lines.append("Key Points:")
            for point in card['key_points']:
                lines.append(f"- {point}")
        
        if card.get('themes'):
            lines.append(f"Themes: {', '.join(card['themes'])}")
        
        lines.append(f"\nOriginal Content:\n{original}")
        
        return '\n'.join(lines)
    
    def enrich_all_chunks(self, chunks: List[Dict], progress_callback=None) -> List[Dict]:
        """Enrich all chunks with progress tracking"""
        enriched = []
        total = len(chunks)
        
        for i, chunk in enumerate(chunks):
            enriched_chunk = self.enrich_chunk(chunk)
            enriched.append(enriched_chunk)
            
            if progress_callback:
                progress_callback(i + 1, total)
            
            if (i + 1) % 10 == 0:
                print(f"[RAG] Enriched {i + 1}/{total} chunks (LLM used: {self.enrichment_used})")
        
        return enriched


# ============================================================================
# Flask Application
# ============================================================================
app = Flask(__name__, template_folder='templates')

# Register Case Management Blueprint (optional module)
try:
    from case_api import case_bp
    app.register_blueprint(case_bp)
    print("[RAG] Case management API loaded at /cases/*")
except ImportError as e:
    print(f"[RAG] Case management not available: {e}")

# Global instances
config = RAGConfig()
pdf_processor = PDFProcessor(config)
vector_store = VectorStore(config)
ocr_service = OCRService(config)
enrichment_service = EnrichmentService(config)

# Progress tracking
upload_progress = {}
enrichment_status = {
    'status': 'idle',
    'progress': 0,
    'message': '',
    'llm_used': 0,
    'updated_at': None,
    'started_at': None,
    'source': None,
    'total_target': 0,
    'done_target': 0,
    'checkpoint_every': int(os.environ.get('ENRICHMENT_CHECKPOINT_EVERY', 10)),
}
progress_lock = threading.Lock()
enrichment_lock = threading.Lock()
processing_queue = queue.Queue()

# Enrichment runtime + persistence
enrichment_thread = None
ENRICHMENT_STATUS_FILE = config.cache_dir / 'enrichment_status.json'


def _save_enrichment_status_to_disk():
    """Persist enrichment status for resume after restarts."""
    try:
        tmp_path = ENRICHMENT_STATUS_FILE.with_suffix(ENRICHMENT_STATUS_FILE.suffix + '.tmp')
        with open(tmp_path, 'w', encoding='utf-8') as f:
            json.dump(enrichment_status, f, ensure_ascii=False, indent=2)
        os.replace(tmp_path, ENRICHMENT_STATUS_FILE)
    except Exception as e:
        print(f"[RAG] Failed to persist enrichment status: {e}")


def _load_enrichment_status_from_disk():
    """Restore enrichment status on startup (and mark stale running jobs as interrupted)."""
    global enrichment_status
    try:
        if ENRICHMENT_STATUS_FILE.exists():
            with open(ENRICHMENT_STATUS_FILE, 'r', encoding='utf-8') as f:
                data = json.load(f)
            if isinstance(data, dict):
                enrichment_status.update(data)
    except Exception as e:
        print(f"[RAG] Failed to load enrichment status: {e}")

    # Normalize defaults
    enrichment_status.setdefault('status', 'idle')
    enrichment_status.setdefault('progress', 0)
    enrichment_status.setdefault('message', '')
    enrichment_status.setdefault('llm_used', 0)
    enrichment_status.setdefault('updated_at', None)
    enrichment_status.setdefault('started_at', None)
    enrichment_status.setdefault('source', None)
    enrichment_status.setdefault('total_target', 0)
    enrichment_status.setdefault('done_target', 0)
    enrichment_status.setdefault('checkpoint_every', int(os.environ.get('ENRICHMENT_CHECKPOINT_EVERY', 10)))

    # If the process restarted while enrichment was running, mark as interrupted (resume is safe).
    if enrichment_status.get('status') == 'running':
        enrichment_status['status'] = 'interrupted'
        enrichment_status['message'] = (
            enrichment_status.get('message')
            or 'Previous enrichment was interrupted. You can resume.'
        )
        enrichment_status['updated_at'] = datetime.now(timezone.utc).isoformat() + 'Z'
        _save_enrichment_status_to_disk()

# API key authentication
FLASK_API_KEY = os.environ.get('FLASK_API_KEY', 'eventheodds-flask-api-key-2025')

def require_api_key(f):
    @wraps(f)
    def decorated(*args, **kwargs):
        api_key = request.headers.get('X-API-Key', '')
        if api_key != FLASK_API_KEY:
            return jsonify({'error': 'Invalid API key'}), 401
        return f(*args, **kwargs)
    return decorated


# Initialize on startup
print("[RAG] Initializing Advanced Legal Document RAG Service...")
vector_store.load()
_load_enrichment_status_from_disk()


def update_progress(job_id: str, status: str, progress: int, message: str, **extra):
    """Update upload progress"""
    with progress_lock:
        upload_progress[job_id] = {
            'status': status,
            'progress': progress,
            'message': message,
            'updated': time.time(),
            **extra
        }


def process_file_background(job_id: str, file_path: Path, filename: str):
    """Process uploaded file in background"""
    try:
        update_progress(job_id, 'processing', 10, 'Extracting text from document...')
        
        # Extract text
        text = pdf_processor.extract_text(file_path)
        
        # Try OCR if extraction failed
        if not text or len(text.strip()) < 100:
            update_progress(job_id, 'processing', 20, 'Text extraction failed, trying OCR...')
            text = ocr_service.extract_text(file_path)
        
        if not text or len(text.strip()) < 100:
            update_progress(job_id, 'failed', 0, 'Failed to extract text from document')
            return
        
        update_progress(job_id, 'processing', 40, f'Extracted {len(text)} characters. Chunking...')
        
        # Detect document type
        doc_type = pdf_processor.detect_document_type(text, filename)
        
        # Chunk text
        chunks = pdf_processor.chunk_text(text, doc_type)
        
        if not chunks:
            update_progress(job_id, 'failed', 0, 'No valid chunks extracted from document')
            return
        
        # Add metadata
        for i, chunk in enumerate(chunks):
            chunk['source'] = filename
            chunk['chunk_id'] = i
            chunk['total_chunks'] = len(chunks)
            chunk['document_type'] = doc_type
        
        update_progress(job_id, 'processing', 60, f'Created {len(chunks)} chunks. Generating embeddings...')
        
        # Replace in vector store
        vector_store.replace_source(filename, chunks)
        
        update_progress(job_id, 'processing', 90, 'Saving index...')
        vector_store.save()
        
        update_progress(
            job_id, 'completed', 100,
            f'Successfully processed! Added {len(chunks)} chunks.',
            chunks_added=len(chunks),
            document_type=doc_type,
            total_documents=len(vector_store.documents)
        )
        
    except Exception as e:
        print(f"[RAG] Error processing file: {e}")
        import traceback
        traceback.print_exc()
        update_progress(job_id, 'failed', 0, f'Error: {str(e)}')


# ============================================================================
# API Endpoints
# ============================================================================

@app.route('/')
def landing():
    """Serve the public landing page"""
    return render_template('landing.html')

@app.route('/admin')
def admin():
    """Serve upload interface with chunk viewer and enrichment controls"""
    return '''<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>Legal Document RAG - Document Upload</title>
    <style>
        * { box-sizing: border-box; margin: 0; padding: 0; }
        body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; background: linear-gradient(135deg, #1a1a2e 0%, #16213e 100%); min-height: 100vh; padding: 20px; color: #e4e4e7; }
        .container { max-width: 1000px; margin: 0 auto; }
        h1 { color: #fbbf24; margin-bottom: 10px; font-size: 2rem; }
        .subtitle { color: #9ca3af; margin-bottom: 30px; }
        .card { background: rgba(255,255,255,0.05); border-radius: 12px; padding: 24px; margin-bottom: 20px; border: 1px solid rgba(255,255,255,0.1); }
        .card h2 { color: #fbbf24; font-size: 1.25rem; margin-bottom: 16px; }
        .stats { display: grid; grid-template-columns: repeat(auto-fit, minmax(120px, 1fr)); gap: 16px; }
        .stat { background: rgba(251,191,36,0.1); padding: 16px; border-radius: 8px; text-align: center; }
        .stat-value { font-size: 1.75rem; font-weight: bold; color: #fbbf24; }
        .stat-label { color: #9ca3af; font-size: 0.75rem; }
        form { display: flex; flex-direction: column; gap: 16px; }
        input[type="text"], input[type="file"], select { padding: 12px; border-radius: 8px; border: 1px solid rgba(255,255,255,0.2); background: rgba(0,0,0,0.2); color: white; font-size: 1rem; }
        input[type="file"] { cursor: pointer; }
        button { background: #fbbf24; color: #1a1a2e; padding: 12px 20px; border: none; border-radius: 8px; font-weight: 600; font-size: 0.9rem; cursor: pointer; transition: all 0.2s; }
        button:hover { background: #f59e0b; transform: translateY(-1px); }
        button:disabled { opacity: 0.5; cursor: not-allowed; transform: none; }
        button.secondary { background: #6366f1; color: white; }
        button.secondary:hover { background: #4f46e5; }
        button.danger { background: #ef4444; color: white; }
        .btn-group { display: flex; gap: 10px; flex-wrap: wrap; }
        .progress-container { margin-top: 20px; }
        .progress-bar { height: 8px; background: rgba(255,255,255,0.1); border-radius: 4px; overflow: hidden; }
        .progress-fill { height: 100%; background: linear-gradient(90deg, #fbbf24, #f59e0b); transition: width 0.3s; }
        .progress-text { margin-top: 8px; color: #9ca3af; font-size: 0.875rem; }
        .result { padding: 16px; border-radius: 8px; margin-top: 16px; }
        .result.success { background: rgba(34,197,94,0.2); border: 1px solid rgba(34,197,94,0.3); }
        .result.error { background: rgba(239,68,68,0.2); border: 1px solid rgba(239,68,68,0.3); }
        .result.info { background: rgba(59,130,246,0.2); border: 1px solid rgba(59,130,246,0.3); }
        .tabs { display: flex; gap: 10px; margin-bottom: 20px; border-bottom: 1px solid rgba(255,255,255,0.1); padding-bottom: 10px; }
        .tab { padding: 8px 16px; cursor: pointer; border-radius: 6px; color: #9ca3af; }
        .tab.active { background: rgba(251,191,36,0.2); color: #fbbf24; }
        .tab:hover { background: rgba(255,255,255,0.05); }
        .chunk-list { max-height: 500px; overflow-y: auto; }
        .chunk-item { background: rgba(0,0,0,0.2); padding: 16px; border-radius: 8px; margin-bottom: 12px; border-left: 4px solid #6366f1; }
        .chunk-item.enriched { border-left-color: #22c55e; }
        .chunk-header { display: flex; justify-content: space-between; align-items: center; margin-bottom: 8px; }
        .chunk-source { color: #fbbf24; font-weight: 600; font-size: 0.875rem; }
        .chunk-badge { padding: 2px 8px; border-radius: 4px; font-size: 0.75rem; }
        .chunk-badge.enriched { background: rgba(34,197,94,0.3); color: #22c55e; }
        .chunk-badge.raw { background: rgba(156,163,175,0.3); color: #9ca3af; }
        .chunk-summary { color: #e4e4e7; margin-bottom: 8px; font-size: 0.9rem; }
        .chunk-themes { display: flex; gap: 6px; flex-wrap: wrap; margin-bottom: 8px; }
        .theme-tag { background: rgba(99,102,241,0.3); color: #a5b4fc; padding: 2px 8px; border-radius: 4px; font-size: 0.75rem; }
        .chunk-content { color: #9ca3af; font-size: 0.8rem; background: rgba(0,0,0,0.2); padding: 10px; border-radius: 4px; white-space: pre-wrap; max-height: 150px; overflow-y: auto; }
        .pagination { display: flex; justify-content: center; gap: 10px; margin-top: 16px; }
        .hidden { display: none !important; }
    </style>
</head>
<body>
    <div class="container">
        <h1>🏆 Legal Document RAG</h1>
        <p class="subtitle">Advanced Document Processing with FAISS + BM25 + DeepSeek LLM Enrichment</p>
        
        <div class="tabs">
            <div class="tab active" onclick="showTab('status')">📊 Status</div>
            <div class="tab" onclick="showTab('upload')">📤 Upload</div>
            <div class="tab" onclick="showTab('chunks')">📄 View Chunks</div>
            <div class="tab" onclick="showTab('enrich')">✨ Enrichment</div>
        </div>
        
        <!-- Status Tab -->
        <div id="tab-status" class="card">
            <h2>📊 System Status</h2>
            <div class="stats" id="stats">
                <div class="stat"><div class="stat-value" id="docCount">-</div><div class="stat-label">Total Chunks</div></div>
                <div class="stat"><div class="stat-value" id="enrichedCount">-</div><div class="stat-label">Enriched</div></div>
                <div class="stat"><div class="stat-value" id="rawCount">-</div><div class="stat-label">Raw</div></div>
                <div class="stat"><div class="stat-value" id="fileCount">-</div><div class="stat-label">Files</div></div>
                <div class="stat"><div class="stat-value" id="searchType">-</div><div class="stat-label">Search</div></div>
            </div>
        </div>
        
        <!-- Upload Tab -->
        <div id="tab-upload" class="card hidden">
            <h2>📤 Upload Document</h2>
            <form id="uploadForm">
                <input type="text" id="apiKey" placeholder="API Key" value="eventheodds-flask-api-key-2025">
                <input type="file" id="fileInput" accept=".pdf,.txt" required>
                <div class="btn-group">
                    <button type="submit" id="submitBtn">Upload & Process</button>
                    <button type="button" class="secondary" onclick="uploadWithEnrich()">Upload + Enrich</button>
                </div>
            </form>
            <div class="progress-container" id="progressContainer" style="display:none;">
                <div class="progress-bar"><div class="progress-fill" id="progressFill" style="width:0%"></div></div>
                <div class="progress-text" id="progressText">Starting...</div>
            </div>
            <div id="result"></div>
        </div>
        
        <!-- Chunks Tab -->
        <div id="tab-chunks" class="card hidden">
            <h2>📄 View Chunks</h2>
            <div style="display:flex; gap:10px; margin-bottom:16px; flex-wrap:wrap;">
                <select id="chunkFilter" onchange="loadChunks()">
                    <option value="all">All Chunks</option>
                    <option value="enriched">Enriched Only</option>
                    <option value="raw">Raw Only</option>
                </select>
                <select id="sourceFilter" onchange="loadChunks()">
                    <option value="">All Sources</option>
                </select>
                <button class="secondary" onclick="loadChunks()">🔄 Refresh</button>
            </div>
            <div class="chunk-list" id="chunkList">Loading...</div>
            <div class="pagination" id="pagination"></div>
        </div>
        
        <!-- Enrichment Tab -->
        <div id="tab-enrich" class="card hidden">
            <h2>✨ LLM Enrichment</h2>
            <p style="color:#9ca3af; margin-bottom:16px;">Use DeepSeek LLM to generate summaries, key points, and themes for all chunks.</p>
            <div class="btn-group">
                <button onclick="startEnrichment()">🚀 Enrich All Raw Chunks</button>
                <button class="secondary" onclick="resumeEnrichment()">▶️ Resume</button>
                <button class="secondary" onclick="checkEnrichmentStatus()">🔄 Check Status</button>
            </div>
            <div class="progress-container" id="enrichProgress" style="display:none;">
                <div class="progress-bar"><div class="progress-fill" id="enrichProgressFill" style="width:0%"></div></div>
                <div class="progress-text" id="enrichProgressText">Starting...</div>
            </div>
            <div id="enrichResult"></div>
        </div>
    </div>
    
    <script>
        const BASE_URL = window.location.pathname.replace(/\/admin.*$/, '').replace(/\/$/, '');
        let currentPage = 1;
        
        function showTab(tabName) {
            document.querySelectorAll('.card').forEach(c => c.classList.add('hidden'));
            document.querySelectorAll('.tab').forEach(t => t.classList.remove('active'));
            document.getElementById('tab-' + tabName).classList.remove('hidden');
            event.target.classList.add('active');
            
            if (tabName === 'chunks') loadChunks();
            if (tabName === 'status') loadStatus();
            if (tabName === 'enrich') checkEnrichmentStatus();
        }
        
        async function loadStatus() {
            try {
                const resp = await fetch(BASE_URL + '/documents', {cache:'no-store'});
                const data = await resp.json();
                document.getElementById('docCount').textContent = data.total_chunks || 0;
                document.getElementById('enrichedCount').textContent = data.total_enriched || 0;
                document.getElementById('rawCount').textContent = (data.total_chunks || 0) - (data.total_enriched || 0);
                document.getElementById('fileCount').textContent = data.total_sources || 0;
                
                // Update source filter
                const sourceSelect = document.getElementById('sourceFilter');
                sourceSelect.innerHTML = '<option value="">All Sources</option>';
                (data.documents || []).forEach(d => {
                    sourceSelect.innerHTML += `<option value="${d.source}">${d.source} (${d.chunk_count})</option>`;
                });
                
                // Get search type
                const statusResp = await fetch(BASE_URL + '/status', {cache:'no-store'});
                const statusData = await statusResp.json();
                document.getElementById('searchType').textContent = statusData.has_faiss ? 'FAISS' : 'Cosine';
            } catch(e) {
                console.error('Status load failed:', e);
            }
        }
        
        async function loadChunks() {
            const filter = document.getElementById('chunkFilter').value;
            const source = document.getElementById('sourceFilter').value;
            const chunkList = document.getElementById('chunkList');
            chunkList.innerHTML = 'Loading...';
            
            let url = BASE_URL + '/chunks?page=' + currentPage + '&per_page=10';
            if (filter === 'enriched') url += '&enriched=true';
            if (filter === 'raw') url += '&raw=true';
            if (source) url += '&source=' + encodeURIComponent(source);
            
            try {
                const resp = await fetch(url, {cache:'no-store'});
                const data = await resp.json();
                
                if (!data.chunks || data.chunks.length === 0) {
                    chunkList.innerHTML = '<p style="color:#9ca3af;">No chunks found</p>';
                    return;
                }
                
                chunkList.innerHTML = data.chunks.map(c => `
                    <div class="chunk-item ${c.enriched ? 'enriched' : ''}">
                        <div class="chunk-header">
                            <span class="chunk-source">${c.source} #${c.chunk_id}</span>
                            <span class="chunk-badge ${c.enriched ? 'enriched' : 'raw'}">${c.enriched ? '✓ Enriched' : 'Raw'}</span>
                        </div>
                        ${c.summary ? `<div class="chunk-summary"><strong>Summary:</strong> ${c.summary}</div>` : ''}
                        ${c.themes && c.themes.length ? `<div class="chunk-themes">${c.themes.map(t => `<span class="theme-tag">${t}</span>`).join('')}</div>` : ''}
                        <div class="chunk-content">${c.content_preview}</div>
                    </div>
                `).join('');
                
                // Pagination
                const pag = data.pagination;
                document.getElementById('pagination').innerHTML = `
                    <button ${pag.page <= 1 ? 'disabled' : ''} onclick="currentPage=${pag.page-1};loadChunks()">← Prev</button>
                    <span style="color:#9ca3af;">Page ${pag.page} of ${pag.pages} (${pag.total} chunks)</span>
                    <button ${pag.page >= pag.pages ? 'disabled' : ''} onclick="currentPage=${pag.page+1};loadChunks()">Next →</button>
                `;
            } catch(e) {
                chunkList.innerHTML = '<p style="color:#ef4444;">Error loading chunks</p>';
            }
        }
        
        async function startEnrichment() {
            const apiKey = document.getElementById('apiKey').value.trim() || 'eventheodds-flask-api-key-2025';
            document.getElementById('enrichProgress').style.display = 'block';
            document.getElementById('enrichResult').innerHTML = '';
            
            try {
                const resp = await fetch(BASE_URL + '/enrich', {
                    method: 'POST',
                    headers: {'X-API-Key': apiKey, 'Content-Type': 'application/json'}
                });
                const data = await resp.json();
                
                if (data.error) {
                    document.getElementById('enrichResult').innerHTML = `<div class="result error">❌ ${data.error}</div>`;
                } else {
                    pollEnrichment();
                }
            } catch(e) {
                document.getElementById('enrichResult').innerHTML = `<div class="result error">❌ Error: ${e.message}</div>`;
            }
        }

        async function resumeEnrichment() {
            // Resume is the same API call as start; the server will skip already-enriched chunks.
            await startEnrichment();
        }
        
        async function pollEnrichment() {
            const poll = async () => {
                const resp = await fetch(BASE_URL + '/enrichment-status', {cache:'no-store'});
                const data = await resp.json();
                
                document.getElementById('enrichProgressFill').style.width = data.progress + '%';
                document.getElementById('enrichProgressText').textContent = data.message || 'Processing...';
                
                if (data.status === 'completed') {
                    document.getElementById('enrichResult').innerHTML = `<div class="result success">✅ ${data.message}</div>`;
                    loadStatus();
                    return;
                } else if (data.status === 'failed') {
                    document.getElementById('enrichResult').innerHTML = `<div class="result error">❌ ${data.message}</div>`;
                    return;
                } else if (data.status === 'interrupted') {
                    document.getElementById('enrichResult').innerHTML =
                      `<div class="result info">⏸️ Interrupted. Progress saved. Click <strong>Resume</strong> to continue.</div>`;
                    return;
                } else if (data.status === 'idle') {
                    document.getElementById('enrichProgress').style.display = 'none';
                    return;
                }
                
                setTimeout(poll, 1000);
            };
            poll();
        }
        
        async function checkEnrichmentStatus() {
            const resp = await fetch(BASE_URL + '/enrichment-status', {cache:'no-store'});
            const data = await resp.json();
            
            if (data.status === 'running' || data.status === 'interrupted') {
                document.getElementById('enrichProgress').style.display = 'block';
                document.getElementById('enrichProgressFill').style.width = data.progress + '%';
                document.getElementById('enrichProgressText').textContent = data.message;
                if (data.status === 'running') {
                  pollEnrichment();
                } else {
                  document.getElementById('enrichResult').innerHTML =
                    `<div class="result info">⏸️ Interrupted. Progress saved. Click <strong>Resume</strong> to continue. (LLM calls: ${data.llm_used || 0})</div>`;
                }
            } else {
                document.getElementById('enrichProgress').style.display = 'none';
                document.getElementById('enrichResult').innerHTML = `<div class="result info">Status: ${data.status} | LLM calls: ${data.llm_used || 0}</div>`;
            }
        }
        
        async function uploadWithEnrich() {
            const fileInput = document.getElementById('fileInput');
            const apiKey = document.getElementById('apiKey').value.trim() || 'eventheodds-flask-api-key-2025';
            
            if (!fileInput.files[0]) {
                document.getElementById('result').innerHTML = '<div class="result error">Please select a file</div>';
                return;
            }
            
            document.getElementById('submitBtn').disabled = true;
            document.getElementById('progressContainer').style.display = 'block';
            document.getElementById('progressText').textContent = 'Uploading with enrichment...';
            
            const formData = new FormData();
            formData.append('file', fileInput.files[0]);
            
            try {
                const resp = await fetch(BASE_URL + '/enrich-upload', {
                    method: 'POST',
                    headers: {'X-API-Key': apiKey},
                    body: formData
                });
                const data = await resp.json();
                if (data.job_id) pollProgress(data.job_id, apiKey);
            } catch(e) {
                document.getElementById('result').innerHTML = `<div class="result error">❌ ${e.message}</div>`;
                document.getElementById('submitBtn').disabled = false;
            }
        }
        
        // Standard upload
        document.getElementById('uploadForm').onsubmit = async (e) => {
            e.preventDefault();
            const fileInput = document.getElementById('fileInput');
            const apiKey = document.getElementById('apiKey').value.trim() || 'eventheodds-flask-api-key-2025';
            
            if (!fileInput.files[0]) {
                document.getElementById('result').innerHTML = '<div class="result error">Please select a file</div>';
                return;
            }
            
            document.getElementById('submitBtn').disabled = true;
            document.getElementById('progressContainer').style.display = 'block';
            
            const formData = new FormData();
            formData.append('file', fileInput.files[0]);
            
            try {
                const resp = await fetch(BASE_URL + '/upload', {
                    method: 'POST',
                    headers: {'X-API-Key': apiKey},
                    body: formData
                });
                const data = await resp.json();
                if (data.job_id) pollProgress(data.job_id, apiKey);
            } catch(e) {
                document.getElementById('result').innerHTML = `<div class="result error">❌ ${e.message}</div>`;
                document.getElementById('submitBtn').disabled = false;
            }
        };
        
        async function pollProgress(jobId, apiKey) {
            const poll = async () => {
                const resp = await fetch(BASE_URL + '/progress/' + jobId, {cache:'no-store'});
                const data = await resp.json();
                
                document.getElementById('progressFill').style.width = data.progress + '%';
                document.getElementById('progressText').textContent = data.message || 'Processing...';
                
                if (data.status === 'completed') {
                    document.getElementById('result').innerHTML = `<div class="result success">✅ ${data.message}</div>`;
                    document.getElementById('submitBtn').disabled = false;
                    loadStatus();
                    return;
                } else if (data.status === 'failed') {
                    document.getElementById('result').innerHTML = `<div class="result error">❌ ${data.message}</div>`;
                    document.getElementById('submitBtn').disabled = false;
                    return;
                }
                setTimeout(poll, 500);
            };
            poll();
        }
        
        loadStatus();
    </script>
</body>
</html>'''


# ============================================================================
# Confidence Calculation Helpers for Progressive Disclosure UI
# ============================================================================

def calculate_precedential_confidence(boost: float) -> dict:
    """Calculate precedential strength confidence (0-100) with label"""
    if boost >= 2.0:
        return {'score': 95, 'label': 'Gold Standard', 'description': 'Supreme Court or Restatement authority'}
    elif boost >= 1.5:
        return {'score': 80, 'label': 'Highly Persuasive', 'description': 'Circuit Court or major treatise'}
    elif boost >= 1.2:
        return {'score': 65, 'label': 'Persuasive', 'description': 'District Court or state court authority'}
    elif boost >= 1.0:
        return {'score': 50, 'label': 'Standard', 'description': 'General legal authority'}
    else:
        return {'score': 25, 'label': 'Declining', 'description': 'May be overruled or superseded'}


def calculate_jurisdiction_confidence(boost: float) -> dict:
    """Calculate jurisdictional relevance confidence"""
    if boost >= 1.5:
        return {'score': 95, 'label': 'Supreme Court', 'description': 'Binding precedent nationwide'}
    elif boost >= 1.3:
        return {'score': 80, 'label': 'Federal Circuit', 'description': 'Binding in circuit'}
    elif boost >= 1.2:
        return {'score': 70, 'label': 'State Supreme', 'description': 'Binding in state'}
    elif boost >= 1.1:
        return {'score': 60, 'label': 'State Appeals', 'description': 'Persuasive in state'}
    else:
        return {'score': 50, 'label': 'General', 'description': 'General jurisdiction'}


def calculate_argument_confidence(boost: float) -> dict:
    """Calculate argument structure quality confidence"""
    if boost >= 1.4:
        return {'score': 90, 'label': 'Strong Structure', 'description': 'Clear analogical or doctrinal framework'}
    elif boost >= 1.3:
        return {'score': 75, 'label': 'Good Structure', 'description': 'Well-organized legal argument'}
    elif boost >= 1.15:
        return {'score': 60, 'label': 'Moderate', 'description': 'Standard argument structure'}
    else:
        return {'score': 45, 'label': 'Basic', 'description': 'Simple presentation'}


def calculate_overall_confidence(result: dict) -> dict:
    """Calculate overall confidence score for a search result"""
    boosts = result.get('boosts', {})

    # Weight different factors
    factors = {
        'precedential': (boosts.get('precedential', 1.0), 0.25),
        'jurisdiction': (boosts.get('jurisdiction', 1.0), 0.20),
        'citation': (boosts.get('citation', 1.0), 0.15),
        'argument': (boosts.get('argument', 1.0), 0.15),
        'doctrine': (boosts.get('doctrine', 1.0), 0.10),
        'treatment': (boosts.get('treatment', 1.0), 0.10),
        'temporal': (boosts.get('temporal', 1.0), 0.05)
    }

    # Calculate weighted score
    weighted_sum = 0
    for factor, (boost, weight) in factors.items():
        # Normalize boost to 0-100 scale
        normalized = min(100, max(0, (boost - 0.4) / 1.6 * 100))
        weighted_sum += normalized * weight

    score = int(weighted_sum)

    # Determine label
    if score >= 85:
        label = 'Very High'
        description = 'Highly authoritative and relevant'
    elif score >= 70:
        label = 'High'
        description = 'Strong legal authority'
    elif score >= 55:
        label = 'Moderate'
        description = 'Relevant but verify with additional sources'
    elif score >= 40:
        label = 'Low'
        description = 'Limited authority, use with caution'
    else:
        label = 'Very Low'
        description = 'Weak authority, requires verification'

    return {
        'score': score,
        'label': label,
        'description': description,
        'factors': {k: round(v[0], 2) for k, v in factors.items()}
    }


@app.route('/health', methods=['GET'])
def health():
    """Health check"""
    return jsonify({
        'status': 'healthy',
        'total_documents': len(vector_store.documents),
        'has_faiss': HAS_FAISS,
        'has_sentence_transformers': HAS_SENTENCE_TRANSFORMERS,
        'has_pymupdf': HAS_PYMUPDF
    })


@app.route('/status', methods=['GET'])
def status():
    """Get system status"""
    # Count unique sources
    sources = set(d.get('source', '') for d in vector_store.documents)

    # Get LLM service info (via enrichment_service)
    llm = enrichment_service.llm
    llm_info = {
        'chat_llm': 'Grok API (grok-4-fast-reasoning)',
        'enrichment_llm': 'DeepSeek (local)' if llm.deepseek_available else 'Grok API',
        'deepseek_available': llm.deepseek_available,
        'deepseek_model': llm.deepseek_model if llm.deepseek_available else None,
        'grok_available': bool(llm.grok_api_key)
    }

    return jsonify({
        'total_documents': len(vector_store.documents),
        'total_files': len(sources),
        'has_faiss': HAS_FAISS,
        'has_sentence_transformers': HAS_SENTENCE_TRANSFORMERS,
        'embedding_model': 'SentenceTransformer' if HAS_SENTENCE_TRANSFORMERS else 'hash-based',
        'search_type': 'FAISS + BM25 Hybrid' if HAS_FAISS else 'Cosine + BM25 Hybrid',
        'llm': llm_info
    })


@app.route('/upload', methods=['POST'])
@require_api_key
def upload_file():
    """Upload and process a document"""
    if 'file' not in request.files:
        return jsonify({'error': 'No file provided'}), 400
    
    file = request.files['file']
    if not file.filename:
        return jsonify({'error': 'Empty filename'}), 400
    
    # Validate extension
    ext = Path(file.filename).suffix.lower()
    if ext not in ['.pdf', '.txt']:
        return jsonify({'error': 'Only PDF and TXT files are supported'}), 400
    
    # Save file
    safe_filename = secure_filename(file.filename)
    file_path = config.pdf_dir / safe_filename
    file.save(str(file_path))
    
    # Start background processing
    job_id = str(uuid.uuid4())[:8]
    update_progress(job_id, 'starting', 0, 'File uploaded, starting processing...')
    
    thread = threading.Thread(
        target=process_file_background,
        args=(job_id, file_path, safe_filename),
        daemon=True
    )
    thread.start()
    
    return jsonify({
        'success': True,
        'job_id': job_id,
        'message': 'File uploaded. Processing started.',
        'filename': safe_filename
    }), 202


@app.route('/progress/<job_id>', methods=['GET'])
def get_progress(job_id):
    """Get upload progress"""
    with progress_lock:
        if job_id in upload_progress:
            return jsonify(upload_progress[job_id])
    return jsonify({'status': 'unknown', 'progress': 0, 'message': 'Job not found'}), 404


@app.route('/ask', methods=['POST'])
@require_api_key
def ask():
    """Query the RAG system with enriched context support"""
    import time as time_module
    start_time = time_module.time()

    data = request.get_json(silent=True) or {}
    question = data.get('question', '').strip()
    use_enriched = data.get('use_enriched', True)  # Default to using enriched data
    include_analysis = data.get('include_analysis', True)  # Include detailed analysis for UI

    if not question:
        return jsonify({'error': 'Question is required'}), 400

    k = data.get('k', 5)

    # Search
    results = vector_store.search(question, k=k)
    
    if not results:
        return jsonify({
            'answer': "I don't have enough information to answer that question.",
            'sources': [],
            'chunks_searched': len(vector_store.documents),
            'enriched_count': 0
        })
    
    # Build context - prefer enriched summaries when available
    context_parts = []
    enriched_count = 0
    all_themes = set()
    all_key_points = []
    
    for r in results:
        meta = r.get('metadata', {})
        is_enriched = meta.get('enriched', False)
        
        if use_enriched and is_enriched and meta.get('summary'):
            # Use enriched summary + key points
            enriched_count += 1
            part = f"**Summary**: {meta.get('summary', '')}"
            
            key_points = meta.get('key_points', [])
            if key_points:
                all_key_points.extend(key_points[:2])  # Collect top 2 key points per chunk
                part += f"\n**Key Points**: {'; '.join(key_points[:3])}"
            
            themes = meta.get('themes', [])
            if themes:
                all_themes.update(themes)
            
            context_parts.append(part)
        else:
            # Fall back to raw content
            context_parts.append(r['content'][:500])
    
    context = '\n\n---\n\n'.join(context_parts)
    
    # Build prompt for LLM (Legal AI Assistant)
    prompt = f"""You are an AI Legal Research Assistant. Your role is to help users understand legal concepts,
case law, statutes, and legal procedures based on the retrieved legal documents.

=== RETRIEVED LEGAL KNOWLEDGE ===
{context[:3000]}

=== QUESTION ===
{question}

INSTRUCTIONS:
1. Answer based on the retrieved legal documents and your legal knowledge.
2. Cite specific sources, cases, or statutes when applicable.
3. Explain legal concepts clearly for the user's understanding.
4. If the question involves specific case law or procedures, reference the relevant details.
5. If you cannot find sufficient information, acknowledge limitations and suggest areas to research.
6. Always maintain professional legal language while being accessible.

IMPORTANT: This is for educational/research purposes only. Remind users to consult a licensed attorney for specific legal advice.

ANSWER:"""

    # Generate answer with Grok API (same as sports RAG)
    # Note: DeepSeek is used for enrichment, Grok for chat responses
    llm_used = None
    try:
        chat_llm = enrichment_service.llm

        # Use Grok for chat responses
        generated_answer = chat_llm.generate(prompt, max_tokens=1200, temperature=0.4)
        if generated_answer:
            llm_used = "DeepSeek or Grok"
            answer = generated_answer.strip()
        else:
            llm_used = "None (fallback to context)"
            answer = "I could not generate a specific answer. Here is the relevant context:\n\n" + context[:2500]
    except Exception as e:
        print(f"[RAG] LLM generation error: {e}")
        llm_used = f"Error: {str(e)[:50]}"
        answer = "Error generating answer. Here is the context:\n\n" + context[:2500]

    # Build detailed sources with enrichment info and boost breakdown
    sources = []
    for r in results:
        meta = r.get('metadata', {})
        source_data = {
            'source': meta.get('source', 'unknown'),
            'score': round(r['score'], 3),
            'preview': r['content'][:150],
            'enriched': meta.get('enriched', False),
            'summary': meta.get('summary', '')[:200] if meta.get('summary') else None,
            'themes': meta.get('themes', [])
        }

        # Include detailed analysis for progressive disclosure
        if include_analysis:
            source_data['analysis'] = {
                'base_score': round(r.get('base_score', r['score']), 3),
                'vector_score': round(r.get('vector_score', 0), 3),
                'keyword_score': round(r.get('keyword_score', 0), 3),
                'bm25_score': round(r.get('bm25_score', 0), 3),
                'matched_terms': r.get('matched_terms', []),
                'expanded_matches': r.get('expanded_matches', []),
                'boosts': r.get('boosts', {}),
                'phase': r.get('phase', 'general'),
                'court_level': r.get('court_level', 'unknown'),
                'argument_structure': r.get('argument_structure'),
                'primary_doctrine': r.get('primary_doctrine')
            }

            # Calculate confidence metrics
            boosts = r.get('boosts', {})
            source_data['confidence'] = {
                'precedential_strength': calculate_precedential_confidence(boosts.get('precedential', 1.0)),
                'jurisdictional_relevance': calculate_jurisdiction_confidence(boosts.get('jurisdiction', 1.0)),
                'argument_quality': calculate_argument_confidence(boosts.get('argument', 1.0)),
                'overall': calculate_overall_confidence(r)
            }

        sources.append(source_data)

    # Calculate response time
    response_time_ms = int((time_module.time() - start_time) * 1000)

    # Build response
    response_data = {
        'answer': answer,
        'sources': sources,
        'chunks_searched': len(vector_store.documents),
        'chunks_returned': len(results),
        'enriched_count': enriched_count,
        'themes': list(all_themes)[:10],
        'key_points': all_key_points[:5],
        'llm_used': llm_used,
        'response_time_ms': response_time_ms
    }

    # Include query analysis for UI
    if include_analysis:
        response_data['query_analysis'] = {
            'detected_phase': detect_litigation_phase(question),
            'relevant_doctrines': get_relevant_doctrines(question),
            'expanded_terms_count': len(expand_query_with_concepts(question))
        }

    # Track analytics
    try:
        track_query(question, response_data)
    except Exception as e:
        print(f"[Analytics] Track error: {e}")

    return jsonify(response_data)
@app.route('/search', methods=['POST'])
@require_api_key
def search():
    """Search documents"""
    data = request.get_json(silent=True) or {}
    query = data.get('query', '').strip()
    k = data.get('k', 10)
    
    if not query:
        return jsonify({'error': 'Query is required'}), 400
    
    results = vector_store.search(query, k=k)
    
    return jsonify({
        'results': results,
        'query': query,
        'total_documents': len(vector_store.documents)
    })


@app.route('/documents', methods=['GET'])
def list_documents():
    """List all documents with enrichment status"""
    sources = {}
    for doc in vector_store.documents:
        source = doc.get('source', 'unknown')
        if source not in sources:
            sources[source] = {
                'chunk_count': 0,
                'enriched_count': 0,
                'document_type': doc.get('document_type', 'unknown')
            }
        sources[source]['chunk_count'] += 1
        if doc.get('enriched'):
            sources[source]['enriched_count'] += 1
    
    # Calculate totals
    total_chunks = len(vector_store.documents)
    total_enriched = sum(1 for d in vector_store.documents if d.get('enriched'))
    
    return jsonify({
        'documents': [{'source': k, **v} for k, v in sources.items()],
        'total_sources': len(sources),
        'total_chunks': total_chunks,
        'total_enriched': total_enriched,
        'enrichment_percentage': round((total_enriched / total_chunks * 100) if total_chunks > 0 else 0, 1)
    })


@app.route('/chunks', methods=['GET'])
def list_chunks():
    """List chunks with pagination and filtering"""
    # Query params
    source = request.args.get('source')
    enriched_only = request.args.get('enriched', '').lower() == 'true'
    raw_only = request.args.get('raw', '').lower() == 'true'
    page = int(request.args.get('page', 1))
    per_page = int(request.args.get('per_page', 20))
    
    # Filter chunks
    filtered = vector_store.documents
    
    if source:
        filtered = [d for d in filtered if d.get('source') == source]
    
    if enriched_only:
        filtered = [d for d in filtered if d.get('enriched')]
    elif raw_only:
        filtered = [d for d in filtered if not d.get('enriched')]
    
    # Paginate
    total = len(filtered)
    start = (page - 1) * per_page
    end = start + per_page
    paginated = filtered[start:end]
    
    # Format chunks for display
    chunks = []
    for i, doc in enumerate(paginated, start=start):
        chunks.append({
            'index': i,
            'source': doc.get('source', 'unknown'),
            'chunk_id': doc.get('chunk_id', i),
            'enriched': doc.get('enriched', False),
            'summary': doc.get('summary', '')[:200] if doc.get('summary') else None,
            'key_points': doc.get('key_points', [])[:3],
            'themes': doc.get('themes', []),
            'content_preview': doc.get('content', '')[:300] + '...' if len(doc.get('content', '')) > 300 else doc.get('content', ''),
            'content_length': len(doc.get('content', ''))
        })
    
    return jsonify({
        'chunks': chunks,
        'pagination': {
            'page': page,
            'per_page': per_page,
            'total': total,
            'pages': (total + per_page - 1) // per_page
        },
        'filters': {
            'source': source,
            'enriched_only': enriched_only,
            'raw_only': raw_only
        }
    })


@app.route('/chunks/<int:chunk_index>', methods=['GET'])
def get_chunk(chunk_index):
    """Get a single chunk with full details"""
    if chunk_index < 0 or chunk_index >= len(vector_store.documents):
        return jsonify({'error': 'Chunk not found'}), 404
    
    doc = vector_store.documents[chunk_index]
    
    return jsonify({
        'index': chunk_index,
        'source': doc.get('source', 'unknown'),
        'chunk_id': doc.get('chunk_id'),
        'document_type': doc.get('document_type'),
        'enriched': doc.get('enriched', False),
        'summary': doc.get('summary'),
        'key_points': doc.get('key_points', []),
        'themes': doc.get('themes', []),
        'content': doc.get('content', ''),
        'content_length': len(doc.get('content', '')),
        'section': doc.get('section'),
        'total_chunks': doc.get('total_chunks')
    })


@app.route('/chunks/by-source/<path:source_name>', methods=['GET'])
def get_chunks_by_source(source_name):
    """Get all chunks for a specific source"""
    safe_source = secure_filename(source_name)
    chunks = [d for d in vector_store.documents if d.get('source') == safe_source]
    
    if not chunks:
        return jsonify({'error': 'Source not found'}), 404
    
    enriched_count = sum(1 for c in chunks if c.get('enriched'))
    
    return jsonify({
        'source': safe_source,
        'total_chunks': len(chunks),
        'enriched_chunks': enriched_count,
        'chunks': [{
            'chunk_id': c.get('chunk_id'),
            'enriched': c.get('enriched', False),
            'summary': c.get('summary', '')[:200] if c.get('summary') else None,
            'themes': c.get('themes', []),
            'content_preview': c.get('content', '')[:200]
        } for c in chunks]
    })


@app.route('/delete/<path:filename>', methods=['DELETE', 'POST'])
@require_api_key
def delete_document(filename):
    """Delete a document"""
    safe_filename = secure_filename(filename)
    
    # Remove from vector store
    original_count = len(vector_store.documents)
    vector_store.documents = [d for d in vector_store.documents if d.get('source') != safe_filename]
    removed_count = original_count - len(vector_store.documents)
    
    if removed_count > 0:
        # Rebuild index
        if vector_store.documents:
            contents = [d['content'] for d in vector_store.documents]
            embeddings = vector_store._get_embeddings_batch(contents)
            
            if HAS_FAISS:
                vector_store.index = faiss.IndexFlatIP(vector_store.embedding_dim)
                embeddings_np = embeddings.astype('float32')
                faiss.normalize_L2(embeddings_np)
                vector_store.index.add(embeddings_np)
            else:
                vector_store.embeddings = embeddings
        else:
            if HAS_FAISS:
                vector_store.index = faiss.IndexFlatIP(vector_store.embedding_dim)
            else:
                vector_store.embeddings = np.array([])
        
        vector_store._prepare_searchable_texts()
        vector_store.save()
        
        # Delete source file
        file_path = config.pdf_dir / safe_filename
        if file_path.exists():
            file_path.unlink()
        
        return jsonify({
            'success': True,
            'message': f'Deleted {safe_filename} ({removed_count} chunks)',
            'chunks_removed': removed_count
        })
    
    return jsonify({'error': 'Document not found'}), 404


@app.route('/reload', methods=['POST'])
@require_api_key
def reload():
    """Reload the index"""
    vector_store.load()
    return jsonify({
        'success': True,
        'total_documents': len(vector_store.documents)
    })


@app.route('/enrich', methods=['POST'])
@require_api_key
def enrich_documents():
    """Enrich documents with LLM-generated metadata"""
    global enrichment_status, enrichment_thread
    
    with enrichment_lock:
        if enrichment_status.get('status') == 'running' and enrichment_thread is not None and enrichment_thread.is_alive():
            return jsonify({
                'error': 'Enrichment already running',
                'status': enrichment_status
            }), 409
        # Handle stale "running" state (process restarted / thread died)
        if enrichment_status.get('status') == 'running':
            enrichment_status['status'] = 'interrupted'
            enrichment_status['message'] = enrichment_status.get('message') or 'Previous enrichment was interrupted. You can resume.'
            enrichment_status['updated_at'] = datetime.now(timezone.utc).isoformat() + 'Z'
            _save_enrichment_status_to_disk()
    
    data = request.get_json(silent=True) or {}
    source_filter = data.get('source')  # Optional: enrich only specific source
    
    def run_enrichment():
        global enrichment_status
        try:
            checkpoint_every = int(os.environ.get('ENRICHMENT_CHECKPOINT_EVERY', enrichment_status.get('checkpoint_every', 10) or 10))

            # Ensure searchable caches exist
            try:
                vector_store._prepare_searchable_texts()
            except Exception:
                pass

            # Determine target set (for accurate progress + resume)
            if source_filter:
                target_indices = [i for i, d in enumerate(vector_store.documents) if d.get('source') == source_filter]
            else:
                target_indices = list(range(len(vector_store.documents)))

            total_target = len(target_indices)
            if total_target == 0:
                with enrichment_lock:
                    enrichment_status['status'] = 'completed'
                    enrichment_status['progress'] = 100
                    enrichment_status['message'] = 'No chunks found to enrich.'
                    enrichment_status['llm_used'] = enrichment_service.enrichment_used
                    enrichment_status['source'] = source_filter
                    enrichment_status['total_target'] = 0
                    enrichment_status['done_target'] = 0
                    enrichment_status['updated_at'] = datetime.now(timezone.utc).isoformat() + 'Z'
                    enrichment_status['checkpoint_every'] = checkpoint_every
                    _save_enrichment_status_to_disk()
                return

            already_enriched = sum(1 for i in target_indices if vector_store.documents[i].get('enriched'))
            remaining = total_target - already_enriched
            started_at = datetime.now(timezone.utc).isoformat() + 'Z'

            with enrichment_lock:
                enrichment_status['status'] = 'running'
                enrichment_status['progress'] = int((already_enriched / total_target) * 100) if total_target else 0
                enrichment_status['message'] = f'Resuming enrichment: {already_enriched}/{total_target} already enriched. Enriching remaining {remaining}...'
                enrichment_status['llm_used'] = enrichment_service.enrichment_used
                enrichment_status['started_at'] = enrichment_status.get('started_at') or started_at
                enrichment_status['updated_at'] = started_at
                enrichment_status['source'] = source_filter
                enrichment_status['total_target'] = total_target
                enrichment_status['done_target'] = already_enriched
                enrichment_status['checkpoint_every'] = checkpoint_every
                _save_enrichment_status_to_disk()

            if remaining <= 0:
                with enrichment_lock:
                    enrichment_status['status'] = 'completed'
                    enrichment_status['progress'] = 100
                    enrichment_status['message'] = 'No chunks need enrichment (all already enriched)'
                    enrichment_status['llm_used'] = enrichment_service.enrichment_used
                    enrichment_status['updated_at'] = datetime.now(timezone.utc).isoformat() + 'Z'
                    _save_enrichment_status_to_disk()
                return

            enriched_this_run = 0

            for doc_index in target_indices:
                doc = vector_store.documents[doc_index]
                if doc.get('enriched'):
                    continue

                enriched_doc = enrichment_service.enrich_chunk(doc)
                doc.update(enriched_doc)
                doc['enriched_at'] = datetime.now(timezone.utc).isoformat() + 'Z'

                # Update cached searchable text for this doc
                try:
                    new_text = vector_store._build_searchable_text(doc)
                    if len(vector_store.searchable_texts) == len(vector_store.documents):
                        vector_store.searchable_texts[doc_index] = new_text.lower()
                    if len(vector_store.doc_lengths) == len(vector_store.documents):
                        old_len = vector_store.doc_lengths[doc_index]
                        new_len = len(new_text.split())
                        vector_store.doc_lengths[doc_index] = new_len
                        n = len(vector_store.doc_lengths)
                        if n > 0:
                            vector_store.avg_doc_length = (vector_store.avg_doc_length * n - old_len + new_len) / n
                except Exception:
                    pass

                enriched_this_run += 1
                done = already_enriched + enriched_this_run
                progress = int((done / total_target) * 100) if total_target else 0

                with enrichment_lock:
                    enrichment_status['status'] = 'running'
                    enrichment_status['progress'] = progress
                    enrichment_status['message'] = f'Enriched {done}/{total_target} chunks (this run: {enriched_this_run}/{remaining})'
                    enrichment_status['llm_used'] = enrichment_service.enrichment_used
                    enrichment_status['updated_at'] = datetime.now(timezone.utc).isoformat() + 'Z'
                    enrichment_status['done_target'] = done
                    enrichment_status['checkpoint_every'] = checkpoint_every

                # Periodic checkpoint so progress is not lost on interruptions
                if enriched_this_run % checkpoint_every == 0:
                    vector_store.save_metadata_only()
                    with enrichment_lock:
                        _save_enrichment_status_to_disk()

            # Final checkpoint and completion
            vector_store.save_metadata_only()
            with enrichment_lock:
                enrichment_status['status'] = 'completed'
                enrichment_status['progress'] = 100
                enrichment_status['message'] = f'Successfully enriched {enriched_this_run} chunks (target: {total_target})'
                enrichment_status['llm_used'] = enrichment_service.enrichment_used
                enrichment_status['updated_at'] = datetime.now(timezone.utc).isoformat() + 'Z'
                enrichment_status['done_target'] = total_target
                _save_enrichment_status_to_disk()
            
        except Exception as e:
            print(f"[RAG] Enrichment error: {e}")
            import traceback
            traceback.print_exc()
            with enrichment_lock:
                enrichment_status['status'] = 'failed'
                enrichment_status['message'] = f'Error: {str(e)}'
                enrichment_status['llm_used'] = enrichment_service.enrichment_used
                enrichment_status['updated_at'] = datetime.now(timezone.utc).isoformat() + 'Z'
                _save_enrichment_status_to_disk()
    
    # Run in background thread
    enrichment_thread = threading.Thread(target=run_enrichment, daemon=True)
    enrichment_thread.start()
    
    return jsonify({
        'success': True,
        'message': 'Enrichment started',
        'status': 'running'
    }), 202


@app.route('/enrichment-status', methods=['GET'])
def get_enrichment_status():
    """Get current enrichment status"""
    with enrichment_lock:
        return jsonify(enrichment_status)


@app.route('/processing-status', methods=['GET'])
def get_processing_status():
    """Get status of file processing jobs"""
    # Return list of files and their processing status
    files = []
    try:
        if config.pdf_dir.exists():
            for f in config.pdf_dir.iterdir():
                if f.suffix.lower() in ['.pdf', '.txt']:
                    # Check if file has been processed (has chunks in vector store)
                    processed = any(
                        doc.get('source', '').endswith(f.name)
                        for doc in vector_store.documents
                    )
                    files.append({
                        'filename': f.name,
                        'status': 'completed' if processed else 'pending',
                        'size': f.stat().st_size,
                        'modified': datetime.fromtimestamp(f.stat().st_mtime).isoformat()
                    })
    except Exception as e:
        print(f"[RAG] Error getting processing status: {e}")

    return jsonify({'files': files})


@app.route('/enrich-upload', methods=['POST'])
@require_api_key
def enrich_on_upload():
    """Upload and process with immediate enrichment"""
    if 'file' not in request.files:
        return jsonify({'error': 'No file provided'}), 400
    
    file = request.files['file']
    if not file.filename:
        return jsonify({'error': 'Empty filename'}), 400
    
    ext = Path(file.filename).suffix.lower()
    if ext not in ['.pdf', '.txt']:
        return jsonify({'error': 'Only PDF and TXT files supported'}), 400
    
    safe_filename = secure_filename(file.filename)
    file_path = config.pdf_dir / safe_filename
    file.save(str(file_path))
    
    job_id = str(uuid.uuid4())[:8]
    update_progress(job_id, 'starting', 0, 'File uploaded, starting processing with enrichment...')
    
    def process_and_enrich():
        try:
            update_progress(job_id, 'processing', 10, 'Extracting text...')
            
            text = pdf_processor.extract_text(file_path)
            if not text or len(text.strip()) < 100:
                update_progress(job_id, 'processing', 15, 'Trying OCR...')
                text = ocr_service.extract_text(file_path)
            
            if not text or len(text.strip()) < 100:
                update_progress(job_id, 'failed', 0, 'Failed to extract text')
                return
            
            update_progress(job_id, 'processing', 30, 'Chunking document...')
            
            doc_type = pdf_processor.detect_document_type(text, safe_filename)
            chunks = pdf_processor.chunk_text(text, doc_type)
            
            if not chunks:
                update_progress(job_id, 'failed', 0, 'No valid chunks extracted')
                return
            
            for i, chunk in enumerate(chunks):
                chunk['source'] = safe_filename
                chunk['chunk_id'] = i
                chunk['total_chunks'] = len(chunks)
                chunk['document_type'] = doc_type
            
            update_progress(job_id, 'processing', 50, f'Enriching {len(chunks)} chunks with LLM...')
            
            # Enrich chunks
            def enrich_progress(current, total):
                pct = 50 + int((current / total) * 30)
                update_progress(job_id, 'processing', pct, f'Enriching chunk {current}/{total}...')
            
            enriched_chunks = enrichment_service.enrich_all_chunks(chunks, enrich_progress)
            
            update_progress(job_id, 'processing', 85, 'Adding to vector store...')
            
            vector_store.replace_source(safe_filename, enriched_chunks)
            
            update_progress(job_id, 'processing', 95, 'Saving index...')
            vector_store.save()
            
            update_progress(
                job_id, 'completed', 100,
                f'Successfully processed and enriched {len(enriched_chunks)} chunks!',
                chunks_added=len(enriched_chunks),
                llm_used=enrichment_service.enrichment_used
            )
            
        except Exception as e:
            print(f"[RAG] Error: {e}")
            import traceback
            traceback.print_exc()
            update_progress(job_id, 'failed', 0, f'Error: {str(e)}')
    
    thread = threading.Thread(target=process_and_enrich, daemon=True)
    thread.start()
    
    return jsonify({
        'success': True,
        'job_id': job_id,
        'message': 'File uploaded. Processing with enrichment started.',
        'filename': safe_filename
    }), 202


# ============================================================================
# Validation & Introspection Endpoints
# ============================================================================
@app.route('/validate', methods=['GET'])
def validate_rag():
    """Run validation queries to test RAG performance"""
    if len(vector_store.documents) == 0:
        return jsonify({
            'status': 'no_documents',
            'message': 'No documents loaded. Upload legal documents first.',
            'results': []
        })

    results = []
    for query, expected_terms in VALIDATION_QUERIES:
        search_results = vector_store.search(query, k=5)
        found_terms = []

        # Check which expected terms appear in results
        combined_text = ' '.join(r.get('content', '') for r in search_results).lower()
        for term in expected_terms:
            if term.lower() in combined_text:
                found_terms.append(term)

        coverage = len(found_terms) / len(expected_terms) if expected_terms else 0

        results.append({
            'query': query,
            'expected_terms': expected_terms,
            'found_terms': found_terms,
            'coverage': round(coverage, 2),
            'result_count': len(search_results),
            'top_score': search_results[0]['score'] if search_results else 0
        })

    overall_coverage = sum(r['coverage'] for r in results) / len(results) if results else 0

    return jsonify({
        'status': 'success',
        'overall_coverage': round(overall_coverage, 2),
        'total_documents': len(vector_store.documents),
        'results': results
    })


@app.route('/legal-intelligence', methods=['GET'])
def legal_intelligence_info():
    """Get information about the legal intelligence features"""
    return jsonify({
        'features': {
            'tiered_keywords': {
                'description': 'Legal terms weighted by importance tier (1-6)',
                'tier_1_count': len(LEGAL_TIER_1),
                'tier_2_count': len(LEGAL_TIER_2),
                'tier_3_count': len(LEGAL_TIER_3),
                'tier_4_count': len(LEGAL_TIER_4),
                'tier_5_count': len(LEGAL_TIER_5),
                'tier_6_count': len(LEGAL_TIER_6),
                'total_terms': len(LEGAL_TERMS)
            },
            'concept_clusters': {
                'description': 'Related legal concepts automatically expanded in queries',
                'clusters': list(LEGAL_CONCEPT_CLUSTERS.keys())
            },
            'litigation_phases': {
                'description': 'Phase-aware search boosting',
                'phases': list(LITIGATION_PHASES.keys())
            },
            'legal_doctrines': {
                'description': 'Doctrine-specific search enhancement',
                'doctrines': list(LEGAL_DOCTRINES.keys())
            },
            'jurisdiction_hierarchy': {
                'description': 'Court authority level boosting',
                'levels': list(JURISDICTION_HIERARCHY.keys())
            },
            'citation_recognition': {
                'description': 'Legal citation pattern detection',
                'pattern_count': len(LEGAL_CITATION_PATTERNS)
            },
            'precedential_hierarchy': {
                'description': 'Source authority weighting (gold standard, persuasive, declining)',
                'levels': list(PRECEDENTIAL_HIERARCHY.keys()),
                'weights': {k: v['weight'] for k, v in PRECEDENTIAL_HIERARCHY.items()}
            },
            'argument_schemas': {
                'description': 'Legal reasoning pattern recognition',
                'schemas': list(ARGUMENT_SCHEMAS.keys()),
                'types': [s['type'] for s in ARGUMENT_SCHEMAS.values()]
            },
            'cross_references': {
                'description': 'Doctrine relationship mapping for query expansion',
                'doctrines': list(LEGAL_CROSS_REFERENCES.keys())
            }
        },
        'search_boosts': {
            'phase_aware': 'Documents from same litigation phase get 15% boost',
            'jurisdiction': 'U.S. Supreme Court: 1.5x, Federal Circuit: 1.3x, State Supreme: 1.2x',
            'doctrine_specific': '12% boost per matched doctrine',
            'citation_authority': 'Up to 24% boost for Supreme Court citations',
            'citation_treatment': 'Positive treatment: +5%, Heavy negative: -10%',
            'precedential': 'Gold standard: 2.0x, Highly persuasive: 1.5x, Declining: 0.4x',
            'temporal': 'Recent statutes boosted, landmark cases stable',
            'argument_structure': 'Well-structured arguments boosted by schema weight',
            'cross_reference': 'Up to 25% boost for related legal concepts'
        },
        'total_boost_factors': 9
    })


@app.route('/analyze-query', methods=['POST'])
@require_api_key
def analyze_query():
    """Analyze a query to show how legal intelligence will process it"""
    data = request.get_json() or {}
    query = data.get('query', '')

    if not query:
        return jsonify({'error': 'Query required'}), 400

    # Analyze the query
    phase = detect_litigation_phase(query)
    doctrines = get_relevant_doctrines(query)
    expanded_terms = expand_query_with_concepts(query)

    # Find matching tier terms
    query_lower = query.lower()
    tier_matches = {1: [], 2: [], 3: [], 4: [], 5: [], 6: []}
    for term in LEGAL_TERMS:
        if term in query_lower:
            tier = get_term_tier(term)
            if tier > 0:
                tier_matches[tier].append(term)

    return jsonify({
        'query': query,
        'detected_phase': phase,
        'relevant_doctrines': doctrines,
        'expanded_terms': expanded_terms[:20],
        'tier_matches': {k: v for k, v in tier_matches.items() if v},
        'total_expansion_terms': len(expanded_terms)
    })


# ============================================================================
# Analytics System
# ============================================================================

# Analytics data store (in-memory with file persistence)
analytics_data = {
    'queries': [],  # List of query records
    'daily_stats': {},  # Day -> aggregated stats
    'doctrinal_coverage': {},  # Doctrine -> query count
    'response_quality': [],  # Quality scores
    'session_start': datetime.now(timezone.utc).isoformat()
}
analytics_lock = threading.Lock()
ANALYTICS_FILE = None  # Set after config init

def init_analytics():
    """Initialize analytics file path and load existing data"""
    global analytics_data, ANALYTICS_FILE
    ANALYTICS_FILE = config.data_dir / 'analytics.json'
    if ANALYTICS_FILE.exists():
        try:
            with open(ANALYTICS_FILE, 'r') as f:
                loaded = json.load(f)
                analytics_data.update(loaded)
                # Keep only last 30 days of query data
                cutoff = (datetime.now(timezone.utc) - timedelta(days=30)).isoformat()
                analytics_data['queries'] = [q for q in analytics_data.get('queries', [])
                                             if q.get('timestamp', '') > cutoff]
        except Exception as e:
            print(f"[Analytics] Error loading analytics: {e}")

def save_analytics():
    """Save analytics to disk"""
    if not ANALYTICS_FILE:
        return
    try:
        with open(ANALYTICS_FILE, 'w') as f:
            json.dump(analytics_data, f)
    except Exception as e:
        print(f"[Analytics] Error saving analytics: {e}")

def track_query(query: str, response_data: dict):
    """Track a query for analytics"""
    with analytics_lock:
        now = datetime.now(timezone.utc)
        today = now.strftime('%Y-%m-%d')

        # Extract query metadata
        phase = detect_litigation_phase(query)
        doctrines = get_relevant_doctrines(query)

        # Calculate response quality score (0-100)
        quality_score = calculate_response_quality(response_data)

        # Query record
        query_record = {
            'timestamp': now.isoformat(),
            'query_length': len(query),
            'phase': phase,
            'doctrines': doctrines,
            'sources_count': len(response_data.get('sources', [])),
            'enriched_count': response_data.get('enriched_count', 0),
            'quality_score': quality_score,
            'llm_used': response_data.get('llm_used'),
            'response_time_ms': response_data.get('response_time_ms', 0)
        }

        # Add to queries list (keep last 1000)
        analytics_data['queries'].append(query_record)
        if len(analytics_data['queries']) > 1000:
            analytics_data['queries'] = analytics_data['queries'][-1000:]

        # Update daily stats
        if today not in analytics_data['daily_stats']:
            analytics_data['daily_stats'][today] = {
                'query_count': 0,
                'avg_quality': 0,
                'total_quality': 0,
                'phases': {},
                'doctrines_used': {},
                'avg_response_time': 0,
                'total_response_time': 0
            }

        daily = analytics_data['daily_stats'][today]
        daily['query_count'] += 1
        daily['total_quality'] += quality_score
        daily['avg_quality'] = daily['total_quality'] / daily['query_count']
        daily['total_response_time'] += response_data.get('response_time_ms', 0)
        daily['avg_response_time'] = daily['total_response_time'] / daily['query_count']

        # Track phase distribution
        if phase not in daily['phases']:
            daily['phases'][phase] = 0
        daily['phases'][phase] += 1

        # Track doctrine usage
        for doctrine in doctrines:
            if doctrine not in daily['doctrines_used']:
                daily['doctrines_used'][doctrine] = 0
            daily['doctrines_used'][doctrine] += 1

            # Global doctrinal coverage
            if doctrine not in analytics_data['doctrinal_coverage']:
                analytics_data['doctrinal_coverage'][doctrine] = 0
            analytics_data['doctrinal_coverage'][doctrine] += 1

        # Save periodically (every 10 queries)
        if len(analytics_data['queries']) % 10 == 0:
            save_analytics()

def calculate_response_quality(response_data: dict) -> int:
    """Calculate a quality score (0-100) for a response"""
    score = 0

    # Sources found (up to 30 points)
    sources = response_data.get('sources', [])
    if sources:
        score += min(30, len(sources) * 6)

    # Enriched sources (up to 20 points)
    enriched = response_data.get('enriched_count', 0)
    if enriched:
        score += min(20, enriched * 5)

    # Answer length/quality proxy (up to 25 points)
    answer = response_data.get('answer', '')
    if len(answer) > 100:
        score += 10
    if len(answer) > 300:
        score += 10
    if len(answer) > 500:
        score += 5

    # Themes identified (up to 15 points)
    themes = response_data.get('themes', [])
    score += min(15, len(themes) * 3)

    # Key points extracted (up to 10 points)
    key_points = response_data.get('key_points', [])
    score += min(10, len(key_points) * 2)

    return min(100, score)


@app.route('/analytics', methods=['GET'])
def get_analytics():
    """Get analytics dashboard data"""
    with analytics_lock:
        # Calculate summary metrics
        all_queries = analytics_data.get('queries', [])
        recent_queries = all_queries[-100:] if len(all_queries) > 100 else all_queries

        # Calculate averages
        avg_quality = sum(q.get('quality_score', 0) for q in recent_queries) / len(recent_queries) if recent_queries else 0
        avg_sources = sum(q.get('sources_count', 0) for q in recent_queries) / len(recent_queries) if recent_queries else 0
        avg_response_time = sum(q.get('response_time_ms', 0) for q in recent_queries) / len(recent_queries) if recent_queries else 0

        # Phase distribution
        phase_counts = {}
        for q in recent_queries:
            phase = q.get('phase', 'unknown')
            phase_counts[phase] = phase_counts.get(phase, 0) + 1

        # Recent daily stats (last 7 days)
        recent_days = sorted(analytics_data.get('daily_stats', {}).items(), reverse=True)[:7]
        daily_trends = []
        for day, stats in recent_days:
            daily_trends.append({
                'date': day,
                'queries': stats.get('query_count', 0),
                'avg_quality': round(stats.get('avg_quality', 0), 1),
                'avg_response_time': round(stats.get('avg_response_time', 0), 0)
            })

        # Top doctrines
        top_doctrines = sorted(
            analytics_data.get('doctrinal_coverage', {}).items(),
            key=lambda x: x[1],
            reverse=True
        )[:10]

        return jsonify({
            'summary': {
                'total_queries': len(all_queries),
                'avg_quality_score': round(avg_quality, 1),
                'avg_sources_per_query': round(avg_sources, 1),
                'avg_response_time_ms': round(avg_response_time, 0),
                'session_start': analytics_data.get('session_start')
            },
            'phase_distribution': phase_counts,
            'daily_trends': daily_trends,
            'top_doctrines': [{'doctrine': d, 'count': c} for d, c in top_doctrines],
            'doctrinal_coverage': {
                'covered': len(analytics_data.get('doctrinal_coverage', {})),
                'total_available': len(LEGAL_DOCTRINES)
            },
            'document_stats': {
                'total_documents': len(vector_store.documents),
                'enriched_documents': sum(1 for d in vector_store.documents if d.get('enriched')),
                'unique_sources': len(set(d.get('source', '') for d in vector_store.documents))
            }
        })


@app.route('/analytics/query-log', methods=['GET'])
def get_query_log():
    """Get recent query log"""
    with analytics_lock:
        limit = int(request.args.get('limit', 50))
        offset = int(request.args.get('offset', 0))

        queries = analytics_data.get('queries', [])
        # Return most recent first
        recent = list(reversed(queries))[offset:offset + limit]

        return jsonify({
            'queries': recent,
            'total': len(queries),
            'limit': limit,
            'offset': offset
        })


@app.route('/analytics/doctrinal-coverage', methods=['GET'])
def get_doctrinal_coverage():
    """Get detailed doctrinal coverage report"""
    with analytics_lock:
        coverage = analytics_data.get('doctrinal_coverage', {})

        # Build coverage report with all doctrines
        report = []
        for doctrine, terms in LEGAL_DOCTRINES.items():
            query_count = coverage.get(doctrine, 0)
            # Check document coverage
            doc_count = 0
            for doc in vector_store.documents:
                content = doc.get('content', '').lower()
                if any(term.lower() in content for term in terms[:5]):  # Check first 5 terms
                    doc_count += 1

            report.append({
                'doctrine': doctrine,
                'query_count': query_count,
                'document_count': doc_count,
                'key_terms': terms[:5],
                'coverage_score': min(100, (doc_count / max(1, len(vector_store.documents))) * 100 * 2)
            })

        # Sort by document coverage
        report.sort(key=lambda x: x['document_count'], reverse=True)

        return jsonify({
            'doctrines': report,
            'total_doctrines': len(LEGAL_DOCTRINES),
            'covered_in_queries': len([d for d in report if d['query_count'] > 0]),
            'covered_in_documents': len([d for d in report if d['document_count'] > 0])
        })


@app.route('/analytics/export', methods=['GET'])
def export_analytics():
    """Export analytics data as JSON"""
    with analytics_lock:
        return jsonify(analytics_data)


# Initialize analytics
try:
    from datetime import timedelta
    init_analytics()
except Exception as e:
    print(f"[Analytics] Init error: {e}")


# ============================================================================
# Main
# ============================================================================
if __name__ == '__main__':
    port = int(os.environ.get('RAG_PORT', 5007))
    print(f"[RAG] Starting Advanced Legal Document RAG Service on port {port}")
    print(f"[RAG] FAISS: {HAS_FAISS}, SentenceTransformers: {HAS_SENTENCE_TRANSFORMERS}, PyMuPDF: {HAS_PYMUPDF}")
    print(f"[RAG] Documents loaded: {len(vector_store.documents)}")
    init_analytics()
    app.run(host='0.0.0.0', port=port, debug=False)


