#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Literal

if TYPE_CHECKING:
    import numpy.typing as npt

    from paddle._typing.dtype_like import _DTypeLiteral
    from paddle.vision.transforms.transforms import _Transform

    from ..image import _ImageBackend, _ImageDataType

    _DatasetMode = Literal["train", "test"]

import gzip
import struct

import numpy as np
from PIL import Image

import paddle
from paddle.dataset.common import _check_exists_and_download
from paddle.io import Dataset

__all__ = []


class MNIST(Dataset[tuple["_ImageDataType", "npt.NDArray[np.int64]"]]):
    """
    Implementation of `MNIST <http://yann.lecun.com/exdb/mnist/>`_ dataset.

    Args:
        image_path (str|None, optional): Path to image file, can be set None if
            :attr:`download` is True. Default: None, default data path: ~/.cache/paddle/dataset/mnist.
        label_path (str|None, optional): Path to label file, can be set None if
            :attr:`download` is True. Default: None, default data path: ~/.cache/paddle/dataset/mnist.
        mode (str, optional): Either train or test mode. Default 'train'.
        transform (Callable|None, optional): Transform to perform on image, None for no transform. Default: None.
        download (bool, optional): Download dataset automatically if
            :attr:`image_path` :attr:`label_path` is not set. Default: True.
        backend (str|None, optional): Specifies which type of image to be returned:
            PIL.Image or numpy.ndarray. Should be one of {'pil', 'cv2'}.
            If this option is not set, will get backend from :ref:`paddle.vision.get_image_backend <api_paddle_vision_get_image_backend>`,
            default backend is 'pil'. Default: None.

    Returns:
        :ref:`api_paddle_io_Dataset`. An instance of MNIST dataset.

    Examples:

        .. code-block:: python

            >>> import itertools
            >>> import paddle.vision.transforms as T
            >>> from paddle.vision.datasets import MNIST


            >>> mnist = MNIST()
            >>> print(len(mnist))
            60000

            >>> for i in range(5):  # only show first 5 images
            ...     img, label = mnist[i]
            ...     # do something with img and label
            ...     print(type(img), img.size, label)
            ...     # <class 'PIL.Image.Image'> (28, 28) [5]


            >>> transform = T.Compose(
            ...     [
            ...         T.ToTensor(),
            ...         T.Normalize(
            ...             mean=[127.5],
            ...             std=[127.5],
            ...         ),
            ...     ]
            ... )

            >>> mnist_test = MNIST(
            ...     mode="test",
            ...     transform=transform,  # apply transform to every image
            ...     backend="cv2",  # use OpenCV as image transform backend
            ... )
            >>> print(len(mnist_test))
            10000

            >>> for img, label in itertools.islice(iter(mnist_test), 5):  # only show first 5 images
            ...     # do something with img and label
            ...     print(type(img), img.shape, label)  # type: ignore
            ...     # <class 'paddle.Tensor'> [1, 28, 28] [7]
    """

    NAME = 'mnist'
    URL_PREFIX = 'https://dataset.bj.bcebos.com/mnist/'
    TEST_IMAGE_URL = URL_PREFIX + 't10k-images-idx3-ubyte.gz'
    TEST_IMAGE_MD5 = '9fb629c4189551a2d022fa330f9573f3'
    TEST_LABEL_URL = URL_PREFIX + 't10k-labels-idx1-ubyte.gz'
    TEST_LABEL_MD5 = 'ec29112dd5afa0611ce80d1b7f02629c'
    TRAIN_IMAGE_URL = URL_PREFIX + 'train-images-idx3-ubyte.gz'
    TRAIN_IMAGE_MD5 = 'f68b3c2dcbeaaa9fbdd348bbdeb94873'
    TRAIN_LABEL_URL = URL_PREFIX + 'train-labels-idx1-ubyte.gz'
    TRAIN_LABEL_MD5 = 'd53e105ee54ea40749a09fcbcd1e9432'

    mode: _DatasetMode
    image_path: str | None
    label_path: str | None
    transform: _Transform[Any, Any] | None
    backend: _ImageBackend
    dtype: _DTypeLiteral
    labels: list
    images: list

    def __init__(
        self,
        image_path: str | None = None,
        label_path: str | None = None,
        mode: _DatasetMode = 'train',
        transform: _Transform[Any, Any] | None = None,
        download: bool = True,
        backend: _ImageBackend | None = None,
    ) -> None:
        assert mode.lower() in [
            'train',
            'test',
        ], f"mode should be 'train' or 'test', but got {mode}"

        if backend is None:
            backend = paddle.vision.get_image_backend()
        if backend not in ['pil', 'cv2']:
            raise ValueError(
                f"Expected backend are one of ['pil', 'cv2'], but got {backend}"
            )
        self.backend = backend

        self.mode = mode.lower()
        self.image_path = image_path
        if self.image_path is None:
            assert download, (
                "image_path is not set and downloading automatically is disabled"
            )
            image_url = (
                self.TRAIN_IMAGE_URL if mode == 'train' else self.TEST_IMAGE_URL
            )
            image_md5 = (
                self.TRAIN_IMAGE_MD5 if mode == 'train' else self.TEST_IMAGE_MD5
            )
            self.image_path = _check_exists_and_download(
                image_path, image_url, image_md5, self.NAME, download
            )

        self.label_path = label_path
        if self.label_path is None:
            assert download, (
                "label_path is not set and downloading automatically is disabled"
            )
            label_url = (
                self.TRAIN_LABEL_URL
                if self.mode == 'train'
                else self.TEST_LABEL_URL
            )
            label_md5 = (
                self.TRAIN_LABEL_MD5
                if self.mode == 'train'
                else self.TEST_LABEL_MD5
            )
            self.label_path = _check_exists_and_download(
                label_path, label_url, label_md5, self.NAME, download
            )

        self.transform = transform

        # read dataset into memory
        self._parse_dataset()

        self.dtype = paddle.get_default_dtype()

    def _parse_dataset(self, buffer_size=100):
        self.images = []
        self.labels = []
        with gzip.GzipFile(self.image_path, 'rb') as image_file:
            img_buf = image_file.read()
            with gzip.GzipFile(self.label_path, 'rb') as label_file:
                lab_buf = label_file.read()

                step_label = 0
                offset_img = 0
                # read from Big-endian
                # get file info from magic byte
                # image file : 16B
                magic_byte_img = '>IIII'
                magic_img, image_num, rows, cols = struct.unpack_from(
                    magic_byte_img, img_buf, offset_img
                )
                offset_img += struct.calcsize(magic_byte_img)

                offset_lab = 0
                # label file : 8B
                magic_byte_lab = '>II'
                magic_lab, label_num = struct.unpack_from(
                    magic_byte_lab, lab_buf, offset_lab
                )
                offset_lab += struct.calcsize(magic_byte_lab)

                while True:
                    if step_label >= label_num:
                        break
                    fmt_label = '>' + str(buffer_size) + 'B'
                    labels = struct.unpack_from(fmt_label, lab_buf, offset_lab)
                    offset_lab += struct.calcsize(fmt_label)
                    step_label += buffer_size

                    fmt_images = '>' + str(buffer_size * rows * cols) + 'B'
                    images_temp = struct.unpack_from(
                        fmt_images, img_buf, offset_img
                    )
                    images = np.reshape(
                        images_temp, (buffer_size, rows * cols)
                    ).astype('float32')
                    offset_img += struct.calcsize(fmt_images)

                    for i in range(buffer_size):
                        self.images.append(images[i, :])
                        self.labels.append(
                            np.array([labels[i]]).astype('int64')
                        )

    def __getitem__(
        self, idx: int
    ) -> tuple[_ImageDataType, npt.NDArray[np.int64]]:
        image, label = self.images[idx], self.labels[idx]
        image = np.reshape(image, [28, 28])

        if self.backend == 'pil':
            image = Image.fromarray(image.astype('uint8'), mode='L')

        if self.transform is not None:
            image = self.transform(image)

        if self.backend == 'pil':
            return image, label.astype('int64')

        return image.astype(self.dtype), label.astype('int64')

    def __len__(self) -> int:
        return len(self.labels)


