#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import math
from collections.abc import Iterable, Iterator, Sequence, Sized

import numpy as np

from .dataset import IterableDataset
from .sampler import RandomSampler, Sampler, SequenceSampler


class BatchSampler(Sampler[Sequence[int]]):
    """
    A base implement of batch sampler used by `paddle.io.DataLoader`
    which yield mini-batch indices(a list/tuple with length as
    mini-batch size and holds sample indices) iterably.

    Batch sampler used by :code:`paddle.io.DataLoader` should be a subclass
    of :code:`paddle.io.BatchSampler`, BatchSampler subclasses should
    implement following methods:

    :code:`__iter__`: return mini-batch indices iterably.

    :code:`__len__`: get mini-batch number in an epoch.


    Args:
        dataset(Dataset, optional): this should be an instance of a subclass of :ref:`api_paddle_io_Dataset` or
                :ref:`api_paddle_io_IterableDataset` or other python object which implemented
                :code:`__len__` for BatchSampler to get indices as the
                range of :attr:`dataset` length. Default None, disabled.
        sampler (Sampler, Iterable, optional): this should be a :ref:`api_paddle_io_Sample` or Iterable
                instance which implemented :code:`__iter__` to generate
                sample indices. :attr:`sampler` and :attr:`dataset`
                can not be set in the same time.  If :attr:`sampler`
                is set, :attr:`dataset` should not be set. Default None, disabled.
        shuffle(bool, optional): whether to shuffle indices order before generating
                batch indices. Default False, don't shuffle indices before generating batch indices.
        batch_size(int, optional): sample indice number in a mini-batch indices. default 1, each mini-batch includes 1 sample.
        drop_last(bool, optional): whether drop the last incomplete (less than 1 mini-batch) batch dataset. Default False, keep it.
    see :ref:`api_paddle_io_DataLoader`

    Returns:
        BatchSampler: an iterable object for indices iterating

    Examples:

        .. code-block:: python

            >>> import numpy as np
            >>> from paddle.io import RandomSampler, BatchSampler, Dataset

            >>> np.random.seed(2023)
            >>> # init with dataset
            >>> class RandomDataset(Dataset):  # type: ignore[type-arg]
            ...     def __init__(self, num_samples):
            ...         self.num_samples = num_samples
            ...
            ...     def __getitem__(self, idx):
            ...         image = np.random.random([784]).astype('float32')
            ...         label = np.random.randint(0, 9, (1, )).astype('int64')
            ...         return image, label
            ...
            ...     def __len__(self):
            ...         return self.num_samples
            ...
            >>> bs = BatchSampler(dataset=RandomDataset(100),
            ...                     shuffle=False,
            ...                     batch_size=16,
            ...                     drop_last=False)
            ...
            >>> for batch_indices in bs:
            ...     print(batch_indices)
            [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
            ...
            [96, 97, 98, 99]
            >>> # init with sampler
            >>> sampler = RandomSampler(RandomDataset(100))
            >>> bs = BatchSampler(sampler=sampler,
            ...                     batch_size=8,
            ...                     drop_last=True)
            ...
            >>> for batch_indices in bs:
            ...     print(batch_indices)
            [56, 12, 68, 0, 82, 66, 91, 44]
            ...
            [53, 17, 22, 86, 52, 3, 92, 33]
    """

    sampler: Sampler[int] | Iterable[int]
    batch_size: int
    shuffle: bool
    drop_last: bool

    def __init__(
        self,
        dataset: Sized | None = None,
        sampler: Sampler | Iterable[int] | None = None,
        shuffle: bool = False,
        batch_size: int = 1,
        drop_last: bool = False,
    ) -> None:
        if dataset is None:
            assert sampler is not None, (
                "either dataset or sampler should be set"
            )
            assert isinstance(sampler, (Sampler, Iterable)), (
                f"sampler should be either paddle.io.Sampler or Iterable, but got {type(sampler)}"
            )
            assert not shuffle, "shuffle should be False when sampler is set"
            self.sampler = sampler
        else:
            assert not isinstance(dataset, IterableDataset), (
                "dataset should not be a paddle.io.IterableDataset"
            )
            assert sampler is None, "should not set both dataset and sampler"
            assert isinstance(shuffle, bool), (
                f"shuffle should be a boolean value, but got {type(shuffle)}"
            )
            if shuffle:
                self.sampler = RandomSampler(dataset)
            else:
                self.sampler = SequenceSampler(dataset)

        assert isinstance(batch_size, int) and batch_size > 0, (
            f"batch_size should be a positive integer, but got {batch_size}"
        )
        self.batch_size = batch_size  # per_device_batch_size or mini_batch_size
        self.shuffle = shuffle
        assert isinstance(drop_last, bool), (
            f"drop_last should be a boolean value, but got {type(drop_last)}"
        )
        self.drop_last = drop_last

        # TODO(dev): consider to make it as public argument, acc_steps is only used
        # in auto-parallel
        self._acc_steps = 1

    def __iter__(self) -> Iterator[list[int]]:
        local_batch_size = self.batch_size * self._acc_steps
        batch_indices = []
        for idx in self.sampler:
            batch_indices.append(idx)
            if len(batch_indices) == local_batch_size:
                yield batch_indices
                batch_indices = []
        if not self.drop_last and len(batch_indices) > 0:
            yield batch_indices

    def __len__(self) -> int:
        local_batch_size = self.batch_size * self._acc_steps
        num_samples = len(self.sampler)
        num_samples += int(not self.drop_last) * (local_batch_size - 1)
        return num_samples // local_batch_size


