
import os
import json
import torch
import requests
import threading
import time
from flask import Flask, request, jsonify, Response, stream_with_context
from flask_cors import CORS
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from datetime import datetime

app = Flask(__name__)
CORS(app)

# Configuration
MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
MODEL_PATH = "/opt/deepseek-env/models"
MEMORY_FILE = "/var/www/html/leadgen/airagagent/data/admin_memory.json"
QUERY_LOG_FILE = "/var/www/html/leadgen/airagagent/data/query_log.jsonl"
RAG_SERVICE_URL = "http://127.0.0.1:5000/ask"
PORT = 5004

# Grok API Configuration
GROK_API_KEY = os.getenv('XAI_API_KEY', '')
GROK_API_URL = "https://api.x.ai/v1/chat/completions"
GROK_MODEL = "grok-2-1212"

def classify_query(query_text):
    """Determine if query needs Grok (complex) or DeepSeek (factual)"""
    query_lower = query_text.lower()
    
    # Complex analytical signals
    complex_signals = [
        any(word in query_lower for word in ['compare', 'difference', 'versus', 'vs', 'contrast']),
        any(word in query_lower for word in ['why', 'explain', 'analyze', 'synthesize']),
        any(word in query_lower for word in ['relationship', 'connection', 'pattern']),
        len(query_text.split(' and ')) > 1,  # Multi-part question
        '?' in query_text and len(query_text) > 100  # Long complex question
    ]
    
    # Factual/simple signals
    simple_signals = [
        query_lower.startswith(('what is', 'who is', 'when', 'where')),
        len(query_text.split()) < 10,  # Short question
        'list' in query_lower or 'summarize' in query_lower
    ]
    
    # Decision logic
    if sum(complex_signals) >= 2:
        return 'grok'  # Complex analysis
    elif sum(simple_signals) >= 1:
        return 'deepseek'  # Quick factual
    else:
        return 'deepseek'  # Default to local model

def log_query(query, response, sources, llm_used, response_time):
    """Log queries for future evaluation"""
    try:
        log_entry = {
            'timestamp': datetime.now().isoformat(),
            'query': query,
            'response_preview': response[:200] if response else '',
            'num_sources': len(sources) if sources else 0,
            'sources': [s.get('source', 'Unknown') for s in sources] if sources else [],
            'llm_used': llm_used,
            'response_time_sec': round(response_time, 2),
            'query_length': len(query),
            'response_length': len(response) if response else 0
        }
        
        # Ensure log directory exists
        os.makedirs(os.path.dirname(QUERY_LOG_FILE), exist_ok=True)
        
        with open(QUERY_LOG_FILE, 'a') as f:
            f.write(json.dumps(log_entry) + '\n')
    except Exception as e:
        print(f"⚠️  Query logging failed: {e}")

def call_grok_api(messages, max_tokens=1024):
    """Call Grok API for complex synthesis"""
    if not GROK_API_KEY:
        print("⚠️  Grok API key not set, falling back to DeepSeek")
        return None
    
    try:
        headers = {
            "Authorization": f"Bearer {GROK_API_KEY}",
            "Content-Type": "application/json"
        }
        
        payload = {
            "model": GROK_MODEL,
            "messages": messages,
            "max_tokens": max_tokens,
            "temperature": 0.7,
            "stream": False
        }
        
        response = requests.post(GROK_API_URL, headers=headers, json=payload, timeout=30)
        response.raise_for_status()
        
        data = response.json()
        return data['choices'][0]['message']['content']
    
    except Exception as e:
        print(f"❌ Grok API Error: {e}")
        return None

class MemoryManager:
    def __init__(self, filepath):
        self.filepath = filepath
        self.lock = threading.Lock()
        self.ensure_memory_file()

    def ensure_memory_file(self):
        if not os.path.exists(os.path.dirname(self.filepath)):
            os.makedirs(os.path.dirname(self.filepath), exist_ok=True)
        if not os.path.exists(self.filepath):
            with open(self.filepath, 'w') as f:
                json.dump({"history": [], "preferences": {}, "system_context": {}}, f)

    def load_memory(self):
        with self.lock:
            try:
                with open(self.filepath, 'r') as f:
                    return json.load(f)
            except Exception:
                return {"history": [], "preferences": {}, "system_context": {}}

    def save_memory(self, data):
        with self.lock:
            with open(self.filepath, 'w') as f:
                json.dump(data, f, indent=2)

    def add_message(self, role, content):
        memory = self.load_memory()
        memory["history"].append({
            "role": role,
            "content": content,
            "timestamp": datetime.now().isoformat()
        })
        # Keep last 50 messages
        if len(memory["history"]) > 50:
            memory["history"] = memory["history"][-50:]
        self.save_memory(memory)
    
    def get_history(self):
        return self.load_memory().get("history", [])

    def clear_history(self):
        memory = self.load_memory()
        memory["history"] = []
        self.save_memory(memory)

