# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import enum
import math
import typing
from typing import (
    TYPE_CHECKING,
    Any,
    overload,
)

import paddle
import paddle.nn.functional as F
from paddle.distribution import (
    constraint,
    distribution,
    transformed_distribution,
    variable,
)

if TYPE_CHECKING:
    from collections.abc import Sequence

    from paddle import Tensor
    from paddle.distribution import Distribution, TransformedDistribution

__all__ = [
    'Transform',
    'AbsTransform',
    'AffineTransform',
    'ChainTransform',
    'ExpTransform',
    'IndependentTransform',
    'PowerTransform',
    'ReshapeTransform',
    'SigmoidTransform',
    'SoftmaxTransform',
    'StackTransform',
    'StickBreakingTransform',
    'TanhTransform',
]


class Type(enum.Enum):
    """Mapping type of a transformation."""

    BIJECTION = 'bijection'  # bijective(injective and surjective)
    INJECTION = 'injection'  # injective-only
    SURJECTION = 'surjection'  # surjective-only
    OTHER = 'other'  # general, neither injective nor surjective

    @classmethod
    def is_injective(cls, _type):
        """Both bijection and injection are injective mapping."""
        return _type in (cls.BIJECTION, cls.INJECTION)


class Transform:
    r"""Base class for the transformations of random variables.

    ``Transform`` can be used to represent any differentiable and injective
    function from the subset of :math:`R^n` to subset of :math:`R^m`, generally
    used for transforming a random sample generated by ``Distribution``
    instance.

    Suppose :math:`X` is a K-dimensional random variable with probability
    density function :math:`p_X(x)`. A new random variable :math:`Y = f(X)` may
    be defined by transforming :math:`X` with a suitably well-behaved function
    :math:`f`. It suffices for what follows to note that if `f` is one-to-one and
    its inverse :math:`f^{-1}` have a well-defined Jacobian, then the density of
    :math:`Y` is

    .. math::

        p_Y(y) = p_X(f^{-1}(y)) |det J_{f^{-1}}(y)|

    where det is the matrix determinant operation and :math:`J_{f^{-1}}(y)` is
    the Jacobian matrix of :math:`f^{-1}` evaluated at :math:`y`.
    Taking :math:`x = f^{-1}(y)`, the Jacobian matrix is defined by

    .. math::

        J(y) = \begin{bmatrix}
        {\frac{\partial x_1}{\partial y_1}} &{\frac{\partial x_1}{\partial y_2}}
        &{\cdots} &{\frac{\partial x_1}{\partial y_K}} \\
        {\frac{\partial x_2}{\partial y_1}}  &{\frac{\partial x_2}
        {\partial y_2}}&{\cdots} &{\frac{\partial x_2}{\partial y_K}} \\
        {\vdots} &{\vdots} &{\ddots} &{\vdots}\\
        {\frac{\partial x_K}{\partial y_1}} &{\frac{\partial x_K}{\partial y_2}}
        &{\cdots} &{\frac{\partial x_K}{\partial y_K}}
        \end{bmatrix}

    A ``Transform`` can be characterized by three operations:

        #. forward
           Forward implements :math:`x \rightarrow f(x)`, and is used to convert
           one random outcome into another.
        #. inverse
           Undoes the transformation :math:`y \rightarrow f^{-1}(y)`.
        #. log_det_jacobian
           The log of the absolute value of the determinant of the matrix of all
           first-order partial derivatives of the inverse function.

    Subclass typically implement follow methods:

        * _forward
        * _inverse
        * _forward_log_det_jacobian
        * _inverse_log_det_jacobian (optional)

    If the transform changes the shape of the input, you must also implemented:

        * _forward_shape
        * _inverse_shape

    """

    _type = Type.INJECTION

    def __init__(self) -> None:
        super().__init__()

    @classmethod
    def _is_injective(cls):
        """Is the transformation type one-to-one or not.

        Returns:
            bool: ``True`` denotes injective. ``False`` denotes non-injective.
        """
        return Type.is_injective(cls._type)

    @overload
    def __call__(self, input: Tensor) -> Tensor: ...

    @overload
    def __call__(self, input: Distribution) -> TransformedDistribution: ...

    @overload
    def __call__(self, input: Transform) -> ChainTransform: ...

    def __call__(self, input) -> Any:
        """Make this instance as a callable object. The return value is
        depending on the input type.

        * If the input is a ``Tensor`` instance, return
          ``self.forward(input)`` .
        * If the input is a ``Distribution`` instance, return
          ``TransformedDistribution(base=input, transforms=[self])`` .
        * If the input is a ``Transform`` instance, return
          ``ChainTransform([self, input])`` .

        Args:
            input (Tensor|Distribution|Transform): The input value.

        Returns:
            [Tensor|TransformedDistribution|ChainTransform]: The return value.
        """
        if isinstance(input, distribution.Distribution):
            return transformed_distribution.TransformedDistribution(
                input, [self]
            )
        if isinstance(input, Transform):
            return ChainTransform([self, input])
        return self.forward(input)

    def forward(self, x: Tensor) -> Tensor:
        """Forward transformation with mapping :math:`y = f(x)`.

        Useful for turning one random outcome into another.

        Args:
            x (Tensor): Input parameter, generally is a sample generated
                from ``Distribution``.

        Returns:
            Tensor: Outcome of forward transformation.
        """
        if not isinstance(
            x, (paddle.base.framework.Variable, paddle.pir.Value)
        ):
            raise TypeError(
                f"Expected 'x' is a Tensor or Real, but got {type(x)}."
            )
        if x.dim() < self._domain.event_rank:
            raise ValueError(
                f'The dimensions of x({x.dim()}) should be '
                f'grater than or equal to {self._domain.event_rank}'
            )
        return self._forward(x)

    def inverse(self, y: Tensor) -> Tensor:
        """Inverse transformation :math:`x = f^{-1}(y)`. It's useful for "reversing"
        a transformation to compute one probability in terms of another.

        Args:
            y (Tensor): Input parameter for inverse transformation.

        Returns:
            Tensor: Outcome of inverse transform.
        """
        if not isinstance(
            y, (paddle.base.framework.Variable, paddle.pir.Value)
        ):
            raise TypeError(
                f"Expected 'y' is a Tensor or Real, but got {type(y)}."
            )
        if y.dim() < self._codomain.event_rank:
            raise ValueError(
                f'The dimensions of y({y.dim()}) should be '
                f'grater than or equal to {self._codomain.event_rank}'
            )
        return self._inverse(y)

    def forward_log_det_jacobian(self, x: Tensor) -> Tensor:
        """The log of the absolute value of the determinant of the matrix of all
        first-order partial derivatives of the inverse function.

        Args:
            x (Tensor): Input tensor, generally is a sample generated from
                ``Distribution``

        Returns:
            Tensor: The log of the absolute value of Jacobian determinant.
        """
        if not isinstance(
            x, (paddle.base.framework.Variable, paddle.pir.Value)
        ):
            raise TypeError(
                f"Expected 'y' is a Tensor or Real, but got {type(x)}."
            )
        if (
            isinstance(x, (paddle.base.framework.Variable, paddle.pir.Value))
            and x.dim() < self._domain.event_rank
        ):
            raise ValueError(
                f'The dimensions of x({x.dim()}) should be '
                f'grater than or equal to {self._domain.event_rank}'
            )
        if not self._is_injective():
            raise NotImplementedError(
                "forward_log_det_jacobian can't be implemented for non-injective"
                "transforms."
            )

        return self._call_forward_log_det_jacobian(x)

    def inverse_log_det_jacobian(self, y: Tensor) -> Tensor:
        """Compute :math:`log|det J_{f^{-1}}(y)|`.
        Note that ``forward_log_det_jacobian`` is the negative of this function,
        evaluated at :math:`f^{-1}(y)`.

        Args:
            y (Tensor): The input to the ``inverse`` Jacobian determinant
                evaluation.

        Returns:
            Tensor: The value of :math:`log|det J_{f^{-1}}(y)|`.
        """
        if not isinstance(
            y, (paddle.base.framework.Variable, paddle.pir.Value)
        ):
            raise TypeError(f"Expected 'y' is a Tensor, but got {type(y)}.")
        if y.dim() < self._codomain.event_rank:
            raise ValueError(
                f'The dimensions of y({y.dim()}) should be '
                f'grater than or equal to {self._codomain.event_rank}'
            )
        return self._call_inverse_log_det_jacobian(y)

    def forward_shape(self, shape: Sequence[int]) -> Sequence[int]:
        """Infer the shape of forward transformation.

        Args:
            shape (Sequence[int]): The input shape.

        Returns:
            Sequence[int]: The output shape.
        """
        if not isinstance(shape, typing.Sequence):
            raise TypeError(
                f"Expected shape is Sequence[int] type, but got {type(shape)}."
            )
        return self._forward_shape(shape)

    def inverse_shape(self, shape: Sequence[int]) -> Sequence[int]:
        """Infer the shape of inverse transformation.

        Args:
            shape (Sequence[int]): The input shape of inverse transformation.

        Returns:
            Sequence[int]: The output shape of inverse transformation.
        """
        if not isinstance(shape, typing.Sequence):
            raise TypeError(
                f"Expected shape is Sequence[int] type, but got {type(shape)}."
            )
        return self._inverse_shape(shape)

    @property
    def _domain(self) -> variable.Variable:
        """The domain of this transformation"""
        return variable.real

    @property
    def _codomain(self) -> variable.Variable:
        """The codomain of this transformation"""
        return variable.real

    def _forward(self, x: Tensor) -> Tensor:
        """Inner method for public API ``forward``, subclass should
        overwrite this method for supporting forward transformation.
        """
        raise NotImplementedError('Forward not implemented')

    def _inverse(self, y: Tensor) -> Tensor:
        """Inner method of public API ``inverse``, subclass should
        overwrite this method for supporting inverse transformation.
        """
        raise NotImplementedError('Inverse not implemented')

    def _call_forward_log_det_jacobian(self, x: Tensor) -> Tensor:
        """Inner method called by ``forward_log_det_jacobian``."""
        if hasattr(self, '_forward_log_det_jacobian'):
            return self._forward_log_det_jacobian(x)
        if hasattr(self, '_inverse_log_det_jacobian'):
            return -self._inverse_log_det_jacobian(self.forward(x))
        raise NotImplementedError(
            'Neither _forward_log_det_jacobian nor _inverse_log_det_jacobian'
            'is implemented. One of them is required.'
        )

    def _call_inverse_log_det_jacobian(self, y: Tensor) -> Tensor:
        """Inner method called by ``inverse_log_det_jacobian``"""
        if hasattr(self, '_inverse_log_det_jacobian'):
            return self._inverse_log_det_jacobian(y)
        if hasattr(self, '_forward_log_det_jacobian'):
            return -self._forward_log_det_jacobian(self._inverse(y))
        raise NotImplementedError(
            'Neither _forward_log_det_jacobian nor _inverse_log_det_jacobian '
            'is implemented. One of them is required'
        )

    def _forward_shape(self, shape: Sequence[int]) -> Sequence[int]:
        """Inner method called by ``forward_shape``, which is used to infer the
        forward shape. Subclass should overwrite this method for supporting
        ``forward_shape``.
        """
        return shape

    def _inverse_shape(self, shape: Sequence[int]) -> Sequence[int]:
        """Inner method called by ``inverse_shape``, which is used to infer the
        inverse shape. Subclass should overwrite this method for supporting
        ``inverse_shape``.
        """
        return shape


