"""
Answer Synthesis Pipeline - 3-Layer Enhancement for RAG Quality
Fixes: Answer incoherence, chunk mixing, reasoning leaks
"""
import re
import json
import requests
import os
import hashlib
from typing import List, Dict, Optional

# Import Grok API config from existing module
try:
    from grok_api import GROK_API_KEY, GROK_MODEL
    GROK_API_URL = "https://api.x.ai/v1/chat/completions"
except ImportError:
    GROK_API_KEY = os.getenv('XAI_API_KEY', '')
    GROK_API_URL = "https://api.x.ai/v1/chat/completions"
    GROK_MODEL = "grok-2-1212"

# Quality thresholds
MIN_CHUNK_SCORE = 0.7  # Stricter than default 0.5
MAX_CHUNKS_PER_DOC = 3


def call_grok_api(
    prompt: str,
    system_prompt: Optional[str] = None,
    max_tokens: int = 900,
    temperature: float = 0.2,
    timeout: int = 60
) -> Optional[str]:
    """Call Grok API for synthesis (grounded final answer)."""
    if not GROK_API_KEY:
        return None
    
    try:
        headers = {
            "Authorization": f"Bearer {GROK_API_KEY}",
            "Content-Type": "application/json"
        }

        messages = []
        if system_prompt:
            messages.append({"role": "system", "content": system_prompt})
        messages.append({"role": "user", "content": prompt})

        payload = {
            "model": GROK_MODEL,
            "messages": messages,
            "max_tokens": max_tokens,
            "temperature": temperature,
            "stream": False
        }
        
        response = requests.post(GROK_API_URL, headers=headers, json=payload, timeout=timeout)
        response.raise_for_status()
        
        return response.json()['choices'][0]['message']['content']
    except Exception as e:
        print(f"Grok API Error: {e}")
        return None


def strip_reasoning(text: str) -> str:
    """
    Remove AI reasoning/thinking from response.
    Handles DeepSeek R1 model's verbose reasoning about documents.
    """
    if not text:
        return text
    
    # Handle DeepSeek R1's <think>...</think> tags
    # Remove everything before </think> if present
    if '</think>' in text:
        text = text.split('</think>')[-1].strip()
    
    # Also handle <think> without closing tag (incomplete)
    if '<think>' in text and '</think>' not in text:
        text = text.split('<think>')[0].strip()
    
    # Strip DeepSeek end-of-sentence tokens
    text = text.replace('<｜end▁of▁sentence｜>', '').strip()
    text = text.replace('<|end_of_sentence|>', '').strip()
    
    # Common reasoning patterns to remove (line starters)
    reasoning_starters = [
        r"^(Okay|Alright|Let me|I'll|I need to|Looking at|Based on|First,? I)",
        r"^(The user is asking|The question asks|To answer this)",
        r"^(I see that|I notice that|I can see|I found)",
        r"^(Checking|Analyzing|Reviewing|Searching)",
        r"^(First, there's|Document \d|Now,? looking at)",
        r"^(Putting it all together|So,? from this)",
        r"^(That's about|That's not relevant|Not related)",
        r"^(Also unrelated|Again not relevant)",
    ]
    
    # Find the actual answer section - look for clear answer markers
    answer_markers = [
        "The answer is",
        "In summary,",
        "To summarize,",
        "In conclusion,",
        "Therefore,",
        "The book is about",
        "William Cooper wrote",
        "The main themes are",
        "Based on the documents:",
        "**From ",  # Our structured format
    ]
    
    # Try to find where the real answer starts
    text_lower = text.lower()
    answer_start = -1
    for marker in answer_markers:
        pos = text_lower.find(marker.lower())
        if pos != -1:
            if answer_start == -1 or pos < answer_start:
                answer_start = pos
    
    # If we found an answer marker, extract from there
    if answer_start > 0 and answer_start < len(text) // 2:
        return text[answer_start:].strip()
    
    # Otherwise, filter line by line
    lines = text.split('\n')
    clean_lines = []
    skip_reasoning = True
    consecutive_reasoning = 0
    
    for line in lines:
        line_stripped = line.strip()
        
        # Skip empty lines at start
        if skip_reasoning and not line_stripped:
            continue
        
        # Check if line is reasoning
        is_reasoning = any(re.match(pattern, line_stripped, re.IGNORECASE) for pattern in reasoning_starters)
        
        # Also detect document analysis patterns
        if re.match(r'^Document \d+ (is|talks|discusses|mentions)', line_stripped, re.IGNORECASE):
            is_reasoning = True
        
        if is_reasoning:
            consecutive_reasoning += 1
            continue
        else:
            # If we've seen enough content, stop skipping
            if len(line_stripped) > 50 or not skip_reasoning:
                skip_reasoning = False
                consecutive_reasoning = 0
                clean_lines.append(line)
    
    cleaned = '\n'.join(clean_lines).strip()
    
    # If we stripped almost everything, use a different approach
    if len(cleaned) < 100 and len(text) > 200:
        # Find the last paragraph as fallback (often the conclusion)
        paragraphs = text.split('\n\n')
        for para in reversed(paragraphs):
            para = para.strip()
            if len(para) > 100 and not any(re.match(p, para, re.IGNORECASE) for p in reasoning_starters):
                return para
        # Ultimate fallback - return original
        return text
    
    return cleaned if cleaned else text


