# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING

import paddle
from paddle import _C_ops, pir
from paddle.framework import in_dynamic_or_pir_mode
from paddle.regularizer import L2Decay

from ..base import core, framework
from .optimizer import Optimizer

if TYPE_CHECKING:
    from collections.abc import Sequence

    from typing_extensions import NotRequired

    from paddle import Tensor
    from paddle.nn.clip import GradientClipBase
    from paddle.regularizer import WeightDecayRegularizer

    from .lr import LRScheduler
    from .optimizer import _ParameterConfig

    class _MomentumParameterConfig(_ParameterConfig):
        momentum: NotRequired[float]
        use_nesterov: NotRequired[bool]
        rescale_grad: NotRequired[float]
        regularization_method: NotRequired[str]
        regularization_coeff: NotRequired[float]


__all__ = []


class Momentum(Optimizer):
    r"""

    Simple Momentum optimizer with velocity state

    This optimizer has a flag for Nestrov Momentum.

    The update equations are as follows:

    .. math::

        & velocity = mu * velocity + gradient

        & if (use\_nesterov):

        &\quad   param = param - (gradient + mu * velocity) * learning\_rate

        & else:

        &\quad   param = param - learning\_rate * velocity

    Parameters:

        learning_rate (float|Tensor|LRScheduler, optional): The learning rate used to update ``Parameter``.
            It can be a float value, a ``Tensor`` with a float type or a LRScheduler. The default value is 0.001.
        momentum (float): Momentum factor. The default value is 0.9.
        parameters (list|tuple|None, optional): List|Tuple of ``Tensor`` to update to minimize ``loss``. \
            This parameter is required in dygraph mode. And you can specify different options for \
            different parameter groups such as the learning rate, weight decay, etc, \
            then the parameters are list of dict. Note that the learning_rate in parameter groups \
            represents the scale of base learning_rate. \
            The default value is None in static graph mode, at this time all parameters will be updated.
        use_nesterov(bool, optional): Enables Nesterov momentum. The default value is False.
        weight_decay (int|float|WeightDecayRegularizer|None, optional): The strategy of regularization. \
            It can be a int or float value as coeff of L2 regularization or \
            :ref:`api_paddle_regularizer_L1Decay`, :ref:`api_paddle_regularizer_L2Decay`.
            If a parameter has set regularizer using :ref:`api_paddle_ParamAttr` already, \
            the regularization setting here in optimizer will be ignored for this parameter. \
            Otherwise, the regularization setting here in optimizer will take effect. \
            Default None, meaning there is no regularization.
        grad_clip (GradientClipBase|None, optional): Gradient clipping strategy, it's an instance of
            some derived class of ``GradientClipBase`` . There are three clipping strategies
            ( :ref:`api_paddle_nn_ClipGradByGlobalNorm` , :ref:`api_paddle_nn_ClipGradByNorm` ,
            :ref:`api_paddle_nn_ClipGradByValue` ). Default None, meaning there is no gradient clipping.
        multi_precision (bool, optional): Whether to use multi-precision during weight updating. Default is false.
        rescale_grad (float, optional): Multiply the gradient with `rescale_grad` before updating. \
            Often choose to be ``1.0/batch_size``.
        use_multi_tensor (bool, optional): Whether to use multi-tensor strategy to update all parameters at once . Default is false.
        name (str|None, optional): The default value is None. Normally there is no need for user
                to set this property. For more information, please refer to
                :ref:`api_guide_Name` .

    Examples:
        .. code-block:: python

            >>> import paddle

            >>> inp = paddle.uniform([10, 10], dtype="float32", min=-0.1, max=0.1)
            >>> linear = paddle.nn.Linear(10, 10)
            >>> inp = paddle.to_tensor(inp)
            >>> out = linear(inp)
            >>> loss = paddle.mean(out)
            >>> momentum = paddle.optimizer.Momentum(
            ...     learning_rate=0.1,
            ...     parameters=linear.parameters(),
            ...     weight_decay=0.01
            ... )
            >>> back = out.backward()
            >>> momentum.step()
            >>> momentum.clear_grad()

            >>> # Note that the learning_rate of linear_2 is 0.01.
            >>> linear_1 = paddle.nn.Linear(10, 10)
            >>> linear_2 = paddle.nn.Linear(10, 10)
            >>> inp = paddle.uniform(shape=[10, 10], min=-0.1, max=0.1)
            >>> out = linear_1(inp)
            >>> out = linear_2(out)
            >>> loss = paddle.mean(out)
            >>> momentum = paddle.optimizer.Momentum(
            ...     learning_rate=0.1,
            ...     parameters=[{ # type: ignore
            ...         'params': linear_1.parameters()
            ...     }, {
            ...         'params': linear_2.parameters(),
            ...         'weight_decay': 0.001,
            ...         'learning_rate': 0.1
            ...     }],
            ...     weight_decay=0.01,
            ...     momentum=0.9
            ... )
            >>> out.backward()
            >>> momentum.step()
            >>> momentum.clear_grad()

    """

    _velocity_acc_str = "velocity"

    def __init__(
        self,
        learning_rate: float | Tensor | LRScheduler = 0.001,
        momentum: float = 0.9,
        parameters: (
            Sequence[Tensor] | Sequence[_MomentumParameterConfig] | None
        ) = None,
        use_nesterov: bool = False,
        weight_decay: float | WeightDecayRegularizer | None = None,
        grad_clip: GradientClipBase | None = None,
        multi_precision: bool = False,
        rescale_grad: float = 1.0,
        use_multi_tensor: bool = False,
        name: str | None = None,
    ) -> None:
        if learning_rate is None:
            raise ValueError("learning_rate is not set")
        if momentum is None:
            raise ValueError("momentum is not set")

        if isinstance(weight_decay, int):
            weight_decay = float(weight_decay)
        predicate = lambda regular: isinstance(regular, (L2Decay, float))
        if isinstance(parameters, list):
            if isinstance(parameters[0], dict):
                for param_group in parameters:
                    decay = (
                        param_group['weight_decay']
                        if 'weight_decay' in param_group
                        else weight_decay
                    )
                    reg_method, reg_coeff = self._update_regularization(decay)
                    param_group['regularization_method'] = reg_method
                    param_group['regularization_coeff'] = reg_coeff
                    py_regular = None if predicate(decay) else decay
                    param_group['weight_decay'] = py_regular

        py_regular = None if predicate(weight_decay) else weight_decay
        super().__init__(
            learning_rate=learning_rate,
            parameters=parameters,
            weight_decay=py_regular,
            grad_clip=grad_clip,
            name=name,
        )
        self.type = "momentum"
        self._momentum = momentum
        self._use_nesterov = bool(use_nesterov)
        (
            self._regularization_method,
            self._regularization_coeff,
        ) = self._update_regularization(weight_decay)
        self._multi_precision = multi_precision
        self._rescale_grad = rescale_grad
        self._master_weights = {}

        self._default_dict = {
            'momentum': momentum,
            'use_nesterov': use_nesterov,
            'rescale_grad': rescale_grad,
            'regularization_method': self._regularization_method,
            'regularization_coeff': self._regularization_coeff,
        }
        self._use_multi_tensor = use_multi_tensor
        if self._use_multi_tensor:
            self._param_dict = self._create_multi_tensor_dict()
            self._velocity_dict = self._create_multi_tensor_dict()
            self._master_weight_dict = self._create_multi_tensor_dict()
            self._master_weight_dict['FP32_DenseTensor'] = None
            self._regularization_method_dict = self._create_multi_tensor_dict()
            self._regularization_coeff_dict = self._create_multi_tensor_dict()

    def _update_regularization(self, weight_decay):
        reg_method = ""
        reg_coeff = 0.0

        if isinstance(weight_decay, L2Decay):
            reg_method = "l2_decay"
            reg_coeff = weight_decay._coeff
        if isinstance(weight_decay, float):
            reg_method = "l2_decay"
            reg_coeff = weight_decay
        return reg_method, reg_coeff

    def _create_accumulators(self, block, parameters):
        '''
        if framework.in_dynamic_mode():
            return
        '''
        assert isinstance(block, (framework.Block, paddle.pir.Block))

        if isinstance(parameters, dict):
            parameters = self._update_param_group(parameters)

        for p in parameters:
            if p.name in self._already_create_accumulator:
                continue
            if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype):
                master_p = self._create_master_weight(p)
                self._add_accumulator(self._velocity_acc_str, master_p)
                self._already_create_accumulator.add(p.name)
                continue
            if (
                self._is_dtype_fp16_or_bf16(p.dtype)
                and not self._multi_precision
            ):
                warnings.warn(
                    "Accumulating with FP16/BF16 in optimizer can lead to poor accuracy or slow convergence."
                    "Consider using multi_precision=True option of the Momentum optimizer."
                )
            self._add_accumulator(self._velocity_acc_str, p)
            self._already_create_accumulator.add(p.name)

    def _create_regularization_of_grad(self, param, grad, regularization=None):
        """Create and add backward regularization Operators

        Function helper of append_regularization_ops.
        """
        # If ParamAttr is set to L2Decay, we skip doing regularization here. And then we fused
        # L2Decay with momentum which can refer to _append_optimize_op below.
        if hasattr(param, 'regularizer') and isinstance(
            param.regularizer, L2Decay
        ):
            return grad
        return super()._create_regularization_of_grad(
            param, grad, regularization
        )

    def _append_optimize_op(self, block, param_and_grad):
        if not isinstance(block, (framework.Block, pir.Block)):
            raise TypeError("block is not instance of Block.")
        if isinstance(param_and_grad, dict):
            param_and_grad = self._update_param_group(param_and_grad)

        velocity_acc = self._get_accumulator_master(
            self._velocity_acc_str, param_and_grad[0]
        )
        lr = self._create_param_lr(param_and_grad)

        # For fusion of momentum and l2decay
        param = param_and_grad[0]
        regularization_method = self._regularization_method
        regularization_coeff = self._regularization_coeff
        if hasattr(param, 'regularizer'):
            # we skip param's l2decay before, so fuse it with momentum here.
            if isinstance(param.regularizer, L2Decay):
                regularization_method = "l2_decay"
                regularization_coeff = param.regularizer._coeff
            # the param's regularization has been done before, we avoid do l2decay in momentum.
            elif param.regularizer is not None:
                regularization_method = ""
                regularization_coeff = 0.0

        find_master = self._multi_precision and self._is_dtype_fp16_or_bf16(
            param_and_grad[0].dtype
        )
        master_weight = (
            self._master_weights[param_and_grad[0].name]
            if find_master
            else None
        )

        if in_dynamic_or_pir_mode():
            if isinstance(param_and_grad, dict):
                self._update_regularization(param_and_grad['weight_decay'])
            return _C_ops.momentum_(
                param_and_grad[0],
                param_and_grad[1],
                velocity_acc,
                lr,
                master_weight,
                self._momentum,
                self._use_nesterov,
                regularization_method,
                regularization_coeff,
                find_master,
                self._rescale_grad,
            )
        else:
            attrs = {
                "mu": self._momentum,
                "use_nesterov": self._use_nesterov,
                "regularization_method": regularization_method,
                "regularization_coeff": regularization_coeff,
                "multi_precision": find_master,
                "rescale_grad": self._rescale_grad,
            }

            inputs = {
                "Param": [param_and_grad[0]],
                "Grad": [param_and_grad[1]],
                "Velocity": [velocity_acc],
                "LearningRate": [lr],
            }

            outputs = {
                "ParamOut": [param_and_grad[0]],
                "VelocityOut": [velocity_acc],
            }

            if find_master:
                inputs["MasterParam"] = master_weight
                outputs["MasterParamOut"] = master_weight

            # create the momentum optimize op
            momentum_op = block.append_op(
                type=self.type,
                inputs=inputs,
                outputs=outputs,
                attrs=attrs,
                stop_gradient=True,
            )

            return momentum_op

    def _multi_tensor_init(self, target_block, parameters, param_group_idx):
        """
        All parameters used for optimizer (such as: parameters, master_weight, velocity_acc for momentum) calculations are grouped into a python list by data type (float16, bf16, float32).
        This function will be overridden in the corresponding optimizer file.

        Args:
            target_block: the block in which the loss tensor is present
            parameters: list of parameter tensors for the optimizer
        """
        self._create_accumulators(target_block, parameters)
        for param in parameters:
            velocity_acc = self._get_accumulator_master(
                self._velocity_acc_str, param
            )
            regularization_method = self._regularization_method
            regularization_coeff = self._regularization_coeff
            if hasattr(param, 'regularizer'):
                # we skip param's l2decay before, so fuse it with momentum here.
                if isinstance(param.regularizer, L2Decay):
                    regularization_method = "l2_decay"
                    regularization_coeff = param.regularizer._coeff
                elif param.regularizer is not None:
                    regularization_method = ""
                    regularization_coeff = 0.0
            if param.dtype == paddle.float32:
                self._param_dict['FP32_DenseTensor'][param_group_idx].append(
                    param
                )
                self._velocity_dict['FP32_DenseTensor'][param_group_idx].append(
                    velocity_acc
                )
                # fp32 no master weight
                self._regularization_method_dict['FP32_DenseTensor'][
                    param_group_idx
                ].append(regularization_method)
                self._regularization_coeff_dict['FP32_DenseTensor'][
                    param_group_idx
                ].append(regularization_coeff)
            elif self._is_dtype_fp16_or_bf16(param.dtype):
                self._param_dict['FP16_DenseTensor'][param_group_idx].append(
                    param
                )
                self._velocity_dict['FP16_DenseTensor'][param_group_idx].append(
                    velocity_acc
                )
                if self._multi_precision:
                    self._master_weight_dict['FP16_DenseTensor'][
                        param_group_idx
                    ].append(self._master_weights[param.name])
                else:
                    self._master_weight_dict['FP16_DenseTensor'][
                        param_group_idx
                    ] = None
                self._regularization_method_dict['FP16_DenseTensor'][
                    param_group_idx
                ].append(regularization_method)
                self._regularization_coeff_dict['FP16_DenseTensor'][
                    param_group_idx
                ].append(regularization_coeff)
            else:
                raise ValueError(
                    "Now multi_tensor_momentum only support fp32, fp16 or bf16 parameters and grad is DENSE_TENSOR."
                )

    def _append_optimize_multi_tensor_op(
        self,
        target_block,
        parameters_and_grads,
        param_group_idx,
    ):
        """
        For Multi Tensor, append optimize merged_operator to block.
        """
        assert isinstance(target_block, framework.Block)

        grad_dict = {'FP32_DenseTensor': [], 'FP16_DenseTensor': []}
        lr_dict = {'FP32_DenseTensor': [], 'FP16_DenseTensor': []}

        if isinstance(parameters_and_grads, list):
            for param_and_grad in parameters_and_grads:
                if param_and_grad[1] is None:
                    continue
                if param_and_grad[0].stop_gradient is False:
                    if (
                        param_and_grad[0].dtype == paddle.float32
                        and param_and_grad[1].type
                        == core.VarDesc.VarType.DENSE_TENSOR
                    ):
                        grad_dict['FP32_DenseTensor'].append(param_and_grad[1])
                        lr = self._create_param_lr(param_and_grad)
                        lr_dict['FP32_DenseTensor'].append(lr)
                    elif (
                        self._is_dtype_fp16_or_bf16(param_and_grad[0].dtype)
                        and param_and_grad[1].type
                        == core.VarDesc.VarType.DENSE_TENSOR
                    ):
                        grad_dict['FP16_DenseTensor'].append(param_and_grad[1])
                        lr = self._create_param_lr(param_and_grad)
                        lr_dict['FP16_DenseTensor'].append(lr)
        else:
            for param_and_grad in parameters_and_grads['params']:
                if param_and_grad[1] is None:
                    continue
                if param_and_grad[0].stop_gradient is False:
                    param_grad_dict = {}
                    param_grad_dict['params'] = param_and_grad
                    param_grad_dict.update(
                        {
                            k: v
                            for k, v in parameters_and_grads.items()
                            if k != 'params'
                        }
                    )
                    param_and_grad = self._update_param_group(param_grad_dict)
                    if (
                        param_and_grad[0].dtype == paddle.float32
                        and param_and_grad[1].type
                        == core.VarDesc.VarType.DENSE_TENSOR
                    ):
                        grad_dict['FP32_DenseTensor'].append(param_and_grad[1])
                        lr = self._create_param_lr(param_and_grad)
                        lr_dict['FP32_DenseTensor'].append(lr)
                    elif (
                        self._is_dtype_fp16_or_bf16(param_and_grad[0].dtype)
                        and param_and_grad[1].type
                        == core.VarDesc.VarType.DENSE_TENSOR
                    ):
                        grad_dict['FP16_DenseTensor'].append(param_and_grad[1])
                        lr = self._create_param_lr(param_and_grad)
                        lr_dict['FP16_DenseTensor'].append(lr)

        multi_tensor_list = ['FP32_DenseTensor', 'FP16_DenseTensor']
        for key in multi_tensor_list:
            if len(self._param_dict[key][param_group_idx]) > 0:
                find_master = (
                    self._multi_precision and key == 'FP16_DenseTensor'
                )

                master_weight = self._master_weight_dict[key]
                master_weight = (
                    master_weight[param_group_idx]
                    if master_weight is not None
                    else None
                )

                if in_dynamic_or_pir_mode():
                    found_inf = self._get_auxiliary_var('found_inf')
                    if found_inf:
                        if isinstance(
                            found_inf, (core.eager.Tensor, paddle.pir.Value)
                        ):
                            self._set_auxiliary_var('found_inf', True)
                    else:
                        if isinstance(
                            found_inf, (core.eager.Tensor, paddle.pir.Value)
                        ):
                            self._set_auxiliary_var('found_inf', False)
                        _, _, _ = _C_ops.merged_momentum_(
                            self._param_dict[key][param_group_idx],
                            grad_dict[key],
                            self._velocity_dict[key][param_group_idx],
                            lr_dict[key],
                            master_weight,
                            self._momentum,
                            self._use_nesterov,
                            self._regularization_method_dict[key][
                                param_group_idx
                            ],
                            self._regularization_coeff_dict[key][
                                param_group_idx
                            ],
                            find_master,
                            self._rescale_grad,
                        )
                else:
                    inputs = {
                        "Param": self._param_dict[key][param_group_idx],
                        "Grad": grad_dict[key],
                        "Velocity": self._velocity_dict[key][param_group_idx],
                        "LearningRate": lr_dict[key],
                    }
                    outputs = {
                        "ParamOut": self._param_dict[key][param_group_idx],
                        "VelocityOut": self._velocity_dict[key][
                            param_group_idx
                        ],
                    }
                    attrs = {
                        "mu": self._momentum,
                        "use_nesterov": self._use_nesterov,
                        "regularization_method": self._regularization_method_dict[
                            key
                        ][param_group_idx],
                        "regularization_coeff": self._regularization_coeff_dict[
                            key
                        ][param_group_idx],
                    }
                    if find_master:
                        inputs["MasterParam"] = self._master_weight_dict[key][
                            param_group_idx
                        ]
                        outputs["MasterParamOut"] = self._master_weight_dict[
                            key
                        ][param_group_idx]
                        attrs["multi_precision"] = find_master
                    target_block.append_op(
                        type="merged_momentum",
                        inputs=inputs,
                        outputs=outputs,
                        attrs=attrs,
                        stop_gradient=True,
                    )

    def _update_param_group(self, parameters):
        self._momentum = parameters.get(
            'momentum', self._default_dict['momentum']
        )
        self._use_nesterov = parameters.get(
            'use_nesterov', self._default_dict['use_nesterov']
        )
        self._rescale_grad = parameters.get(
            'rescale_grad', self._default_dict['rescale_grad']
        )
        self._regularization_method = parameters.get(
            'regularization_method', self._default_dict['regularization_method']
        )
        self._regularization_coeff = parameters.get(
            'regularization_coeff', self._default_dict['regularization_coeff']
        )
        parameters = parameters.get('params')
        return parameters
