#!/usr/bin/env python3
"""
Visualization Service for Backtesting AI
Generates charts and graphs for backtest results using Matplotlib, Seaborn, and Plotly
"""

import os
import sys
import json
import base64
import io
from typing import Dict, List, Any, Optional
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objects as go
import plotly.express as px
from plotly.utils import PlotlyJSONEncoder

# Set matplotlib backend to non-GUI
plt.switch_backend('Agg')

class BacktestVisualizer:
    """Generates visualizations for backtest results"""

    def __init__(self):
        # Set up plotting style
        plt.style.use('default')
        sns.set_palette("husl")

        # Create output directory
        self.output_dir = "data/charts"
        os.makedirs(self.output_dir, exist_ok=True)

    def generate_profit_chart(self, trades: List[Dict[str, Any]], market: str = "general",
                            chart_type: str = "matplotlib") -> Dict[str, Any]:
        """Generate profit over time chart"""

        if not trades:
            return {"error": "No trade data provided"}

        # Convert to DataFrame
        df = pd.DataFrame(trades)

        # 🔧 ROBUSTNESS FIX: Handle NaN and missing values
        if 'profit' not in df.columns:
            return {"error": "Trade data missing 'profit' column"}

        # Replace NaN profits with 0 (trades with no profit yet)
        df['profit'] = pd.to_numeric(df['profit'], errors='coerce').fillna(0)

        # Handle date conversion safely
        if 'date' not in df.columns:
            return {"error": "Trade data missing 'date' column"}

        try:
            df['date'] = pd.to_datetime(df['date'], errors='coerce')
            # Drop rows with invalid dates
            df = df.dropna(subset=['date'])
        except Exception as e:
            return {"error": f"Invalid date format in trade data: {e}"}

        if len(df) == 0:
            return {"error": "No valid trades with dates and profits"}

        df = df.sort_values('date')

        # Calculate cumulative profit (now safe with numeric data)
        df['cumulative_profit'] = df['profit'].cumsum()

        if chart_type == "plotly":
            return self._generate_plotly_profit_chart(df, market)
        else:
            return self._generate_matplotlib_profit_chart(df, market)

    def generate_win_loss_chart(self, trades: List[Dict[str, Any]], market: str = "general",
                              chart_type: str = "matplotlib") -> Dict[str, Any]:
        """Generate win/loss distribution chart"""

        if not trades:
            return {"error": "No trade data provided"}

        df = pd.DataFrame(trades)

        # 🔧 ROBUSTNESS FIX: Handle missing outcome column
        if 'outcome' not in df.columns:
            return {"error": "Trade data missing 'outcome' column"}

        # Handle NaN outcomes by treating them as neutral
        df['outcome'] = df['outcome'].fillna('neutral')

        # 🔧 ROBUSTNESS FIX: Handle profit column safely
        if 'profit' in df.columns:
            df['profit'] = pd.to_numeric(df['profit'], errors='coerce').fillna(0)

        # Calculate win/loss stats
        win_rate = (df['outcome'] == 'win').mean() * 100
        total_trades = len(df)
        winning_trades = (df['outcome'] == 'win').sum()
        losing_trades = (df['outcome'] == 'loss').sum()

        if chart_type == "plotly":
            return self._generate_plotly_win_loss_chart(df, market)
        else:
            return self._generate_matplotlib_win_loss_chart(df, market)

    def generate_market_comparison_chart(self, results: Dict[str, List[Dict[str, Any]]],
                                       chart_type: str = "matplotlib") -> Dict[str, Any]:
        """Generate comparison chart across different markets"""

        if not results:
            return {"error": "No results data provided"}

        market_data = {}
        for market, trades in results.items():
            if trades:
                df = pd.DataFrame(trades)

                # 🔧 ROBUSTNESS FIX: Handle missing columns
                if 'profit' not in df.columns or 'outcome' not in df.columns:
                    continue  # Skip this market

                # Convert profit to numeric safely
                df['profit'] = pd.to_numeric(df['profit'], errors='coerce').fillna(0)
                df['outcome'] = df['outcome'].fillna('neutral')

                # Safe calculations
                total_return = df['profit'].sum() if not df['profit'].isna().all() else 0
                win_rate = (df['outcome'] == 'win').mean() * 100 if len(df) > 0 else 0
                total_trades = len(df)

                market_data[market] = {
                    'total_return': float(total_return) if pd.notna(total_return) else 0,
                    'win_rate': float(win_rate) if pd.notna(win_rate) else 0,
                    'total_trades': total_trades
                }

        if chart_type == "plotly":
            return self._generate_plotly_comparison_chart(market_data)
        else:
            return self._generate_matplotlib_comparison_chart(market_data)

    def generate_risk_metrics_chart(self, trades: List[Dict[str, Any]], market: str = "general",
                                  chart_type: str = "matplotlib") -> Dict[str, Any]:
        """Generate risk metrics visualization"""

        if not trades:
            return {"error": "No trade data provided"}

        df = pd.DataFrame(trades)

        # 🔧 ROBUSTNESS FIX: Handle missing columns safely
        if 'profit' not in df.columns:
            return {"error": "Trade data missing 'profit' column"}

        if 'outcome' not in df.columns:
            return {"error": "Trade data missing 'outcome' column"}

        # Convert profit to numeric and handle NaN
        df['profit'] = pd.to_numeric(df['profit'], errors='coerce').fillna(0)
        df['outcome'] = df['outcome'].fillna('neutral')

        # Calculate risk metrics with NaN safety
        returns = df['profit']

        # Only calculate if we have valid returns
        if len(returns) == 0 or returns.isna().all():
            return {"error": "No valid profit data for risk analysis"}

        # 🔧 ROBUSTNESS FIX: Safe cumulative calculation
        try:
            cumulative = returns.cumsum()
        except Exception as e:
            return {"error": f"Failed to calculate cumulative returns: {e}"}

        # Maximum drawdown (handle NaN safely)
        try:
            running_max = cumulative.expanding().max()
            drawdown = cumulative - running_max
            max_drawdown = drawdown.min() if not drawdown.isna().all() else 0
        except Exception as e:
            max_drawdown = 0

        # Sharpe ratio (assuming 0% risk-free rate for simplicity)
        try:
            returns_std = returns.std()
            returns_mean = returns.mean()
            if pd.notna(returns_std) and returns_std > 0 and pd.notna(returns_mean):
                sharpe_ratio = returns_mean / returns_std * (252 ** 0.5)  # Annualized
            else:
                sharpe_ratio = 0
        except Exception as e:
            sharpe_ratio = 0

        # Volatility
        try:
            volatility = returns_std * (252 ** 0.5) if pd.notna(returns_std) else 0  # Annualized
        except Exception as e:
            volatility = 0

        # Safe win/loss calculations
        try:
            win_mask = df['outcome'] == 'win'
            loss_mask = df['outcome'] == 'loss'

            win_rate = win_mask.mean() * 100 if len(df) > 0 else 0

            win_profits = df[win_mask]['profit']
            loss_profits = df[loss_mask]['profit']

            avg_win = win_profits.mean() if len(win_profits) > 0 and not win_profits.isna().all() else 0
            avg_loss = loss_profits.mean() if len(loss_profits) > 0 and not loss_profits.isna().all() else 0
        except Exception as e:
            win_rate = 0
            avg_win = 0
            avg_loss = 0

        risk_data = {
            'max_drawdown': float(max_drawdown) if pd.notna(max_drawdown) else 0,
            'sharpe_ratio': float(sharpe_ratio) if pd.notna(sharpe_ratio) else 0,
            'volatility': float(volatility) if pd.notna(volatility) else 0,
            'win_rate': float(win_rate) if pd.notna(win_rate) else 0,
            'total_trades': len(df),
            'avg_win': float(avg_win) if pd.notna(avg_win) else 0,
            'avg_loss': float(avg_loss) if pd.notna(avg_loss) else 0
        }

        if chart_type == "plotly":
            return self._generate_plotly_risk_chart(risk_data, df, market)
        else:
            return self._generate_matplotlib_risk_chart(risk_data, df, market)

    def _generate_matplotlib_profit_chart(self, df: pd.DataFrame, market: str) -> Dict[str, Any]:
        """Generate profit chart using Matplotlib"""

        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8))

        # Cumulative profit
        ax1.plot(df['date'], df['cumulative_profit'], linewidth=2, color='#667eea')
        ax1.fill_between(df['date'], df['cumulative_profit'], 0,
                        where=(df['cumulative_profit'] >= 0),
                        color='#667eea', alpha=0.3)
        ax1.fill_between(df['date'], df['cumulative_profit'], 0,
                        where=(df['cumulative_profit'] < 0),
                        color='#e74c3c', alpha=0.3)
        ax1.set_title(f'{market.title()} Backtest - Cumulative Profit Over Time', fontsize=14, fontweight='bold')
        ax1.set_ylabel('Cumulative Profit ($)')
        ax1.grid(True, alpha=0.3)

        # Daily profits
        colors = ['green' if x >= 0 else 'red' for x in df['profit']]
        ax2.bar(df['date'], df['profit'], color=colors, alpha=0.7)
        ax2.set_title('Daily Profit/Loss', fontsize=12)
        ax2.set_ylabel('Profit ($)')
        ax2.set_xlabel('Date')
        ax2.grid(True, alpha=0.3)

        plt.tight_layout()

        # Save to base64
        buf = io.BytesIO()
        fig.savefig(buf, format='png', dpi=100, bbox_inches='tight')
        buf.seek(0)
        image_base64 = base64.b64encode(buf.read()).decode('utf-8')
        plt.close(fig)

        return {
            "chart_type": "profit_over_time",
            "market": market,
            "format": "matplotlib",
            "image_base64": image_base64,
            "total_return": df['cumulative_profit'].iloc[-1] if len(df) > 0 else 0,
            "total_trades": len(df),
            "win_rate": (df['outcome'] == 'win').mean() * 100
        }

    def _generate_plotly_profit_chart(self, df: pd.DataFrame, market: str) -> Dict[str, Any]:
        """Generate profit chart using Plotly"""

        # Create figure with subplots
        fig = go.Figure()

        # Cumulative profit line
        fig.add_trace(go.Scatter(
            x=df['date'],
            y=df['cumulative_profit'],
            mode='lines',
            name='Cumulative Profit',
            line=dict(color='#667eea', width=3),
            fill='tozeroy'
        ))

        # Daily profit bars
        colors = ['#27ae60' if x >= 0 else '#e74c3c' for x in df['profit']]
        fig.add_trace(go.Bar(
            x=df['date'],
            y=df['profit'],
            name='Daily P/L',
            marker_color=colors,
            opacity=0.7,
            yaxis='y2'
        ))

        # Update layout
        fig.update_layout(
            title=f'{market.title()} Backtest - Profit Analysis',
            xaxis_title='Date',
            yaxis_title='Cumulative Profit ($)',
            yaxis2=dict(
                title='Daily Profit ($)',
                overlaying='y',
                side='right'
            ),
            showlegend=True,
            template='plotly_white'
        )

        # Convert to JSON for frontend
        chart_json = json.dumps(fig, cls=PlotlyJSONEncoder)

        return {
            "chart_type": "profit_over_time",
            "market": market,
            "format": "plotly",
            "chart_data": chart_json,
            "total_return": df['cumulative_profit'].iloc[-1] if len(df) > 0 else 0,
            "total_trades": len(df),
            "win_rate": (df['outcome'] == 'win').mean() * 100
        }

    def _generate_matplotlib_win_loss_chart(self, df: pd.DataFrame, market: str) -> Dict[str, Any]:
        """Generate win/loss chart using Matplotlib"""

        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

        # Win/Loss pie chart
        outcome_counts = df['outcome'].value_counts()
        colors = ['#27ae60' if idx == 'win' else '#e74c3c' for idx in outcome_counts.index]
        ax1.pie(outcome_counts.values, labels=outcome_counts.index.str.title(),
                autopct='%1.1f%%', colors=colors, startangle=90)
        ax1.set_title('Win/Loss Distribution', fontsize=14, fontweight='bold')

        # Profit distribution histogram
        ax2.hist(df['profit'], bins=20, alpha=0.7, color='#667eea', edgecolor='black')
        ax2.axvline(df['profit'].mean(), color='red', linestyle='--', linewidth=2,
                   label=f'Mean: ${df["profit"].mean():.2f}')
        ax2.set_title('Profit Distribution', fontsize=14, fontweight='bold')
        ax2.set_xlabel('Profit ($)')
        ax2.set_ylabel('Frequency')
        ax2.legend()
        ax2.grid(True, alpha=0.3)

        plt.tight_layout()

        # Save to base64
        buf = io.BytesIO()
        fig.savefig(buf, format='png', dpi=100, bbox_inches='tight')
        buf.seek(0)
        image_base64 = base64.b64encode(buf.read()).decode('utf-8')
        plt.close(fig)

        return {
            "chart_type": "win_loss_analysis",
            "market": market,
            "format": "matplotlib",
            "image_base64": image_base64,
            "win_rate": (df['outcome'] == 'win').mean() * 100,
            "total_trades": len(df),
            "avg_profit": df['profit'].mean(),
            "median_profit": df['profit'].median()
        }

    def _generate_plotly_win_loss_chart(self, df: pd.DataFrame, market: str) -> Dict[str, Any]:
        """Generate win/loss chart using Plotly"""

        fig = go.Figure()

        # Win/Loss pie chart
        outcome_counts = df['outcome'].value_counts()
        colors = ['#27ae60' if idx == 'win' else '#e74c3c' for idx in outcome_counts.index]

        fig.add_trace(go.Pie(
            labels=outcome_counts.index.str.title(),
            values=outcome_counts.values,
            marker_colors=colors,
            title="Win/Loss Distribution"
        ))

        # Update layout
        fig.update_layout(
            title=f'{market.title()} Backtest - Win/Loss Analysis',
            template='plotly_white'
        )

        # Convert to JSON
        chart_json = json.dumps(fig, cls=PlotlyJSONEncoder)

        return {
            "chart_type": "win_loss_analysis",
            "market": market,
            "format": "plotly",
            "chart_data": chart_json,
            "win_rate": (df['outcome'] == 'win').mean() * 100,
            "total_trades": len(df),
            "avg_profit": df['profit'].mean(),
            "median_profit": df['profit'].median()
        }

    def _generate_matplotlib_comparison_chart(self, market_data: Dict[str, Dict]) -> Dict[str, Any]:
        """Generate market comparison chart using Matplotlib"""

        markets = list(market_data.keys())
        win_rates = [market_data[m]['win_rate'] for m in markets]
        total_returns = [market_data[m]['total_return'] for m in markets]

        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

        # Win rate comparison
        bars1 = ax1.bar(markets, win_rates, color='#667eea', alpha=0.7)
        ax1.set_title('Win Rate Comparison (%)', fontsize=14, fontweight='bold')
        ax1.set_ylabel('Win Rate (%)')
        ax1.tick_params(axis='x', rotation=45)

        # Add value labels
        for bar, rate in zip(bars1, win_rates):
            ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1,
                    f'{rate:.1f}%', ha='center', va='bottom', fontweight='bold')

        # Total return comparison
        colors = ['#27ae60' if r >= 0 else '#e74c3c' for r in total_returns]
        bars2 = ax2.bar(markets, total_returns, color=colors, alpha=0.7)
        ax2.set_title('Total Return Comparison ($)', fontsize=14, fontweight='bold')
        ax2.set_ylabel('Total Return ($)')
        ax2.tick_params(axis='x', rotation=45)

        # Add value labels
        for bar, ret in zip(bars2, total_returns):
            ax2.text(bar.get_x() + bar.get_width()/2,
                    bar.get_height() + (0.01 * abs(bar.get_height())) if bar.get_height() != 0 else 0.01,
                    f'${ret:.2f}', ha='center', va='bottom' if ret >= 0 else 'top', fontweight='bold')

        plt.tight_layout()

        # Save to base64
        buf = io.BytesIO()
        fig.savefig(buf, format='png', dpi=100, bbox_inches='tight')
        buf.seek(0)
        image_base64 = base64.b64encode(buf.read()).decode('utf-8')
        plt.close(fig)

        return {
            "chart_type": "market_comparison",
            "format": "matplotlib",
            "image_base64": image_base64,
            "markets": market_data
        }

    def _generate_plotly_comparison_chart(self, market_data: Dict[str, Dict]) -> Dict[str, Any]:
        """Generate market comparison chart using Plotly"""

        markets = list(market_data.keys())
        win_rates = [market_data[m]['win_rate'] for m in markets]
        total_returns = [market_data[m]['total_return'] for m in markets]

        fig = go.Figure()

        # Win rate bars
        fig.add_trace(go.Bar(
            x=markets,
            y=win_rates,
            name='Win Rate (%)',
            marker_color='#667eea',
            opacity=0.7
        ))

        # Total return bars
        fig.add_trace(go.Bar(
            x=markets,
            y=total_returns,
            name='Total Return ($)',
            marker_color='#27ae60',
            opacity=0.7,
            yaxis='y2'
        ))

        # Update layout
        fig.update_layout(
            title='Market Performance Comparison',
            xaxis_title='Market',
            yaxis_title='Win Rate (%)',
            yaxis2=dict(
                title='Total Return ($)',
                overlaying='y',
                side='right'
            ),
            showlegend=True,
            template='plotly_white'
        )

        # Convert to JSON
        chart_json = json.dumps(fig, cls=PlotlyJSONEncoder)

        return {
            "chart_type": "market_comparison",
            "format": "plotly",
            "chart_data": chart_json,
            "markets": market_data
        }

    def _generate_matplotlib_risk_chart(self, risk_data: Dict, df: pd.DataFrame, market: str) -> Dict[str, Any]:
        """Generate risk metrics chart using Matplotlib"""

        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(14, 10))

        # Risk metrics text
        metrics_text = ".2f"".2f"".2f"".1f"f"""
        Max Drawdown: ${risk_data['max_drawdown']:.2f}
        Sharpe Ratio: {risk_data['sharpe_ratio']:.2f}
        Volatility: {risk_data['volatility']:.2f}
        Win Rate: {risk_data['win_rate']:.1f}%
        Total Trades: {risk_data['total_trades']}
        Avg Win: ${risk_data['avg_win']:.2f}
        Avg Loss: ${risk_data['avg_loss']:.2f}
        """

        ax1.text(0.1, 0.5, metrics_text, fontsize=12, verticalalignment='center',
                bbox=dict(boxstyle="round,pad=0.3", facecolor='#f8f9fa', alpha=0.8))
        ax1.set_title('Risk Metrics Summary', fontsize=14, fontweight='bold')
        ax1.set_xlim(0, 1)
        ax1.set_ylim(0, 1)
        ax1.axis('off')

        # Drawdown chart
        cumulative = df['profit'].cumsum()
        running_max = cumulative.expanding().max()
        drawdown = cumulative - running_max

        ax2.fill_between(range(len(drawdown)), drawdown, 0, color='#e74c3c', alpha=0.3)
        ax2.plot(drawdown, color='#e74c3c', linewidth=2)
        ax2.set_title('Drawdown Over Time', fontsize=14, fontweight='bold')
        ax2.set_ylabel('Drawdown ($)')
        ax2.grid(True, alpha=0.3)

        # Monthly returns (if we have date data)
        if 'date' in df.columns:
            df_copy = df.copy()
            df_copy['date'] = pd.to_datetime(df_copy['date'])
            df_copy['month'] = df_copy['date'].dt.to_period('M')
            monthly_returns = df_copy.groupby('month')['profit'].sum()

            monthly_returns.plot(kind='bar', ax=ax3, color='#667eea', alpha=0.7)
            ax3.set_title('Monthly Returns', fontsize=14, fontweight='bold')
            ax3.set_ylabel('Return ($)')
            ax3.tick_params(axis='x', rotation=45)
            ax3.grid(True, alpha=0.3)

        # Profit factor chart
        gross_profit = df[df['profit'] > 0]['profit'].sum()
        gross_loss = abs(df[df['profit'] < 0]['profit'].sum())
        profit_factor = gross_profit / gross_loss if gross_loss > 0 else float('inf')

        categories = ['Gross Profit', 'Gross Loss', 'Profit Factor']
        values = [gross_profit, gross_loss, profit_factor]
        colors = ['#27ae60', '#e74c3c', '#667eea']

        bars = ax4.bar(categories, values, color=colors, alpha=0.7)
        ax4.set_title('Profit Analysis', fontsize=14, fontweight='bold')
        ax4.set_ylabel('Amount ($)')
        ax4.grid(True, alpha=0.3)

        # Add value labels
        for bar, value in zip(bars, values):
            if value != float('inf'):
                ax4.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                        f'${value:.2f}' if categories[bars.tolist().index(bar)] != 'Profit Factor' else f'{value:.2f}',
                        ha='center', va='bottom', fontweight='bold')

        plt.tight_layout()

        # Save to base64
        buf = io.BytesIO()
        fig.savefig(buf, format='png', dpi=100, bbox_inches='tight')
        buf.seek(0)
        image_base64 = base64.b64encode(buf.read()).decode('utf-8')
        plt.close(fig)

        return {
            "chart_type": "risk_analysis",
            "market": market,
            "format": "matplotlib",
            "image_base64": image_base64,
            "risk_metrics": risk_data
        }

    def _generate_plotly_risk_chart(self, risk_data: Dict, df: pd.DataFrame, market: str) -> Dict[str, Any]:
        """Generate risk metrics chart using Plotly"""

        # Create subplot figure
        fig = go.Figure()

        # Add drawdown trace
        cumulative = df['profit'].cumsum()
        running_max = cumulative.expanding().max()
        drawdown = cumulative - running_max

        fig.add_trace(go.Scatter(
            x=list(range(len(drawdown))),
            y=drawdown,
            mode='lines',
            name='Drawdown',
            line=dict(color='#e74c3c', width=2),
            fill='tozeroy',
            fillcolor='rgba(231, 76, 60, 0.3)'
        ))

        # Update layout
        fig.update_layout(
            title=f'{market.title()} Backtest - Risk Analysis',
            xaxis_title='Trade Number',
            yaxis_title='Drawdown ($)',
            showlegend=True,
            template='plotly_white'
        )

        # Convert to JSON
        chart_json = json.dumps(fig, cls=PlotlyJSONEncoder)

        return {
            "chart_type": "risk_analysis",
            "market": market,
            "format": "plotly",
            "chart_data": chart_json,
            "risk_metrics": risk_data
        }

