#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import warnings
from collections import defaultdict
from enum import Enum
from typing import (
    TYPE_CHECKING,
    Any,
    TypedDict,
)

import numpy as np

import paddle
from paddle import _C_ops, _legacy_C_ops
from paddle.base import core, unique_name
from paddle.base.data_feeder import check_type
from paddle.base.framework import Operator, _dygraph_tracer, in_pir_mode
from paddle.framework import in_dynamic_mode

from .auto_cast import amp_global_state

if TYPE_CHECKING:
    from paddle import Tensor
    from paddle.static.amp.decorator import OptimizerWithMixedPrecision
    from python.paddle.optimizer.optimizer import Optimizer

    class _ScaleStateDict(TypedDict):
        scale: Tensor
        incr_ratio: float
        decr_ratio: float
        incr_every_n_steps: int
        decr_every_n_nan_or_inf: int
        incr_count: int
        decr_count: int
        use_dynamic_loss_scaling: bool


class OptimizerState(Enum):
    INIT = 0
    UNSCALED = 1
    STEPPED = 2


def _refresh_optimizer_state():
    return {"state": OptimizerState.INIT}


class AmpScaler:
    """
    AmpScaler is used for Auto-Mixed-Precision training/inferring in imperative
    mode. It controls the scaling of loss, helps avoiding numerical overflow.
    The object of this class has seventeen methods `scale()`, `unscale_()`, `minimize()` and `get`/`set` api of parameters.

    `scale()` is used to multiply the loss by a scale ratio.
    `unscale_()` is used to unscale the gradients of parameters, multiplies the gradients of parameters by 1/(scale ratio)
    `minimize()` is similar as `optimizer.minimize()`, performs parameters updating, and it will update the loss_scaling.

    Commonly, it is used together with `amp_guard` to achieve Auto-Mixed-Precision in
    imperative mode.

    Args:
        enable(bool, optional): Enable loss scaling or not. Default is True.
        init_loss_scaling (float, optional): The initial loss scaling factor. Default is 2**15.
        incr_ratio(float, optional): The multiplier to use when increasing the loss
                        scaling. Default is 2.0.
        decr_ratio(float, optional): The less-than-one-multiplier to use when decreasing
                        the loss scaling. Default is 0.5.
        incr_every_n_steps(int, optional): Increases loss scaling every n consecutive
                                steps with finite gradients. Default is 1000.
        decr_every_n_nan_or_inf(int, optional): Decreases loss scaling every n
                                    accumulated steps with nan or inf gradients. Default is 2.
        use_dynamic_loss_scaling(bool, optional): Whether to use dynamic loss scaling. If False, fixed loss_scaling is used. If True, the loss scaling is updated dynamically. Default is True.
    Returns:
        An AmpScaler object.

    Examples:

        .. code-block:: python

            >>> import numpy as np
            >>> import paddle

            >>> data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32')
            >>> model = paddle.nn.Conv2D(3, 2, 3)
            >>> optimizer = paddle.optimizer.SGD(
            ...         learning_rate=0.01, parameters=model.parameters())
            >>> scaler = paddle.amp.AmpScaler(init_loss_scaling=1024)
            >>> data = paddle.to_tensor(data)
            >>> with paddle.amp.amp_guard():
            ...     conv = model(data)
            ...     loss = paddle.mean(conv)
            ...     scaled = scaler.scale(loss)
            ...     scaled.backward()
            ...     scaler.minimize(optimizer, scaled)
    """

    def __init__(
        self,
        enable: bool = True,
        init_loss_scaling: float = 2.0**15,
        incr_ratio: float = 2.0,
        decr_ratio: float = 0.5,
        incr_every_n_steps: int = 1000,
        decr_every_n_nan_or_inf: int = 1,
        use_dynamic_loss_scaling: bool = True,
    ) -> None:
        if in_dynamic_mode():
            tracer = _dygraph_tracer()
            if not tracer:
                raise ValueError(
                    "current_tracer is None, maybe it is not in imperative mode."
                )

            if enable and not (
                tracer._expected_place.is_gpu_place()
                or tracer._expected_place.is_xpu_place()
                or tracer._expected_place.is_custom_place()
            ):
                warnings.warn(
                    f'AmpScaler can only be enabled on CUDAPlace, XPUPlace and CustomPlace, current place is {tracer._expected_place}, so it makes no effect.'
                )
                enable = False

        self._enable = enable
        self._use_dynamic_loss_scaling = False
        self._init_loss_scaling = 1.0
        self._scale = None

        if self._enable:
            assert incr_ratio > 1.0, "The incr_ratio must be > 1.0."
            assert decr_ratio < 1.0, "The decr_ratio must be < 1.0."

            self._init_loss_scaling = init_loss_scaling
            self._incr_ratio = incr_ratio
            self._decr_ratio = decr_ratio
            self._incr_every_n_steps = incr_every_n_steps
            self._decr_every_n_nan_or_inf = decr_every_n_nan_or_inf
            self._incr_count = 0
            self._decr_count = 0
            self._use_dynamic_loss_scaling = use_dynamic_loss_scaling

            if in_pir_mode():
                self._scale = paddle.pir.core.create_persistable_value(
                    dtype='float32',
                    shape=[1],
                    name=unique_name.generate("loss_scaling"),
                    initializer=paddle.nn.initializer.ConstantInitializer(
                        value=self._init_loss_scaling
                    ),
                )
            else:
                self._found_inf = paddle.to_tensor(
                    np.array([0]).astype(np.bool_)
                )
                self._temp_found_inf_value_false = paddle.to_tensor(
                    np.array([0]).astype(np.bool_)
                )
                self._temp_found_inf_fp16 = paddle.to_tensor(
                    np.array([0]).astype(np.bool_)
                )
                self._temp_found_inf_bf16 = paddle.to_tensor(
                    np.array([0]).astype(np.bool_)
                )
                self._temp_found_inf_fp32 = paddle.to_tensor(
                    np.array([0]).astype(np.bool_)
                )
                self._scale = paddle.to_tensor(
                    np.array([self._init_loss_scaling]).astype(np.float32)
                )
                self._cache_found_inf = None
                self._optimizer_states = defaultdict(_refresh_optimizer_state)

    def scale(self, var: Tensor) -> Tensor:
        """
        Multiplies a Tensor by the scale factor and returns scaled outputs.
        If this instance of :class:`AmpScaler` is not enabled, output are returned unmodified.

        Args:
            var (Tensor):  The Tensor to scale.
        Returns:
            The scaled Tensor or original Tensor.

        Examples:

            .. code-block:: python

                >>> import numpy as np
                >>> import paddle

                >>> data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32')
                >>> model = paddle.nn.Conv2D(3, 2, 3)
                >>> optimizer = paddle.optimizer.SGD(
                ...         learning_rate=0.01, parameters=model.parameters())
                >>> scaler = paddle.amp.AmpScaler(init_loss_scaling=1024)
                >>> data = paddle.to_tensor(data)
                >>> with paddle.amp.amp_guard():
                ...     conv = model(data)
                ...     loss = paddle.mean(conv)
                ...     scaled = scaler.scale(loss)
                ...     scaled.backward()
                ...     scaler.minimize(optimizer, scaled)
        """
        check_type(
            var,
            "var",
            (paddle.Tensor, paddle.pir.Value),
            'AmpScaler.scale()',
        )

        if (
            self._enable
            and amp_global_state().amp_dtype != 'float16'
            and self._use_dynamic_loss_scaling
        ):
            self._enable = False
            self._use_dynamic_loss_scaling = False
            self._init_loss_scaling = 1.0
            warnings.warn(
                f'It is not recommended to use dynamic loss scaling for {amp_global_state().amp_dtype}, so GradScaler is disable by default.'
            )

        if in_pir_mode():
            if var.dtype != core.DataType.FLOAT32:
                var = var.astype('float32')
            if not self._use_dynamic_loss_scaling:
                return var
            scale_out = paddle._C_ops.multiply(var, self._scale)
            multiply_op = scale_out.get_defining_op()
            src_var_op = var.get_defining_op()
            if multiply_op.dist_attr and src_var_op.dist_attr:
                multiply_op.dist_attr = (
                    paddle.base.libpaddle.pir.create_op_dist_attribute(
                        multiply_op.dist_attr.process_mesh,
                        multiply_op.dist_attr.operands(),
                        multiply_op.dist_attr.results(),
                        src_var_op.dist_attr.chunk_id,
                    )
                )
            return scale_out

        # NOTE(lizhiyu): We hack here to avoid changing the `dist_attr` of `self._scale` of 'no-calculation-rank'
        if not self._enable or not var._is_initialized():
            return var

        return var * self._scale

    def minimize(
        self,
        optimizer: Optimizer | OptimizerWithMixedPrecision,
        *args: Any,
        **kwargs: Any,
    ) -> tuple[list[Operator], list[tuple[Tensor, Tensor]]]:
        """
        This function is similar as `Optimizer.minimize()`, which performs parameters updating.

        If the scaled gradients of parameters contains NAN or INF, the parameters updating is skipped.
        Otherwise, if `unscale_()` has not been called, it first unscales the scaled gradients of parameters, then updates the parameters.

        Finally, the loss scaling ratio is updated.

        Args:
            optimizer(Optimizer):  The optimizer used to update parameters.
            args:  Arguments, which will be forward to `Optimizer.minimize()`.
            kwargs: Keyword arguments, which will be forward to `Optimizer.minimize()`.

        Examples:

            .. code-block:: python

                >>> import numpy as np
                >>> import paddle

                >>> data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32')
                >>> model = paddle.nn.Conv2D(3, 2, 3)
                >>> optimizer = paddle.optimizer.SGD(
                ...     learning_rate=0.01,
                ...     parameters=model.parameters()
                ... )
                >>> scaler = paddle.amp.AmpScaler(init_loss_scaling=1024)
                >>> data = paddle.to_tensor(data)
                >>> with paddle.amp.amp_guard():
                ...     conv = model(data)
                ...     loss = paddle.mean(conv)
                ...     scaled = scaler.scale(loss)
                ...     scaled.backward()
                ...     scaler.minimize(optimizer, scaled)
        """

        if in_pir_mode():
            assert isinstance(
                optimizer,
                paddle.static.amp.decorator.OptimizerWithMixedPrecision,
            )
            optimizer._use_dynamic_loss_scaling = self._use_dynamic_loss_scaling
            optimizer._init_loss_scaling = self._init_loss_scaling
            optimizer._loss_scaling = self._scale
            optimizer._scaled_loss = args[0]
            if self._use_dynamic_loss_scaling:
                optimizer._incr_every_n_steps = self._incr_every_n_steps
                optimizer._decr_every_n_nan_or_inf = (
                    self._decr_every_n_nan_or_inf
                )
                optimizer._incr_ratio = self._incr_ratio
                optimizer._decr_ratio = self._decr_ratio
                optimizer._num_good_steps = None
                optimizer._num_bad_steps = None
            return optimizer.minimize(*args, **kwargs)

        if not self._enable:
            return optimizer.minimize(*args, **kwargs)

        optimizer_state = self._optimizer_states[id(optimizer)]

        #  unscale the grad
        if optimizer_state["state"] is OptimizerState.INIT:
            self._unscale(optimizer)

        optimize_ops, params_grads = (None, None)

        if hasattr(optimizer, "_set_auxiliary_var"):
            optimizer._set_auxiliary_var('found_inf', self._found_inf)
            optimize_ops, params_grads = optimizer.minimize(*args, **kwargs)
            # TODO: Fix to _cache_found_inf after PaddleNLP update
            self._cache_found_inf = optimizer._get_auxiliary_var('found_inf')
        else:
            if self._found_inf:
                self._cache_found_inf = True
            else:
                optimize_ops, params_grads = optimizer.minimize(*args, **kwargs)
                self._cache_found_inf = False

        if self._use_dynamic_loss_scaling:
            # update the scale
            self._update()

        self._optimizer_states = defaultdict(_refresh_optimizer_state)

        return optimize_ops, params_grads

    def _unscale(self, optimizer):
        """
        Unscale the gradients of parameters, multiplies the gradients of parameters by 1/(loss scaling ratio).
        If this instance of :class:`GradScaler` is not enabled, output are returned unmodified.
        Args:
            optimizer(Optimizer):  The optimizer used to update parameters.
        Returns:
            The unscaled parameters or original parameters.
        """
        if not self._enable:
            return

        optimizer_state = self._optimizer_states[id(optimizer)]

        if optimizer_state["state"] is OptimizerState.UNSCALED:
            raise RuntimeError(
                "unscale_() has already been called on this optimizer since the last update()."
            )
        elif optimizer_state["state"] is OptimizerState.STEPPED:
            raise RuntimeError("unscale_() is being called after step().")

        if getattr(optimizer, '_param_groups', None) and isinstance(
            optimizer._param_groups[0], dict
        ):
            param_grads = []
            param_grads_fp16 = []
            param_grads_bf16 = []
            param_grads_fp32 = []
            for group in optimizer._param_groups:
                for param in group['params']:
                    if param._grad_ivar() is not None:
                        param_grads.append(param._grad_ivar())
                        if param._grad_ivar().dtype == paddle.float16:
                            param_grads_fp16.append(param._grad_ivar())
                        elif param._grad_ivar().dtype == paddle.bfloat16:
                            param_grads_bf16.append(param._grad_ivar())
                        else:
                            param_grads_fp32.append(param._grad_ivar())
        else:
            if in_dynamic_mode():
                # It is very time-consuming to call c++ functions in a loop on the python side.
                # We put this part of the code on the c++ side to improve the speed in eager mode.
                (
                    param_grads_fp16,
                    param_grads_bf16,
                    param_grads_fp32,
                ) = core.eager.get_grads_lists(optimizer._parameter_list)
            else:
                # Keep the original code to support legacy mode.
                # Delete the else branch when the legacy mode exits.
                param_grads = [
                    param._grad_ivar()
                    for param in optimizer._parameter_list
                    if param._grad_ivar() is not None
                ]
                param_grads_fp16 = [
                    param
                    for param in param_grads
                    if param.dtype == paddle.float16
                ]
                param_grads_bf16 = [
                    param
                    for param in param_grads
                    if param.dtype == paddle.bfloat16
                ]
                param_grads_fp32 = [
                    param
                    for param in param_grads
                    if param.dtype == paddle.float32
                ]
        self._found_inf = self._temp_found_inf_value_false
        if len(param_grads_fp16):
            _legacy_C_ops.check_finite_and_unscale(
                param_grads_fp16,
                self._scale,
                param_grads_fp16,
                self._temp_found_inf_fp16,
            )
            self._found_inf = _C_ops.bitwise_or(
                self._found_inf, self._temp_found_inf_fp16
            )
        if len(param_grads_bf16):
            _legacy_C_ops.check_finite_and_unscale(
                param_grads_bf16,
                self._scale,
                param_grads_bf16,
                self._temp_found_inf_bf16,
            )
            self._found_inf = _C_ops.bitwise_or(
                self._found_inf, self._temp_found_inf_bf16
            )
        if len(param_grads_fp32):
            _legacy_C_ops.check_finite_and_unscale(
                param_grads_fp32,
                self._scale,
                param_grads_fp32,
                self._temp_found_inf_fp32,
            )
            self._found_inf = _C_ops.bitwise_or(
                self._found_inf, self._temp_found_inf_fp32
            )

        optimizer_state["state"] = OptimizerState.UNSCALED

    def _update(self):
        """
        Updates the loss_scaling.
        """
        if not self._enable:
            return

        if self._cache_found_inf:
            self._incr_count = 0
            self._decr_count = self._decr_count + 1
            if self._decr_count == self._decr_every_n_nan_or_inf:
                print(
                    f'Found inf or nan, current scale is: {float(self._scale)}, decrease to: {float(self._scale)}*{float(self._decr_ratio)}'
                )
                self._scale = self._scale * self._decr_ratio
                self._decr_count = 0
        else:
            self._decr_count = 0
            self._incr_count = self._incr_count + 1
            if self._incr_count == self._incr_every_n_steps:
                self._scale = self._scale * self._incr_ratio
                self._incr_count = 0

        return

    def is_enable(self) -> bool:
        """
        Enable loss scaling or not.

        Returns:
            bool: enable loss scaling return True else return False.
        """
        return self._enable

    def is_use_dynamic_loss_scaling(self) -> bool:
        """
        Whether to use dynamic loss scaling.

        Returns:
            bool: if fixed loss_scaling is used return False, if the loss scaling is updated dynamically return true.
        """
        return self._use_dynamic_loss_scaling

    def get_init_loss_scaling(self) -> float:
        """
        Return the initial loss scaling factor.

        Returns:
            float:  the initial loss scaling factor.
        """
        return self._init_loss_scaling

    def set_init_loss_scaling(self, new_init_loss_scaling: int) -> None:
        """
        Set the initial loss scaling factor by `new_init_loss_scaling`.

        Args:
            new_init_loss_scaling(int):  The new_init_loss_scaling used to update initial loss scaling factor.s
        """
        self._init_loss_scaling = new_init_loss_scaling
        self._scale = paddle.to_tensor(
            np.array([self._init_loss_scaling]).astype(np.float32)
        )

    def get_incr_ratio(self) -> float:
        """
        Return the multiplier to use when increasing the loss scaling.

        Returns:
            float:  the multiplier to use when increasing the loss scaling.
        """
        return self._incr_ratio

    def set_incr_ratio(self, new_incr_ratio: float) -> None:
        """
        Set the multiplier to use when increasing the loss scaling by `new_incr_ratio`, `new_incr_ratio` should > 1.0.

        Args:
            new_incr_ratio(float):  The new_incr_ratio used to update the multiplier to use when increasing the loss scaling.
        """
        assert new_incr_ratio > 1.0, "The new_incr_ratio must be > 1.0."
        self._incr_ratio = new_incr_ratio

    def get_decr_ratio(self) -> float:
        """
        Get the less-than-one-multiplier to use when decreasing the loss scaling.

        Returns:
            float:  the less-than-one-multiplier to use when decreasing the loss scaling.
        """
        return self._decr_ratio

    def set_decr_ratio(self, new_decr_ratio: float) -> None:
        """
        Set the less-than-one-multiplier to use when decreasing the loss scaling by `new_incr_ratio`, `new_decr_ratio` should < 1.0.

        Args:
            new_decr_ratio(float):  The new_decr_ratio used to update the less-than-one-multiplier to use when decreasing the loss scaling.
        """
        assert new_decr_ratio < 1.0, "The new_decr_ratio must be < 1.0."
        self._decr_ratio = new_decr_ratio

    def get_incr_every_n_steps(self) -> int:
        """
        Return the num `n`, `n` represent increases loss scaling every `n` consecutive steps with finite gradients.

        Returns:
            int:  the num `n`, `n` represent increases loss scaling every `n` consecutive steps with finite gradients.
        """
        return self._incr_every_n_steps

    def set_incr_every_n_steps(self, new_incr_every_n_steps: int) -> None:
        """
        Set the num `n` by `new_incr_every_n_steps`, `n` represent increases loss scaling every `n` consecutive steps with finite gradients.

        Args:
            new_incr_every_n_steps(int):  The new_incr_every_n_steps used to update the num `n`, `n` represent increases loss scaling every `n` consecutive steps with finite gradients.
        """
        self._incr_every_n_steps = new_incr_every_n_steps

    def get_decr_every_n_nan_or_inf(self) -> int:
        """
        Return the num `n`, `n` represent decreases loss scaling every `n` accumulated steps with nan or inf gradients.

        Returns:
            int:  the num `n`, `n` represent decreases loss scaling every `n` accumulated steps with nan or inf gradients.
        """
        return self._decr_every_n_nan_or_inf

    def set_decr_every_n_nan_or_inf(
        self, new_decr_every_n_nan_or_inf: int
    ) -> None:
        """
        Set the num `n` by `new_decr_every_n_nan_or_inf`, `n` represent decreases loss scaling every `n` accumulated steps with nan or inf gradients.

        Args:
            new_decr_every_n_nan_or_inf(int):  The new_decr_every_n_nan_or_inf used to update the num `n`, `n` represent decreases loss scaling every `n` accumulated steps with nan or inf gradients.
        """
        self._decr_every_n_nan_or_inf = new_decr_every_n_nan_or_inf

    def state_dict(self) -> _ScaleStateDict:
        """
        Returns the state of the scaler as a `dict`, If this instance is not enabled, returns an empty dict.

        Returns:
            A dict of scaler includes:
            scale (tensor): The loss scaling factor.
            incr_ratio(float): The multiplier to use when increasing the loss scaling.
            decr_ratio(float): The less-than-one-multiplier to use when decreasing the loss scaling.
            incr_every_n_steps(int): Increases loss scaling every n consecutive steps with finite gradients.
            decr_every_n_nan_or_inf(int): Decreases loss scaling every n accumulated steps with nan or inf gradients.
            incr_count(int): The number of recent consecutive unskipped steps.
            decr_count(int): The number of recent consecutive skipped steps.
            use_dynamic_loss_scaling(bool): Whether to use dynamic loss scaling. If False, fixed loss_scaling is used. If True, the loss scaling is updated dynamically. Default is True.
        """
        return (
            {
                "scale": self._scale.numpy(),
                "incr_ratio": self._incr_ratio,
                "decr_ratio": self._decr_ratio,
                "incr_every_n_steps": self._incr_every_n_steps,
                "decr_every_n_nan_or_inf": self._decr_every_n_nan_or_inf,
                "incr_count": self._incr_count,
                "decr_count": self._decr_count,
                "use_dynamic_loss_scaling": self._use_dynamic_loss_scaling,
            }
            if self._enable
            else {}
        )

    def load_state_dict(self, state_dict: _ScaleStateDict) -> None:
        """
        Loads the scaler state.

        Args:
           state_dict(dict): scaler state. Should be an object returned from a call to `AmpScaler.state_dict()`.
        """
        if not self._enable:
            return

        if len(state_dict) == 0:
            raise RuntimeError(
                "The input state dict is empty, possibly because it was saved "
                "from a disabled instance of GradScaler."
            )

        self._init_loss_scaling = state_dict["scale"][0]
        self._scale = paddle.to_tensor(
            np.array([self._init_loss_scaling]).astype(np.float32)
        )
        self._incr_ratio = state_dict["incr_ratio"]
        self._decr_ratio = state_dict["decr_ratio"]
        self._incr_every_n_steps = state_dict["incr_every_n_steps"]
        self._decr_every_n_nan_or_inf = state_dict["decr_every_n_nan_or_inf"]
        self._incr_count = state_dict["incr_count"]
        self._decr_count = state_dict["decr_count"]
        self._use_dynamic_loss_scaling = state_dict["use_dynamic_loss_scaling"]


