"""
Fibonacci averaging engine.

Implements Fibonacci-based averaging with unlimited levels using
golden ratio extension, and TP cascade logic for position management.
"""

from dataclasses import dataclass
from datetime import datetime
from typing import List, Optional, Tuple
from functools import lru_cache

from config import (
    AGGRESSION_LEVEL,
    FIRST_DISTANCE_MULTIPLIER,
    TP_TOLERANCE_PIPS,
    REENTRY_ENABLED
)
from normalization import normalize_price, pips_to_price, price_to_pips
from thread_manager import ThreadData, ThreadManager, OrderInfo
from order_engine import OrderEngine, Position
from risk_manager import RiskManager
from utils import get_logger, log_trade


# Golden ratio for Fibonacci extension
GOLDEN_RATIO = 1.618033988749895


@lru_cache(maxsize=100)
def fibonacci(n: int) -> int:
    """
    Calculate Fibonacci number at position n.

    Uses memoization for efficiency. For n > 10, extends using golden ratio.

    Args:
        n: Fibonacci sequence position (0-indexed)

    Returns:
        Fibonacci number
    """
    if n <= 0:
        return 0
    if n == 1 or n == 2:
        return 1
    if n <= 10:
        # Standard Fibonacci calculation
        a, b = 1, 1
        for _ in range(n - 2):
            a, b = b, a + b
        return b
    else:
        # Golden ratio extension for larger values
        # F(n) ≈ φ^n / √5, rounded to nearest integer
        import math
        sqrt5 = math.sqrt(5)
        phi = GOLDEN_RATIO
        return round((phi ** n) / sqrt5)


def get_fibonacci_sequence(length: int) -> List[int]:
    """
    Get Fibonacci sequence up to specified length.

    Args:
        length: Number of terms to generate

    Returns:
        List of Fibonacci numbers
    """
    return [fibonacci(i) for i in range(1, length + 1)]


@dataclass
class AveragingLevel:
    """Information about a single averaging level."""
    level: int
    distance_multiplier: int  # Fibonacci multiplier
    price_distance: float
    trigger_price: float
    lot_size: float
    tp_price: float


