#!/usr/bin/env python3
"""
Parameter Optimization Engine for Trading Strategies
Premium feature with subscription tier restrictions
"""

import os
import sys
import json
import itertools
import numpy as np
import pandas as pd
from typing import Dict, List, Any, Optional, Tuple
from datetime import datetime, timedelta
import random

# Add the project root to Python path
sys.path.insert(0, os.path.dirname(__file__))

from universal_backtesting import UniversalBacktestingEngine

class ParameterOptimizer:
    """Optimizes trading strategy parameters using various algorithms"""

    def __init__(self):
        self.engine = UniversalBacktestingEngine()
        self.optimization_methods = {
            'grid_search': self._grid_search,
            'random_search': self._random_search,
            'genetic_algorithm': self._genetic_algorithm
        }

    def optimize_strategy(self, market: str, strategy_name: str, base_parameters: Dict[str, Any],
                         parameter_ranges: Dict[str, List], optimization_config: Dict[str, Any],
                         data_file: str = None, timeframe: str = '1hour') -> Dict[str, Any]:
        """
        Optimize strategy parameters

        Args:
            market: 'forex', 'crypto', 'stocks', 'nba'
            strategy_name: Name of strategy to optimize
            base_parameters: Base parameter values
            parameter_ranges: Parameter ranges to test {'param_name': [min, max, step]}
            optimization_config: Optimization settings
            data_file: Path to data file
            timeframe: Data timeframe

        Returns:
            Dict with optimization results
        """
        method = optimization_config.get('method', 'grid_search')
        max_evaluations = optimization_config.get('max_evaluations', 100)
        target_metric = optimization_config.get('target_metric', 'sharpe_ratio')

        print(f"🎯 Starting parameter optimization for {strategy_name}")
        print(f"   Method: {method}")
        print(f"   Max evaluations: {max_evaluations}")
        print(f"   Target metric: {target_metric}")
        print(f"   Parameters to optimize: {list(parameter_ranges.keys())}")

        if method not in self.optimization_methods:
            return {'error': f'Unknown optimization method: {method}'}

        # Run optimization
        optimizer_func = self.optimization_methods[method]
        results = optimizer_func(
            market=market,
            strategy_name=strategy_name,
            base_parameters=base_parameters,
            parameter_ranges=parameter_ranges,
            max_evaluations=max_evaluations,
            target_metric=target_metric,
            data_file=data_file,
            timeframe=timeframe
        )

        # Format results
        return self._format_optimization_results(results, target_metric)

    def _grid_search(self, market: str, strategy_name: str, base_parameters: Dict,
                    parameter_ranges: Dict, max_evaluations: int, target_metric: str,
                    data_file: str = None, timeframe: str = '1hour') -> List[Dict]:
        """Grid search optimization"""

        # Generate parameter combinations
        param_combinations = []
        param_names = list(parameter_ranges.keys())

        # Create grid of parameter values
        param_values = []
        for param_name, param_range in parameter_ranges.items():
            if len(param_range) == 3:  # [min, max, step]
                min_val, max_val, step = param_range
                values = np.arange(min_val, max_val + step, step)
                if isinstance(min_val, int) and isinstance(max_val, int) and isinstance(step, int):
                    values = [int(v) for v in values]
                param_values.append(values)
            else:  # List of specific values
                param_values.append(param_range)

        # Generate all combinations
        combinations = list(itertools.product(*param_values))

        # Limit to max_evaluations
        if len(combinations) > max_evaluations:
            # Randomly sample combinations
            indices = np.random.choice(len(combinations), max_evaluations, replace=False)
            combinations = [combinations[i] for i in indices]

        print(f"   Grid search: Testing {len(combinations)} parameter combinations")

        results = []
        for i, combo in enumerate(combinations):
            # Create parameter set
            params = base_parameters.copy()
            for j, param_name in enumerate(param_names):
                params[param_name] = combo[j]

            # Run backtest
            try:
                backtest_result = self.engine.run_backtest(
                    market=market,
                    strategy_name=strategy_name,
                    parameters=params,
                    data_source=data_file,
                    min_trades=10,
                    timeframe=timeframe
                )

                if backtest_result.get('success'):
                    # Extract target metric
                    metric_value = self._extract_metric(backtest_result, target_metric)
                    results.append({
                        'parameters': params,
                        'metrics': backtest_result.get('results', {}),
                        'target_metric': metric_value,
                        'evaluation_number': i + 1
                    })

                    if (i + 1) % 10 == 0:
                        print(f"   Completed {i + 1}/{len(combinations)} evaluations")

            except Exception as e:
                print(f"   Error in evaluation {i + 1}: {e}")
                continue

        return results

    def _random_search(self, market: str, strategy_name: str, base_parameters: Dict,
                      parameter_ranges: Dict, max_evaluations: int, target_metric: str,
                      data_file: str = None, timeframe: str = '1hour') -> List[Dict]:
        """Random search optimization"""

        print(f"   Random search: Testing {max_evaluations} random parameter combinations")

        results = []
        param_names = list(parameter_ranges.keys())

        for i in range(max_evaluations):
            # Generate random parameter set
            params = base_parameters.copy()
            for param_name, param_range in parameter_ranges.items():
                if len(param_range) == 3:  # [min, max, step]
                    min_val, max_val, step = param_range
                    if isinstance(min_val, float) or isinstance(max_val, float):
                        params[param_name] = random.uniform(min_val, max_val)
                    else:
                        params[param_name] = random.randint(min_val, max_val)
                else:  # List of specific values
                    params[param_name] = random.choice(param_range)

            # Run backtest
            try:
                backtest_result = self.engine.run_backtest(
                    market=market,
                    strategy_name=strategy_name,
                    parameters=params,
                    data_source=data_file,
                    min_trades=10,
                    timeframe=timeframe
                )

                if backtest_result.get('success'):
                    metric_value = self._extract_metric(backtest_result, target_metric)
                    results.append({
                        'parameters': params,
                        'metrics': backtest_result.get('results', {}),
                        'target_metric': metric_value,
                        'evaluation_number': i + 1
                    })

                    if (i + 1) % 10 == 0:
                        print(f"   Completed {i + 1}/{max_evaluations} evaluations")

            except Exception as e:
                print(f"   Error in evaluation {i + 1}: {e}")
                continue

        return results

    def _genetic_algorithm(self, market: str, strategy_name: str, base_parameters: Dict,
                          parameter_ranges: Dict, max_evaluations: int, target_metric: str,
                          data_file: str = None, timeframe: str = '1hour') -> List[Dict]:
        """Simple genetic algorithm optimization"""

        population_size = min(20, max_evaluations // 5)
        generations = max_evaluations // population_size

        print(f"   Genetic algorithm: {population_size} population, {generations} generations")

        # Initialize population
        population = []
        param_names = list(parameter_ranges.keys())

        for _ in range(population_size):
            individual = {}
            for param_name, param_range in parameter_ranges.items():
                if len(param_range) == 3:  # [min, max, step]
                    min_val, max_val, step = param_range
                    if isinstance(min_val, float) or isinstance(max_val, float):
                        individual[param_name] = random.uniform(min_val, max_val)
                    else:
                        individual[param_name] = random.randint(min_val, max_val)
                else:  # List of specific values
                    individual[param_name] = random.choice(param_range)
            population.append(individual)

        all_results = []

        for generation in range(generations):
            print(f"   Generation {generation + 1}/{generations}")

            # Evaluate population
            generation_results = []
            for individual in population:
                params = base_parameters.copy()
                params.update(individual)

                try:
                    backtest_result = self.engine.run_backtest(
                        market=market,
                        strategy_name=strategy_name,
                        parameters=params,
                        data_source=data_file,
                        min_trades=10,
                        timeframe=timeframe
                    )

                    if backtest_result.get('success'):
                        metric_value = self._extract_metric(backtest_result, target_metric)
                        result = {
                            'parameters': params,
                            'metrics': backtest_result.get('results', {}),
                            'target_metric': metric_value,
                            'generation': generation + 1
                        }
                        generation_results.append(result)
                        all_results.append(result)

                except Exception as e:
                    continue

            if not generation_results:
                continue

            # Sort by fitness (higher is better for most metrics)
            generation_results.sort(key=lambda x: x['target_metric'], reverse=True)

            # Select top performers
            elite_size = max(2, population_size // 5)
            elites = generation_results[:elite_size]

            # Create next generation
            new_population = [elite['parameters'] for elite in elites]

            # Crossover and mutation
            while len(new_population) < population_size:
                # Select parents
                parent1 = random.choice(elites)['parameters']
                parent2 = random.choice(elites)['parameters']

                # Crossover
                child = {}
                for param_name in param_names:
                    if random.random() < 0.5:
                        child[param_name] = parent1[param_name]
                    else:
                        child[param_name] = parent2[param_name]

                    # Mutation
                    if random.random() < 0.1:  # 10% mutation rate
                        param_range = parameter_ranges[param_name]
                        if len(param_range) == 3:
                            min_val, max_val, step = param_range
                            if isinstance(min_val, float) or isinstance(max_val, float):
                                child[param_name] = random.uniform(min_val, max_val)
                            else:
                                child[param_name] = random.randint(min_val, max_val)
                        else:
                            child[param_name] = random.choice(param_range)

                new_population.append(child)

            population = new_population

        return all_results

    def _extract_metric(self, backtest_result: Dict, target_metric: str) -> float:
        """Extract target metric from backtest results"""
        results = backtest_result.get('results', {})

        metric_mapping = {
            'sharpe_ratio': lambda r: r.get('sharpe_ratio', 0),
            'total_profit': lambda r: r.get('total_profit', 0),
            'win_rate': lambda r: r.get('win_rate', 0),
            'profit_factor': lambda r: r.get('profit_factor', 1),
            'max_drawdown': lambda r: -abs(r.get('max_drawdown', 0)),  # Negative for minimization
            'total_trades': lambda r: backtest_result.get('total_trades', 0),
            'expectancy': lambda r: r.get('expectancy', 0)
        }

        if target_metric in metric_mapping:
            return metric_mapping[target_metric](results)

        # Default to Sharpe ratio
        return results.get('sharpe_ratio', 0)

    def _format_optimization_results(self, results: List[Dict], target_metric: str) -> Dict[str, Any]:
        """Format optimization results for display"""

        if not results:
            return {'error': 'No successful optimization results'}

        # Sort by target metric (descending for most metrics)
        sorted_results = sorted(results, key=lambda x: x['target_metric'], reverse=True)

        # Get best result
        best_result = sorted_results[0]

        # Calculate statistics
        target_values = [r['target_metric'] for r in results]
        target_values = [v for v in target_values if isinstance(v, (int, float)) and not np.isnan(v)]

        if target_values:
            best_value = max(target_values)
            worst_value = min(target_values)
            avg_value = np.mean(target_values)
            std_value = np.std(target_values)
        else:
            best_value = worst_value = avg_value = std_value = 0

        # Get top 5 parameter sets
        top_5 = sorted_results[:5]

        return {
            'success': True,
            'total_evaluations': len(results),
            'best_result': {
                'parameters': best_result['parameters'],
                'target_metric': best_result['target_metric'],
                'metrics': best_result['metrics']
            },
            'statistics': {
                'best_value': round(best_value, 4),
                'worst_value': round(worst_value, 4),
                'average_value': round(avg_value, 4),
                'standard_deviation': round(std_value, 4),
                'target_metric': target_metric
            },
            'top_5_results': [
                {
                    'rank': i + 1,
                    'parameters': result['parameters'],
                    'target_metric': result['target_metric'],
                    'key_metrics': {
                        'win_rate': result['metrics'].get('win_rate', 0),
                        'total_profit': result['metrics'].get('total_profit', 0),
                        'sharpe_ratio': result['metrics'].get('sharpe_ratio', 0)
                    }
                }
                for i, result in enumerate(top_5)
            ],
            'all_results': sorted_results,
            'optimization_timestamp': str(datetime.now())
        }