class GradScaler(AmpScaler):
    """
    GradScaler is used for Auto-Mixed-Precision training in dynamic graph mode.
    It controls the scaling of loss, helps avoiding numerical overflow.
    The object of this class has nineteen methods `scale()`, `unscale_()`, `minimize()`, `step()`, `update()` and `get`/`set` api of parameters.

    `scale()` is used to multiply the loss by a scale ratio.
    `unscale_()` is used to unscale the gradients of parameters, multiplies the gradients of parameters by 1/(scale ratio)
    `minimize()` is similar as `optimizer.minimize()`, performs parameters updating, and it will update the loss_scaling, it equal to `step()` + `update()`.
    `step()` is similar as `optimizer.step()`, which performs parameters updating.
    `update` is used to update the loss_scaling.


    Commonly, it is used together with `paddle.amp.auto_cast` to achieve Auto-Mixed-Precision in
    dynamic graph mode.

    Args:
        enable(bool, optional): Enable loss scaling or not. Default is True.
        init_loss_scaling (float, optional): The initial loss scaling factor. Default is 65536.0.
        incr_ratio(float, optional): The multiplier to use when increasing the loss
                        scaling. Default is 2.0.
        decr_ratio(float, optional): The less-than-one-multiplier to use when decreasing
                        the loss scaling. Default is 0.5.
        incr_every_n_steps(int, optional): Increases loss scaling every n consecutive
                                steps with finite gradients. Default is 2000.
        decr_every_n_nan_or_inf(int, optional): Decreases loss scaling every n
                                    accumulated steps with nan or inf gradients. Default is 1.
        use_dynamic_loss_scaling(bool, optional): Whether to use dynamic loss scaling. If False, fixed loss_scaling is used. If True, the loss scaling is updated dynamically. Default is True.
    Returns:
        An GradScaler object.

    Examples:

        .. code-block:: python

            >>> import paddle

            >>> model = paddle.nn.Conv2D(3, 2, 3, bias_attr=True)
            >>> optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())
            >>> scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
            >>> data = paddle.rand([10, 3, 32, 32])

            >>> with paddle.amp.auto_cast():
            ...     conv = model(data)
            ...     loss = paddle.mean(conv)

            >>> scaled = scaler.scale(loss)  # scale the loss
            >>> scaled.backward()            # do backward
            >>> scaler.minimize(optimizer, scaled)  # update parameters
            >>> optimizer.clear_grad()
    """

    def __init__(
        self,
        enable: bool = True,
        init_loss_scaling: float = 2.0**16,
        incr_ratio: float = 2.0,
        decr_ratio: float = 0.5,
        incr_every_n_steps: int = 2000,
        decr_every_n_nan_or_inf: int = 1,
        use_dynamic_loss_scaling: bool = True,
    ) -> None:
        super().__init__(
            enable,
            init_loss_scaling,
            incr_ratio,
            decr_ratio,
            incr_every_n_steps,
            decr_every_n_nan_or_inf,
            use_dynamic_loss_scaling,
        )

    def scale(self, var: Tensor) -> Tensor:
        """
        Multiplies a Tensor by the scale factor and returns scaled outputs.
        If this instance of :class:`GradScaler` is not enabled, output are returned unmodified.

        Args:
            var (Tensor):  The tensor to scale.
        Returns:
            The scaled tensor or original tensor.

        Examples:

            .. code-block:: python

                >>> import paddle

                >>> model = paddle.nn.Conv2D(3, 2, 3, bias_attr=True)
                >>> optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())
                >>> scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
                >>> data = paddle.rand([10, 3, 32, 32])

                >>> with paddle.amp.auto_cast():
                ...     conv = model(data)
                ...     loss = paddle.mean(conv)

                >>> scaled = scaler.scale(loss)  # scale the loss
                >>> scaled.backward()            # do backward
                >>> scaler.minimize(optimizer, scaled)  # update parameters
                >>> optimizer.clear_grad()
        """
        return super().scale(var)

    def minimize(
        self,
        optimizer: Optimizer | OptimizerWithMixedPrecision,
        *args: Any,
        **kwargs: Any,
    ) -> tuple[list[Operator], list[tuple[Tensor, Tensor]]]:
        """
        This function is similar as `optimizer.minimize()`, which performs parameters updating.

        If the scaled gradients of parameters contains NAN or INF, the parameters updating is skipped.
        Otherwise, if `unscale_()` has not been called, it first unscales the scaled gradients of parameters, then updates the parameters.

        Finally, the loss scaling ratio is updated.

        Args:
            optimizer(Optimizer):  The optimizer used to update parameters.
            args:  Arguments, which will be forward to `optimizer.minimize()`.
            kwargs: Keyword arguments, which will be forward to `optimizer.minimize()`.

        Examples:

            .. code-block:: python

                >>> import paddle

                >>> model = paddle.nn.Conv2D(3, 2, 3, bias_attr=True)
                >>> optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())
                >>> scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
                >>> data = paddle.rand([10, 3, 32, 32])

                >>> with paddle.amp.auto_cast():
                ...     conv = model(data)
                ...     loss = paddle.mean(conv)

                >>> scaled = scaler.scale(loss)  # scale the loss
                >>> scaled.backward()            # do backward
                >>> scaler.minimize(optimizer, scaled)  # update parameters
                >>> optimizer.clear_grad()
        """
        return super().minimize(optimizer, *args, **kwargs)

    def step(self, optimizer: Optimizer) -> None:
        """
        This function is similar as `optimizer.step()`, which performs parameters updating.

        If the scaled gradients of parameters contains NAN or INF, the parameters updating is skipped.
        Otherwise, if `unscale_()` has not been called, it first unscales the scaled gradients of parameters, then updates the parameters.

        Args:
            optimizer(Optimizer):  The optimizer used to update parameters.

        Examples:

            .. code-block:: python

                >>> # doctest: +REQUIRES(env:GPU)
                >>> import paddle
                >>> paddle.device.set_device('gpu')

                >>> model = paddle.nn.Conv2D(3, 2, 3, bias_attr=True)
                >>> optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())
                >>> scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
                >>> data = paddle.rand([10, 3, 32, 32])
                >>> with paddle.amp.auto_cast():
                ...     conv = model(data)
                ...     loss = paddle.mean(conv)
                >>> scaled = scaler.scale(loss)  # scale the loss
                >>> scaled.backward()            # do backward
                >>> scaler.step(optimizer)       # update parameters
                >>> scaler.update()              # update the loss scaling ratio
                >>> optimizer.clear_grad()
        """
        if not self._enable:
            return optimizer.step()

        optimizer_state = self._optimizer_states[id(optimizer)]
        if optimizer_state["state"] is OptimizerState.STEPPED:
            raise RuntimeError(
                "step() has already been called since the last update()."
            )

        #  unscale the grad
        if optimizer_state["state"] is OptimizerState.INIT:
            self._unscale(optimizer)

        if hasattr(optimizer, "_set_auxiliary_var"):
            optimizer._set_auxiliary_var('found_inf', self._found_inf)
            optimizer.step()
            self._cache_found_inf = optimizer._get_auxiliary_var('found_inf')
        else:
            if self._found_inf:
                self._cache_found_inf = True
            else:
                optimizer.step()
                self._cache_found_inf = False

        optimizer_state["state"] = OptimizerState.STEPPED

        if not self._use_dynamic_loss_scaling:
            self._optimizer_states = defaultdict(_refresh_optimizer_state)

    def update(self) -> None:
        """
        Updates the loss_scaling.

        Examples:

            .. code-block:: python

                >>> # doctest: +REQUIRES(env:GPU)
                >>> import paddle

                >>> paddle.device.set_device('gpu')
                >>> model = paddle.nn.Conv2D(3, 2, 3, bias_attr=True)
                >>> optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())
                >>> scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
                >>> data = paddle.rand([10, 3, 32, 32])
                >>> with paddle.amp.auto_cast():
                ...     conv = model(data)
                ...     loss = paddle.mean(conv)
                >>> scaled = scaler.scale(loss)     # scale the loss
                >>> scaled.backward()               # do backward
                >>> scaler.step(optimizer)          # update parameters
                >>> scaler.update()                 # update the loss scaling ratio
                >>> optimizer.clear_grad()
        """
        if not self._enable:
            return
        if self._use_dynamic_loss_scaling:
            self._update()
            self._optimizer_states = defaultdict(_refresh_optimizer_state)
        return

    def unscale_(self, optimizer):
        """
        Unscale the gradients of parameters, multiplies the gradients of parameters by 1/(loss scaling ratio).
        If this instance of :class:`GradScaler` is not enabled, output are returned unmodified.

        Args:
            optimizer(Optimizer):  The optimizer used to update parameters.

        Returns:
            The unscaled parameters or original parameters.

        Examples:

            .. code-block:: python

                >>> # doctest: +REQUIRES(env:GPU)
                >>> import paddle

                >>> paddle.device.set_device('gpu')
                >>> model = paddle.nn.Conv2D(3, 2, 3, bias_attr=True)
                >>> optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())
                >>> scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
                >>> data = paddle.rand([10, 3, 32, 32])
                >>> with paddle.amp.auto_cast():
                ...     conv = model(data)
                ...     loss = paddle.mean(conv)
                >>> scaled = scaler.scale(loss)  # scale the loss
                >>> scaled.backward()            # do backward
                >>> scaler.unscale_(optimizer)    # unscale the parameter
                >>> scaler.step(optimizer)
                >>> scaler.update()
                >>> optimizer.clear_grad()
        """
        return super()._unscale(optimizer)

    def is_enable(self) -> bool:
        """
        Enable loss scaling or not.

        Returns:
            bool: enable loss scaling return True else return False.

        Examples:
            .. code-block:: python

                >>> # doctest: +REQUIRES(env:GPU, env:XPU)
                >>> import paddle
                >>> scaler = paddle.amp.GradScaler(
                ...     enable=True,
                ...     init_loss_scaling=1024,
                ...     incr_ratio=2.0,
                ...     decr_ratio=0.5,
                ...     incr_every_n_steps=1000,
                ...     decr_every_n_nan_or_inf=2,
                ...     use_dynamic_loss_scaling=True
                ... )
                >>> enable = scaler.is_enable()
                >>> print(enable)
                True
        """
        return super().is_enable()

    def is_use_dynamic_loss_scaling(self) -> bool:
        """
        Whether to use dynamic loss scaling.

        Returns:
            bool: if fixed loss_scaling is used return False, if the loss scaling is updated dynamically return True.

        Examples:
            .. code-block:: python

                >>> # doctest: +REQUIRES(env:GPU, env:XPU)
                >>> import paddle
                >>> scaler = paddle.amp.GradScaler(
                ...     enable=True,
                ...     init_loss_scaling=1024,
                ...     incr_ratio=2.0,
                ...     decr_ratio=0.5,
                ...     incr_every_n_steps=1000,
                ...     decr_every_n_nan_or_inf=2,
                ...     use_dynamic_loss_scaling=True
                ... )
                >>> use_dynamic_loss_scaling = scaler.is_use_dynamic_loss_scaling()
                >>> print(use_dynamic_loss_scaling)
                True
        """
        return super().is_use_dynamic_loss_scaling()

    def get_init_loss_scaling(self) -> float:
        """
        Return the initial loss scaling factor.

        Returns:
            float:  the initial loss scaling factor.

        Examples:
            .. code-block:: python

                >>> # doctest: +REQUIRES(env:GPU, env:XPU)
                >>> import paddle
                >>> scaler = paddle.amp.GradScaler(
                ...     enable=True,
                ...     init_loss_scaling=1024,
                ...     incr_ratio=2.0,
                ...     decr_ratio=0.5,
                ...     incr_every_n_steps=1000,
                ...     decr_every_n_nan_or_inf=2,
                ...     use_dynamic_loss_scaling=True
                ... )
                >>> init_loss_scaling = scaler.get_init_loss_scaling()
                >>> print(init_loss_scaling)
                1024
        """
        return super().get_init_loss_scaling()

    def set_init_loss_scaling(self, new_init_loss_scaling: int) -> None:
        """
        Set the initial loss scaling factor by `new_init_loss_scaling`.

        Args:
            new_init_loss_scaling(float):  The new_init_loss_scaling used to update initial loss scaling factor.

        Examples:
            .. code-block:: python

                >>> # doctest: +REQUIRES(env:GPU, env:XPU)
                >>> import paddle
                >>> scaler = paddle.amp.GradScaler(
                ...     enable=True,
                ...     init_loss_scaling=1024,
                ...     incr_ratio=2.0,
                ...     decr_ratio=0.5,
                ...     incr_every_n_steps=1000,
                ...     decr_every_n_nan_or_inf=2,
                ...     use_dynamic_loss_scaling=True
                ... )
                >>> print(scaler.get_init_loss_scaling())
                1024
                >>> new_init_loss_scaling = 1000
                >>> scaler.set_init_loss_scaling(new_init_loss_scaling)
                >>> print(scaler.get_init_loss_scaling())
                1000
        """
        super().set_init_loss_scaling(new_init_loss_scaling)

    def get_incr_ratio(self) -> float:
        """
        Return the multiplier to use when increasing the loss scaling.

        Returns:
            float:  the multiplier to use when increasing the loss scaling.

        Examples:
            .. code-block:: python

                >>> # doctest: +REQUIRES(env:GPU, env:XPU)
                >>> import paddle
                >>> scaler = paddle.amp.GradScaler(
                ...     enable=True,
                ...     init_loss_scaling=1024,
                ...     incr_ratio=2.0,
                ...     decr_ratio=0.5,
                ...     incr_every_n_steps=1000,
                ...     decr_every_n_nan_or_inf=2,
                ...     use_dynamic_loss_scaling=True
                ... )
                >>> incr_ratio = scaler.get_incr_ratio()
                >>> print(incr_ratio)
                2.0
        """
        return super().get_incr_ratio()

    def set_incr_ratio(self, new_incr_ratio: float) -> None:
        """
        Set the multiplier to use when increasing the loss scaling by `new_incr_ratio`, `new_incr_ratio` should > 1.0.

        Args:
            new_incr_ratio(float):  The new_incr_ratio used to update the multiplier to use when increasing the loss scaling.

        Examples:
            .. code-block:: python

                >>> # doctest: +REQUIRES(env:GPU, env:XPU)
                >>> import paddle
                >>> scaler = paddle.amp.GradScaler(
                ...     enable=True,
                ...     init_loss_scaling=1024,
                ...     incr_ratio=2.0,
                ...     decr_ratio=0.5,
                ...     incr_every_n_steps=1000,
                ...     decr_every_n_nan_or_inf=2,
                ...     use_dynamic_loss_scaling=True
                ... )
                >>> print(scaler.get_incr_ratio())
                2.0
                >>> new_incr_ratio = 3.0
                >>> scaler.set_incr_ratio(new_incr_ratio)
                >>> print(scaler.get_incr_ratio())
                3.0
        """
        super().set_incr_ratio(new_incr_ratio)

    def get_decr_ratio(self) -> float:
        """
        Get the less-than-one-multiplier to use when decreasing the loss scaling.

        Returns:
            float:  the less-than-one-multiplier to use when decreasing the loss scaling.

        Examples:
            .. code-block:: python

                >>> # doctest: +REQUIRES(env:GPU, env:XPU)
                >>> import paddle
                >>> scaler = paddle.amp.GradScaler(
                ...     enable=True,
                ...     init_loss_scaling=1024,
                ...     incr_ratio=2.0,
                ...     decr_ratio=0.5,
                ...     incr_every_n_steps=1000,
                ...     decr_every_n_nan_or_inf=2,
                ...     use_dynamic_loss_scaling=True
                ... )
                >>> decr_ratio = scaler.get_decr_ratio()
                >>> print(decr_ratio)
                0.5
        """
        return super().get_decr_ratio()

    def set_decr_ratio(self, new_decr_ratio: float) -> None:
        """
        Set the less-than-one-multiplier to use when decreasing the loss scaling by `new_incr_ratio`, `new_decr_ratio` should < 1.0.

        Args:
            new_decr_ratio(float):  The new_decr_ratio used to update the less-than-one-multiplier to use when decreasing the loss scaling.

        Examples:
            .. code-block:: python

                >>> # doctest: +REQUIRES(env:GPU, env:XPU)
                >>> import paddle
                >>> scaler = paddle.amp.GradScaler(
                ...     enable=True,
                ...     init_loss_scaling=1024,
                ...     incr_ratio=2.0,
                ...     decr_ratio=0.5,
                ...     incr_every_n_steps=1000,
                ...     decr_every_n_nan_or_inf=2,
                ...     use_dynamic_loss_scaling=True
                ... )
                >>> print(scaler.get_decr_ratio())
                0.5
                >>> new_decr_ratio = 0.1
                >>> scaler.set_decr_ratio(new_decr_ratio)
                >>> print(scaler.get_decr_ratio())
                0.1
        """
        super().set_decr_ratio(new_decr_ratio)

    def get_incr_every_n_steps(self) -> int:
        """
        Return the num `n`, `n` represent increases loss scaling every `n` consecutive steps with finite gradients.

        Returns:
            int:  the num `n`, `n` represent increases loss scaling every `n` consecutive steps with finite gradients.

        Examples:
            .. code-block:: python

                >>> # doctest: +REQUIRES(env:GPU, env:XPU)
                >>> import paddle
                >>> scaler = paddle.amp.GradScaler(
                ...     enable=True,
                ...     init_loss_scaling=1024,
                ...     incr_ratio=2.0,
                ...     decr_ratio=0.5,
                ...     incr_every_n_steps=1000,
                ...     decr_every_n_nan_or_inf=2,
                ...     use_dynamic_loss_scaling=True
                ... )
                >>> incr_every_n_steps = scaler.get_incr_every_n_steps()
                >>> print(incr_every_n_steps)
                1000
        """
        return super().get_incr_every_n_steps()

    def set_incr_every_n_steps(self, new_incr_every_n_steps: int) -> None:
        """
        Set the num `n` by `new_incr_every_n_steps`, `n` represent increases loss scaling every `n` consecutive steps with finite gradients.

        Args:
            new_incr_every_n_steps(int):  The new_incr_every_n_steps used to update the num `n`, `n` represent increases loss scaling every `n` consecutive steps with finite gradients.

        Examples:
            .. code-block:: python

                >>> # doctest: +REQUIRES(env:GPU, env:XPU)
                >>> import paddle
                >>> scaler = paddle.amp.GradScaler(
                ...     enable=True,
                ...     init_loss_scaling=1024,
                ...     incr_ratio=2.0,
                ...     decr_ratio=0.5,
                ...     incr_every_n_steps=1000,
                ...     decr_every_n_nan_or_inf=2,
                ...     use_dynamic_loss_scaling=True
                ... )
                >>> print(scaler.get_incr_every_n_steps())
                1000
                >>> new_incr_every_n_steps = 2000
                >>> scaler.set_incr_every_n_steps(new_incr_every_n_steps)
                >>> print(scaler.get_incr_every_n_steps())
                2000
        """
        super().set_incr_every_n_steps(new_incr_every_n_steps)

    def get_decr_every_n_nan_or_inf(self) -> int:
        """
        Return the num `n`, `n` represent decreases loss scaling every `n` accumulated steps with nan or inf gradients.

        Returns:
            int: the num `n`, `n` represent decreases loss scaling every `n` accumulated steps with nan or inf gradients.

        Examples:
            .. code-block:: python

                >>> # doctest: +REQUIRES(env:GPU, env:XPU)
                >>> import paddle
                >>> scaler = paddle.amp.GradScaler(
                ...     enable=True,
                ...     init_loss_scaling=1024,
                ...     incr_ratio=2.0,
                ...     decr_ratio=0.5,
                ...     incr_every_n_steps=1000,
                ...     decr_every_n_nan_or_inf=2,
                ...     use_dynamic_loss_scaling=True
                ... )
                >>> decr_every_n_nan_or_inf = scaler.get_decr_every_n_nan_or_inf()
                >>> print(decr_every_n_nan_or_inf)
                2
        """
        return super().get_decr_every_n_nan_or_inf()

    def set_decr_every_n_nan_or_inf(
        self, new_decr_every_n_nan_or_inf: int
    ) -> None:
        """
        Set the num `n` by `new_decr_every_n_nan_or_inf`, `n` represent decreases loss scaling every `n` accumulated steps with nan or inf gradients.

        Args:
            new_decr_every_n_nan_or_inf(int):  The new_decr_every_n_nan_or_inf used to update the num `n`, `n` represent decreases loss scaling every `n` accumulated steps with nan or inf gradients.

        Examples:
            .. code-block:: python

                >>> # doctest: +REQUIRES(env:GPU, env:XPU)
                >>> import paddle
                >>> scaler = paddle.amp.GradScaler(
                ...     enable=True,
                ...     init_loss_scaling=1024,
                ...     incr_ratio=2.0,
                ...     decr_ratio=0.5,
                ...     incr_every_n_steps=1000,
                ...     decr_every_n_nan_or_inf=2,
                ...     use_dynamic_loss_scaling=True
                ... )
                >>> print(scaler.get_decr_every_n_nan_or_inf())
                2
                >>> new_decr_every_n_nan_or_inf = 3
                >>> scaler.set_decr_every_n_nan_or_inf(new_decr_every_n_nan_or_inf)
                >>> print(scaler.get_decr_every_n_nan_or_inf())
                3
        """
        super().set_decr_every_n_nan_or_inf(new_decr_every_n_nan_or_inf)

    def state_dict(self) -> _ScaleStateDict:
        """
        Returns the state of the scaler as a `dict`, If this instance is not enabled, returns an empty dict.

        Returns:
            A dict of scaler includes:
            scale (tensor): The loss scaling factor.
            incr_ratio(float): The multiplier to use when increasing the loss scaling.
            decr_ratio(float): The less-than-one-multiplier to use when decreasing the loss scaling.
            incr_every_n_steps(int): Increases loss scaling every n consecutive steps with finite gradients.
            decr_every_n_nan_or_inf(int): Decreases loss scaling every n accumulated steps with nan or inf gradients.
            incr_count(int): The number of recent consecutive unskipped steps.
            decr_count(int): The number of recent consecutive skipped steps.
            use_dynamic_loss_scaling(bool): Whether to use dynamic loss scaling. If False, fixed loss_scaling is used. If True, the loss scaling is updated dynamically. Default is True.


        Examples:

            .. code-block:: python

                >>> # doctest: +REQUIRES(env:GPU, env:XPU)
                >>> import paddle

                >>> scaler = paddle.amp.GradScaler(
                ...     enable=True,
                ...     init_loss_scaling=1024,
                ...     incr_ratio=2.0,
                ...     decr_ratio=0.5,
                ...     incr_every_n_steps=1000,
                ...     decr_every_n_nan_or_inf=2,
                ...     use_dynamic_loss_scaling=True
                ... )
                >>> scaler_state = scaler.state_dict()
        """
        return super().state_dict()

    def load_state_dict(self, state_dict: _ScaleStateDict) -> None:
        """
        Loads the scaler state.

        Args:
            state_dict(dict): scaler state. Should be an object returned from a call to `GradScaler.state_dict()`.

        Examples:

            .. code-block:: python

                >>> # doctest: +REQUIRES(env:GPU, env:XPU)
                >>> import paddle

                >>> scaler = paddle.amp.GradScaler(
                ...     enable=True,
                ...     init_loss_scaling=1024,
                ...     incr_ratio=2.0,
                ...     decr_ratio=0.5,
                ...     incr_every_n_steps=1000,
                ...     decr_every_n_nan_or_inf=2,
                ...     use_dynamic_loss_scaling=True
                ... )
                >>> scaler_state = scaler.state_dict()
                >>> scaler.load_state_dict(scaler_state)
        """
        super().load_state_dict(state_dict)
