#!/usr/bin/env python3
"""
Market Regime Detector for Auto Hedge-Mart V4

Classifies market conditions into 4 regimes and provides parameter adjustments:
  🔴 HIGH_VOL   — Volatile/crisis: widen grid, pause new threads
  🟢 LOW_VOL    — Quiet/grinding: tighten grid, maximize sub-pair repeats
  🔵 TRENDING   — Strong directional: favor trend entries, wider TPs
  🟡 RANGING    — Sideways: full send on sub-pair grinding

Uses three indicators:
  - ATR ratio  (current ATR / reference ATR) → volatility level
  - ADX        (Average Directional Index)   → trend strength
  - BB Width   (Bollinger Bandwidth)         → squeeze/expansion detection

Each regime returns parameter multipliers that the engine can apply to:
  - Grid spacing scale
  - Entry TP scale  
  - Recovery TP scale
  - Thread profit target scale
  - Max new threads (0 = pause)
  - Sub-pair repeat aggressiveness

Author: TheCryptoClaw 🦾
"""

from dataclasses import dataclass, field
from enum import Enum
from typing import List, Optional, Tuple
from collections import deque
import math


class MarketRegime(Enum):
    HIGH_VOL = "HIGH_VOL"    # 🔴 Volatile — widen everything, reduce exposure
    LOW_VOL = "LOW_VOL"      # 🟢 Quiet — tighten grid, grind mode
    TRENDING = "TRENDING"    # 🔵 Directional — wider TPs, trend entries
    RANGING = "RANGING"      # 🟡 Sideways — sub-pair paradise


@dataclass
class RegimeParams:
    """Parameter adjustments for current regime."""
    regime: MarketRegime
    confidence: float              # 0-1, how confident is the classification
    grid_scale: float              # Multiplier for grid spacing (>1 = wider)
    entry_tp_scale: float          # Multiplier for entry TP
    recovery_tp_scale: float       # Multiplier for recovery TPs
    thread_target_scale: float     # Multiplier for thread profit target
    max_new_threads: int           # Max new threads allowed (0 = paused)
    sub_pair_aggressive: bool      # Enable extra sub-pair repeat modes
    description: str               # Human-readable summary

    def __repr__(self):
        emoji = {"HIGH_VOL": "🔴", "LOW_VOL": "🟢", "TRENDING": "🔵", "RANGING": "🟡"}
        e = emoji.get(self.regime.value, "⚪")
        return (f"{e} {self.regime.value} (conf={self.confidence:.0%}) | "
                f"grid={self.grid_scale:.2f}x tp={self.entry_tp_scale:.2f}x "
                f"threads={self.max_new_threads} sub_agg={self.sub_pair_aggressive}")


# =============================================================================
# Regime parameter presets
# =============================================================================

REGIME_PRESETS = {
    MarketRegime.HIGH_VOL: RegimeParams(
        regime=MarketRegime.HIGH_VOL,
        confidence=0.0,  # Set dynamically
        grid_scale=1.8,            # Much wider grid — avoid false recovery triggers
        entry_tp_scale=1.5,        # Wider TP — let winners run in vol
        recovery_tp_scale=1.3,     # Slightly wider recovery TPs
        thread_target_scale=2.0,   # Higher thread exit — more grinding before close
        max_new_threads=3,         # Reduce new thread exposure
        sub_pair_aggressive=False,  # Don't add extra sub-pairs in chaos
        description="High volatility — widened grid, reduced thread count, larger targets"
    ),
    MarketRegime.LOW_VOL: RegimeParams(
        regime=MarketRegime.LOW_VOL,
        confidence=0.0,
        grid_scale=0.6,            # Tighter grid — catch smaller moves
        entry_tp_scale=0.7,        # Tighter TP — take what the market gives
        recovery_tp_scale=0.7,     # Tighter recovery TPs
        thread_target_scale=0.8,   # Lower exit — faster thread turnover
        max_new_threads=10,        # Full thread count
        sub_pair_aggressive=True,  # Extra sub-pairs for grinding
        description="Low volatility — tightened grid, full threads, aggressive grinding"
    ),
    MarketRegime.TRENDING: RegimeParams(
        regime=MarketRegime.TRENDING,
        confidence=0.0,
        grid_scale=1.2,            # Slightly wider — trend pullbacks are deeper
        entry_tp_scale=1.8,        # Much wider TP — ride the trend
        recovery_tp_scale=1.0,     # Normal recovery
        thread_target_scale=1.5,   # Higher exit — trends give more
        max_new_threads=8,         # Slightly fewer — be selective
        sub_pair_aggressive=False,  # Sub-pairs less useful in trends
        description="Trending market — wider TPs to ride momentum, moderate grid"
    ),
    MarketRegime.RANGING: RegimeParams(
        regime=MarketRegime.RANGING,
        confidence=0.0,
        grid_scale=0.9,            # Slightly tighter grid
        entry_tp_scale=0.8,        # Tighter TP — range-bound moves
        recovery_tp_scale=0.8,     # Tighter recovery TPs  
        thread_target_scale=1.0,   # Normal exit target
        max_new_threads=10,        # Full threads — this is our sweet spot
        sub_pair_aggressive=True,  # Sub-pair grinding paradise
        description="Ranging market — tighter targets, full threads, max sub-pair grinding"
    ),
}


