{% include '_header.py.jinja' %}
{% from '_utils.py.jinja' import recursive_types with context %}
# -- template models.py.jinja --
import os
import logging
import inspect
import warnings
from collections import OrderedDict

from pydantic import BaseModel, Field

from . import types, enums, errors, fields, bases
from ._types import FuncType
from ._compat import model_rebuild, field_validator
from ._builder import serialize_base64
from .generator import partial_models_ctx, PartialModelField


log: logging.Logger = logging.getLogger(__name__)
_created_partial_types: Set[str] = set()

{% for model in dmmf.datamodel.models %}
class {{ model.name }}(bases.Base{{ model.name }}):
    {% if model.documentation is none %}
    """Represents a {{ model.name }} record"""
    {% else %}
    """{{ format_documentation(model.documentation) }}"""
    {% endif %}

    {% for field in model.all_fields %}
    {{ field.name }}:
            {%- if field.is_required and not field.relation_name -%}
                {{ ' ' }}{{ field.python_type }}
            {% else -%}
                {{ ' ' }}Optional[{{ field.python_type }}] = None
            {% endif %}
    {% if not field.documentation is none %}
    """{{ format_documentation(field.documentation) }}"""

    {% endif %}
    {% endfor %}

    {% if not recursive_types %}
    # take *args and **kwargs so that other metaclasses can define arguments
    def __init_subclass__(
        cls,
        *args: Any,
        warn_subclass: Optional[bool] = None,
        **kwargs: Any,
    ) -> None:
        super().__init_subclass__()
        if warn_subclass is not None:
            warnings.warn(
                'The `warn_subclass` argument is deprecated as it is no longer necessary and will be removed in the next release',
                DeprecationWarning,
                stacklevel=3,
            )
    {% endif %}

    {% if model.required_array_fields|list|length > 0 %}
    @field_validator({{ model.required_array_fields | map(attribute="name") | map('quote') | join(", ") }}, pre=True, allow_reuse=True)
    @classmethod
    def _transform_required_list_fields(cls, value: object) -> object:
        # When using raw queries, some databases will return `None` for an array field that has not been set yet.
        #
        # In our case we want to use an empty list instead as that is the internal Prisma behaviour and we want
        # to use the same consistent structure between the core ORM and raw queries. For example, if we updated
        # our type definitions to include `None` for `List` fields then it would be misleading as it will only
        # ever be `None` in raw queries.
        if value is None:
            return []

        return value
    {% endif %}

    @staticmethod
    def create_partial(
        name: str,
        include: Optional[Iterable['types.{{ model.name }}Keys']] = None,
        exclude: Optional[Iterable['types.{{ model.name }}Keys']] = None,
        required: Optional[Iterable['types.{{ model.name }}Keys']] = None,
        optional: Optional[Iterable['types.{{ model.name }}Keys']] = None,
        relations: Optional[Mapping['types.{{ model.name }}RelationalFieldKeys', str]] = None,
        exclude_relational_fields: bool = False,
    ) -> None:
        if not os.environ.get('PRISMA_GENERATOR_INVOCATION'):
            raise RuntimeError(
                'Attempted to create a partial type outside of client generation.'
            )

        if name in _created_partial_types:
            raise ValueError(f'Partial type "{name}" has already been created.')

        if include is not None:
            if exclude is not None:
                raise TypeError('Exclude and include are mutually exclusive.')
            if exclude_relational_fields is True:
                raise TypeError('Include and exclude_relational_fields=True are mutually exclusive.')

        if required and optional:
            shared = set(required) & set(optional)
            if shared:
                raise ValueError(f'Cannot make the same field(s) required and optional {shared}')

        if exclude_relational_fields and relations:
            raise ValueError(
                'exclude_relational_fields and relations are mutually exclusive'
            )

        fields: Dict['types.{{ model.name }}Keys', PartialModelField] = OrderedDict()

        try:
            if include:
                for field in include:
                    fields[field] = _{{ model.name }}_fields[field].copy()
            elif exclude:
                for field in exclude:
                    if field not in _{{ model.name }}_fields:
                        raise KeyError(field)

                fields = {
                    key: data.copy()
                    for key, data in _{{ model.name }}_fields.items()
                    if key not in exclude
                }
            else:
                fields = {
                    key: data.copy()
                    for key, data in _{{ model.name }}_fields.items()
                }

            if required:
                for field in required:
                    fields[field]['optional'] = False

            if optional:
                for field in optional:
                    fields[field]['optional'] = True

            {% if model.has_relational_fields %}
            if exclude_relational_fields:
                fields = {
                    key: data
                    for key, data in fields.items()
                    if key not in _{{ model.name }}_relational_fields
                }
            {% endif %}

            if relations:
                {% if model.has_relational_fields %}
                for field, type_ in relations.items():
                    if field not in _{{ model.name }}_relational_fields:
                        raise errors.UnknownRelationalFieldError('{{ model.name }}', field)

                    # TODO: this method of validating types is not ideal
                    # as it means we cannot two create partial types that
                    # reference each other
                    if type_ not in _created_partial_types:
                        raise ValueError(
                            f'Unknown partial type: "{type_}". '
                            f'Did you remember to generate the {type_} type before this one?'
                        )

                    # TODO: support non prisma.partials models
                    info = fields[field]
                    if info['is_list']:
                        info['type'] = f'List[\'partials.{type_}\']'
                    else:
                        info['type'] = f'\'partials.{type_}\''
                {% else %}
                raise ValueError('Model: "{{ model.name }}" has no relational fields.')
                {% endif %}
        except KeyError as exc:
            raise ValueError(
                f'{exc.args[0]} is not a valid {{ model.name }} / {name} field.'
            ) from None

        models = partial_models_ctx.get()
        models.append(
            {
                'name': name,
                'fields': cast(Mapping[str, PartialModelField], fields),
                'from_model': '{{ model.name }}',
            }
        )
        _created_partial_types.add(name)


{% endfor %}

{% for model in dmmf.datamodel.models %}
{% if model.has_relational_fields -%}
    _{{ model.name }}_relational_fields: Set[str] = {
        {% for field in model.relational_fields %}
        '{{ field.name }}',
        {% endfor %}
    }
{% else -%}
    _{{ model.name }}_relational_fields: Set[str] = set()  # pyright: ignore[reportUnusedVariable]
{% endif %}
_{{ model.name }}_fields: Dict['types.{{ model.name }}Keys', PartialModelField] = OrderedDict(
    [
        {% for field in model.all_fields %}
        ('{{ field.name }}', {
            'name': '{{ field.name }}',
            'is_list': {{ field.is_list }},
            'optional': {{ field.is_optional }},
            'type': {{ field.python_type_as_string }},
            'is_relational': {{ field.relation_name is not none }},
            'documentation': {% if field.documentation is none %}None{% else %}'''{{ field.documentation }}'''{% endif %},
        }),
        {% endfor %}
    ],
)

{% endfor %}


# we have to import ourselves as relation types are namespaced to models
# e.g. models.Post
from . import models, actions

# required to support relationships between models
{% for model in dmmf.datamodel.models %}
model_rebuild({{ model.name }})
{% endfor %}
