"""Types for setting agent response formats."""

from __future__ import annotations

import uuid
from dataclasses import dataclass, is_dataclass
from types import UnionType
from typing import (
    TYPE_CHECKING,
    Any,
    Generic,
    Literal,
    TypeVar,
    Union,
    get_args,
    get_origin,
)

from langchain_core.tools import BaseTool, StructuredTool
from pydantic import BaseModel, TypeAdapter
from typing_extensions import Self, is_typeddict

if TYPE_CHECKING:
    from collections.abc import Callable, Iterable

    from langchain_core.messages import AIMessage

# Supported schema types: Pydantic models, dataclasses, TypedDict, JSON schema dicts
SchemaT = TypeVar("SchemaT")

SchemaKind = Literal["pydantic", "dataclass", "typeddict", "json_schema"]


class StructuredOutputError(Exception):
    """Base class for structured output errors."""

    ai_message: AIMessage


class MultipleStructuredOutputsError(StructuredOutputError):
    """Raised when model returns multiple structured output tool calls when only one is expected."""

    def __init__(self, tool_names: list[str], ai_message: AIMessage) -> None:
        """Initialize `MultipleStructuredOutputsError`.

        Args:
            tool_names: The names of the tools called for structured output.
            ai_message: The AI message that contained the invalid multiple tool calls.
        """
        self.tool_names = tool_names
        self.ai_message = ai_message

        super().__init__(
            "Model incorrectly returned multiple structured responses "
            f"({', '.join(tool_names)}) when only one is expected."
        )


class StructuredOutputValidationError(StructuredOutputError):
    """Raised when structured output tool call arguments fail to parse according to the schema."""

    def __init__(self, tool_name: str, source: Exception, ai_message: AIMessage) -> None:
        """Initialize `StructuredOutputValidationError`.

        Args:
            tool_name: The name of the tool that failed.
            source: The exception that occurred.
            ai_message: The AI message that contained the invalid structured output.
        """
        self.tool_name = tool_name
        self.source = source
        self.ai_message = ai_message
        super().__init__(f"Failed to parse structured output for tool '{tool_name}': {source}.")


def _parse_with_schema(
    schema: type[SchemaT] | dict, schema_kind: SchemaKind, data: dict[str, Any]
) -> Any:
    """Parse data using for any supported schema type.

    Args:
        schema: The schema type (Pydantic model, `dataclass`, or `TypedDict`)
        schema_kind: One of `"pydantic"`, `"dataclass"`, `"typeddict"`, or
            `"json_schema"`
        data: The data to parse

    Returns:
        The parsed instance according to the schema type

    Raises:
        ValueError: If parsing fails
    """
    if schema_kind == "json_schema":
        return data
    try:
        adapter: TypeAdapter[SchemaT] = TypeAdapter(schema)
        return adapter.validate_python(data)
    except Exception as e:
        schema_name = getattr(schema, "__name__", str(schema))
        msg = f"Failed to parse data to {schema_name}: {e}"
        raise ValueError(msg) from e


@dataclass(init=False)
class _SchemaSpec(Generic[SchemaT]):
    """Describes a structured output schema."""

    schema: type[SchemaT]
    """The schema for the response, can be a Pydantic model, `dataclass`, `TypedDict`,
    or JSON schema dict."""

    name: str
    """Name of the schema, used for tool calling.

    If not provided, the name will be the model name or `"response_format"` if it's a
    JSON schema.
    """

    description: str
    """Custom description of the schema.

    If not provided, provided will use the model's docstring.
    """

    schema_kind: SchemaKind
    """The kind of schema."""

    json_schema: dict[str, Any]
    """JSON schema associated with the schema."""

    strict: bool = False
    """Whether to enforce strict validation of the schema."""

    def __init__(
        self,
        schema: type[SchemaT],
        *,
        name: str | None = None,
        description: str | None = None,
        strict: bool = False,
    ) -> None:
        """Initialize SchemaSpec with schema and optional parameters."""
        self.schema = schema

        if name:
            self.name = name
        elif isinstance(schema, dict):
            self.name = str(schema.get("title", f"response_format_{str(uuid.uuid4())[:4]}"))
        else:
            self.name = str(getattr(schema, "__name__", f"response_format_{str(uuid.uuid4())[:4]}"))

        self.description = description or (
            schema.get("description", "")
            if isinstance(schema, dict)
            else getattr(schema, "__doc__", None) or ""
        )

        self.strict = strict

        if isinstance(schema, dict):
            self.schema_kind = "json_schema"
            self.json_schema = schema
        elif isinstance(schema, type) and issubclass(schema, BaseModel):
            self.schema_kind = "pydantic"
            self.json_schema = schema.model_json_schema()
        elif is_dataclass(schema):
            self.schema_kind = "dataclass"
            self.json_schema = TypeAdapter(schema).json_schema()
        elif is_typeddict(schema):
            self.schema_kind = "typeddict"
            self.json_schema = TypeAdapter(schema).json_schema()
        else:
            msg = (
                f"Unsupported schema type: {type(schema)}. "
                f"Supported types: Pydantic models, dataclasses, TypedDicts, and JSON schema dicts."
            )
            raise ValueError(msg)


