#  Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import (
    TYPE_CHECKING,
)

import paddle
from paddle import nn

if TYPE_CHECKING:
    from paddle import Tensor

__all__ = []


class LeNet(nn.Layer):
    """LeNet model from
    `"Gradient-based learning applied to document recognition" <https://ieeexplore.ieee.org/document/726791>`_.

    Args:
        num_classes (int, optional): Output dim of last fc layer. If num_classes <= 0, last fc layer
            will not be defined. Default: 10.

    Returns:
        :ref:`api_paddle_nn_Layer`. An instance of LeNet model.

    Examples:
        .. code-block:: python

            >>> import paddle
            >>> from paddle.vision.models import LeNet

            >>> model = LeNet()

            >>> x = paddle.rand([1, 1, 28, 28])
            >>> out = model(x)

            >>> print(out.shape)
            [1, 10]
    """

    num_classes: int

    def __init__(self, num_classes: int = 10) -> None:
        super().__init__()
        self.num_classes = num_classes
        self.features = nn.Sequential(
            nn.Conv2D(1, 6, 3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2D(2, 2),
            nn.Conv2D(6, 16, 5, stride=1, padding=0),
            nn.ReLU(),
            nn.MaxPool2D(2, 2),
        )

        if num_classes > 0:
            self.fc = nn.Sequential(
                nn.Linear(400, 120),
                nn.Linear(120, 84),
                nn.Linear(84, num_classes),
            )

    def forward(self, inputs: Tensor) -> Tensor:
        x = self.features(inputs)

        if self.num_classes > 0:
            x = paddle.flatten(x, 1)
            x = self.fc(x)
        return x
