#!/usr/bin/env python3
"""
Public OCR API Endpoint v2.0
With real-time progress tracking and RAG integration.
"""

import os
import uuid
import json
import time
import threading
import requests
from flask import Flask, request, jsonify, Response
from flask_cors import CORS
from functools import wraps
import logging

logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger(__name__)

app = Flask(__name__)
CORS(app, origins="*")

# Configuration
API_KEY = os.environ.get("OCR_API_KEY", "deepseek-ocr-2024-secret-key")
OCR_SERVICE_URL = "http://127.0.0.1:5003"
MAX_FILE_SIZE = 100 * 1024 * 1024  # 100MB max
PROGRESS_DIR = "/tmp/ocr_progress"
TIMEOUT_SECONDS = 3600  # 60 minutes

# Ensure progress directory exists
os.makedirs(PROGRESS_DIR, exist_ok=True)

# In-memory job tracking
active_jobs = {}

def require_api_key(f):
    """Decorator to require API key authentication"""
    @wraps(f)
    def decorated(*args, **kwargs):
        api_key = request.headers.get("X-API-Key") or request.args.get("api_key")
        if not api_key or api_key != API_KEY:
            return jsonify({"error": "Invalid or missing API key"}), 401
        return f(*args, **kwargs)
    return decorated

def update_job_progress(job_id, **kwargs):
    """Update job progress in memory and file"""
    if job_id not in active_jobs:
        active_jobs[job_id] = {
            "job_id": job_id,
            "status": "initializing",
            "progress": 0,
            "current_page": 0,
            "total_pages": 0,
            "chars_extracted": 0,
            "words_extracted": 0,
            "start_time": time.time(),
            "elapsed_seconds": 0,
            "estimated_remaining": 0,
            "filename": "",
            "message": "Initializing..."
        }
    
    active_jobs[job_id].update(kwargs)
    active_jobs[job_id]["elapsed_seconds"] = round(time.time() - active_jobs[job_id]["start_time"], 1)
    
    # Estimate remaining time
    if active_jobs[job_id]["total_pages"] > 0 and active_jobs[job_id]["current_page"] > 0:
        avg_time_per_page = active_jobs[job_id]["elapsed_seconds"] / active_jobs[job_id]["current_page"]
        remaining_pages = active_jobs[job_id]["total_pages"] - active_jobs[job_id]["current_page"]
        active_jobs[job_id]["estimated_remaining"] = round(avg_time_per_page * remaining_pages, 1)
    
    # Save to file for persistence and RAG access
    progress_file = os.path.join(PROGRESS_DIR, f"{job_id}.json")
    try:
        with open(progress_file, "w") as f:
            json.dump(active_jobs[job_id], f)
    except Exception as e:
        logger.error(f"Failed to save progress: {e}")

def get_job_progress(job_id):
    """Get job progress from memory or file"""
    if job_id in active_jobs:
        return active_jobs[job_id]
    
    progress_file = os.path.join(PROGRESS_DIR, f"{job_id}.json")
    if os.path.exists(progress_file):
        try:
            with open(progress_file, "r") as f:
                return json.load(f)
        except:
            pass
    return None

def cleanup_old_jobs():
    """Clean up jobs older than 1 hour"""
    now = time.time()
    to_remove = []
    for job_id, job in active_jobs.items():
        if now - job.get("start_time", now) > 3600:
            to_remove.append(job_id)
    for job_id in to_remove:
        del active_jobs[job_id]
        progress_file = os.path.join(PROGRESS_DIR, f"{job_id}.json")
        if os.path.exists(progress_file):
            os.unlink(progress_file)

@app.route("/", methods=["GET"])
def index():
    return jsonify({
        "service": "DeepSeek OCR API",
        "version": "2.0",
        "endpoints": {
            "/health": "Check service status",
            "/ocr/pdf": "Extract text from PDF (POST, multipart/form-data)",
            "/ocr/pdf/async": "Start async PDF OCR job (returns job_id)",
            "/ocr/progress/<job_id>": "Get job progress",
            "/ocr/result/<job_id>": "Get completed job result",
            "/ocr/image": "Extract text from image (POST, multipart/form-data)",
            "/ocr/url": "Extract text from webpage (POST, JSON)",
            "/ocr/status": "Get all active OCR jobs (for RAG)"
        },
        "authentication": "X-API-Key header or api_key query parameter",
        "timeout": f"{TIMEOUT_SECONDS} seconds (60 minutes)"
    })

