"""
Thread tracking and management for trading positions.

A "thread" represents a trading sequence that may include a main order
and multiple averaging orders that form a position management unit.
"""

from dataclasses import dataclass, field
from datetime import datetime
from typing import Dict, List, Optional
from enum import Enum

from config import BASE_MAGIC
from utils import get_logger, log_thread_event


class ThreadStatus(Enum):
    """Thread lifecycle status."""
    ACTIVE = "active"
    PENDING_REENTRY = "pending_reentry"
    CLOSED = "closed"


@dataclass
class OrderInfo:
    """Information about an individual order within a thread."""
    order_id: int
    price: float
    lots: float
    tp_price: float
    open_time: datetime
    is_averaging: bool = False
    averaging_level: int = 0
    is_open: bool = True
    close_price: Optional[float] = None
    close_time: Optional[datetime] = None
    profit: float = 0.0


@dataclass
class ThreadData:
    """
    Data structure for tracking a trading thread.

    A thread encompasses a main order and any averaging orders
    that are opened to manage the position.
    """
    magic_number: int
    symbol: str
    main_order_price: float
    main_order_time: datetime
    main_order_lots: float
    tp_distance: float
    entry_distance: float

    # Thread state
    status: ThreadStatus = ThreadStatus.ACTIVE
    averaging_count: int = 0

    # Order tracking
    orders: List[OrderInfo] = field(default_factory=list)

    # Averaging state
    averaging_prices: List[float] = field(default_factory=list)
    averaging_lot_sizes: List[float] = field(default_factory=list)
    last_averaging_price: float = 0.0

    # Re-entry tracking
    pending_reentry: bool = False
    reentry_price: float = 0.0
    reentry_level: int = 0

    # Statistics
    total_profit: float = 0.0
    total_orders: int = 0
    max_averaging_level: int = 0

    def get_main_order(self) -> Optional[OrderInfo]:
        """Get the main (non-averaging) order."""
        for order in self.orders:
            if not order.is_averaging:
                return order
        return None

    def get_open_orders(self) -> List[OrderInfo]:
        """Get all currently open orders."""
        return [o for o in self.orders if o.is_open]

    def get_averaging_orders(self) -> List[OrderInfo]:
        """Get all averaging orders."""
        return [o for o in self.orders if o.is_averaging]

    def get_open_averaging_orders(self) -> List[OrderInfo]:
        """Get open averaging orders sorted by level."""
        return sorted(
            [o for o in self.orders if o.is_averaging and o.is_open],
            key=lambda x: x.averaging_level
        )

    def get_highest_averaging_level(self) -> int:
        """Get the highest averaging level currently open."""
        open_avg = self.get_open_averaging_orders()
        if not open_avg:
            return 0
        return max(o.averaging_level for o in open_avg)

    def get_total_lots(self) -> float:
        """Get total lot size of all open orders."""
        return sum(o.lots for o in self.get_open_orders())

    def get_weighted_average_price(self) -> float:
        """Get weighted average entry price of all open orders."""
        open_orders = self.get_open_orders()
        if not open_orders:
            return 0.0
        total_value = sum(o.price * o.lots for o in open_orders)
        total_lots = sum(o.lots for o in open_orders)
        return total_value / total_lots if total_lots > 0 else 0.0


