{% set annotations = true %}
{% include '_header.py.jinja' %}
{% from '_utils.py.jinja' import is_async, maybe_async_def, maybe_await, maybe_async, recursive_types, active_provider with context %}
# -- template client.py.jinja --
import warnings
import logging
from datetime import timedelta
from pathlib import Path
from types import TracebackType
from typing_extensions import override

from pydantic import BaseModel

from . import types, models, errors, actions
from ._base_client import BasePrisma, UseClientDefault, USE_CLIENT_DEFAULT
from .types import DatasourceOverride, HttpConfig, MetricsFormat
from ._types import BaseModelT, PrismaMethod, TransactionId, Datasource
from .bases import _PrismaModel
from ._builder import QueryBuilder, dumps
from .generator.models import EngineType, OptionalValueFromEnvVar, BinaryPaths
from ._compat import removeprefix, model_parse
from ._constants import CREATE_MANY_SKIP_DUPLICATES_UNSUPPORTED, DEFAULT_CONNECT_TIMEOUT, DEFAULT_TX_MAX_WAIT, DEFAULT_TX_TIMEOUT
from ._raw_query import deserialize_raw_results
from ._metrics import Metrics
from .metadata import PRISMA_MODELS, RELATIONAL_FIELD_MAPPINGS
from ._transactions import AsyncTransactionManager, SyncTransactionManager

# re-exports
from ._base_client import SyncBasePrisma, AsyncBasePrisma, load_env as load_env
from ._registry import (
    register as register,
    get_client as get_client,
    RegisteredClient as RegisteredClient,
)


__all__ = (
    'ENGINE_TYPE',
    'SCHEMA_PATH',
    'BINARY_PATHS',
    'Batch',
    'Prisma',
    'Client',
    'load_env',
    'register',
    'get_client',
)

log: logging.Logger = logging.getLogger(__name__)

SCHEMA_PATH = Path('{{ schema_path.as_posix() }}')
PACKAGED_SCHEMA_PATH = Path(__file__).parent.joinpath('schema.prisma')
ENGINE_TYPE: EngineType = EngineType.{{ generator.config.engine_type }}
BINARY_PATHS = model_parse(BinaryPaths, {{ model_dict(binary_paths, by_alias=True) }})


