"""
Generate comprehensive trade report from backtest results.
"""

import re
import json
from datetime import datetime
from dataclasses import dataclass
from typing import List, Dict, Optional
from settings import (
    SYMBOL, INITIAL_BALANCE, RISK_PERCENT, BASE_MULTIPLIER,
    AGGRESSION_LEVEL, MAX_SPREAD, REENTRY_ENABLED, COMMISSION_PERCENT,
    CUSTOM_LEVELS, CUSTOM_MULTIPLIERS
)


@dataclass
class Trade:
    """Individual trade record."""
    thread_id: int
    timestamp: datetime
    event_type: str
    level: int
    entry_price: float
    exit_price: float
    lots: float
    profit: float
    is_winner: bool


def parse_thread_logs(filename: str = "thread_logs.txt") -> Dict:
    """Parse thread logs and extract all trade data."""

    with open(filename, 'r') as f:
        content = f.read()

    # Extract summary stats
    summary = {}

    # Total threads
    match = re.search(r'Total Threads:\s+(\d+)', content)
    summary['total_threads'] = int(match.group(1)) if match else 0

    # Profitable threads
    match = re.search(r'Profitable:\s+(\d+)', content)
    summary['profitable_threads'] = int(match.group(1)) if match else 0

    # Total P/L
    match = re.search(r'Total P/L:\s+\$([+-]?[\d,]+\.?\d*)', content)
    summary['total_pl'] = float(match.group(1).replace(',', '')) if match else 0

    # Spread filtered
    match = re.search(r'Spread Filtered:\s+(\d+)', content)
    summary['spread_filtered'] = int(match.group(1)) if match else 0

    # Reentries abandoned
    match = re.search(r'Reentries Abandoned:\s*(\d+)', content)
    summary['reentries_abandoned'] = int(match.group(1)) if match else 0

    # Parse individual threads (deduplicate by thread ID)
    threads_dict = {}
    thread_pattern = r'Thread #(\d+): Entry \$([\d,]+\.?\d*) \| MaxLvl (\d+) \| Trades (\d+) \| Wins (\d+) \| P/L \$([+-]?[\d,]+\.?\d*)'

    for match in re.finditer(thread_pattern, content):
        thread_id = int(match.group(1))
        if thread_id not in threads_dict:  # Only keep first occurrence
            threads_dict[thread_id] = {
                'id': thread_id,
                'entry_price': float(match.group(2).replace(',', '')),
                'max_level': int(match.group(3)),
                'total_trades': int(match.group(4)),
                'wins': int(match.group(5)),
                'pnl': float(match.group(6).replace(',', ''))
            }
    threads = list(threads_dict.values())

    # Parse all events - using a more flexible approach
    trades = []
    current_thread_id = 0

    for line in content.split('\n'):
        # Check for new thread
        thread_match = re.search(r'Thread #(\d+):', line)
        if thread_match:
            current_thread_id = int(thread_match.group(1))
            continue

        # Parse event lines (they start with timestamp)
        if re.match(r'^\d{4}-\d{2}-\d{2}', line.strip()):
            parts = line.split()
            if len(parts) >= 8:
                timestamp = parts[0] + ' ' + parts[1]
                event_type = parts[2]

                if event_type in ['OPEN', 'AVG', 'TP_HIT', 'TP_REENTRY', 'REENTRY', 'ABANDON']:
                    try:
                        level = int(parts[3])
                        price = float(parts[4].replace('$', '').replace(',', ''))
                        lots = float(parts[5])
                        tp_price = float(parts[6].replace('$', '').replace(',', ''))

                        # Find profit (starts with $ and has + or -)
                        profit = 0.0
                        balance = 0.0
                        for i, p in enumerate(parts[7:], 7):
                            if p.startswith('$') and ('+' in p or '-' in p) and i < len(parts) - 1:
                                profit = float(p.replace('$', '').replace(',', '').replace('+', ''))
                            elif p.startswith('$') and profit != 0:
                                balance = float(p.replace('$', '').replace(',', ''))
                                break
                            elif p.startswith('$') and i == len(parts) - 2:
                                # Second to last $ value might be balance
                                pass

                        # Last $ value is usually balance
                        for p in reversed(parts):
                            if p.startswith('$'):
                                balance = float(p.replace('$', '').replace(',', ''))
                                break

                        trades.append({
                            'thread_id': current_thread_id,
                            'timestamp': timestamp,
                            'event_type': event_type,
                            'level': level,
                            'price': price,
                            'lots': lots,
                            'tp_price': tp_price,
                            'profit': profit,
                            'balance': balance
                        })
                    except (ValueError, IndexError):
                        continue

    # Count events
    events = {
        'OPEN': 0, 'AVG': 0, 'TP_HIT': 0, 'TP_REENTRY': 0,
        'REENTRY': 0, 'ABANDON': 0, 'THREAD_COMPLETE': 0, 'THREAD_ABANDONED': 0
    }
    for event_type in events:
        events[event_type] = len(re.findall(rf'\s{event_type}\s', content))

    return {
        'summary': summary,
        'threads': threads,
        'trades': trades,
        'events': events
    }