class AbsTransform(Transform):
    r"""Absolute transformation with formula :math:`y = f(x) = abs(x)`,
    element-wise.

    This non-injective transformation allows for transformations of scalar
    distributions with the absolute value function, which maps ``(-inf, inf)``
    to ``[0, inf)`` .

    * For ``y`` in ``(0, inf)`` , ``AbsTransform.inverse(y)`` returns the set inverse
      ``{x  in (-inf, inf) : |x| = y}`` as a tuple, ``-y, y`` .
    * For ``y`` equal ``0`` , ``AbsTransform.inverse(0)`` returns ``0, 0``, which is not
      the set inverse (the set inverse is the singleton {0}), but "works" in
      conjunction with ``TransformedDistribution`` to produce a left
      semi-continuous pdf.
    * For ``y`` in ``(-inf, 0)`` , ``AbsTransform.inverse(y)`` returns the
      wrong thing ``-y, y``. This is done for efficiency.

    Examples:

        .. code-block:: python

            >>> import paddle

            >>> abs = paddle.distribution.AbsTransform()

            >>> print(abs.forward(paddle.to_tensor([-1., 0., 1.])))
            Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True,
                    [1., 0., 1.])

            >>> print(abs.inverse(paddle.to_tensor([1.])))
            (Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
                    [-1.]), Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
                    [1.]))

            >>> # The |dX/dY| is constant 1. So Log|dX/dY| == 0
            >>> print(abs.inverse_log_det_jacobian(paddle.to_tensor(1.)))
            (Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
                    0.), Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
                    0.))

            >>> #Special case handling of 0.
            >>> print(abs.inverse(paddle.to_tensor([0.])))
            (Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
                    [0.]), Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
                    [0.]))
            >>> print(abs.inverse_log_det_jacobian(paddle.to_tensor(0.)))
            (Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
                    0.), Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
                    0.))

    """

    _type = Type.SURJECTION

    def _forward(self, x: Tensor) -> Tensor:
        return x.abs()

    def _inverse(self, y: Tensor) -> tuple[Tensor, Tensor]:
        return -y, y

    def _inverse_log_det_jacobian(self, y: Tensor) -> tuple[Tensor, Tensor]:
        zero = paddle.zeros([], dtype=y.dtype)
        return zero, zero

    @property
    def _domain(self) -> variable.Real:
        return variable.real

    @property
    def _codomain(self) -> variable.Positive:
        return variable.positive


