"""
Kelly Criterion Lot Sizing Module
==================================
Derived from Huygens (1657), Bernoulli (1713), Kelly (1956).

Instead of fixed lot sizes, scale lots based on measured edge at each
trade category. Entries have negative Kelly → use minimum lots.
Deep recovery has massive Kelly → scale up lots.

Usage:
    from kelly_sizing import KellySizer
    sizer = KellySizer(base_lot=0.02, equity=20000)
    lot = sizer.get_lot(category='REC_L2', equity=current_equity)
"""

from dataclasses import dataclass, field
from typing import Dict, Optional
import math


@dataclass
class KellyConfig:
    """Kelly fraction configuration per trade category."""
    # Measured from 2-month baseline backtest (4,656 trades)
    # Kelly f* = (b*p - q) / b where b=avg_win/avg_loss, p=win_rate
    
    # Category: (kelly_fraction, fractional_kelly_pct, min_lot_mult, max_lot_mult)
    # fractional_kelly = what % of full Kelly to use (10-25% is standard)
    KELLY_TABLE: Dict[str, dict] = field(default_factory=lambda: {
        'ENTRY': {
            'full_kelly': -0.063,   # Negative edge - entries are a cost
            'win_rate': 0.794,
            'payoff_ratio': 0.24,
            'lot_multiplier': 0.5,  # Use HALF the base lot - minimize the cost
            'notes': 'Entries lose money. Minimize exposure. They exist to seed threads.'
        },
        'SUB_REPEAT': {
            'full_kelly': -0.031,   # Negative edge
            'win_rate': 0.572,
            'payoff_ratio': 0.71,
            'lot_multiplier': 0.75, # Slightly reduced
            'notes': 'Sub repeats slightly negative. Reduce but keep for grinding.'
        },
        'REC_L1': {
            'full_kelly': 0.008,    # Barely positive
            'win_rate': 0.711,
            'payoff_ratio': 0.41,
            'lot_multiplier': 1.0,  # Base lot - edge too thin to scale up
            'notes': 'L1 recovery breaks even. Keep at base.'
        },
        'REC_L2': {
            'full_kelly': 0.492,    # Strong edge!
            'win_rate': 0.654,
            'payoff_ratio': 2.12,
            'lot_multiplier': 1.5,  # 50% bigger lots on L2 recovery
            'notes': 'L2 is where the money is. Scale up.'
        },
        'REC_L3': {
            'full_kelly': 0.800,    # Massive edge (small sample)
            'win_rate': 0.833,
            'payoff_ratio': 4.94,
            'lot_multiplier': 1.8,  # 80% bigger (capped due to small sample - Bernoulli!)
            'notes': 'Deep recovery has huge edge but few samples. Scale up with caution.'
        },
        'REC_DEEP': {
            'full_kelly': 0.800,    # Same as L3+ 
            'win_rate': 0.833,
            'payoff_ratio': 4.94,
            'lot_multiplier': 2.0,  # Double lots on deepest recovery
            'notes': 'Deepest levels. Highest edge. Maximum aggression.'
        },
    })


