"""LLM-based tool selector middleware."""

from __future__ import annotations

import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, Annotated, Literal, Union

if TYPE_CHECKING:
    from collections.abc import Awaitable, Callable

    from langchain.tools import BaseTool

from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import HumanMessage
from pydantic import Field, TypeAdapter
from typing_extensions import TypedDict

from langchain.agents.middleware.types import (
    AgentMiddleware,
    ModelCallResult,
    ModelRequest,
    ModelResponse,
)
from langchain.chat_models.base import init_chat_model

logger = logging.getLogger(__name__)

DEFAULT_SYSTEM_PROMPT = (
    "Your goal is to select the most relevant tools for answering the user's query."
)


@dataclass
class _SelectionRequest:
    """Prepared inputs for tool selection."""

    available_tools: list[BaseTool]
    system_message: str
    last_user_message: HumanMessage
    model: BaseChatModel
    valid_tool_names: list[str]


def _create_tool_selection_response(tools: list[BaseTool]) -> TypeAdapter:
    """Create a structured output schema for tool selection.

    Args:
        tools: Available tools to include in the schema.

    Returns:
        `TypeAdapter` for a schema where each tool name is a `Literal` with its
            description.
    """
    if not tools:
        msg = "Invalid usage: tools must be non-empty"
        raise AssertionError(msg)

    # Create a Union of Annotated Literal types for each tool name with description
    # Example: Union[Annotated[Literal["tool1"], Field(description="...")], ...] noqa: ERA001
    literals = [
        Annotated[Literal[tool.name], Field(description=tool.description)] for tool in tools
    ]
    selected_tool_type = Union[tuple(literals)]  # type: ignore[valid-type]  # noqa: UP007

    description = "Tools to use. Place the most relevant tools first."

    class ToolSelectionResponse(TypedDict):
        """Use to select relevant tools."""

        tools: Annotated[list[selected_tool_type], Field(description=description)]  # type: ignore[valid-type]

    return TypeAdapter(ToolSelectionResponse)


def _render_tool_list(tools: list[BaseTool]) -> str:
    """Format tools as markdown list.

    Args:
        tools: Tools to format.

    Returns:
        Markdown string with each tool on a new line.
    """
    return "\n".join(f"- {tool.name}: {tool.description}" for tool in tools)


