# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import copy
import logging
import multiprocessing
import sys
import time
import warnings
from typing import (
    TYPE_CHECKING,
    Any,
    AnyStr,
    Callable,
    Protocol,
    TypeVar,
    overload,
)

import paddle

from ..base.framework import (
    _current_expected_place,
    _get_paddle_place,
    _get_paddle_place_list,
)
from ..framework import core, in_dynamic_mode
from .dataloader import BatchSampler, IterableDataset, Subset
from .dataloader.batch_sampler import (
    DistributedBatchSampler,
    _InfiniteIterableSampler,
)
from .dataloader.dataloader_iter import (
    _DataLoaderIterMultiProcess,
    _DataLoaderIterSingleProcess,
    _DatasetKind,
)

if TYPE_CHECKING:
    import numbers
    from collections.abc import Mapping, Sequence

    import numpy.typing as npt

    from paddle import Tensor
    from paddle._typing import PlaceLike
    from paddle._typing.device_like import _Place
    from paddle.io.dataloader.dataloader_iter import _DataLoaderIterBase
    from paddle.io.dataloader.dataset import Dataset

    _K = TypeVar('_K')
    _V = TypeVar('_V')

    class _CollateFn(Protocol):
        @overload
        def __call__(
            self, batch: Sequence[npt.NDArray[Any]] | Sequence[numbers.Number]
        ) -> npt.NDArray[Any]: ...

        @overload
        def __call__(self, batch: Sequence[Tensor]) -> Tensor: ...

        @overload
        def __call__(self, batch: Sequence[AnyStr]) -> AnyStr: ...

        @overload
        def __call__(
            self, batch: Sequence[Mapping[_K, _V]]
        ) -> Mapping[_K, _V]: ...

        @overload
        def __call__(self, batch: Sequence[Sequence[_V]]) -> Sequence[_V]: ...


# NOTE: [ avoid hanging & failed quickly ]
# These value is used in getting data from another process
QUEUE_GET_TIMEOUT = 60

USE_PINNED_MEMORY = None
# AutoTune Flags
USE_AUTOTUNE = False
TUNING_STEPS = 500


def set_autotune_config(use_autotune, tuning_steps=500):
    global USE_AUTOTUNE
    USE_AUTOTUNE = use_autotune
    global TUNING_STEPS
    TUNING_STEPS = tuning_steps


def use_pinned_memory(*args):
    global USE_PINNED_MEMORY
    if len(args) == 0:
        return USE_PINNED_MEMORY
    else:
        assert len(args) == 1 and isinstance(args[0], bool)
        USE_PINNED_MEMORY = args[0]


def _convert_places(places):
    if not isinstance(places, (list, tuple)):
        places = [places]

    ret = []
    for p in places:
        if not isinstance(p, core.Place):
            tmp = core.Place()
            tmp.set_place(p)
            p = tmp

        ret.append(p)
    return ret