class AveragingEngine:
    """
    Manages Fibonacci-based averaging logic.

    Calculates averaging distances, lot sizes, and handles
    TP cascade behavior for position management.
    """

    def __init__(
        self,
        thread_manager: ThreadManager,
        risk_manager: RiskManager,
        symbol: str,
        aggression_level: float = AGGRESSION_LEVEL,
        first_distance_mult: float = FIRST_DISTANCE_MULTIPLIER
    ):
        """
        Initialize averaging engine.

        Args:
            thread_manager: ThreadManager instance
            risk_manager: RiskManager instance
            symbol: Trading pair symbol
            aggression_level: Lot size increase percentage per level
            first_distance_mult: Multiplier for first averaging distance
        """
        self.thread_manager = thread_manager
        self.risk_manager = risk_manager
        self.symbol = symbol
        self.aggression_level = aggression_level
        self.first_distance_mult = first_distance_mult
        self.logger = get_logger()

    def calculate_averaging_distance(
        self,
        level: int,
        base_distance: float
    ) -> float:
        """
        Calculate the distance for an averaging level.

        Formula: AVG_DISTANCE(n) = base_distance * fibonacci(n)

        Args:
            level: Averaging level (1, 2, 3, ...)
            base_distance: Base entry distance

        Returns:
            Distance in price units
        """
        fib_mult = fibonacci(level)
        distance = base_distance * fib_mult * self.first_distance_mult
        return distance

    def calculate_averaging_lot(
        self,
        base_lot: float,
        level: int
    ) -> float:
        """
        Calculate lot size for an averaging level.

        Formula: LOT_SIZE(n) = base_lot * (1 + aggression/100)^n

        Args:
            base_lot: Base lot size from main order
            level: Averaging level

        Returns:
            Lot size for this level
        """
        return self.risk_manager.calculate_averaging_lot_size(
            base_lot, level, self.symbol
        )

    def get_averaging_trigger_price(
        self,
        main_entry_price: float,
        level: int,
        base_distance: float,
        is_buy: bool = True
    ) -> float:
        """
        Calculate the price that triggers an averaging order.

        For buy positions, averaging triggers below entry.

        Args:
            main_entry_price: Main order entry price
            level: Averaging level
            base_distance: Base entry distance
            is_buy: True for buy position

        Returns:
            Trigger price for this averaging level
        """
        distance = self.calculate_averaging_distance(level, base_distance)

        if is_buy:
            # Buy averaging triggers below entry (price drops)
            return main_entry_price - distance
        else:
            # Sell averaging triggers above entry (price rises)
            return main_entry_price + distance

    def get_cumulative_distance(
        self,
        level: int,
        base_distance: float
    ) -> float:
        """
        Get cumulative distance from main entry to averaging level.

        Args:
            level: Averaging level
            base_distance: Base entry distance

        Returns:
            Total distance from main entry
        """
        return self.calculate_averaging_distance(level, base_distance)

    def calculate_all_averaging_levels(
        self,
        main_entry_price: float,
        base_lot: float,
        base_distance: float,
        tp_distance: float,
        max_levels: int = 20,
        is_buy: bool = True
    ) -> List[AveragingLevel]:
        """
        Calculate all averaging levels up to max_levels.

        Args:
            main_entry_price: Main order entry price
            base_lot: Base lot size
            base_distance: Base entry distance
            tp_distance: Take profit distance
            max_levels: Maximum number of levels to calculate
            is_buy: True for buy position

        Returns:
            List of AveragingLevel objects
        """
        levels = []

        for level in range(1, max_levels + 1):
            distance_mult = fibonacci(level)
            price_distance = self.calculate_averaging_distance(level, base_distance)

            if is_buy:
                trigger_price = main_entry_price - price_distance
                tp_price = trigger_price + tp_distance
            else:
                trigger_price = main_entry_price + price_distance
                tp_price = trigger_price - tp_distance

            lot_size = self.calculate_averaging_lot(base_lot, level)

            levels.append(AveragingLevel(
                level=level,
                distance_multiplier=distance_mult,
                price_distance=price_distance,
                trigger_price=normalize_price(trigger_price, self.symbol),
                lot_size=lot_size,
                tp_price=normalize_price(tp_price, self.symbol)
            ))

        return levels

    def check_averaging_needed(
        self,
        thread: ThreadData,
        current_price: float,
        is_buy: bool = True
    ) -> Optional[AveragingLevel]:
        """
        Check if an averaging order should be triggered.

        Args:
            thread: Current thread data
            current_price: Current market price
            is_buy: True for buy position

        Returns:
            AveragingLevel if averaging should trigger, None otherwise
        """
        # Get the next averaging level
        current_max_level = thread.get_highest_averaging_level()
        next_level = current_max_level + 1

        # Calculate trigger price for next level
        trigger_price = self.get_averaging_trigger_price(
            thread.main_order_price,
            next_level,
            thread.entry_distance,
            is_buy
        )

        # Check if price has reached trigger
        should_trigger = False
        if is_buy:
            should_trigger = current_price <= trigger_price
        else:
            should_trigger = current_price >= trigger_price

        if should_trigger:
            lot_size = self.calculate_averaging_lot(
                thread.main_order_lots,
                next_level
            )

            if is_buy:
                tp_price = current_price + thread.tp_distance
            else:
                tp_price = current_price - thread.tp_distance

            return AveragingLevel(
                level=next_level,
                distance_multiplier=fibonacci(next_level),
                price_distance=self.calculate_averaging_distance(
                    next_level, thread.entry_distance
                ),
                trigger_price=normalize_price(trigger_price, self.symbol),
                lot_size=lot_size,
                tp_price=normalize_price(tp_price, self.symbol)
            )

        return None

    def check_tp_hit(
        self,
        order: OrderInfo,
        current_price: float,
        is_buy: bool = True
    ) -> bool:
        """
        Check if take profit has been hit for an order.

        Args:
            order: Order to check
            current_price: Current market price
            is_buy: True for buy position

        Returns:
            True if TP is hit
        """
        if order.tp_price is None:
            return False

        if is_buy:
            return current_price >= order.tp_price
        else:
            return current_price <= order.tp_price

    def check_price_returned_to_entry(
        self,
        entry_price: float,
        current_price: float,
        tolerance_pips: float = TP_TOLERANCE_PIPS
    ) -> bool:
        """
        Check if price has returned to entry level (for re-entry).

        Args:
            entry_price: Original entry price
            current_price: Current market price
            tolerance_pips: Tolerance in pips

        Returns:
            True if price is within tolerance of entry
        """
        tolerance = pips_to_price(tolerance_pips, self.symbol)
        return abs(current_price - entry_price) <= tolerance

    def handle_tp_cascade(
        self,
        thread: ThreadData,
        closed_level: int,
        current_price: float,
        current_time: datetime,
        order_engine: OrderEngine,
        is_buy: bool = True
    ) -> Tuple[bool, Optional[int]]:
        """
        Handle TP cascade logic after an averaging TP is hit.

        CloseRecentAvg behavior:
        1. On avg TP hit → mark order closed
        2. Set pending re-entry at original entry price
        3. If price returns → reopen same averaging order
        4. If that order also TPs → cascade to (n-1)
        5. Continue until thread resolved

        Args:
            thread: Thread data
            closed_level: Level that just hit TP
            current_price: Current market price
            current_time: Current timestamp
            order_engine: Order engine for execution
            is_buy: True for buy position

        Returns:
            Tuple of (cascade_complete, next_pending_level)
        """
        if closed_level == 0:
            # Main order TP hit - thread is complete
            return (True, None)

        if not REENTRY_ENABLED:
            # No re-entry, cascade is complete
            return (True, None)

        # Set up pending re-entry for the closed level
        # Re-entry price is the original entry price of that level
        avg_orders = thread.get_averaging_orders()
        closed_order = None

        for order in avg_orders:
            if order.averaging_level == closed_level and not order.is_open:
                closed_order = order
                break

        if closed_order:
            reentry_price = closed_order.price
            self.thread_manager.set_pending_reentry(
                thread.magic_number,
                reentry_price,
                closed_level
            )
            return (False, closed_level)

        return (True, None)

    def process_pending_reentry(
        self,
        thread: ThreadData,
        current_price: float,
        current_time: datetime,
        order_engine: OrderEngine,
        is_buy: bool = True
    ) -> Optional[OrderInfo]:
        """
        Process pending re-entry if price returns to entry.

        Args:
            thread: Thread data
            current_price: Current market price
            current_time: Current timestamp
            order_engine: Order engine for execution
            is_buy: True for buy position

        Returns:
            New OrderInfo if re-entry executed, None otherwise
        """
        if not thread.pending_reentry:
            return None

        if not self.check_price_returned_to_entry(
            thread.reentry_price, current_price
        ):
            return None

        # Execute re-entry
        level = thread.reentry_level
        lot_size = self.calculate_averaging_lot(thread.main_order_lots, level)

        if is_buy:
            tp_price = current_price + thread.tp_distance
        else:
            tp_price = current_price - thread.tp_distance

        # Place order through order engine
        if is_buy:
            order = order_engine.place_buy_order(
                lots=lot_size,
                tp_price=tp_price,
                magic=thread.magic_number,
                comment=f"AVG_REENTRY_L{level}"
            )
        else:
            order = order_engine.place_sell_order(
                lots=lot_size,
                tp_price=tp_price,
                magic=thread.magic_number,
                comment=f"AVG_REENTRY_L{level}"
            )

        if order:
            # Add to thread
            order_info = self.thread_manager.add_averaging_order(
                thread.magic_number,
                current_price,
                lot_size,
                tp_price,
                current_time,
                level,
                order.order_id
            )

            # Clear pending re-entry
            self.thread_manager.clear_pending_reentry(thread.magic_number)

            log_trade(
                "REENTRY",
                self.symbol,
                current_price,
                lot_size,
                thread.magic_number,
                tp=tp_price
            )

            return order_info

        return None

    def execute_averaging(
        self,
        thread: ThreadData,
        averaging_level: AveragingLevel,
        current_price: float,
        current_time: datetime,
        order_engine: OrderEngine,
        is_buy: bool = True
    ) -> Optional[OrderInfo]:
        """
        Execute an averaging order.

        Args:
            thread: Thread data
            averaging_level: Level to execute
            current_price: Current execution price
            current_time: Current timestamp
            order_engine: Order engine for execution
            is_buy: True for buy position

        Returns:
            OrderInfo if successful, None otherwise
        """
        if is_buy:
            order = order_engine.place_buy_order(
                lots=averaging_level.lot_size,
                tp_price=averaging_level.tp_price,
                magic=thread.magic_number,
                comment=f"AVG_L{averaging_level.level}"
            )
        else:
            order = order_engine.place_sell_order(
                lots=averaging_level.lot_size,
                tp_price=averaging_level.tp_price,
                magic=thread.magic_number,
                comment=f"AVG_L{averaging_level.level}"
            )

        if order:
            order_info = self.thread_manager.add_averaging_order(
                thread.magic_number,
                current_price,
                averaging_level.lot_size,
                averaging_level.tp_price,
                current_time,
                averaging_level.level,
                order.order_id
            )

            log_trade(
                "AVERAGING",
                self.symbol,
                current_price,
                averaging_level.lot_size,
                thread.magic_number,
                tp=averaging_level.tp_price
            )

            return order_info

        return None

    def get_thread_summary(self, thread: ThreadData) -> dict:
        """
        Get summary of a thread's averaging state.

        Args:
            thread: Thread data

        Returns:
            Dictionary with thread summary
        """
        open_orders = thread.get_open_orders()
        avg_orders = thread.get_open_averaging_orders()

        return {
            "magic": thread.magic_number,
            "main_entry": thread.main_order_price,
            "total_lots": thread.get_total_lots(),
            "weighted_avg_price": thread.get_weighted_average_price(),
            "open_orders": len(open_orders),
            "averaging_levels": len(avg_orders),
            "highest_level": thread.get_highest_averaging_level(),
            "max_level_reached": thread.max_averaging_level,
            "total_profit": thread.total_profit,
            "pending_reentry": thread.pending_reentry,
            "reentry_level": thread.reentry_level if thread.pending_reentry else None
        }


