"""
Data loader for historical OHLCV and tick data from CSV and JSON files.
"""

import json
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import List, Optional, Iterator, Generator
import re

import pandas as pd
import numpy as np

from config import DATE_FORMAT


@dataclass
class OHLCV:
    """Single OHLCV candle data."""
    timestamp: datetime
    open: float
    high: float
    low: float
    close: float
    volume: float

    @property
    def bid(self) -> float:
        """Simulated bid price (close - small offset)."""
        return self.close

    @property
    def ask(self) -> float:
        """Simulated ask price (close + small offset based on typical spread)."""
        spread = self.close * 0.0001
        return self.close + spread


@dataclass
class TickData:
    """Tick data for live trading simulation."""
    timestamp: datetime
    bid: float
    ask: float
    bid_volume: float = 0.0
    ask_volume: float = 0.0

    @property
    def mid(self) -> float:
        """Mid price between bid and ask."""
        return (self.bid + self.ask) / 2

    @property
    def spread(self) -> float:
        """Spread between bid and ask."""
        return self.ask - self.bid


class TickDataLoader:
    """
    Load and process tick data from CSV files.

    Optimized for large files with streaming/chunked loading.
    """

    def __init__(self, symbol: str = "BTCUSD"):
        """
        Initialize tick data loader.

        Args:
            symbol: Trading pair symbol
        """
        self.symbol = symbol
        self.file_path: Optional[str] = None
        self.total_rows: int = 0
        self._df: Optional[pd.DataFrame] = None

    def load_csv(
        self,
        file_path: str,
        max_rows: Optional[int] = None,
        skip_rows: int = 0,
        timestamp_col: str = "Timestamp",
        bid_col: str = "Bid price",
        ask_col: str = "Ask price",
        bid_vol_col: str = "Bid volume",
        ask_vol_col: str = "Ask volume"
    ) -> "TickDataLoader":
        """
        Load tick data from CSV file.

        Args:
            file_path: Path to CSV file
            max_rows: Maximum rows to load (None for all)
            skip_rows: Number of rows to skip from start
            timestamp_col: Timestamp column name
            bid_col: Bid price column name
            ask_col: Ask price column name
            bid_vol_col: Bid volume column name
            ask_vol_col: Ask volume column name

        Returns:
            Self for chaining
        """
        self.file_path = file_path

        # Build read parameters
        read_params = {
            'filepath_or_buffer': file_path,
            'skiprows': range(1, skip_rows + 1) if skip_rows > 0 else None,
            'nrows': max_rows
        }

        # Remove None values
        read_params = {k: v for k, v in read_params.items() if v is not None}

        print(f"Loading tick data from {file_path}...")
        df = pd.read_csv(**read_params)

        # Store column mapping
        self._timestamp_col = timestamp_col
        self._bid_col = bid_col
        self._ask_col = ask_col
        self._bid_vol_col = bid_vol_col
        self._ask_vol_col = ask_vol_col

        # Parse timestamps (format: YYYYMMDD HH:MM:SS:milliseconds)
        df['parsed_timestamp'] = df[timestamp_col].apply(self._parse_timestamp)

        self._df = df
        self.total_rows = len(df)

        print(f"Loaded {self.total_rows:,} ticks")

        return self

    def _parse_timestamp(self, ts_str: str) -> datetime:
        """Parse timestamp in format 'YYYYMMDD HH:MM:SS:milliseconds'"""
        try:
            # Format: 20250622 00:00:00:299
            match = re.match(r'(\d{8})\s+(\d{2}):(\d{2}):(\d{2}):(\d+)', str(ts_str))
            if match:
                date_part = match.group(1)
                hour = int(match.group(2))
                minute = int(match.group(3))
                second = int(match.group(4))
                ms = int(match.group(5))

                year = int(date_part[:4])
                month = int(date_part[4:6])
                day = int(date_part[6:8])

                return datetime(year, month, day, hour, minute, second, ms * 1000)
            else:
                # Fallback to pandas parsing
                return pd.to_datetime(ts_str).to_pydatetime()
        except Exception:
            return datetime.now()

    def iterate_ticks(
        self,
        start_idx: int = 0,
        end_idx: Optional[int] = None
    ) -> Generator[TickData, None, None]:
        """
        Iterate over ticks.

        Args:
            start_idx: Starting index
            end_idx: Ending index (None for all)

        Yields:
            TickData objects
        """
        if self._df is None:
            return

        end = end_idx if end_idx is not None else len(self._df)

        for idx in range(start_idx, min(end, len(self._df))):
            row = self._df.iloc[idx]
            yield TickData(
                timestamp=row['parsed_timestamp'],
                bid=float(row[self._bid_col]),
                ask=float(row[self._ask_col]),
                bid_volume=float(row[self._bid_vol_col]) if self._bid_vol_col in self._df.columns else 0.0,
                ask_volume=float(row[self._ask_vol_col]) if self._ask_vol_col in self._df.columns else 0.0
            )

    def iterate_ticks_chunked(
        self,
        chunk_size: int = 10000
    ) -> Generator[List[TickData], None, None]:
        """
        Iterate over ticks in chunks for memory efficiency.

        Args:
            chunk_size: Number of ticks per chunk

        Yields:
            Lists of TickData objects
        """
        if self._df is None:
            return

        for start in range(0, len(self._df), chunk_size):
            end = min(start + chunk_size, len(self._df))
            chunk = []

            for idx in range(start, end):
                row = self._df.iloc[idx]
                chunk.append(TickData(
                    timestamp=row['parsed_timestamp'],
                    bid=float(row[self._bid_col]),
                    ask=float(row[self._ask_col]),
                    bid_volume=float(row[self._bid_vol_col]) if self._bid_vol_col in self._df.columns else 0.0,
                    ask_volume=float(row[self._ask_vol_col]) if self._ask_vol_col in self._df.columns else 0.0
                ))

            yield chunk

    def get_tick(self, index: int) -> Optional[TickData]:
        """Get single tick by index."""
        if self._df is None or index >= len(self._df):
            return None

        row = self._df.iloc[index]
        return TickData(
            timestamp=row['parsed_timestamp'],
            bid=float(row[self._bid_col]),
            ask=float(row[self._ask_col]),
            bid_volume=float(row[self._bid_vol_col]) if self._bid_vol_col in self._df.columns else 0.0,
            ask_volume=float(row[self._ask_vol_col]) if self._ask_vol_col in self._df.columns else 0.0
        )

    def get_price_range(self) -> tuple:
        """Get min and max prices in the data."""
        if self._df is None:
            return (0.0, 0.0)

        min_price = min(self._df[self._bid_col].min(), self._df[self._ask_col].min())
        max_price = max(self._df[self._bid_col].max(), self._df[self._ask_col].max())
        return (float(min_price), float(max_price))

    def get_date_range(self) -> tuple:
        """Get start and end dates of the data."""
        if self._df is None or len(self._df) == 0:
            return (None, None)

        return (
            self._df['parsed_timestamp'].iloc[0],
            self._df['parsed_timestamp'].iloc[-1]
        )

    def get_spread_stats(self) -> dict:
        """Get spread statistics."""
        if self._df is None:
            return {}

        spreads = self._df[self._ask_col] - self._df[self._bid_col]
        return {
            "min_spread": float(spreads.min()),
            "max_spread": float(spreads.max()),
            "avg_spread": float(spreads.mean()),
            "median_spread": float(spreads.median())
        }

    def __len__(self) -> int:
        """Get number of ticks."""
        return self.total_rows

    def __getitem__(self, index: int) -> TickData:
        """Get tick by index."""
        tick = self.get_tick(index)
        if tick is None:
            raise IndexError(f"Index {index} out of range")
        return tick


