#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Definition of device workers."""

import sys

__all__ = []


class DeviceWorker:
    """
    DeviceWorker is an abstract class, which generates worker desc.
    This class is an inner class that we do computation logics within
    the implementation. For example, execution of a program or a graph.
    """

    def __init__(self):
        """Init."""
        self._program = None
        self._infer = None

    def _set_infer(self, infer=False):
        """
        set inference flag for current device worker

        Args:
            infer(bool): whether to do inference
        """
        self._infer = infer

    def _set_fleet_desc(self, fleet_desc):
        """
        Set fleet desc.

        Args:
            fleet_desc(PSParameter): pslib.PSParameter object
        """
        self._fleet_desc = fleet_desc

    def _set_program(self, program):
        """
        Set program.

        Args:
            program(Program): a Program object
        """
        self._program = program

    def _gen_worker_desc(self, trainer_desc):
        """
        Generator worker desc.

        Args:
            trainer_desc(TrainerDesc): a TrainerDesc object
        """
        raise NotImplementedError(
            "DeviceWorker does not implement gen_worker_desc, "
            "please use Hogwild or DownpourSGD, etc."
        )


class Hogwild(DeviceWorker):
    """
    Hogwild is a kind of SGD algorithm.

    """

    def __init__(self):
        """Init."""
        super().__init__()

    def _gen_worker_desc(self, trainer_desc):
        """
        Generator worker desc, which device worker is HogwildWorker.

        Args:
            trainer_desc(TrainerDesc): a TrainerDesc object
        """
        trainer_desc.device_worker_name = "HogwildWorker"
        if self._infer:
            # just ignore feed op for inference model
            trainer_desc.hogwild_param.skip_ops.extend(
                [
                    "feed",
                    "push_sparse_v2",
                    "push_dense",
                    "distributed_push_sparse",
                    "send",
                ]
            )

        dense_table_set = set()
        program_id = str(id(self._program))
        print("device worker program id:", program_id)
        if self._program is None:
            print("program of current device worker is not configured")
            sys.exit(-1)
        opt_info = self._program._fleet_opt
        # when opt_info is None or empty dict, it should return
        if not opt_info:
            return
        downpour = trainer_desc.downpour_param
        hogwild = trainer_desc.hogwild_param
        if opt_info["stat_var_names"]:
            for i in opt_info["stat_var_names"]:
                hogwild.stat_var_names.extend([i])
                downpour.stat_var_names.extend([i])

        from paddle.incubate.distributed.fleet.parameter_server import version

        if (
            version.is_transpiler()
            and "fleet_desc" not in opt_info
            and "program_configs" not in opt_info
        ):
            return

        program_configs = opt_info["program_configs"]
        print("device worker program_configs:", program_configs)

        for pid in program_configs:
            print("device worker", pid, program_id)
            if pid == program_id:
                pc = downpour.program_config.add()
                pc.program_id = program_id
                print(
                    "device worker pull dense:",
                    program_configs[program_id]["pull_dense"],
                )
                for i in program_configs[program_id]["push_sparse"]:
                    pc.push_sparse_table_id.extend([i])
                for i in program_configs[program_id]["push_dense"]:
                    pc.push_dense_table_id.extend([i])
                    dense_table_set.add(i)
                for i in program_configs[program_id]["pull_sparse"]:
                    pc.pull_sparse_table_id.extend([i])
                for i in program_configs[program_id]["pull_dense"]:
                    pc.pull_dense_table_id.extend([i])
                    dense_table_set.add(i)
                break

        trainer_desc.device_worker_name = "HogwildWorker"
        pull_thread = trainer_desc.pull_dense_param
        pull_thread.device_num = trainer_desc.thread_num
        if (
            opt_info.get("program_id_to_worker") is None
            and opt_info.get("dense_table_config") is None
        ):
            raise ValueError(
                "opt_info must have program_id_to_worker or dense_table_config"
            )
        if opt_info.get("program_id_to_worker") is not None:
            prog_id_to_worker = opt_info["program_id_to_worker"]
            if prog_id_to_worker.get(program_id) is None:
                raise ValueError(
                    f"{program_id} not found in program_id_to_worker"
                )
            worker = opt_info["program_id_to_worker"][program_id]
            for i in worker.get_desc().dense_table:
                if i.table_id in dense_table_set:
                    dense_table = pull_thread.dense_table.add()
                    dense_table.dense_value_name.extend(i.dense_variable_name)
                    dense_table.table_id = i.table_id
            sparse_len = len(worker.get_desc().sparse_table)
            for i in range(sparse_len):
                sparse_table = downpour.sparse_table.add()
                sparse_table.table_id = (
                    worker.get_desc().sparse_table[i].table_id
                )
                sparse_table.sparse_key_name.extend(
                    worker.get_desc().sparse_table[i].slot_key
                )
                sparse_table.sparse_value_name.extend(
                    worker.get_desc().sparse_table[i].slot_value
                )
                sparse_table.sparse_grad_name.extend(
                    worker.get_desc().sparse_table[i].slot_gradient
                )
                sparse_table.fea_dim = self._fleet_desc.server_param.downpour_server_param.downpour_table_param[
                    i
                ].accessor.fea_dim
                # not use emb_dim
                sparse_table.emb_dim = -1
                # not use hard code click
                sparse_table.label_var_name = ""

            for i in worker.get_desc().dense_table:
                if i.table_id in dense_table_set:
                    dense_table = downpour.dense_table.add()
                    dense_table.table_id = i.table_id
                    dense_table.dense_value_name.extend(i.dense_variable_name)
                    dense_table.dense_grad_name.extend(
                        i.dense_gradient_variable_name
                    )
            hogwild.skip_ops.extend(worker.get_desc().skip_op)
        else:
            dense_table_config = opt_info.get("dense_table_config")
            print("device worker dense_table_config:", dense_table_config)
            for table_id, varnames in dense_table_config.items():
                dense_table = pull_thread.dense_table.add()
                dense_table.dense_value_name.extend(varnames)
                dense_table.table_id = table_id

        if self._infer:
            hogwild.skip_ops.extend(
                ["push_sparse", "push_sparse_v2", "push_dense"]
            )


