"""
Backtesting engine for the Auto Hedge-Mart V4 strategy.

Simulates the hedged martingale recovery strategy on historical data
and generates performance metrics.
"""

from dataclasses import dataclass, field
from datetime import datetime
from typing import List, Optional, Callable

from config import (
    SYMBOL,
    INITIAL_BALANCE,
    COMMISSION_PERCENT,
    SLIPPAGE_PIPS,
    SAMPLE_DATA_CANDLES,
    SAMPLE_START_PRICE,
    SAMPLE_VOLATILITY,
    DATA_FILE_CSV,
    DATA_FILE_JSON,
    CSV_TIMESTAMP_COL,
    CSV_OPEN_COL,
    CSV_HIGH_COL,
    CSV_LOW_COL,
    CSV_CLOSE_COL,
    CSV_VOLUME_COL,
    USE_TICK_DATA,
    TICKS_PER_CANDLE,
)
from data_loader import DataLoader, OHLCV, TickData, create_sample_data
from normalization import get_spread, normalize_price, get_pip_size
from order_engine import BacktestOrderEngine, OrderSide
from hedge_recovery_engine import HedgeMartStrategy
from utils import (
    setup_logging,
    get_logger,
    log_trade,
    log_balance,
    timestamp_to_str
)


@dataclass
class BacktestResult:
    """Results from a backtest run."""
    initial_balance: float
    final_balance: float
    total_return: float
    total_return_percent: float

    total_trades: int
    winning_trades: int
    losing_trades: int
    win_rate: float

    gross_profit: float
    gross_loss: float
    net_profit: float
    total_commission: float

    max_drawdown: float
    max_drawdown_percent: float
    max_equity: float
    min_equity: float

    # Thread statistics
    total_threads: int
    max_recovery_level: int
    total_recovery_pairs: int

    # Margin tracking
    peak_margin: float = 0.0
    peak_positions: int = 0
    peak_lots: float = 0.0

    start_date: Optional[datetime] = None
    end_date: Optional[datetime] = None
    trading_days: int = 0

    equity_curve: List[float] = field(default_factory=list)
    balance_curve: List[float] = field(default_factory=list)
    trade_log: List[dict] = field(default_factory=list)

    def to_dict(self) -> dict:
        return {
            "initial_balance": self.initial_balance,
            "final_balance": self.final_balance,
            "total_return": self.total_return,
            "total_return_percent": self.total_return_percent,
            "total_trades": self.total_trades,
            "winning_trades": self.winning_trades,
            "losing_trades": self.losing_trades,
            "win_rate": self.win_rate,
            "gross_profit": self.gross_profit,
            "gross_loss": self.gross_loss,
            "net_profit": self.net_profit,
            "total_commission": self.total_commission,
            "max_drawdown": self.max_drawdown,
            "max_drawdown_percent": self.max_drawdown_percent,
            "max_equity": self.max_equity,
            "min_equity": self.min_equity,
            "total_threads": self.total_threads,
            "max_recovery_level": self.max_recovery_level,
            "total_recovery_pairs": self.total_recovery_pairs,
            "start_date": timestamp_to_str(self.start_date) if self.start_date else None,
            "end_date": timestamp_to_str(self.end_date) if self.end_date else None,
            "trading_days": self.trading_days
        }

    def print_summary(self) -> None:
        print("\n" + "=" * 60)
        print("HEDGE-MART V4 BACKTEST RESULTS")
        print("=" * 60)

        print(f"\n{'PERFORMANCE':=^40}")
        print(f"Initial Balance:     ${self.initial_balance:,.2f}")
        print(f"Final Balance:       ${self.final_balance:,.2f}")
        print(f"Net Profit:          ${self.net_profit:,.2f}")
        print(f"Total Return:        {self.total_return_percent:.2f}%")

        print(f"\n{'TRADE STATISTICS':=^40}")
        print(f"Total Trades:        {self.total_trades}")
        print(f"Winning Trades:      {self.winning_trades}")
        print(f"Losing Trades:       {self.losing_trades}")
        print(f"Win Rate:            {self.win_rate:.2f}%")
        print(f"Gross Profit:        ${self.gross_profit:,.2f}")
        print(f"Gross Loss:          ${self.gross_loss:,.2f}")
        print(f"Total Commission:    ${self.total_commission:,.2f}")

        print(f"\n{'RISK METRICS':=^40}")
        print(f"Max Drawdown:        ${self.max_drawdown:,.2f}")
        print(f"Max Drawdown %:      {self.max_drawdown_percent:.2f}%")
        print(f"Max Equity:          ${self.max_equity:,.2f}")
        print(f"Min Equity:          ${self.min_equity:,.2f}")

        if hasattr(self, 'peak_margin') and self.peak_margin > 0:
            print(f"\n{'MARGIN ANALYSIS':=^40}")
            print(f"Peak Notional:       ${self.peak_margin:,.0f}")
            print(f"Peak Positions:      {self.peak_positions}")
            print(f"Peak Lots:           {self.peak_lots:.3f}")
            for lev in [50, 100, 200]:
                req = self.peak_margin / lev
                print(f"  Margin @ {lev}x:      ${req:,.0f}")

        print(f"\n{'RECOVERY STATISTICS':=^40}")
        print(f"Total Threads:       {self.total_threads}")
        print(f"Max Recovery Level:  {self.max_recovery_level}")
        print(f"Total Hedge Pairs:   {self.total_recovery_pairs}")

        if self.start_date and self.end_date:
            print(f"\n{'TIME PERIOD':=^40}")
            print(f"Start Date:          {timestamp_to_str(self.start_date)}")
            print(f"End Date:            {timestamp_to_str(self.end_date)}")
            print(f"Trading Days:        {self.trading_days}")

        print("\n" + "=" * 60)