def clean_chunk_text(text: str) -> str:
    """Remove metadata artifacts from chunk text"""
    # Remove chunk references like "(Chunk 72)"
    text = re.sub(r'\(Chunk \d+\)', '', text)
    # Remove "Adjacent Context:" markers
    text = re.sub(r'Adjacent Context:.*?\n', '', text, flags=re.IGNORECASE)
    # Remove excessive whitespace
    text = re.sub(r'\s+', ' ', text).strip()
    return text


def _chunk_fingerprint(text: str) -> str:
    """Stable-ish fingerprint for deduping near-identical chunks."""
    normalized = re.sub(r'\s+', ' ', (text or '').strip().lower())
    if not normalized:
        return ''
    # Hash a prefix to keep it fast and stable
    return hashlib.md5(normalized[:600].encode('utf-8')).hexdigest()


def dedupe_chunks(chunks: List[Dict]) -> List[Dict]:
    """Remove duplicate/near-identical chunks to reduce repetition in synthesis."""
    seen = set()
    deduped = []
    for chunk in chunks:
        fp = _chunk_fingerprint(chunk.get('content', ''))
        if not fp:
            continue
        if fp in seen:
            continue
        seen.add(fp)
        deduped.append(chunk)
    return deduped


def filter_chunks(chunks: List[Dict], min_score: float = MIN_CHUNK_SCORE) -> List[Dict]:
    """Filter chunks by score and quality"""
    filtered = []
    for chunk in chunks:
        score = chunk.get('score', 0)
        if score >= min_score:
            chunk['content'] = clean_chunk_text(chunk.get('content', ''))
            filtered.append(chunk)
    return filtered


def group_chunks_by_document(chunks: List[Dict]) -> Dict[str, List[Dict]]:
    """Group chunks by their source document"""
    grouped = {}
    for chunk in chunks:
        source = chunk.get('metadata', {}).get('source', 'Unknown')
        # Clean source name
        source = re.sub(r'\.pdf$|\.txt$', '', source)
        source = source.replace('_', ' ').replace('-', ' ')
        
        if source not in grouped:
            grouped[source] = []
        grouped[source].append(chunk)
    
    # Limit chunks per document
    for source in grouped:
        grouped[source] = sorted(grouped[source], key=lambda x: x.get('score', 0), reverse=True)[:MAX_CHUNKS_PER_DOC]
    
    return grouped