class DownpourLite(DeviceWorker):
    """
    DownpourLite is a kind of SGD algorithm.

    """

    def __init__(self):
        """Init."""
        super().__init__()

    def _gen_worker_desc(self, trainer_desc):
        """
        Generator worker desc, which device worker is DownpourLiteWorker.

        Args:
            trainer_desc(TrainerDesc): a TrainerDesc object
        """
        print("create DownpourLiteWorker")
        trainer_desc.device_worker_name = "DownpourLiteWorker"
        if self._infer:
            # just ignore feed op for inference model
            trainer_desc.downpour_param.skip_ops.extend(
                [
                    "feed",
                    "push_sparse",
                    "push_sparse_v2",
                    "push_dense",
                    "distributed_push_sparse",
                    "send",
                ]
            )

        dense_table_set = set()
        program_id = str(id(self._program))
        print("device worker program id:", program_id)
        if self._program is None:
            print("program of current device worker is not configured")
            sys.exit(-1)
        opt_info = self._program._fleet_opt
        # when opt_info is None or empty dict, it should return
        if not opt_info:
            return
        downpour = trainer_desc.downpour_param
        if opt_info["stat_var_names"]:
            for i in opt_info["stat_var_names"]:
                downpour.stat_var_names.extend([i])

        from paddle.incubate.distributed.fleet.parameter_server import version

        if (
            version.is_transpiler()
            and "fleet_desc" not in opt_info
            and "program_configs" not in opt_info
        ):
            return

        program_configs = opt_info["program_configs"]
        print("device worker program_configs:", program_configs)

        for pid in program_configs:
            print("device worker", pid, program_id)
            if pid == program_id:
                pc = downpour.program_config.add()
                pc.program_id = program_id
                print(
                    "device worker pull dense:",
                    program_configs[program_id]["pull_dense"],
                )
                for i in program_configs[program_id]["push_sparse"]:
                    pc.push_sparse_table_id.extend([i])
                for i in program_configs[program_id]["push_dense"]:
                    pc.push_dense_table_id.extend([i])
                    dense_table_set.add(i)
                for i in program_configs[program_id]["pull_sparse"]:
                    pc.pull_sparse_table_id.extend([i])
                for i in program_configs[program_id]["pull_dense"]:
                    pc.pull_dense_table_id.extend([i])
                    dense_table_set.add(i)
                break

        pull_thread = trainer_desc.pull_dense_param
        pull_thread.device_num = trainer_desc.thread_num
        if (
            opt_info.get("program_id_to_worker") is None
            and opt_info.get("dense_table_config") is None
        ):
            raise ValueError(
                "opt_info must have program_id_to_worker or dense_table_config"
            )
        if opt_info.get("program_id_to_worker") is not None:
            prog_id_to_worker = opt_info["program_id_to_worker"]
            if prog_id_to_worker.get(program_id) is None:
                raise ValueError(
                    f"{program_id} not found in program_id_to_worker"
                )
            worker = opt_info["program_id_to_worker"][program_id]
            for i in worker.get_desc().dense_table:
                if i.table_id in dense_table_set:
                    dense_table = pull_thread.dense_table.add()
                    dense_table.dense_value_name.extend(i.dense_variable_name)
                    dense_table.table_id = i.table_id
            sparse_len = len(worker.get_desc().sparse_table)
            for i in range(sparse_len):
                sparse_table = downpour.sparse_table.add()
                sparse_table.table_id = (
                    worker.get_desc().sparse_table[i].table_id
                )
                sparse_table.sparse_key_name.extend(
                    worker.get_desc().sparse_table[i].slot_key
                )
                sparse_table.sparse_value_name.extend(
                    worker.get_desc().sparse_table[i].slot_value
                )
                sparse_table.sparse_grad_name.extend(
                    worker.get_desc().sparse_table[i].slot_gradient
                )
                sparse_table.fea_dim = self._fleet_desc.server_param.downpour_server_param.downpour_table_param[
                    i
                ].accessor.fea_dim
                # not use emb_dim
                sparse_table.emb_dim = -1
                # not use hard code click
                sparse_table.label_var_name = ""

            for i in worker.get_desc().dense_table:
                if i.table_id in dense_table_set:
                    dense_table = downpour.dense_table.add()
                    dense_table.table_id = i.table_id
                    dense_table.dense_value_name.extend(i.dense_variable_name)
                    dense_table.dense_grad_name.extend(
                        i.dense_gradient_variable_name
                    )
            downpour.skip_ops.extend(worker.get_desc().skip_op)
        else:
            dense_table_config = opt_info.get("dense_table_config")
            print("device worker dense_table_config:", dense_table_config)
            for table_id, varnames in dense_table_config.items():
                dense_table = pull_thread.dense_table.add()
                dense_table.dense_value_name.extend(varnames)
                dense_table.table_id = table_id

        if self._infer:
            downpour.skip_ops.extend(
                ["push_sparse", "push_sparse_v2", "push_dense"]
            )


