"""MusicGen wrapper for text-to-music generation."""
import torch
from audiocraft.models import MusicGen
from audiocraft.data.audio import audio_write
import numpy as np
from typing import Optional, List
from pathlib import Path
import sys

sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent))

from shared.config import settings
from shared.logging_config import setup_logger

logger = setup_logging("audio.musicgen")


class MusicGenerator:
    """Wrapper for MusicGen model."""
    
    def __init__(self, model_name: str = "facebook/musicgen-small"):
        """Initialize MusicGen model.
        
        Args:
            model_name: HuggingFace model name (small, medium, large)
        """
        self.model_name = model_name
        self.model = None
        self.device = "cuda" if torch.cuda.is_available() and settings.use_gpu else "cpu"
        logger.info(f"Using device: {self.device}")
        
    def load_model(self):
        """Load the MusicGen model into memory."""
        if self.model is not None:
            logger.info("Model already loaded")
            return
            
        logger.info(f"Loading MusicGen model: {self.model_name}")
        try:
            self.model = MusicGen.get_pretrained(self.model_name, device=self.device)
            logger.info("Model loaded successfully")
        except Exception as e:
            logger.error(f"Failed to load model: {e}")
            raise
    
    def unload_model(self):
        """Unload model to free GPU memory."""
        if self.model is not None:
            del self.model
            self.model = None
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            logger.info("Model unloaded")
    
    def generate(
        self,
        prompts: List[str],
        duration: int = 30,
        temperature: float = 1.0,
        top_k: int = 250,
        top_p: float = 0.0,
        cfg_coef: float = 3.0
    ) -> torch.Tensor:
        """Generate music from text prompts.
        
        Args:
            prompts: List of text descriptions
            duration: Duration in seconds
            temperature: Sampling temperature (higher = more random)
            top_k: Top-k sampling parameter
            top_p: Top-p (nucleus) sampling parameter
            cfg_coef: Classifier-free guidance coefficient
            
        Returns:
            Generated audio tensor (num_samples, num_channels, length)
        """
        if self.model is None:
            self.load_model()
        
        logger.info(f"Generating music for {len(prompts)} prompts, duration={duration}s")
        
        try:
            # Set generation parameters
            self.model.set_generation_params(
                duration=duration,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
                cfg_coef=cfg_coef
            )
            
            # Generate
            with torch.no_grad():
                audio = self.model.generate(prompts)
            
            logger.info(f"Generated audio shape: {audio.shape}")
            return audio
            
        except Exception as e:
            logger.error(f"Generation failed: {e}")
            raise
    
    def save_audio(
        self,
        audio: torch.Tensor,
        output_path: str,
        sample_rate: Optional[int] = None,
        format: str = "wav"
    ) -> str:
        """Save generated audio to file.
        
        Args:
            audio: Audio tensor
            output_path: Output file path (without extension)
            sample_rate: Sample rate (uses model's if None)
            format: Audio format (wav, mp3)
            
        Returns:
            Path to saved file
        """
        if sample_rate is None:
            sample_rate = self.model.sample_rate
        
        logger.info(f"Saving audio to: {output_path}.{format}")
        
        try:
            # Convert to CPU and numpy
            audio_np = audio.cpu().squeeze(0).numpy()
            
            # Save using audiocraft's audio_write
            output_file = audio_write(
                output_path,
                audio_np,
                sample_rate,
                format=format,
                strategy="loudness"  # Normalize audio
            )
            
            logger.info(f"Audio saved successfully: {output_file}")
            return output_file
            
        except Exception as e:
            logger.error(f"Failed to save audio: {e}")
            raise
    
    @property
    def sample_rate(self) -> int:
        """Get model sample rate."""
        if self.model is None:
            return 32000  # Default for MusicGen
        return self.model.sample_rate


# Singleton instance
_music_generator = None


def get_music_generator() -> MusicGenerator:
    """Get singleton MusicGenerator instance.
    
    Returns:
        MusicGenerator instance
    """
    global _music_generator
    if _music_generator is None:
        _music_generator = MusicGenerator(model_name=settings.musicgen_model)
    return _music_generator