def summarize_document_chunks(doc_name: str, chunks: List[Dict], query: str) -> str:
    """
    Build a compact, grounded snippet from top chunks.
    We prefer verbatim excerpts (trimmed) over weak extractive summaries.
    """
    if not chunks:
        return ""

    # Use top chunk excerpts (already limited per-doc by caller).
    excerpts = []
    seen = set()
    for chunk in chunks:
        content = clean_chunk_text(chunk.get('content', ''))
        if len(content) < 40:
            continue
        excerpt = content[:520].strip()
        fp = _chunk_fingerprint(excerpt)
        if fp and fp in seen:
            continue
        if fp:
            seen.add(fp)
        excerpts.append(excerpt)
        if len(excerpts) >= 2:
            break

    if not excerpts:
        return ""

    bullets = "\n".join([f"- {e}" for e in excerpts])
    return bullets


def enhance_retrieved_context(query: str, raw_chunks: List[Dict]) -> str:
    """
    LAYER 1-2: Transform raw chunks into clean, organized context.
    
    Args:
        query: User's question
        raw_chunks: Raw chunks from vector search
        
    Returns:
        Enhanced, organized context string
    """
    # LAYER 0: Dynamic document name extraction and boosting
    query_lower = query.lower()
    query_words = set(re.findall(r'\b\w+\b', query_lower))
    
    # Extract all unique document names from chunks
    doc_names = set()
    for chunk in raw_chunks:
        source = chunk.get('metadata', {}).get('source', '')
        if source:
            doc_names.add(source)
    
    # Build dynamic keyword-to-document mapping
    # Extract keywords from each document name
    doc_keywords = {}  # {keyword: doc_name}
    for doc_name in doc_names:
        # Clean and tokenize document name
        clean_name = re.sub(r'\.pdf$|\.txt$', '', doc_name, flags=re.IGNORECASE)
        clean_name = clean_name.replace('_', ' ').replace('-', ' ')
        # Extract significant words (3+ chars, not common words)
        stop_words = {'the', 'and', 'for', 'pdf', 'com', 'org', 'www', 'oceanofpdf', 'by'}
        words = [w.lower() for w in re.findall(r'\b\w{3,}\b', clean_name) if w.lower() not in stop_words]
        
        for word in words:
            if word not in doc_keywords:
                doc_keywords[word] = []
            doc_keywords[word].append(doc_name)
    
    # Find which documents the query is asking about
    matching_docs = set()
    for query_word in query_words:
        if query_word in doc_keywords:
            matching_docs.update(doc_keywords[query_word])
        # Also check partial matches (e.g., "cooper" in "William_Cooper_...")
        for keyword, docs in doc_keywords.items():
            if query_word in keyword or keyword in query_word:
                matching_docs.update(docs)
    
    # Layer 1: Filter and clean, boost matching documents
    filtered = []
    for chunk in raw_chunks:
        score = chunk.get('score', 0)
        source = chunk.get('metadata', {}).get('source', '')
        
        # Boost score if document matches query
        boost = 0
        if source in matching_docs:
            boost = 5.0  # Strong boost for matching document
        
        adjusted_score = score + boost
        
        if adjusted_score >= 0.6 or boost > 0:  # Include if boosted or high score
            chunk['content'] = clean_chunk_text(chunk.get('content', ''))
            chunk['adjusted_score'] = adjusted_score
            filtered.append(chunk)
    
    # Sort by adjusted score
    filtered.sort(key=lambda x: x.get('adjusted_score', x.get('score', 0)), reverse=True)

    # Dedupe to reduce repeated passages from the same OCR/run
    filtered = dedupe_chunks(filtered)
    
    if not filtered:
        return "No sufficiently relevant information found in the knowledge base."
    
    # Layer 2: Group by document and summarize
    grouped = group_chunks_by_document(filtered[:18])  # Top N chunks after dedupe
    
    # Prioritize documents that match query
    ordered_docs = []
    other_docs = []
    for doc_name in grouped.keys():
        # Check if any keyword in doc_name matches query
        clean_doc = doc_name.lower()
        matches_query = any(qw in clean_doc for qw in query_words if len(qw) > 3)
        if matches_query:
            ordered_docs.append(doc_name)
        else:
            other_docs.append(doc_name)
    
    ordered_docs.extend(other_docs)  # Query-matching docs first
    
    context_parts = []
    for doc_name in ordered_docs:
        doc_chunks = grouped[doc_name]
        summary = summarize_document_chunks(doc_name, doc_chunks, query)
        if summary:
            context_parts.append(f"SOURCE: {doc_name}\n{summary}")
    
    if not context_parts:
        return "No sufficiently relevant information found in the knowledge base."
    
    return "\n\n".join(context_parts)