class Prisma({% if is_async %}AsyncBasePrisma{% else %}SyncBasePrisma{% endif %}):
    # Note: these property names can be customised using `/// @Python(instance_name: '...')`
    # https://prisma-client-py.readthedocs.io/en/stable/reference/schema-extensions/#instance_name
    {% for model in dmmf.datamodel.models %}
    {{ model.instance_name }}: 'actions.{{ model.name }}Actions[models.{{ model.name }}]'
    {% endfor %}

    __slots__ = (
        {% for model in dmmf.datamodel.models %}
        '{{ model.instance_name }}',
        {% endfor %}
    )

    def __init__(
        self,
        *,
        use_dotenv: bool = True,
        log_queries: bool = False,
        auto_register: bool = False,
        datasource: DatasourceOverride | None = None,
        connect_timeout: int | timedelta = DEFAULT_CONNECT_TIMEOUT,
        http: HttpConfig | None = None,
    ) -> None:
        super().__init__(
            http=http,
            use_dotenv=use_dotenv,
            log_queries=log_queries,
            datasource=datasource,
            connect_timeout=connect_timeout,
        )
        self._set_generated_properties(
            schema_path=SCHEMA_PATH,
            engine_type=ENGINE_TYPE,
            prisma_models=PRISMA_MODELS,
            packaged_schema_path=PACKAGED_SCHEMA_PATH,
            relational_field_mappings=RELATIONAL_FIELD_MAPPINGS,
            preview_features=set({{ generator.preview_features }}),
            active_provider='{{ active_provider }}',
            default_datasource_name='{{ datasources[0].name }}',
        )

        {% for model in dmmf.datamodel.models %}
        self.{{ model.instance_name }} = actions.{{ model.name }}Actions[models.{{ model.name }}](self, models.{{ model.name }})
        {% endfor %}

        if auto_register:
            register(self)

    @property
    @override
    def _default_datasource(self) -> Datasource:
        return {
            'name': '{{ datasources[0].name }}',
            'url': OptionalValueFromEnvVar(**{{ model_dict(datasources[0].url, by_alias=True) }}).resolve(),
            {% if datasources[0].source_file_path %}
            'source_file_path': '{{ datasources[0].source_file_path.as_posix() }}',
            {% endif %}
        }

    {% if active_provider != 'mongodb' %}
    {{ maybe_async_def }}execute_raw(self, query: LiteralString, *args: Any) -> int:
        resp = {{ maybe_await }}self._execute(
            method='execute_raw',
            arguments={
                'query': query,
                'parameters': args,
            },
            model=None,
        )
        return int(resp['data']['result'])

    @overload
    {{ maybe_async_def }}query_first(
        self,
        query: LiteralString,
        *args: Any,
    ) -> dict[str, Any]:
        ...

    @overload
    {{ maybe_async_def }}query_first(
        self,
        query: LiteralString,
        *args: Any,
        model: Type[BaseModelT],
    ) -> Optional[BaseModelT]:
        ...

    {{ maybe_async_def }}query_first(
        self,
        query: LiteralString,
        *args: Any,
        model: Optional[Type[BaseModelT]] = None,
    ) -> Union[Optional[BaseModelT], dict[str, Any]]:
        """This function is the exact same as `query_raw()` but returns the first result.

        If model is given, the returned record is converted to the pydantic model first,
        otherwise a raw dictionary will be returned.
        """
        results: Sequence[Union[BaseModelT, dict[str, Any]]]
        if model is not None:
            results = {{ maybe_await }}self.query_raw(query, *args, model=model)
        else:
            results = {{ maybe_await }}self.query_raw(query, *args)

        if not results:
            return None

        return results[0]

    @overload
    {{ maybe_async_def }}query_raw(
        self,
        query: LiteralString,
        *args: Any,
    ) -> List[dict[str, Any]]:
        ...

    @overload
    {{ maybe_async_def }}query_raw(
        self,
        query: LiteralString,
        *args: Any,
        model: Type[BaseModelT],
    ) -> List[BaseModelT]:
        ...

    {{ maybe_async_def }}query_raw(
        self,
        query: LiteralString,
        *args: Any,
        model: Optional[Type[BaseModelT]] = None,
    ) -> Union[List[BaseModelT], List[dict[str, Any]]]:
        """Execute a raw SQL query against the database.

        If model is given, each returned record is converted to the pydantic model first,
        otherwise results will be raw dictionaries.
        """
        resp = {{ maybe_await }}self._execute(
            method='query_raw',
            arguments={
                'query': query,
                'parameters': args,
            },
            model=model,
        )
        result = resp['data']['result']
        if model is not None:
            return deserialize_raw_results(result, model=model)

        return deserialize_raw_results(result)
    {% endif %}

    def batch_(self) -> Batch:
        """Returns a context manager for grouping write queries into a single transaction."""
        return Batch(client=self)

    def tx(
        self,
        *,
        max_wait: Union[int, timedelta] = DEFAULT_TX_MAX_WAIT,
        timeout: Union[int, timedelta] = DEFAULT_TX_TIMEOUT,
    ) -> TransactionManager:
        """Returns a context manager for executing queries within a database transaction.

        Entering the context manager returns a new Prisma instance wrapping all
        actions within a transaction, queries will be isolated to the Prisma instance and
        will not be commited to the database until the context manager exits.

        By default, Prisma will wait a maximum of 2 seconds to acquire a transaction from the database. You can modify this
        default with the `max_wait` argument which accepts a value in milliseconds or `datetime.timedelta`.

        By default, Prisma will cancel and rollback ay transactions that last longer than 5 seconds. You can modify this timeout
        with the `timeout` argument which accepts a value in milliseconds or `datetime.timedelta`.

        Example usage:

        ```py
        {{ maybe_async }}with client.tx() as transaction:
            user1 = {{ maybe_await }}client.user.create({'name': 'Robert'})
            user2 = {{ maybe_await }}client.user.create({'name': 'Tegan'})
        ```

        In the above example, if the first database call succeeds but the second does not then neither of the records will be created.
        """
        return TransactionManager(
            client=self,
            max_wait=max_wait,
            timeout=timeout,
        )


TransactionManager = {% if is_async %}AsyncTransactionManager{% else %}SyncTransactionManager{% endif %}[Prisma]