class AuToTune:
    def __init__(self, loader):
        self.loader = loader
        self.max_num_worker = multiprocessing.cpu_count() / 2

    def __call__(self):
        # use default loader
        if (not USE_AUTOTUNE) or (not self.need_autotune()):
            return self.loader.num_workers

        # get autotune loader
        auto_tune_loader = self.get_autotune_loader()
        if auto_tune_loader is None:
            return self.loader.num_workers

        # pick the best num_workers
        auto_tune_start = time.time()
        logging.debug("========= DataLoader Auto Tune =========")
        logging.debug(
            "User config for DataLoader: " + str(self.loader.num_workers)
        )
        best_num_workers = 0
        min_cost = float("inf")
        logging.debug(
            "Tuning Range for num_workers: 0 ~ " + str(self.max_num_worker)
        )
        num_workers = 0
        while num_workers < self.max_num_worker:
            auto_tune_loader.num_workers = num_workers
            avg_cost = self.evaluate_reader_cost(auto_tune_loader)
            if min_cost * 0.75 > avg_cost:
                min_cost = avg_cost
                best_num_workers = num_workers
            else:
                update_num = self.is_best(
                    auto_tune_loader,
                    best_num_workers,
                    min_cost,
                    self.max_num_worker,
                )
                if update_num == best_num_workers:
                    break
                else:
                    best_num_workers = update_num
            logging.debug(
                "num_workers: "
                + str(num_workers)
                + " avg_cost: "
                + str(avg_cost)
            )
            num_workers += 2
        logging.info(
            "auto_tune dataLoader best_num_workers: " + str(best_num_workers)
        )
        logging.debug(
            "AutoTuning Cost for DataLoader: "
            + str(time.time() - auto_tune_start)
            + ' seconds'
        )

        # tune the default loader's num_workers
        return best_num_workers

    def need_autotune(self):
        if sys.platform == 'darwin' or sys.platform == 'win32':
            return False
        else:
            return True

    def get_sub_dataset(self, dataset, batch_size):
        num_samples = min(batch_size * TUNING_STEPS, len(dataset))
        sub_dataset = Subset(dataset, indices=list(range(num_samples)))
        return sub_dataset

    def get_autotune_loader(self):
        loader = copy.copy(self.loader)
        batch_size = self.loader.batch_sampler.batch_size
        if isinstance(
            self.loader.batch_sampler, paddle.io.DistributedBatchSampler
        ):
            dataset = self.loader.batch_sampler.dataset
            sub_dataset = self.get_sub_dataset(dataset, batch_size)
            loader.batch_sampler = paddle.io.DistributedBatchSampler(
                dataset=sub_dataset,
                batch_size=batch_size,
                num_replicas=self.loader.batch_sampler.nranks,
                rank=self.loader.batch_sampler.local_rank,
                shuffle=self.loader.batch_sampler.shuffle,
                drop_last=self.loader.batch_sampler.drop_last,
            )
        elif isinstance(self.loader.batch_sampler, paddle.io.BatchSampler):
            dataset = self.loader.batch_sampler.sampler.data_source
            sub_dataset = self.get_sub_dataset(dataset, batch_size)
            loader.batch_sampler = paddle.io.BatchSampler(
                dataset=sub_dataset,
                batch_size=batch_size,
                drop_last=self.loader.batch_sampler.drop_last,
            )
        else:
            loader = None
        return loader

    def evaluate_reader_cost(self, reader):
        costs = []
        avg_cost = 0
        start = time.time()
        for i, data in enumerate(reader):
            costs.append(time.time() - start)
            start = time.time()
        if len(costs) > 2:
            avg_cost = sum(costs[2:]) / len(costs[2:])
        else:
            avg_cost = sum(costs[0:]) / len(costs[0:])
        return avg_cost

    def is_best(self, reader, best_workers, best_time, num_work_boundary):
        step = 0
        num_workers = best_workers + 1
        boundary = 1
        while num_workers < num_work_boundary and step < 5:
            self.loader.num_workers = num_workers
            time = self.evaluate_reader_cost(reader)
            logging.debug(
                "for back num_workers: "
                + str(num_workers)
                + " avg_cost: "
                + str(time)
            )
            step += 1
            if time < best_time * 0.70 * boundary:
                return num_workers
            else:
                num_workers += 1
            boundary *= 0.80
        return best_workers