@app.route("/health", methods=["GET"])
def health():
    """Public health check"""
    try:
        resp = requests.get(f"{OCR_SERVICE_URL}/health", timeout=5)
        data = resp.json()
        return jsonify({
            "status": "ok",
            "ocr_service": data,
            "api_ready": True,
            active_jobs: data.get(active_jobs, 0)
        })
    except Exception as e:
        return jsonify({
            "status": "error",
            "error": str(e),
            "api_ready": False
        }), 503

@app.route("/ocr/status", methods=["GET"])
def ocr_status():
    """Get status of all active OCR jobs - for RAG system integration"""
    cleanup_old_jobs()
    return jsonify({
        "active_jobs": len(active_jobs),
        "jobs": list(active_jobs.values())
    })

@app.route("/ocr/progress/<job_id>", methods=["GET"])
def ocr_progress(job_id):
    """Get progress of a specific OCR job - fetches from internal service"""
    # First try to get from internal OCR service (most accurate)
    try:
        resp = requests.get(f"{OCR_SERVICE_URL}/progress/{job_id}", timeout=5)
        if resp.status_code == 200:
            return jsonify(resp.json())
    except:
        pass
    
    # Fall back to local cache
    progress = get_job_progress(job_id)
    if progress:
        return jsonify(progress)
    return jsonify({"error": "Job not found"}), 404

@app.route("/ocr/result/<job_id>", methods=["GET"])
def ocr_result(job_id):
    """Get result of completed OCR job"""
    progress = get_job_progress(job_id)
    if not progress:
        return jsonify({"error": "Job not found"}), 404
    
    if progress.get("status") != "completed":
        return jsonify({
            "error": "Job not completed",
            "status": progress.get("status"),
            "progress": progress.get("progress", 0)
        }), 202
    
    return jsonify(progress)

@app.route("/ocr/pdf", methods=["POST"])
@require_api_key
def ocr_pdf():
    """OCR a PDF file with progress tracking (synchronous)"""
    if "file" not in request.files:
        return jsonify({"error": "No file provided"}), 400
    
    file = request.files["file"]
    if not file.filename:
        return jsonify({"error": "Empty filename"}), 400
    
    if not file.filename.lower().endswith(".pdf"):
        return jsonify({"error": "Only PDF files supported"}), 400
    
    job_id = request.form.get("job_id") or str(uuid.uuid4())[:8]
    max_pages = request.form.get("max_pages", 100, type=int)
    
    update_job_progress(job_id,
        status="uploading",
        filename=file.filename,
        message=f"Uploading {file.filename}...",
        progress=5
    )
    
    try:
        # Forward to internal OCR service with progress tracking
        files = {"file": (file.filename, file.stream, "application/pdf")}
        data = {"max_pages": max_pages, "job_id": job_id}
        
        update_job_progress(job_id,
            status="processing",
            message="Starting OCR processing...",
            progress=10
        )
        
        logger.info(f"[{job_id}] Processing PDF: {file.filename}, max_pages={max_pages}")
        
        # Start progress polling thread
        stop_polling = threading.Event()
        
        def poll_progress():
            while not stop_polling.is_set():
                try:
                    resp = requests.get(f"{OCR_SERVICE_URL}/progress/{job_id}", timeout=5)
                    if resp.status_code == 200:
                        data = resp.json()
                        update_job_progress(job_id, **data)
                except:
                    pass
                time.sleep(2)
        
        poll_thread = threading.Thread(target=poll_progress, daemon=True)
        poll_thread.start()
        
        try:
            resp = requests.post(
                f"{OCR_SERVICE_URL}/ocr",
                files=files,
                data=data,
                timeout=TIMEOUT_SECONDS
            )
        finally:
            stop_polling.set()
        
        result = resp.json()
        
        if result.get("success"):
            text = result.get("text", "")
            words = text.split() if text else []
            
            update_job_progress(job_id,
                status="completed",
                progress=100,
                message="OCR completed successfully!",
                text=text,
                word_count=len(words),
                char_count=len(text) if text else 0
            )
            
            return jsonify({
                "success": True,
                "job_id": job_id,
                "text": text,
                "word_count": len(words),
                "char_count": len(text) if text else 0,
                "filename": file.filename,
                "elapsed_seconds": active_jobs[job_id]["elapsed_seconds"]
            })
        else:
            update_job_progress(job_id,
                status="failed",
                message=result.get("error", "OCR failed"),
                progress=0
            )
            return jsonify(result), 422
            
    except requests.Timeout:
        update_job_progress(job_id,
            status="timeout",
            message="OCR processing timed out",
            progress=0
        )
        return jsonify({"error": "OCR processing timed out", "job_id": job_id}), 504
    except Exception as e:
        update_job_progress(job_id,
            status="error",
            message=str(e),
            progress=0
        )
        logger.error(f"[{job_id}] OCR error: {e}")
        return jsonify({"error": str(e), "job_id": job_id}), 500