class DeepSeekChatService:
    def __init__(self):
        print(f"Loading model from {MODEL_PATH}...")
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=MODEL_PATH)
            self.model = AutoModelForCausalLM.from_pretrained(
                MODEL_NAME,
                cache_dir=MODEL_PATH,
                torch_dtype=torch.bfloat16,
                device_map="auto"
            )
            print("Model loaded successfully!")
        except Exception as e:
            print(f"Failed to load model: {e}")
            # Fallback for testing/setup without model
            self.model = None
            self.tokenizer = None

        self.memory_manager = MemoryManager(MEMORY_FILE)

    def get_rag_context(self, query):
        print(f"🔍 RAG Query: {query}")
        try:
            # RAG queries can take 100+ seconds, so use long timeout
            response = requests.post(RAG_SERVICE_URL, json={"question": query}, timeout=120)
            print(f"📡 RAG Response Status: {response.status_code}")
            if response.status_code == 200:
                data = response.json()
                print(f"📦 RAG Data Keys: {list(data.keys())}")
                # Check if sources exist (don't require source=='rag' as it's not always set)
                if 'sources' in data and data['sources']:
                    # Extract actual document chunks
                    context_parts = []
                    for idx, source in enumerate(data['sources'][:15], 1):  # Top 15 sources for broader context
                        # Clean up source name to be more readable
                        raw_source = source.get('source', 'Unknown Document')
                        # Remove file extensions and clean up underscores/hyphens
                        clean_name = raw_source.replace('.pdf', '').replace('.txt', '')
                        clean_name = clean_name.replace('_', ' ').replace('-', ' ')
                        # Capitalize properly
                        clean_name = ' '.join(word.capitalize() for word in clean_name.split())
                        
                        doc_info = f"\n**📚 {clean_name}**\n"
                        if source.get('title') and source.get('title') != raw_source:
                            doc_info += f"Section: {source['title']}\n"
                        if source.get('summary'):
                            doc_info += f"{source['summary']}\n"
                        if source.get('key_points'):
                            doc_info += f"Key Points: {', '.join(source['key_points'][:5])}\n"
                        if source.get('preview'):
                            doc_info += f"Excerpt: \"{source['preview'][:400]}...\"\n"
                        elif source.get('content'):
                            doc_info += f"Excerpt: \"{source['content'][:400]}...\"\n"
                        context_parts.append(doc_info)
                    
                    if context_parts:
                        result = "\n---\n".join(context_parts)
                        print(f"✅ RAG Context Built: {len(result)} chars from {len(context_parts)} docs")
                        return result
            return None
        except Exception as e:
            print(f"❌ RAG Error: {e}")
            return None

    def generate_response(self, user_message, use_rag=True, system_context=None):
        start_time = time.time()
        
        if not self.model:
            yield "Model not loaded correctly. Please check server logs."
            return

        # 0. Classify Query - decide which LLM to use
        llm_choice = classify_query(user_message)
        print(f"🤖 Query Classification: {llm_choice.upper()}")

        # 1. Get Context (Memory + RAG)
        print(f"📥 New Request | use_rag={use_rag} | message: {user_message[:50]}")
        history = self.memory_manager.get_history()
        rag_context = ""
        sources_list = []
        
        if use_rag:
            print("🔄 RAG is enabled, querying...")
            rag_context = self.get_rag_context(user_message)
            if rag_context:
                print(f"✅ RAG Context retrieved successfully")
                # Extract sources for logging
                if hasattr(self, '_last_rag_sources'):
                    sources_list = self._last_rag_sources
            else:
                print(f"⚠️ RAG returned no context")
        else:
            print("❌ RAG is DISABLED")
        
        # 2. Construct Prompt
        base_system_prompt = """You are an expert research assistant with access to a curated knowledge base of rare and valuable documents.

CORE MISSION: Provide comprehensive, insightful answers by deeply synthesizing information from the provided documents.

ANSWER QUALITY STANDARDS:
1. CITE SOURCES NATURALLY: Reference documents by their actual title (e.g., "According to 'The Lost Civilization of Lemuria'...")
2. SYNTHESIZE ACROSS SOURCES: When multiple documents are relevant, weave their insights together into a coherent narrative
3. INCLUDE SPECIFIC DETAILS: Quote key passages, cite specific claims, reference page numbers or sections when available
4. PROVIDE DEPTH: Don't just summarize - explain significance, draw connections, highlight patterns across sources
5. BE AUTHORITATIVE: Speak with confidence about what the documents contain
6. STRUCTURE CLEARLY: Use paragraphs, bullet points, or numbered lists for complex answers

IMPORTANT: The documents provided are from a verified knowledge base. Base your answers EXCLUSIVELY on this content.
If asked about something not covered in the documents, say so clearly.

When you receive document excerpts, analyze them thoroughly and provide substantive answers that would satisfy a serious researcher."""

        # Inject dynamic system context (stats, user info) if provided
        if system_context:
            base_system_prompt += f"\n\nCurrent System Status:\n{json.dumps(system_context, indent=2)}"
        
        conversation = [{"role": "system", "content": base_system_prompt}]
        
        # Add history (last 10 turns for context)
        for msg in history[-10:]:
            conversation.append({"role": msg["role"], "content": msg["content"]})
            
        # Add RAG context if available - make it very explicit
        if rag_context:
            context_msg = f"""IMPORTANT: I am providing you with VERIFIED DOCUMENTS from the knowledge base.
You MUST base your answer EXCLUSIVELY on this information. Do NOT use external knowledge.

=== DOCUMENT CONTEXT START ===
{rag_context}
=== DOCUMENT CONTEXT END ===

Now, using ONLY the information from the documents above, answer this question.
Start your response by citing which document you're using.

User's Question: {user_message}"""
            conversation.append({"role": "user", "content": context_msg})
        else:
            conversation.append({"role": "user", "content": user_message})

        # 3. Generate - Route to appropriate LLM
        full_response = ""
        
        if llm_choice == 'grok' and GROK_API_KEY:
            # Use Grok for complex synthesis
            print("🚀 Routing to Grok API for complex analysis...")
            grok_response = call_grok_api(conversation, max_tokens=1024)
            
            if grok_response:
                full_response = grok_response
                yield grok_response
                print(f"✅ Grok response: {len(grok_response)} chars")
            else:
                # Fallback to DeepSeek if Grok fails
                print("⚠️ Grok failed, falling back to DeepSeek")
                llm_choice = 'deepseek_fallback'
                for chunk in self._generate_with_deepseek(conversation):
                    full_response += chunk
                    yield chunk
        else:
            # Use DeepSeek for factual/quick queries
            print(f"⚡ Using DeepSeek (local) for {llm_choice} query...")
            for chunk in self._generate_with_deepseek(conversation):
                full_response += chunk
                yield chunk

        # 4. Log query for evaluation
        response_time = time.time() - start_time
        log_query(user_message, full_response, sources_list, llm_choice, response_time)

        # 5. Save to Memory
        self.memory_manager.add_message("user", user_message)
        self.memory_manager.add_message("assistant", full_response)

    def _generate_with_deepseek(self, conversation):
        """Helper to generate with local DeepSeek model"""
        input_text = self.tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
        inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device)
        
        streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, decode_kwargs={"skip_special_tokens": True})
        generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=512, temperature=0.7)
        
        thread = threading.Thread(target=self.model.generate, kwargs=generation_kwargs)
        thread.start()

        for new_text in streamer:
            yield new_text

