from __future__ import annotations

import logging
import sys
from typing import Any, Literal, cast

import httpx
import orjson

logger = logging.getLogger(__name__)


class LangGraphError(Exception):
    pass


class APIError(httpx.HTTPStatusError, LangGraphError):
    message: str
    request: httpx.Request

    body: object | None
    code: str | None
    param: str | None
    type: str | None

    def __init__(
        self,
        message: str,
        response_or_request: httpx.Response | httpx.Request,
        *,
        body: object | None,
    ) -> None:
        if isinstance(response_or_request, httpx.Response):
            req = response_or_request.request
            response = response_or_request
        else:
            req = response_or_request
            response = None

        httpx.HTTPStatusError.__init__(self, message, request=req, response=response)  # type: ignore[arg-type]
        LangGraphError.__init__(self)

        self.request = req
        self.message = message
        self.body = body

        if isinstance(body, dict):
            b = cast("dict[str, Any]", body)
            # Best-effort extraction of common fields if present
            code_val = b.get("code")
            self.code = code_val if isinstance(code_val, str) else None
            param_val = b.get("param")
            self.param = param_val if isinstance(param_val, str) else None
            t = b.get("type")
            self.type = t if isinstance(t, str) else None
        else:
            self.code = None
            self.param = None
            self.type = None


class APIResponseValidationError(APIError):
    response: httpx.Response
    status_code: int

    def __init__(
        self,
        response: httpx.Response,
        body: object | None,
        *,
        message: str | None = None,
    ) -> None:
        super().__init__(
            message or "Data returned by API invalid for expected schema.",
            response,
            body=body,
        )
        self.response = response
        self.status_code = response.status_code


class APIStatusError(APIError):
    response: httpx.Response
    status_code: int
    request_id: str | None

    def __init__(
        self, message: str, *, response: httpx.Response, body: object | None
    ) -> None:
        super().__init__(message, response, body=body)
        self.response = response
        self.status_code = response.status_code
        self.request_id = response.headers.get("x-request-id")


class APIConnectionError(APIError):
    def __init__(
        self, *, message: str = "Connection error.", request: httpx.Request
    ) -> None:
        super().__init__(message, response_or_request=request, body=None)


class APITimeoutError(APIConnectionError):
    def __init__(self, request: httpx.Request) -> None:
        super().__init__(message="Request timed out.", request=request)


class BadRequestError(APIStatusError):
    status_code: Literal[400] = 400


class AuthenticationError(APIStatusError):
    status_code: Literal[401] = 401


class PermissionDeniedError(APIStatusError):
    status_code: Literal[403] = 403


class NotFoundError(APIStatusError):
    status_code: Literal[404] = 404


class ConflictError(APIStatusError):
    status_code: Literal[409] = 409


class UnprocessableEntityError(APIStatusError):
    status_code: Literal[422] = 422


class RateLimitError(APIStatusError):
    status_code: Literal[429] = 429


class InternalServerError(APIStatusError):
    pass


def _extract_error_message(body: object | None, fallback: str) -> str:
    if isinstance(body, dict):
        b = cast("dict[str, Any]", body)
        for key in ("message", "detail", "error"):
            val = b.get(key)
            if isinstance(val, str) and val:
                return val
        # Sometimes errors are structured like {"error": {"message": "..."}}
        err = b.get("error")
        if isinstance(err, dict):
            e = cast("dict[str, Any]", err)
            for key in ("message", "detail"):
                val = e.get(key)
                if isinstance(val, str) and val:
                    return val
    return fallback


async def _adecode_error_body(r: httpx.Response) -> object | None:
    try:
        data = await r.aread()
    except Exception:
        return None
    if not data:
        return None
    try:
        return orjson.loads(data)
    except Exception:
        try:
            return data.decode()
        except Exception:
            return None


def _decode_error_body(r: httpx.Response) -> object | None:
    try:
        data = r.read()
    except Exception:
        return None
    if not data:
        return None
    try:
        return orjson.loads(data)
    except Exception:
        try:
            return data.decode()
        except Exception:
            return None


def _map_status_error(response: httpx.Response, body: object | None) -> APIStatusError:
    status = response.status_code
    reason = response.reason_phrase or "HTTP Error"
    message = _extract_error_message(body, f"{status} {reason}")
    if status == 400:
        return BadRequestError(message, response=response, body=body)
    if status == 401:
        return AuthenticationError(message, response=response, body=body)
    if status == 403:
        return PermissionDeniedError(message, response=response, body=body)
    if status == 404:
        return NotFoundError(message, response=response, body=body)
    if status == 409:
        return ConflictError(message, response=response, body=body)
    if status == 422:
        return UnprocessableEntityError(message, response=response, body=body)
    if status == 429:
        return RateLimitError(message, response=response, body=body)
    if status >= 500:
        return InternalServerError(message, response=response, body=body)
    return APIStatusError(message, response=response, body=body)


async def _araise_for_status_typed(r: httpx.Response) -> None:
    if r.status_code < 400:
        return
    body = await _adecode_error_body(r)
    err = _map_status_error(r, body)
    # Log for older Python versions without Exception notes
    if not (sys.version_info >= (3, 11)):
        logger.error(f"Error from langgraph-api: {getattr(err, 'message', '')}")
    raise err


def _raise_for_status_typed(r: httpx.Response) -> None:
    if r.status_code < 400:
        return
    body = _decode_error_body(r)
    err = _map_status_error(r, body)
    if not (sys.version_info >= (3, 11)):
        logger.error(f"Error from langgraph-api: {getattr(err, 'message', '')}")
    raise err