class AffineTransform(Transform):
    r"""Affine transformation with mapping
    :math:`y = \text{loc} + \text{scale} \times x`.

    Args:
        loc (Tensor): The location parameter.
        scale (Tensor): The scale parameter.

    Examples:

        .. code-block:: python

            >>> import paddle

            >>> x = paddle.to_tensor([1., 2.])
            >>> affine = paddle.distribution.AffineTransform(paddle.to_tensor(0.), paddle.to_tensor(1.))

            >>> print(affine.forward(x))
            Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True,
                    [1., 2.])
            >>> print(affine.inverse(affine.forward(x)))
            Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True,
                    [1., 2.])
            >>> print(affine.forward_log_det_jacobian(x))
            Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
                    0.)
    """

    _type = Type.BIJECTION

    def __init__(self, loc: Tensor, scale: Tensor) -> None:
        if not isinstance(
            loc, (paddle.base.framework.Variable, paddle.pir.Value)
        ):
            raise TypeError(f"Expected 'loc' is a Tensor, but got {type(loc)}")
        if not isinstance(
            scale, (paddle.base.framework.Variable, paddle.pir.Value)
        ):
            raise TypeError(
                f"Expected scale is a Tensor, but got {type(scale)}"
            )
        self._loc = loc
        self._scale = scale
        super().__init__()

    @property
    def loc(self) -> Tensor:
        return self._loc

    @property
    def scale(self) -> Tensor:
        return self._scale

    def _forward(self, x: Tensor) -> Tensor:
        return self._loc + self._scale * x

    def _inverse(self, y: Tensor) -> Tensor:
        return (y - self._loc) / self._scale

    def _forward_log_det_jacobian(self, x: Tensor) -> Tensor:
        return paddle.abs(self._scale).log()

    def _forward_shape(self, shape: Sequence[int]) -> Sequence[int]:
        return tuple(
            paddle.broadcast_shape(
                paddle.broadcast_shape(shape, self._loc.shape),
                self._scale.shape,
            )
        )

    def _inverse_shape(self, shape: Sequence[int]) -> Sequence[int]:
        return tuple(
            paddle.broadcast_shape(
                paddle.broadcast_shape(shape, self._loc.shape),
                self._scale.shape,
            )
        )

    @property
    def _domain(self) -> variable.Real:
        return variable.real

    @property
    def _codomain(self) -> variable.Real:
        return variable.real