class DownpourSGD(DeviceWorker):
    """
    DownpourSGD is a kind of distributed SGD algorithm.
    """

    def __init__(self):
        """
        Init.
        initialize downpourSGD device worker
        """
        super().__init__()

    def _gen_worker_desc(self, trainer_desc):
        """
        Generator worker desc, which device worker is DownpourWorker.

        Args:
            trainer_desc(TrainerDesc): a TrainerDesc object
        """
        dense_table_set = set()
        program_id = str(id(self._program))
        if self._program is None:
            print("program of current device worker is not configured")
            sys.exit(-1)
        opt_info = self._program._fleet_opt
        program_configs = opt_info["program_configs"]
        downpour = trainer_desc.downpour_param

        for pid in program_configs:
            if pid == program_id:
                pc = downpour.program_config.add()
                pc.program_id = program_id
                for i in program_configs[program_id]["push_sparse"]:
                    pc.push_sparse_table_id.extend([i])
                for i in program_configs[program_id]["push_dense"]:
                    pc.push_dense_table_id.extend([i])
                    dense_table_set.add(i)
                for i in program_configs[program_id]["pull_sparse"]:
                    pc.pull_sparse_table_id.extend([i])
                for i in program_configs[program_id]["pull_dense"]:
                    pc.pull_dense_table_id.extend([i])
                    dense_table_set.add(i)
                # code for partial push dense table such as multitask
                if "cond2denseid" in program_configs[program_id]:
                    cond2denseid = program_configs[program_id]["cond2denseid"]
                    for key, value in cond2denseid.items():
                        mc_map = pc.partial_pushdense_condtable_map.add()
                        mc_map.key = key
                        mc_map.value = value
                break

        trainer_desc.device_worker_name = opt_info.get(
            "worker_class", "DownpourWorker"
        )
        pull_thread = trainer_desc.pull_dense_param
        pull_thread.device_num = trainer_desc.thread_num
        if opt_info.get("program_id_to_worker") is None:
            raise ValueError("opt_info must have program_id_to_worker")
        prog_id_to_worker = opt_info["program_id_to_worker"]
        if prog_id_to_worker.get(program_id) is None:
            raise ValueError(f"{program_id} not found in program_id_to_worker")
        worker = opt_info["program_id_to_worker"][program_id]
        for i in worker.get_desc().dense_table:
            if i.table_id in dense_table_set:
                dense_table = pull_thread.dense_table.add()
                dense_table.dense_value_name.extend(i.dense_variable_name)
                dense_table.table_id = i.table_id
        sparse_len = len(worker.get_desc().sparse_table)
        for i in range(sparse_len):
            sparse_table = downpour.sparse_table.add()
            sparse_table.table_id = worker.get_desc().sparse_table[i].table_id
            sparse_table.sparse_key_name.extend(
                worker.get_desc().sparse_table[i].slot_key
            )
            sparse_table.sparse_value_name.extend(
                worker.get_desc().sparse_table[i].slot_value
            )
            sparse_table.sparse_grad_name.extend(
                worker.get_desc().sparse_table[i].slot_gradient
            )
            if (
                opt_info["use_cvm"]
                or "no_cvm" in opt_info
                and opt_info["no_cvm"] is True
            ):
                sparse_table.emb_dim = self._fleet_desc.server_param.downpour_server_param.downpour_table_param[
                    i
                ].accessor.fea_dim
                sparse_table.fea_dim = sparse_table.emb_dim
            else:
                sparse_table.emb_dim = (
                    self._fleet_desc.server_param.downpour_server_param.downpour_table_param[
                        i
                    ].accessor.fea_dim
                    - 2
                )
                sparse_table.fea_dim = sparse_table.emb_dim + 2
            # TODO(guru4elephant): hard code here, need to improve
            sparse_table.label_var_name = "click"
        if opt_info["stat_var_names"]:
            for i in opt_info["stat_var_names"]:
                downpour.stat_var_names.extend([i])

        for i in worker.get_desc().dense_table:
            if i.table_id in dense_table_set:
                dense_table = downpour.dense_table.add()
                dense_table.table_id = i.table_id
                dense_table.dense_value_name.extend(i.dense_variable_name)
                dense_table.dense_grad_name.extend(
                    i.dense_gradient_variable_name
                )
        downpour.skip_ops.extend(worker.get_desc().skip_op)
        if self._infer:
            downpour.push_dense = False
            downpour.push_sparse = False