class DataLoader:
    """
    DataLoader provides an iterator which iterates given dataset
    once by the batch_sampler.

    DataLoader supports single-process and multi-process data loading,
    multi-process workers will be used to load data asynchronously if
    :attr:`num_workers` is set as a positive number.

    DataLoader supports map-style dataset and iterable-style dataset.

    For map-style dataset(can get a sample from dataset with a given
    index), please see :code:`paddle.io.Dataset`.

    For iterable-style dataset(get samples from dataset iteratively,
    like a Python iterator), please see :code:`paddle.io.IterableDataset`.

    For :code:`batch_sampler` please see :code:`paddle.io.BatchSampler`

    Notes:
        GPU tensor operation is not supported in subprocess currently,
        please don't use GPU tensor operations in pipeline which will
        be performed in subprocess, such as dataset transforms, collate_fn,
        etc. Numpy array and CPU tensor operation is supported.

    **Disable automatic batching**

    In certain cases such as some NLP tasks, instead of automatic batching,
    handling batching manually in dataset is needed by users. For these
    cases, automatic batching is disabled if both :attr:`batch_size` and
    :attr:`batch_sampler` is set as None, each data got from :attr:`dataset`
    should be batched data and will be processed with function define by
    :attr:`collate_fn` or :attr:`default_collate_fn`.


    Notes:
        When automatic batching is disabled, :attr:`default_collate_fn` will
        do nothing to data from dataset.


    Args:
        dataset(Dataset): the dataset to load data from, should be an
            instance of subclass of :code:`paddle.io.Dataset` or
            :code:`paddle.io.IterableDataset`.
        feed_list (list(Tensor)|tuple(Tensor)|None, optional): feed Tensor list.
            The Tensors should be created by :code:`paddle.static.data()`.
            :attr:`feed_list` must be set if :attr:`return_list` is
            False. Default None.
        places(list(Place)|tuple(Place)|list(str)|None, optional): a list of Place,
            to put data onto, :attr:`places` can be None, if
            :attr:`places` is None, default place(CPUPlace or CUDAPlace(0))
            will be used. Default None. If ``places`` is list of string,
            the string in the list can be ``cpu``, ``gpu:x`` and ``gpu_pinned``,
            where ``x`` is the index of the GPUs.
        return_list (bool, optional): whether the return value on each device is
            presented as a list. If :attr:`return_list=False`, the return
            value on each device would be a dict of str -> Tensor, where
            the key of the dict is the name of each fed Tensors. If
            :attr:`return_list=True`, the return value on each device would
            be a list(Tensor). :attr:`return_list` can only be True
            in dynamic graph mode. Default True.
        batch_sampler(BatchSampler|None, optional): an instance of `paddle.io.BatchSampler`
            to generate batch indices to draw samples from :attr:`dataset`
            and combine a batch. Default None.
        batch_size(int|None, optional): sample number in a mini-batch, a substitution
            parameter for :attr:`batch_sampler`, if :attr:`batch_sampler`
            is not set, a default `paddle.io.BatchSampler` will be used
            and initialize by :attr:`batch_size`, :attr:`shuffle` and
            :attr:`drop_last`. Default 1.
        shuffle(bool, optional): whether to shuffle indices order before generate
            batch indices, a substitution parameter for :attr:`batch_sampler`
            see :attr:`batch_size`. Default False.
        drop_last(bool, optional): whether drop the last incomplete batch dataset size
            is not divisible by the batch size, a substitution parameter
            for :attr:`batch_sampler`, see :attr:`batch_size`. Default False
        collate_fn(Callable|None, optional): function to generate mini-batch data by merging
            the sample list, None for only stack each fields of sample in axis
            0(same as :attr::`np.stack(..., axis=0)`). Default None
        num_workers(int, optional): the number of subprocess to load data, 0 for no
            subprocess used and loading data in main process. Default 0
        use_buffer_reader (bool, optional): whether to use buffered reader.
            If use_buffer_reader=True, the DataLoader would prefetch
            batch data asynchronously, so it would speed up data feeding
            and occupies a little more CPU or GPU memory, i.e., the memory
            of one batch input data. Default True.
        prefetch_factor (int, optional): Number of batch data the DataLoader would prefetch
            if use_buffer_reader=True. Default 2.
        use_shared_memory (bool, optional): whether to use shared memory to speed up
            putting data into inter-process queue, set :attr:`use_shared_memory`
            as True only when the shared memory space on your machine(e.g.
            space of '/dev/shm' on Linux operating system) is large enough.
            Shared memory will only be enabled in multi-process mode(num_workers
            > 0). Default True.
        timeout(int, optional): the timeout value for getting data form output queue
            of subprocesses. Default 0.
        worker_init_fn(Callable|None, optional): init function which will be called with
            worker id on each subprocess starting if not set as None. Default
            None.
        persistent_workers(bool, optional): whether to keep the workers in the DataLoader. Default False.

    Returns:
        DataLoader: an iterable object for data iterating, each element of the generated data is a Tensor.

    Examples:

        .. code-block:: python

            >>> # doctest: +SOLO('can not use multiprocessing testing `paddle.io.DataLoader`')
            >>> import numpy as np

            >>> import paddle
            >>> import paddle.nn as nn
            >>> import paddle.nn.functional as F
            >>> from paddle.io import Dataset, BatchSampler, DataLoader

            >>> BATCH_NUM = 20
            >>> BATCH_SIZE = 16
            >>> EPOCH_NUM = 4

            >>> IMAGE_SIZE = 784
            >>> CLASS_NUM = 10

            >>> # define a random dataset
            >>> class RandomDataset(Dataset):  # type: ignore[type-arg]
            ...     def __init__(self, num_samples):
            ...         self.num_samples = num_samples
            ...
            ...     def __getitem__(self, idx):
            ...         image = np.random.random([IMAGE_SIZE]).astype('float32')
            ...         label = np.random.randint(0, CLASS_NUM - 1, (1, )).astype('int64')
            ...         return image, label
            ...
            ...     def __len__(self):
            ...         return self.num_samples
            ...
            >>> dataset = RandomDataset(BATCH_NUM * BATCH_SIZE)

            >>> class SimpleNet(nn.Layer):
            ...     def __init__(self):
            ...         super().__init__()
            ...         self.fc = nn.Linear(IMAGE_SIZE, CLASS_NUM)
            ...
            ...     def forward(self, image, label=None):
            ...         return self.fc(image)
            ...
            >>> simple_net = SimpleNet()
            >>> opt = paddle.optimizer.SGD(learning_rate=1e-3,
            ...                             parameters=simple_net.parameters())
            ...
            >>> loader = DataLoader(dataset,
            ...                     batch_size=BATCH_SIZE,
            ...                     shuffle=True,
            ...                     drop_last=True,
            ...                     num_workers=2)
            ...
            >>> for e in range(EPOCH_NUM):
            ...     for i, (image, label) in enumerate(loader()):
            ...         out = simple_net(image)
            ...         loss = F.cross_entropy(out, label)
            ...         avg_loss = paddle.mean(loss)
            ...         avg_loss.backward()
            ...         opt.minimize(avg_loss)
            ...         simple_net.clear_gradients()
            ...         print("Epoch {} batch {}: loss = {}".format(e, i, np.mean(loss.numpy())))

    Notes:
        For reading iterable dataset with multiprocess Dataloader,
        please see :code:`paddle.io.IterableDataset`
    """

    return_list: bool
    collate_fn: _CollateFn | None
    use_buffer_reader: bool
    prefetch_factor: int
    worker_init_fn: Callable[[int], None] | None
    dataset: Dataset[Any]
    feed_list: Sequence[Tensor] | None
    places: list[_Place]
    num_workers: int
    dataset_kind: _DatasetKind
    use_shared_memory: bool
    timeout: int
    batch_sampler: BatchSampler | _InfiniteIterableSampler | None
    drop_last: bool
    auto_collate_batch: bool

    def __init__(
        self,
        dataset: Dataset[Any],
        feed_list: Sequence[Tensor] | None = None,
        places: PlaceLike | Sequence[PlaceLike] | None = None,
        return_list: bool = True,
        batch_sampler: BatchSampler | None = None,
        batch_size: int = 1,
        shuffle: bool = False,
        drop_last: bool = False,
        collate_fn: _CollateFn | None = None,
        num_workers: int = 0,
        use_buffer_reader: bool = True,
        prefetch_factor: int = 2,
        use_shared_memory: bool = True,
        timeout: int = 0,
        worker_init_fn: Callable[[int], None] | None = None,
        persistent_workers: bool = False,
    ) -> None:
        self.return_list = return_list
        self.collate_fn = collate_fn
        self.use_buffer_reader = use_buffer_reader
        self.prefetch_factor = prefetch_factor
        self.worker_init_fn = worker_init_fn

        self.dataset = dataset

        if not return_list and not in_dynamic_mode():
            assert feed_list is not None, (
                "feed_list should be set when return_list=False"
            )
        self.feed_list = feed_list

        if places is None:
            places = _current_expected_place()
        if isinstance(places, (list, tuple)):
            places = _get_paddle_place_list(places)
        else:
            places = _get_paddle_place(places)
        self.places = _convert_places(places)

        assert num_workers >= 0, "num_workers should be a non-negative value"
        if num_workers > 0 and (
            sys.platform == 'darwin' or sys.platform == 'win32'
        ):
            warnings.warn(
                "DataLoader with multi-process mode is not supported on MacOs and Windows currently."
                " Please use single-process mode with num_workers = 0 instead"
            )
            num_workers = 0
        self.num_workers = num_workers

        assert prefetch_factor > 0, "prefetch_factor should be a positive value"

        self.use_shared_memory = use_shared_memory
        if use_shared_memory and num_workers == 0:
            self.use_shared_memory = False

        assert timeout >= 0, "timeout should be a non-negative value"
        self.timeout = timeout

        if isinstance(dataset, IterableDataset):
            self.dataset_kind = _DatasetKind.ITER
            if shuffle:
                raise ValueError(
                    f"IterableDataset not support shuffle, but got shuffle={shuffle}"
                )
            if batch_sampler is not None:
                raise ValueError(
                    "IterableDataset expect unspecified batch_sampler"
                )
        else:
            self.dataset_kind = _DatasetKind.MAP

        if batch_sampler is not None:
            assert batch_size == 1 and not shuffle and not drop_last, (
                "batch_size/shuffle/drop_last should not be set when "
                "batch_sampler is given"
            )
            self.batch_sampler = batch_sampler
            self.batch_size = None
        elif batch_size is None:
            self.batch_sampler = None
            self.batch_size = None
        else:
            assert batch_size > 0, (
                "batch_size should be None or a positive value when "
                "batch_sampler is not given"
            )
            self.batch_size = batch_size
            if isinstance(dataset, IterableDataset):
                self.batch_sampler = _InfiniteIterableSampler(
                    dataset, batch_size
                )
            else:
                self.batch_sampler = BatchSampler(
                    dataset=dataset,
                    batch_size=batch_size,
                    shuffle=shuffle,
                    drop_last=drop_last,
                )

        # Note(luchang): In auto DP mode, we use a distributed batch sampler to
        # ensure that each DP rank receives different data.
        if paddle.distributed.auto_parallel.auto_dp_utils.in_auto_dp_mode():
            mesh = paddle.distributed.fleet.auto.get_mesh()
            if mesh is None:
                word_size = paddle.distributed.get_world_size()
                mesh = paddle.distributed.ProcessMesh(
                    list(range(0, word_size)), dim_names=["dp"]
                )

            if "dp" not in mesh.dim_names:
                raise ValueError(
                    "Auto-DP mode requires the mesh to include a 'dp' dimension."
                )

            dp_rank = mesh.get_rank_by_dim_and_process_id(
                "dp", paddle.distributed.get_rank()
            )
            dp_world_size = mesh.get_dim_size("dp")

            self.batch_size = int(self.batch_sampler.batch_size / dp_world_size)
            if isinstance(self.batch_sampler, _InfiniteIterableSampler):
                shuffle = False
                drop_last = False
            else:
                shuffle = self.batch_sampler.shuffle
                drop_last = self.batch_sampler.drop_last
            self.batch_sampler = DistributedBatchSampler(
                dataset=dataset,
                batch_size=self.batch_size,
                num_replicas=dp_world_size,
                rank=dp_rank,
                shuffle=shuffle,
                drop_last=drop_last,
            )

        self.drop_last = drop_last
        self.auto_collate_batch = self.batch_sampler is not None

        self.pin_memory = False
        if in_dynamic_mode():
            self.pin_memory = (
                True if use_pinned_memory() is None else use_pinned_memory()
            )

        self._persistent_workers = persistent_workers
        self._iterator = None
        self.num_workers = AuToTune(self).__call__()

    def __len__(self) -> int:
        if self.dataset_kind == _DatasetKind.ITER:
            raise ValueError("length of IterableDataset not supported")
        else:
            if self.auto_collate_batch:
                return len(self.batch_sampler)
            else:
                return len(self.dataset)

    def __iter__(self) -> _DataLoaderIterBase:
        if self.num_workers == 0:
            return _DataLoaderIterSingleProcess(self)
        elif self._persistent_workers:
            if self._iterator is None:
                self._iterator = _DataLoaderIterMultiProcess(self)
            else:
                self._iterator._reset()
            return self._iterator
        else:
            return _DataLoaderIterMultiProcess(self)

    def __call__(self) -> _DataLoaderIterBase:
        return self.__iter__()
