"""
Optimized fast backtesting engine using numpy vectorization.

This version is optimized for processing large tick datasets with:
- Numpy-based price data handling
- Chunked processing to reduce memory usage
- Efficient position tracking
- Real-time progress reporting
"""

import time
import numpy as np
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import List, Optional, Dict, Tuple
import pandas as pd

from config import (
    SYMBOL,
    INITIAL_BALANCE,
    COMMISSION_PERCENT,
    SLIPPAGE_PIPS,
    BASE_MULTIPLIER,
    RISK_PERCENT,
    AGGRESSION_LEVEL,
)
from utils import setup_logging, get_logger

# Try to import settings for additional config
try:
    import settings
    REENTRY_ABANDON_DISTANCE = getattr(settings, 'REENTRY_ABANDON_DISTANCE', 5000.0)
except ImportError:
    REENTRY_ABANDON_DISTANCE = 5000.0


@dataclass
class ThreadEvent:
    """Single event in a thread's lifecycle."""
    timestamp: datetime
    event_type: str  # OPEN, AVG, TP_HIT, REENTRY, CLOSE, ABANDON
    level: int
    price: float
    lots: float
    tp_price: float
    profit: float = 0.0
    balance: float = 0.0
    equity: float = 0.0
    details: str = ""


@dataclass
class ThreadLog:
    """Complete log of a trading thread."""
    thread_id: int
    start_time: datetime
    end_time: Optional[datetime] = None
    events: List[ThreadEvent] = field(default_factory=list)
    max_level: int = 0
    total_profit: float = 0.0
    total_trades: int = 0
    winning_trades: int = 0
    entry_price: float = 0.0

    def add_event(self, event: ThreadEvent):
        self.events.append(event)
        if event.level > self.max_level:
            self.max_level = event.level

    def summary(self) -> str:
        """Generate thread summary."""
        duration = (self.end_time - self.start_time).total_seconds() if self.end_time else 0
        hours = int(duration // 3600)
        mins = int((duration % 3600) // 60)

        return (
            f"Thread #{self.thread_id}: "
            f"Entry ${self.entry_price:,.2f} | "
            f"MaxLvl {self.max_level} | "
            f"Trades {self.total_trades} | "
            f"Wins {self.winning_trades} | "
            f"P/L ${self.total_profit:+,.2f} | "
            f"Duration {hours}h {mins}m"
        )


@dataclass
class Position:
    """Lightweight position for fast backtesting."""
    id: int
    entry_price: float
    lots: float
    tp_price: float
    entry_time: datetime
    direction: int = 1  # 1 = BUY (long), -1 = SELL (short)
    is_averaging: bool = False
    avg_level: int = 0
    is_reentry: bool = False  # True if this is a re-entry position
    thread_id: int = 0  # Which thread this position belongs to


@dataclass
class PendingReentry:
    """Tracks a level waiting for price to return for re-entry."""
    level: int
    entry_price: float  # Price at which we should reopen
    lots: float         # Lot size for this level
    tp_distance: float  # TP distance from entry
    direction: int = 1  # 1 = BUY, -1 = SELL


@dataclass
class TradingThread:
    """Tracks state of a single trading thread."""
    id: int
    direction: int  # 1 = BUY (averages DOWN), -1 = SELL (averages UP)
    entry_price: float
    entry_time: datetime
    base_lots: float
    tp_dist: float
    entry_dist: float
    avg_level: int = 0
    active: bool = True

    # Weighted average tracking
    total_lots: float = 0.0
    weighted_entry: float = 0.0

    # Pending re-entries for this thread
    pending_reentries: Dict[int, PendingReentry] = field(default_factory=dict)

    # Grid levels occupied by this thread
    occupied_grid_levels: set = field(default_factory=set)

    # Hedge tracking
    hedge_active: bool = False
    hedge_position_id: int = 0
    hedge_entry_price: float = 0.0
    hedge_lots: float = 0.0

    # Logging
    log: Optional[ThreadLog] = None


@dataclass
class FastBacktestResult:
    """Results from fast backtest."""
    symbol: str
    start_time: datetime
    end_time: datetime
    total_ticks: int
    run_duration: float

    initial_balance: float
    final_balance: float
    final_equity: float
    net_profit: float
    total_return_pct: float

    total_trades: int
    winning_trades: int
    losing_trades: int
    win_rate: float

    gross_profit: float
    gross_loss: float
    profit_factor: float

    max_drawdown: float
    max_drawdown_pct: float

    total_threads: int
    max_avg_level: int
    total_commission: float
    reentries_abandoned: int = 0
    spread_filtered: int = 0  # Trades blocked due to high spread

    equity_curve: List[Tuple[int, float]] = field(default_factory=list)
    drawdown_curve: List[Tuple[int, float]] = field(default_factory=list)

    def print_summary(self):
        """Print formatted summary."""
        print("\n" + "=" * 70)
        print("FAST BACKTEST RESULTS")
        print("=" * 70)

        print(f"\n{'RUN INFO':=^50}")
        print(f"Symbol:              {self.symbol}")
        print(f"Period:              {self.start_time} to {self.end_time}")
        print(f"Total Ticks:         {self.total_ticks:,}")
        print(f"Run Duration:        {self.run_duration:.1f} seconds")
        print(f"Ticks/Second:        {self.total_ticks / self.run_duration:,.0f}")

        print(f"\n{'PERFORMANCE':=^50}")
        print(f"Initial Balance:     ${self.initial_balance:,.2f}")
        print(f"Final Balance:       ${self.final_balance:,.2f}")
        print(f"Final Equity:        ${self.final_equity:,.2f}")
        print(f"Net Profit:          ${self.net_profit:,.2f}")
        print(f"Total Return:        {self.total_return_pct:+.2f}%")

        print(f"\n{'TRADE STATISTICS':=^50}")
        print(f"Total Trades:        {self.total_trades}")
        print(f"Winning Trades:      {self.winning_trades} ({self.win_rate:.1f}%)")
        print(f"Losing Trades:       {self.losing_trades}")
        print(f"Gross Profit:        ${self.gross_profit:,.2f}")
        print(f"Gross Loss:          ${self.gross_loss:,.2f}")
        print(f"Profit Factor:       {self.profit_factor:.2f}")

        print(f"\n{'DRAWDOWN':=^50}")
        print(f"Max Drawdown:        ${self.max_drawdown:,.2f} ({self.max_drawdown_pct:.2f}%)")

        print(f"\n{'AVERAGING':=^50}")
        print(f"Total Threads:       {self.total_threads}")
        print(f"Max Avg Level:       {self.max_avg_level}")
        print(f"Reentries Abandoned: {self.reentries_abandoned}")

        print(f"\n{'COSTS':=^50}")
        print(f"Total Commission:    ${self.total_commission:,.2f}")
        print(f"Spread Filtered:     {self.spread_filtered}")

        print("\n" + "=" * 70)

    def to_json(self, filepath: str = "backtest_results.json"):
        """Save results to JSON file."""
        import json
        data = {
            "symbol": self.symbol,
            "start_time": str(self.start_time),
            "end_time": str(self.end_time),
            "total_ticks": self.total_ticks,
            "run_duration": self.run_duration,
            "initial_balance": self.initial_balance,
            "final_balance": self.final_balance,
            "final_equity": self.final_equity,
            "net_profit": self.net_profit,
            "total_return_pct": self.total_return_pct,
            "total_trades": self.total_trades,
            "winning_trades": self.winning_trades,
            "losing_trades": self.losing_trades,
            "win_rate": self.win_rate,
            "gross_profit": self.gross_profit,
            "gross_loss": self.gross_loss,
            "profit_factor": self.profit_factor,
            "max_drawdown": self.max_drawdown,
            "max_drawdown_pct": self.max_drawdown_pct,
            "total_threads": self.total_threads,
            "max_avg_level": self.max_avg_level,
            "total_commission": self.total_commission,
            "spread_filtered": self.spread_filtered,
            "reentries_abandoned": self.reentries_abandoned
        }
        with open(filepath, 'w') as f:
            json.dump(data, f, indent=2)
        print(f"Results saved to: {filepath}")


class FastBacktestEngine:
    """
    Fast backtesting engine optimized for large datasets.
    """

    def __init__(
        self,
        symbol: str = SYMBOL,
        initial_balance: float = INITIAL_BALANCE,
        commission_pct: float = COMMISSION_PERCENT,
        risk_pct: float = RISK_PERCENT,
        base_multiplier: float = BASE_MULTIPLIER,
        aggression: float = AGGRESSION_LEVEL,
        report_interval: int = 100000,
    ):
        self.symbol = symbol
        self.initial_balance = initial_balance
        self.balance = initial_balance
        self.commission_pct = commission_pct / 100  # Convert to decimal
        self.risk_pct = risk_pct / 100
        self.base_multiplier = base_multiplier
        self.aggression = aggression / 100
        self.report_interval = report_interval

        # Fibonacci sequence for averaging distances (default)
        self.fib = [1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]

        # Custom martingale levels (from CustomMartingaleEA)
        # If set, these override Fibonacci distances
        # Format: distance multipliers from entry (e.g., 1.3x, 2.1x, 3.4x of base distance)
        self.custom_levels = None  # e.g., [1.3, 2.1, 3.4, 5.5, 8.9]
        self.custom_multipliers = None  # e.g., [1.5, 1.6, 1.7, 1.8, 1.9]
        self.use_cumulative_multipliers = False  # If True, multiply through all levels like EA

        # Spread safety (from CustomMartingaleEA)
        self.max_spread = 100.0  # Maximum allowed spread in dollars
        self.spread_filtered_count = 0  # Count trades blocked by spread filter

        # Position tracking
        self.positions: Dict[int, Position] = {}
        self._pos_counter = 0

        # Multi-thread tracking (bidirectional trading)
        self.threads: Dict[int, TradingThread] = {}  # thread_id -> TradingThread
        self._thread_counter = 0
        self.threads_completed = 0

        # Bidirectional trading settings
        self.enable_longs = True  # BUY positions (average DOWN)
        self.enable_shorts = True  # SELL positions (average UP)
        self.grid_spacing = 500.0  # Min distance between main order entries
        self.max_concurrent_threads = 10  # Max simultaneous threads

        # Individual TP settings
        self.individual_tp_enabled = True  # Each position has its own TP
        self.individual_tp_distance = 200.0  # TP distance per position

        # Hedge settings (anti-martingale lock)
        self.hedge_enabled = False  # Enable hedging
        self.hedge_trigger_level = 5  # Open hedge at this averaging level
        self.hedge_close_profit = 50.0  # Close hedge when it profits this much
        self.hedge_lot_percent = 50  # Hedge size as % of total main lots

        # Track last entry prices for grid spacing
        self.last_long_entry = 0.0
        self.last_short_entry = 0.0

        # Legacy single-thread tracking (for progress display)
        self.thread_avg_level = 0  # Current max level across all threads

        # TP Cascade / Re-entry tracking
        self.reentry_tolerance = 1.0  # Price tolerance for re-entry trigger (in dollars)

        # Stale re-entry cleanup: abandon pending re-entries if price moves too far
        self.reentry_abandon_distance = REENTRY_ABANDON_DISTANCE  # From settings or default
        self.reentries_abandoned = 0  # Track abandoned re-entries

        # Grid-based order management (from NAGA_GRID EA)
        self.grid_size = 50.0  # Grid size in dollars for BTC
        self.min_trade_interval = 0  # Minimum ticks between trades (0 = disabled)
        self.last_trade_tick = 0

        # Metrics
        self.total_trades = 0
        self.winning_trades = 0
        self.losing_trades = 0
        self.gross_profit = 0.0
        self.gross_loss = 0.0
        self.total_commission = 0.0
        self.max_avg_level = 0

        # Drawdown
        self.peak_equity = initial_balance
        self.max_drawdown = 0.0
        self.max_drawdown_pct = 0.0

        # Curves (sampled)
        self.equity_curve = []
        self.drawdown_curve = []
        self._sample_interval = 10000

        # Thread logging
        self.thread_logs: List[ThreadLog] = []
        self.current_thread_log: Optional[ThreadLog] = None
        self.enable_logging = True  # Set to False to disable for speed

    def _get_fib(self, n: int) -> int:
        """Get Fibonacci number, extend if needed."""
        while n >= len(self.fib):
            self.fib.append(self.fib[-1] + self.fib[-2])
        return self.fib[n]

    def write_thread_logs(self, filename: str = "thread_logs.txt"):
        """Write detailed thread logs to a file."""
        with open(filename, 'w') as f:
            f.write("=" * 100 + "\n")
            f.write("THREAD-BY-THREAD TRADING LOG\n")
            f.write("=" * 100 + "\n\n")

            # Summary statistics
            total_threads = len(self.thread_logs)
            profitable_threads = sum(1 for t in self.thread_logs if t.total_profit > 0)
            losing_threads = sum(1 for t in self.thread_logs if t.total_profit <= 0)
            total_profit = sum(t.total_profit for t in self.thread_logs)

            f.write(f"SUMMARY\n")
            f.write(f"-" * 50 + "\n")
            f.write(f"Total Threads:      {total_threads}\n")
            f.write(f"Profitable:         {profitable_threads} ({profitable_threads/total_threads*100:.1f}%)\n" if total_threads > 0 else "")
            f.write(f"Losing:             {losing_threads}\n")
            f.write(f"Total P/L:          ${total_profit:+,.2f}\n")
            f.write(f"\n")

            # Max level distribution
            level_dist = {}
            for t in self.thread_logs:
                level_dist[t.max_level] = level_dist.get(t.max_level, 0) + 1

            f.write(f"MAX LEVEL DISTRIBUTION\n")
            f.write(f"-" * 50 + "\n")
            for level in sorted(level_dist.keys()):
                count = level_dist[level]
                pct = count / total_threads * 100 if total_threads > 0 else 0
                bar = "#" * int(pct / 2)
                f.write(f"Level {level:>2}: {count:>4} ({pct:>5.1f}%) {bar}\n")
            f.write(f"\n")

            # Detailed thread logs
            f.write("=" * 100 + "\n")
            f.write("DETAILED THREAD LOGS\n")
            f.write("=" * 100 + "\n\n")

            for thread in self.thread_logs:
                f.write(f"\n{'='*80}\n")
                f.write(f"{thread.summary()}\n")
                f.write(f"{'='*80}\n")
                f.write(f"{'Time':<25} {'Event':<15} {'Lvl':>4} {'Price':>12} {'Lots':>10} {'TP':>12} {'Profit':>12} {'Balance':>12} {'Details'}\n")
                f.write(f"{'-'*130}\n")

                for event in thread.events:
                    time_str = event.timestamp.strftime("%Y-%m-%d %H:%M:%S")
                    profit_str = f"${event.profit:+,.2f}" if event.profit != 0 else ""
                    f.write(
                        f"{time_str:<25} {event.event_type:<15} {event.level:>4} "
                        f"${event.price:>10,.2f} {event.lots:>10.4f} "
                        f"${event.tp_price:>10,.2f} {profit_str:>12} "
                        f"${event.balance:>10,.2f} {event.details}\n"
                    )

                f.write(f"\n")

            # Top 10 most profitable threads
            f.write("\n" + "=" * 100 + "\n")
            f.write("TOP 10 MOST PROFITABLE THREADS\n")
            f.write("=" * 100 + "\n")
            sorted_by_profit = sorted(self.thread_logs, key=lambda t: t.total_profit, reverse=True)[:10]
            for i, thread in enumerate(sorted_by_profit, 1):
                f.write(f"{i:>2}. {thread.summary()}\n")

            # Top 10 worst threads
            f.write("\n" + "=" * 100 + "\n")
            f.write("TOP 10 WORST THREADS (MOST LOSING)\n")
            f.write("=" * 100 + "\n")
            sorted_by_loss = sorted(self.thread_logs, key=lambda t: t.total_profit)[:10]
            for i, thread in enumerate(sorted_by_loss, 1):
                f.write(f"{i:>2}. {thread.summary()}\n")

            # Threads that hit max averaging levels
            f.write("\n" + "=" * 100 + "\n")
            f.write("THREADS WITH HIGHEST AVERAGING LEVELS\n")
            f.write("=" * 100 + "\n")
            sorted_by_level = sorted(self.thread_logs, key=lambda t: t.max_level, reverse=True)[:20]
            for i, thread in enumerate(sorted_by_level, 1):
                f.write(f"{i:>2}. {thread.summary()}\n")

        print(f"\nThread logs written to: {filename}")

    def _get_grid_index(self, price: float) -> int:
        """Convert price to grid level index."""
        return int(round(price / self.grid_size))

    def _calc_lot_size(self, price: float) -> float:
        """Calculate lot size based on risk."""
        risk_amount = self.balance * self.risk_pct
        return risk_amount / price

    def _calc_avg_lot(self, base_lot: float, level: int) -> float:
        """Calculate averaging lot size.

        If use_cumulative_multipliers is True (like UniqueMartingaleEA):
            Level N lot = base_lot * mult[0] * mult[1] * ... * mult[N-1]
        Otherwise (direct):
            Level N lot = base_lot * mult[N-1]
        """
        if self.custom_multipliers and level > 0 and level <= len(self.custom_multipliers):
            if self.use_cumulative_multipliers:
                # CUMULATIVE: multiply through all levels (like the MQL4 EA)
                # Level 1: base * mult[0]
                # Level 2: base * mult[0] * mult[1]
                # Level 3: base * mult[0] * mult[1] * mult[2]
                result = base_lot
                for i in range(level):
                    result *= self.custom_multipliers[i]
                return result
            else:
                # DIRECT: use the multiplier for this specific level
                return base_lot * self.custom_multipliers[level - 1]
        else:
            # Default: exponential aggression
            return base_lot * ((1 + self.aggression) ** level)

    def _calc_avg_distance(self, base_dist: float, level: int) -> float:
        """Calculate averaging distance.

        If custom_levels is set, use custom distance multipliers.
        Otherwise, use Fibonacci sequence.
        """
        if self.custom_levels and level < len(self.custom_levels):
            return base_dist * self.custom_levels[level]
        else:
            # Default: Fibonacci
            return base_dist * self._get_fib(level)

    def _calc_commission(self, lots: float, price: float) -> float:
        """Calculate commission for trade."""
        return lots * price * self.commission_pct

    def _open_position(
        self,
        price: float,
        lots: float,
        tp_price: float,
        timestamp: datetime,
        direction: int = 1,
        thread_id: int = 0,
        is_avg: bool = False,
        avg_level: int = 0,
        is_reentry: bool = False
    ) -> int:
        """Open a new position."""
        self._pos_counter += 1
        pos = Position(
            id=self._pos_counter,
            entry_price=price,
            lots=lots,
            tp_price=tp_price,
            entry_time=timestamp,
            direction=direction,
            is_averaging=is_avg,
            avg_level=avg_level,
            is_reentry=is_reentry,
            thread_id=thread_id
        )
        self.positions[pos.id] = pos

        # Track grid level as occupied in the thread
        if thread_id in self.threads:
            grid_idx = self._get_grid_index(price)
            self.threads[thread_id].occupied_grid_levels.add(grid_idx)

        # Deduct commission
        comm = self._calc_commission(lots, price)
        self.balance -= comm
        self.total_commission += comm

        return pos.id

    def _close_position(self, pos_id: int, price: float) -> float:
        """Close a position and return profit."""
        if pos_id not in self.positions:
            return 0.0

        pos = self.positions[pos_id]

        # Calculate profit based on direction
        if pos.direction == 1:  # BUY
            profit = pos.lots * (price - pos.entry_price)
        else:  # SELL
            profit = pos.lots * (pos.entry_price - price)

        # Update weighted average tracking in thread (remove this position's contribution)
        # Skip hedge positions (avg_level = -1) - they're not part of the weighted average
        if pos.thread_id in self.threads and pos.avg_level >= 0:
            thread = self.threads[pos.thread_id]
            thread.total_lots -= pos.lots
            thread.weighted_entry -= pos.entry_price * pos.lots

            # Remove grid level from thread's occupied set
            grid_idx = self._get_grid_index(pos.entry_price)
            thread.occupied_grid_levels.discard(grid_idx)

        # Deduct closing commission
        comm = self._calc_commission(pos.lots, price)
        self.balance -= comm
        self.total_commission += comm

        self.balance += profit

        if profit > 0:
            self.winning_trades += 1
            self.gross_profit += profit
        else:
            self.losing_trades += 1
            self.gross_loss += abs(profit)

        self.total_trades += 1
        del self.positions[pos_id]

        return profit

    def _get_equity(self, bid: float, ask: float = None) -> float:
        """Calculate current equity accounting for position direction."""
        if ask is None:
            ask = bid  # Approximate if not provided
        unrealized = 0.0
        for p in self.positions.values():
            if p.direction == 1:  # BUY - close at bid
                unrealized += p.lots * (bid - p.entry_price)
            else:  # SELL - close at ask
                unrealized += p.lots * (p.entry_price - ask)
        return self.balance + unrealized

    def _update_drawdown(self, equity: float):
        """Update drawdown metrics."""
        if equity > self.peak_equity:
            self.peak_equity = equity

        dd = self.peak_equity - equity
        dd_pct = dd / self.peak_equity * 100 if self.peak_equity > 0 else 0

        if dd > self.max_drawdown:
            self.max_drawdown = dd
            self.max_drawdown_pct = dd_pct

    def run_from_csv(
        self,
        csv_file: str,
        max_rows: Optional[int] = None,
        chunk_size: int = 500000
    ) -> FastBacktestResult:
        """
        Run backtest directly from CSV file in chunks.

        Args:
            csv_file: Path to tick data CSV
            max_rows: Maximum rows to process
            chunk_size: Rows per chunk for memory efficiency
        """
        print(f"\nFast Backtest: {self.symbol}")
        print(f"Loading data from {csv_file}...")

        start_time = time.time()

        # Count total rows first
        if max_rows:
            total_rows = max_rows
        else:
            total_rows = sum(1 for _ in open(csv_file)) - 1  # Minus header

        print(f"Processing {total_rows:,} ticks in chunks of {chunk_size:,}")
        print("-" * 80)

        processed = 0
        data_start = None
        data_end = None

        # Process in chunks
        for chunk in pd.read_csv(csv_file, chunksize=chunk_size, nrows=max_rows):
            # Parse timestamps
            timestamps = chunk['Timestamp'].apply(self._parse_timestamp)
            bids = chunk['Bid price'].values
            asks = chunk['Ask price'].values

            if data_start is None:
                data_start = timestamps.iloc[0]
            data_end = timestamps.iloc[-1]

            # Process each tick in chunk
            for i in range(len(chunk)):
                self._process_tick(
                    timestamps.iloc[i],
                    bids[i],
                    asks[i]
                )

                processed += 1

                # Sample curves
                if processed % self._sample_interval == 0:
                    equity = self._get_equity(bids[i])
                    self.equity_curve.append((processed, equity))
                    self.drawdown_curve.append((processed, self.peak_equity - equity))

                # Progress report
                if processed % self.report_interval == 0:
                    equity = self._get_equity(bids[i])
                    dd = self.peak_equity - equity
                    dd_pct = dd / self.peak_equity * 100 if self.peak_equity > 0 else 0
                    elapsed = time.time() - start_time
                    tps = processed / elapsed if elapsed > 0 else 0

                    print(
                        f"\r[{processed:>12,}/{total_rows:,}] "
                        f"{processed/total_rows*100:5.1f}% | "
                        f"{tps:>8,.0f} t/s | "
                        f"Bal: ${self.balance:>10,.2f} | "
                        f"Eq: ${equity:>10,.2f} | "
                        f"DD: ${dd:>8,.2f} ({dd_pct:>5.2f}%) | "
                        f"MaxDD: ${self.max_drawdown:>8,.2f} ({self.max_drawdown_pct:>5.2f}%) | "
                        f"Pos: {len(self.positions):>2} | "
                        f"AvgL: {self.thread_avg_level:>2}",
                        end="", flush=True
                    )

        print()  # Newline after progress
        print("-" * 80)

        # Close remaining positions
        if self.positions:
            print(f"Closing {len(self.positions)} remaining positions...")
            last_bid = bids[-1] if len(bids) > 0 else 0
            for pos_id in list(self.positions.keys()):
                self._close_position(pos_id, last_bid)

        run_duration = time.time() - start_time

        # Calculate final metrics
        final_equity = self.balance  # All positions closed
        profit_factor = self.gross_profit / self.gross_loss if self.gross_loss > 0 else float('inf')

        result = FastBacktestResult(
            symbol=self.symbol,
            start_time=data_start,
            end_time=data_end,
            total_ticks=processed,
            run_duration=run_duration,
            initial_balance=self.initial_balance,
            final_balance=self.balance,
            final_equity=final_equity,
            net_profit=self.balance - self.initial_balance,
            total_return_pct=(self.balance / self.initial_balance - 1) * 100,
            total_trades=self.total_trades,
            winning_trades=self.winning_trades,
            losing_trades=self.losing_trades,
            win_rate=self.winning_trades / self.total_trades * 100 if self.total_trades > 0 else 0,
            gross_profit=self.gross_profit,
            gross_loss=self.gross_loss,
            profit_factor=profit_factor,
            max_drawdown=self.max_drawdown,
            max_drawdown_pct=self.max_drawdown_pct,
            total_threads=self.threads_completed,
            max_avg_level=self.max_avg_level,
            total_commission=self.total_commission,
            reentries_abandoned=self.reentries_abandoned,
            spread_filtered=self.spread_filtered_count,
            equity_curve=self.equity_curve,
            drawdown_curve=self.drawdown_curve
        )

        result.print_summary()
        result.to_json("backtest_results.json")

        # Write thread logs if logging is enabled
        if self.enable_logging and self.thread_logs:
            self.write_thread_logs("thread_logs.txt")

        return result

    def _parse_timestamp(self, ts_str: str) -> datetime:
        """Parse timestamp string."""
        import re
        match = re.match(r'(\d{8})\s+(\d{2}):(\d{2}):(\d{2}):(\d+)', str(ts_str))
        if match:
            date_part = match.group(1)
            return datetime(
                int(date_part[:4]),
                int(date_part[4:6]),
                int(date_part[6:8]),
                int(match.group(2)),
                int(match.group(3)),
                int(match.group(4)),
                int(match.group(5)) * 1000
            )
        return datetime.now()

    def _process_tick(self, timestamp: datetime, bid: float, ask: float):
        """Process a single tick - handles multiple threads and both directions."""
        spread = ask - bid
        spread_too_high = spread > self.max_spread

        # Check TP hits for all positions
        self._check_tp_hits(timestamp, bid, ask)

        # Process each active thread
        for thread_id in list(self.threads.keys()):
            thread = self.threads[thread_id]
            if not thread.active:
                continue

            # Check for re-entries in this thread
            if thread.pending_reentries and not spread_too_high:
                self._check_thread_reentries(thread, timestamp, bid, ask)

            # Check for averaging in this thread
            thread_positions = [p for p in self.positions.values() if p.thread_id == thread_id and p.avg_level >= 0]
            if thread_positions and not spread_too_high:
                self._check_thread_averaging(thread, timestamp, bid, ask)

            # Check hedge status
            if thread.hedge_active:
                self._check_hedge(thread, timestamp, bid, ask)

            # Cleanup stale re-entries
            if thread.pending_reentries and not thread_positions:
                self._cleanup_thread_reentries(thread, ask, timestamp)

            # Check if thread is complete
            if not thread_positions and not thread.pending_reentries:
                self._complete_thread(thread, timestamp, bid)

        # Check if we can open new threads (grid spacing)
        if not spread_too_high and len(self.threads) < self.max_concurrent_threads:
            # Try to open LONG thread
            if self.enable_longs:
                if self._can_open_long(ask):
                    self._open_thread(timestamp, bid, ask, spread, direction=1)

            # Try to open SHORT thread
            if self.enable_shorts:
                if self._can_open_short(bid):
                    self._open_thread(timestamp, bid, ask, spread, direction=-1)
        elif spread_too_high:
            self.spread_filtered_count += 1

        # Update max avg level for display
        self.thread_avg_level = max((t.avg_level for t in self.threads.values()), default=0)

        # Update drawdown
        equity = self._get_equity(bid, ask)
        self._update_drawdown(equity)

    def _can_open_long(self, ask: float) -> bool:
        """Check if we can open a new LONG thread based on grid spacing."""
        # Check if any LONG thread exists
        active_longs = [t for t in self.threads.values() if t.direction == 1 and t.active]

        if not active_longs:
            # No active longs - check spacing from last entry
            if self.last_long_entry == 0:
                return True
            return abs(ask - self.last_long_entry) >= self.grid_spacing

        # Check spacing from all active long entries
        for thread in active_longs:
            if abs(ask - thread.entry_price) < self.grid_spacing:
                return False
        return True

    def _can_open_short(self, bid: float) -> bool:
        """Check if we can open a new SHORT thread based on grid spacing."""
        # Check if any SHORT thread exists
        active_shorts = [t for t in self.threads.values() if t.direction == -1 and t.active]

        if not active_shorts:
            # No active shorts - check spacing from last entry
            if self.last_short_entry == 0:
                return True
            return abs(bid - self.last_short_entry) >= self.grid_spacing

        # Check spacing from all active short entries
        for thread in active_shorts:
            if abs(bid - thread.entry_price) < self.grid_spacing:
                return False
        return True

    def _check_tp_hits(self, timestamp: datetime, bid: float, ask: float):
        """Check and process TP hits for all positions."""
        for pos_id in list(self.positions.keys()):
            pos = self.positions[pos_id]
            tp_hit = False

            if pos.direction == 1:  # BUY - TP hit when bid >= tp_price
                tp_hit = bid >= pos.tp_price
                close_price = pos.tp_price if tp_hit else bid
            else:  # SELL - TP hit when ask <= tp_price
                tp_hit = ask <= pos.tp_price
                close_price = pos.tp_price if tp_hit else ask

            if tp_hit:
                # Calculate profit
                if pos.direction == 1:
                    profit = pos.lots * (pos.tp_price - pos.entry_price)
                else:
                    profit = pos.lots * (pos.entry_price - pos.tp_price)

                # Close the position
                self._close_position(pos_id, close_price)

                # Get thread for logging
                thread = self.threads.get(pos.thread_id)
                if thread and self.enable_logging and thread.log:
                    equity = self._get_equity(bid, ask)
                    event_type = "TP_REENTRY" if pos.is_reentry else "TP_HIT"
                    dir_str = "BUY" if pos.direction == 1 else "SELL"
                    thread.log.add_event(ThreadEvent(
                        timestamp=timestamp,
                        event_type=event_type,
                        level=pos.avg_level,
                        price=close_price,
                        lots=pos.lots,
                        tp_price=pos.tp_price,
                        profit=profit,
                        balance=self.balance,
                        equity=equity,
                        details=f"{dir_str} Entry=${pos.entry_price:.2f}, Profit=${profit:.2f}"
                    ))
                    thread.log.total_profit += profit
                    thread.log.total_trades += 1
                    if profit > 0:
                        thread.log.winning_trades += 1

                # TP Cascade: mark for re-entry if not already a re-entry
                if thread and thread.active and not pos.is_reentry:
                    thread.pending_reentries[pos.avg_level] = PendingReentry(
                        level=pos.avg_level,
                        entry_price=pos.entry_price,
                        lots=pos.lots,
                        tp_distance=thread.tp_dist,
                        direction=pos.direction
                    )

    def _check_thread_reentries(self, thread: TradingThread, timestamp: datetime, bid: float, ask: float):
        """Check if price has returned to any pending re-entry levels for this thread."""
        levels_to_reopen = []

        for level, reentry in list(thread.pending_reentries.items()):
            trigger = False
            entry_price = ask if reentry.direction == 1 else bid

            if reentry.direction == 1:  # BUY - reenter when price drops back down
                trigger = ask <= reentry.entry_price + self.reentry_tolerance
            else:  # SELL - reenter when price rises back up
                trigger = bid >= reentry.entry_price - self.reentry_tolerance

            if trigger:
                grid_idx = self._get_grid_index(entry_price)
                if grid_idx not in thread.occupied_grid_levels:
                    levels_to_reopen.append(level)

        for level in sorted(levels_to_reopen):
            reentry = thread.pending_reentries[level]
            entry_price = ask if reentry.direction == 1 else bid

            # Update weighted average
            thread.total_lots += reentry.lots
            thread.weighted_entry += entry_price * reentry.lots

            # Calculate TP
            if self.individual_tp_enabled:
                if reentry.direction == 1:
                    tp_price = entry_price + self.individual_tp_distance
                else:
                    tp_price = entry_price - self.individual_tp_distance
            else:
                # Unified TP
                avg_entry = thread.weighted_entry / thread.total_lots if thread.total_lots > 0 else entry_price
                if reentry.direction == 1:
                    tp_price = avg_entry + reentry.tp_distance
                else:
                    tp_price = avg_entry - reentry.tp_distance

            # Open position
            is_avg = level > 0
            self._open_position(entry_price, reentry.lots, tp_price, timestamp,
                              direction=reentry.direction, thread_id=thread.id,
                              is_avg=is_avg, avg_level=level, is_reentry=True)

            # Log
            if self.enable_logging and thread.log:
                equity = self._get_equity(bid, ask)
                avg_entry = thread.weighted_entry / thread.total_lots if thread.total_lots > 0 else entry_price
                thread.log.add_event(ThreadEvent(
                    timestamp=timestamp,
                    event_type="REENTRY",
                    level=level,
                    price=entry_price,
                    lots=reentry.lots,
                    tp_price=tp_price,
                    balance=self.balance,
                    equity=equity,
                    details=f"OrigEntry=${reentry.entry_price:.2f}, AvgEntry=${avg_entry:.2f}"
                ))

            del thread.pending_reentries[level]

    def _check_thread_averaging(self, thread: TradingThread, timestamp: datetime, bid: float, ask: float):
        """Check if averaging order should be placed for this thread."""
        next_level = thread.avg_level + 1
        trigger_dist = self._calc_avg_distance(thread.entry_dist, next_level)

        if thread.direction == 1:  # BUY - average DOWN when price drops
            trigger_price = thread.entry_price - trigger_dist
            triggered = bid <= trigger_price
            entry_price = ask
        else:  # SELL - average UP when price rises
            trigger_price = thread.entry_price + trigger_dist
            triggered = ask >= trigger_price
            entry_price = bid

        if triggered:
            # Grid-based duplicate prevention
            grid_idx = self._get_grid_index(entry_price)
            if grid_idx in thread.occupied_grid_levels:
                return

            # Calculate lot size
            lots = self._calc_avg_lot(thread.base_lots, next_level)

            # Update weighted average
            thread.total_lots += lots
            thread.weighted_entry += entry_price * lots

            # Calculate TP
            if self.individual_tp_enabled:
                if thread.direction == 1:
                    tp_price = entry_price + self.individual_tp_distance
                else:
                    tp_price = entry_price - self.individual_tp_distance
            else:
                # Unified TP
                avg_entry = thread.weighted_entry / thread.total_lots
                if thread.direction == 1:
                    tp_price = avg_entry + thread.tp_dist
                else:
                    tp_price = avg_entry - thread.tp_dist

                # Update all positions to unified TP
                for pos in self.positions.values():
                    if pos.thread_id == thread.id:
                        pos.tp_price = tp_price

            # Open position
            self._open_position(entry_price, lots, tp_price, timestamp,
                              direction=thread.direction, thread_id=thread.id,
                              is_avg=True, avg_level=next_level)

            thread.avg_level = next_level
            if next_level > self.max_avg_level:
                self.max_avg_level = next_level

            # Check if we should open a hedge
            if self.hedge_enabled and not thread.hedge_active and next_level >= self.hedge_trigger_level:
                self._open_hedge(thread, timestamp, bid, ask)

            # Log
            if self.enable_logging and thread.log:
                equity = self._get_equity(bid, ask)
                drop = abs(entry_price - thread.entry_price)
                avg_entry = thread.weighted_entry / thread.total_lots
                thread.log.add_event(ThreadEvent(
                    timestamp=timestamp,
                    event_type="AVG",
                    level=next_level,
                    price=entry_price,
                    lots=lots,
                    tp_price=tp_price,
                    balance=self.balance,
                    equity=equity,
                    details=f"Drop=${drop:.2f}, AvgEntry=${avg_entry:.2f}, TotalLots={thread.total_lots:.4f}"
                ))

    def _open_hedge(self, thread: TradingThread, timestamp: datetime, bid: float, ask: float):
        """Open a hedge position to lock the thread's drawdown."""
        if thread.hedge_active:
            return  # Already hedged

        # Hedge direction is opposite of thread direction
        hedge_direction = -thread.direction

        # Hedge size = percentage of total main lots
        hedge_lots = thread.total_lots * (self.hedge_lot_percent / 100)

        # Entry price depends on hedge direction
        if hedge_direction == 1:  # BUY hedge (thread was SHORT)
            entry_price = ask
        else:  # SELL hedge (thread was LONG)
            entry_price = bid

        # Open hedge position (no TP - we manage it manually)
        self._pos_counter += 1
        hedge_pos = Position(
            id=self._pos_counter,
            entry_price=entry_price,
            lots=hedge_lots,
            tp_price=0,  # No automatic TP
            entry_time=timestamp,
            direction=hedge_direction,
            is_averaging=False,
            avg_level=-1,  # Mark as hedge
            is_reentry=False,
            thread_id=thread.id
        )
        self.positions[hedge_pos.id] = hedge_pos

        # Deduct commission
        comm = self._calc_commission(hedge_lots, entry_price)
        self.balance -= comm
        self.total_commission += comm

        # Track hedge in thread
        thread.hedge_active = True
        thread.hedge_position_id = hedge_pos.id
        thread.hedge_entry_price = entry_price
        thread.hedge_lots = hedge_lots

        # Log
        if self.enable_logging and thread.log:
            equity = self._get_equity(bid, ask)
            dir_str = "BUY" if hedge_direction == 1 else "SELL"
            thread.log.add_event(ThreadEvent(
                timestamp=timestamp,
                event_type="HEDGE_OPEN",
                level=thread.avg_level,
                price=entry_price,
                lots=hedge_lots,
                tp_price=0,
                balance=self.balance,
                equity=equity,
                details=f"{dir_str} Hedge, TotalMainLots={thread.total_lots:.4f}"
            ))

    def _check_hedge(self, thread: TradingThread, timestamp: datetime, bid: float, ask: float):
        """Check if hedge should be closed - take profit when hedge is profitable."""
        if not thread.hedge_active or thread.hedge_position_id not in self.positions:
            thread.hedge_active = False
            return

        hedge_pos = self.positions[thread.hedge_position_id]

        # Calculate hedge P&L
        if hedge_pos.direction == 1:  # BUY hedge - close at bid
            hedge_pnl = hedge_pos.lots * (bid - hedge_pos.entry_price)
        else:  # SELL hedge - close at ask
            hedge_pnl = hedge_pos.lots * (hedge_pos.entry_price - ask)

        # Take profit on hedge when it's profitable
        if hedge_pnl >= self.hedge_close_profit:
            self._close_hedge(thread, timestamp, bid, ask, f"HedgeTP=${hedge_pnl:.2f}")
            return

        # Also close if all main positions closed
        thread_positions = [p for p in self.positions.values()
                          if p.thread_id == thread.id and p.avg_level >= 0]

        if len(thread_positions) == 0:
            self._close_hedge(thread, timestamp, bid, ask, "MainPositionsTP")

    def _close_hedge(self, thread: TradingThread, timestamp: datetime, bid: float, ask: float, reason: str = ""):
        """Close the hedge position."""
        if not thread.hedge_active or thread.hedge_position_id not in self.positions:
            thread.hedge_active = False
            return

        hedge_pos = self.positions[thread.hedge_position_id]

        # Close price depends on direction
        if hedge_pos.direction == 1:  # BUY - close at bid
            close_price = bid
        else:  # SELL - close at ask
            close_price = ask

        # Calculate profit
        profit = self._close_position(thread.hedge_position_id, close_price)

        # Clear hedge tracking
        thread.hedge_active = False
        thread.hedge_position_id = 0
        thread.hedge_entry_price = 0
        thread.hedge_lots = 0

        # Log
        if self.enable_logging and thread.log:
            equity = self._get_equity(bid, ask)
            thread.log.add_event(ThreadEvent(
                timestamp=timestamp,
                event_type="HEDGE_CLOSE",
                level=thread.avg_level,
                price=close_price,
                lots=hedge_pos.lots,
                tp_price=0,
                profit=profit,
                balance=self.balance,
                equity=equity,
                details=f"Reason={reason}, P&L=${profit:.2f}"
            ))
            thread.log.total_profit += profit
            thread.log.total_trades += 1
            if profit > 0:
                thread.log.winning_trades += 1

    def _cleanup_thread_reentries(self, thread: TradingThread, current_price: float, timestamp: datetime):
        """Abandon pending re-entries if price has moved too far."""
        levels_to_abandon = []

        for level, reentry in thread.pending_reentries.items():
            distance = abs(current_price - reentry.entry_price)
            if distance > self.reentry_abandon_distance:
                levels_to_abandon.append((level, reentry, distance))

        for level, reentry, distance in levels_to_abandon:
            if self.enable_logging and thread.log:
                thread.log.add_event(ThreadEvent(
                    timestamp=timestamp,
                    event_type="ABANDON",
                    level=level,
                    price=current_price,
                    lots=reentry.lots,
                    tp_price=0,
                    balance=self.balance,
                    equity=self._get_equity(current_price),
                    details=f"Distance=${distance:.2f}, OrigEntry=${reentry.entry_price:.2f}"
                ))

            del thread.pending_reentries[level]
            self.reentries_abandoned += 1

    def _complete_thread(self, thread: TradingThread, timestamp: datetime, bid: float):
        """Mark a thread as complete."""
        # Close any remaining hedge
        if thread.hedge_active:
            self._close_hedge(thread, timestamp, bid, bid, "ThreadComplete")

        thread.active = False
        self.threads_completed += 1

        # Finalize log
        if self.enable_logging and thread.log:
            thread.log.end_time = timestamp
            equity = self._get_equity(bid)

            # Check if abandoned or completed
            event_type = "THREAD_COMPLETE"
            if any(e.event_type == "ABANDON" for e in thread.log.events[-5:]):
                event_type = "THREAD_ABANDONED"

            thread.log.add_event(ThreadEvent(
                timestamp=timestamp,
                event_type=event_type,
                level=thread.log.max_level,
                price=bid,
                lots=0,
                tp_price=0,
                profit=thread.log.total_profit,
                balance=self.balance,
                equity=equity,
                details=f"TotalProfit=${thread.log.total_profit:.2f}"
            ))
            self.thread_logs.append(thread.log)

        # Remove completed thread
        del self.threads[thread.id]

    def _open_thread(self, timestamp: datetime, bid: float, ask: float, spread: float, direction: int = 1):
        """Open a new trading thread (BUY or SELL)."""
        self._thread_counter += 1
        thread_id = self._thread_counter

        tp_dist = spread * self.base_multiplier
        entry_dist = spread * self.base_multiplier

        # Entry price depends on direction
        if direction == 1:  # BUY
            entry_price = ask
            self.last_long_entry = entry_price
        else:  # SELL
            entry_price = bid
            self.last_short_entry = entry_price

        lots = self._calc_lot_size(entry_price)

        # Calculate TP based on mode
        if self.individual_tp_enabled:
            if direction == 1:
                tp_price = entry_price + self.individual_tp_distance
            else:
                tp_price = entry_price - self.individual_tp_distance
        else:
            if direction == 1:
                tp_price = entry_price + tp_dist
            else:
                tp_price = entry_price - tp_dist

        # Create thread
        thread = TradingThread(
            id=thread_id,
            direction=direction,
            entry_price=entry_price,
            entry_time=timestamp,
            base_lots=lots,
            tp_dist=tp_dist,
            entry_dist=entry_dist,
            total_lots=lots,
            weighted_entry=entry_price * lots
        )

        # Create thread log
        if self.enable_logging:
            dir_str = "BUY" if direction == 1 else "SELL"
            thread.log = ThreadLog(
                thread_id=thread_id,
                start_time=timestamp,
                entry_price=entry_price
            )
            equity = self._get_equity(bid, ask)
            thread.log.add_event(ThreadEvent(
                timestamp=timestamp,
                event_type="OPEN",
                level=0,
                price=entry_price,
                lots=lots,
                tp_price=tp_price,
                balance=self.balance,
                equity=equity,
                details=f"{dir_str} Spread=${spread:.2f}, TP_dist=${tp_dist:.2f}"
            ))

        self.threads[thread_id] = thread

        # Open the position
        self._open_position(entry_price, lots, tp_price, timestamp,
                          direction=direction, thread_id=thread_id,
                          is_avg=False, avg_level=0)

def run_fast_backtest(
    csv_file: str = "BTCUSD.csv",
    max_rows: Optional[int] = None
) -> FastBacktestResult:
    """Run fast backtest with settings from config."""
    from settings import (
        SYMBOL, INITIAL_BALANCE, COMMISSION_PERCENT,
        RISK_PERCENT, BASE_MULTIPLIER, AGGRESSION_LEVEL,
        CUSTOM_LEVELS, CUSTOM_MULTIPLIERS, MAX_SPREAD,
        USE_CUMULATIVE_MULTIPLIERS,
        ENABLE_LONGS, ENABLE_SHORTS, GRID_SPACING,
        MAX_CONCURRENT_THREADS, INDIVIDUAL_TP_ENABLED,
        INDIVIDUAL_TP_DISTANCE,
        HEDGE_ENABLED, HEDGE_TRIGGER_LEVEL, HEDGE_CLOSE_PROFIT, HEDGE_LOT_PERCENT
    )

    engine = FastBacktestEngine(
        symbol=SYMBOL,
        initial_balance=INITIAL_BALANCE,
        commission_pct=COMMISSION_PERCENT,
        risk_pct=RISK_PERCENT,
        base_multiplier=BASE_MULTIPLIER,
        aggression=AGGRESSION_LEVEL,
        report_interval=100000
    )

    # Apply custom martingale settings if configured
    if CUSTOM_LEVELS is not None:
        engine.custom_levels = CUSTOM_LEVELS
    if CUSTOM_MULTIPLIERS is not None:
        engine.custom_multipliers = CUSTOM_MULTIPLIERS
    engine.max_spread = MAX_SPREAD
    engine.use_cumulative_multipliers = USE_CUMULATIVE_MULTIPLIERS

    # Bidirectional trading settings
    engine.enable_longs = ENABLE_LONGS
    engine.enable_shorts = ENABLE_SHORTS
    engine.grid_spacing = GRID_SPACING
    engine.max_concurrent_threads = MAX_CONCURRENT_THREADS

    # Individual TP settings
    engine.individual_tp_enabled = INDIVIDUAL_TP_ENABLED
    engine.individual_tp_distance = INDIVIDUAL_TP_DISTANCE

    # Hedge settings
    engine.hedge_enabled = HEDGE_ENABLED
    engine.hedge_trigger_level = HEDGE_TRIGGER_LEVEL
    engine.hedge_close_profit = HEDGE_CLOSE_PROFIT
    engine.hedge_lot_percent = HEDGE_LOT_PERCENT

    return engine.run_from_csv(csv_file, max_rows=max_rows)


if __name__ == "__main__":
    import sys
    from settings import TICK_DATA_FILE

    csv_file = TICK_DATA_FILE
    max_rows = None  # Process all

    if len(sys.argv) > 1:
        csv_file = sys.argv[1]
    if len(sys.argv) > 2:
        max_rows = int(sys.argv[2])

    result = run_fast_backtest(csv_file, max_rows)
