#!/usr/bin/env python3
"""
Run multiple aggressive parameter variants against BTCUSD M5 data.
Overrides settings at runtime, no file edits needed.
"""

import sys, os
os.chdir(os.path.dirname(os.path.abspath(__file__)))

import settings
from run_btc_backtest import load_ticks_to_candles
from backtest import BacktestEngine
from utils import setup_logging

CSV_PATH = "BTCUSD.csv"

# ============================================================================
# VARIANT DEFINITIONS
# ============================================================================
# Each variant overrides specific settings values.
# Everything not listed inherits from the current (adaptive) baseline.

VARIANTS = {
    "BASELINE (current)": {
        # no overrides — current adaptive settings
    },
    "AGG-1: 2x lots, 20 threads": {
        "LOT_SIZE": 0.02,
        "MAX_INITIAL_ORDERS": 20,
    },
    "AGG-2: 2x lots, 20 threads, tight TP": {
        "LOT_SIZE": 0.02,
        "MAX_INITIAL_ORDERS": 20,
        "_REF_ENTRY_TP_PIPS": 92,        # ~$525 (25% tighter than $700)
        "_REF_RECOVERY_TPS": [92, 46, 23, 0],
    },
    "AGG-3: 3x lots, 15 threads, tight TP, $10 thread exit": {
        "LOT_SIZE": 0.03,
        "MAX_INITIAL_ORDERS": 15,
        "_REF_ENTRY_TP_PIPS": 92,
        "_REF_RECOVERY_TPS": [92, 46, 23, 0],
        "_REF_THREAD_PROFIT_TARGET": 10.0,
    },
    "AGG-4: 2x lots, 20 threads, $5 thread exit": {
        "LOT_SIZE": 0.02,
        "MAX_INITIAL_ORDERS": 20,
        "_REF_THREAD_PROFIT_TARGET": 5.0,
    },
    "AGG-5: 3x lots, 10 threads, tight TP, $3 thread exit": {
        "LOT_SIZE": 0.03,
        "MAX_INITIAL_ORDERS": 10,
        "_REF_ENTRY_TP_PIPS": 92,
        "_REF_RECOVERY_TPS": [92, 46, 23, 0],
        "_REF_THREAD_PROFIT_TARGET": 3.0,
    },
}


def apply_overrides(overrides: dict):
    """Apply setting overrides and recompute normalized values."""
    for key, val in overrides.items():
        setattr(settings, key, val)

    # If reference pips changed, recompute normalized values
    if any(k.startswith("_REF_") for k in overrides):
        norm = settings._normalize_for_pair(settings.SYMBOL)
        settings.ENTRY_TP_PIPS = norm["entry_tp_pips"]
        settings.RECOVERY_GRID_STEPS = norm["grid_steps"]
        settings.RECOVERY_TPS = norm["recovery_tps"]

    # Thread profit target
    if "_REF_THREAD_PROFIT_TARGET" in overrides:
        settings.THREAD_PROFIT_TARGET = overrides["_REF_THREAD_PROFIT_TARGET"]

    # Reload hedge engine module to pick up new settings
    import importlib
    import hedge_recovery_engine
    importlib.reload(hedge_recovery_engine)


def save_settings_snapshot():
    """Capture current settings for restoration."""
    keys = [
        "LOT_SIZE", "MAX_INITIAL_ORDERS", "ENTRY_TP_PIPS", "ENTRY_SL_PIPS",
        "RECOVERY_GRID_STEPS", "RECOVERY_TPS", "THREAD_PROFIT_TARGET",
        "_REF_ENTRY_TP_PIPS", "_REF_RECOVERY_TPS", "_REF_THREAD_PROFIT_TARGET",
    ]
    return {k: getattr(settings, k) for k in keys}


def restore_settings(snapshot: dict):
    """Restore settings from snapshot."""
    for k, v in snapshot.items():
        setattr(settings, k, v)
    import importlib
    import hedge_recovery_engine
    importlib.reload(hedge_recovery_engine)


def run_variant(name, overrides, data):
    """Run a single variant and return results dict."""
    snapshot = save_settings_snapshot()
    try:
        apply_overrides(overrides)

        engine = BacktestEngine(
            symbol=settings.SYMBOL,
            initial_balance=settings.INITIAL_BALANCE,
            commission_percent=settings.COMMISSION_PERCENT,
            slippage_pips=settings.SLIPPAGE_PIPS,
        )

        result = engine.run(data, progress_callback=None)

        stats = engine.order_engine.get_statistics()
        gross_profit = stats.get("total_profit", 0) + stats.get("total_commission", 0)
        # Calculate from trade history
        winners = [t.profit for t in engine.order_engine.trade_history if t.profit > 0]
        losers = [t.profit for t in engine.order_engine.trade_history if t.profit < 0]
        gp = sum(winners) if winners else 0
        gl = abs(sum(losers)) if losers else 0

        return {
            "name": name,
            "net_profit": result.total_return,
            "pct": result.total_return_percent,
            "trades": result.total_trades,
            "win_rate": result.win_rate,
            "gross_profit": result.gross_profit,
            "gross_loss": result.gross_loss,
            "commission": result.total_commission,
            "pf": result.gross_profit / result.gross_loss if result.gross_loss > 0 else 999,
            "max_dd_pct": result.max_drawdown_percent,
            "min_equity": result.min_equity,
            "max_rec": result.max_recovery_level,
            "lot": settings.LOT_SIZE,
            "max_orders": settings.MAX_INITIAL_ORDERS,
            "entry_tp": settings.ENTRY_TP_PIPS * 0.1,  # pips to $
            "thread_tp": settings.THREAD_PROFIT_TARGET,
        }
    finally:
        restore_settings(snapshot)


def main():
    setup_logging(console=False, file=True)

    max_rows = None
    for arg in sys.argv[1:]:
        if arg.startswith("--max-rows="):
            max_rows = int(arg.split("=")[1])

    print("Loading data...")
    data = load_ticks_to_candles(CSV_PATH, timeframe="5min", max_rows=max_rows)
    print(f"\nRunning {len(VARIANTS)} variants on {len(data)} candles\n")
    print("=" * 120)

    results = []
    for i, (name, overrides) in enumerate(VARIANTS.items()):
        print(f"[{i+1}/{len(VARIANTS)}] {name}...", end=" ", flush=True)
        r = run_variant(name, overrides, data)
        results.append(r)
        print(f"${r['net_profit']:>+10,.0f} ({r['pct']:>+.1f}%)  DD:{r['max_dd_pct']:.1f}%  PF:{r['pf']:.2f}")

    # Summary table
    print(f"\n{'=' * 120}")
    print(f"{'VARIANT':<45} {'Net $':>10} {'Return':>8} {'Trades':>7} {'Win%':>6} "
          f"{'PF':>5} {'MaxDD%':>7} {'MinEq':>10} {'Comm$':>10} {'Lot':>5} {'MaxOrd':>7} {'ThTP$':>6}")
    print("-" * 120)

    for r in results:
        print(f"{r['name']:<45} ${r['net_profit']:>9,.0f} {r['pct']:>+7.1f}% {r['trades']:>7,} {r['win_rate']:>5.1f}% "
              f"{r['pf']:>5.2f} {r['max_dd_pct']:>6.1f}% ${r['min_equity']:>9,.0f} ${r['commission']:>9,.0f} "
              f"{r['lot']:>5.3f} {r['max_orders']:>7} ${r['thread_tp']:>5.1f}")

    print(f"\n{'=' * 120}")


if __name__ == "__main__":
    main()