class BacktestEngine:
    """Backtesting engine for Hedge-Mart V4 strategy."""

    def __init__(
        self,
        symbol: str = SYMBOL,
        initial_balance: float = INITIAL_BALANCE,
        commission_percent: float = COMMISSION_PERCENT,
        slippage_pips: float = SLIPPAGE_PIPS,
    ):
        self.symbol = symbol
        self.initial_balance = initial_balance

        # Initialize order engine
        self.order_engine = BacktestOrderEngine(
            symbol=symbol,
            initial_balance=initial_balance,
            commission_percent=commission_percent,
            slippage_pips=slippage_pips
        )

        # Initialize strategy
        self.strategy = HedgeMartStrategy(
            symbol=symbol,
            order_engine=self.order_engine
        )

        # Tracking
        self.equity_curve: List[float] = []
        self.balance_curve: List[float] = []
        self.max_equity = initial_balance
        self.min_equity = initial_balance
        self.max_drawdown = 0.0
        self.peak_equity = initial_balance

        # Margin tracking
        self.peak_margin = 0.0
        self.peak_positions = 0
        self.peak_lots = 0.0

        # Thread tracking
        self.total_threads_created = 0
        self.max_recovery_level = 0
        self.total_recovery_pairs = 0

        self.current_time: Optional[datetime] = None
        self.logger = get_logger()

    def process_candle(self, candle: OHLCV) -> None:
        """Process a single candle through the trading logic."""
        self.current_time = candle.timestamp
        self.order_engine.set_time(candle.timestamp)

        # Set prices to candle open for entry decisions
        bid = candle.open
        ask = candle.open * 1.0001
        self.order_engine.update_prices(bid, ask)

        # Check entry signals on new bar (uses previous completed bars)
        threads_before = self.strategy.get_active_thread_count()
        self.strategy.on_new_bar(candle)
        threads_after = self.strategy.get_active_thread_count()
        if threads_after > threads_before:
            self.total_threads_created += (threads_after - threads_before)

        # Update prices to candle close for P&L calculations
        bid = candle.close
        ask = candle.close * 1.0001
        self.order_engine.update_prices(bid, ask)

        # Process all strategy logic: grid levels → recovery TPs → initial TP/SL → thread profit
        # Strategy manages ALL TP/SL (not order_engine.process_tp_sl) to ensure correct ordering
        self.strategy.process_candle(candle)

        # Track stats
        for t in self.strategy.recovery_threads:
            if t.level_count > self.max_recovery_level:
                self.max_recovery_level = t.level_count
            self.total_recovery_pairs = max(
                self.total_recovery_pairs,
                self.strategy.get_total_recovery_orders()
            )

        # Update metrics
        self._update_metrics()

    def process_tick(self, tick: TickData) -> None:
        """Process a single tick through the trading logic."""
        self.current_time = tick.timestamp
        self.order_engine.set_time(tick.timestamp)
        self.order_engine.update_prices(tick.bid, tick.ask)

        # Check TP/SL at tick level for initial orders
        self.order_engine.process_tp_sl(tick.ask, tick.bid)

        # Run recovery threads at tick price
        self.strategy.on_tick(tick.bid, tick.ask)

        # Track stats
        for t in self.strategy.recovery_threads:
            if t.level_count > self.max_recovery_level:
                self.max_recovery_level = t.level_count

        self._update_metrics()

    def _update_metrics(self) -> None:
        """Update equity curve and drawdown metrics."""
        equity = self.order_engine.get_equity()
        balance = self.order_engine.balance

        self.equity_curve.append(equity)
        self.balance_curve.append(balance)

        if equity > self.peak_equity:
            self.peak_equity = equity
        else:
            drawdown = self.peak_equity - equity
            if drawdown > self.max_drawdown:
                self.max_drawdown = drawdown

        self.max_equity = max(self.max_equity, equity)
        self.min_equity = min(self.min_equity, equity)

        # Track margin usage
        margin = self.order_engine.get_margin_used()
        n_pos = len(self.order_engine.positions)
        total_lots = sum(p.lots for p in self.order_engine.positions.values()) if n_pos > 0 else 0
        if margin > self.peak_margin:
            self.peak_margin = margin
        self.peak_positions = max(self.peak_positions, n_pos)
        self.peak_lots = max(self.peak_lots, total_lots)

    def run(
        self,
        data_loader: DataLoader,
        use_ticks: bool = False,
        ticks_per_candle: int = 4,
        progress_callback: Optional[Callable[[int, int], None]] = None
    ) -> BacktestResult:
        """Run backtest on loaded data."""
        self.logger.info(f"Starting Hedge-Mart V4 backtest on {len(data_loader)} candles")

        total_items = len(data_loader)
        processed = 0

        if use_ticks:
            for tick in data_loader.iterate_ticks(ticks_per_candle):
                self.process_tick(tick)
                processed += 1
                if progress_callback and processed % 1000 == 0:
                    progress_callback(processed, total_items * ticks_per_candle)

                # Also feed candles for bar-based entry signals
                # (simplified: use tick as both tick and bar)
        else:
            # Feed bars one at a time
            prev_candle = None
            for candle in data_loader.iterate_candles():
                self.process_candle(candle)
                processed += 1
                if progress_callback and processed % 100 == 0:
                    progress_callback(processed, total_items)
                prev_candle = candle

        # Close remaining positions at last price
        bid, ask = self.order_engine.get_current_price()
        for position in list(self.order_engine.positions.values()):
            if position.side == OrderSide.BUY:
                self.order_engine.close_position(position.position_id, bid)
            else:
                self.order_engine.close_position(position.position_id, ask)

        self.logger.info("Backtest complete")

        return self._compile_results(data_loader)

    def _compile_results(self, data_loader: DataLoader) -> BacktestResult:
        """Compile all results."""
        stats = self.order_engine.get_statistics()

        final_balance = self.order_engine.balance
        total_return = final_balance - self.initial_balance
        total_return_percent = (total_return / self.initial_balance) * 100

        max_dd_percent = (self.max_drawdown / self.peak_equity) * 100 if self.peak_equity > 0 else 0

        start_date, end_date = data_loader.get_date_range()
        trading_days = 0
        if start_date and end_date:
            trading_days = (end_date - start_date).days

        trade_log = [
            {
                "order_id": t.order_id,
                "symbol": t.symbol,
                "side": t.side.value,
                "lots": t.lots,
                "entry_price": t.price,
                "close_price": t.close_price,
                "profit": t.profit,
                "open_time": timestamp_to_str(t.open_time) if t.open_time else None,
                "close_time": timestamp_to_str(t.close_time) if t.close_time else None,
                "comment": t.comment,
            }
            for t in self.order_engine.trade_history
        ]

        return BacktestResult(
            initial_balance=self.initial_balance,
            final_balance=final_balance,
            total_return=total_return,
            total_return_percent=total_return_percent,
            total_trades=stats["total_trades"],
            winning_trades=stats["winning_trades"],
            losing_trades=stats["losing_trades"],
            win_rate=stats["win_rate"],
            gross_profit=stats["average_win"] * stats["winning_trades"],
            gross_loss=abs(stats["average_loss"] * stats["losing_trades"]),
            net_profit=total_return,
            total_commission=stats["total_commission"],
            max_drawdown=self.max_drawdown,
            max_drawdown_percent=max_dd_percent,
            max_equity=self.max_equity,
            min_equity=self.min_equity,
            peak_margin=self.peak_margin,
            peak_positions=self.peak_positions,
            peak_lots=self.peak_lots,
            total_threads=self.total_threads_created,
            max_recovery_level=self.max_recovery_level,
            total_recovery_pairs=self.total_recovery_pairs,
            start_date=start_date,
            end_date=end_date,
            trading_days=trading_days,
            equity_curve=self.equity_curve,
            balance_curve=self.balance_curve,
            trade_log=trade_log
        )


