"""
PolyEdge — Risk & Position Sizing

Kelly Criterion-based sizing with bankroll management.
"""

import math
from typing import Dict, Any, Optional
from config import (
    KELLY_FRACTION, BASE_BET_PCT, MAX_BET_PCT,
    MIN_BET_USD, MAX_BET_USD,
    MAX_DAILY_LOSS_PCT, MAX_CONSECUTIVE_LOSSES,
    COOLDOWN_BARS, MAX_OPEN_POSITIONS, BANKROLL_FLOOR_PCT,
    TAKER_FEE_BPS, MAKER_FEE_BPS, USE_LIMIT_ORDERS,
)


def calc_taker_fee(price: float, size_usd: float) -> float:
    """
    Calculate Polymarket taker fee.
    Fee = baseRate * min(price, 1-price) * size
    """
    base_rate = TAKER_FEE_BPS / 10000
    return base_rate * min(price, 1 - price) * size_usd


def kelly_bet_size(
    edge: float,
    win_price: float,
    bankroll: float,
) -> float:
    """
    Calculate optimal bet size using Kelly Criterion.
    
    For a binary bet:
    - Win probability: p = market_price + edge (our estimate of true prob)
    - Win payout: (1 - buy_price) / buy_price  (odds)
    - Kelly: f* = (p * b - q) / b
      where b = payout odds, p = win prob, q = 1-p
    
    We use fractional Kelly (KELLY_FRACTION) for safety.
    
    Args:
        edge: Our edge (fair_value - market_price for buy_yes)
        win_price: Price we buy at (market YES price for buy_yes)
        bankroll: Current bankroll
    
    Returns:
        Bet size in USD
    """
    # True win probability (our estimate)
    p = win_price + edge
    p = max(0.01, min(0.99, p))
    q = 1 - p
    
    # Payout odds (binary: win = $1/share, cost = win_price/share)
    b = (1 - win_price) / win_price  # e.g., buy at 0.55 → odds = 0.818:1
    
    if b <= 0:
        return 0
    
    # Kelly formula
    kelly_f = (p * b - q) / b
    
    if kelly_f <= 0:
        return 0  # No edge → no bet
    
    # Fractional Kelly
    f = kelly_f * KELLY_FRACTION
    
    # Convert to USD
    bet = f * bankroll
    
    # Apply limits
    bet = max(MIN_BET_USD, min(bet, MAX_BET_USD))
    bet = min(bet, bankroll * MAX_BET_PCT)
    
    return round(bet, 2)


