#!/usr/bin/env python3
"""
DeepSeek-OCR Primary Service v2.0
GPU-accelerated OCR with real-time progress tracking.
Runs on port 5003.

Endpoints:
- /health - Health check & model status
- /ocr - OCR a PDF file with progress tracking
- /ocr_image - OCR a single image
- /progress/<job_id> - Get job progress
- /scan_url - Capture & OCR a webpage
"""

import os
import sys
import torch
import tempfile
import time
import json
import threading
from pathlib import Path
from flask import Flask, request, jsonify
from PIL import Image
from pdf2image import convert_from_path
import logging
import requests
import subprocess

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s')
logger = logging.getLogger(__name__)

# Check flash_attn availability at startup
FLASH_ATTN_AVAILABLE = False
try:
    import flash_attn
    FLASH_ATTN_AVAILABLE = True
    logger.info(f"flash_attn v{flash_attn.__version__} loaded successfully!")
except ImportError as e:
    logger.warning(f"flash_attn not available: {e}")

app = Flask(__name__)

# Model configuration
MODEL_NAME = 'deepseek-ai/DeepSeek-OCR'
model = None
tokenizer = None

# Progress tracking
PROGRESS_DIR = "/tmp/ocr_progress"
os.makedirs(PROGRESS_DIR, exist_ok=True)
active_jobs = {}
jobs_lock = threading.Lock()

def update_progress(job_id, **kwargs):
    """Thread-safe progress update"""
    with jobs_lock:
        if job_id not in active_jobs:
            active_jobs[job_id] = {
                "job_id": job_id,
                "status": "initializing",
                "progress": 0,
                "current_page": 0,
                "total_pages": 0,
                "chars_extracted": 0,
                "pages_text": {},
                "start_time": time.time(),
                "elapsed_seconds": 0,
                "estimated_remaining": 0,
                "message": "Initializing..."
            }
        
        active_jobs[job_id].update(kwargs)
        active_jobs[job_id]["elapsed_seconds"] = round(time.time() - active_jobs[job_id]["start_time"], 1)
        
        # Calculate progress percentage
        if active_jobs[job_id]["total_pages"] > 0:
            active_jobs[job_id]["progress"] = round(
                (active_jobs[job_id]["current_page"] / active_jobs[job_id]["total_pages"]) * 100, 1
            )
            
            # Estimate remaining time based on average time per page
            if active_jobs[job_id]["current_page"] > 0:
                avg_time = active_jobs[job_id]["elapsed_seconds"] / active_jobs[job_id]["current_page"]
                remaining = active_jobs[job_id]["total_pages"] - active_jobs[job_id]["current_page"]
                active_jobs[job_id]["estimated_remaining"] = round(avg_time * remaining, 1)
        
        # Save to file for external access
        progress_file = os.path.join(PROGRESS_DIR, f"{job_id}.json")
        try:
            # Don't save pages_text to file (too large)
            save_data = {k: v for k, v in active_jobs[job_id].items() if k != "pages_text"}
            with open(progress_file, "w") as f:
                json.dump(save_data, f)
        except Exception as e:
            logger.error(f"Failed to save progress: {e}")

def get_progress(job_id):
    """Get job progress"""
    with jobs_lock:
        if job_id in active_jobs:
            return {k: v for k, v in active_jobs[job_id].items() if k != "pages_text"}
    
    progress_file = os.path.join(PROGRESS_DIR, f"{job_id}.json")
    if os.path.exists(progress_file):
        try:
            with open(progress_file, "r") as f:
                return json.load(f)
        except:
            pass
    return None

def load_model():
    """Load DeepSeek-OCR model (lazy loading on first request)"""
    global model, tokenizer
    
    if model is not None:
        return True
    
    try:
        logger.info(f"Loading DeepSeek-OCR model: {MODEL_NAME}")
        logger.info(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'N/A'}")
        logger.info(f"CUDA available: {torch.cuda.is_available()}")
        logger.info(f"flash_attn available: {FLASH_ATTN_AVAILABLE}")
        
        from transformers import AutoModel, AutoTokenizer
        
        tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
        
        attn_impl = 'eager'
        if FLASH_ATTN_AVAILABLE:
            attn_impl = 'flash_attention_2'
            logger.info("Using flash_attention_2 for maximum speed")
        else:
            try:
                if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
                    attn_impl = 'sdpa'
                    logger.info("Using SDPA attention")
            except:
                logger.info("Using eager attention (fallback)")
        
        model = AutoModel.from_pretrained(
            MODEL_NAME,
            torch_dtype=torch.bfloat16,
            device_map='cuda',
            trust_remote_code=True,
            attn_implementation=attn_impl
        )
        
        logger.info(f"Model loaded successfully with {attn_impl} attention!")
        return True
        
    except Exception as e:
        logger.error(f"Failed to load model: {e}")
        import traceback
        traceback.print_exc()
        return False

def parse_grounding_result(result_text):
    """Parse grounding format to extract plain text."""
    import re
    if not result_text:
        return ""
    
    pattern = r'<\|ref\|>(.*?)<\|/ref\|>'
    matches = re.findall(pattern, str(result_text))
    
    if matches:
        return ' '.join(matches)
    
    return str(result_text).strip()