def load_data_from_settings() -> DataLoader:
    """Load data based on settings configuration."""
    loader = DataLoader(SYMBOL)

    if DATA_FILE_CSV:
        print(f"Loading data from CSV: {DATA_FILE_CSV}")
        loader.load_csv(
            DATA_FILE_CSV,
            timestamp_col=CSV_TIMESTAMP_COL,
            open_col=CSV_OPEN_COL,
            high_col=CSV_HIGH_COL,
            low_col=CSV_LOW_COL,
            close_col=CSV_CLOSE_COL,
            volume_col=CSV_VOLUME_COL
        )
        return loader

    if DATA_FILE_JSON:
        print(f"Loading data from JSON: {DATA_FILE_JSON}")
        loader.load_json(DATA_FILE_JSON)
        return loader

    print(f"Generating {SAMPLE_DATA_CANDLES} candles of sample data...")
    return create_sample_data(
        symbol=SYMBOL,
        num_candles=SAMPLE_DATA_CANDLES,
        start_price=SAMPLE_START_PRICE,
        volatility=SAMPLE_VOLATILITY
    )


def run_backtest() -> BacktestResult:
    """Run backtest using settings from settings.py."""
    setup_logging(console=True, file=False)

    data = load_data_from_settings()
    print(f"Loaded {len(data)} candles for {SYMBOL}")

    engine = BacktestEngine(
        symbol=SYMBOL,
        initial_balance=INITIAL_BALANCE,
        commission_percent=COMMISSION_PERCENT,
        slippage_pips=SLIPPAGE_PIPS,
    )

    def progress(current, total):
        pct = current / total * 100
        print(f"\rProgress: {pct:.1f}%", end="", flush=True)

    result = engine.run(
        data,
        use_ticks=USE_TICK_DATA,
        ticks_per_candle=TICKS_PER_CANDLE,
        progress_callback=progress
    )
    print()

    result.print_summary()
    return result


def run_sample_backtest():
    """Run a sample backtest (alias)."""
    return run_backtest()


if __name__ == "__main__":
    run_backtest()