class FashionMNIST(MNIST):
    """
    Implementation of `Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`_ dataset.

    Args:
        image_path (str, optional): Path to image file, can be set None if
            :attr:`download` is True. Default: None, default data path: ~/.cache/paddle/dataset/fashion-mnist.
        label_path (str, optional): Path to label file, can be set None if
            :attr:`download` is True. Default: None, default data path: ~/.cache/paddle/dataset/fashion-mnist.
        mode (str, optional): Either train or test mode. Default 'train'.
        transform (Callable, optional): Transform to perform on image, None for no transform. Default: None.
        download (bool, optional): Whether to download dataset automatically if
            :attr:`image_path` :attr:`label_path` is not set. Default: True.
        backend (str, optional): Specifies which type of image to be returned:
            PIL.Image or numpy.ndarray. Should be one of {'pil', 'cv2'}.
            If this option is not set, will get backend from :ref:`paddle.vision.get_image_backend <api_paddle_vision_get_image_backend>`,
            default backend is 'pil'. Default: None.

    Returns:
        :ref:`api_paddle_io_Dataset`. An instance of FashionMNIST dataset.

    Examples:

        .. code-block:: python

            >>> import itertools
            >>> import paddle.vision.transforms as T
            >>> from paddle.vision.datasets import FashionMNIST


            >>> fashion_mnist = FashionMNIST()
            >>> print(len(fashion_mnist))
            60000

            >>> for i in range(5):  # only show first 5 images
            ...     img, label = fashion_mnist[i]
            ...     # do something with img and label
            ...     print(type(img), img.size, label)
            ...     # <class 'PIL.Image.Image'> (28, 28) [9]


            >>> transform = T.Compose(
            ...     [
            ...         T.ToTensor(),
            ...         T.Normalize(
            ...             mean=[127.5],
            ...             std=[127.5],
            ...         ),
            ...     ]
            ... )

            >>> fashion_mnist_test = FashionMNIST(
            ...     mode="test",
            ...     transform=transform,  # apply transform to every image
            ...     backend="cv2",  # use OpenCV as image transform backend
            ... )
            >>> print(len(fashion_mnist_test))
            10000

            >>> for img, label in itertools.islice(iter(fashion_mnist_test), 5):  # only show first 5 images
            ...     # do something with img and label
            ...     print(type(img), img.shape, label)  # type: ignore
            ...     # <class 'paddle.Tensor'> [1, 28, 28] [9]
    """

    NAME = 'fashion-mnist'
    URL_PREFIX = 'https://dataset.bj.bcebos.com/fashion_mnist/'
    TEST_IMAGE_URL = URL_PREFIX + 't10k-images-idx3-ubyte.gz'
    TEST_IMAGE_MD5 = 'bef4ecab320f06d8554ea6380940ec79'
    TEST_LABEL_URL = URL_PREFIX + 't10k-labels-idx1-ubyte.gz'
    TEST_LABEL_MD5 = 'bb300cfdad3c16e7a12a480ee83cd310'
    TRAIN_IMAGE_URL = URL_PREFIX + 'train-images-idx3-ubyte.gz'
    TRAIN_IMAGE_MD5 = '8d4fb7e6c68d591d4c3dfef9ec88bf0d'
    TRAIN_LABEL_URL = URL_PREFIX + 'train-labels-idx1-ubyte.gz'
    TRAIN_LABEL_MD5 = '25c81989df183df01b3e8a0aad5dffbe'
