# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, Literal

from typing_extensions import TypeAlias

if TYPE_CHECKING:
    from collections.abc import Sequence

    from paddle._typing.dtype_like import _DTypeLiteral
    from paddle.vision.transforms.transforms import _Transform

    from ..image import _ImageDataType

    _AllowedExtensions: TypeAlias = Literal[
        '.jpg',
        '.jpeg',
        '.png',
        '.ppm',
        '.bmp',
        '.pgm',
        '.tif',
        '.tiff',
        '.webp',
    ]

import os

from PIL import Image

import paddle
from paddle.io import Dataset
from paddle.utils import try_import

__all__ = []


def has_valid_extension(filename: str, extensions: Sequence[str]) -> bool:
    """Checks if a file is a valid extension.

    Args:
        filename (str): path to a file
        extensions (list[str]|tuple[str]): extensions to consider

    Returns:
        bool: True if the filename ends with one of given extensions
    """
    assert isinstance(extensions, (list, tuple)), (
        "`extensions` must be list or tuple."
    )
    extensions = tuple([x.lower() for x in extensions])
    return filename.lower().endswith(extensions)


def make_dataset(dir, class_to_idx, extensions, is_valid_file=None):
    images = []
    dir = os.path.expanduser(dir)

    if extensions is not None:

        def is_valid_file(x):
            return has_valid_extension(x, extensions)

    for target in sorted(class_to_idx.keys()):
        d = os.path.join(dir, target)
        if not os.path.isdir(d):
            continue
        for root, _, fnames in sorted(os.walk(d, followlinks=True)):
            for fname in sorted(fnames):
                path = os.path.join(root, fname)
                if is_valid_file(path):
                    item = (path, class_to_idx[target])
                    images.append(item)

    return images