class ChainTransform(Transform):
    r"""Composes multiple transforms in a chain.

    Args:
        transforms (Sequence[Transform]): A sequence of transformations.

    Examples:

        .. code-block:: python

            >>> import paddle


            >>> x = paddle.to_tensor([0., 1., 2., 3.])

            >>> chain = paddle.distribution.ChainTransform((
            ...     paddle.distribution.AffineTransform(
            ...         paddle.to_tensor(0.), paddle.to_tensor(1.)),
            ...     paddle.distribution.ExpTransform()
            >>> ))
            >>> print(chain.forward(x))
            Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True,
                    [1.         , 2.71828175 , 7.38905621 , 20.08553696])
            >>> print(chain.inverse(chain.forward(x)))
            Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True,
                    [0., 1., 2., 3.])
            >>> print(chain.forward_log_det_jacobian(x))
            Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True,
                    [0., 1., 2., 3.])
            >>> print(chain.inverse_log_det_jacobian(chain.forward(x)))
            Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True,
                    [ 0., -1., -2., -3.])
    """

    def __init__(self, transforms: Sequence[Transform]) -> None:
        if not isinstance(transforms, typing.Sequence):
            raise TypeError(
                f"Expected type of 'transforms' is Sequence, but got {type(transforms)}"
            )
        if not all(isinstance(t, Transform) for t in transforms):
            raise TypeError(
                "All elements of transforms should be Transform type."
            )

        self.transforms = transforms
        super().__init__()

    def _is_injective(self) -> bool:
        return all(t._is_injective() for t in self.transforms)

    def _forward(self, x: Tensor) -> Tensor:
        for transform in self.transforms:
            x = transform.forward(x)
        return x

    def _inverse(self, y: Tensor) -> Tensor:
        for transform in reversed(self.transforms):
            y = transform.inverse(y)
        return y

    def _forward_log_det_jacobian(self, x: Tensor) -> float:
        value = 0.0
        event_rank = self._domain.event_rank
        for t in self.transforms:
            value += self._sum_rightmost(
                t.forward_log_det_jacobian(x), event_rank - t._domain.event_rank
            )
            x = t.forward(x)
            event_rank += t._codomain.event_rank - t._domain.event_rank
        return value

    def _forward_shape(self, shape: Sequence[int]) -> Sequence[int]:
        for transform in self.transforms:
            shape = transform.forward_shape(shape)
        return shape

    def _inverse_shape(self, shape: Sequence[int]) -> Sequence[int]:
        for transform in self.transforms:
            shape = transform.inverse_shape(shape)
        return shape

    def _sum_rightmost(self, value: Tensor, n: int) -> Tensor:
        """sum value along rightmost n dim"""
        return value.sum(list(range(-n, 0))) if n > 0 else value

    @property
    def _domain(self) -> variable.Independent:
        domain = self.transforms[0]._domain

        # Compute the lower bound of input dimensions for chain transform.
        #
        # Suppose the dimensions of input tensor is N, and chain [t0,...ti,...tm],
        # ti(in) denotes ti.domain.event_rank, ti(out) denotes ti.codomain.event_rank,
        # delta(ti) denotes (ti(out) - ti(in)).
        # For transform ti, N should satisfy the constraint:
        #   N + delta(t0) + delta(t1)...delta(t(i-1)) >= ti(in)
        # So, for all transform in chain, N should satisfy follow constraints:
        #   t0: N >= t0(in)
        #   t1: N >= t1(in) - delta(t0)
        #   ...
        #   tm: N >= tm(in) - ... - delta(ti) - ... - delta(t0)
        #
        # Above problem can be solved more effectively use dynamic programming.
        # Let N(i) denotes lower bound of transform ti, than the state
        # transition equation is:
        #   N(i) = max{N(i+1)-delta(ti), ti(in)}
        event_rank = self.transforms[-1]._codomain.event_rank
        for t in reversed(self.transforms):
            event_rank -= t._codomain.event_rank - t._domain.event_rank
            event_rank = max(event_rank, t._domain.event_rank)

        return variable.Independent(domain, event_rank - domain.event_rank)

    @property
    def _codomain(self) -> variable.Independent:
        codomain = self.transforms[-1]._codomain

        event_rank = self.transforms[0]._domain.event_rank
        for t in self.transforms:
            event_rank += t._codomain.event_rank - t._domain.event_rank
            event_rank = max(event_rank, t._codomain.event_rank)

        return variable.Independent(codomain, event_rank - codomain.event_rank)


class ExpTransform(Transform):
    r"""Exponent transformation with mapping :math:`y = \exp(x)`.

    Examples:

        .. code-block:: python

            >>> import paddle

            >>> exp = paddle.distribution.ExpTransform()
            >>> print(exp.forward(paddle.to_tensor([1., 2., 3.])))
            Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True,
                    [2.71828175 , 7.38905621 , 20.08553696])

            >>> print(exp.inverse(paddle.to_tensor([1., 2., 3.])))
            Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True,
                    [0.        , 0.69314718, 1.09861231])

            >>> print(exp.forward_log_det_jacobian(paddle.to_tensor([1., 2., 3.])))
            Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True,
                    [1., 2., 3.])

            >>> print(exp.inverse_log_det_jacobian(paddle.to_tensor([1., 2., 3.])))
            Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True,
                    [ 0.        , -0.69314718, -1.09861231])
    """

    _type = Type.BIJECTION

    def __init__(self) -> None:
        super().__init__()

    @property
    def _domain(self) -> variable.Real:
        return variable.real

    @property
    def _codomain(self) -> variable.Positive:
        return variable.positive

    def _forward(self, x: Tensor) -> Tensor:
        return x.exp()

    def _inverse(self, y: Tensor) -> Tensor:
        return y.log()

    def _forward_log_det_jacobian(self, x: Tensor) -> Tensor:
        return x