@dataclass(init=False)
class ToolStrategy(Generic[SchemaT]):
    """Use a tool calling strategy for model responses."""

    schema: type[SchemaT]
    """Schema for the tool calls."""

    schema_specs: list[_SchemaSpec[SchemaT]]
    """Schema specs for the tool calls."""

    tool_message_content: str | None
    """The content of the tool message to be returned when the model calls
    an artificial structured output tool."""

    handle_errors: (
        bool | str | type[Exception] | tuple[type[Exception], ...] | Callable[[Exception], str]
    )
    """Error handling strategy for structured output via `ToolStrategy`.

    - `True`: Catch all errors with default error template
    - `str`: Catch all errors with this custom message
    - `type[Exception]`: Only catch this exception type with default message
    - `tuple[type[Exception], ...]`: Only catch these exception types with default
        message
    - `Callable[[Exception], str]`: Custom function that returns error message
    - `False`: No retry, let exceptions propagate
    """

    def __init__(
        self,
        schema: type[SchemaT],
        *,
        tool_message_content: str | None = None,
        handle_errors: bool
        | str
        | type[Exception]
        | tuple[type[Exception], ...]
        | Callable[[Exception], str] = True,
    ) -> None:
        """Initialize `ToolStrategy`.

        Initialize `ToolStrategy` with schemas, tool message content, and error handling
        strategy.
        """
        self.schema = schema
        self.tool_message_content = tool_message_content
        self.handle_errors = handle_errors

        def _iter_variants(schema: Any) -> Iterable[Any]:
            """Yield leaf variants from Union and JSON Schema oneOf."""
            if get_origin(schema) in (UnionType, Union):
                for arg in get_args(schema):
                    yield from _iter_variants(arg)
                return

            if isinstance(schema, dict) and "oneOf" in schema:
                for sub in schema.get("oneOf", []):
                    yield from _iter_variants(sub)
                return

            yield schema

        self.schema_specs = [_SchemaSpec(s) for s in _iter_variants(schema)]


@dataclass(init=False)
class ProviderStrategy(Generic[SchemaT]):
    """Use the model provider's native structured output method."""

    schema: type[SchemaT]
    """Schema for native mode."""

    schema_spec: _SchemaSpec[SchemaT]
    """Schema spec for native mode."""

    def __init__(
        self,
        schema: type[SchemaT],
    ) -> None:
        """Initialize ProviderStrategy with schema."""
        self.schema = schema
        self.schema_spec = _SchemaSpec(schema)

    def to_model_kwargs(self) -> dict[str, Any]:
        """Convert to kwargs to bind to a model to force structured output."""
        # OpenAI:
        # - see https://platform.openai.com/docs/guides/structured-outputs
        response_format = {
            "type": "json_schema",
            "json_schema": {
                "name": self.schema_spec.name,
                "schema": self.schema_spec.json_schema,
            },
        }
        return {"response_format": response_format}


