#!/usr/bin/env python3
"""
DeepSeek-OCR Primary Service
GPU-accelerated OCR service using DeepSeek-OCR model.
Runs on port 5003 and supports multiple systems.

Endpoints:
- /health - Health check & model status
- /ocr - OCR a PDF file
- /ocr_image - OCR a single image
- /scan_url - Capture & OCR a webpage
"""

import os
import sys
import torch
import tempfile
import time
from pathlib import Path
from flask import Flask, request, jsonify
from PIL import Image
from pdf2image import convert_from_path
import logging
import requests
import subprocess

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s')
logger = logging.getLogger(__name__)

# Check flash_attn availability at startup
FLASH_ATTN_AVAILABLE = False
try:
    import flash_attn
    FLASH_ATTN_AVAILABLE = True
    logger.info(f"flash_attn v{flash_attn.__version__} loaded successfully!")
except ImportError as e:
    logger.warning(f"flash_attn not available: {e}")

app = Flask(__name__)

# Model configuration
MODEL_NAME = 'deepseek-ai/DeepSeek-OCR'
model = None
tokenizer = None

def load_model():
    """Load DeepSeek-OCR model (lazy loading on first request)"""
    global model, tokenizer
    
    if model is not None:
        return True
    
    try:
        logger.info(f"Loading DeepSeek-OCR model: {MODEL_NAME}")
        logger.info(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'N/A'}")
        logger.info(f"CUDA available: {torch.cuda.is_available()}")
        logger.info(f"flash_attn available: {FLASH_ATTN_AVAILABLE}")
        
        from transformers import AutoModel, AutoTokenizer
        
        tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
        
        # Determine attention implementation
        attn_impl = 'eager'  # Default fallback
        if FLASH_ATTN_AVAILABLE:
            attn_impl = 'flash_attention_2'
            logger.info("Using flash_attention_2 for maximum speed")
        else:
            # Try SDPA as intermediate option
            try:
                if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
                    attn_impl = 'sdpa'
                    logger.info("Using SDPA attention (flash_attn not available)")
            except:
                logger.info("Using eager attention (fallback)")
        
        model = AutoModel.from_pretrained(
            MODEL_NAME,
            torch_dtype=torch.bfloat16,
            device_map='cuda',
            trust_remote_code=True,
            attn_implementation=attn_impl
        )
        
        logger.info(f"Model loaded successfully with {attn_impl} attention!")
        return True
        
    except Exception as e:
        logger.error(f"Failed to load model: {e}")
        import traceback
        traceback.print_exc()
        return False

def parse_grounding_result(result_text):
    """Parse grounding format to extract plain text.
    Format: <|ref|>text<|/ref|><|det|>[[coords]]<|/det|>
    """
    import re
    if not result_text:
        return ""
    
    # Extract text between <|ref|> and <|/ref|> tags
    pattern = r'<\|ref\|>(.*?)<\|/ref\|>'
    matches = re.findall(pattern, str(result_text))
    
    if matches:
        return ' '.join(matches)
    
    # If no grounding tags, return as-is (might be plain text)
    return str(result_text).strip()

def run_ocr_on_image(image_path):
    """Run OCR on a single image file"""
    import sys
    import io
    
    try:
        prompt = "<image>\n<|grounding|>Extract all text from this image."
        
        # Create temp directory for output (required by model.infer)
        with tempfile.TemporaryDirectory() as output_dir:
            # Capture stdout since model.infer prints results instead of returning them
            old_stdout = sys.stdout
            captured_output = io.StringIO()
            sys.stdout = captured_output
            
            try:
                result = model.infer(
                    tokenizer,
                    prompt=prompt,
                    image_file=image_path,
                    output_path=output_dir,
                    base_size=1024,
                    image_size=640,
                    crop_mode=True,
                    save_results=False,
                    test_compress=False
                )
            finally:
                sys.stdout = old_stdout
            
            # Get captured stdout
            stdout_text = captured_output.getvalue()
            
            logger.info(f"OCR captured {len(stdout_text)} chars from stdout")
            
            # Parse grounding format from captured output
            if stdout_text and '<|ref|>' in stdout_text:
                parsed = parse_grounding_result(stdout_text)
                logger.info(f"Parsed {len(parsed)} chars from grounding format")
                return parsed
            elif stdout_text:
                return stdout_text.strip()
            
            # Fallback to return value if stdout was empty
            if result:
                logger.info(f"Using return value: {type(result)}")
                if '<|ref|>' in str(result):
                    return parse_grounding_result(result)
                return result
            
            return None
    except Exception as e:
        logger.error(f"OCR error: {e}")
        import traceback
        traceback.print_exc()
        return None