def synthesize_answer(query: str, enhanced_context: str, use_grok: bool = True) -> str:
    """
    LAYER 3: Generate final coherent answer using LLM.
    
    Args:
        query: User's question
        enhanced_context: Clean, organized context from Layer 1-2
        use_grok: Whether to use Grok API (falls back to simple format if False)
        
    Returns:
        Final synthesized answer
    """
    # Always synthesize when we have a key + usable context.
    has_context = enhanced_context and "No sufficiently relevant information" not in enhanced_context

    if use_grok and GROK_API_KEY and has_context:
        system_prompt = (
            "You are CashHive RAG, an expert research assistant.\n"
            "You MUST follow these rules:\n"
            "- Use ONLY the provided SOURCES and their excerpts. Do not add outside facts.\n"
            "- Write a single coherent answer that merges the sources.\n"
            "- After factual claims, add citations like [SourceName]. Multiple citations allowed.\n"
            "- If sources conflict, explicitly note the conflict and cite both sides.\n"
            "- If the sources do not contain enough info, say what is missing.\n"
            "- Do NOT reveal chain-of-thought or hidden reasoning.\n"
        )

        synthesis_prompt = f"""QUESTION:
{query}

SOURCES (verbatim excerpts):
{enhanced_context}

Write the best possible grounded answer with citations."""

        grok_response = call_grok_api(
            prompt=synthesis_prompt,
            system_prompt=system_prompt,
            max_tokens=900,
            temperature=0.2,
            timeout=75
        )
        if grok_response:
            return strip_reasoning(grok_response)
    
    # Fallback: Format context as structured answer
    answer = f"{enhanced_context}"
    return answer


def process_rag_response(query: str, raw_response: Dict) -> Dict:
    """
    Main entry point: Process RAG response through 3-layer pipeline.
    
    Args:
        query: Original user query
        raw_response: Raw response from RAG /ask endpoint
        
    Returns:
        Enhanced response with synthesized answer
    """
    raw_chunks = raw_response.get('sources', [])
    original_answer = raw_response.get('answer', '')
    
    # Apply 3-layer enhancement
    enhanced_context = enhance_retrieved_context(query, raw_chunks)
    
    # Check if original answer needs enhancement
    needs_enhancement = (
        'Chunk' in original_answer or
        len(original_answer) < 50 or
        any(word in original_answer[:100].lower() for word in ['let me', 'looking at', 'i see'])
    )
    
    if needs_enhancement or True:  # Always enhance for now
        final_answer = synthesize_answer(query, enhanced_context)
    else:
        final_answer = strip_reasoning(original_answer)
    
    return {
        'answer': final_answer,
        'sources': raw_chunks[:5],  # Limit sources in response
        'enhanced': True,
        'source_count': len(raw_chunks)
    }


# For testing
if __name__ == "__main__":
    test_chunks = [
        {'content': '(Chunk 72) William Cooper discusses UFOs and government...', 'score': 0.85, 'metadata': {'source': 'William_Cooper_Pale_Horse.pdf'}},
        {'content': 'The secret societies have long hidden...', 'score': 0.72, 'metadata': {'source': 'William_Cooper_Pale_Horse.pdf'}},
        {'content': 'Hasheesh was introduced by doctors...', 'score': 0.45, 'metadata': {'source': 'Hasheesh_Eater.pdf'}},
    ]
    
    result = enhance_retrieved_context("What are the main themes in William Cooper's book?", test_chunks)
    print("Enhanced Context:")
    print(result)
