#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import copy
from typing import TYPE_CHECKING, Any

import paddle

from . import unique_name
from .dygraph_utils import _append_activation_in_dygraph
from .framework import (
    Parameter,
    dtype_is_floating,
    in_dygraph_mode,
    in_pir_mode,
)
from .layer_helper_base import LayerHelperBase
from .param_attr import ParamAttr

if TYPE_CHECKING:
    from collections.abc import Generator

    from paddle import Tensor
    from paddle.base.framework import Operator


class LayerHelper(LayerHelperBase):
    def __init__(self, layer_type: str, **kwargs: Any) -> None:
        self.kwargs = kwargs
        name = self.kwargs.get('name', None)
        # TODO(panyx0718, minqiyang): dygraph mode
        # can not use both `layer_type` and `name`. Deprecate LayerHelper
        # and write a Helper for dygraph mode.
        if name is None:
            if in_dygraph_mode():
                self.kwargs['name'] = unique_name.generate(layer_type)
            else:
                self.kwargs['name'] = (
                    self.main_program._name_generator.generate(layer_type)
                )

        super().__init__(self.kwargs['name'], layer_type=layer_type)

    def append_op(self, *args: Any, **kwargs: Any) -> Operator:
        return self.main_program.current_block().append_op(*args, **kwargs)

    def multiple_input(self, input_param_name: str = 'input') -> list[Tensor]:
        inputs = self.kwargs.get(input_param_name, [])
        ret = []
        if isinstance(inputs, (list, tuple)):
            for inp in inputs:
                ret.append(self.to_variable(inp))
        else:
            ret.append(self.to_variable(inputs))
        return ret

    def input(self, input_param_name: str = 'input') -> Tensor:
        inputs = self.multiple_input(input_param_name)
        if len(inputs) != 1:
            raise f"{self.layer_type} layer only takes one input"
        return inputs[0]

    @property
    def param_attr(self) -> ParamAttr:
        return ParamAttr._to_attr(self.kwargs.get('param_attr', None))

    @property
    def bias_attr(self) -> ParamAttr:
        return ParamAttr._to_attr(self.kwargs.get('bias_attr', None))

    # TODO (jiabin): reconstruct this in LayerObjHelper and avoid dependency of param_attr
    def multiple_param_attr(self, length: int) -> list[ParamAttr]:
        param_attr = self.param_attr
        if isinstance(param_attr, ParamAttr):
            param_attr = [param_attr]

        if len(param_attr) != 1 and len(param_attr) != length:
            raise ValueError("parameter number mismatch")
        elif len(param_attr) == 1 and length != 1:
            tmp = [None] * length
            for i in range(length):
                tmp[i] = copy.deepcopy(param_attr[0])
            param_attr = tmp
        return param_attr

    def iter_inputs_and_params(
        self, input_param_name: str = 'input'
    ) -> Generator[tuple[Tensor, ParamAttr]]:
        inputs = self.multiple_input(input_param_name)
        param_attrs = self.multiple_param_attr(len(inputs))
        yield from zip(inputs, param_attrs)

    def input_dtype(
        self, input_param_name: str = 'input'
    ) -> None | paddle.dtype:
        inputs = self.multiple_input(input_param_name)
        dtype = None
        for each in inputs:
            if dtype is None:
                dtype = each.dtype
            elif dtype != each.dtype:
                raise ValueError(f"Data Type mismatch: {dtype} to {each.dtype}")
        return dtype

    def get_parameter(self, name: str) -> Tensor:
        param = self.main_program.global_block().var(name)
        if not isinstance(param, Parameter):
            raise ValueError(f"no Parameter name {name} found")
        return param

    # TODO (jiabin): reconstruct this in LayerObjHelper and avoid dependency of bias_attr
    def append_bias_op(
        self, input_var: Tensor, dim_start: int = 1, dim_end: int | None = None
    ) -> Tensor:
        """
        Append bias operator and return its output. If the user does not set
        bias_attr, append_bias_op will return input_var

        :param input_var: the input variable. The len(input_var.shape) is
        larger or equal than 2.
        :bias_initializer: an instance of a subclass of Initializer used to
        initialize the bias
        :param dim_start:
        :param dim_end: the shape of the bias will be
        input_var.shape[dim_start:dim_end]. The bias is broadcasted to other
        dimensions and added to input_var to get the output
        """
        size = list(input_var.shape[dim_start:dim_end])
        bias_attr = self.bias_attr
        if not bias_attr:
            return input_var

        b = self.create_parameter(
            attr=bias_attr, shape=size, dtype=input_var.dtype, is_bias=True
        )
        if in_pir_mode():
            return input_var + b
        tmp = self.create_variable_for_type_inference(dtype=input_var.dtype)
        self.append_op(
            type='elementwise_add',
            inputs={'X': [input_var], 'Y': [b]},
            outputs={'Out': [tmp]},
            attrs={'axis': dim_start},
        )
        return tmp

    # TODO (jiabin): reconstruct this in LayerObjHelper and avoid dependency of act
    def append_activation(self, input_var: Tensor) -> Tensor:
        act = self.kwargs.get('act', None)
        if act is None:
            return input_var
        if isinstance(act, str):
            act = {'type': act}
        else:
            raise TypeError(str(act) + " should be unicode or str")

        use_cudnn = None
        if 'use_cudnn' in self.kwargs and self.kwargs.get('use_cudnn'):
            use_cudnn = self.kwargs.get('use_cudnn')
            act['use_cudnn'] = use_cudnn
        act_type = act.pop('type')
        if in_dygraph_mode():
            res = _append_activation_in_dygraph(input_var, act_type, use_cudnn)
            return res
        elif in_pir_mode():
            return paddle.pir_utils.append_activation_in_pir(
                input_var, act_type, use_cudnn
            )
        else:
            tmp = self.create_variable_for_type_inference(dtype=input_var.dtype)
            self.append_op(
                type=act_type,
                inputs={"X": [input_var]},
                outputs={"Out": [tmp]},
                attrs=act,
            )
            return tmp

    # TODO (jiabin): should we remove this since it has never be used
    def _get_default_initializer(self, dtype):
        if dtype is None or dtype_is_floating(dtype) is True:
            return paddle.nn.initializer.XavierUniform()
        else:
            # For integer and boolean types, initialize with all zeros
            return paddle.nn.initializer.Constant()

    # TODO (jiabin): reconstruct this in LayerObjHelper and avoid dependency of kwargs
    def is_instance(self, param_name: str, cls: Any) -> None:
        param = self.kwargs.get(param_name, None)
        if not isinstance(param, cls):
            raise TypeError(
                "The input {0} parameter of method {1} must be {2}",
                param_name,
                self.layer_type,
                cls.__name__,
            )