class ThreadManager:
    """
    Manages trading threads.

    Handles creation, tracking, updating, and cleanup of trading threads.
    """

    def __init__(self):
        """Initialize thread manager."""
        self.threads: Dict[int, ThreadData] = {}
        self._order_counter = 0
        self._thread_index = 0
        self.logger = get_logger()

    def generate_magic_number(self) -> int:
        """
        Generate a unique magic number for a new thread.

        Magic number format: BASE_MAGIC + (order_counter << 8) + thread_index

        Returns:
            Unique magic number
        """
        self._order_counter += 1
        self._thread_index = (self._thread_index + 1) % 256

        magic = BASE_MAGIC + (self._order_counter << 8) + self._thread_index
        return magic

    def create_thread(
        self,
        symbol: str,
        entry_price: float,
        entry_time: datetime,
        lot_size: float,
        tp_distance: float,
        entry_distance: float,
        order_id: int = 0
    ) -> ThreadData:
        """
        Create and initialize a new trading thread.

        Args:
            symbol: Trading pair symbol
            entry_price: Main order entry price
            entry_time: Entry timestamp
            lot_size: Main order lot size
            tp_distance: Take profit distance in price units
            entry_distance: Entry/averaging base distance

        Returns:
            New ThreadData instance
        """
        magic = self.generate_magic_number()

        thread = ThreadData(
            magic_number=magic,
            symbol=symbol,
            main_order_price=entry_price,
            main_order_time=entry_time,
            main_order_lots=lot_size,
            tp_distance=tp_distance,
            entry_distance=entry_distance,
            total_orders=1
        )

        # Create main order info
        main_order = OrderInfo(
            order_id=order_id if order_id else magic,
            price=entry_price,
            lots=lot_size,
            tp_price=entry_price + tp_distance,  # Buy order TP is above entry
            open_time=entry_time,
            is_averaging=False,
            averaging_level=0
        )
        thread.orders.append(main_order)

        self.threads[magic] = thread

        log_thread_event(
            "CREATED",
            magic,
            f"Symbol: {symbol} | Entry: {entry_price:.8f} | Lots: {lot_size:.6f}"
        )

        return thread

    def get_thread(self, magic: int) -> Optional[ThreadData]:
        """
        Get thread by magic number.

        Args:
            magic: Thread magic number

        Returns:
            ThreadData if found, None otherwise
        """
        return self.threads.get(magic)

    def get_active_threads(self) -> List[ThreadData]:
        """
        Get all active threads.

        Returns:
            List of active ThreadData instances
        """
        return [
            t for t in self.threads.values()
            if t.status == ThreadStatus.ACTIVE
        ]

    def get_threads_by_symbol(self, symbol: str) -> List[ThreadData]:
        """
        Get all threads for a specific symbol.

        Args:
            symbol: Trading pair symbol

        Returns:
            List of threads for the symbol
        """
        return [
            t for t in self.threads.values()
            if t.symbol == symbol and t.status != ThreadStatus.CLOSED
        ]

    def add_averaging_order(
        self,
        magic: int,
        price: float,
        lots: float,
        tp_price: float,
        order_time: datetime,
        level: int,
        order_id: int = 0
    ) -> Optional[OrderInfo]:
        """
        Add an averaging order to a thread.

        Args:
            magic: Thread magic number
            price: Averaging order entry price
            lots: Order lot size
            tp_price: Take profit price
            order_time: Order timestamp
            level: Averaging level (1, 2, 3, ...)
            order_id: External order ID

        Returns:
            OrderInfo if successful, None if thread not found
        """
        thread = self.get_thread(magic)
        if not thread:
            return None

        order = OrderInfo(
            order_id=order_id if order_id else magic * 1000 + level,
            price=price,
            lots=lots,
            tp_price=tp_price,
            open_time=order_time,
            is_averaging=True,
            averaging_level=level
        )

        thread.orders.append(order)
        thread.averaging_count += 1
        thread.averaging_prices.append(price)
        thread.averaging_lot_sizes.append(lots)
        thread.last_averaging_price = price
        thread.total_orders += 1
        thread.max_averaging_level = max(thread.max_averaging_level, level)

        log_thread_event(
            "AVERAGING",
            magic,
            f"Level: {level} | Price: {price:.8f} | Lots: {lots:.6f}"
        )

        return order

    def close_order(
        self,
        magic: int,
        order_id: int,
        close_price: float,
        close_time: datetime
    ) -> Optional[float]:
        """
        Close an order within a thread.

        Args:
            magic: Thread magic number
            order_id: Order ID to close
            close_price: Closing price
            close_time: Close timestamp

        Returns:
            Profit from the order, or None if not found
        """
        thread = self.get_thread(magic)
        if not thread:
            return None

        for order in thread.orders:
            if order.order_id == order_id and order.is_open:
                order.is_open = False
                order.close_price = close_price
                order.close_time = close_time
                # Calculate profit (buy order: close - open)
                order.profit = order.lots * (close_price - order.price)
                thread.total_profit += order.profit

                log_thread_event(
                    "ORDER_CLOSED",
                    magic,
                    f"OrderID: {order_id} | Profit: {order.profit:.2f}"
                )

                return order.profit

        return None

    def close_order_by_level(
        self,
        magic: int,
        level: int,
        close_price: float,
        close_time: datetime
    ) -> Optional[float]:
        """
        Close an averaging order by its level.

        Args:
            magic: Thread magic number
            level: Averaging level to close (0 = main order)
            close_price: Closing price
            close_time: Close timestamp

        Returns:
            Profit from the order, or None if not found
        """
        thread = self.get_thread(magic)
        if not thread:
            return None

        for order in thread.orders:
            if order.averaging_level == level and order.is_open:
                return self.close_order(magic, order.order_id, close_price, close_time)

        return None

    def set_pending_reentry(
        self,
        magic: int,
        reentry_price: float,
        reentry_level: int
    ) -> bool:
        """
        Set thread to pending re-entry state.

        Args:
            magic: Thread magic number
            reentry_price: Price at which to re-enter
            reentry_level: Averaging level to reopen

        Returns:
            True if successful
        """
        thread = self.get_thread(magic)
        if not thread:
            return False

        thread.pending_reentry = True
        thread.reentry_price = reentry_price
        thread.reentry_level = reentry_level
        thread.status = ThreadStatus.PENDING_REENTRY

        log_thread_event(
            "PENDING_REENTRY",
            magic,
            f"Level: {reentry_level} | Price: {reentry_price:.8f}"
        )

        return True

    def clear_pending_reentry(self, magic: int) -> bool:
        """
        Clear pending re-entry state.

        Args:
            magic: Thread magic number

        Returns:
            True if successful
        """
        thread = self.get_thread(magic)
        if not thread:
            return False

        thread.pending_reentry = False
        thread.status = ThreadStatus.ACTIVE

        return True

    def update_thread(self, thread: ThreadData) -> None:
        """
        Update thread data in manager.

        Args:
            thread: ThreadData to update
        """
        self.threads[thread.magic_number] = thread

    def close_thread(self, magic: int) -> Optional[float]:
        """
        Close a thread entirely.

        Args:
            magic: Thread magic number

        Returns:
            Total profit from the thread
        """
        thread = self.get_thread(magic)
        if not thread:
            return None

        thread.status = ThreadStatus.CLOSED

        log_thread_event(
            "CLOSED",
            magic,
            f"Total Profit: {thread.total_profit:.2f} | "
            f"Max Avg Level: {thread.max_averaging_level}"
        )

        return thread.total_profit

    def cleanup_closed_threads(self) -> int:
        """
        Remove closed threads from manager.

        Returns:
            Number of threads removed
        """
        closed = [
            m for m, t in self.threads.items()
            if t.status == ThreadStatus.CLOSED
        ]

        for magic in closed:
            del self.threads[magic]

        return len(closed)

    def get_statistics(self) -> dict:
        """
        Get overall thread statistics.

        Returns:
            Dictionary of statistics
        """
        active = self.get_active_threads()

        total_profit = sum(t.total_profit for t in self.threads.values())
        total_orders = sum(t.total_orders for t in self.threads.values())
        max_avg = max(
            (t.max_averaging_level for t in self.threads.values()),
            default=0
        )

        return {
            "active_threads": len(active),
            "total_threads": len(self.threads),
            "total_profit": total_profit,
            "total_orders": total_orders,
            "max_averaging_level": max_avg,
            "open_positions": sum(
                len(t.get_open_orders()) for t in active
            )
        }