def extract_text_from_pdf(pdf_path, max_pages=50):
    """Convert PDF pages to images and run OCR on each"""
    try:
        logger.info(f"Converting PDF to images: {pdf_path}")
        
        # Convert PDF to images
        images = convert_from_path(
            pdf_path,
            dpi=200,
            first_page=1,
            last_page=max_pages
        )
        
        logger.info(f"Converted {len(images)} pages")
        
        all_text = []
        for i, image in enumerate(images):
            logger.info(f"OCR on page {i+1}/{len(images)}")
            
            # Save image temporarily
            with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp:
                image.save(tmp.name, 'PNG')
                
                try:
                    text = run_ocr_on_image(tmp.name)
                    if text:
                        all_text.append(text)
                    else:
                        logger.warning(f"Page {i+1} OCR returned no text")
                finally:
                    os.unlink(tmp.name)
        
        if all_text:
            return '\n\n'.join(all_text), None
        else:
            return None, "OCR produced no text from any page"
            
    except Exception as e:
        logger.error(f"PDF OCR error: {e}")
        import traceback
        traceback.print_exc()
        return None, str(e)

def capture_webpage_screenshot(url, output_path):
    """Capture screenshot of webpage using Playwright"""
    try:
        from playwright.sync_api import sync_playwright
        
        with sync_playwright() as p:
            browser = p.chromium.launch(headless=True)
            page = browser.new_page(viewport={'width': 1920, 'height': 1080})
            page.goto(url, wait_until='networkidle', timeout=30000)
            page.screenshot(path=output_path, full_page=True)
            browser.close()
        
        return True
        
    except ImportError:
        logger.warning("Playwright not installed, trying fallback methods")
        
        try:
            cmd = [
                'google-chrome',
                '--headless',
                '--disable-gpu',
                '--screenshot=' + output_path,
                '--window-size=1920,1080',
                url
            ]
            subprocess.run(cmd, timeout=30, capture_output=True)
            return os.path.exists(output_path)
        except Exception as e:
            logger.error(f"Chrome headless failed: {e}")
            return False
    
    except Exception as e:
        logger.error(f"Screenshot capture failed: {e}")
        return False

@app.route('/health', methods=['GET'])
def health():
    """Health check endpoint"""
    return jsonify({
        'status': 'ok',
        'service': 'DeepSeek-OCR',
        'model_loaded': model is not None,
        'gpu_available': torch.cuda.is_available(),
        'gpu_name': torch.cuda.get_device_name(0) if torch.cuda.is_available() else None,
        'gpu_memory_gb': round(torch.cuda.get_device_properties(0).total_memory / 1e9, 1) if torch.cuda.is_available() else 0,
        'flash_attn': FLASH_ATTN_AVAILABLE
    })

@app.route('/ocr', methods=['POST'])
def ocr_pdf():
    """Run OCR on a PDF file."""
    if 'file' not in request.files:
        return jsonify({'error': 'No file provided'}), 400
    
    file = request.files['file']
    if not file.filename:
        return jsonify({'error': 'Empty filename'}), 400
    
    if not load_model():
        return jsonify({'error': 'Failed to load OCR model'}), 500
    
    with tempfile.NamedTemporaryFile(suffix='.pdf', delete=False) as tmp:
        file.save(tmp.name)
        
        try:
            start_time = time.time()
            max_pages = request.form.get('max_pages', 50, type=int)
            text, error = extract_text_from_pdf(tmp.name, max_pages=max_pages)
            elapsed = time.time() - start_time
            
            if error:
                return jsonify({'error': error, 'text': None}), 422
            
            return jsonify({
                'success': True,
                'text': text,
                'chars': len(text) if text else 0,
                'filename': file.filename,
                'elapsed_seconds': round(elapsed, 2)
            })
            
        finally:
            os.unlink(tmp.name)