# TODO: this should return the results as well
# TODO: don't require copy-pasting arguments between actions and batch actions
class Batch:
    {% for model in dmmf.datamodel.models %}
    {{ model.instance_name }}: '{{ model.name }}BatchActions'
    {% endfor %}

    def __init__(self, client: Prisma) -> None:
        self.__client = client
        self.__queries: List[str] = []
        self._active_provider = client._active_provider
        {% for model in dmmf.datamodel.models %}
        self.{{ model.instance_name }} = {{ model.name }}BatchActions(self)
        {% endfor %}

    def _add(self, **kwargs: Any) -> None:
        builder = QueryBuilder(
            **kwargs,
            prisma_models=PRISMA_MODELS,
            relational_field_mappings=RELATIONAL_FIELD_MAPPINGS,
        )
        self.__queries.append(builder.build_query())

    {{ maybe_async_def }}commit(self) -> None:
        """Execute the queries"""
        # TODO: normalise this, we should still call client._execute
        queries = self.__queries
        self.__queries = []

        payload = {
            'batch': [
                {
                    'query': query,
                    'variables': {},
                }
                for query in queries
            ],
            'transaction': True,
        }
        {{ maybe_await }}self.__client._engine.query(
            dumps(payload),
            tx_id=self.__client._tx_id,
        )

    {% if active_provider != 'mongodb' %}
    def execute_raw(self, query: LiteralString, *args: Any) -> None:
        self._add(
            method='execute_raw',
            arguments={
                'query': query,
                'parameters': args,
            }
        )
    {% endif %}

    {% if is_async %}
    async def __aenter__(self) -> 'Batch':
        return self

    async def __aexit__(
        self,
        exc_type: Optional[Type[BaseException]],
        exc: Optional[BaseException],
        exc_tb: Optional[TracebackType],
    ) -> None:
        if exc is None:
            await self.commit()
    {% else %}
    def __enter__(self) -> 'Batch':
        return self

    def __exit__(
        self,
        exc_type: Optional[Type[BaseException]],
        exc: Optional[BaseException],
        exc_tb: Optional[TracebackType],
    ) -> None:
        if exc is None:
            self.commit()
    {% endif %}

{% for model in dmmf.datamodel.models %}

# NOTE: some arguments are meaningless in this context but are included
# for completeness sake
class {{ model.name }}BatchActions:
    def __init__(self, batcher: Batch) -> None:
        self._batcher = batcher

    def create(
        self,
        data: types.{{ model.name }}CreateInput,
        include: Optional[types.{{ model.name}}Include] = None
    ) -> None:
        self._batcher._add(
            method='create',
            model=models.{{ model.name }},
            arguments={
                'data': data,
                'include': include,
            },
        )

    def create_many(
        self,
        data: List[types.{{ model.name }}CreateWithoutRelationsInput],
        *,
        skip_duplicates: Optional[bool] = None,
    ) -> None:
        if skip_duplicates and self._batcher._active_provider in CREATE_MANY_SKIP_DUPLICATES_UNSUPPORTED:
            raise errors.UnsupportedDatabaseError(self._batcher._active_provider, 'create_many_skip_duplicates')

        self._batcher._add(
            method='create_many',
            model=models.{{ model.name }},
            arguments={
                'data': data,
                'skipDuplicates': skip_duplicates,
            },
            root_selection=['count'],
        )

    def delete(
        self,
        where: types.{{ model.name }}WhereUniqueInput,
        include: Optional[types.{{ model.name}}Include] = None,
    ) -> None:
        self._batcher._add(
            method='delete',
            model=models.{{ model.name }},
            arguments={
                'where': where,
                'include': include,
            },
        )

    def update(
        self,
        data: types.{{ model.name }}UpdateInput,
        where: types.{{ model.name }}WhereUniqueInput,
        include: Optional[types.{{ model.name}}Include] = None
    ) -> None:
        self._batcher._add(
            method='update',
            model=models.{{ model.name }},
            arguments={
                'data': data,
                'where': where,
                'include': include,
            },
        )

    def upsert(
        self,
        where: types.{{ model.name }}WhereUniqueInput,
        data: types.{{ model.name }}UpsertInput,
        include: Optional[types.{{ model.name}}Include] = None,
    ) -> None:
        self._batcher._add(
            method='upsert',
            model=models.{{ model.name }},
            arguments={
                'where': where,
                'include': include,
                'create': data.get('create'),
                'update': data.get('update'),
            },
        )

    def update_many(
        self,
        data: types.{{ model.name }}UpdateManyMutationInput,
        where: types.{{ model.name }}WhereInput,
    ) -> None:
        self._batcher._add(
            method='update_many',
            model=models.{{ model.name }},
            arguments={'data': data, 'where': where,},
            root_selection=['count'],
        )

    def delete_many(
        self,
        where: Optional[types.{{ model.name }}WhereInput] = None,
    ) -> None:
        self._batcher._add(
            method='delete_many',
            model=models.{{ model.name }},
            arguments={'where': where},
            root_selection=['count'],
        )


{% endfor %}

Client = Prisma
