#!/usr/bin/env python3
"""
Threaded backtest runner for TheCryptoClaw bot.

Features:
- Worker thread for backtest execution (non-blocking)
- Progress file updated every 500 candles (bot can poll)
- CPU throttling via os.nice() and periodic sleeps
- Hard timeout via signal alarm
- Chunked CSV loading for memory efficiency
"""

import sys
import os
import json
import time
import signal
import threading
import argparse
from datetime import datetime

os.chdir(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, ".")

# Lower CPU priority immediately
try:
    os.nice(10)
except OSError:
    pass

PROGRESS_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "backtest_progress.json")
RESULTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "results")

def write_progress(data):
    """Atomically write progress JSON."""
    tmp = PROGRESS_FILE + ".tmp"
    try:
        with open(tmp, "w") as f:
            json.dump(data, f, indent=2, default=str)
        os.replace(tmp, PROGRESS_FILE)
    except Exception:
        pass

def timeout_handler(signum, frame):
    """Hard timeout - write final status and exit."""
    write_progress({"status": "timeout", "message": "Hard timeout reached", "updated_at": datetime.now().isoformat()})
    print("\nHARD TIMEOUT - backtest killed")
    os._exit(1)


def _save_final_result(data: dict):
    """Save final result to both progress file and timestamped results JSON.
    
    Called on completion AND on graceful shutdown (SIGTERM/SIGINT) so results
    are never lost even if the process is killed.
    """
    os.makedirs(RESULTS_DIR, exist_ok=True)
    run_id = datetime.now().strftime("%Y%m%d_%H%M%S")
    result_path = os.path.join(RESULTS_DIR, f"result_{run_id}.json")
    try:
        with open(result_path, "w") as f:
            json.dump(data, f, indent=2)
        print(f"Results saved to {result_path}")
    except Exception as e:
        print(f"WARNING: Failed to save results: {e}")
    # Also update progress file
    write_progress(data)


# Global reference so signal handlers can access last known progress
_last_progress = {}
_last_progress_lock = threading.Lock()

def _update_last_progress(data: dict):
    """Thread-safe update of last known progress for crash recovery."""
    global _last_progress
    with _last_progress_lock:
        _last_progress = data.copy()

def _graceful_shutdown(signum, frame):
    """Handle SIGTERM/SIGINT - save whatever progress we have."""
    with _last_progress_lock:
        progress = _last_progress.copy()
    if progress and progress.get("status") == "running":
        progress["status"] = "interrupted"
        progress["message"] = f"Process killed by signal {signum}"
        progress["updated_at"] = datetime.now().isoformat()
        _save_final_result(progress)
        print(f"\nInterrupted by signal {signum} - progress saved")
    os._exit(1)

def resolve_csv_paths(csv_arg: str) -> list:
    """
    Resolve CSV argument to one or more file paths.
    
    Supports:
      - Direct file path: "BTCUSD_PAST6MONTHS.csv" or "data/monthly/BTCUSD_202601.csv"
      - Month shorthand: "1month" → latest 1 month, "3month" → latest 3 months
      - Specific month: "202601" → data/monthly/BTCUSD_202601.csv
      - Month range: "202510-202512" → Oct, Nov, Dec 2025
      - Legacy: "6month" → full 6-month file, "5year" → full 5-year file
    
    Returns list of CSV file paths to load in order.
    """
    MONTHLY_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data", "monthly")
    INDEX_FILE = os.path.join(MONTHLY_DIR, "index.json")
    
    # Direct file path
    if os.path.isfile(csv_arg):
        return [csv_arg]
    
    # Legacy aliases — prefer monthly files if available
    if csv_arg in ("6month", "BTCUSD_PAST6MONTHS.csv"):
        if os.path.isfile(INDEX_FILE):
            # Use monthly files for the 6-month range instead of the 2.1GB monolith
            with open(INDEX_FILE) as f:
                index = json.load(f)
            available = sorted(index.keys())
            # Last 6+ months
            selected = available[-7:] if len(available) >= 7 else available
            return [index[m]["path"] for m in selected]
        return ["BTCUSD_PAST6MONTHS.csv"]
    if csv_arg in ("5year", "BTCUSD.csv"):
        if os.path.isfile(INDEX_FILE):
            # Use all monthly files instead of the 25GB monolith
            with open(INDEX_FILE) as f:
                index = json.load(f)
            return [index[m]["path"] for m in sorted(index.keys())]
        return ["BTCUSD.csv"]
    
    # Monthly index required for everything else
    if not os.path.isfile(INDEX_FILE):
        print(f"WARNING: Monthly index not found at {INDEX_FILE}, falling back to full file")
        return ["BTCUSD_PAST6MONTHS.csv"]
    
    with open(INDEX_FILE) as f:
        index = json.load(f)
    
    available = sorted(index.keys())  # ["202508", "202509", ...]
    
    # Specific month: "202601"
    if len(csv_arg) == 6 and csv_arg.isdigit():
        if csv_arg in index:
            return [index[csv_arg]["path"]]
        raise FileNotFoundError(f"No monthly data for {csv_arg}")
    
    # Month range: "202510-202512"
    if "-" in csv_arg and len(csv_arg) == 13:
        start_m, end_m = csv_arg.split("-")
        paths = [index[m]["path"] for m in available if start_m <= m <= end_m]
        if not paths:
            raise FileNotFoundError(f"No data in range {csv_arg}")
        return paths
    
    # N-month shorthand: "1month", "2month", "3month", etc.
    import re
    match = re.match(r"^(\d+)month$", csv_arg)
    if match:
        n = int(match.group(1))
        if n >= len(available):
            # Want more months than we have split → use full file
            return ["BTCUSD_PAST6MONTHS.csv"]
        selected = available[-n:]
        return [index[m]["path"] for m in selected]
    
    # Fallback: try as filename in monthly dir
    monthly_path = os.path.join(MONTHLY_DIR, csv_arg)
    if os.path.isfile(monthly_path):
        return [monthly_path]
    
    raise FileNotFoundError(f"Cannot resolve CSV: {csv_arg}")


