"""
Hedge-Mart Recovery Engine.

Port of Auto Hedge-Mart-V4.mq4 strategy to Python backtesting framework.

Architecture mapping from MQ4:
  Ors struct       → SubOrderPair (additional buy+sell pair at same level)
  Order struct     → RecoveryOrderPair (main hedged buy+sell at a recovery level)
  Odr class        → RecoveryThread (one recovery grid sequence)
  OnTick entries   → HedgeMartStrategy (entry signals + thread management)
"""

from dataclasses import dataclass, field
from typing import List, Optional
from collections import deque

from normalization import normalize_lot_size, get_pip_size
from order_engine import BacktestOrderEngine, OrderSide
from data_loader import OHLCV
from utils import get_logger, log_trade

from settings import (
    SYMBOL,
    LOT_SIZE, MAX_INITIAL_ORDERS,
    ENTRY_TP_PIPS, ENTRY_SL_PIPS,
    ENABLE_REVERSAL_ENTRIES, ENABLE_CONTINUATION_ENTRIES,
    FIBO_BARS_BACK, FIBO_BUY_ZONE, FIBO_SELL_ZONE, FIBO_RANGE_ZONE,
    RECOVERY_FIBO_ZONE,
    RECOVERY_GRID_STEPS, RECOVERY_TPS,
    TP_MULTIPLIER, TP_MULT_START_LEVEL, TP_MULT_STOP_LEVEL,
    AUTO_CALCULATE_LOTS, RECOVERY_PROFIT_TARGET, RECOVERY_LOT_MULTIPLIER,
    OUTSIDE_LOT_MULTIPLIER, INSIDE_LOT_MULTIPLIER,
    OUTSIDE_LOT_START, OUTSIDE_LOT_STOP,
    INSIDE_LOT_START, INSIDE_LOT_STOP,
    REPEAT_BOTH_SIDES, REPEAT_SINGLE_SIDES,
    MAX_LEVERAGE, MARGIN_CAP_ENABLED, MARGIN_RESERVE_PCT,
    PROGRESSIVE_RLM_ENABLED, PROGRESSIVE_RLM_START,
    PROGRESSIVE_RLM_RAMP, PROGRESSIVE_RLM_MAX,
    MARGIN_RELIEF_ENABLED, MARGIN_RELIEF_LEVEL, MARGIN_RELIEF_CLOSE_SUBS,
    MIN_ENTRY_SPACING,
    ANTI_MART_ENABLED, ANTI_MART_LOT_PCT, ANTI_MART_TP_PIPS,
    ANTI_MART_REOPEN, ANTI_MART_MAX_LEVELS, ANTI_MART_LOT_MULT,
    DYNAMIC_MO_ENABLED, DYNAMIC_MO_LOW_THRESH, DYNAMIC_MO_HIGH_THRESH, DYNAMIC_MO_MIN,
    INVENTORY_BIAS_ENABLED, INVENTORY_BIAS_THRESHOLD,
    BAR_REVERSAL_MODE, THREAD_PROFIT_ENABLED, THREAD_PROFIT_TARGET,
    ADAPTIVE_GRID, ADAPTIVE_ATR_BARS, REFERENCE_ATR,
    ADAPTIVE_MIN_SCALE, ADAPTIVE_MAX_SCALE,
)


# =============================================================================
# DATA STRUCTURES
# =============================================================================

@dataclass
class SubOrderPair:
    """Additional order pair at same level with its own TP.
    Maps to MQ4 Ors struct.

    Each recovery level can have multiple sub-order pairs (recoveryTP2, TP3, TP4)
    that trade independently with their own TP and repeat logic.
    """
    buy_order_id: int = 0
    sell_order_id: int = 0
    tp_pips: float = 0.0
    level_price: float = 0.0
    buy_tp_hit: bool = False
    sell_tp_hit: bool = False
    is_closed: bool = False
    last_direction: int = 0   # 0=buy hit first, 1=sell hit first
    buy_lots: float = 0.0
    sell_lots: float = 0.0


@dataclass
class AntiMartHedge:
    """Anti-martingale hedge grinder at a grid level.
    
    Ported from AntiMartingaleEA.mq4. Opens a small opposite-direction
    order with tight TP. On TP hit, immediately reopens at current price.
    Generates continuous small profits to offset martingale recovery costs.
    """
    level: int = 0
    order_id: int = 0
    direction: int = 0        # 0=buy, 1=sell (OPPOSITE to thread direction)
    lots: float = 0.0
    entry_price: float = 0.0
    tp_pips: float = 0.0
    is_active: bool = True
    reopen_count: int = 0
    total_profit: float = 0.0


@dataclass
class RecoveryOrderPair:
    """Main hedged order pair at a recovery level.
    Maps to MQ4 Order struct.

    When a grid level is crossed, BOTH a buy and sell are opened simultaneously.
    The pair has a shared TP distance and links to the previous level's orders
    for cascade closing.
    """
    buy_order_id: int = 0
    sell_order_id: int = 0
    tp_pips: float = 0.0
    buy_entry: float = 0.0
    sell_entry: float = 0.0
    buy_lots: float = 0.0
    sell_lots: float = 0.0
    # Cascade close: tickets from previous level to close on TP hit
    buy_close_id: int = 0
    sell_close_id: int = 0
    # Sub-order pairs at this level
    sub_pairs: List[SubOrderPair] = field(default_factory=list)
    # Previous level's sub-pairs (closed on cascade)
    old_sub_pairs: List[SubOrderPair] = field(default_factory=list)


# =============================================================================
# RECOVERY THREAD (maps to MQ4 Odr class)
# =============================================================================