class DataLoader:
    """Load and process historical OHLCV data."""

    def __init__(self, symbol: str = "BTCUSDT"):
        """
        Initialize data loader.

        Args:
            symbol: Trading pair symbol
        """
        self.symbol = symbol
        self.data: List[OHLCV] = []
        self._df: Optional[pd.DataFrame] = None

    def load_csv(
        self,
        file_path: str,
        timestamp_col: str = "timestamp",
        open_col: str = "open",
        high_col: str = "high",
        low_col: str = "low",
        close_col: str = "close",
        volume_col: str = "volume",
        date_format: Optional[str] = None
    ) -> "DataLoader":
        """
        Load OHLCV data from CSV file.

        Args:
            file_path: Path to CSV file
            timestamp_col: Name of timestamp column
            open_col: Name of open price column
            high_col: Name of high price column
            low_col: Name of low price column
            close_col: Name of close price column
            volume_col: Name of volume column
            date_format: Date format string (auto-detect if None)

        Returns:
            Self for chaining
        """
        df = pd.read_csv(file_path)

        # Parse timestamps
        if date_format:
            df[timestamp_col] = pd.to_datetime(df[timestamp_col], format=date_format)
        else:
            try:
                df[timestamp_col] = pd.to_datetime(df[timestamp_col])
            except (ValueError, TypeError):
                ts_values = df[timestamp_col].astype(float)
                if ts_values.max() > 1e12:
                    df[timestamp_col] = pd.to_datetime(ts_values, unit='ms')
                else:
                    df[timestamp_col] = pd.to_datetime(ts_values, unit='s')

        df = df.sort_values(timestamp_col).reset_index(drop=True)

        self._df = df
        self.data = self._dataframe_to_ohlcv(
            df, timestamp_col, open_col, high_col, low_col, close_col, volume_col
        )

        return self

    def load_json(
        self,
        file_path: str,
        data_key: Optional[str] = None
    ) -> "DataLoader":
        """Load OHLCV data from JSON file."""
        with open(file_path, 'r') as f:
            raw_data = json.load(f)

        if data_key:
            raw_data = raw_data[data_key]
        elif isinstance(raw_data, dict) and "data" in raw_data:
            raw_data = raw_data["data"]

        if not raw_data:
            self.data = []
            return self

        first_item = raw_data[0]

        if isinstance(first_item, list):
            df = pd.DataFrame(
                raw_data,
                columns=["timestamp", "open", "high", "low", "close", "volume"]
            )
        else:
            df = pd.DataFrame(raw_data)

        ts_col = "timestamp"
        try:
            df[ts_col] = pd.to_datetime(df[ts_col])
        except (ValueError, TypeError):
            ts_values = df[ts_col].astype(float)
            if ts_values.max() > 1e12:
                df[ts_col] = pd.to_datetime(ts_values, unit='ms')
            else:
                df[ts_col] = pd.to_datetime(ts_values, unit='s')

        df = df.sort_values(ts_col).reset_index(drop=True)

        self._df = df
        self.data = self._dataframe_to_ohlcv(
            df, "timestamp", "open", "high", "low", "close", "volume"
        )

        return self

    def _dataframe_to_ohlcv(
        self,
        df: pd.DataFrame,
        timestamp_col: str,
        open_col: str,
        high_col: str,
        low_col: str,
        close_col: str,
        volume_col: str
    ) -> List[OHLCV]:
        """Convert DataFrame to list of OHLCV objects."""
        data = []
        for _, row in df.iterrows():
            candle = OHLCV(
                timestamp=row[timestamp_col].to_pydatetime() if hasattr(row[timestamp_col], 'to_pydatetime') else row[timestamp_col],
                open=float(row[open_col]),
                high=float(row[high_col]),
                low=float(row[low_col]),
                close=float(row[close_col]),
                volume=float(row[volume_col]) if volume_col in df.columns else 0.0
            )
            data.append(candle)
        return data

    def to_dataframe(self) -> pd.DataFrame:
        """Get data as pandas DataFrame."""
        if self._df is not None:
            return self._df.copy()

        if not self.data:
            return pd.DataFrame()

        return pd.DataFrame([
            {
                "timestamp": c.timestamp,
                "open": c.open,
                "high": c.high,
                "low": c.low,
                "close": c.close,
                "volume": c.volume
            }
            for c in self.data
        ])

    def iterate_candles(self) -> Iterator[OHLCV]:
        """Iterate over candles one by one."""
        for candle in self.data:
            yield candle

    def iterate_ticks(self, ticks_per_candle: int = 4) -> Iterator[TickData]:
        """Generate simulated tick data from candles."""
        for candle in self.data:
            prices = self._interpolate_candle(candle, ticks_per_candle)

            for i, price in enumerate(prices):
                tick_time = candle.timestamp
                spread = price * 0.0001
                tick = TickData(
                    timestamp=tick_time,
                    bid=price,
                    ask=price + spread
                )
                yield tick

    def _interpolate_candle(self, candle: OHLCV, num_ticks: int) -> List[float]:
        """Interpolate prices within a candle."""
        if num_ticks <= 4:
            return [candle.open, candle.high, candle.low, candle.close][:num_ticks]

        prices = [candle.open]
        is_bullish = candle.close >= candle.open

        if is_bullish:
            mid_points = num_ticks - 2
            low_idx = mid_points // 3
            high_idx = 2 * mid_points // 3

            for i in range(1, num_ticks - 1):
                if i <= low_idx:
                    ratio = i / low_idx
                    prices.append(candle.open + ratio * (candle.low - candle.open))
                elif i <= high_idx:
                    ratio = (i - low_idx) / (high_idx - low_idx)
                    prices.append(candle.low + ratio * (candle.high - candle.low))
                else:
                    ratio = (i - high_idx) / (mid_points - high_idx)
                    prices.append(candle.high + ratio * (candle.close - candle.high))
        else:
            mid_points = num_ticks - 2
            high_idx = mid_points // 3
            low_idx = 2 * mid_points // 3

            for i in range(1, num_ticks - 1):
                if i <= high_idx:
                    ratio = i / high_idx
                    prices.append(candle.open + ratio * (candle.high - candle.open))
                elif i <= low_idx:
                    ratio = (i - high_idx) / (low_idx - high_idx)
                    prices.append(candle.high + ratio * (candle.low - candle.high))
                else:
                    ratio = (i - low_idx) / (mid_points - low_idx)
                    prices.append(candle.low + ratio * (candle.close - candle.low))

        prices.append(candle.close)
        return prices

    def get_price_range(self) -> tuple:
        """Get min and max prices in the data."""
        if not self.data:
            return (0.0, 0.0)

        min_price = min(c.low for c in self.data)
        max_price = max(c.high for c in self.data)
        return (min_price, max_price)

    def get_date_range(self) -> tuple:
        """Get start and end dates of the data."""
        if not self.data:
            return (None, None)

        return (self.data[0].timestamp, self.data[-1].timestamp)

    def __len__(self) -> int:
        """Get number of candles."""
        return len(self.data)

    def __getitem__(self, index: int) -> OHLCV:
        """Get candle by index."""
        return self.data[index]