class ThrottledBacktest:
    """Run backtest with throttling and progress reporting."""

    def __init__(self, csv_path, max_rows, timeframe="5min", throttle_ms=2):
        self.csv_path = csv_path
        self.max_rows = max_rows
        self.timeframe = timeframe
        self.throttle_ms = throttle_ms  # sleep this many ms every 100 candles
        self.result = None
        self.error = None
        self.start_time = None

    def load_data(self):
        """Load tick data and resample to candles. Supports multi-file loading."""
        import pandas as pd
        from data_loader import DataLoader, OHLCV

        # Resolve to file list
        csv_files = resolve_csv_paths(self.csv_path)
        file_desc = ", ".join(os.path.basename(f) for f in csv_files)

        write_progress({
            "status": "loading",
            "message": f"Loading {file_desc} (max {self.max_rows} rows)...",
            "files": [os.path.basename(f) for f in csv_files],
            "updated_at": datetime.now().isoformat()
        })

        # Load all files, concatenate
        frames = []
        total_loaded = 0
        rows_left = self.max_rows if (self.max_rows and self.max_rows > 0) else None

        for csv_file in csv_files:
            read_params = {"filepath_or_buffer": csv_file}
            if rows_left is not None and rows_left > 0:
                read_params["nrows"] = rows_left

            chunk = pd.read_csv(**read_params)
            frames.append(chunk)
            total_loaded += len(chunk)

            if rows_left is not None:
                rows_left -= len(chunk)
                if rows_left <= 0:
                    break

            write_progress({
                "status": "loading",
                "message": f"Loaded {os.path.basename(csv_file)} ({len(chunk):,} ticks). Total: {total_loaded:,}",
                "updated_at": datetime.now().isoformat()
            })

        df = pd.concat(frames, ignore_index=True) if len(frames) > 1 else frames[0]
        del frames
        tick_count = len(df)

        write_progress({
            "status": "resampling",
            "message": f"Resampling {tick_count:,} ticks to {self.timeframe} candles...",
            "ticks_loaded": tick_count,
            "updated_at": datetime.now().isoformat()
        })

        # Parse timestamps
        df["datetime"] = pd.to_datetime(
            df["Timestamp"].str.replace(r":(\d{3})$", r".\1", regex=True),
            format="%Y%m%d %H:%M:%S.%f"
        )
        df["mid"] = (df["Bid price"] + df["Ask price"]) / 2.0
        df = df.set_index("datetime")

        ohlcv = df["mid"].resample(self.timeframe).ohlc()
        ohlcv.columns = ["open", "high", "low", "close"]
        ohlcv["volume"] = df["Bid volume"].resample(self.timeframe).sum()
        ohlcv = ohlcv.dropna()

        # Free tick data memory
        del df
        import gc
        gc.collect()

        # Convert to DataLoader
        import settings
        loader = DataLoader(settings.SYMBOL)
        candles = []
        for ts, row in ohlcv.iterrows():
            candles.append(OHLCV(
                timestamp=ts.to_pydatetime(),
                open=float(row["open"]),
                high=float(row["high"]),
                low=float(row["low"]),
                close=float(row["close"]),
                volume=float(row["volume"]),
            ))
        loader.data = candles

        write_progress({
            "status": "ready",
            "message": f"Ready: {len(candles):,} candles from {tick_count:,} ticks",
            "ticks_loaded": tick_count,
            "candles": len(candles),
            "date_start": str(ohlcv.index[0]),
            "date_end": str(ohlcv.index[-1]),
            "price_low": float(ohlcv["low"].min()),
            "price_high": float(ohlcv["high"].max()),
            "updated_at": datetime.now().isoformat()
        })

        return loader

    def _export_trade_log(self, engine):
        """Export full trade-by-trade log to CSV for post-run analysis.
        
        Columns: trade_id, side, lots, entry_price, entry_time, exit_price,
                 exit_time, profit, commission, magic (thread_id), comment 
                 (recovery level info), duration_sec, return_pct
        """
        import csv
        
        trades = engine.order_engine.trade_history
        if not trades:
            return
        
        os.makedirs(RESULTS_DIR, exist_ok=True)
        run_id = datetime.now().strftime("%Y%m%d_%H%M%S")
        log_path = os.path.join(RESULTS_DIR, f"trades_{run_id}.csv")
        
        with open(log_path, "w", newline="") as f:
            writer = csv.writer(f)
            writer.writerow([
                "trade_id", "side", "lots", "entry_price", "entry_time",
                "exit_price", "exit_time", "profit", "magic_thread",
                "comment", "duration_min", "return_pct"
            ])
            
            for t in trades:
                # Calculate duration
                duration_min = None
                if t.open_time and t.close_time:
                    duration_min = round((t.close_time - t.open_time).total_seconds() / 60, 1)
                
                # Calculate return % on trade
                return_pct = None
                if t.price and t.price > 0 and t.lots > 0:
                    cost_basis = t.lots * t.price
                    return_pct = round((t.profit / cost_basis) * 100, 4) if cost_basis > 0 else 0
                
                writer.writerow([
                    t.order_id,
                    t.side.value if hasattr(t.side, 'value') else str(t.side),
                    round(t.lots, 4),
                    round(t.price, 2) if t.price else None,
                    str(t.open_time) if t.open_time else None,
                    round(t.close_price, 2) if t.close_price else None,
                    str(t.close_time) if t.close_time else None,
                    round(t.profit, 2),
                    t.magic,
                    t.comment,
                    duration_min,
                    return_pct,
                ])
        
        print(f"Trade log exported: {log_path} ({len(trades)} trades)")
        
        # Also save thread-level summary
        thread_stats = {}
        for t in trades:
            mid = t.magic
            if mid not in thread_stats:
                thread_stats[mid] = {"trades": 0, "profit": 0, "max_level": 0}
            thread_stats[mid]["trades"] += 1
            thread_stats[mid]["profit"] += t.profit
            # Extract recovery level from comment (REC_L3_BUY → 3)
            if t.comment and "REC_L" in t.comment:
                try:
                    level = int(t.comment.split("REC_L")[1].split("_")[0])
                    thread_stats[mid]["max_level"] = max(thread_stats[mid]["max_level"], level)
                except (ValueError, IndexError):
                    pass
        
        thread_path = os.path.join(RESULTS_DIR, f"threads_{run_id}.csv")
        with open(thread_path, "w", newline="") as f:
            writer = csv.writer(f)
            writer.writerow(["thread_id", "trades", "total_profit", "max_recovery_level"])
            for mid, stats in sorted(thread_stats.items()):
                writer.writerow([mid, stats["trades"], round(stats["profit"], 2), stats["max_level"]])
        
        print(f"Thread summary: {thread_path} ({len(thread_stats)} threads)")

    def run(self):
        """Run the backtest with throttling and progress updates."""
        import settings
        from backtest import BacktestEngine
        from utils import setup_logging

        setup_logging(console=False, file=True)
        self.start_time = time.time()

        # Load data
        data = self.load_data()
        total = len(data)

        # Create engine
        engine = BacktestEngine(
            symbol=settings.SYMBOL,
            initial_balance=settings.INITIAL_BALANCE,
            commission_percent=settings.COMMISSION_PERCENT,
            slippage_pips=settings.SLIPPAGE_PIPS,
        )

        # Process candles with throttling and progress
        processed = 0
        for candle in data.iterate_candles():
            engine.process_candle(candle)
            processed += 1

            # Throttle: small sleep every 100 candles to yield CPU
            if self.throttle_ms > 0 and processed % 100 == 0:
                time.sleep(self.throttle_ms / 1000.0)

            # Progress update every 500 candles
            if processed % 500 == 0 or processed == total:
                elapsed = time.time() - self.start_time
                equity = engine.order_engine.get_equity()
                balance = engine.order_engine.balance
                pct_done = processed / total * 100
                rate = processed / elapsed if elapsed > 0 else 0
                eta = (total - processed) / rate if rate > 0 else 0

                progress_data = {
                    "status": "running",
                    "processed": processed,
                    "total": total,
                    "pct": round(pct_done, 1),
                    "balance": round(balance, 2),
                    "equity": round(equity, 2),
                    "return_pct": round((balance - settings.INITIAL_BALANCE) / settings.INITIAL_BALANCE * 100, 2),
                    "max_equity": round(engine.max_equity, 2),
                    "min_equity": round(engine.min_equity, 2),
                    "max_dd_pct": round((engine.max_drawdown / engine.peak_equity * 100) if engine.peak_equity > 0 else 0, 2),
                    "trades": len(engine.order_engine.trade_history),
                    "open_positions": len(engine.order_engine.positions),
                    "current_price": round(candle.close, 2),
                    "current_date": str(candle.timestamp),
                    "elapsed_sec": round(elapsed, 1),
                    "candles_per_sec": round(rate, 1),
                    "eta_sec": round(eta, 1),
                    "updated_at": datetime.now().isoformat()
                }
                write_progress(progress_data)
                _update_last_progress(progress_data)

        # Close remaining positions
        from order_engine import OrderSide
        bid, ask = engine.order_engine.get_current_price()
        for position in list(engine.order_engine.positions.values()):
            if position.side == OrderSide.BUY:
                engine.order_engine.close_position(position.position_id, bid)
            else:
                engine.order_engine.close_position(position.position_id, ask)

        # Export trade log CSV
        self._export_trade_log(engine)

        # Compile results
        result = engine._compile_results(data)
        elapsed = time.time() - self.start_time

        # Write final results
        final = {
            "status": "complete",
            "elapsed_sec": round(elapsed, 1),
            "net_profit": round(result.total_return, 2),
            "return_pct": round(result.total_return_percent, 2),
            "total_trades": result.total_trades,
            "win_rate": round(result.win_rate, 2),
            "profit_factor": round(result.gross_profit / result.gross_loss, 2) if result.gross_loss > 0 else 999,
            "max_dd_pct": round(result.max_drawdown_percent, 2),
            "min_equity": round(result.min_equity, 2),
            "max_equity": round(result.max_equity, 2),
            "gross_profit": round(result.gross_profit, 2),
            "gross_loss": round(result.gross_loss, 2),
            "commission": round(result.total_commission, 2),
            "total_threads": result.total_threads,
            "max_recovery_level": result.max_recovery_level,
            "date_start": str(result.start_date) if result.start_date else None,
            "date_end": str(result.end_date) if result.end_date else None,
            "candles": total,
            "updated_at": datetime.now().isoformat()
        }
        _save_final_result(final)

        return result


