import subprocess

from packaging import version
import torch

try:
    import triton  # noqa: F401
    import triton.language as tl  # noqa: F401

    triton_available = True
except ImportError:
    triton_available = False


_NF4_QUANT_TABLE = torch.tensor(
    [
        -1.0,
        -0.6961928009986877,
        -0.5250730514526367,
        -0.39491748809814453,
        -0.28444138169288635,
        -0.18477343022823334,
        -0.09105003625154495,
        0.0,
        0.07958029955625534,
        0.16093020141124725,
        0.24611230194568634,
        0.33791524171829224,
        0.44070982933044434,
        0.5626170039176941,
        0.7229568362236023,
        1.0,
    ],
    dtype=torch.float32,
    device="xpu"
    if hasattr(torch, "xpu") and torch.xpu.is_available()
    else "cpu",  # Only cpu/xpu use this table for now.
)
_FP4_QUANT_TABLE = torch.tensor(
    [
        0.0000,
        0.0052,
        0.6667,
        1.0000,
        0.3333,
        0.5000,
        0.1667,
        0.2500,
        0.0000,
        -0.0052,
        -0.6667,
        -1.0000,
        -0.3333,
        -0.5000,
        -0.1667,
        -0.2500,
    ],
    dtype=torch.float32,
    device="xpu"
    if hasattr(torch, "xpu") and torch.xpu.is_available()
    else "cpu",  # Only cpu/xpu use this table for now.
)
CODE = {"nf4": _NF4_QUANT_TABLE, "fp4": _FP4_QUANT_TABLE}


def get_gaudi_sw_version():
    """
    Returns the installed version of Gaudi SW.
    """
    output = subprocess.run(
        "pip list | grep habana-torch-plugin",
        shell=True,
        text=True,
        capture_output=True,
    )
    # If grep return nothing
    if not output.stdout.strip():
        return None

    return version.parse(output.stdout.split("\n")[0].split()[-1])


GAUDI_SW_VER = get_gaudi_sw_version()