class RiskManager:
    """
    Manages bankroll, tracks P&L, enforces risk limits.
    """
    
    def __init__(self, initial_bankroll: float):
        self.initial_bankroll = initial_bankroll
        self.bankroll = initial_bankroll
        self.peak_bankroll = initial_bankroll
        self.daily_start = initial_bankroll
        
        # Tracking
        self.total_trades = 0
        self.wins = 0
        self.losses = 0
        self.consecutive_losses = 0
        self.cooldown_remaining = 0
        self.open_positions = 0
        self.total_pnl = 0
        self.total_fees = 0
        self.max_drawdown_pct = 0
        
        # Daily tracking
        self.daily_pnl = 0
        self.daily_trades = 0
        
        # History
        self.equity_curve: list = [initial_bankroll]
        self.trade_log: list = []
    
    def can_trade(self) -> tuple:
        """
        Check if we're allowed to trade.
        Returns (allowed: bool, reason: str)
        """
        # Cooldown active
        if self.cooldown_remaining > 0:
            return False, f"Cooldown: {self.cooldown_remaining} bars remaining"
        
        # Max consecutive losses
        if self.consecutive_losses >= MAX_CONSECUTIVE_LOSSES:
            self.cooldown_remaining = COOLDOWN_BARS
            return False, f"Hit {MAX_CONSECUTIVE_LOSSES} consecutive losses, cooling down"
        
        # Daily loss limit
        daily_loss_pct = -self.daily_pnl / max(self.daily_start, 1)
        if daily_loss_pct >= MAX_DAILY_LOSS_PCT:
            return False, f"Daily loss limit hit ({daily_loss_pct:.1%})"
        
        # Bankroll floor
        if self.bankroll < self.initial_bankroll * BANKROLL_FLOOR_PCT:
            return False, f"Bankroll below {BANKROLL_FLOOR_PCT:.0%} floor"
        
        # Max open positions
        if self.open_positions >= MAX_OPEN_POSITIONS:
            return False, f"Max open positions ({MAX_OPEN_POSITIONS})"
        
        # Bankroll too low for min bet
        if self.bankroll < MIN_BET_USD:
            return False, "Bankroll depleted"
        
        return True, "OK"
    
    def size_trade(self, signal: Dict[str, Any]) -> float:
        """
        Size a trade based on signal edge and Kelly criterion.
        
        Args:
            signal: Signal dict from SignalEngine
        
        Returns:
            Bet size in USD (0 if shouldn't trade)
        """
        allowed, reason = self.can_trade()
        if not allowed:
            return 0
        
        edge = signal["edge"]
        
        if signal["direction"] == "buy_yes":
            buy_price = signal["market_price"]
        else:  # buy_no
            buy_price = 1 - signal["market_price"]  # NO price
        
        bet = kelly_bet_size(edge, buy_price, self.bankroll)
        
        # Strong signal bonus (up to 1.5x)
        if signal["strength"] == "strong":
            bet *= 1.5
            bet = min(bet, self.bankroll * MAX_BET_PCT)
        
        return bet
    
    def record_trade(self, result: Dict[str, Any]):
        """
        Record a completed trade.
        
        Args:
            result: {
                'direction': str,
                'bet_size': float,
                'buy_price': float,
                'won': bool,
                'pnl': float,
                'fee': float,
            }
        """
        pnl = result['pnl']
        fee = result['fee']
        net_pnl = pnl - fee
        
        self.bankroll += net_pnl
        self.total_pnl += net_pnl
        self.total_fees += fee
        self.daily_pnl += net_pnl
        self.total_trades += 1
        self.daily_trades += 1
        
        if result['won']:
            self.wins += 1
            self.consecutive_losses = 0
        else:
            self.losses += 1
            self.consecutive_losses += 1
        
        # Update peak / drawdown
        self.peak_bankroll = max(self.peak_bankroll, self.bankroll)
        dd_pct = (self.peak_bankroll - self.bankroll) / self.peak_bankroll * 100
        self.max_drawdown_pct = max(self.max_drawdown_pct, dd_pct)
        
        self.equity_curve.append(self.bankroll)
        self.trade_log.append(result)
    
    def tick(self):
        """Called each bar to decrement cooldown."""
        if self.cooldown_remaining > 0:
            self.cooldown_remaining -= 1
    
    def new_day(self):
        """Reset daily counters."""
        self.daily_start = self.bankroll
        self.daily_pnl = 0
        self.daily_trades = 0
    
    def get_stats(self) -> Dict[str, Any]:
        """Get current performance stats."""
        wr = self.wins / max(self.total_trades, 1) * 100
        avg_win = 0
        avg_loss = 0
        
        wins = [t['pnl'] - t['fee'] for t in self.trade_log if t['won']]
        losses = [t['pnl'] - t['fee'] for t in self.trade_log if not t['won']]
        
        if wins:
            avg_win = sum(wins) / len(wins)
        if losses:
            avg_loss = sum(losses) / len(losses)
        
        pf = sum(wins) / max(abs(sum(losses)), 0.01) if losses else float('inf')
        
        return {
            'bankroll': self.bankroll,
            'total_pnl': self.total_pnl,
            'return_pct': (self.bankroll / self.initial_bankroll - 1) * 100,
            'total_trades': self.total_trades,
            'wins': self.wins,
            'losses': self.losses,
            'win_rate': wr,
            'avg_win': avg_win,
            'avg_loss': avg_loss,
            'profit_factor': pf,
            'max_drawdown_pct': self.max_drawdown_pct,
            'total_fees': self.total_fees,
            'consecutive_losses': self.consecutive_losses,
        }