class IndependentTransform(Transform):
    r"""
    ``IndependentTransform`` wraps a base transformation, reinterprets
    some of the rightmost batch axes as event axes.

    Generally, it is used to expand the event axes. This has no effect on the
    forward or inverse transformation, but does sum out the
    ``reinterpreted_batch_rank`` rightmost dimensions in computing the determinant
    of Jacobian matrix.

    To see this, consider the ``ExpTransform`` applied to a Tensor which has
    sample, batch, and event ``(S,B,E)`` shape semantics. Suppose the Tensor's
    partitioned-shape is ``(S=[4], B=[2, 2], E=[3])`` , reinterpreted_batch_rank
    is 1. Then the reinterpreted Tensor's shape  is ``(S=[4], B=[2], E=[2, 3])`` .
    The shape returned by ``forward`` and ``inverse`` is unchanged, ie,
    ``[4,2,2,3]`` . However the shape returned by ``inverse_log_det_jacobian``
    is ``[4,2]``, because the Jacobian determinant is a reduction over the
    event dimensions.

    Args:
        base (Transform): The base transformation.
        reinterpreted_batch_rank (int): The num of rightmost batch rank that
            will be reinterpreted as event rank.

    Examples:

        .. code-block:: python

            >>> import paddle

            >>> x = paddle.to_tensor([[1., 2., 3.], [4., 5., 6.]])

            >>> # Exponential transform with event_rank = 1
            >>> multi_exp = paddle.distribution.IndependentTransform(
            ...     paddle.distribution.ExpTransform(), 1)
            >>> print(multi_exp.forward(x))
            Tensor(shape=[2, 3], dtype=float32, place=Place(cpu), stop_gradient=True,
                    [[2.71828175  , 7.38905621  , 20.08553696 ],
                     [54.59814835 , 148.41316223, 403.42880249]])
            >>> print(multi_exp.forward_log_det_jacobian(x))
            Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True,
                    [6. , 15.])
    """

    def __init__(self, base: Transform, reinterpreted_batch_rank: int) -> None:
        if not isinstance(base, Transform):
            raise TypeError(
                f"Expected 'base' is Transform type, but get {type(base)}"
            )
        if reinterpreted_batch_rank <= 0:
            raise ValueError(
                f"Expected 'reinterpreted_batch_rank' is grater than zero, but got {reinterpreted_batch_rank}"
            )

        self._base = base
        self._reinterpreted_batch_rank = reinterpreted_batch_rank
        super().__init__()

    def _is_injective(self) -> bool:
        return self._base._is_injective()

    def _forward(self, x: Tensor) -> Tensor:
        if x.dim() < self._domain.event_rank:
            raise ValueError("Input dimensions is less than event dimensions.")
        return self._base.forward(x)

    def _inverse(self, y: Tensor) -> Tensor:
        if y.dim() < self._codomain.event_rank:
            raise ValueError("Input dimensions is less than event dimensions.")
        return self._base.inverse(y)

    def _forward_log_det_jacobian(self, x: Tensor) -> Tensor:
        return self._base.forward_log_det_jacobian(x).sum(
            list(range(-self._reinterpreted_batch_rank, 0))
        )

    def _forward_shape(self, shape: Sequence[int]) -> Sequence[int]:
        return self._base.forward_shape(shape)

    def _inverse_shape(self, shape: Sequence[int]) -> Sequence[int]:
        return self._base.inverse_shape(shape)

    @property
    def _domain(self) -> variable.Independent:
        return variable.Independent(
            self._base._domain, self._reinterpreted_batch_rank
        )

    @property
    def _codomain(self) -> variable.Independent:
        return variable.Independent(
            self._base._codomain, self._reinterpreted_batch_rank
        )


class PowerTransform(Transform):
    r"""
    Power transformation with mapping :math:`y = x^{\text{power}}`.

    Args:
        power (Tensor): The power parameter.

    Examples:

        .. code-block:: python

            >>> import paddle

            >>> x = paddle.to_tensor([1., 2.])
            >>> power = paddle.distribution.PowerTransform(paddle.to_tensor(2.))

            >>> print(power.forward(x))
            Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True,
                    [1., 4.])
            >>> print(power.inverse(power.forward(x)))
            Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True,
                    [1., 2.])
            >>> print(power.forward_log_det_jacobian(x))
            Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True,
                    [0.69314718, 1.38629436])
    """

    _type = Type.BIJECTION

    def __init__(self, power: Tensor) -> None:
        if not isinstance(
            power, (paddle.base.framework.Variable, paddle.pir.Value)
        ):
            raise TypeError(
                f"Expected 'power' is a tensor, but got {type(power)}"
            )
        self._power = power
        super().__init__()

    @property
    def power(self) -> Tensor:
        return self._power

    @property
    def _domain(self) -> variable.Real:
        return variable.real

    @property
    def _codomain(self) -> variable.Positive:
        return variable.positive

    def _forward(self, x: Tensor) -> Tensor:
        return x.pow(self._power)

    def _inverse(self, y: Tensor) -> Tensor:
        return y.pow(1 / self._power)

    def _forward_log_det_jacobian(self, x: Tensor) -> Tensor:
        return (self._power * x.pow(self._power - 1)).abs().log()

    def _forward_shape(self, shape: Sequence[int]) -> Sequence[int]:
        return tuple(paddle.broadcast_shape(shape, self._power.shape))

    def _inverse_shape(self, shape: Sequence[int]) -> Sequence[int]:
        return tuple(paddle.broadcast_shape(shape, self._power.shape))