class LLMToolSelectorMiddleware(AgentMiddleware):
    """Uses an LLM to select relevant tools before calling the main model.

    When an agent has many tools available, this middleware filters them down
    to only the most relevant ones for the user's query. This reduces token usage
    and helps the main model focus on the right tools.

    Examples:
        !!! example "Limit to 3 tools"

            ```python
            from langchain.agents.middleware import LLMToolSelectorMiddleware

            middleware = LLMToolSelectorMiddleware(max_tools=3)

            agent = create_agent(
                model="openai:gpt-4o",
                tools=[tool1, tool2, tool3, tool4, tool5],
                middleware=[middleware],
            )
            ```

        !!! example "Use a smaller model for selection"

            ```python
            middleware = LLMToolSelectorMiddleware(model="openai:gpt-4o-mini", max_tools=2)
            ```
    """

    def __init__(
        self,
        *,
        model: str | BaseChatModel | None = None,
        system_prompt: str = DEFAULT_SYSTEM_PROMPT,
        max_tools: int | None = None,
        always_include: list[str] | None = None,
    ) -> None:
        """Initialize the tool selector.

        Args:
            model: Model to use for selection.

                If not provided, uses the agent's main model.

                Can be a model identifier string or `BaseChatModel` instance.
            system_prompt: Instructions for the selection model.
            max_tools: Maximum number of tools to select.

                If the model selects more, only the first `max_tools` will be used.

                If not specified, there is no limit.
            always_include: Tool names to always include regardless of selection.

                These do not count against the `max_tools` limit.
        """
        super().__init__()
        self.system_prompt = system_prompt
        self.max_tools = max_tools
        self.always_include = always_include or []

        if isinstance(model, (BaseChatModel, type(None))):
            self.model: BaseChatModel | None = model
        else:
            self.model = init_chat_model(model)

    def _prepare_selection_request(self, request: ModelRequest) -> _SelectionRequest | None:
        """Prepare inputs for tool selection.

        Returns:
            `SelectionRequest` with prepared inputs, or `None` if no selection is
                needed.
        """
        # If no tools available, return None
        if not request.tools or len(request.tools) == 0:
            return None

        # Filter to only BaseTool instances (exclude provider-specific tool dicts)
        base_tools = [tool for tool in request.tools if not isinstance(tool, dict)]

        # Validate that always_include tools exist
        if self.always_include:
            available_tool_names = {tool.name for tool in base_tools}
            missing_tools = [
                name for name in self.always_include if name not in available_tool_names
            ]
            if missing_tools:
                msg = (
                    f"Tools in always_include not found in request: {missing_tools}. "
                    f"Available tools: {sorted(available_tool_names)}"
                )
                raise ValueError(msg)

        # Separate tools that are always included from those available for selection
        available_tools = [tool for tool in base_tools if tool.name not in self.always_include]

        # If no tools available for selection, return None
        if not available_tools:
            return None

        system_message = self.system_prompt
        # If there's a max_tools limit, append instructions to the system prompt
        if self.max_tools is not None:
            system_message += (
                f"\nIMPORTANT: List the tool names in order of relevance, "
                f"with the most relevant first. "
                f"If you exceed the maximum number of tools, "
                f"only the first {self.max_tools} will be used."
            )

        # Get the last user message from the conversation history
        last_user_message: HumanMessage
        for message in reversed(request.messages):
            if isinstance(message, HumanMessage):
                last_user_message = message
                break
        else:
            msg = "No user message found in request messages"
            raise AssertionError(msg)

        model = self.model or request.model
        valid_tool_names = [tool.name for tool in available_tools]

        return _SelectionRequest(
            available_tools=available_tools,
            system_message=system_message,
            last_user_message=last_user_message,
            model=model,
            valid_tool_names=valid_tool_names,
        )

    def _process_selection_response(
        self,
        response: dict,
        available_tools: list[BaseTool],
        valid_tool_names: list[str],
        request: ModelRequest,
    ) -> ModelRequest:
        """Process the selection response and return filtered `ModelRequest`."""
        selected_tool_names: list[str] = []
        invalid_tool_selections = []

        for tool_name in response["tools"]:
            if tool_name not in valid_tool_names:
                invalid_tool_selections.append(tool_name)
                continue

            # Only add if not already selected and within max_tools limit
            if tool_name not in selected_tool_names and (
                self.max_tools is None or len(selected_tool_names) < self.max_tools
            ):
                selected_tool_names.append(tool_name)

        if invalid_tool_selections:
            msg = f"Model selected invalid tools: {invalid_tool_selections}"
            raise ValueError(msg)

        # Filter tools based on selection and append always-included tools
        selected_tools: list[BaseTool] = [
            tool for tool in available_tools if tool.name in selected_tool_names
        ]
        always_included_tools: list[BaseTool] = [
            tool
            for tool in request.tools
            if not isinstance(tool, dict) and tool.name in self.always_include
        ]
        selected_tools.extend(always_included_tools)

        # Also preserve any provider-specific tool dicts from the original request
        provider_tools = [tool for tool in request.tools if isinstance(tool, dict)]

        return request.override(tools=[*selected_tools, *provider_tools])

    def wrap_model_call(
        self,
        request: ModelRequest,
        handler: Callable[[ModelRequest], ModelResponse],
    ) -> ModelCallResult:
        """Filter tools based on LLM selection before invoking the model via handler."""
        selection_request = self._prepare_selection_request(request)
        if selection_request is None:
            return handler(request)

        # Create dynamic response model with Literal enum of available tool names
        type_adapter = _create_tool_selection_response(selection_request.available_tools)
        schema = type_adapter.json_schema()
        structured_model = selection_request.model.with_structured_output(schema)

        response = structured_model.invoke(
            [
                {"role": "system", "content": selection_request.system_message},
                selection_request.last_user_message,
            ]
        )

        # Response should be a dict since we're passing a schema (not a Pydantic model class)
        if not isinstance(response, dict):
            msg = f"Expected dict response, got {type(response)}"
            raise AssertionError(msg)
        modified_request = self._process_selection_response(
            response, selection_request.available_tools, selection_request.valid_tool_names, request
        )
        return handler(modified_request)

    async def awrap_model_call(
        self,
        request: ModelRequest,
        handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
    ) -> ModelCallResult:
        """Filter tools based on LLM selection before invoking the model via handler."""
        selection_request = self._prepare_selection_request(request)
        if selection_request is None:
            return await handler(request)

        # Create dynamic response model with Literal enum of available tool names
        type_adapter = _create_tool_selection_response(selection_request.available_tools)
        schema = type_adapter.json_schema()
        structured_model = selection_request.model.with_structured_output(schema)

        response = await structured_model.ainvoke(
            [
                {"role": "system", "content": selection_request.system_message},
                selection_request.last_user_message,
            ]
        )

        # Response should be a dict since we're passing a schema (not a Pydantic model class)
        if not isinstance(response, dict):
            msg = f"Expected dict response, got {type(response)}"
            raise AssertionError(msg)
        modified_request = self._process_selection_response(
            response, selection_request.available_tools, selection_request.valid_tool_names, request
        )
        return await handler(modified_request)