class DatasetFolder(Dataset[tuple["_ImageDataType", int]]):
    """A generic data loader where the samples are arranged in this way:

    .. code-block:: text

        root/class_a/1.ext
        root/class_a/2.ext
        root/class_a/3.ext

        root/class_b/123.ext
        root/class_b/456.ext
        root/class_b/789.ext

    Args:
        root (str): Root directory path.
        loader (Callable|None, optional): A function to load a sample given its path. Default: None.
        extensions (list[str]|tuple[str]|None, optional): A list of allowed extensions.
            Both :attr:`extensions` and :attr:`is_valid_file` should not be passed.
            If this value is not set, the default is to use ('.jpg', '.jpeg', '.png',
            '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp'). Default: None.
        transform (Callable|None, optional): A function/transform that takes in
            a sample and returns a transformed version. Default: None.
        is_valid_file (Callable|None, optional): A function that takes path of a file
            and check if the file is a valid file. Both :attr:`extensions` and
            :attr:`is_valid_file` should not be passed. Default: None.

    Returns:
        :ref:`api_paddle_io_Dataset`. An instance of DatasetFolder.

    Attributes:
        classes (list[str]): List of the class names.
        class_to_idx (dict[str, int]): Dict with items (class_name, class_index).
        samples (list[tuple[str, int]]): List of (sample_path, class_index) tuples.
        targets (list[int]): The class_index value for each image in the dataset.

    Example:

        .. code-block:: python

            >>> import shutil
            >>> import tempfile
            >>> import cv2
            >>> import numpy as np
            >>> import paddle.vision.transforms as T
            >>> from pathlib import Path
            >>> from paddle.vision.datasets import DatasetFolder

            >>> def make_fake_file(img_path: str):
            ...     if img_path.endswith((".jpg", ".png", ".jpeg")):
            ...         fake_img = np.random.randint(0, 256, (32, 32, 3), dtype=np.uint8)
            ...         cv2.imwrite(img_path, fake_img)
            ...     elif img_path.endswith(".txt"):
            ...         with open(img_path, "w") as f:
            ...             f.write("This is a fake file.")

            >>> def make_directory(root, directory_hierarchy, file_maker=make_fake_file):
            ...     root = Path(root)
            ...     root.mkdir(parents=True, exist_ok=True)
            ...     for subpath in directory_hierarchy:
            ...         if isinstance(subpath, str):
            ...             filepath = root / subpath
            ...             file_maker(str(filepath))
            ...         else:
            ...             dirname = list(subpath.keys())[0]
            ...             make_directory(root / dirname, subpath[dirname])

            >>> directory_hierarchy = [
            ...     {"class_0": [
            ...         "abc.jpg",
            ...         "def.png"]},
            ...     {"class_1": [
            ...         "ghi.jpeg",
            ...         "jkl.png",
            ...         {"mno": [
            ...             "pqr.jpeg",
            ...             "stu.jpg"]}]},
            ...     "this_will_be_ignored.txt",
            ... ]

            >>> # You can replace this with any directory to explore the structure
            >>> # of generated data. e.g. fake_data_dir = "./temp_dir"
            >>> fake_data_dir = tempfile.mkdtemp()
            >>> make_directory(fake_data_dir, directory_hierarchy)
            >>> data_folder_1 = DatasetFolder(fake_data_dir)
            >>> print(data_folder_1.classes)
            ['class_0', 'class_1']
            >>> print(data_folder_1.class_to_idx)
            {'class_0': 0, 'class_1': 1}
            >>> print(data_folder_1.samples)
            >>> # doctest: +SKIP(it's different with windows)
            [('./temp_dir/class_0/abc.jpg', 0), ('./temp_dir/class_0/def.png', 0),
             ('./temp_dir/class_1/ghi.jpeg', 1), ('./temp_dir/class_1/jkl.png', 1),
             ('./temp_dir/class_1/mno/pqr.jpeg', 1), ('./temp_dir/class_1/mno/stu.jpg', 1)]
            >>> # doctest: -SKIP
            >>> print(data_folder_1.targets)
            [0, 0, 1, 1, 1, 1]
            >>> print(len(data_folder_1))
            6

            >>> for i in range(len(data_folder_1)):
            ...     img, label = data_folder_1[i]
            ...     # do something with img and label
            ...     print(type(img), img.size, label)
            ...     # <class 'PIL.Image.Image'> (32, 32) 0


            >>> transform = T.Compose(
            ...     [
            ...         T.Resize(64),
            ...         T.ToTensor(),
            ...         T.Normalize(
            ...             mean=[0.5, 0.5, 0.5],
            ...             std=[0.5, 0.5, 0.5],
            ...             to_rgb=True,
            ...         ),
            ...     ]
            ... )

            >>> data_folder_2 = DatasetFolder(
            ...     fake_data_dir,
            ...     loader=lambda x: cv2.imread(x),  # load image with OpenCV
            ...     extensions=(".jpg",),  # only load *.jpg files
            ...     transform=transform,  # apply transform to every image
            ... )

            >>> print([img_path for img_path, label in data_folder_2.samples])
            >>> # doctest: +SKIP(it's different with windows)
            ['./temp_dir/class_0/abc.jpg', './temp_dir/class_1/mno/stu.jpg']
            >>> # doctest: -SKIP
            >>> print(len(data_folder_2))
            2

            >>> for img, label in iter(data_folder_2):
            ...     # do something with img and label
            ...     print(type(img), img.shape, label)  # type: ignore
            ...     # <class 'paddle.Tensor'> [3, 64, 64] 0

            >>> shutil.rmtree(fake_data_dir)
    """

    loader: Callable[..., _ImageDataType] | None
    extensions: Sequence[_AllowedExtensions] | None
    transform: _Transform[Any, Any] | None
    classes: list[str]
    class_to_idx: dict[str, int]
    samples: list[tuple[str, int]]
    targets: list[str]
    dtype: _DTypeLiteral

    def __init__(
        self,
        root: str,
        loader: Callable[..., _ImageDataType] | None = None,
        extensions: Sequence[_AllowedExtensions] | None = None,
        transform: _Transform[Any, Any] | None = None,
        is_valid_file: _ImageDataType | None = None,
    ) -> None:
        self.root = root
        self.transform = transform
        if extensions is None:
            extensions = IMG_EXTENSIONS
        classes, class_to_idx = self._find_classes(self.root)
        samples = make_dataset(
            self.root, class_to_idx, extensions, is_valid_file
        )
        if len(samples) == 0:
            raise (
                RuntimeError(
                    "Found 0 directories in subfolders of: " + self.root + "\n"
                    "Supported extensions are: " + ",".join(extensions)
                )
            )

        self.loader = default_loader if loader is None else loader
        self.extensions = extensions

        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples
        self.targets = [s[1] for s in samples]

        self.dtype = paddle.get_default_dtype()

    def _find_classes(self, dir: str) -> tuple[list[str], dict[str, int]]:
        """
        Finds the class folders in a dataset.

        Args:
            dir (string): Root directory path.

        Returns:
            tuple: (classes, class_to_idx) where classes are relative to (dir),
                    and class_to_idx is a dictionary.

        """
        classes = [d.name for d in os.scandir(dir) if d.is_dir()]
        classes.sort()
        class_to_idx = {classes[i]: i for i in range(len(classes))}
        return classes, class_to_idx

    def __getitem__(self, index: int) -> tuple[_ImageDataType, int]:
        """
        Args:
            index (int): Index

        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)

        return sample, target

    def __len__(self):
        return len(self.samples)


IMG_EXTENSIONS = (
    '.jpg',
    '.jpeg',
    '.png',
    '.ppm',
    '.bmp',
    '.pgm',
    '.tif',
    '.tiff',
    '.webp',
)


def pil_loader(path):
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')


def cv2_loader(path):
    cv2 = try_import('cv2')
    return cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB)


def default_loader(path):
    from paddle.vision import get_image_backend

    if get_image_backend() == 'cv2':
        return cv2_loader(path)
    else:
        return pil_loader(path)


class ImageFolder(Dataset[list["_ImageDataType"]]):
    """A generic data loader where the samples are arranged in this way:

    .. code-block:: text

        root/1.ext
        root/2.ext
        root/sub_dir/3.ext

    Args:
        root (str): Root directory path.
        loader (Callable|None, optional): A function to load a sample given its path. Default: None.
        extensions (list[str]|tuple[str]|None, optional): A list of allowed extensions.
            Both :attr:`extensions` and :attr:`is_valid_file` should not be passed.
            If this value is not set, the default is to use ('.jpg', '.jpeg', '.png',
            '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp'). Default: None.
        transform (Callable|None, optional): A function/transform that takes in
            a sample and returns a transformed version. Default: None.
        is_valid_file (Callable|None, optional): A function that takes path of a file
            and check if the file is a valid file. Both :attr:`extensions` and
            :attr:`is_valid_file` should not be passed. Default: None.

    Returns:
        :ref:`api_paddle_io_Dataset`. An instance of ImageFolder.

    Attributes:
        samples (list[str]): List of sample path.

    Example:

        .. code-block:: python

            >>> import shutil
            >>> import tempfile
            >>> import cv2
            >>> import numpy as np
            >>> import paddle.vision.transforms as T
            >>> from pathlib import Path
            >>> from paddle.vision.datasets import ImageFolder

            >>> def make_fake_file(img_path: str):
            ...     if img_path.endswith((".jpg", ".png", ".jpeg")):
            ...         fake_img = np.random.randint(0, 256, (32, 32, 3), dtype=np.uint8)
            ...         cv2.imwrite(img_path, fake_img)
            ...     elif img_path.endswith(".txt"):
            ...         with open(img_path, "w") as f:
            ...             f.write("This is a fake file.")

            >>> def make_directory(root, directory_hierarchy, file_maker=make_fake_file):
            ...     root = Path(root)
            ...     root.mkdir(parents=True, exist_ok=True)
            ...     for subpath in directory_hierarchy:
            ...         if isinstance(subpath, str):
            ...             filepath = root / subpath
            ...             file_maker(str(filepath))
            ...         else:
            ...             dirname = list(subpath.keys())[0]
            ...             make_directory(root / dirname, subpath[dirname])

            >>> directory_hierarchy = [
            ...     "abc.jpg",
            ...     "def.png",
            ...     {"ghi": [
            ...         "jkl.jpeg",
            ...         {"mno": [
            ...             "pqr.jpg"]}]},
            ...     "this_will_be_ignored.txt",
            ... ]

            >>> # You can replace this with any directory to explore the structure
            >>> # of generated data. e.g. fake_data_dir = "./temp_dir"
            >>> fake_data_dir = tempfile.mkdtemp()
            >>> make_directory(fake_data_dir, directory_hierarchy)
            >>> image_folder_1 = ImageFolder(fake_data_dir)
            >>> print(image_folder_1.samples)
            >>> # doctest: +SKIP(it's different with windows)
            ['./temp_dir/abc.jpg', './temp_dir/def.png',
             './temp_dir/ghi/jkl.jpeg', './temp_dir/ghi/mno/pqr.jpg']
            >>> # doctest: -SKIP
            >>> print(len(image_folder_1))
            4

            >>> for i in range(len(image_folder_1)):
            ...     (img,) = image_folder_1[i]
            ...     # do something with img
            ...     print(type(img), img.size)
            ...     # <class 'PIL.Image.Image'> (32, 32)


            >>> transform = T.Compose(
            ...     [
            ...         T.Resize(64),
            ...         T.ToTensor(),
            ...         T.Normalize(
            ...             mean=[0.5, 0.5, 0.5],
            ...             std=[0.5, 0.5, 0.5],
            ...             to_rgb=True,
            ...         ),
            ...     ]
            ... )

            >>> image_folder_2 = ImageFolder(
            ...     fake_data_dir,
            ...     loader=lambda x: cv2.imread(x),  # load image with OpenCV
            ...     extensions=(".jpg",),  # only load *.jpg files
            ...     transform=transform,  # apply transform to every image
            ... )

            >>> print(image_folder_2.samples)
            >>> # doctest: +SKIP(it's different with windows)
            ['./temp_dir/abc.jpg', './temp_dir/ghi/mno/pqr.jpg']
            >>> # doctest: -SKIP
            >>> print(len(image_folder_2))
            2

            >>> for (img,) in iter(image_folder_2):
            ...     # do something with img
            ...     print(type(img), img.shape)  # type: ignore
            ...     # <class 'paddle.Tensor'> [3, 64, 64]

            >>> shutil.rmtree(fake_data_dir)
    """

    loader: Callable[..., _ImageDataType] | None
    extensions: Sequence[_AllowedExtensions] | None
    samples: list[str]
    transform: _Transform[Any, Any] | None

    def __init__(
        self,
        root: str,
        loader: Callable[..., _ImageDataType] | None = None,
        extensions: Sequence[_AllowedExtensions] | None = None,
        transform: _Transform[Any, Any] | None = None,
        is_valid_file: _ImageDataType | None = None,
    ) -> None:
        self.root = root
        if extensions is None:
            extensions = IMG_EXTENSIONS

        samples = []
        path = os.path.expanduser(root)

        if extensions is not None:

            def is_valid_file(x):
                return has_valid_extension(x, extensions)

        for root, _, fnames in sorted(os.walk(path, followlinks=True)):
            for fname in sorted(fnames):
                f = os.path.join(root, fname)
                if is_valid_file(f):
                    samples.append(f)

        if len(samples) == 0:
            raise (
                RuntimeError(
                    "Found 0 files in subfolders of: " + self.root + "\n"
                    "Supported extensions are: " + ",".join(extensions)
                )
            )

        self.loader = default_loader if loader is None else loader
        self.extensions = extensions
        self.samples = samples
        self.transform = transform

    def __getitem__(self, index: int) -> list[_ImageDataType]:
        """
        Args:
            index (int): Index

        Returns:
            sample of specific index.
        """
        path = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        return [sample]

    def __len__(self) -> int:
        return len(self.samples)
