# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING

from ...base import framework

if TYPE_CHECKING:
    from types import TracebackType


__all__ = ["LazyGuard"]


class LazyInitHelper:
    """
    A Helper Context to trigger switching mode between dygraph and static graph mode,
    and holds the startup program resource.
    """

    def __init__(self):
        self._state = False
        self._tracer = None
        self._in_guard = False

    def enable(self):
        """
        Switch into lazy mode.

        NOTE(dev): This is a very low level API and not exposed for user.
        """
        if self._state:
            return
        assert framework.in_dygraph_mode(), (
            "LazyInit.enable() is only available in dygraph mode."
        )
        self._state = True

    def disable(self):
        """
        Exit from lazy mode.

        NOTE(dev): This is a very low level API and not exposed for user.
        """
        if not self._state:
            return
        self._state = False

    def __enter__(self):
        """
        Switch into lazy mode and set _dygraph_tracer_ with None to convert
        dygraph mode into static graph mode.
        """
        self.enable()
        if self._in_guard:
            return
        self._tracer = framework.global_var._dygraph_tracer_
        framework.global_var._dygraph_tracer_ = None
        self._in_guard = True

    def __exit__(self, *args, **kwargs):
        """
        Exit from lazy mode and recover _dygraph_tracer_.
        """
        self.disable()
        if not self._in_guard:
            return
        assert self._tracer is not None
        framework.global_var._dygraph_tracer_ = self._tracer
        self._tracer = None
        self._in_guard = False

    @property
    def state(self):
        return self._state


_lazy_init_helper = LazyInitHelper()


def lazy_init_helper():
    global _lazy_init_helper
    return _lazy_init_helper


class LazyGuard:
    """
    LazyGuard is a wrapper interface for nn.Layer, it forwards the construct
    process of user defined Layer. Meanwhile, it provides necessary API to
    trigger EagerParamBase Lazy Initialization and get startup Program.

    Examples:

        .. code-block:: python

            >>> from paddle import LazyGuard
            >>> from paddle.nn import Linear

            >>> with LazyGuard():
            ...     # w and b are initialized lazily and have no memory.
            ...     net = Linear(10, 10)
            ...
            >>> for param in net.parameters():
            ...     # Initialize param and allocate memory explicitly.
            ...     param.initialize()
    """

    def __enter__(self) -> None:
        """
        Construct instance from class_obj by Lazy Initializing parameters.

        Examples:

            .. code-block:: python

                >>> from paddle import LazyGuard
                >>> from paddle.nn import Linear

                >>> with LazyGuard():
                ...     fc = LazyInit(Linear)(10, 10)
                ...
                >>> for param in fc.parameters():
                ...     param.initialize()
        """
        lazy_init_helper().enable()

    def __exit__(
        self,
        exc_type: type[BaseException] | None,
        exc_val: BaseException | None,
        exc_tb: TracebackType | None,
    ) -> None:
        lazy_init_helper().disable()