class ReshapeTransform(Transform):
    r"""Reshape the event shape of a tensor.

    Note that ``in_event_shape`` and ``out_event_shape`` must have the same
    number of elements.

    Args:
        in_event_shape(Sequence[int]): The input event shape.
        out_event_shape(Sequence[int]): The output event shape.

    Examples:

        .. code-block:: python

            >>> import paddle

            >>> x = paddle.ones((1,2,3))
            >>> reshape_transform = paddle.distribution.ReshapeTransform((2, 3), (3, 2))
            >>> print(reshape_transform.forward_shape((1,2,3)))
            (1, 3, 2)
            >>> print(reshape_transform.forward(x))
            Tensor(shape=[1, 3, 2], dtype=float32, place=Place(cpu), stop_gradient=True,
                    [[[1., 1.],
                    [1., 1.],
                    [1., 1.]]])
            >>> print(reshape_transform.inverse(reshape_transform.forward(x)))
            Tensor(shape=[1, 2, 3], dtype=float32, place=Place(cpu), stop_gradient=True,
                    [[[1., 1., 1.],
                        [1., 1., 1.]]])
            >>> print(reshape_transform.forward_log_det_jacobian(x))
            Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
                [0.])
    """

    _type = Type.BIJECTION

    def __init__(
        self, in_event_shape: Sequence[int], out_event_shape: Sequence[int]
    ) -> None:
        if not isinstance(in_event_shape, typing.Sequence) or not isinstance(
            out_event_shape, typing.Sequence
        ):
            raise TypeError(
                f"Expected type of 'in_event_shape' and 'out_event_shape' is "
                f"Sequence[int], but got 'in_event_shape': {in_event_shape}, "
                f"'out_event_shape': {out_event_shape}"
            )
        in_size = 1
        for e in in_event_shape:
            in_size *= e
        out_size = 1
        for e in out_event_shape:
            out_size *= e
        if in_size != out_size:
            raise ValueError(
                f"The numel of 'in_event_shape' should be 'out_event_shape', "
                f"but got {in_size}!={out_size}"
            )

        self._in_event_shape = tuple(in_event_shape)
        self._out_event_shape = tuple(out_event_shape)
        super().__init__()

    @property
    def in_event_shape(self) -> tuple[Sequence[int]]:
        return self._in_event_shape

    @property
    def out_event_shape(self) -> tuple[Sequence[int]]:
        return self._out_event_shape

    @property
    def _domain(self) -> variable.Independent:
        return variable.Independent(variable.real, len(self._in_event_shape))

    @property
    def _codomain(self) -> variable.Independent:
        return variable.Independent(variable.real, len(self._out_event_shape))

    def _forward(self, x: Tensor) -> Tensor:
        return x.reshape(
            tuple(x.shape)[: x.dim() - len(self._in_event_shape)]
            + self._out_event_shape
        )

    def _inverse(self, y: Tensor) -> Tensor:
        return y.reshape(
            tuple(y.shape)[: y.dim() - len(self._out_event_shape)]
            + self._in_event_shape
        )

    def _forward_shape(self, shape: Sequence[int]) -> Sequence[int]:
        if len(shape) < len(self._in_event_shape):
            raise ValueError(
                f"Expected length of 'shape' is not less than {len(self._in_event_shape)}, but got {len(shape)}"
            )
        if tuple(shape[-len(self._in_event_shape) :]) != tuple(
            self._in_event_shape
        ):
            raise ValueError(
                f"Event shape mismatch, expected: {self._in_event_shape}, but got {shape[-len(self._in_event_shape) :]}"
            )
        return (
            tuple(shape[: -len(self._in_event_shape)]) + self._out_event_shape
        )

    def _inverse_shape(self, shape: Sequence[int]) -> Sequence[int]:
        if len(shape) < len(self._out_event_shape):
            raise ValueError(
                f"Expected 'shape' length is not less than {len(self._out_event_shape)}, but got {len(shape)}"
            )
        if tuple(shape[-len(self._out_event_shape) :]) != tuple(
            self._out_event_shape
        ):
            raise ValueError(
                f"Event shape mismatch, expected: {self._out_event_shape}, but got {shape[-len(self._out_event_shape) :]}"
            )
        return (
            tuple(shape[: -len(self._out_event_shape)]) + self._in_event_shape
        )

    def _forward_log_det_jacobian(self, x: Tensor) -> Tensor:
        shape = x.shape[: x.dim() - len(self._in_event_shape)]
        return paddle.zeros(shape, dtype=x.dtype)


class SigmoidTransform(Transform):
    r"""Sigmoid transformation with mapping :math:`y = \frac{1}{1 + \exp(-x)}` and :math:`x = \text{logit}(y)`.

    Examples:

        .. code-block:: python

            >>> import paddle

            >>> x = paddle.ones((2,3))
            >>> t = paddle.distribution.SigmoidTransform()
            >>> print(t.forward(x))
            Tensor(shape=[2, 3], dtype=float32, place=Place(cpu), stop_gradient=True,
                    [[0.73105860, 0.73105860, 0.73105860],
                     [0.73105860, 0.73105860, 0.73105860]])
            >>> print(t.inverse(t.forward(x)))
            Tensor(shape=[2, 3], dtype=float32, place=Place(cpu), stop_gradient=True,
                    [[1.00000012, 1.00000012, 1.00000012],
                     [1.00000012, 1.00000012, 1.00000012]])
            >>> print(t.forward_log_det_jacobian(x))
            Tensor(shape=[2, 3], dtype=float32, place=Place(cpu), stop_gradient=True,
                    [[-1.62652326, -1.62652326, -1.62652326],
                     [-1.62652326, -1.62652326, -1.62652326]])
    """

    @property
    def _domain(self) -> variable.Real:
        return variable.real

    @property
    def _codomain(self) -> variable.Variable:
        return variable.Variable(False, 0, constraint.Range(0.0, 1.0))

    def _forward(self, x: Tensor) -> Tensor:
        return F.sigmoid(x)

    def _inverse(self, y: Tensor) -> Tensor:
        return y.log() - (-y).log1p()

    def _forward_log_det_jacobian(self, x: Tensor) -> Tensor:
        return -F.softplus(-x) - F.softplus(x)


