"""
Logging configuration for the RAG system.

This module provides structured logging capabilities with different log levels,
JSON formatting for production, and configurable outputs.
"""

import logging
import json
import sys
from datetime import datetime
from pathlib import Path
from typing import Dict, Any, Optional
from airagagent.config import BASE_DIR


class JSONFormatter(logging.Formatter):
    """JSON formatter for structured logging."""

    def format(self, record: logging.LogRecord) -> str:
        log_entry = {
            "timestamp": datetime.fromtimestamp(record.created).isoformat(),
            "level": record.levelname,
            "logger": record.name,
            "message": record.getMessage(),
            "module": record.module,
            "function": record.funcName,
            "line": record.lineno,
        }

        # Add exception info if present
        if record.exc_info:
            log_entry["exception"] = self.formatException(record.exc_info)

        # Add extra fields if present
        if hasattr(record, 'extra_fields'):
            log_entry.update(record.extra_fields)

        return json.dumps(log_entry, ensure_ascii=False)


class ConsoleFormatter(logging.Formatter):
    """Human-readable console formatter."""

    def format(self, record: logging.LogRecord) -> str:
        timestamp = datetime.fromtimestamp(record.created).strftime('%Y-%m-%d %H:%M:%S')
        level = f"[{record.levelname}]"
        module = f"[{record.module}]"
        message = record.getMessage()

        formatted = f"{timestamp} {level} {module} {message}"

        # Add exception info if present
        if record.exc_info:
            formatted += f"\n{self.formatException(record.exc_info)}"

        return formatted


def setup_logging(
    level: str = "INFO",
    log_to_file: bool = True,
    log_to_console: bool = True,
    log_dir: Optional[Path] = None,
    json_format: bool = False
) -> logging.Logger:
    """
    Set up logging configuration for the RAG system.

    Args:
        level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
        log_to_file: Whether to log to file
        log_to_console: Whether to log to console
        log_dir: Directory for log files (defaults to BASE_DIR/logs)
        json_format: Whether to use JSON formatting for file logs

    Returns:
        Root logger configured with the specified settings
    """
    if log_dir is None:
        log_dir = BASE_DIR / "logs"

    # Create log directory
    log_dir.mkdir(parents=True, exist_ok=True)

    # Convert string level to logging level
    numeric_level = getattr(logging, level.upper(), logging.INFO)

    # Get root logger
    logger = logging.getLogger()
    logger.setLevel(numeric_level)

    # Remove existing handlers
    for handler in logger.handlers[:]:
        logger.removeHandler(handler)

    # Create formatters
    console_formatter = ConsoleFormatter()
    json_formatter = JSONFormatter()

    # Console handler
    if log_to_console:
        console_handler = logging.StreamHandler(sys.stdout)
        console_handler.setLevel(numeric_level)
        console_handler.setFormatter(console_formatter)
        logger.addHandler(console_handler)

    # File handler
    if log_to_file:
        # General log file
        log_file = log_dir / f"rag_system_{datetime.now().strftime('%Y%m%d')}.log"
        file_handler = logging.FileHandler(log_file, encoding='utf-8')
        file_handler.setLevel(numeric_level)

        if json_format:
            file_handler.setFormatter(json_formatter)
        else:
            file_handler.setFormatter(console_formatter)

        logger.addHandler(file_handler)

        # Error log file (WARNING and above)
        error_log_file = log_dir / f"rag_system_errors_{datetime.now().strftime('%Y%m%d')}.log"
        error_handler = logging.FileHandler(error_log_file, encoding='utf-8')
        error_handler.setLevel(logging.WARNING)
        error_handler.setFormatter(console_formatter)
        logger.addHandler(error_handler)

    return logger


def get_logger(name: str) -> logging.Logger:
    """
    Get a logger instance for a specific module.

    Args:
        name: Logger name (usually __name__)

    Returns:
        Configured logger instance
    """
    return logging.getLogger(name)


def log_performance(logger: logging.Logger, operation: str, duration: float, **extra_fields):
    """Log performance metrics."""
    logger.info(
        f"Performance: {operation} completed in {duration:.3f}s",
        extra={"extra_fields": {"operation": operation, "duration": duration, **extra_fields}}
    )


def log_error_with_context(
    logger: logging.Logger,
    error: Exception,
    context: Dict[str, Any],
    message: Optional[str] = None
):
    """Log an error with additional context information."""
    if message is None:
        message = f"Error occurred: {error.__class__.__name__}"

    logger.error(
        message,
        exc_info=error,
        extra={
            "extra_fields": {
                "error_type": error.__class__.__name__,
                "error_message": str(error),
                **context
            }
        }
    )


# Default logger setup
default_logger = setup_logging(
    level="INFO",
    log_to_file=True,
    log_to_console=True,
    json_format=False
)

