# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


# The following codes are from https://github.com/facebookresearch/xformers

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING

import paddle

if TYPE_CHECKING:
    from collections.abc import Sequence


class AttentionBias(ABC):
    @abstractmethod
    def materialize(self, shape, dtype=paddle.float32):
        raise NotImplementedError


class LowerTriangularMask(AttentionBias):
    def materialize(self, shape, dtype=paddle.float32):
        create_as = dtype if dtype is not paddle.bfloat16 else paddle.float32
        tensor = paddle.full(
            shape=shape, fill_value=float("-inf"), dtype=create_as
        )
        return paddle.triu(tensor, diagonal=1).astype(dtype)

    def add_bias(self, bias):
        return LowerTriangularMaskWithTensorBias(bias)


class LowerTriangularMaskWithTensorBias(LowerTriangularMask):
    def __init__(self, bias):
        self._bias = bias

    def materialize(self, shape, dtype=paddle.float32):
        return super().materialize(shape, dtype) + self._bias


@dataclass
class SeqLenInfo:
    seqstart: paddle.Tensor
    max_seqlen: int
    seqstart_py: list[int]

    def intervals(self):
        yield from zip(self.seqstart_py, self.seqstart_py[1:])

    @classmethod
    def from_seqlens(cls, seqlens):
        seqstart_py = [0]
        max_seqlen = -1
        for seqlen in seqlens:
            max_seqlen = max(max_seqlen, seqlen)
            seqstart_py.append(seqstart_py[-1] + seqlen)
        seqstart = paddle.to_tensor(seqstart_py, dtype=paddle.int32)
        return cls(
            max_seqlen=max_seqlen, seqstart=seqstart, seqstart_py=seqstart_py
        )

    def split(self, x, batch_sizes=None):
        assert self.seqstart_py[-1] == x.shape[1] and x.shape[0] == 1
        if batch_sizes is None:
            batch_sizes = [1] * (len(self.seqstart_py) - 1)
        split_chunks = []
        it = 0
        for batch_size in batch_sizes:
            split_chunks.append(
                self.seqstart_py[it + batch_size] - self.seqstart_py[it]
            )
            it += batch_size
        return [
            tensor.reshape([bs, -1, *tensor.shape[2:]])
            for bs, tensor in zip(batch_sizes, x.split(split_chunks, axis=1))
        ]


@dataclass
class PaddedSeqLenInfo(SeqLenInfo):
    seqlen: paddle.Tensor
    seqlen_py: Sequence[int]

    def intervals(self):
        for (start, _), length in zip(super().intervals(), self.seqlen_py):
            yield start, start + length

    @classmethod
    def from_seqlens(cls, seqlens):
        raise NotImplementedError(
            "Please use SeqLenInfo.from_seq_lens() or PaddedSeqLenInfo.from_seq_lens_padded()."
        )

    @classmethod
    def from_seqlens_padded(cls, seqlens, padding):
        assert all(seqlen <= padding for seqlen in seqlens)
        seqstart_py = list(range(0, len(seqlens) * padding + 1, padding))
        return cls(
            seqlen=paddle.to_tensor(seqlens, dtype=paddle.int32),
            seqlen_py=seqlens,
            max_seqlen=max(seqlens),
            seqstart=paddle.to_tensor(seqstart_py, dtype=paddle.int32),
            seqstart_py=seqstart_py,
        )

    def split(self, x, batch_sizes=None):
        raise NotImplementedError