class _InfiniteIterableSampler(Sampler[Sequence[None]]):
    dataset: IterableDataset
    batch_size: int

    def __init__(self, dataset: IterableDataset, batch_size: int = 1) -> None:
        assert isinstance(dataset, IterableDataset), (
            "dataset should be an instance of paddle.io.IterableDataset"
        )
        self.dataset = dataset
        self.batch_size = batch_size

    def __iter__(self) -> Iterator[list[None]]:
        while True:
            yield [None] * self.batch_size


class DistributedBatchSampler(BatchSampler):
    """Sampler that restricts data loading to a subset of the dataset.

    In such case, each process can pass a DistributedBatchSampler instance
    as a DataLoader sampler, and load a subset of the original dataset that
    is exclusive to it.

    .. note::
        Dataset is assumed to be of constant size.

    Args:
        dataset(Dataset): this could be an instance of subclass of :ref:`api_paddle_io_Dataset`
                     or other python object which implemented
                     `__len__` for BatchSampler to get indices of samples.
        batch_size(int): sample size of each mini-batch.
        num_replicas(int, optional): process number in distributed training.
            If :attr:`num_replicas` is None, :attr:`num_replicas` will be
            retrieved from :ref:`api_paddle_distributed_ParallelEnv` .
            Default None.
        rank(int, optional): the rank of the current process among :attr:`num_replicas`
            processes. If :attr:`rank` is None, :attr:`rank` is retrieved from
            :ref:`api_paddle_distributed_ParallelEnv`. Default None.
        shuffle(bool, optional): whether to shuffle indices order before generating
            batch indices. Default False.
        drop_last(bool, optional): whether drop the last incomplete(less than a mini-batch) batch dataset size.
            Default False.

    Returns:
        DistributedBatchSampler, return an iterable object for indices iterating.

    Examples:
        .. code-block:: python

            >>> import numpy as np

            >>> from paddle.io import Dataset, DistributedBatchSampler

            >>> # init with dataset
            >>> class RandomDataset(Dataset):  # type: ignore[type-arg]
            ...     def __init__(self, num_samples):
            ...         self.num_samples = num_samples
            ...
            ...     def __getitem__(self, idx):
            ...         image = np.random.random([784]).astype('float32')
            ...         label = np.random.randint(0, 9, (1, )).astype('int64')
            ...         return image, label
            ...
            ...     def __len__(self):
            ...         return self.num_samples
            ...
            >>> dataset = RandomDataset(100)
            >>> sampler = DistributedBatchSampler(dataset, batch_size=64)

            >>> for data in sampler:
            ...     # do something
            ...     break
    """

    dataset: Sized
    batch_size: int
    drop_last: bool
    nranks: int
    epoch: int
    local_rank: int
    num_samples: int
    total_size: int

    def __init__(
        self,
        dataset: Sized,
        batch_size: int,
        num_replicas: int | None = None,
        rank: int | None = None,
        shuffle: bool = False,
        drop_last: bool = False,
    ) -> None:
        self.dataset = dataset

        assert isinstance(batch_size, int) and batch_size > 0, (
            "batch_size should be a positive integer"
        )
        self.batch_size = batch_size
        assert isinstance(shuffle, bool), "shuffle should be a boolean value"
        self.shuffle = shuffle
        assert isinstance(drop_last, bool), (
            "drop_last should be a boolean number"
        )

        from paddle.distributed import ParallelEnv

        if num_replicas is not None:
            assert isinstance(num_replicas, int) and num_replicas > 0, (
                "num_replicas should be a positive integer"
            )
            self.nranks = num_replicas
        else:
            self.nranks = ParallelEnv().nranks

        if rank is not None:
            assert isinstance(rank, int) and rank >= 0, (
                "rank should be a non-negative integer"
            )
            self.local_rank = rank
        else:
            self.local_rank = ParallelEnv().local_rank

        self.drop_last = drop_last
        self.epoch = 0
        self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.nranks))
        self.total_size = self.num_samples * self.nranks

        # TODO(dev): consider to make it as public argument, acc_steps is only used
        # in auto-parallel
        self._acc_steps = 1

    def __iter__(self) -> Iterator[list[int]]:
        local_batch_size = self.batch_size * self._acc_steps
        num_samples = len(self.dataset)
        indices = np.arange(num_samples).tolist()
        # add extra samples to make it evenly divisible
        padding_size = self.total_size - len(indices)
        if padding_size <= len(indices):
            indices += indices[:padding_size]
        else:
            indices += (indices * math.ceil(padding_size / len(indices)))[
                :padding_size
            ]

        assert len(indices) == self.total_size
        if self.shuffle:
            np.random.RandomState(self.epoch).shuffle(indices)
            self.epoch += 1

        # subsample
        def _get_indices_by_batch_size(indices):
            subsampled_indices = []
            last_batch_size = self.total_size % (self.batch_size * self.nranks)
            assert last_batch_size % self.nranks == 0
            last_local_batch_size = last_batch_size // self.nranks

            for i in range(
                self.local_rank * self.batch_size,
                len(indices) - last_batch_size,
                self.batch_size * self.nranks,
            ):
                subsampled_indices.extend(indices[i : i + self.batch_size])

            indices = indices[len(indices) - last_batch_size :]
            subsampled_indices.extend(
                indices[
                    self.local_rank * last_local_batch_size : (
                        self.local_rank + 1
                    )
                    * last_local_batch_size
                ]
            )

            return subsampled_indices

        if self.nranks > 1:
            indices = _get_indices_by_batch_size(indices)

        assert len(indices) == self.num_samples
        _sample_iter = iter(indices)

        batch_indices = []
        for idx in _sample_iter:
            batch_indices.append(idx)
            if len(batch_indices) == local_batch_size:
                yield batch_indices
                batch_indices = []
        if not self.drop_last and len(batch_indices) > 0:
            yield batch_indices

    def __len__(self) -> int:
        local_batch_size = self.batch_size * self._acc_steps
        num_samples = self.num_samples
        num_samples += int(not self.drop_last) * (local_batch_size - 1)
        return num_samples // local_batch_size

    def set_epoch(self, epoch: int) -> None:
        """
        Sets the epoch number. When :attr:`shuffle=True`, this number is used
        as seeds of random numbers. By default, users may not set this, all
        replicas (workers) use a different random ordering for each epoch.
        If set same number at each epoch, this sampler will yield the same
        ordering at all epochs.

        Arguments:
            epoch (int): Epoch number.

        Examples:
            .. code-block:: python

                >>> import numpy as np

                >>> from paddle.io import Dataset, DistributedBatchSampler

                >>> # init with dataset
                >>> class RandomDataset(Dataset):  # type: ignore[type-arg]
                ...     def __init__(self, num_samples):
                ...         self.num_samples = num_samples
                ...
                ...     def __getitem__(self, idx):
                ...         image = np.random.random([784]).astype('float32')
                ...         label = np.random.randint(0, 9, (1, )).astype('int64')
                ...         return image, label
                ...
                ...     def __len__(self):
                ...         return self.num_samples
                ...
                >>> dataset = RandomDataset(100)
                >>> sampler = DistributedBatchSampler(dataset, batch_size=64)

                >>> for epoch in range(10):
                ...     sampler.set_epoch(epoch)
        """
        self.epoch = epoch