@app.route('/ocr_image', methods=['POST'])
def ocr_image():
    """Run OCR on a single image file."""
    if 'file' not in request.files:
        return jsonify({'error': 'No file provided'}), 400
    
    file = request.files['file']
    if not file.filename:
        return jsonify({'error': 'Empty filename'}), 400
    
    if not load_model():
        return jsonify({'error': 'Failed to load OCR model'}), 500
    
    suffix = Path(file.filename).suffix or '.jpg'
    with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
        file.save(tmp.name)
        
        try:
            start_time = time.time()
            text = run_ocr_on_image(tmp.name)
            elapsed = time.time() - start_time
            
            return jsonify({
                'success': True,
                'text': text,
                'chars': len(text) if text else 0,
                'filename': file.filename,
                'elapsed_seconds': round(elapsed, 2)
            })
            
        finally:
            os.unlink(tmp.name)

@app.route('/scan_url', methods=['POST'])
def scan_url():
    """Capture screenshot of a URL and run OCR on it."""
    data = request.get_json(silent=True) or {}
    url = data.get('url') or request.form.get('url')
    
    if not url:
        return jsonify({'error': 'No URL provided'}), 400
    
    if not load_model():
        return jsonify({'error': 'Failed to load OCR model'}), 500
    
    with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp:
        try:
            start_time = time.time()
            
            if not capture_webpage_screenshot(url, tmp.name):
                return jsonify({'error': 'Failed to capture webpage screenshot'}), 422
            
            text = run_ocr_on_image(tmp.name)
            elapsed = time.time() - start_time
            
            return jsonify({
                'success': True,
                'text': text,
                'chars': len(text) if text else 0,
                'url': url,
                'elapsed_seconds': round(elapsed, 2)
            })
            
        finally:
            if os.path.exists(tmp.name):
                os.unlink(tmp.name)

@app.route('/extract_orders', methods=['POST'])
def extract_orders():
    """Scan a URL for order data."""
    data = request.get_json(silent=True) or {}
    url = data.get('url') or request.form.get('url')
    
    if not url:
        return jsonify({'error': 'No URL provided'}), 400
    
    if not load_model():
        return jsonify({'error': 'Failed to load OCR model'}), 500
    
    with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp:
        try:
            start_time = time.time()
            
            if not capture_webpage_screenshot(url, tmp.name):
                return jsonify({'error': 'Failed to capture webpage screenshot'}), 422
            
            prompt = "<image>\n<|grounding|>Extract all order information from this page. Format as JSON with fields: order_id, customer_name, items, quantities, prices, total, status, date."
            
            result = model.infer(
                tokenizer,
                prompt=prompt,
                image_file=tmp.name,
                base_size=1024,
                image_size=640,
                crop_mode=True,
                save_results=False,
                test_compress=False
            )
            
            elapsed = time.time() - start_time
            
            return jsonify({
                'success': True,
                'orders_raw': result,
                'url': url,
                'elapsed_seconds': round(elapsed, 2)
            })
            
        finally:
            if os.path.exists(tmp.name):
                os.unlink(tmp.name)

if __name__ == '__main__':
    # Pre-load model on startup (optional)
    if '--preload' in sys.argv:
        logger.info("Preloading model at startup...")
        load_model()
    
    port = int(os.environ.get('DEEPSEEK_OCR_PORT', 5003))
    logger.info(f"Starting DeepSeek-OCR service on port {port}")
    logger.info(f"flash_attn available: {FLASH_ATTN_AVAILABLE}")
    app.run(host='0.0.0.0', port=port, debug=False)
