#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numbers
from collections.abc import Mapping, Sequence

import numpy as np

import paddle

FIELD_PREFIX = "_paddle_field_"


def _flatten_batch(batch):
    """
    For lod_blocking_queue only receive tensor array, flatten batch
    data, extract numpy.array data out as a list of numpy.array to
    send to lod_blocking_queue, and save the batch data structure
    such as fields in other types (str, int, etc) or key-value map
    of dictionaries
    """

    def _flatten(batch, flat_batch, structure, field_idx):
        if isinstance(batch, Sequence):
            for field in batch:
                if isinstance(
                    field,
                    (np.ndarray, paddle.Tensor, paddle.base.core.eager.Tensor),
                ):
                    structure.append(f'{FIELD_PREFIX}{field_idx}')
                    flat_batch.append(field)
                    field_idx += 1
                elif isinstance(field, (str, bytes, numbers.Number)):
                    structure.append(field)
                elif isinstance(field, Sequence):
                    field_struct, field_idx = _flatten(
                        field, flat_batch, [], field_idx
                    )
                    structure.append(field_struct)
                elif isinstance(field, Mapping):
                    field_struct, field_idx = _flatten(
                        field, flat_batch, {}, field_idx
                    )
                    structure.append(field_struct)
                else:
                    structure.append(field)
        elif isinstance(batch, Mapping):
            for k, field in batch.items():
                if isinstance(
                    field,
                    (np.ndarray, paddle.Tensor, paddle.base.core.eager.Tensor),
                ):
                    structure[k] = f'{FIELD_PREFIX}{field_idx}'
                    flat_batch.append(field)
                    field_idx += 1
                elif isinstance(field, (str, bytes, numbers.Number)):
                    structure[k] = field
                elif isinstance(field, Sequence):
                    field_struct, field_idx = _flatten(
                        field, flat_batch, [], field_idx
                    )
                    structure[k] = field_struct
                elif isinstance(field, Mapping):
                    field_struct, field_idx = _flatten(
                        field, flat_batch, {}, field_idx
                    )
                    structure[k] = field_struct
                else:
                    structure[k] = field
        else:
            raise TypeError(f"wrong flat data type: {type(batch)}")

        return structure, field_idx

    # sample only contains single fields
    if not isinstance(batch, Sequence):
        flat_batch = []
        structure, _ = _flatten([batch], flat_batch, [], 0)
        return flat_batch, structure[0]
    flat_batch = []
    structure, _ = _flatten(batch, flat_batch, [], 0)
    return flat_batch, structure


def _restore_batch(flat_batch, structure):
    """
    After reading list of Tensor data from lod_blocking_queue outputs,
    use this function to restore the batch data structure, replace
    :attr:`_paddle_field_x` with data from flat_batch
    """

    def _restore(structure, field_idx):
        if isinstance(structure, Sequence):
            for i, field in enumerate(structure):
                if isinstance(field, str) and field.startswith(FIELD_PREFIX):
                    cur_field_idx = int(field.replace(FIELD_PREFIX, ''))
                    field_idx = max(field_idx, cur_field_idx)
                    assert flat_batch[cur_field_idx] is not None, (
                        "flat_batch[{}] parsed repeatedly"
                    )
                    structure[i] = flat_batch[cur_field_idx]
                    flat_batch[cur_field_idx] = None
                elif isinstance(field, (str, bytes, numbers.Number)):
                    continue
                elif isinstance(field, (Sequence, Mapping)):
                    field_idx = _restore(structure[i], field_idx)
        elif isinstance(structure, Mapping):
            for k, field in structure.items():
                if isinstance(field, str) and field.startswith(FIELD_PREFIX):
                    cur_field_idx = int(field.replace(FIELD_PREFIX, ''))
                    field_idx = max(field_idx, cur_field_idx)
                    assert flat_batch[cur_field_idx] is not None, (
                        "flat_batch[{}] parsed repeatedly"
                    )
                    structure[k] = flat_batch[cur_field_idx]
                    flat_batch[cur_field_idx] = None
                elif isinstance(field, (str, bytes, numbers.Number)):
                    continue
                elif isinstance(field, (Sequence, Mapping)):
                    field_idx = _restore(structure[k], field_idx)
        else:
            raise TypeError(f"wrong flat data type: {type(structure)}")

        return field_idx

    assert isinstance(flat_batch, Sequence), "flat_batch is not a list or tuple"

    # no np.array in dataset, no output tensor from blocking queue
    # simply return structure
    if len(flat_batch) == 0:
        return structure

    # sample only contains single fields
    if isinstance(structure, (str, bytes)):
        assert structure == f'{FIELD_PREFIX}{0}', (
            f"invalid structure: {structure}"
        )
        return flat_batch[0]
    field_idx = _restore(structure, 0)
    assert field_idx + 1 == len(flat_batch), "Tensor parse incomplete"
    return structure
