"""
Risk management and lot sizing calculations.
"""

from typing import Optional

from config import RISK_PERCENT, MAX_LOT_SIZE, MIN_LOT_SIZE, LOT_STEP, AGGRESSION_LEVEL
from normalization import get_pair_config, normalize_lot_size


class RiskManager:
    """
    Manages risk calculations and lot sizing.

    Calculates appropriate position sizes based on account balance,
    risk parameters, and symbol constraints.
    """

    def __init__(
        self,
        risk_percent: float = RISK_PERCENT,
        aggression_level: float = AGGRESSION_LEVEL
    ):
        """
        Initialize risk manager.

        Args:
            risk_percent: Percentage of balance to risk per trade
            aggression_level: Percentage increase per averaging level
        """
        self.risk_percent = risk_percent
        self.aggression_level = aggression_level

    def calculate_base_lot_size(
        self,
        balance: float,
        price: float,
        symbol: str,
        risk_per_pip: Optional[float] = None
    ) -> float:
        """
        Calculate base lot size based on balance and risk.

        Args:
            balance: Account balance
            price: Current price
            symbol: Trading pair symbol
            risk_per_pip: Optional risk per pip (for stop-loss based sizing)

        Returns:
            Calculated lot size
        """
        # Simple percentage-based sizing
        # Risk amount = balance * risk_percent / 100
        risk_amount = balance * (self.risk_percent / 100)

        # Calculate lot size: risk_amount / price gives us the base lots
        # This is a simplified calculation for crypto spot trading
        lot_size = risk_amount / price

        # Normalize to symbol constraints
        lot_size = normalize_lot_size(lot_size, symbol)

        return lot_size

    def calculate_lot_size(
        self,
        balance: float,
        symbol: str,
        price: float,
        tp_distance: Optional[float] = None,
        use_tp_for_sizing: bool = False
    ) -> float:
        """
        Calculate lot size for a new position.

        Args:
            balance: Account balance
            symbol: Trading pair symbol
            price: Current price
            tp_distance: Take profit distance (optional)
            use_tp_for_sizing: If True, size based on TP distance

        Returns:
            Calculated lot size
        """
        if use_tp_for_sizing and tp_distance:
            # Size based on risking a percentage for TP distance
            # If TP is hit, we make risk_percent of balance
            risk_amount = balance * (self.risk_percent / 100)
            lot_size = risk_amount / tp_distance
        else:
            # Standard percentage-based sizing
            lot_size = self.calculate_base_lot_size(balance, price, symbol)

        return normalize_lot_size(lot_size, symbol)

    def calculate_averaging_lot_size(
        self,
        base_lot: float,
        averaging_level: int,
        symbol: str
    ) -> float:
        """
        Calculate lot size for an averaging order.

        Formula: LOT_SIZE(n) = base_lot * (1 + AGGRESSION_LEVEL/100) ** n

        Args:
            base_lot: Base lot size (from main order)
            averaging_level: Averaging level (1, 2, 3, ...)
            symbol: Trading pair symbol

        Returns:
            Averaging order lot size
        """
        multiplier = (1 + self.aggression_level / 100) ** averaging_level
        lot_size = base_lot * multiplier

        return normalize_lot_size(lot_size, symbol)

    def calculate_total_exposure(
        self,
        positions: list,
        current_price: float
    ) -> float:
        """
        Calculate total exposure (position value) across positions.

        Args:
            positions: List of positions with 'lots' attribute
            current_price: Current market price

        Returns:
            Total exposure value
        """
        total_lots = sum(p.lots if hasattr(p, 'lots') else p.get('lots', 0) for p in positions)
        return total_lots * current_price

    def calculate_unrealized_pnl(
        self,
        positions: list,
        current_price: float
    ) -> float:
        """
        Calculate unrealized P&L across positions.

        Args:
            positions: List of positions with 'lots' and 'price' attributes
            current_price: Current market price

        Returns:
            Total unrealized P&L
        """
        total_pnl = 0.0
        for pos in positions:
            lots = pos.lots if hasattr(pos, 'lots') else pos.get('lots', 0)
            entry = pos.price if hasattr(pos, 'price') else pos.get('price', 0)
            total_pnl += lots * (current_price - entry)
        return total_pnl

    def calculate_equity(
        self,
        balance: float,
        positions: list,
        current_price: float
    ) -> float:
        """
        Calculate current equity (balance + unrealized P&L).

        Args:
            balance: Account balance
            positions: List of open positions
            current_price: Current market price

        Returns:
            Current equity
        """
        unrealized = self.calculate_unrealized_pnl(positions, current_price)
        return balance + unrealized

    def can_open_position(
        self,
        balance: float,
        required_margin: float,
        current_margin_used: float = 0.0,
        max_margin_percent: float = 100.0
    ) -> bool:
        """
        Check if a new position can be opened based on margin.

        Args:
            balance: Account balance
            required_margin: Margin required for new position
            current_margin_used: Currently used margin
            max_margin_percent: Maximum margin usage percentage

        Returns:
            True if position can be opened
        """
        max_margin = balance * (max_margin_percent / 100)
        return (current_margin_used + required_margin) <= max_margin

    def get_lot_size_sequence(
        self,
        base_lot: float,
        max_levels: int,
        symbol: str
    ) -> list:
        """
        Get the lot size sequence for all averaging levels.

        Args:
            base_lot: Base lot size
            max_levels: Maximum number of levels
            symbol: Trading pair symbol

        Returns:
            List of lot sizes for each level
        """
        return [
            self.calculate_averaging_lot_size(base_lot, level, symbol)
            for level in range(max_levels + 1)
        ]

    def calculate_breakeven_price(
        self,
        positions: list
    ) -> float:
        """
        Calculate the breakeven price for a set of positions.

        Args:
            positions: List of positions with 'lots' and 'price' attributes

        Returns:
            Weighted average entry price (breakeven)
        """
        total_value = 0.0
        total_lots = 0.0

        for pos in positions:
            lots = pos.lots if hasattr(pos, 'lots') else pos.get('lots', 0)
            price = pos.price if hasattr(pos, 'price') else pos.get('price', 0)
            total_value += lots * price
            total_lots += lots

        if total_lots == 0:
            return 0.0

        return total_value / total_lots

    def calculate_dynamic_tp(
        self,
        positions: list,
        base_tp_distance: float
    ) -> float:
        """
        Calculate dynamic TP based on weighted average entry.

        Args:
            positions: List of open positions
            base_tp_distance: Base TP distance

        Returns:
            Dynamic TP price
        """
        breakeven = self.calculate_breakeven_price(positions)
        return breakeven + base_tp_distance