def main():
    parser = argparse.ArgumentParser(description="Throttled backtest runner")
    parser.add_argument("--csv", default="1month", help="CSV file or period (1month, 3month, 6month, 202601, etc)")
    parser.add_argument("--max-rows", type=int, default=2000000, help="Max tick rows (0=all)")
    parser.add_argument("--tf", default="5min", help="Candle timeframe")
    parser.add_argument("--throttle", type=int, default=2, help="Sleep ms per 100 candles")
    parser.add_argument("--timeout", type=int, default=480, help="Hard timeout seconds")
    parser.add_argument("--background", action="store_true", help="Run in background thread")
    args = parser.parse_args()

    # Set hard timeout
    signal.signal(signal.SIGALRM, timeout_handler)
    signal.alarm(args.timeout)

    # Register graceful shutdown handlers (save progress on kill)
    signal.signal(signal.SIGTERM, _graceful_shutdown)
    signal.signal(signal.SIGINT, _graceful_shutdown)

    max_rows = args.max_rows if args.max_rows > 0 else None

    runner = ThrottledBacktest(
        csv_path=args.csv,
        max_rows=max_rows,
        timeframe=args.tf,
        throttle_ms=args.throttle
    )

    if args.background:
        # Run in thread, main thread monitors
        thread = threading.Thread(target=_run_worker, args=(runner,), daemon=True)
        thread.start()
        print(f"Backtest started in background thread. Monitor: cat {PROGRESS_FILE}")
        # Wait for thread or timeout
        thread.join()
    else:
        result = runner.run()
        result.print_summary()


def _run_worker(runner):
    """Worker function for threaded execution."""
    try:
        result = runner.run()
        print("\n" + "=" * 60)
        result.print_summary()
    except Exception as e:
        write_progress({"status": "error", "error": str(e), "updated_at": datetime.now().isoformat()})
        print(f"ERROR: {e}")


if __name__ == "__main__":
    main()