class DownpourSGDOPT(DeviceWorker):
    """
    DownpourSGDOPT is a kind of distributed SGD algorithm.
    """

    def __init__(self):
        """
        Init.
        initialize downpourSGDOPT device worker
        """
        super().__init__()

    def _gen_worker_desc(self, trainer_desc):
        """
        Generator worker desc, which device worker is DownpourWorker.

        Args:
            trainer_desc(TrainerDesc): a TrainerDesc object
        """
        dense_table_set = set()
        program_id = str(id(self._program))
        if self._program is None:
            print("program of current device worker is not configured")
            sys.exit(-1)
        opt_info = self._program._fleet_opt
        program_configs = opt_info["program_configs"]
        downpour = trainer_desc.downpour_param

        for pid in program_configs:
            if pid == program_id:
                pc = downpour.program_config.add()
                pc.program_id = program_id
                for i in program_configs[program_id]["push_sparse"]:
                    pc.push_sparse_table_id.extend([i])
                for i in program_configs[program_id]["push_dense"]:
                    pc.push_dense_table_id.extend([i])
                    dense_table_set.add(i)
                for i in program_configs[program_id]["pull_sparse"]:
                    pc.pull_sparse_table_id.extend([i])
                for i in program_configs[program_id]["pull_dense"]:
                    pc.pull_dense_table_id.extend([i])
                    dense_table_set.add(i)
                break

        trainer_desc.device_worker_name = "DownpourWorkerOpt"
        pull_thread = trainer_desc.pull_dense_param
        pull_thread.device_num = trainer_desc.thread_num
        if opt_info.get("program_id_to_worker") is None:
            raise ValueError("opt_info must have program_id_to_worker")
        prog_id_to_worker = opt_info["program_id_to_worker"]
        if prog_id_to_worker.get(program_id) is None:
            raise ValueError(f"{program_id} not found in program_id_to_worker")
        worker = opt_info["program_id_to_worker"][program_id]
        for i in worker.get_desc().dense_table:
            if i.table_id in dense_table_set:
                dense_table = pull_thread.dense_table.add()
                dense_table.dense_value_name.extend(i.dense_variable_name)
                dense_table.table_id = i.table_id
        sparse_len = len(worker.get_desc().sparse_table)
        for i in range(sparse_len):
            sparse_table = downpour.sparse_table.add()
            sparse_table.table_id = worker.get_desc().sparse_table[i].table_id
            sparse_table.sparse_key_name.extend(
                worker.get_desc().sparse_table[i].slot_key
            )
            sparse_table.sparse_value_name.extend(
                worker.get_desc().sparse_table[i].slot_value
            )
            sparse_table.sparse_grad_name.extend(
                worker.get_desc().sparse_table[i].slot_gradient
            )
            if (
                opt_info["use_cvm"]
                or "no_cvm" in opt_info
                and opt_info["no_cvm"] is True
            ):
                sparse_table.emb_dim = self._fleet_desc.server_param.downpour_server_param.downpour_table_param[
                    i
                ].accessor.fea_dim
                sparse_table.fea_dim = sparse_table.emb_dim
            else:
                sparse_table.emb_dim = (
                    self._fleet_desc.server_param.downpour_server_param.downpour_table_param[
                        i
                    ].accessor.fea_dim
                    - 2
                )
                sparse_table.fea_dim = sparse_table.emb_dim + 2
            # TODO(guru4elephant): hard code here, need to improve
            sparse_table.label_var_name = "click"
        if (
            "local_tables" in opt_info
            and sparse_table.table_id in opt_info["local_tables"]
        ):
            sparse_table.is_local = True
        if (
            "async_tables" in opt_info
            and sparse_table.table_id in opt_info["async_tables"]
        ):
            sparse_table.is_async = True
        if opt_info["stat_var_names"]:
            for i in opt_info["stat_var_names"]:
                downpour.stat_var_names.extend([i])

        for i in worker.get_desc().dense_table:
            if i.table_id in dense_table_set:
                dense_table = downpour.dense_table.add()
                dense_table.table_id = i.table_id
                dense_table.dense_value_name.extend(i.dense_variable_name)
                dense_table.dense_grad_name.extend(
                    i.dense_gradient_variable_name
                )
        downpour.skip_ops.extend(worker.get_desc().skip_op)
        if self._infer:
            downpour.push_dense = False
            downpour.push_sparse = False