class KellySizer:
    """Dynamic lot sizing based on Kelly criterion per trade category."""
    
    def __init__(self, base_lot: float = 0.02, min_lot: float = 0.001, 
                 max_lot: float = 15.0, lot_step: float = 0.001,
                 kelly_enabled: bool = True):
        self.base_lot = base_lot
        self.min_lot = min_lot
        self.max_lot = max_lot
        self.lot_step = lot_step
        self.kelly_enabled = kelly_enabled
        self.config = KellyConfig()
        self.trade_counts = {}  # Track trades per category for Bernoulli confidence
        
    def _categorize_trade(self, comment: str, recovery_level: int = 0) -> str:
        """Map trade comment/level to Kelly category."""
        if comment and 'ENTRY' in comment.upper():
            return 'ENTRY'
        elif comment and 'SUB_RPT' in comment.upper():
            return 'SUB_REPEAT'
        elif recovery_level >= 4:
            return 'REC_DEEP'
        elif recovery_level == 3:
            return 'REC_L3'
        elif recovery_level == 2:
            return 'REC_L2'
        elif recovery_level == 1:
            return 'REC_L1'
        else:
            return 'ENTRY'
    
    def _bernoulli_confidence(self, category: str) -> float:
        """
        Bernoulli confidence factor: reduce sizing when we don't have
        enough observations to trust the Kelly estimate.
        
        Law of Large Numbers: confidence ~ sqrt(n) / n
        We need ~100 trades for reasonable confidence, 400+ for high confidence.
        
        Returns: multiplier between 0.5 and 1.0
        """
        n = self.trade_counts.get(category, 0)
        if n >= 400:
            return 1.0      # High confidence - use full Kelly multiplier
        elif n >= 100:
            return 0.8       # Moderate confidence
        elif n >= 30:
            return 0.65      # Low confidence - reduce
        else:
            return 0.5       # Very low confidence - half Kelly multiplier
    
    def get_lot_multiplier(self, comment: str = '', recovery_level: int = 0) -> float:
        """
        Get Kelly-optimal lot multiplier for this trade type.
        
        Returns a multiplier to apply to the base lot or recovery lot calculation.
        """
        if not self.kelly_enabled:
            return 1.0
            
        category = self._categorize_trade(comment, recovery_level)
        
        # Track observation count (Bernoulli)
        self.trade_counts[category] = self.trade_counts.get(category, 0) + 1
        
        # Get Kelly-derived multiplier
        kelly_data = self.config.KELLY_TABLE.get(category, 
                     self.config.KELLY_TABLE.get('ENTRY'))
        mult = kelly_data['lot_multiplier']
        
        # Apply Bernoulli confidence scaling
        confidence = self._bernoulli_confidence(category)
        
        # Blend: start at 1.0 (no change), move toward Kelly mult as confidence grows
        adjusted_mult = 1.0 + (mult - 1.0) * confidence
        
        return max(0.5, min(adjusted_mult, 2.5))  # Clamp to safe range
    
    def get_entry_lot(self) -> float:
        """Get lot size for initial entry trades (Kelly says: minimize)."""
        mult = self.get_lot_multiplier(comment='ENTRY', recovery_level=0)
        lot = self.base_lot * mult
        return self._normalize_lot(lot)
    
    def get_recovery_lot(self, calculated_lot: float, recovery_level: int) -> float:
        """
        Apply Kelly multiplier to the standard recovery lot calculation.
        
        The recovery lot is already calculated by the engine as:
            lot = (accumulated_loss + profit_target) / tp_distance
            
        Kelly says: scale this up for L2+ where edge is proven.
        """
        mult = self.get_lot_multiplier(comment='RECOVERY', recovery_level=recovery_level)
        lot = calculated_lot * mult
        return self._normalize_lot(lot)
    
    def _normalize_lot(self, lot: float) -> float:
        """Round to lot step and clamp to limits."""
        lot = max(self.min_lot, min(lot, self.max_lot))
        lot = round(lot / self.lot_step) * self.lot_step
        return round(lot, 6)
    
    def get_stats(self) -> dict:
        """Return current Kelly sizing statistics."""
        return {
            'kelly_enabled': self.kelly_enabled,
            'base_lot': self.base_lot,
            'trade_counts': dict(self.trade_counts),
            'confidence_levels': {
                cat: self._bernoulli_confidence(cat) 
                for cat in self.trade_counts
            }
        }


# Quick test
if __name__ == '__main__':
    sizer = KellySizer(base_lot=0.02)
    
    print("Kelly Lot Sizing Test")
    print("=" * 50)
    
    # Simulate trade sequence
    for i in range(50):
        entry_lot = sizer.get_entry_lot()
    
    for i in range(20):
        rec_lot = sizer.get_recovery_lot(0.05, recovery_level=1)
    
    for i in range(10):
        rec_lot = sizer.get_recovery_lot(0.08, recovery_level=2)
    
    for i in range(3):
        rec_lot = sizer.get_recovery_lot(0.12, recovery_level=3)
    
    print(f"\nEntry lot (base 0.02): {sizer.get_entry_lot()}")
    print(f"L1 recovery (calc 0.05): {sizer.get_recovery_lot(0.05, 1)}")
    print(f"L2 recovery (calc 0.08): {sizer.get_recovery_lot(0.08, 2)}")
    print(f"L3 recovery (calc 0.12): {sizer.get_recovery_lot(0.12, 3)}")
    print(f"L5 deep recovery (calc 0.20): {sizer.get_recovery_lot(0.20, 5)}")
    print(f"\nStats: {sizer.get_stats()}")