def main():
    """Main function for testing the visualization service"""
    import sys

    if len(sys.argv) > 1:
        # API mode - called from Next.js
        visualizer = BacktestVisualizer()
        command = sys.argv[1]
        params = json.loads(sys.argv[2]) if len(sys.argv) > 2 else {}

        try:
            if command == "profit_chart":
                result = visualizer.generate_profit_chart(
                    params['trades'],
                    params.get('market', 'general'),
                    params.get('chart_type', 'matplotlib')
                )
            elif command == "win_loss_chart":
                result = visualizer.generate_win_loss_chart(
                    params['trades'],
                    params.get('market', 'general'),
                    params.get('chart_type', 'matplotlib')
                )
            elif command == "comparison_chart":
                result = visualizer.generate_market_comparison_chart(
                    params['results'],
                    params.get('chart_type', 'matplotlib')
                )
            elif command == "risk_chart":
                result = visualizer.generate_risk_metrics_chart(
                    params['trades'],
                    params.get('market', 'general'),
                    params.get('chart_type', 'matplotlib')
                )
            else:
                result = {"error": f"Unknown command: {command}"}

            # Output JSON for API consumption
            print(json.dumps(result))

        except Exception as e:
            print(json.dumps({"error": str(e)}))
            sys.exit(1)

    else:
        # Test mode
        visualizer = BacktestVisualizer()

        # Generate sample data for testing
        import random
        from datetime import datetime, timedelta

        sample_trades = []
        current_date = datetime.now() - timedelta(days=100)

        for i in range(50):
            trade = {
                'id': f'trade_{i}',
                'date': current_date.strftime('%Y-%m-%d'),
                'outcome': 'win' if random.random() > 0.4 else 'loss',
                'profit': (random.random() - 0.4) * 200,  # Random profit/loss
                'label': f'Trade {i}',
                'context': {}
            }
            sample_trades.append(trade)
            current_date += timedelta(days=2)

        print("🧪 Testing Visualization Service")
        print("=" * 50)

        # Test profit chart
        print("\n📊 Testing Profit Chart (Matplotlib):")
        profit_chart = visualizer.generate_profit_chart(sample_trades, 'crypto', 'matplotlib')
        print(f"✅ Generated profit chart: {len(profit_chart.get('image_base64', ''))} bytes")

        # Test win/loss chart
        print("\n📊 Testing Win/Loss Chart (Plotly):")
        win_loss_chart = visualizer.generate_win_loss_chart(sample_trades, 'crypto', 'plotly')
        print(f"✅ Generated win/loss chart: {len(win_loss_chart.get('chart_data', ''))} chars")

        # Test comparison chart
        print("\n📊 Testing Comparison Chart:")
        comparison_data = {
            'crypto': sample_trades,
            'stocks': sample_trades[:20],  # Shorter list
            'forex': sample_trades[10:30]  # Middle section
        }
        comparison_chart = visualizer.generate_market_comparison_chart(comparison_data, 'matplotlib')
        print(f"✅ Generated comparison chart: {len(comparison_chart.get('image_base64', ''))} bytes")

        print("\n✅ Visualization service test completed!")

if __name__ == "__main__":
    main()