def run_ocr_on_image(image_path, job_id=None, page_num=None, total_pages=None):
    """Run OCR on a single image file"""
    import io
    
    try:
        prompt = "<image>\n<|grounding|>Extract all text from this image."
        
        with tempfile.TemporaryDirectory() as output_dir:
            old_stdout = sys.stdout
            captured_output = io.StringIO()
            sys.stdout = captured_output
            
            try:
                result = model.infer(
                    tokenizer,
                    prompt=prompt,
                    image_file=image_path,
                    output_path=output_dir,
                    base_size=1024,
                    image_size=640,
                    crop_mode=True,
                    save_results=False,
                    test_compress=False
                )
            finally:
                sys.stdout = old_stdout
            
            stdout_text = captured_output.getvalue()
            
            logger.info(f"OCR captured {len(stdout_text)} chars from stdout")
            
            if stdout_text and '<|ref|>' in stdout_text:
                parsed = parse_grounding_result(stdout_text)
                logger.info(f"Parsed {len(parsed)} chars from grounding format")
                return parsed
            elif stdout_text:
                return stdout_text.strip()
            
            if result:
                logger.info(f"Using return value: {type(result)}")
                if '<|ref|>' in str(result):
                    return parse_grounding_result(result)
                return result
            
            return None
    except Exception as e:
        logger.error(f"OCR error: {e}")
        import traceback
        traceback.print_exc()
        return None

def extract_text_from_pdf(pdf_path, max_pages=50, job_id=None):
    """Convert PDF pages to images and run OCR with progress tracking"""
    try:
        if job_id:
            update_progress(job_id, 
                status="converting",
                message="Converting PDF to images..."
            )
        
        logger.info(f"Converting PDF to images: {pdf_path}")
        
        images = convert_from_path(
            pdf_path,
            dpi=200,
            first_page=1,
            last_page=max_pages
        )
        
        total_pages = len(images)
        logger.info(f"Converted {total_pages} pages")
        
        if job_id:
            update_progress(job_id,
                status="processing",
                total_pages=total_pages,
                message=f"Starting OCR on {total_pages} pages..."
            )
        
        all_text = []
        total_chars = 0
        
        for i, image in enumerate(images):
            page_num = i + 1
            page_start_time = time.time()
            
            if job_id:
                update_progress(job_id,
                    current_page=page_num,
                    message=f"Processing page {page_num}/{total_pages}..."
                )
            
            logger.info(f"OCR on page {page_num}/{total_pages}")
            
            with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp:
                image.save(tmp.name, 'PNG')
                
                try:
                    text = run_ocr_on_image(tmp.name, job_id, page_num, total_pages)
                    page_chars = len(text) if text else 0
                    
                    if text:
                        all_text.append(f"--- Page {page_num} ---\n{text}")
                        total_chars += page_chars
                        logger.info(f"Page {page_num}: extracted {page_chars} chars")
                    else:
                        all_text.append(f"--- Page {page_num} ---\n[No text detected]")
                        logger.warning(f"Page {page_num} OCR returned no text")
                    
                    page_elapsed = time.time() - page_start_time
                    
                    if job_id:
                        update_progress(job_id,
                            chars_extracted=total_chars,
                            message=f"Page {page_num}/{total_pages} complete ({page_chars} chars, {page_elapsed:.1f}s)"
                        )
                    
                finally:
                    os.unlink(tmp.name)
        
        if job_id:
            update_progress(job_id,
                status="completed",
                progress=100,
                current_page=total_pages,
                chars_extracted=total_chars,
                message=f"OCR complete! Extracted {total_chars:,} characters from {total_pages} pages."
            )
        
        if all_text:
            return '\n\n'.join(all_text), None
        else:
            return None, "OCR produced no text from any page"
            
    except Exception as e:
        logger.error(f"PDF OCR error: {e}")
        import traceback
        traceback.print_exc()
        if job_id:
            update_progress(job_id, status="error", message=str(e))
        return None, str(e)

def capture_webpage_screenshot(url, output_path):
    """Capture screenshot of webpage using Playwright"""
    try:
        from playwright.sync_api import sync_playwright
        
        with sync_playwright() as p:
            browser = p.chromium.launch(headless=True)
            page = browser.new_page(viewport={'width': 1920, 'height': 1080})
            page.goto(url, wait_until='networkidle', timeout=30000)
            page.screenshot(path=output_path, full_page=True)
            browser.close()
        
        return True
        
    except ImportError:
        logger.warning("Playwright not installed, trying fallback")
        
        try:
            cmd = [
                'google-chrome',
                '--headless',
                '--disable-gpu',
                '--screenshot=' + output_path,
                '--window-size=1920,1080',
                url
            ]
            subprocess.run(cmd, timeout=30, capture_output=True)
            return os.path.exists(output_path)
        except Exception as e:
            logger.error(f"Chrome headless failed: {e}")
            return False
    
    except Exception as e:
        logger.error(f"Screenshot capture failed: {e}")
        return False