@dataclass
class OutputToolBinding(Generic[SchemaT]):
    """Information for tracking structured output tool metadata.

    This contains all necessary information to handle structured responses
    generated via tool calls, including the original schema, its type classification,
    and the corresponding tool implementation used by the tools strategy.
    """

    schema: type[SchemaT]
    """The original schema provided for structured output
    (Pydantic model, dataclass, TypedDict, or JSON schema dict)."""

    schema_kind: SchemaKind
    """Classification of the schema type for proper response construction."""

    tool: BaseTool
    """LangChain tool instance created from the schema for model binding."""

    @classmethod
    def from_schema_spec(cls, schema_spec: _SchemaSpec[SchemaT]) -> Self:
        """Create an `OutputToolBinding` instance from a `SchemaSpec`.

        Args:
            schema_spec: The `SchemaSpec` to convert

        Returns:
            An `OutputToolBinding` instance with the appropriate tool created
        """
        return cls(
            schema=schema_spec.schema,
            schema_kind=schema_spec.schema_kind,
            tool=StructuredTool(
                args_schema=schema_spec.json_schema,
                name=schema_spec.name,
                description=schema_spec.description,
            ),
        )

    def parse(self, tool_args: dict[str, Any]) -> SchemaT:
        """Parse tool arguments according to the schema.

        Args:
            tool_args: The arguments from the tool call

        Returns:
            The parsed response according to the schema type

        Raises:
            ValueError: If parsing fails
        """
        return _parse_with_schema(self.schema, self.schema_kind, tool_args)


@dataclass
class ProviderStrategyBinding(Generic[SchemaT]):
    """Information for tracking native structured output metadata.

    This contains all necessary information to handle structured responses
    generated via native provider output, including the original schema,
    its type classification, and parsing logic for provider-enforced JSON.
    """

    schema: type[SchemaT]
    """The original schema provided for structured output
    (Pydantic model, `dataclass`, `TypedDict`, or JSON schema dict)."""

    schema_kind: SchemaKind
    """Classification of the schema type for proper response construction."""

    @classmethod
    def from_schema_spec(cls, schema_spec: _SchemaSpec[SchemaT]) -> Self:
        """Create a `ProviderStrategyBinding` instance from a `SchemaSpec`.

        Args:
            schema_spec: The `SchemaSpec` to convert

        Returns:
            A `ProviderStrategyBinding` instance for parsing native structured output
        """
        return cls(
            schema=schema_spec.schema,
            schema_kind=schema_spec.schema_kind,
        )

    def parse(self, response: AIMessage) -> SchemaT:
        """Parse `AIMessage` content according to the schema.

        Args:
            response: The `AIMessage` containing the structured output

        Returns:
            The parsed response according to the schema

        Raises:
            ValueError: If text extraction, JSON parsing or schema validation fails
        """
        # Extract text content from AIMessage and parse as JSON
        raw_text = self._extract_text_content_from_message(response)

        import json

        try:
            data = json.loads(raw_text)
        except Exception as e:
            schema_name = getattr(self.schema, "__name__", "response_format")
            msg = (
                f"Native structured output expected valid JSON for {schema_name}, "
                f"but parsing failed: {e}."
            )
            raise ValueError(msg) from e

        # Parse according to schema
        return _parse_with_schema(self.schema, self.schema_kind, data)

    def _extract_text_content_from_message(self, message: AIMessage) -> str:
        """Extract text content from an AIMessage.

        Args:
            message: The AI message to extract text from

        Returns:
            The extracted text content
        """
        content = message.content
        if isinstance(content, str):
            return content
        if isinstance(content, list):
            parts: list[str] = []
            for c in content:
                if isinstance(c, dict):
                    if c.get("type") == "text" and "text" in c:
                        parts.append(str(c["text"]))
                    elif "content" in c and isinstance(c["content"], str):
                        parts.append(c["content"])
                else:
                    parts.append(str(c))
            return "".join(parts)
        return str(content)


class AutoStrategy(Generic[SchemaT]):
    """Automatically select the best strategy for structured output."""

    schema: type[SchemaT]
    """Schema for automatic mode."""

    def __init__(
        self,
        schema: type[SchemaT],
    ) -> None:
        """Initialize AutoStrategy with schema."""
        self.schema = schema


ResponseFormat = ToolStrategy[SchemaT] | ProviderStrategy[SchemaT] | AutoStrategy[SchemaT]