chat_service = DeepSeekChatService()

@app.route('/health', methods=['GET'])
def health():
    return jsonify({"status": "running", "model": MODEL_NAME})

@app.route('/chat', methods=['POST'])
def chat():
    data = request.json
    message = data.get('message', '')
    use_rag = data.get('use_rag', True)
    system_context = data.get('system_context', None)
    
    if not message:
        return jsonify({"error": "Message required"}), 400

    def generate():
        full_response = ""
        for chunk in chat_service.generate_response(message, use_rag, system_context):
            full_response += chunk
        
        # Strip reasoning from response before sending
        from answer_synthesis import strip_reasoning
        cleaned = strip_reasoning(full_response)
        yield cleaned

    return Response(stream_with_context(generate()), mimetype='text/plain')

@app.route('/history', methods=['GET'])
def get_history():
    return jsonify(chat_service.memory_manager.get_history())

@app.route('/history', methods=['DELETE'])
def clear_history():
    chat_service.memory_manager.clear_history()
    return jsonify({"status": "cleared"})

@app.route('/debug', methods=['POST'])
def debug():
    """Debug endpoint to see what's being received and processed"""
    data = request.json
    message = data.get('message', 'test query')
    use_rag = data.get('use_rag', True)
    
    debug_info = {
        "received_data": {
            "message": message,
            "use_rag": use_rag,
            "use_rag_type": type(use_rag).__name__,
            "raw_json": data
        },
        "rag_test": {}
    }
    
    # Test RAG query if enabled
    if use_rag:
        try:
            rag_context = chat_service.get_rag_context(message)
            debug_info["rag_test"] = {
                "status": "success" if rag_context else "empty",
                "context_length": len(rag_context) if rag_context else 0,
                "context_preview": rag_context[:500] if rag_context else None
            }
        except Exception as e:
            debug_info["rag_test"] = {
                "status": "error",
                "error": str(e)
            }
    else:
        debug_info["rag_test"] = {"status": "disabled"}
    
    return jsonify(debug_info)

if __name__ == '__main__':
    host = os.environ.get("DEEPSEEK_CHAT_HOST", "127.0.0.1")
    port = int(os.environ.get("DEEPSEEK_CHAT_PORT", PORT))
    app.run(host=host, port=port)