@dataclass
class BlockDiagonalMask(AttentionBias):
    q_seqinfo: SeqLenInfo
    k_seqinfo: SeqLenInfo
    _batch_sizes: Sequence[int] | None = None

    def _create_block_mask(self, shape, dtype=paddle.float32):
        return paddle.zeros(shape=shape, dtype=dtype)

    def materialize(self, shape, dtype=paddle.float32):
        assert shape[-1] == self.k_seqinfo.seqstart_py[-1]
        assert shape[-2] == self.q_seqinfo.seqstart_py[-1]
        mask = paddle.full(shape[-2:], fill_value=float('-inf'), dtype=dtype)
        for (q_start, q_end), (k_start, k_end) in zip(
            self.q_seqinfo.intervals(), self.k_seqinfo.intervals()
        ):
            sub_shape = [q_end - q_start, k_end - k_start]
            mask[q_start:q_end, k_start:k_end] = self._create_block_mask(
                sub_shape, dtype
            )
        for _ in range(len(shape) - 2):
            mask = mask.unsqueeze(0)
        return mask.expand(shape)

    @classmethod
    def from_seqlens(cls, q_seqlen, kv_seqlen=None):
        assert kv_seqlen is None or len(q_seqlen) == len(kv_seqlen)
        q_seqinfo = SeqLenInfo.from_seqlens(q_seqlen)
        if kv_seqlen is None or q_seqlen == kv_seqlen:
            k_seqinfo = q_seqinfo
        else:
            k_seqinfo = SeqLenInfo.from_seqlens(kv_seqlen)
        return cls(q_seqinfo=q_seqinfo, k_seqinfo=k_seqinfo)

    @classmethod
    def from_tensor_list(cls, tensors):
        batch_sizes = [tensor.shape[0] for tensor in tensors]
        seqlens = []
        for x in tensors:
            for _ in range(x.shape[0]):
                seqlens.append(x.shape[1])
        block_diag = cls.from_seqlens(seqlens)
        block_diag._batch_sizes = batch_sizes
        concated_tensor = paddle.concat(
            [x.reshape([1, -1, *x.shape[2:]]) for x in tensors], axis=1
        )
        return block_diag, concated_tensor

    @classmethod
    def from_tensor_lists_qkv(cls, tensors_q, tensors_k, tensors_v=None):
        assert len(tensors_q) == len(tensors_k)
        assert tensors_v is None or len(tensors_v) == len(tensors_q)
        batch_sizes = [tensor.shape[0] for tensor in tensors_q]
        q_seqlens, kv_seqlens = [], []
        for i, (q, k) in enumerate(zip(tensors_q, tensors_k)):
            assert q.shape[0] == k.shape[0]
            q_seqlens.extend([q.shape[1]] * q.shape[0])
            kv_seqlens.extend([k.shape[1]] * k.shape[0])
            assert tensors_v is None or tensors_v[i].shape[:2] == k.shape[:2]
        block_diag = cls.from_seqlens(q_seqlens, kv_seqlens)
        block_diag._batch_sizes = [x.shape[0] for x in tensors_q]
        return (
            block_diag,
            paddle.concat(
                [x.reshape([1, -1, *x.shape[2:]]) for x in tensors_q], axis=1
            ),
            paddle.concat(
                [x.reshape([1, -1, *x.shape[2:]]) for x in tensors_k], axis=1
            ),
            (
                paddle.concat(
                    [x.reshape([1, -1, *x.shape[2:]]) for x in tensors_v],
                    axis=1,
                )
                if tensors_v is not None
                else None
            ),
        )

    def split_queries(self, tensor):
        return self.q_seqinfo.split(tensor, self._batch_sizes)

    def split_kv(self, tensor):
        return self.k_seqinfo.split(tensor, self._batch_sizes)

    def split(self, tensor):
        assert self.q_seqinfo is self.k_seqinfo
        return self.q_seqinfo.split(tensor, self._batch_sizes)

    def make_causal(self):
        return BlockDiagonalCausalMask(
            q_seqinfo=self.q_seqinfo,
            k_seqinfo=self.k_seqinfo,
            _batch_sizes=self._batch_sizes,
        )


@dataclass
class BlockDiagonalCausalMask(BlockDiagonalMask):
    def _create_block_mask(self, shape, dtype=paddle.float32):
        return LowerTriangularMask().materialize(shape=shape, dtype=dtype)


@dataclass
class BlockDiagonalCausalWithOffsetPaddedKeysMask(AttentionBias):
    q_seqinfo: SeqLenInfo
    k_seqinfo: PaddedSeqLenInfo
    causal_diagonal: paddle.Tensor | None = None

    def _create_block_mask(self, shape, offset=0, dtype=paddle.float32):
        create_as = dtype if dtype is not paddle.bfloat16 else paddle.float32
        tensor = paddle.full(shape, dtype=create_as, fill_value=float('-inf'))
        return paddle.triu(tensor, diagonal=1 + offset).astype(dtype)

    def materialize(self, shape, dtype=paddle.float32):
        assert shape[-1] == self.k_seqinfo.seqstart_py[-1]
        assert shape[-2] == self.q_seqinfo.seqstart_py[-1]
        mask = paddle.full(shape[-2:], dtype=dtype, fill_value=float('-inf'))
        for i, ((q_start, q_end), (k_start, k_end)) in enumerate(
            zip(self.q_seqinfo.intervals(), self.k_seqinfo.intervals())
        ):
            mask[q_start:q_end, k_start:k_end] = self._create_block_mask(
                (q_end - q_start, k_end - k_start),
                offset=(
                    0
                    if self.causal_diagonal is None
                    else int(self.causal_diagonal[i].item())
                ),
                dtype=dtype,
            )
        for _ in range(len(shape) - 2):
            mask = mask.unsqueeze(0)
        return mask.expand(shape)

    @classmethod
    def from_seqlens(
        cls, q_seqlen, kv_padding, kv_seqlen, causal_diagonal=None
    ):
        assert kv_seqlen is None or len(q_seqlen) == len(kv_seqlen)
        q_seqinfo = SeqLenInfo.from_seqlens(q_seqlen)
        k_seqinfo = PaddedSeqLenInfo.from_seqlens_padded(kv_seqlen, kv_padding)
        return cls(
            q_seqinfo=q_seqinfo,
            k_seqinfo=k_seqinfo,
            causal_diagonal=causal_diagonal,
        )