@app.route("/ocr/image", methods=["POST"])
@require_api_key
def ocr_image():
    """OCR a single image file"""
    if "file" not in request.files:
        return jsonify({"error": "No file provided"}), 400
    
    file = request.files["file"]
    if not file.filename:
        return jsonify({"error": "Empty filename"}), 400
    
    allowed_ext = [".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", ".webp"]
    if not any(file.filename.lower().endswith(ext) for ext in allowed_ext):
        return jsonify({"error": f"Supported formats: {allowed_ext}"}), 400
    
    job_id = str(uuid.uuid4())[:8]
    
    try:
        start_time = time.time()
        files = {"file": (file.filename, file.stream, file.content_type or "image/jpeg")}
        
        update_job_progress(job_id,
            status="processing",
            filename=file.filename,
            message="Processing image...",
            total_pages=1,
            current_page=1,
            progress=50
        )
        
        logger.info(f"[{job_id}] Processing image: {file.filename}")
        
        resp = requests.post(
            f"{OCR_SERVICE_URL}/ocr_image",
            files=files,
            timeout=300
        )
        
        result = resp.json()
        elapsed = time.time() - start_time
        
        text = result.get("text", "")
        words = text.split() if text else []
        
        update_job_progress(job_id,
            status="completed",
            progress=100,
            message="Image OCR completed!",
            word_count=len(words),
            char_count=len(text) if text else 0
        )
        
        return jsonify({
            "success": True,
            "job_id": job_id,
            "text": text,
            "word_count": len(words),
            "char_count": len(text) if text else 0,
            "filename": file.filename,
            "elapsed_seconds": round(elapsed, 2)
        })
        
    except Exception as e:
        update_job_progress(job_id, status="error", message=str(e))
        logger.error(f"[{job_id}] Image OCR error: {e}")
        return jsonify({"error": str(e)}), 500

@app.route("/ocr/url", methods=["POST"])
@require_api_key
def ocr_url():
    """Capture and OCR a webpage"""
    data = request.get_json(silent=True) or {}
    url = data.get("url")
    
    if not url:
        return jsonify({"error": "No URL provided in JSON body"}), 400
    
    job_id = str(uuid.uuid4())[:8]
    
    try:
        start_time = time.time()
        
        update_job_progress(job_id,
            status="processing",
            filename=url,
            message="Capturing webpage screenshot...",
            progress=30
        )
        
        logger.info(f"[{job_id}] Processing URL: {url}")
        
        resp = requests.post(
            f"{OCR_SERVICE_URL}/scan_url",
            json={"url": url},
            timeout=300
        )
        
        result = resp.json()
        elapsed = time.time() - start_time
        
        text = result.get("text", "")
        words = text.split() if text else []
        
        update_job_progress(job_id,
            status="completed",
            progress=100,
            message="URL OCR completed!"
        )
        
        return jsonify({
            "success": result.get("success", False),
            "job_id": job_id,
            "text": text,
            "word_count": len(words),
            "char_count": len(text) if text else 0,
            "url": url,
            "elapsed_seconds": round(elapsed, 2)
        })
        
    except Exception as e:
        update_job_progress(job_id, status="error", message=str(e))
        logger.error(f"[{job_id}] URL OCR error: {e}")
        return jsonify({"error": str(e)}), 500

if __name__ == "__main__":
    port = int(os.environ.get("OCR_API_PORT", 5004))
    logger.info(f"Starting public OCR API v2.0 on port {port}")
    logger.info(f"Timeout: {TIMEOUT_SECONDS} seconds")
    app.run(host="127.0.0.1", port=port, debug=False, threaded=True)