@dataclass
class OHLC:
    """Simple candle for regime detection (doesn't need volume)."""
    open: float
    high: float
    low: float
    close: float


class RegimeDetector:
    """
    Classifies market regime using ATR, ADX, and Bollinger Bandwidth.
    
    Usage:
        detector = RegimeDetector(reference_atr=132.1)
        for candle in candles:
            detector.update(candle)
            regime = detector.get_regime()
            if regime:
                print(regime)
    """

    def __init__(
        self,
        reference_atr: float = 132.1,
        atr_period: int = 14,
        adx_period: int = 14,
        bb_period: int = 20,
        bb_std: float = 2.0,
        lookback: int = 50,
        # Thresholds
        high_vol_atr_ratio: float = 1.8,   # ATR > 1.8x reference = high vol
        low_vol_atr_ratio: float = 0.6,    # ATR < 0.6x reference = low vol
        trending_adx: float = 30.0,        # ADX > 30 = trending
        ranging_adx: float = 20.0,         # ADX < 20 = ranging
        bb_squeeze_pct: float = 0.02,      # BB width < 2% = squeeze (low vol)
        bb_expansion_pct: float = 0.06,    # BB width > 6% = expansion (high vol)
        regime_confirm_bars: int = 3,       # Need N consecutive bars to confirm regime change
    ):
        self.reference_atr = reference_atr
        self.atr_period = atr_period
        self.adx_period = adx_period
        self.bb_period = bb_period
        self.bb_std = bb_std
        self.lookback = lookback

        # Thresholds
        self.high_vol_atr_ratio = high_vol_atr_ratio
        self.low_vol_atr_ratio = low_vol_atr_ratio
        self.trending_adx = trending_adx
        self.ranging_adx = ranging_adx
        self.bb_squeeze_pct = bb_squeeze_pct
        self.bb_expansion_pct = bb_expansion_pct
        self.regime_confirm_bars = regime_confirm_bars

        # Internal state
        self.candles: deque = deque(maxlen=max(lookback, bb_period + 10, adx_period * 3))
        self.true_ranges: deque = deque(maxlen=atr_period + 5)
        self.plus_dm: deque = deque(maxlen=adx_period + 5)
        self.minus_dm: deque = deque(maxlen=adx_period + 5)
        self.dx_values: deque = deque(maxlen=adx_period + 5)
        
        # Smoothed values (Wilder's method)
        self._smoothed_tr: Optional[float] = None
        self._smoothed_plus_dm: Optional[float] = None
        self._smoothed_minus_dm: Optional[float] = None
        self._smoothed_adx: Optional[float] = None

        # Current values
        self.current_atr: Optional[float] = None
        self.current_adx: Optional[float] = None
        self.current_bb_width: Optional[float] = None

        # Regime tracking
        self.current_regime: Optional[MarketRegime] = None
        self.candidate_regime: Optional[MarketRegime] = None
        self.candidate_count: int = 0
        self.regime_history: deque = deque(maxlen=100)
        self.bars_processed: int = 0

    def update(self, candle) -> Optional[RegimeParams]:
        """
        Process a new candle and update regime classification.
        
        Args:
            candle: Object with .open, .high, .low, .close attributes
            
        Returns:
            RegimeParams if regime changed, None if unchanged
        """
        ohlc = OHLC(
            open=float(candle.open),
            high=float(candle.high),
            low=float(candle.low),
            close=float(candle.close)
        )
        self.candles.append(ohlc)
        self.bars_processed += 1

        # Need at least enough data for indicators
        if len(self.candles) < 2:
            return None

        prev = self.candles[-2]

        # --- True Range ---
        tr = max(
            ohlc.high - ohlc.low,
            abs(ohlc.high - prev.close),
            abs(ohlc.low - prev.close)
        )
        self.true_ranges.append(tr)

        # --- Directional Movement ---
        plus_dm = max(ohlc.high - prev.high, 0) if (ohlc.high - prev.high) > (prev.low - ohlc.low) else 0
        minus_dm = max(prev.low - ohlc.low, 0) if (prev.low - ohlc.low) > (ohlc.high - prev.high) else 0
        self.plus_dm.append(plus_dm)
        self.minus_dm.append(minus_dm)

        # --- Wilder's Smoothing for ATR and DI ---
        n = self.atr_period
        if self._smoothed_tr is None:
            if len(self.true_ranges) >= n:
                self._smoothed_tr = sum(list(self.true_ranges)[-n:]) / n
                self._smoothed_plus_dm = sum(list(self.plus_dm)[-n:]) / n
                self._smoothed_minus_dm = sum(list(self.minus_dm)[-n:]) / n
        else:
            self._smoothed_tr = (self._smoothed_tr * (n - 1) + tr) / n
            self._smoothed_plus_dm = (self._smoothed_plus_dm * (n - 1) + plus_dm) / n
            self._smoothed_minus_dm = (self._smoothed_minus_dm * (n - 1) + minus_dm) / n

        if self._smoothed_tr is None:
            return None

        # --- ATR ---
        self.current_atr = self._smoothed_tr

        # --- +DI / -DI ---
        if self._smoothed_tr > 0:
            plus_di = (self._smoothed_plus_dm / self._smoothed_tr) * 100
            minus_di = (self._smoothed_minus_dm / self._smoothed_tr) * 100
        else:
            plus_di = 0
            minus_di = 0

        # --- DX and ADX ---
        if (plus_di + minus_di) > 0:
            dx = abs(plus_di - minus_di) / (plus_di + minus_di) * 100
        else:
            dx = 0
        self.dx_values.append(dx)

        n_adx = self.adx_period
        if self._smoothed_adx is None:
            if len(self.dx_values) >= n_adx:
                self._smoothed_adx = sum(list(self.dx_values)[-n_adx:]) / n_adx
        else:
            self._smoothed_adx = (self._smoothed_adx * (n_adx - 1) + dx) / n_adx

        self.current_adx = self._smoothed_adx

        # --- Bollinger Bandwidth ---
        if len(self.candles) >= self.bb_period:
            closes = [c.close for c in list(self.candles)[-self.bb_period:]]
            sma = sum(closes) / self.bb_period
            variance = sum((c - sma) ** 2 for c in closes) / self.bb_period
            std = math.sqrt(variance)
            upper = sma + self.bb_std * std
            lower = sma - self.bb_std * std
            if sma > 0:
                self.current_bb_width = (upper - lower) / sma
            else:
                self.current_bb_width = 0
        else:
            self.current_bb_width = None

        # --- Classify Regime ---
        if self.current_atr is not None and self.current_adx is not None:
            new_regime = self._classify()
            return self._apply_regime(new_regime)

        return None

    def _classify(self) -> MarketRegime:
        """
        Classify market regime based on indicator values.
        
        Decision tree:
        1. ATR ratio > high_vol_threshold → HIGH_VOL (regardless of trend)
        2. ATR ratio < low_vol_threshold AND BB squeeze → LOW_VOL
        3. ADX > trending_threshold → TRENDING
        4. ADX < ranging_threshold → RANGING
        5. Default → use BB width as tiebreaker
        """
        atr_ratio = self.current_atr / self.reference_atr if self.reference_atr > 0 else 1.0
        adx = self.current_adx or 0
        bb_width = self.current_bb_width or 0.04  # Default to mid-range

        # Priority 1: Extreme volatility overrides everything
        if atr_ratio >= self.high_vol_atr_ratio:
            return MarketRegime.HIGH_VOL

        # Priority 2: Very low volatility + BB squeeze
        if atr_ratio <= self.low_vol_atr_ratio and bb_width <= self.bb_squeeze_pct:
            return MarketRegime.LOW_VOL

        # Priority 3: Strong trend
        if adx >= self.trending_adx:
            # But if vol is also high, still HIGH_VOL
            if atr_ratio >= self.high_vol_atr_ratio * 0.8:
                return MarketRegime.HIGH_VOL
            return MarketRegime.TRENDING

        # Priority 4: Weak trend = ranging
        if adx <= self.ranging_adx:
            # Low vol ranging vs normal ranging
            if atr_ratio <= self.low_vol_atr_ratio:
                return MarketRegime.LOW_VOL
            return MarketRegime.RANGING

        # Priority 5: Middle ground — use BB width as tiebreaker
        if bb_width >= self.bb_expansion_pct:
            return MarketRegime.HIGH_VOL if atr_ratio > 1.2 else MarketRegime.TRENDING
        elif bb_width <= self.bb_squeeze_pct:
            return MarketRegime.LOW_VOL
        else:
            # True middle — slight trend bias if ADX is above midpoint
            midpoint = (self.trending_adx + self.ranging_adx) / 2
            if adx >= midpoint:
                return MarketRegime.TRENDING
            return MarketRegime.RANGING

    def _apply_regime(self, new_regime: MarketRegime) -> Optional[RegimeParams]:
        """
        Apply regime change with confirmation (debouncing).
        
        Requires regime_confirm_bars consecutive bars of the same classification
        before switching. This prevents whipsawing between regimes.
        """
        if new_regime == self.current_regime:
            # Same regime — reset candidate
            self.candidate_regime = None
            self.candidate_count = 0
            return None

        if new_regime == self.candidate_regime:
            self.candidate_count += 1
        else:
            self.candidate_regime = new_regime
            self.candidate_count = 1

        if self.candidate_count >= self.regime_confirm_bars:
            # Confirmed regime change
            old_regime = self.current_regime
            self.current_regime = new_regime
            self.candidate_regime = None
            self.candidate_count = 0

            # Build params with confidence
            params = RegimeParams(**vars(REGIME_PRESETS[new_regime]))
            params.confidence = self._calc_confidence()
            
            self.regime_history.append({
                "bar": self.bars_processed,
                "from": old_regime.value if old_regime else "NONE",
                "to": new_regime.value,
                "confidence": params.confidence,
                "atr": self.current_atr,
                "adx": self.current_adx,
                "bb_width": self.current_bb_width,
            })

            return params

        return None

    def _calc_confidence(self) -> float:
        """
        Calculate confidence score (0-1) based on how strongly indicators
        agree on the current classification.
        """
        if self.current_atr is None or self.current_adx is None:
            return 0.0

        atr_ratio = self.current_atr / self.reference_atr
        adx = self.current_adx or 0
        bb_width = self.current_bb_width or 0.04

        regime = self.current_regime
        scores = []

        if regime == MarketRegime.HIGH_VOL:
            # ATR strongly above threshold
            scores.append(min(1.0, (atr_ratio - 1.0) / (self.high_vol_atr_ratio - 1.0)))
            # BB expanded
            scores.append(min(1.0, bb_width / self.bb_expansion_pct))
        elif regime == MarketRegime.LOW_VOL:
            # ATR strongly below threshold
            scores.append(min(1.0, (1.0 - atr_ratio) / (1.0 - self.low_vol_atr_ratio)))
            # BB squeezed
            scores.append(min(1.0, (self.bb_squeeze_pct - bb_width) / self.bb_squeeze_pct) if bb_width < self.bb_squeeze_pct else 0.0)
        elif regime == MarketRegime.TRENDING:
            # ADX strongly above threshold
            scores.append(min(1.0, (adx - self.ranging_adx) / (self.trending_adx - self.ranging_adx)))
            # Moderate or high vol
            scores.append(min(1.0, atr_ratio / 1.5))
        elif regime == MarketRegime.RANGING:
            # ADX strongly below threshold
            scores.append(min(1.0, (self.trending_adx - adx) / (self.trending_adx - self.ranging_adx)))
            # Not extreme vol
            scores.append(1.0 - min(1.0, abs(atr_ratio - 1.0)))

        return max(0.0, min(1.0, sum(scores) / len(scores))) if scores else 0.0

    def get_regime(self) -> Optional[RegimeParams]:
        """Get current regime parameters (None if not enough data yet)."""
        if self.current_regime is None:
            return None
        params = RegimeParams(**vars(REGIME_PRESETS[self.current_regime]))
        params.confidence = self._calc_confidence()
        return params

    def get_indicators(self) -> dict:
        """Get current indicator values for debugging/display."""
        atr_ratio = (self.current_atr / self.reference_atr) if self.current_atr else None
        return {
            "atr": round(self.current_atr, 2) if self.current_atr else None,
            "atr_ratio": round(atr_ratio, 3) if atr_ratio else None,
            "adx": round(self.current_adx, 2) if self.current_adx else None,
            "bb_width": round(self.current_bb_width, 4) if self.current_bb_width else None,
            "regime": self.current_regime.value if self.current_regime else None,
            "bars": self.bars_processed,
            "regime_changes": len(self.regime_history),
        }

    def get_regime_summary(self) -> str:
        """Human-readable regime summary for display."""
        if self.current_regime is None:
            return "⏳ Warming up... (need more candles for classification)"

        params = self.get_regime()
        indicators = self.get_indicators()
        emoji = {"HIGH_VOL": "🔴", "LOW_VOL": "🟢", "TRENDING": "🔵", "RANGING": "🟡"}
        e = emoji.get(self.current_regime.value, "⚪")

        lines = [
            f"{e} Market Regime: {self.current_regime.value} ({params.confidence:.0%} confidence)",
            f"",
            f"Indicators:",
            f"  ATR: {indicators['atr']} (ratio: {indicators['atr_ratio']}x vs ref {self.reference_atr})",
            f"  ADX: {indicators['adx']} ({'trending' if indicators['adx'] and indicators['adx'] > self.trending_adx else 'ranging' if indicators['adx'] and indicators['adx'] < self.ranging_adx else 'neutral'})",
            f"  BB Width: {indicators['bb_width']} ({'squeeze' if indicators['bb_width'] and indicators['bb_width'] < self.bb_squeeze_pct else 'expansion' if indicators['bb_width'] and indicators['bb_width'] > self.bb_expansion_pct else 'normal'})",
            f"",
            f"Parameter Adjustments:",
            f"  Grid spacing: {params.grid_scale:.1f}x",
            f"  Entry TP: {params.entry_tp_scale:.1f}x",
            f"  Recovery TP: {params.recovery_tp_scale:.1f}x",
            f"  Thread target: {params.thread_target_scale:.1f}x",
            f"  Max new threads: {params.max_new_threads}",
            f"  Sub-pair aggressive: {'YES' if params.sub_pair_aggressive else 'no'}",
            f"",
            f"Regime changes: {len(self.regime_history)} over {self.bars_processed} bars",
        ]
        return "\n".join(lines)


