# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except jin compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import time
from multiprocessing import Manager, Process

# deprecated module import
# (TODO: GhostScreaming) It will be removed later.
from paddle.base import core
from paddle.distributed.fleet.base.private_helper_function import (
    wait_server_ready,
)

__all__ = []

_global_gloo_ctx = None


def _start_kv_server(port, http_server_d, size):
    from paddle.distributed.fleet.utils.http_server import KVServer

    http_server = KVServer(int(port), size=size)
    http_server.start()
    wait_seconds = 3
    while http_server_d.get("running", False) or not http_server.should_stop():
        time.sleep(wait_seconds)
    http_server.stop()


def gloo_init_parallel_env(
    rank_id: int, rank_num: int, server_endpoint: str
) -> None:
    """
    Initialize parallel environment with gloo for cpu only.

    Args:
        - rank_id (int, required) - the index of current rank;
        - rank_num (int, required) - the number of ranks in this parallel env;
        - server_endpoint (str, required) - endpoint of server to init gloo context in ip:port format;

    Returns:
        None

    Examples:
        .. code-block:: python

            >>> import paddle
            >>> import multiprocessing
            >>> from contextlib import closing
            >>> import socket

            >>> port_set = set() # type: ignore

            >>> def find_free_port():
            ...     def _free_port():
            ...         with closing(socket.socket(socket.AF_INET,
            ...             socket.SOCK_STREAM)) as s:
            ...             s.bind(('', 0))
            ...             return s.getsockname()[1]
            ...     while True:
            ...         port = _free_port()
            ...         if port not in port_set:
            ...             port_set.add(port)
            ...             return port

            >>> def test_gloo_init(id, rank_num, server_endpoint):
            ...     paddle.distributed.gloo_init_parallel_env(
            ...         id, rank_num, server_endpoint)

            >>> def test_gloo_init_with_multiprocess(num_of_ranks):
            ...     jobs = []
            ...     server_endpoint = "127.0.0.1:%s" % (find_free_port())
            ...     for id in range(num_of_ranks):
            ...         p = multiprocessing.Process(
            ...             target=test_gloo_init,
            ...             args=(id, num_of_ranks, server_endpoint))
            ...         jobs.append(p)
            ...         p.start()
            ...     for proc in jobs:
            ...         proc.join()

            >>> if __name__ == '__main__':
            ...     # Arg: number of ranks (processes)
            ...     test_gloo_init_with_multiprocess(2)
    """

    assert (rank_num < 2) is False, (
        "rank_num should greater than or equal to 2 for parallel environment initialization."
    )

    # init gloo context
    manager = Manager()
    # global dict to store status
    http_server_status = manager.dict()
    http_server_status["running"] = False
    if rank_id == 0:
        # The scope for worker used by http server is '_worker'
        size = {'_worker': rank_num}
        http_server_proc = Process(
            target=_start_kv_server,
            args=(int(server_endpoint.split(":")[1]), http_server_status, size),
        )
        http_server_proc.daemon = True
        http_server_status["running"] = True
        http_server_proc.start()

    # all processes in this parallel environment should wait until server is ready
    wait_server_ready([server_endpoint])

    gloo_strategy = core.GlooParallelStrategy()
    gloo_strategy.rank = rank_id
    gloo_strategy.rank_num = rank_num
    gloo_strategy.ip_address = server_endpoint.split(":")[0]
    gloo_strategy.ip_port = int(server_endpoint.split(":")[1])
    # default_init_timeout_seconds
    gloo_strategy.init_seconds = 3600
    # default_run_timeout_seconds
    gloo_strategy.run_seconds = 9999999

    global _global_gloo_ctx
    _global_gloo_ctx = core.GlooParallelContext(gloo_strategy)
    _global_gloo_ctx.init()

    if rank_id == 0:
        http_server_status["running"] = False
        http_server_proc.join()


def gloo_barrier() -> None:
    """
    Call barrier function with initialized gloo context.

    Args:
        None

    Returns:
        None

    Examples:
        .. code-block:: python

            >>> import paddle
            >>> import multiprocessing
            >>> from contextlib import closing
            >>> import socket

            >>> port_set = set() # type: ignore

            >>> def find_free_port():
            ...     def _free_port():
            ...         with closing(socket.socket(socket.AF_INET,
            ...             socket.SOCK_STREAM)) as s:
            ...             s.bind(('', 0))
            ...             return s.getsockname()[1]
            ...     while True:
            ...         port = _free_port()
            ...         if port not in port_set:
            ...             port_set.add(port)
            ...             return port

            >>> def test_gloo_barrier(id, rank_num, server_endpoint):
            ...     paddle.distributed.gloo_init_parallel_env(
            ...         id, rank_num, server_endpoint)
            ...     paddle.distributed.gloo_barrier()

            >>> def test_gloo_barrier_with_multiprocess(num_of_ranks):
            ...     jobs = []
            ...     server_endpoint = "127.0.0.1:%s" % (find_free_port())
            ...     for id in range(num_of_ranks):
            ...         p = multiprocessing.Process(
            ...             target=test_gloo_barrier,
            ...             args=(id, num_of_ranks, server_endpoint))
            ...         jobs.append(p)
            ...         p.start()
            ...     for proc in jobs:
            ...         proc.join()

            >>> if __name__ == '__main__':
            ...     # Arg: number of ranks (processes)
            ...     test_gloo_barrier_with_multiprocess(2)
    """

    assert _global_gloo_ctx is not None, "gloo context is not initialized."
    _global_gloo_ctx.barrier()


def gloo_release() -> None:
    """
    Release the parallel environment initialized by gloo

    Args:
        None

    Returns:
        None

    Examples:
        .. code-block:: python

            >>> import paddle
            >>> import multiprocessing
            >>> from contextlib import closing
            >>> import socket

            >>> port_set = set() # type: ignore

            >>> def find_free_port():
            ...     def _free_port():
            ...         with closing(socket.socket(socket.AF_INET,
            ...             socket.SOCK_STREAM)) as s:
            ...             s.bind(('', 0))
            ...             return s.getsockname()[1]
            ...     while True:
            ...         port = _free_port()
            ...         if port not in port_set:
            ...             port_set.add(port)
            ...             return port

            >>> def test_gloo_release(id, rank_num, server_endpoint):
            ...     paddle.distributed.gloo_init_parallel_env(
            ...         id, rank_num, server_endpoint)
            ...     paddle.distributed.gloo_barrier()
            ...     paddle.distributed.gloo_release()

            >>> def test_gloo_release_with_multiprocess(num_of_ranks):
            ...     jobs = []
            ...     server_endpoint = "127.0.0.1:%s" % (find_free_port())
            ...     for id in range(num_of_ranks):
            ...         p = multiprocessing.Process(
            ...             target=test_gloo_release,
            ...             args=(id, num_of_ranks, server_endpoint))
            ...         jobs.append(p)
            ...         p.start()
            ...     for proc in jobs:
            ...         proc.join()

            >>> if __name__ == '__main__':
            ...     # Arg: number of ranks (processes)
            ...     test_gloo_release_with_multiprocess(2)
    """

    if _global_gloo_ctx is not None:
        _global_gloo_ctx.release()