def create_sample_data(
    symbol: str = "BTCUSDT",
    num_candles: int = 1000,
    start_price: float = 50000.0,
    volatility: float = 0.02
) -> DataLoader:
    """
    Create sample OHLCV data for testing.

    Args:
        symbol: Trading pair symbol
        num_candles: Number of candles to generate
        start_price: Starting price
        volatility: Price volatility factor

    Returns:
        DataLoader with generated data
    """
    np.random.seed(42)

    loader = DataLoader(symbol)
    data = []
    current_price = start_price
    current_time = datetime(2024, 1, 1)

    for i in range(num_candles):
        change = np.random.normal(0, volatility)
        open_price = current_price
        close_price = open_price * (1 + change)

        max_move = abs(change) + volatility * np.random.random()
        if close_price > open_price:
            high_price = close_price * (1 + max_move * np.random.random())
            low_price = open_price * (1 - max_move * np.random.random())
        else:
            high_price = open_price * (1 + max_move * np.random.random())
            low_price = close_price * (1 - max_move * np.random.random())

        high_price = max(high_price, open_price, close_price)
        low_price = min(low_price, open_price, close_price)

        volume = np.random.random() * 1000 + 100

        candle = OHLCV(
            timestamp=current_time,
            open=open_price,
            high=high_price,
            low=low_price,
            close=close_price,
            volume=volume
        )
        data.append(candle)

        current_price = close_price
        current_time = current_time + pd.Timedelta(hours=1)

    loader.data = data
    return loader