class SoftmaxTransform(Transform):
    r"""Softmax transformation with mapping :math:`y=\exp(x)` then normalizing.

    It's generally used to convert unconstrained space to simplex. This mapping
    is not injective, so ``forward_log_det_jacobian`` and
    ``inverse_log_det_jacobian`` are not implemented.

    Examples:

        .. code-block:: python

            >>> import paddle

            >>> x = paddle.ones((2,3))
            >>> t = paddle.distribution.SoftmaxTransform()
            >>> print(t.forward(x))
            Tensor(shape=[2, 3], dtype=float32, place=Place(cpu), stop_gradient=True,
                    [[0.33333334, 0.33333334, 0.33333334],
                     [0.33333334, 0.33333334, 0.33333334]])
            >>> print(t.inverse(t.forward(x)))
            Tensor(shape=[2, 3], dtype=float32, place=Place(cpu), stop_gradient=True,
                    [[-1.09861231, -1.09861231, -1.09861231],
                     [-1.09861231, -1.09861231, -1.09861231]])
    """

    _type = Type.OTHER

    @property
    def _domain(self) -> variable.Independent:
        return variable.Independent(variable.real, 1)

    @property
    def _codomain(self) -> variable.Variable:
        return variable.Variable(False, 1, constraint.simplex)

    def _forward(self, x: Tensor) -> Tensor:
        x = (x - x.max(-1, keepdim=True)[0]).exp()
        return x / x.sum(-1, keepdim=True)

    def _inverse(self, y: Tensor) -> Tensor:
        return y.log()

    def _forward_shape(self, shape: Sequence[int]) -> Sequence[int]:
        if len(shape) < 1:
            raise ValueError(
                f"Expected length of shape is grater than 1, but got {len(shape)}"
            )
        return shape

    def _inverse_shape(self, shape: Sequence[int]) -> Sequence[int]:
        if len(shape) < 1:
            raise ValueError(
                f"Expected length of shape is grater than 1, but got {len(shape)}"
            )
        return shape


class StackTransform(Transform):
    r"""``StackTransform`` applies a sequence of transformations along the
    specific axis.

    Args:
        transforms (Sequence[Transform]): The sequence of transformations.
        axis (int, optional): The axis along which will be transformed. default
            value is 0.

    Examples:

        .. code-block:: python

            >>> import paddle

            >>> x = paddle.stack(
            ...     (paddle.to_tensor([1., 2., 3.]), paddle.to_tensor([1, 2., 3.])), 1)
            >>> t = paddle.distribution.StackTransform(
            ...     (paddle.distribution.ExpTransform(),
            ...     paddle.distribution.PowerTransform(paddle.to_tensor(2.))),
            ...     1
            >>> )
            >>> print(t.forward(x))
            Tensor(shape=[3, 2], dtype=float32, place=Place(cpu), stop_gradient=True,
                    [[2.71828175 , 1.         ],
                     [7.38905621 , 4.         ],
                     [20.08553696, 9.         ]])

            >>> print(t.inverse(t.forward(x)))
            Tensor(shape=[3, 2], dtype=float32, place=Place(cpu), stop_gradient=True,
                    [[1., 1.],
                     [2., 2.],
                     [3., 3.]])

            >>> print(t.forward_log_det_jacobian(x))
            Tensor(shape=[3, 2], dtype=float32, place=Place(cpu), stop_gradient=True,
                    [[1.        , 0.69314718],
                     [2.        , 1.38629436],
                     [3.        , 1.79175949]])
    """

    def __init__(self, transforms: Sequence[Transform], axis: int = 0):
        if not transforms or not isinstance(transforms, typing.Sequence):
            raise TypeError(
                f"Expected 'transforms' is Sequence[Transform], but got {type(transforms)}."
            )
        if not all(isinstance(t, Transform) for t in transforms):
            raise TypeError(
                'Expected all element in transforms is Transform Type.'
            )
        if not isinstance(axis, int):
            raise TypeError(f"Expected 'axis' is int, but got{type(axis)}.")

        self._transforms = transforms
        self._axis = axis

    def _is_injective(self) -> bool:
        return all(t._is_injective() for t in self._transforms)

    @property
    def transforms(self) -> Sequence[Transform]:
        return self._transforms

    @property
    def axis(self) -> int:
        return self._axis

    def _forward(self, x: Tensor) -> Tensor:
        self._check_size(x)
        return paddle.stack(
            [
                t.forward(v)
                for v, t in zip(paddle.unstack(x, self._axis), self._transforms)
            ],
            self._axis,
        )

    def _inverse(self, y: Tensor) -> Tensor:
        self._check_size(y)
        return paddle.stack(
            [
                t.inverse(v)
                for v, t in zip(paddle.unstack(y, self._axis), self._transforms)
            ],
            self._axis,
        )

    def _forward_log_det_jacobian(self, x: Tensor) -> Tensor:
        self._check_size(x)
        return paddle.stack(
            [
                t.forward_log_det_jacobian(v)
                for v, t in zip(paddle.unstack(x, self._axis), self._transforms)
            ],
            self._axis,
        )

    def _check_size(self, v: Tensor) -> None:
        if not (-v.dim() <= self._axis < v.dim()):
            raise ValueError(
                f'Input dimensions {v.dim()} should be grater than stack '
                f'transform axis {self._axis}.'
            )
        if v.shape[self._axis] != len(self._transforms):
            raise ValueError(
                f'Input size along {self._axis} should be equal to the '
                f'length of transforms.'
            )

    @property
    def _domain(self) -> variable.Stack:
        return variable.Stack([t._domain for t in self._transforms], self._axis)

    @property
    def _codomain(self) -> variable.Stack:
        return variable.Stack(
            [t._codomain for t in self._transforms], self._axis
        )