class Section(DeviceWorker):
    """SectionWorker."""

    def __init__(self):
        """Init."""
        super().__init__()

    def _gen_worker_desc(self, trainer_desc):
        """
        Generator worker desc, which device worker is SectionWorker.
        Args:
            trainer_desc(TrainerDesc): a TrainerDesc object
        """
        from . import core

        trainer_desc.device_worker_name = "SectionWorker"
        pipeline_opt = self._program._pipeline_opt
        section_param = trainer_desc.section_param
        section_param.num_microbatches = pipeline_opt["num_microbatches"]
        section_param.start_cpu_core_id = pipeline_opt["start_cpu_core_id"]
        section_param.pipeline_stage = pipeline_opt["pipeline_stage"]
        section_param.num_pipeline_stages = pipeline_opt["num_pipeline_stages"]
        schedule_mode_str = pipeline_opt["schedule_mode"]
        # F-then-B scheduler which runs Forward phase for all microbatches,
        # then runs Backward phase for all microbatches.
        # 1F1B scheduler, which runs forward phase and backward phase alternatively
        # after startup phase.
        assert schedule_mode_str in [
            "F-then-B",
            "1F1B",
        ], "The schedule mode for pipeline must be one of F-then-B or 1F1B"
        schedule_mode = 0 if schedule_mode_str == "F-then-B" else 1
        section_param.schedule_mode = schedule_mode
        cfg = section_param.section_config
        program = pipeline_opt["section_program"]
        cfg.program_desc.ParseFromString(
            program._get_desc().serialize_to_string()
        )
        # TODO: why does not work
        # cfg.program_desc.CopyFrom(program.program._get_desc())
        place = pipeline_opt["place"]
        place_id = pipeline_opt["place_id"]
        if core.is_compiled_with_cuda():
            assert isinstance(place, core.CUDAPlace)
        cfg.place = cfg.CUDAPlace
        cfg.place_id = place_id


class HeterSection(DeviceWorker):
    """HeterSectionWorker."""

    def __init__(self):
        """Init."""
        super().__init__()

    def _gen_worker_desc(self, trainer_desc):
        """
        Generator worker desc, which device worker is HeterSectionWorker.
        Args:
            trainer_desc(TrainerDesc): a TrainerDesc object
        """

        trainer_desc.device_worker_name = "HeterSectionWorker"
        heter_pipeline_opt = self._program._heter_pipeline_opt
        heter_section_param = trainer_desc.heter_section_param
        heter_section_param.num_microbatches = heter_pipeline_opt[
            "num_microbatches"
        ]
        heter_section_param.pipeline_stage = heter_pipeline_opt[
            "pipeline_stage"
        ]
        heter_section_param.num_pipeline_stages = heter_pipeline_opt[
            "num_pipeline_stages"
        ]
        cfg = heter_section_param.section_config
        program = heter_pipeline_opt["section_program"]
        cfg.program_desc.ParseFromString(
            program._get_desc().serialize_to_string()
        )


class DeviceWorkerFactory:
    def _create_device_worker(self, worker_type):
        classname = worker_type.capitalize()
        return globals()[classname]()