@app.route('/health', methods=['GET'])
def health():
    """Health check endpoint"""
    return jsonify({
        'status': 'ok',
        'service': 'DeepSeek-OCR',
        'version': '2.0',
        'model_loaded': model is not None,
        'gpu_available': torch.cuda.is_available(),
        'gpu_name': torch.cuda.get_device_name(0) if torch.cuda.is_available() else None,
        'gpu_memory_gb': round(torch.cuda.get_device_properties(0).total_memory / 1e9, 1) if torch.cuda.is_available() else 0,
        'flash_attn': FLASH_ATTN_AVAILABLE,
        'active_jobs': len(active_jobs)
    })

@app.route('/progress/<job_id>', methods=['GET'])
def progress(job_id):
    """Get progress of a specific OCR job"""
    progress_data = get_progress(job_id)
    if progress_data:
        return jsonify(progress_data)
    return jsonify({'error': 'Job not found'}), 404

@app.route('/jobs', methods=['GET'])
def list_jobs():
    """List all active jobs - for RAG system"""
    with jobs_lock:
        jobs = [{k: v for k, v in job.items() if k != "pages_text"} 
                for job in active_jobs.values()]
    return jsonify({
        'active_jobs': len(jobs),
        'jobs': jobs
    })

@app.route('/ocr', methods=['POST'])
def ocr_pdf():
    """Run OCR on a PDF file with progress tracking."""
    if 'file' not in request.files:
        return jsonify({'error': 'No file provided'}), 400
    
    file = request.files['file']
    if not file.filename:
        return jsonify({'error': 'Empty filename'}), 400
    
    if not load_model():
        return jsonify({'error': 'Failed to load OCR model'}), 500
    
    # Get job_id from request or generate one
    job_id = request.form.get('job_id') or f"ocr_{int(time.time())}"
    max_pages = request.form.get('max_pages', 50, type=int)
    
    update_progress(job_id,
        status="uploading",
        message=f"Receiving {file.filename}...",
        filename=file.filename
    )
    
    with tempfile.NamedTemporaryFile(suffix='.pdf', delete=False) as tmp:
        file.save(tmp.name)
        
        try:
            start_time = time.time()
            text, error = extract_text_from_pdf(tmp.name, max_pages=max_pages, job_id=job_id)
            elapsed = time.time() - start_time
            
            if error:
                update_progress(job_id, status="error", message=error)
                return jsonify({'error': error, 'text': None, 'job_id': job_id}), 422
            
            return jsonify({
                'success': True,
                'text': text,
                'chars': len(text) if text else 0,
                'filename': file.filename,
                'elapsed_seconds': round(elapsed, 2),
                'job_id': job_id
            })
            
        finally:
            os.unlink(tmp.name)

@app.route('/ocr_image', methods=['POST'])
def ocr_image():
    """Run OCR on a single image file."""
    if 'file' not in request.files:
        return jsonify({'error': 'No file provided'}), 400
    
    file = request.files['file']
    if not file.filename:
        return jsonify({'error': 'Empty filename'}), 400
    
    if not load_model():
        return jsonify({'error': 'Failed to load OCR model'}), 500
    
    suffix = Path(file.filename).suffix or '.jpg'
    with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
        file.save(tmp.name)
        
        try:
            start_time = time.time()
            text = run_ocr_on_image(tmp.name)
            elapsed = time.time() - start_time
            
            return jsonify({
                'success': True,
                'text': text,
                'chars': len(text) if text else 0,
                'filename': file.filename,
                'elapsed_seconds': round(elapsed, 2)
            })
            
        finally:
            os.unlink(tmp.name)

@app.route('/scan_url', methods=['POST'])
def scan_url():
    """Capture screenshot of a URL and run OCR on it."""
    data = request.get_json(silent=True) or {}
    url = data.get('url') or request.form.get('url')
    
    if not url:
        return jsonify({'error': 'No URL provided'}), 400
    
    if not load_model():
        return jsonify({'error': 'Failed to load OCR model'}), 500
    
    with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp:
        try:
            start_time = time.time()
            
            if not capture_webpage_screenshot(url, tmp.name):
                return jsonify({'error': 'Failed to capture webpage screenshot'}), 422
            
            text = run_ocr_on_image(tmp.name)
            elapsed = time.time() - start_time
            
            return jsonify({
                'success': True,
                'text': text,
                'chars': len(text) if text else 0,
                'url': url,
                'elapsed_seconds': round(elapsed, 2)
            })
            
        finally:
            if os.path.exists(tmp.name):
                os.unlink(tmp.name)

if __name__ == '__main__':
    if '--preload' in sys.argv:
        logger.info("Preloading model at startup...")
        load_model()
    
    host = os.environ.get('DEEPSEEK_OCR_HOST', '127.0.0.1')
    port = int(os.environ.get('DEEPSEEK_OCR_PORT', 5003))
    logger.info(f"Starting DeepSeek-OCR v2.0 service on port {port}")
    logger.info(f"flash_attn available: {FLASH_ATTN_AVAILABLE}")
    app.run(host=host, port=port, debug=False, threaded=True)