class StickBreakingTransform(Transform):
    r"""Convert an unconstrained vector to the simplex with one additional
    dimension by the stick-breaking construction.

    Examples:

        .. code-block:: python

            >>> import paddle


            >>> x = paddle.to_tensor([1.,2.,3.])
            >>> t = paddle.distribution.StickBreakingTransform()
            >>> print(t.forward(x))
            Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True,
                    [0.47536686, 0.41287899, 0.10645414, 0.00530004])
            >>> print(t.inverse(t.forward(x)))
            Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True,
                    [0.99999988, 2.        , 2.99999881])
            >>> print(t.forward_log_det_jacobian(x))
            Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
                    -9.10835075)
    """

    _type = Type.BIJECTION

    def _forward(self, x: Tensor) -> Tensor:
        offset = x.shape[-1] + 1 - paddle.ones([x.shape[-1]]).cumsum(-1)
        z = F.sigmoid(x - offset.log())
        z_cumprod = (1 - z).cumprod(-1)
        return F.pad(z, [0] * 2 * (len(x.shape) - 1) + [0, 1], value=1) * F.pad(
            z_cumprod, [0] * 2 * (len(x.shape) - 1) + [1, 0], value=1
        )

    def _inverse(self, y: Tensor) -> Tensor:
        y_crop = y[..., :-1]
        offset = y.shape[-1] - paddle.ones([y_crop.shape[-1]]).cumsum(-1)
        sf = 1 - y_crop.cumsum(-1)
        x = y_crop.log() - sf.log() + offset.log()
        return x

    def _forward_log_det_jacobian(self, x: Tensor) -> Tensor:
        y = self.forward(x)
        offset = x.shape[-1] + 1 - paddle.ones([x.shape[-1]]).cumsum(-1)
        x = x - offset.log()
        return (-x + F.log_sigmoid(x) + y[..., :-1].log()).sum(-1)

    def _forward_shape(self, shape: Sequence[int]) -> Sequence[int]:
        if not shape:
            raise ValueError(f"Expected 'shape' is not empty, but got {shape}")
        return (*shape[:-1], shape[-1] + 1)

    def _inverse_shape(self, shape: Sequence[int]) -> Sequence[int]:
        if not shape:
            raise ValueError(f"Expected 'shape' is not empty, but got {shape}")
        return (*shape[:-1], shape[-1] - 1)

    @property
    def _domain(self) -> variable.Independent:
        return variable.Independent(variable.real, 1)

    @property
    def _codomain(self) -> variable.Variable:
        return variable.Variable(False, 1, constraint.simplex)


class TanhTransform(Transform):
    r"""Tanh transformation with mapping :math:`y = \tanh(x)`.

    Examples:

        .. code-block:: python

            >>> import paddle

            >>> tanh = paddle.distribution.TanhTransform()

            >>> x = paddle.to_tensor([[1., 2., 3.], [4., 5., 6.]])

            >>> # doctest: +SKIP('random sample')
            >>> print(tanh.forward(x))
            Tensor(shape=[2, 3], dtype=float32, place=Place(cpu), stop_gradient=True,
                [[0.76159418, 0.96402758, 0.99505472],
                    [0.99932921, 0.99990916, 0.99998784]])
            >>> print(tanh.inverse(tanh.forward(x)))
            Tensor(shape=[2, 3], dtype=float32, place=Place(cpu), stop_gradient=True,
                [[1.        , 2.        , 2.99999666],
                    [3.99993253, 4.99977016, 6.00527668]])
            >>> print(tanh.forward_log_det_jacobian(x))
            Tensor(shape=[2, 3], dtype=float32, place=Place(cpu), stop_gradient=True,
                    [[-0.86756170 , -2.65000558 , -4.61865711 ],
                     [-6.61437654 , -8.61379623 , -10.61371803]])
            >>> print(tanh.inverse_log_det_jacobian(tanh.forward(x)))
            Tensor(shape=[2, 3], dtype=float32, place=Place(cpu), stop_gradient=True,
                    [[0.86756176 , 2.65000558 , 4.61866283 ],
                     [6.61441946 , 8.61399269 , 10.61451530]])
            >>> # doctest: -SKIP
    """

    _type = Type.BIJECTION

    @property
    def _domain(self) -> variable.Real:
        return variable.real

    @property
    def _codomain(self) -> variable.Variable:
        return variable.Variable(False, 0, constraint.Range(-1.0, 1.0))

    def _forward(self, x: Tensor) -> Tensor:
        return x.tanh()

    def _inverse(self, y: Tensor) -> Tensor:
        return y.atanh()

    def _forward_log_det_jacobian(self, x: Tensor) -> Tensor:
        """We implicitly rely on _forward_log_det_jacobian rather than
        explicitly implement ``_inverse_log_det_jacobian`` since directly using
        ``-tf.math.log1p(-tf.square(y))`` has lower numerical precision.

        See details: https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/bijectors/tanh.py#L69-L80
        """
        return 2.0 * (math.log(2.0) - x - F.softplus(-2.0 * x))