def calculate_worst_case_lots(
    base_lot: float,
    max_levels: int,
    aggression_level: float = AGGRESSION_LEVEL
) -> float:
    """
    Calculate total lot exposure at worst case (all levels filled).

    Args:
        base_lot: Base lot size
        max_levels: Maximum averaging levels
        aggression_level: Aggression percentage

    Returns:
        Total lots if all levels triggered
    """
    total = base_lot  # Main order
    for level in range(1, max_levels + 1):
        multiplier = (1 + aggression_level / 100) ** level
        total += base_lot * multiplier
    return total


def calculate_worst_case_drawdown(
    base_lot: float,
    base_distance: float,
    max_levels: int,
    aggression_level: float = AGGRESSION_LEVEL
) -> float:
    """
    Calculate maximum theoretical drawdown at worst case.

    Args:
        base_lot: Base lot size
        base_distance: Base entry distance
        max_levels: Maximum averaging levels
        aggression_level: Aggression percentage

    Returns:
        Maximum drawdown in currency units
    """
    total_drawdown = 0.0
    cumulative_distance = 0.0

    # Main order drawdown at max level distance
    max_distance = 0.0
    for level in range(1, max_levels + 1):
        max_distance += base_distance * fibonacci(level)

    # Calculate drawdown for each position
    # Main order
    total_drawdown += base_lot * max_distance

    # Averaging orders
    for level in range(1, max_levels + 1):
        lot_mult = (1 + aggression_level / 100) ** level
        level_lots = base_lot * lot_mult

        # Distance from this level to max
        remaining_distance = max_distance - sum(
            base_distance * fibonacci(l) for l in range(1, level + 1)
        )
        total_drawdown += level_lots * remaining_distance

    return total_drawdown