def load_backtest_results(json_file: str = "backtest_results.json") -> Optional[Dict]:
    """Load backtest results from JSON file."""
    import os
    if os.path.exists(json_file):
        with open(json_file, 'r') as f:
            return json.load(f)
    return None


def generate_report(log_file: str = "thread_logs.txt", output_file: str = "trade_report.txt"):
    """Generate comprehensive trade report."""

    data = parse_thread_logs(log_file)
    summary = data['summary']
    threads = data['threads']
    trades = data['trades']
    events = data['events']

    # Try to load actual backtest results from JSON (most accurate)
    backtest_results = load_backtest_results()

    if backtest_results:
        # Use actual backtest results
        total_trades = backtest_results['total_trades']
        winning_trades = backtest_results['winning_trades']
        losing_trades = backtest_results['losing_trades']
        gross_profit = backtest_results['gross_profit']
        gross_loss = backtest_results['gross_loss']
        net_profit = backtest_results['net_profit']
        profit_factor = backtest_results['profit_factor']
        win_rate = backtest_results['win_rate']
        max_drawdown = backtest_results['max_drawdown']
        max_drawdown_pct = backtest_results['max_drawdown_pct']
        total_commission = backtest_results['total_commission']
        initial_balance = backtest_results['initial_balance']
        final_balance = backtest_results['final_balance']
        final_equity = backtest_results['final_equity']
        total_return_pct = backtest_results['total_return_pct']
    else:
        # Fallback to thread-based calculations
        total_trades = sum(t['total_trades'] for t in threads)
        winning_trades = sum(t['wins'] for t in threads)
        losing_trades = total_trades - winning_trades

        thread_profits = [t['pnl'] for t in threads]
        net_profit = sum(thread_profits)
        gross_profit = net_profit if net_profit > 0 else 0
        gross_loss = 0.0

        profit_factor = float('inf') if gross_loss == 0 else gross_profit / gross_loss
        win_rate = winning_trades / total_trades * 100 if total_trades > 0 else 0
        max_drawdown = 0
        max_drawdown_pct = 0
        total_commission = 0
        initial_balance = INITIAL_BALANCE
        final_balance = initial_balance + net_profit
        final_equity = final_balance
        total_return_pct = net_profit / initial_balance * 100 if initial_balance > 0 else 0

    avg_win = gross_profit / winning_trades if winning_trades > 0 else 0
    avg_loss = gross_loss / losing_trades if losing_trades > 0 else 0

    # Find largest win/loss from threads
    thread_profits = [t['pnl'] for t in threads]
    largest_win = max(thread_profits) if thread_profits else 0
    largest_loss = min(thread_profits) if thread_profits else 0

    # Thread statistics
    if threads:
        avg_thread_pnl = sum(t['pnl'] for t in threads) / len(threads)
        max_level_reached = max(t['max_level'] for t in threads)
        avg_max_level = sum(t['max_level'] for t in threads) / len(threads)
        thread_win_rate = sum(1 for t in threads if t['pnl'] > 0) / len(threads) * 100
    else:
        avg_thread_pnl = max_level_reached = avg_max_level = thread_win_rate = 0

    # Re-entry analysis
    total_reentries = events['REENTRY']
    successful_reentries = events['TP_REENTRY']
    reentry_success_rate = successful_reentries / total_reentries * 100 if total_reentries > 0 else 0

    # Level distribution
    level_dist = {}
    for t in threads:
        level = t['max_level']
        level_dist[level] = level_dist.get(level, 0) + 1

    # Generate report
    with open(output_file, 'w') as f:
        f.write("=" * 100 + "\n")
        f.write("                         COMPREHENSIVE TRADE REPORT\n")
        f.write("=" * 100 + "\n")
        f.write(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        f.write("=" * 100 + "\n\n")

        # Settings section
        f.write("=" * 100 + "\n")
        f.write("                              BACKTEST SETTINGS\n")
        f.write("=" * 100 + "\n\n")
        f.write(f"{'Symbol:':<30} {SYMBOL}\n")
        f.write(f"{'Initial Balance:':<30} ${INITIAL_BALANCE:,.2f}\n")
        f.write(f"{'Risk Percent:':<30} {RISK_PERCENT}%\n")
        f.write(f"{'Base Multiplier:':<30} {BASE_MULTIPLIER}\n")
        f.write(f"{'Aggression Level:':<30} {AGGRESSION_LEVEL}\n")
        f.write(f"{'Max Spread:':<30} ${MAX_SPREAD}\n")
        f.write(f"{'Re-entry Enabled:':<30} {REENTRY_ENABLED}\n")
        f.write(f"{'Commission:':<30} {COMMISSION_PERCENT}%\n")
        f.write(f"\n{'Custom Levels:':<30} {CUSTOM_LEVELS}\n")
        f.write(f"{'Custom Multipliers:':<30} {CUSTOM_MULTIPLIERS}\n")
        f.write("\n")

        # Performance Summary
        f.write("=" * 100 + "\n")
        f.write("                            PERFORMANCE SUMMARY\n")
        f.write("=" * 100 + "\n\n")
        f.write(f"{'Initial Balance:':<30} ${initial_balance:,.2f}\n")
        f.write(f"{'Final Balance:':<30} ${final_balance:,.2f}\n")
        f.write(f"{'Final Equity:':<30} ${final_equity:,.2f}\n")
        f.write(f"{'Net Profit:':<30} ${net_profit:+,.2f}\n")
        f.write(f"{'Total Return:':<30} {total_return_pct:+.2f}%\n")
        f.write("\n")
        f.write(f"{'Gross Profit:':<30} ${gross_profit:,.2f}\n")
        f.write(f"{'Gross Loss:':<30} ${gross_loss:,.2f}\n")
        f.write(f"{'Profit Factor:':<30} {profit_factor:.2f}\n")
        f.write("\n")
        f.write(f"{'Max Drawdown:':<30} ${max_drawdown:,.2f} ({max_drawdown_pct:.2f}%)\n")
        f.write(f"{'Total Commission:':<30} ${total_commission:,.2f}\n")
        f.write("\n")

        # Trade Statistics
        f.write("=" * 100 + "\n")
        f.write("                             TRADE STATISTICS\n")
        f.write("=" * 100 + "\n\n")
        f.write(f"{'Total Trades:':<30} {total_trades}\n")
        f.write(f"{'Winning Trades:':<30} {winning_trades} ({win_rate:.1f}%)\n")
        f.write(f"{'Losing Trades:':<30} {losing_trades}\n")
        f.write(f"{'Average Win:':<30} ${avg_win:,.2f}\n")
        f.write(f"{'Average Loss:':<30} ${avg_loss:,.2f}\n")
        f.write(f"{'Largest Win:':<30} ${largest_win:,.2f}\n")
        f.write(f"{'Largest Loss:':<30} ${largest_loss:,.2f}\n")
        f.write(f"{'Win/Loss Ratio:':<30} {avg_win/avg_loss:.2f}\n" if avg_loss > 0 else "")
        f.write("\n")

        # Thread Statistics
        f.write("=" * 100 + "\n")
        f.write("                             THREAD STATISTICS\n")
        f.write("=" * 100 + "\n\n")
        f.write(f"{'Total Threads:':<30} {summary['total_threads']}\n")
        f.write(f"{'Profitable Threads:':<30} {summary['profitable_threads']} ({thread_win_rate:.1f}%)\n")
        f.write(f"{'Average Thread P/L:':<30} ${avg_thread_pnl:+,.2f}\n")
        f.write(f"{'Max Averaging Level:':<30} {max_level_reached}\n")
        f.write(f"{'Avg Max Level:':<30} {avg_max_level:.1f}\n")
        f.write("\n")

        # Re-entry Analysis
        f.write("=" * 100 + "\n")
        f.write("                             RE-ENTRY ANALYSIS\n")
        f.write("=" * 100 + "\n\n")
        f.write(f"{'Re-entries Opened:':<30} {total_reentries}\n")
        f.write(f"{'Successful (TP Hit):':<30} {successful_reentries} ({reentry_success_rate:.1f}%)\n")
        f.write(f"{'Abandoned:':<30} {summary['reentries_abandoned']}\n")
        f.write(f"{'Spread Filtered:':<30} {summary['spread_filtered']}\n")
        f.write("\n")

        # Event Breakdown
        f.write("=" * 100 + "\n")
        f.write("                              EVENT BREAKDOWN\n")
        f.write("=" * 100 + "\n\n")
        f.write(f"{'Event Type':<25} {'Count':>10}\n")
        f.write("-" * 40 + "\n")
        for event_type, count in sorted(events.items()):
            f.write(f"{event_type:<25} {count:>10}\n")
        f.write("\n")

        # Level Distribution
        f.write("=" * 100 + "\n")
        f.write("                         MAX LEVEL DISTRIBUTION\n")
        f.write("=" * 100 + "\n\n")
        f.write(f"{'Level':<10} {'Threads':>10} {'Percent':>10} {'Bar'}\n")
        f.write("-" * 60 + "\n")
        for level in sorted(level_dist.keys()):
            count = level_dist[level]
            pct = count / summary['total_threads'] * 100 if summary['total_threads'] > 0 else 0
            bar = "#" * int(pct / 2)
            f.write(f"{level:<10} {count:>10} {pct:>9.1f}% {bar}\n")
        f.write("\n")

        # Thread Summary Table
        f.write("=" * 100 + "\n")
        f.write("                              THREAD SUMMARY\n")
        f.write("=" * 100 + "\n\n")
        f.write(f"{'#':<5} {'Entry Price':>15} {'Max Lvl':>10} {'Trades':>10} {'Wins':>10} {'Win%':>10} {'P/L':>15}\n")
        f.write("-" * 80 + "\n")
        for t in threads:
            win_pct = t['wins'] / t['total_trades'] * 100 if t['total_trades'] > 0 else 0
            f.write(f"{t['id']:<5} ${t['entry_price']:>13,.2f} {t['max_level']:>10} {t['total_trades']:>10} {t['wins']:>10} {win_pct:>9.1f}% ${t['pnl']:>+13,.2f}\n")
        f.write("\n")

        # Detailed Trade Log
        f.write("=" * 100 + "\n")
        f.write("                            DETAILED TRADE LOG\n")
        f.write("=" * 100 + "\n\n")
        f.write(f"{'Time':<20} {'Thread':>7} {'Event':<12} {'Lvl':>4} {'Price':>12} {'Lots':>10} {'Profit':>12} {'Balance':>12}\n")
        f.write("-" * 100 + "\n")

        for t in trades:
            profit_str = f"${t['profit']:+,.2f}" if t['profit'] != 0 else ""
            f.write(f"{t['timestamp']:<20} {t['thread_id']:>7} {t['event_type']:<12} {t['level']:>4} ${t['price']:>10,.2f} {t['lots']:>10.4f} {profit_str:>12} ${t['balance']:>10,.2f}\n")

        f.write("\n" + "=" * 100 + "\n")
        f.write("                              END OF REPORT\n")
        f.write("=" * 100 + "\n")

    print(f"Trade report generated: {output_file}")
    return output_file


if __name__ == "__main__":
    generate_report()