class RecoveryThread:
    """One recovery grid sequence tracking an initial trade.

    When an initial trade is opened, a RecoveryThread is created. It builds
    a grid of price levels above and below the entry. When price crosses
    a level, it opens hedged buy+sell pairs. Each pair's lots are sized to
    cover the previous level's loss plus a fixed profit target.

    The grid levels are CUMULATIVE from entry using the step array:
      Level  1 = entry ± step[0] pips
      Level  2 = Level 1 ± step[1] pips
      Level  3 = Level 2 ± step[2] pips
      ...
    """

    def __init__(
        self,
        initial_order_id: int,
        entry_price: float,
        lot_size: float,
        direction: int,
        magic: int,
        recovery_tp_pips: float,
        symbol: str,
        pip_size: float,
        grid_steps: List[int] = None,
        recovery_tps: List[int] = None,
        entry_tp_pips: int = None,
        vol_scale: float = 1.0,
    ):
        self.initial_order_id = initial_order_id
        self.entry_price = entry_price
        self.lot_size = lot_size
        self.direction = direction  # 0=buy, 1=sell
        self.magic = magic
        self.take_profit_pips = recovery_tp_pips
        self.symbol = symbol
        self.pip_size = pip_size
        self.is_active = True
        self.vol_scale = vol_scale  # locked-in volatility scale at creation

        # Per-thread recovery TPs (adaptive or global)
        self.recovery_tps = recovery_tps if recovery_tps is not None else list(RECOVERY_TPS)
        # Per-thread entry TP (adaptive or global)
        self.entry_tp_pips = entry_tp_pips if entry_tp_pips is not None else ENTRY_TP_PIPS

        self.current_level = 0   # current level ID (positive=up, negative=down)
        self.level_count = 0     # total levels crossed (for TP multiplier logic)
        self.hold = False        # bar reversal hold state
        self.thread_mode = False # thread profit active after bar reversal close

        # Anti-martingale hedge grinders (ported from AntiMartingaleEA.mq4)
        self.anti_mart_hedges: List[AntiMartHedge] = []

        # Build grid levels (both up and down from entry)
        if grid_steps is None:
            grid_steps = RECOVERY_GRID_STEPS
        self.grid_steps = grid_steps  # store for max_iterations check
        self.levels = {}  # id -> price
        p_up = entry_price
        p_dn = entry_price
        for i, step in enumerate(grid_steps):
            step_price = step * pip_size
            p_up += step_price
            p_dn -= step_price
            self.levels[i + 1] = round(p_up, 8)
            self.levels[-(i + 1)] = round(p_dn, 8)

        # Recovery order pairs at each level
        self.order_pairs: List[RecoveryOrderPair] = []

        self.logger = get_logger()

    def get_next_levels(self):
        """Get the next up and down level IDs and prices.

        Navigates the grid: from current_level, next up = current+1,
        next down = current-1, skipping level 0 (entry point).
        """
        if self.current_level == 0:
            up_id = 1
            dn_id = -1
        else:
            up_id = self.current_level + 1
            dn_id = self.current_level - 1
            # Skip level 0 (entry point)
            if up_id == 0:
                up_id = 1
            if dn_id == 0:
                dn_id = -1

        up_price = self.levels.get(up_id, 0.0)
        dn_price = self.levels.get(dn_id, 0.0)
        return up_id, up_price, dn_id, dn_price

    def get_lots(self, order_type: int, in_or_out: int, order_engine: BacktestOrderEngine) -> float:
        """Calculate lot size for recovery order.

        MQ4 GetLots() port. Two modes:
        - autocalculatelots=true: multiplier-based (prev lots × multiplier)
        - autocalculatelots=false: loss-recovery (cover loss + fixed profit)

        Cross-reference: getLotSize(type) in MQ4 uses type==1→buyTicket, else→sellTicket.
        So when calculating BUY lots (type=0), it references the previous SELL lots,
        and vice versa. This is intentional for hedging.

        Args:
            order_type: 0=buy, 1=sell
            in_or_out: 0=inside (same dir as break), 1=outside (opposite)
            order_engine: For price/position lookups
        """
        s = len(self.order_pairs)

        if AUTO_CALCULATE_LOTS:
            if (in_or_out == 1
                    and self.level_count >= OUTSIDE_LOT_START
                    and self.level_count < OUTSIDE_LOT_STOP):
                if s > 0:
                    prev_lots = self._get_cross_lots(self.order_pairs[-1], order_type)
                    lots = prev_lots * OUTSIDE_LOT_MULTIPLIER
                else:
                    lots = self.lot_size * OUTSIDE_LOT_MULTIPLIER
            elif (in_or_out == 0
                    and self.level_count >= INSIDE_LOT_START
                    and self.level_count < INSIDE_LOT_STOP):
                if s > 0:
                    prev_lots = self._get_cross_lots(self.order_pairs[-1], order_type)
                    lots = prev_lots * INSIDE_LOT_MULTIPLIER
                else:
                    lots = self.lot_size * INSIDE_LOT_MULTIPLIER
            else:
                if s > 0:
                    prev_lots = self._get_cross_lots(self.order_pairs[-1], order_type)
                    lots = prev_lots
                else:
                    lots = self.lot_size
        else:
            # Loss-recovery mode
            tp_price_dist = self.take_profit_pips * self.pip_size
            if s > 0:
                loss = self._get_loss_at_tp(self.order_pairs[-1], order_type, tp_price_dist, order_engine)
                lots = (loss + RECOVERY_PROFIT_TARGET) / tp_price_dist if tp_price_dist > 0 else self.lot_size
            else:
                if order_type == self.direction:
                    lots = self.lot_size
                else:
                    loss = self._get_initial_loss_at_tp(tp_price_dist, order_engine)
                    lots = (loss + RECOVERY_PROFIT_TARGET) / tp_price_dist if tp_price_dist > 0 else self.lot_size

        # Apply recovery lot multiplier
        if not AUTO_CALCULATE_LOTS and lots > self.lot_size:
            if PROGRESSIVE_RLM_ENABLED:
                # Progressive: only apply at deeper levels, ramp up gradually
                depth = self.level_count
                if depth >= PROGRESSIVE_RLM_START:
                    if PROGRESSIVE_RLM_RAMP:
                        # Ramp: 1.0 at start, increasing by 0.1 per level up to max
                        ramp_levels = depth - PROGRESSIVE_RLM_START
                        rlm = min(1.0 + ramp_levels * 0.1, PROGRESSIVE_RLM_MAX)
                    else:
                        rlm = PROGRESSIVE_RLM_MAX
                    lots = lots * rlm
                # else: L1-L2 stay at 1.0x (no multiplier)
            elif RECOVERY_LOT_MULTIPLIER != 1.0:
                # Legacy: flat multiplier at all levels
                lots = lots * RECOVERY_LOT_MULTIPLIER

        # Margin-based lot cap: skip order if not enough margin
        if MARGIN_CAP_ENABLED and order_engine is not None:
            equity = order_engine.get_equity()
            if equity > 0:
                usable_equity = equity * (1.0 - MARGIN_RESERVE_PCT)
                margin_used = order_engine.get_margin_used()
                available_margin = max(0, usable_equity * MAX_LEVERAGE - margin_used)
                current_price = order_engine._current_bid if hasattr(order_engine, '_current_bid') else 100000
                if current_price > 0:
                    required_margin = lots * current_price
                    if required_margin > available_margin:
                        # Not enough margin — return 0 to signal skip
                        return 0.0

        lots = max(lots, self.lot_size)
        return normalize_lot_size(lots, self.symbol)

    def _get_cross_lots(self, pair: RecoveryOrderPair, order_type: int) -> float:
        """Get lot size from previous pair with MQ4 cross-reference.

        MQ4 getLotSize(type): type==1 → buyTicket lots, else → sellTicket lots.
        So for calculating BUY (type=0) → returns SELL lots from previous.
        For calculating SELL (type=1) → returns BUY lots from previous.
        """
        if order_type == 1:
            return pair.buy_lots
        else:
            return pair.sell_lots

    def _get_loss_at_tp(self, pair: RecoveryOrderPair, order_type: int,
                        tp_dist: float, order_engine: BacktestOrderEngine) -> float:
        """Calculate projected loss of opposite side at TP price."""
        bid, ask = order_engine.get_current_price()
        if order_type == 0:  # calculating for buy
            tp_target = ask + tp_dist
            pos = order_engine.positions.get(pair.sell_order_id)
            if pos:
                return pos.lots * abs(pos.entry_price - tp_target)
        else:  # calculating for sell
            tp_target = bid - tp_dist
            pos = order_engine.positions.get(pair.buy_order_id)
            if pos:
                return pos.lots * abs(pos.entry_price - tp_target)
        return 0.0

    def _get_initial_loss_at_tp(self, tp_dist: float, order_engine: BacktestOrderEngine) -> float:
        """Calculate loss from the initial order at TP distance."""
        pos = order_engine.positions.get(self.initial_order_id)
        if pos:
            bid, ask = order_engine.get_current_price()
            if self.direction == 0:
                tp_target = ask + tp_dist
            else:
                tp_target = bid - tp_dist
            return pos.lots * abs(pos.entry_price - tp_target)
        return 0.0

    def add_order(self, break_direction: int, order_engine: BacktestOrderEngine,
                  fibo_value: float = 0.5):
        """Open hedged buy+sell pair at current level.

        MQ4 addOrder() port. Opens main pair + sub-order pairs (Ors).
        MQ4 line 906-920: sub-pairs only created when Fibo condition met.

        Args:
            break_direction: 0=price went UP, 1=price went DOWN
            order_engine: For placing orders
            fibo_value: Current Fibo zone value for sub-pair gate
        """
        # Apply TP multiplier for levels in range (cumulative!)
        if (self.level_count >= TP_MULT_START_LEVEL
                and self.level_count < TP_MULT_STOP_LEVEL):
            self.take_profit_pips *= TP_MULTIPLIER

        bid, ask = order_engine.get_current_price()

        # Determine inside/outside for lot calculation
        # Inside = same direction as break, Outside = opposite
        buy_in_out = 0 if break_direction == 0 else 1
        sell_in_out = 0 if break_direction == 1 else 1

        buy_lots = self.get_lots(0, buy_in_out, order_engine)
        sell_lots = self.get_lots(1, sell_in_out, order_engine)

        # Skip if margin cap blocked the order
        if buy_lots == 0.0 or sell_lots == 0.0:
            return

        # Place hedged pair
        buy_order = order_engine.place_buy_order(
            lots=buy_lots, magic=self.magic,
            comment=f"REC_L{self.level_count}_BUY"
        )
        sell_order = order_engine.place_sell_order(
            lots=sell_lots, magic=self.magic,
            comment=f"REC_L{self.level_count}_SELL"
        )

        pair = RecoveryOrderPair(
            buy_order_id=buy_order.order_id if buy_order else 0,
            sell_order_id=sell_order.order_id if sell_order else 0,
            tp_pips=self.take_profit_pips,
            buy_entry=buy_order.price if buy_order else ask,
            sell_entry=sell_order.price if sell_order else bid,
            buy_lots=buy_lots,
            sell_lots=sell_lots,
        )

        # Add sub-order pairs (Ors) for recoveryTP2, TP3, TP4
        # MQ4 line 906-920: sub-pairs gated by Fibo condition
        fibo_min, fibo_max = RECOVERY_FIBO_ZONE
        sub_fibo_ok = fibo_min <= fibo_value <= fibo_max
        level_price = bid
        for i, sub_tp in enumerate(self.recovery_tps[1:], 1):
            if sub_tp > 0 and sub_fibo_ok:
                sub_buy = order_engine.place_buy_order(
                    lots=buy_lots, magic=self.magic,
                    comment=f"REC_L{self.level_count}_S{i}_BUY"
                )
                sub_sell = order_engine.place_sell_order(
                    lots=sell_lots, magic=self.magic,
                    comment=f"REC_L{self.level_count}_S{i}_SELL"
                )
                pair.sub_pairs.append(SubOrderPair(
                    buy_order_id=sub_buy.order_id if sub_buy else 0,
                    sell_order_id=sub_sell.order_id if sub_sell else 0,
                    tp_pips=sub_tp,
                    level_price=level_price,
                    buy_lots=buy_lots,
                    sell_lots=sell_lots,
                ))

        # Link cascade closing to previous level
        s = len(self.order_pairs)
        if s > 0:
            prev = self.order_pairs[-1]
            pair.buy_close_id = prev.buy_order_id
            pair.sell_close_id = prev.sell_order_id
            pair.old_sub_pairs = list(prev.sub_pairs)
            for sp in prev.sub_pairs:
                sp.is_closed = True
        else:
            pair.buy_close_id = self.initial_order_id
            pair.sell_close_id = self.initial_order_id

        self.order_pairs.append(pair)

        # Open anti-martingale hedge grinder at this level
        if ANTI_MART_ENABLED and len(self.anti_mart_hedges) < ANTI_MART_MAX_LEVELS:
            self._open_anti_mart_hedge(self.level_count, order_engine, bid, ask)

        log_trade("HEDGE_PAIR", self.symbol,
                  buy_order.price if buy_order else 0,
                  buy_lots, self.magic)

    # ---- Anti-Martingale Hedge Grinder ----

    def _open_anti_mart_hedge(self, level: int, order_engine: BacktestOrderEngine,
                              bid: float, ask: float, reopen_count: int = 0):
        """Open a small opposite-direction order with tight TP.
        
        Ported from AntiMartingaleEA.mq4 OpenHedgeOrder().
        Direction is OPPOSITE to the thread's initial direction.
        """
        # Calculate lot size: base_lot * pct * mult^(level-1)
        base_lot = LOT_SIZE * ANTI_MART_LOT_PCT
        hedge_lots = base_lot * (ANTI_MART_LOT_MULT ** max(0, level - 1))
        hedge_lots = normalize_lot_size(hedge_lots, self.symbol)

        # Opposite direction to thread
        if self.direction == 0:  # thread is BUY → hedge is SELL
            order = order_engine.place_sell_order(
                lots=hedge_lots, magic=self.magic,
                comment=f"AMT_L{level}_SELL_R{reopen_count}"
            )
        else:  # thread is SELL → hedge is BUY
            order = order_engine.place_buy_order(
                lots=hedge_lots, magic=self.magic,
                comment=f"AMT_L{level}_BUY_R{reopen_count}"
            )

        if order:
            hedge_dir = 1 if self.direction == 0 else 0
            self.anti_mart_hedges.append(AntiMartHedge(
                level=level,
                order_id=order.order_id,
                direction=hedge_dir,
                lots=hedge_lots,
                entry_price=order.price,
                tp_pips=ANTI_MART_TP_PIPS,
                is_active=True,
                reopen_count=reopen_count,
            ))

    def check_anti_mart_hedges(self, high: float, low: float,
                               order_engine: BacktestOrderEngine):
        """Check anti-mart hedge TPs and reopen if configured.
        
        Ported from AntiMartingaleEA.mq4 ProcessTPHits() + HandleHedgeTP().
        """
        if not ANTI_MART_ENABLED or not self.anti_mart_hedges:
            return

        bid, ask = order_engine.get_current_price()
        reopens = []

        for hedge in self.anti_mart_hedges:
            if not hedge.is_active:
                continue

            pos = order_engine.positions.get(hedge.order_id)
            if pos is None:
                hedge.is_active = False
                continue

            tp_dist = hedge.tp_pips * self.pip_size
            hit = False

            if hedge.direction == 0:  # BUY hedge
                tp_price = hedge.entry_price + tp_dist
                if high >= tp_price:
                    hit = True
            else:  # SELL hedge
                tp_price = hedge.entry_price - tp_dist
                if low <= tp_price:
                    hit = True

            if hit:
                # Close at TP
                profit = order_engine.close_position(hedge.order_id)
                if profit is not None:
                    hedge.total_profit += profit
                hedge.is_active = False

                # Queue reopen
                if ANTI_MART_REOPEN:
                    reopens.append((hedge.level, hedge.reopen_count + 1))

        # Reopen after processing (avoid modifying list during iteration)
        for level, reopen_count in reopens:
            if len([h for h in self.anti_mart_hedges if h.is_active]) < ANTI_MART_MAX_LEVELS:
                self._open_anti_mart_hedge(level, order_engine, bid, ask, reopen_count)

    # ---- Grid Level Processing ----

    def check_levels(self, high: float, low: float, order_engine: BacktestOrderEngine,
                     fibo_value: float = 0.5):
        """Check if price crossed any grid levels within candle range.

        Loops to handle multiple level crossings within a single candle
        (common for crypto where a single candle can span many grid levels).

        MQ4 V4 (line 710-720): recovery grid level crossings gated by
        CalcFibo() within RANGE_ZONE_MIN..RANGE_ZONE_MAX.
        """
        if not self.is_active:
            return

        # Fibo zone gate for recovery grid (MQ4 V4 range zone filter)
        fibo_min, fibo_max = RECOVERY_FIBO_ZONE
        if not (fibo_min <= fibo_value <= fibo_max):
            return

        max_iterations = len(self.grid_steps) * 2  # safety limit
        for _ in range(max_iterations):
            up_id, up_price, dn_id, dn_price = self.get_next_levels()

            crossed = False

            # Check upward level break
            if up_price > 0 and high >= up_price:
                self.level_count += 1
                self.current_level = up_id
                # Margin relief: close losing subs at deep levels to free margin
                if MARGIN_RELIEF_ENABLED and self.level_count >= MARGIN_RELIEF_LEVEL:
                    self._relieve_margin(order_engine, up_price)
                if not BAR_REVERSAL_MODE or self._count_active_orders(order_engine) < 2:
                    self.add_order(0, order_engine, fibo_value=fibo_value)
                crossed = True

            # Check downward level break (only if up didn't cross)
            elif dn_price > 0 and low <= dn_price:
                self.level_count += 1
                self.current_level = dn_id
                # Margin relief: close losing subs at deep levels to free margin
                if MARGIN_RELIEF_ENABLED and self.level_count >= MARGIN_RELIEF_LEVEL:
                    self._relieve_margin(order_engine, dn_price)
                if not BAR_REVERSAL_MODE or self._count_active_orders(order_engine) < 2:
                    self.add_order(1, order_engine, fibo_value=fibo_value)
                crossed = True

            if not crossed:
                break

    def check_tps(self, high: float, low: float, order_engine: BacktestOrderEngine,
                  fibo_value: float = 0.5):
        """Check TP hits for all recovery pairs using candle high/low."""
        if not self.is_active:
            return

        for pair in self.order_pairs:
            tp_dist = pair.tp_pips * self.pip_size

            # Check sub-pair TPs
            for sp in pair.sub_pairs:
                if sp.is_closed:
                    continue
                self._check_sub_tp(sp, high, low, order_engine)
                self._check_sub_repeat(sp, order_engine, fibo_value=fibo_value)

            # Main pair buy TP: price reaches entry + tp_dist
            buy_pos = order_engine.positions.get(pair.buy_order_id)
            if buy_pos and high >= pair.buy_entry + tp_dist:
                order_engine.close_position(pair.buy_order_id, pair.buy_entry + tp_dist)
                self._close_if_open(pair.buy_close_id, order_engine)
                self._close_if_open(pair.sell_close_id, order_engine)
                self._close_old_subs(pair, order_engine)
                # If sell side gone too, close current sub-pairs
                sell_still_open = pair.sell_order_id in order_engine.positions
                if not sell_still_open:
                    self._close_current_subs(pair, order_engine)

            # Main pair sell TP: price reaches entry - tp_dist
            sell_pos = order_engine.positions.get(pair.sell_order_id)
            if sell_pos and low <= pair.sell_entry - tp_dist:
                order_engine.close_position(pair.sell_order_id, pair.sell_entry - tp_dist)
                self._close_if_open(pair.buy_close_id, order_engine)
                self._close_if_open(pair.sell_close_id, order_engine)
                self._close_old_subs(pair, order_engine)
                buy_still_open = pair.buy_order_id in order_engine.positions
                if not buy_still_open:
                    self._close_current_subs(pair, order_engine)

    def check_thread_profit(self, order_engine: BacktestOrderEngine):
        """Check thread profit exit condition."""
        if not self.is_active:
            return

        if self.thread_mode or THREAD_PROFIT_ENABLED:
            total = self.get_total_profit(order_engine)
            if total >= THREAD_PROFIT_TARGET:
                self.close_all(order_engine)
                self.is_active = False

    def check_deactivation(self, order_engine: BacktestOrderEngine):
        """Check if thread should deactivate."""
        if not self.is_active:
            return

        # Original order closed at profit with no recovery orders → done
        init_order = order_engine.orders.get(self.initial_order_id)
        if init_order and init_order.is_closed():
            if init_order.profit > 0 and len(self.order_pairs) == 0:
                self.is_active = False
                return

        # No active orders at all and we have recovery pairs → done
        if len(self.order_pairs) > 0 and not self._has_active_orders(order_engine):
            self.is_active = False

    def _check_sub_tp(self, sp: SubOrderPair, high: float, low: float,
                      order_engine: BacktestOrderEngine):
        """Check TP for a sub-order pair."""
        tp_dist = sp.tp_pips * self.pip_size

        # Buy TP
        buy_order = order_engine.orders.get(sp.buy_order_id)
        if sp.buy_order_id in order_engine.positions and buy_order:
            if high >= buy_order.price + tp_dist:
                order_engine.close_position(sp.buy_order_id, buy_order.price + tp_dist)
                sp.buy_tp_hit = True

        # Sell TP
        sell_order = order_engine.orders.get(sp.sell_order_id)
        if sp.sell_order_id in order_engine.positions and sell_order:
            if low <= sell_order.price - tp_dist:
                order_engine.close_position(sp.sell_order_id, sell_order.price - tp_dist)
                sp.sell_tp_hit = True

    def _check_sub_repeat(self, sp: SubOrderPair, order_engine: BacktestOrderEngine,
                          fibo_value: float = 0.5):
        """Check if sub-pair should repeat after TPs hit."""
        if sp.is_closed:
            return

        # Fibo zone gate for sub-pair repeats
        fibo_min, fibo_max = RECOVERY_FIBO_ZONE
        if not (fibo_min <= fibo_value <= fibo_max):
            return

        bid, _ = order_engine.get_current_price()

        if REPEAT_BOTH_SIDES:
            if sp.buy_tp_hit and not sp.sell_tp_hit:
                sp.last_direction = 0
            elif not sp.buy_tp_hit and sp.sell_tp_hit:
                sp.last_direction = 1
            elif sp.buy_tp_hit and sp.sell_tp_hit:
                should_repeat = False
                if sp.last_direction == 0 and bid >= sp.level_price:
                    should_repeat = True
                elif sp.last_direction == 1 and bid <= sp.level_price:
                    should_repeat = True

                if should_repeat:
                    sp.buy_tp_hit = False
                    sp.sell_tp_hit = False
                    buy_order = order_engine.place_buy_order(
                        lots=sp.buy_lots, magic=self.magic,
                        comment="SUB_RPT_BUY"
                    )
                    sell_order = order_engine.place_sell_order(
                        lots=sp.sell_lots, magic=self.magic,
                        comment="SUB_RPT_SELL"
                    )
                    if buy_order:
                        sp.buy_order_id = buy_order.order_id
                    if sell_order:
                        sp.sell_order_id = sell_order.order_id
        elif REPEAT_SINGLE_SIDES:
            if sp.buy_tp_hit and bid <= sp.level_price:
                sp.buy_tp_hit = False
                buy_order = order_engine.place_buy_order(
                    lots=sp.buy_lots, magic=self.magic,
                    comment="SUB_RPT_BUY"
                )
                if buy_order:
                    sp.buy_order_id = buy_order.order_id

            if sp.sell_tp_hit and bid >= sp.level_price:
                sp.sell_tp_hit = False
                sell_order = order_engine.place_sell_order(
                    lots=sp.sell_lots, magic=self.magic,
                    comment="SUB_RPT_SELL"
                )
                if sell_order:
                    sp.sell_order_id = sell_order.order_id

    def get_total_profit(self, order_engine: BacktestOrderEngine) -> float:
        """Get total P&L across all positions in this thread (open + closed).

        MQ4 uses OrderProfit()+OrderSwap()+OrderCommission() which includes
        already-realized P&L from closed sub-pairs. We sum:
        1. Unrealized P&L from still-open positions
        2. Realized P&L from closed orders matching this thread's magic number
        """
        total = 0.0

        # Unrealized P&L from open positions
        # Initial order
        pos = order_engine.positions.get(self.initial_order_id)
        if pos:
            total += pos.unrealized_pnl

        # All order pairs
        for pair in self.order_pairs:
            for oid in (pair.buy_order_id, pair.sell_order_id):
                pos = order_engine.positions.get(oid)
                if pos:
                    total += pos.unrealized_pnl

            for sp in pair.sub_pairs:
                if sp.is_closed:
                    continue
                for oid in (sp.buy_order_id, sp.sell_order_id):
                    pos = order_engine.positions.get(oid)
                    if pos:
                        total += pos.unrealized_pnl

        # Anti-martingale hedge positions (unrealized)
        for hedge in self.anti_mart_hedges:
            if hedge.is_active:
                pos = order_engine.positions.get(hedge.order_id)
                if pos:
                    total += pos.unrealized_pnl

        # Realized P&L from closed orders with matching magic number
        for closed_order in order_engine.trade_history:
            if closed_order.magic == self.magic:
                total += closed_order.profit

        return total

    def close_all(self, order_engine: BacktestOrderEngine):
        """Close all positions in this thread."""
        self._close_if_open(self.initial_order_id, order_engine)
        for pair in self.order_pairs:
            self._close_if_open(pair.buy_order_id, order_engine)
            self._close_if_open(pair.sell_order_id, order_engine)
            self._close_if_open(pair.buy_close_id, order_engine)
            self._close_if_open(pair.sell_close_id, order_engine)
            for sp in pair.sub_pairs:
                self._close_if_open(sp.buy_order_id, order_engine)
                self._close_if_open(sp.sell_order_id, order_engine)
        # Close anti-martingale hedges
        for hedge in self.anti_mart_hedges:
            if hedge.is_active:
                self._close_if_open(hedge.order_id, order_engine)
                hedge.is_active = False

    def _relieve_margin(self, order_engine: BacktestOrderEngine, current_price: float):
        """Close losing sub-pairs from earlier levels to free margin for deeper recovery.
        Only called at deep recovery levels (>= MARGIN_RELIEF_LEVEL).
        Closes sub-pairs from all but the latest recovery pair."""
        if not MARGIN_RELIEF_CLOSE_SUBS or len(self.order_pairs) < 2:
            return
        closed_count = 0
        for pair in self.order_pairs[:-1]:  # Don't touch latest pair
            for sub in pair.sub_pairs:
                if sub.is_closed:
                    continue
                # Close buy side if losing
                if sub.buy_order_id and sub.buy_order_id in order_engine.positions:
                    pos = order_engine.positions[sub.buy_order_id]
                    if pos.unrealized_pnl < -5:
                        order_engine.close_position(sub.buy_order_id, current_price)
                        closed_count += 1
                # Close sell side if losing
                if sub.sell_order_id and sub.sell_order_id in order_engine.positions:
                    pos = order_engine.positions[sub.sell_order_id]
                    if pos.unrealized_pnl < -5:
                        order_engine.close_position(sub.sell_order_id, current_price)
                        closed_count += 1
                if closed_count > 0:
                    sub.is_closed = True
        if closed_count > 0:
            self.logger.debug(f"Margin relief: closed {closed_count} losing sub-positions at L{self.level_count}")

    def _close_if_open(self, order_id: int, order_engine: BacktestOrderEngine):
        """Close a position if it's still open."""
        if order_id and order_id in order_engine.positions:
            order_engine.close_position(order_id)

    def _close_old_subs(self, pair: RecoveryOrderPair, order_engine: BacktestOrderEngine):
        """Close old sub-order pairs from previous level."""
        for osp in pair.old_sub_pairs:
            self._close_if_open(osp.buy_order_id, order_engine)
            self._close_if_open(osp.sell_order_id, order_engine)
            osp.is_closed = True

    def _close_current_subs(self, pair: RecoveryOrderPair, order_engine: BacktestOrderEngine):
        """Close current sub-order pairs."""
        for sp in pair.sub_pairs:
            self._close_if_open(sp.buy_order_id, order_engine)
            self._close_if_open(sp.sell_order_id, order_engine)
            sp.is_closed = True

    def _has_active_orders(self, order_engine: BacktestOrderEngine) -> bool:
        """Check if any orders in this thread are still open."""
        if self.initial_order_id in order_engine.positions:
            return True
        for pair in self.order_pairs:
            if pair.buy_order_id in order_engine.positions:
                return True
            if pair.sell_order_id in order_engine.positions:
                return True
            for sp in pair.sub_pairs:
                if not sp.is_closed:
                    if sp.buy_order_id in order_engine.positions:
                        return True
                    if sp.sell_order_id in order_engine.positions:
                        return True
        return False

    def _count_active_orders(self, order_engine: BacktestOrderEngine) -> int:
        """Count active positions in order pairs."""
        count = 0
        for pair in self.order_pairs:
            if pair.buy_order_id in order_engine.positions:
                count += 1
            if pair.sell_order_id in order_engine.positions:
                count += 1
        return count