# =============================================================================
# Standalone test — run on M5 data to see regime transitions
# =============================================================================

if __name__ == "__main__":
    import sys
    
    print("🦾 Market Regime Detector — Standalone Test")
    print("=" * 60)
    
    # Create detector with BTCUSD M5 defaults
    detector = RegimeDetector(reference_atr=132.1)

    # Generate some synthetic test candles to verify logic
    import random
    random.seed(42)

    price = 95000.0
    regimes_seen = set()

    # Simulate 500 candles with different market conditions
    scenarios = [
        ("Quiet ranging", 100, 50, 0),       # Low vol, no trend
        ("Breakout trending", 100, 300, 1),   # High vol, strong trend up
        ("Crash", 50, 500, -1),               # Very high vol, down
        ("Recovery trend", 100, 200, 1),      # Moderate vol, up
        ("Consolidation", 150, 80, 0),        # Low vol, ranging
    ]

    bar = 0
    for name, count, vol_range, direction in scenarios:
        print(f"\n--- Simulating: {name} ({count} bars, vol={vol_range}) ---")
        for i in range(count):
            # Generate candle
            movement = direction * vol_range * 0.3 + random.gauss(0, vol_range)
            o = price
            c = price + movement
            h = max(o, c) + random.uniform(0, vol_range * 0.5)
            l = min(o, c) - random.uniform(0, vol_range * 0.5)
            price = c

            candle = OHLC(open=o, high=h, low=l, close=c)
            change = detector.update(candle)
            bar += 1

            if change:
                regimes_seen.add(change.regime)
                print(f"  Bar {bar}: {change}")

    print(f"\n{'=' * 60}")
    print(detector.get_regime_summary())
    print(f"\nRegimes seen: {[r.value for r in regimes_seen]}")
    print(f"Total regime changes: {len(detector.regime_history)}")
    print("\nRegime history:")
    for h in detector.regime_history:
        print(f"  Bar {h['bar']}: {h['from']} → {h['to']} "
              f"(conf={h['confidence']:.0%}, ATR={h['atr']:.1f}, ADX={h['adx']:.1f})")