# =============================================================================
# MAIN STRATEGY ENGINE
# =============================================================================

class HedgeMartStrategy:
    """Main strategy engine for Auto Hedge-Mart V4.

    Entry logic:
    1. Continuation: two consecutive same-direction bars in Fibo zone → trade
    2. Reversal: previous bar opposite + current bar direction → trade (disabled by default)

    Recovery:
    - Each initial trade gets a RecoveryThread
    - Threads manage grid levels, hedged pairs, TP management, lot sizing
    """

    def __init__(self, symbol: str, order_engine: BacktestOrderEngine):
        self.symbol = symbol
        self.order_engine = order_engine
        self.pip_size = get_pip_size(symbol)
        self.recovery_threads: List[RecoveryThread] = []
        atr_lookback = max(FIBO_BARS_BACK, ADAPTIVE_ATR_BARS) if ADAPTIVE_GRID else FIBO_BARS_BACK
        self.bar_history: deque = deque(maxlen=max(atr_lookback + 5, 30))
        self.magic_counter = 0
        self.current_vol_scale = 1.0  # current volatility scaling factor
        self.logger = get_logger()

    def _next_magic(self) -> int:
        self.magic_counter += 1
        return 100000 + self.magic_counter

    # ---- Fibonacci Zone Calculation ----

    def calc_fibo(self, bars_back: int = FIBO_BARS_BACK) -> float:
        """Calculate Fibonacci zone position.

        MQ4: (Ask - Low[N bars]) / (High[N bars] - Low[N bars])
        Returns 0.0 (at range low) to 1.0 (at range high).
        """
        if len(self.bar_history) < bars_back:
            return 0.5

        recent = list(self.bar_history)[-bars_back:]
        low = min(b.low for b in recent)
        high = max(b.high for b in recent)

        if high <= low:
            return 0.5

        _, ask = self.order_engine.get_current_price()
        if ask <= 0:
            ask = recent[-1].close

        return (ask - low) / (high - low)

    def calc_atr(self, bars_back: int = ADAPTIVE_ATR_BARS) -> float:
        """Calculate Average True Range over recent bars.

        Returns average (high - low) in dollar terms.
        """
        if len(self.bar_history) < bars_back:
            return REFERENCE_ATR  # fallback to baseline

        recent = list(self.bar_history)[-bars_back:]
        ranges = [b.high - b.low for b in recent]
        return sum(ranges) / len(ranges)

    def get_vol_scale(self) -> float:
        """Get current volatility scale factor.

        scale = current_atr / reference_atr, clamped to [min, max].
        > 1.0 means market is hotter than baseline → widen grid/TP
        < 1.0 means market is quieter → tighten grid/TP
        """
        if not ADAPTIVE_GRID:
            return 1.0

        atr = self.calc_atr()
        scale = atr / REFERENCE_ATR if REFERENCE_ATR > 0 else 1.0
        scale = max(ADAPTIVE_MIN_SCALE, min(ADAPTIVE_MAX_SCALE, scale))
        return scale

    def get_scaled_params(self, vol_scale: float) -> dict:
        """Scale grid steps, recovery TPs, and entry TP by volatility factor."""
        scaled_grid = [round(s * vol_scale) for s in RECOVERY_GRID_STEPS]
        scaled_rec_tps = [round(t * vol_scale) if t > 0 else 0 for t in RECOVERY_TPS]
        scaled_entry_tp = round(ENTRY_TP_PIPS * vol_scale)
        return {
            "grid_steps": scaled_grid,
            "recovery_tps": scaled_rec_tps,
            "entry_tp_pips": scaled_entry_tp,
        }

    # ---- Bar Pattern Detection ----

    def _bar_direction(self, offset: int) -> int:
        """Get bar direction at offset from end. 0=bull, 1=bear, 2=doji.

        offset=1 → last completed bar (MQ4 bar[1])
        offset=2 → two bars ago (MQ4 bar[2])
        """
        idx = -(offset + 1)  # +1 because current bar is being built
        if abs(idx) > len(self.bar_history):
            return 2
        bar = self.bar_history[idx]
        if bar.close > bar.open:
            return 0
        elif bar.close < bar.open:
            return 1
        return 2

    def _count_initial_orders(self, side: int) -> int:
        """Count active initial orders for a given side (0=buy, 1=sell)."""
        count = 0
        for t in self.recovery_threads:
            if not t.is_active:
                continue
            pos = self.order_engine.positions.get(t.initial_order_id)
            if pos:
                if side == 0 and pos.side == OrderSide.BUY:
                    count += 1
                elif side == 1 and pos.side == OrderSide.SELL:
                    count += 1
        return count

    # ---- Entry Logic ----

    def on_new_bar(self, candle: OHLCV):
        """Process new bar — check for entry signals.

        MQ4 OnTick → IsNewBar section.
        """
        self.bar_history.append(candle)

        if len(self.bar_history) < 3:
            return

        # Update adaptive volatility scale each bar
        self.current_vol_scale = self.get_vol_scale()

        fibo = self.calc_fibo()
        buy_count = self._count_initial_orders(0)
        sell_count = self._count_initial_orders(1)

        # Dynamic exposure throttle: reduce max orders when positions pile up
        effective_mo = MAX_INITIAL_ORDERS
        if DYNAMIC_MO_ENABLED:
            n_positions = len(self.order_engine.positions)
            if n_positions >= DYNAMIC_MO_HIGH_THRESH:
                effective_mo = DYNAMIC_MO_MIN
            elif n_positions >= DYNAMIC_MO_LOW_THRESH:
                # Linear scale between thresholds
                ratio = (n_positions - DYNAMIC_MO_LOW_THRESH) / max(1, DYNAMIC_MO_HIGH_THRESH - DYNAMIC_MO_LOW_THRESH)
                effective_mo = max(DYNAMIC_MO_MIN, int(MAX_INITIAL_ORDERS * (1 - ratio)))

        # Inventory-aware entry bias (from Foucault et al. 2013)
        # When net exposure is heavy on one side, reduce MO for that side
        # Soft bias: halve effective_mo for the imbalanced side
        buy_mo = effective_mo
        sell_mo = effective_mo
        if INVENTORY_BIAS_ENABLED:
            # Count only INITIAL entry threads, not recovery positions
            # Recovery is supposed to be imbalanced (that's the grid grinding)
            buy_threads = 0
            sell_threads = 0
            for thread in self.recovery_threads:
                if thread.is_active:
                    if thread.direction == 0:
                        buy_threads += 1
                    else:
                        sell_threads += 1
            total_threads = buy_threads + sell_threads
            net_threads = buy_threads - sell_threads
            imbalance_ratio = abs(net_threads) / max(total_threads, 1)
            if imbalance_ratio > INVENTORY_BIAS_THRESHOLD:
                if net_threads > 0:
                    buy_mo = max(DYNAMIC_MO_MIN, effective_mo // 2)
                else:
                    sell_mo = max(DYNAMIC_MO_MIN, effective_mo // 2)

        # Entry type 1: Bar reversal (MQ4 Magic 11111)
        if ENABLE_REVERSAL_ENTRIES:
            prev = self._bar_direction(2)
            curr = self._bar_direction(1)

            in_range = FIBO_RANGE_ZONE[0] <= fibo <= FIBO_RANGE_ZONE[1]

            if (in_range
                    and sell_count < sell_mo
                    and buy_count < buy_mo
                    and prev == 1
                    and curr in (0, 2)):
                self._open_initial_trade(0)

            if (in_range
                    and buy_count < buy_mo
                    and sell_count < sell_mo
                    and prev == 0
                    and curr in (1, 2)):
                self._open_initial_trade(1)

        # Entry type 2: Continuation (MQ4 Magic 22222)
        if ENABLE_CONTINUATION_ENTRIES:
            bar1 = self._bar_direction(1)
            bar2 = self._bar_direction(2)

            if (bar1 == 0 and bar2 == 0
                    and sell_count < sell_mo
                    and buy_count < buy_mo
                    and FIBO_BUY_ZONE[0] <= fibo <= FIBO_BUY_ZONE[1]):
                self._open_initial_trade(0)

            if (bar1 == 1 and bar2 == 1
                    and buy_count < buy_mo
                    and sell_count < sell_mo
                    and FIBO_SELL_ZONE[0] <= fibo <= FIBO_SELL_ZONE[1]):
                self._open_initial_trade(1)

    def _open_initial_trade(self, direction: int):
        """Open initial trade and create recovery thread.

        When ADAPTIVE_GRID is enabled, the current volatility scale is locked
        into the thread at creation time. The thread's grid, TPs, and entry TP
        all reflect the market conditions when the trade was opened.
        """
        bid, ask = self.order_engine.get_current_price()

        # Min entry spacing: skip if too close to any existing thread entry
        if MIN_ENTRY_SPACING > 0 and len(self.recovery_threads) > 0:
            for thread in self.recovery_threads:
                if not thread.is_active:
                    continue
                dist = abs(bid - thread.entry_price)
                if dist < MIN_ENTRY_SPACING:
                    return  # Too close, skip this entry

        magic = self._next_magic()

        if direction == 0:
            order = self.order_engine.place_buy_order(
                lots=LOT_SIZE, magic=magic, comment="ENTRY_BUY"
            )
        else:
            order = self.order_engine.place_sell_order(
                lots=LOT_SIZE, magic=magic, comment="ENTRY_SELL"
            )

        if order:
            # Get adaptive params locked to current volatility
            vs = self.current_vol_scale
            sp = self.get_scaled_params(vs)

            tp_dist = sp["entry_tp_pips"] * self.pip_size
            if direction == 0:
                tp_price = order.price + tp_dist
            else:
                tp_price = order.price - tp_dist

            thread = RecoveryThread(
                initial_order_id=order.order_id,
                entry_price=order.price,
                lot_size=LOT_SIZE,
                direction=direction,
                magic=magic,
                recovery_tp_pips=sp["recovery_tps"][0],
                symbol=self.symbol,
                pip_size=self.pip_size,
                grid_steps=sp["grid_steps"],
                recovery_tps=sp["recovery_tps"],
                entry_tp_pips=sp["entry_tp_pips"],
                vol_scale=vs,
            )
            self.recovery_threads.append(thread)

            log_trade(
                "ENTRY_BUY" if direction == 0 else "ENTRY_SELL",
                self.symbol, order.price, LOT_SIZE, magic, tp=tp_price
            )

    # ---- Recovery Processing ----

    def process_candle(self, candle: OHLCV):
        """Process recovery threads against candle data.

        Order of operations within a candle:
        1. Check grid level crossings (recovery triggers)
        2. Check recovery pair TPs
        3. Check initial order TP/SL (after recovery, so grid can trigger first)
        4. Check thread profit exit
        5. Deactivation check
        """
        fibo = self.calc_fibo()

        for thread in list(self.recovery_threads):
            if not thread.is_active:
                continue

            # 1. Check grid level crossings (uses high/low)
            init_pos = self.order_engine.positions.get(thread.initial_order_id)
            run_recovery = True
            if init_pos and init_pos.unrealized_pnl > 0 and len(thread.order_pairs) == 0:
                run_recovery = False

            if run_recovery:
                thread.check_levels(candle.high, candle.low, self.order_engine,
                                    fibo_value=fibo)

            # 1b. Check anti-martingale hedge TPs and reopens
            thread.check_anti_mart_hedges(candle.high, candle.low, self.order_engine)

            # 2. Check recovery pair TPs
            thread.check_tps(candle.high, candle.low, self.order_engine,
                             fibo_value=fibo)

            # 3. Check initial order TP/SL (strategy-managed, uses per-thread TP)
            tp_dist_init = thread.entry_tp_pips * self.pip_size
            sl_dist_init = ENTRY_SL_PIPS * self.pip_size if ENTRY_SL_PIPS > 0 else 0
            init_pos = self.order_engine.positions.get(thread.initial_order_id)
            if init_pos:
                if thread.direction == 0:  # buy: TP when high hits, SL when low hits
                    tp_target = thread.entry_price + tp_dist_init
                    sl_target = thread.entry_price - sl_dist_init if sl_dist_init > 0 else 0
                    if candle.high >= tp_target:
                        self.order_engine.close_position(thread.initial_order_id, tp_target)
                    elif sl_dist_init > 0 and candle.low <= sl_target:
                        self.order_engine.close_position(thread.initial_order_id, sl_target)
                else:  # sell: TP when low hits, SL when high hits
                    tp_target = thread.entry_price - tp_dist_init
                    sl_target = thread.entry_price + sl_dist_init if sl_dist_init > 0 else 0
                    if candle.low <= tp_target:
                        self.order_engine.close_position(thread.initial_order_id, tp_target)
                    elif sl_dist_init > 0 and candle.high >= sl_target:
                        self.order_engine.close_position(thread.initial_order_id, sl_target)

            # 4. Thread profit exit
            thread.check_thread_profit(self.order_engine)

            # 5. Deactivation check
            thread.check_deactivation(self.order_engine)

        # Cleanup inactive threads
        self.recovery_threads = [t for t in self.recovery_threads if t.is_active]

    def on_tick(self, bid: float, ask: float):
        """Process tick — run recovery threads at exact prices."""
        fibo = self.calc_fibo()

        for thread in list(self.recovery_threads):
            if not thread.is_active:
                continue

            init_pos = self.order_engine.positions.get(thread.initial_order_id)
            run_recovery = True
            if init_pos and init_pos.unrealized_pnl > 0 and len(thread.order_pairs) == 0:
                run_recovery = False

            if run_recovery:
                thread.check_levels(ask, bid, self.order_engine, fibo_value=fibo)

            thread.check_tps(ask, bid, self.order_engine, fibo_value=fibo)
            thread.check_thread_profit(self.order_engine)
            thread.check_deactivation(self.order_engine)

        self.recovery_threads = [t for t in self.recovery_threads if t.is_active]

    def get_active_thread_count(self) -> int:
        return len(self.recovery_threads)

    def get_total_recovery_orders(self) -> int:
        total = 0
        for t in self.recovery_threads:
            total += len(t.order_pairs)
        return total
