""" Sin-cos, fourier, rotary position embedding modules and functions

Hacked together by / Copyright 2022 Ross Wightman
"""
import math
from typing import List, Tuple, Optional, Union

import torch
from torch import nn as nn

from ._fx import register_notrace_function
from .grid import ndgrid
from .trace_utils import _assert


def pixel_freq_bands(
        num_bands: int,
        max_freq: float = 224.,
        linear_bands: bool = True,
        device: Optional[torch.device] = None,
):
    if linear_bands:
        bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=torch.float32, device=device)
    else:
        bands = 2 ** torch.linspace(0, math.log(max_freq, 2) - 1, num_bands, dtype=torch.float32, device=device)
    return bands * torch.pi


def freq_bands(
        num_bands: int,
        temperature: float = 10000.,
        step: int = 2,
        device: Optional[torch.device] = None,
) -> torch.Tensor:
    exp = torch.arange(0, num_bands, step, dtype=torch.int64, device=device).to(torch.float32) / num_bands
    bands = 1. / (temperature ** exp)
    return bands


def build_sincos2d_pos_embed(
        feat_shape: List[int],
        dim: int = 64,
        temperature: float = 10000.,
        reverse_coord: bool = False,
        interleave_sin_cos: bool = False,
        device: Optional[torch.device] = None,
        dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
    """

    Args:
        feat_shape:
        dim:
        temperature:
        reverse_coord: stack grid order W, H instead of H, W
        interleave_sin_cos: sin, cos, sin, cos stack instead of sin, sin, cos, cos
        dtype:
        device:

    Returns:

    """
    assert dim % 4 == 0, 'Embed dimension must be divisible by 4 for sin-cos 2D position embedding'
    pos_dim = dim // 4
    bands = freq_bands(pos_dim, temperature=temperature, step=1, device=device)

    if reverse_coord:
        feat_shape = feat_shape[::-1]  # stack W, H instead of H, W
    grid = torch.stack(ndgrid([
        torch.arange(s, device=device, dtype=torch.int64).to(torch.float32)
        for s in feat_shape
    ])).flatten(1).transpose(0, 1)
    pos2 = grid.unsqueeze(-1) * bands.unsqueeze(0)
    # FIXME add support for unflattened spatial dim?

    stack_dim = 2 if interleave_sin_cos else 1  # stack sin, cos, sin, cos  instead of sin sin cos cos
    pos_emb = torch.stack([torch.sin(pos2), torch.cos(pos2)], dim=stack_dim).flatten(1)
    return pos_emb.to(dtype=dtype)


def swap_shape_xy(seq: List[int]) -> List[int]:
    if len(seq) < 2:
        return seq
    return [seq[1], seq[0]] + list(seq[2:])


def build_fourier_pos_embed(
        feat_shape: List[int],
        bands: Optional[torch.Tensor] = None,
        num_bands: int = 64,
        max_res: int = 224,
        temperature: float = 10000.,
        linear_bands: bool = False,
        include_grid: bool = False,
        in_pixels: bool = True,
        ref_feat_shape: Optional[List[int]] = None,
        grid_offset: float = 0.,
        grid_indexing: str = 'ij',
        device: Optional[torch.device] = None,
        dtype: torch.dtype = torch.float32,
) -> List[torch.Tensor]:
    """

    Args:
        feat_shape: Feature shape for embedding.
        bands: Pre-calculated frequency bands.
        num_bands: Number of frequency bands (determines output dim).
        max_res: Maximum resolution for pixel based freq.
        temperature: Temperature for non-pixel freq.
        linear_bands: Linear band spacing for pixel based freq.
        include_grid: Include the spatial grid in output.
        in_pixels: Output in pixel freq.
        ref_feat_shape: Reference feature shape for resize / fine-tune.
        grid_offset: Constant offset to add to grid for non-pixel freq.
        grid_indexing: Indexing mode for meshgrid ('ij' or 'xy')
        dtype: Output dtype.
        device: Output device.

    Returns:

    """
    if bands is None:
        if in_pixels:
            bands = pixel_freq_bands(
                num_bands,
                float(max_res),
                linear_bands=linear_bands,
                device=device,
            )
        else:
            bands = freq_bands(
                num_bands,
                temperature=temperature,
                step=1,
                device=device,
            )
    else:
        if device is None:
            device = bands.device
        if dtype is None:
            dtype = bands.dtype

    if grid_indexing == 'xy':
        feat_shape = swap_shape_xy(feat_shape)
        if ref_feat_shape is not None:
            ref_feat_shape = swap_shape_xy(ref_feat_shape)

    if in_pixels:
        t = [
            torch.linspace(-1., 1., steps=s, device=device, dtype=torch.float32)
            for s in feat_shape
        ]
    else:
        t = [
            torch.arange(s, device=device, dtype=torch.int64).to(torch.float32) + grid_offset
            for s in feat_shape
        ]

    if ref_feat_shape is not None:
        # eva's scheme for resizing rope embeddings (ref shape = pretrain)
        t = [x / f * r for x, f, r in zip(t, feat_shape, ref_feat_shape)]

    grid = torch.stack(torch.meshgrid(t, indexing=grid_indexing), dim=-1)
    grid = grid.unsqueeze(-1)
    pos = grid * bands

    pos_sin, pos_cos = pos.sin().to(dtype=dtype), pos.cos().to(dtype=dtype)
    out = [grid, pos_sin, pos_cos] if include_grid else [pos_sin, pos_cos]
    return out


class FourierEmbed(nn.Module):

    def __init__(
            self,
            max_res: int = 224,
            num_bands: int = 64,
            concat_grid=True,
            keep_spatial=False,
            device=None,
            dtype=None,
    ):
        super().__init__()
        self.max_res = max_res
        self.num_bands = num_bands
        self.concat_grid = concat_grid
        self.keep_spatial = keep_spatial
        self.register_buffer(
            'bands',
            pixel_freq_bands(max_res, num_bands).to(device=device, dtype=dtype),
            persistent=False,
        )

    def forward(self, x):
        B, C = x.shape[:2]
        feat_shape = x.shape[2:]
        emb = build_fourier_pos_embed(
            feat_shape,
            self.bands,
            include_grid=self.concat_grid,
            dtype=x.dtype,
            device=x.device,
        )
        emb = torch.cat(emb, dim=-1)
        emb = emb.transpose(-1, -2).flatten(len(feat_shape))
        batch_expand = (B,) + (-1,) * (x.ndim - 1)

        # FIXME support nD
        if self.keep_spatial:
            x = torch.cat([x, emb.unsqueeze(0).expand(batch_expand).permute(0, 3, 1, 2)], dim=1)
        else:
            x = torch.cat([x.permute(0, 2, 3, 1), emb.unsqueeze(0).expand(batch_expand)], dim=-1)
            x = x.reshape(B, feat_shape.numel(), -1)

        return x


def rot(x):
    # x:   [ x0  x1  x2  x3  x4  x5]
    # out: [-x1  x0 -x3  x2 -x5  x4]
    return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape)


def rope_rotate_half(x: torch.Tensor) -> torch.Tensor:
    # x:   [ x0  x1  x2  x3  x4  x5]
    # out: [-x3 -x4 -x5  x0  x1  x2]
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat([-x2, x1], dim=-1)


def apply_rot_embed(
        x: torch.Tensor,
        sin_emb: torch.Tensor,
        cos_emb: torch.Tensor,
        half: bool = False,
) -> torch.Tensor:
    # x: [..., D], eg [x0, x1, x2, x3, x4, x5]
    if half:
        # sin: [..., D], eg [sin0, sin1, sin2, sin0, sin1, sin2]
        # cos: [..., D], eg [cos0, cos1, cos2, cos0, cos1, cos2
        # rope_rotate_half(x): eg [-x3, -x4, -x5, x0, x1, x2]
        return x * cos_emb + rope_rotate_half(x) * sin_emb
    else:
        # sin: [..., D], eg [sin0, sin0, sin1, sin1, sin2, sin2]
        # cos: [..., D], eg [cos0, cos0, cos1, cos1, cos2, cos2]
        # rot(x): eg [-x1, x0, -x3, x2, -x5, x4]
        return x * cos_emb + rot(x) * sin_emb


def apply_rot_embed_list(
        x: List[torch.Tensor],
        sin_emb: torch.Tensor,
        cos_emb: torch.Tensor,
        half: bool = False
) -> List[torch.Tensor]:
    if isinstance(x, torch.Tensor):
        x = [x]
    # x: [..., D], eg [x0, x1, x2, x3, x4, x5]
    if half:
        # sin: [..., D], eg [sin0, sin1, sin2, sin0, sin1, sin2]
        # cos: [..., D], eg [cos0, cos1, cos2, cos0, cos1, cos2
        # rope_rotate_half(x): eg [-x3, -x4, -x5, x0, x1, x2]
        return [t * cos_emb + rope_rotate_half(t) * sin_emb for t in x]
    else:
        # sin: [..., D], eg [sin0, sin0, sin1, sin1, sin2, sin2]
        # cos: [..., D], eg [cos0, cos0, cos1, cos1, cos2, cos2]
        # rot(x): eg [-x1, x0, -x3, x2, -x5, x4]
        return [t * cos_emb + rot(t) * sin_emb for t in x]


def apply_rot_embed_cat(
        x: torch.Tensor,
        emb: torch.Tensor,
        half: bool = False
) -> torch.Tensor:
    sin_emb, cos_emb = emb.chunk(2, -1)
    # x: [..., D], eg [x0, x1, x2, x3, x4, x5]
    if half:
        # sin: [..., D], eg [sin0, sin1, sin2, sin0, sin1, sin2]
        # cos: [..., D], eg [cos0, cos1, cos2, cos0, cos1, cos2
        # rope_rotate_half(x), eg [-x3, -x4, -x5, x0, x1, x2]
        return x * cos_emb + rope_rotate_half(x) * sin_emb
    else:
        # sin: [..., D], eg [sin0, sin0, sin1, sin1, sin2, sin2]
        # cos: [..., D], eg [cos0, cos0, cos1, cos1, cos2, cos2]
        # rot(x), eg [-x1, x0, -x3, x2, -x5, x4]
        return x * cos_emb + rot(x) * sin_emb


def apply_keep_indices_nlc(
        x: torch.Tensor,
        pos_embed: torch.Tensor,
        keep_indices: torch.Tensor,
        pos_embed_has_batch: bool = False,
) -> torch.Tensor:
    """ Apply keep indices to different ROPE shapes

    Expected pos_embed shapes:
    * [seq_len, pos_embed_dim] --> output [batch_size, seq_len, pos_embed_dim]
    * [num_heads, seq_len, pos_embed_dim] --> output [batch_size, num_heads, seq_len, pos_embed_dim]
    * [depth, num_heads, seq_len, pos_embed_dim] --> output [batch_size, depth, num_heads, seq_len, pos_embed_dim]

    And all of the above with leading batch dimension already present if `pos_embed_has_batch == True`

    """
    if pos_embed_has_batch:
        # Pos embed already includes batch dim
        _assert(pos_embed.ndim >= 3, 'Incorrect number of dimensions')  # At least [batch, seq_len, pos_embed_dim]
    else:
        # Add batch dimension and expand to batch size
        _assert(pos_embed.ndim >= 2, 'Incorrect number of dimensions')  # At least [seq_len, pos_embed_dim]
        expand_shape = (x.shape[0],) + (-1,) * pos_embed.ndim
        pos_embed = pos_embed.unsqueeze(0).expand(expand_shape)

    # Reshape keep_indices to add singleton dims
    keep_shape = (keep_indices.shape[0],) + (1,) * (pos_embed.ndim - 3) + (keep_indices.shape[1], 1)
    keep_indices = keep_indices.view(keep_shape)

    # Expand all dims to match position embedding except the gather dim (second-last)
    keep_expand = list(pos_embed.shape)
    keep_expand[-2] = -1
    keep_indices = keep_indices.expand(keep_expand)

    return pos_embed.gather(-2, keep_indices)


def build_rotary_pos_embed(
        feat_shape: List[int],
        bands: Optional[torch.Tensor] = None,
        dim: int = 64,
        max_res: int = 224,
        temperature: float = 10000.,
        linear_bands: bool = False,
        in_pixels: bool = True,
        ref_feat_shape: Optional[List[int]] = None,
        grid_offset: float = 0.,
        grid_indexing: str = 'ij',
        device: Optional[torch.device] = None,
        dtype: torch.dtype = torch.float32,
):
    """

    Args:
        feat_shape: Spatial shape of the target tensor for embedding.
        bands: Optional pre-generated frequency bands
        dim: Output dimension of embedding tensor.
        max_res: Maximum resolution for pixel mode.
        temperature: Temperature (inv freq) for non-pixel mode
        linear_bands: Linearly (instead of log) spaced bands for pixel mode
        in_pixels: Pixel vs language (inv freq) mode.
        ref_feat_shape: Reference feature shape for resize / fine-tune.
        grid_offset: Constant offset to add to grid for non-pixel freq.
        grid_indexing: Indexing mode for meshgrid ('ij' or 'xy')
        device: Output device.
        dtype: Output dtype.

    Returns:

    """
    sin_emb, cos_emb = build_fourier_pos_embed(
        feat_shape,
        bands=bands,
        num_bands=dim // 4,
        max_res=max_res,
        temperature=temperature,
        linear_bands=linear_bands,
        in_pixels=in_pixels,
        ref_feat_shape=ref_feat_shape,
        grid_offset=grid_offset,
        grid_indexing=grid_indexing,
        device=device,
        dtype=dtype,
    )
    num_spatial_dim = 1
    # this would be much nicer as a .numel() call to torch.Size(), but torchscript sucks
    for x in feat_shape:
        num_spatial_dim *= x
    sin_emb = sin_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1)
    cos_emb = cos_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1)
    return sin_emb, cos_emb


class RotaryEmbedding(nn.Module):
    """ Rotary position embedding

    NOTE: This is my initial attempt at impl rotary embedding for spatial use, it has not
    been well tested, and will likely change. It will be moved to its own file.

    The following impl/resources were referenced for this impl:
    * https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py
    * https://blog.eleuther.ai/rotary-embeddings/
    """

    def __init__(
            self,
            dim,
            max_res=224,
            temperature=10000,
            in_pixels=True,
            linear_bands: bool = False,
            feat_shape: Optional[List[int]] = None,
            ref_feat_shape: Optional[List[int]] = None,
            grid_offset: float = 0.,
            grid_indexing: str = 'ij',
            device=None,
            dtype=None,
    ):
        super().__init__()
        self.dim = dim
        self.max_res = max_res
        self.temperature = temperature
        self.linear_bands = linear_bands
        self.in_pixels = in_pixels
        self.feat_shape = feat_shape
        self.ref_feat_shape = ref_feat_shape
        self.grid_offset = grid_offset
        self.grid_indexing = grid_indexing

        if feat_shape is None:
            # only cache bands
            if in_pixels:
                bands = pixel_freq_bands(
                    dim // 4,
                    float(max_res),
                    linear_bands=linear_bands,
                )
            else:
                bands = freq_bands(
                    dim // 4,
                    temperature=temperature,
                    step=1,
                )
            self.register_buffer(
                'bands',
                bands.to(device=device, dtype=dtype),
                persistent=False,
            )
            self.pos_embed_sin = None
            self.pos_embed_cos = None
        else:
            # cache full sin/cos embeddings if shape provided up front
            emb_sin, emb_cos = self._get_pos_embed_values(feat_shape, device=device, dtype=dtype)
            self.bands = None
            self.register_buffer(
                'pos_embed_sin',
                emb_sin,
                persistent=False,
            )
            self.register_buffer(
                'pos_embed_cos',
                emb_cos,
                persistent=False,
            )

    def _get_pos_embed_values(self, feat_shape: List[int], device=None, dtype=torch.float32):
        emb_sin, emb_cos = build_rotary_pos_embed(
            feat_shape=feat_shape,
            dim=self.dim,
            max_res=self.max_res,
            temperature=self.temperature,
            linear_bands=self.linear_bands,
            in_pixels=self.in_pixels,
            ref_feat_shape=self.ref_feat_shape,
            grid_offset=self.grid_offset,
            grid_indexing=self.grid_indexing,
            device=device,
            dtype=dtype,
        )
        return emb_sin, emb_cos

    def update_feat_shape(self, feat_shape: List[int]):
        if self.feat_shape is not None and feat_shape != self.feat_shape:
            # only update if feat_shape was set and different from previous value
            assert self.pos_embed_sin is not None
            assert self.pos_embed_cos is not None
            self.pos_embed_sin, self.pos_embed_cos = self._get_pos_embed_values(
                feat_shape,
                device=self.pos_embed_sin.device,
                dtype=self.pos_embed_sin.dtype,
            )
            self.feat_shape = feat_shape

    def get_embed(self, shape: Optional[List[int]] = None):
        if shape is not None and self.bands is not None:
            # rebuild embeddings every call, use if target shape changes
            return build_rotary_pos_embed(
                shape,
                self.bands,
                in_pixels=self.in_pixels,
                ref_feat_shape=self.ref_feat_shape,
                grid_offset=self.grid_offset,
                grid_indexing=self.grid_indexing,
            )
        elif self.pos_embed_sin is not None and self.pos_embed_cos is not None:
            return self.pos_embed_sin, self.pos_embed_cos
        else:
            assert False, "get_embed() requires pre-computed pos embeds or valid shape w/ pre-computed bands"

    def forward(self, x):
        # assuming channel-first tensor where spatial dim are >= 2
        sin_emb, cos_emb = self.get_embed(x.shape[2:])
        return apply_rot_embed(x, sin_emb, cos_emb)


class RotaryEmbeddingCat(nn.Module):
    """ Rotary position embedding w/ concatenatd sin & cos

    The following impl/resources were referenced for this impl:
    * https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py
    * https://blog.eleuther.ai/rotary-embeddings/
    """

    def __init__(
            self,
            dim: int,
            max_res: int = 224,
            temperature: float = 10000,
            in_pixels: bool = True,
            linear_bands: bool = False,
            feat_shape: Optional[List[int]] = None,
            ref_feat_shape: Optional[List[int]] = None,
            grid_offset: float = 0.,
            grid_indexing: str = 'ij',
            device=None,
            dtype=None,
    ):
        super().__init__()
        self.dim = dim
        self.max_res = max_res
        self.temperature = temperature
        self.in_pixels = in_pixels
        self.linear_bands = linear_bands
        self.feat_shape = feat_shape
        self.ref_feat_shape = ref_feat_shape
        self.grid_offset = grid_offset
        self.grid_indexing = grid_indexing

        if feat_shape is None:
            # only cache bands
            if in_pixels:
                bands = pixel_freq_bands(
                    dim // 4,
                    float(max_res),
                    linear_bands=linear_bands,
                )
            else:
                bands = freq_bands(
                    dim // 4,
                    temperature=temperature,
                    step=1,
                )
            self.register_buffer(
                'bands',
                bands.to(device=device, dtype=dtype),
                persistent=False,
            )
            self.pos_embed = None
        else:
            # cache full sin/cos embeddings if shape provided up front
            self.bands = None
            self.register_buffer(
                'pos_embed',
                self._get_pos_embed_values(feat_shape=feat_shape, device=device, dtype=dtype),
                persistent=False,
            )

    def _get_pos_embed_values(self, feat_shape: List[int], device=None, dtype=torch.float32):
        embeds = build_rotary_pos_embed(
            feat_shape=feat_shape,
            dim=self.dim,
            max_res=self.max_res,
            temperature=self.temperature,
            linear_bands=self.linear_bands,
            in_pixels=self.in_pixels,
            ref_feat_shape=self.ref_feat_shape,
            grid_offset=self.grid_offset,
            grid_indexing=self.grid_indexing,
            device=device,
            dtype=dtype,
        )
        return torch.cat(embeds, -1)

    def update_feat_shape(self, feat_shape: List[int]):
        if self.feat_shape is not None and feat_shape != self.feat_shape:
            # only update if feat_shape was set and different from previous value
            assert self.pos_embed is not None
            self.pos_embed = self._get_pos_embed_values(
                feat_shape,
                device=self.pos_embed.device,
                dtype=self.pos_embed.dtype,
            )
            self.feat_shape = feat_shape

    def get_embed(self, shape: Optional[List[int]] = None):
        if shape is not None and self.bands is not None:
            # rebuild embeddings from cached bands every call, use if target shape changes
            embeds = build_rotary_pos_embed(
                shape,
                self.bands,
                in_pixels=self.in_pixels,
                ref_feat_shape=self.ref_feat_shape,
                grid_offset=self.grid_offset,
                grid_indexing=self.grid_indexing,
            )
            return torch.cat(embeds, -1)
        elif self.pos_embed is not None:
            return self.pos_embed
        else:
            assert False, "get_embed() requires pre-computed pos embed or valid shape w/ pre-computed bands"

    def get_batch_embeds(
            self,
            shapes: List[Tuple[int, int]],
            seq_len: Optional[int] = None,
    ) -> Union[torch.Tensor, List[torch.Tensor]]:
        """Generate ROPE embeddings for multiple grid shapes efficiently.

        Computes embeddings for the maximum grid size once, then extracts
        and flattens the relevant portions for each requested shape.

        Args:
            shapes: List of (H, W) tuples representing different grid sizes

        Returns:
            List of concatenated sin/cos embeddings for each shape,
            where each tensor has shape (H*W, dim)
        """
        if not shapes:
            return []

        # Check if we have pre-computed bands
        if self.bands is None:
            # If we have pre-computed pos_embed for a fixed shape, we can't do batch generation
            raise RuntimeError("Batch embedding generation requires cached bands, not pre-computed embeddings")

        # Find max dimensions across all shapes
        max_h = max(h for h, w in shapes)
        max_w = max(w for h, w in shapes)

        # Generate embeddings for max size ONCE
        sin_emb, cos_emb = build_rotary_pos_embed(
            feat_shape=(max_h, max_w),
            bands=self.bands,
            in_pixels=self.in_pixels,
            ref_feat_shape=self.ref_feat_shape,
            grid_offset=self.grid_offset,
            grid_indexing=self.grid_indexing,
        )

        # sin_emb and cos_emb are (max_h * max_w, dim//2)
        # concat and reshape to 2D for slicing
        rope_embed_2d = torch.cat([sin_emb, cos_emb], dim=-1).view(max_h, max_w, -1)

        if seq_len is not None:
            flat_embeds = torch.zeros(len(shapes), seq_len, rope_embed_2d.shape[-1]).type_as(sin_emb)
            for i, (h, w) in enumerate(shapes):
                src_len = h * w
                flat_embeds[i, :src_len] = rope_embed_2d[:h, :w].reshape(src_len, -1)
            return flat_embeds
        else:
            flat_embeds_list = [rope_embed_2d[:h, :w].reshape(h * w, -1) for h, w in shapes]
            return flat_embeds_list

    def forward(self, x):
        # assuming channel-first tensor where spatial dim are >= 2
        pos_embed = self.get_embed(x.shape[2:])
        return apply_rot_embed_cat(x, pos_embed)


def init_random_2d_freqs(
        head_dim: int,
        depth: int,
        num_heads: int,
        temperature: float = 10.0,
        rotate: bool = True,
        *,
        device=None,
        dtype=torch.float32,
) -> torch.Tensor:
    """ Vectorised 2D ROPE frequencies with random rotation for mixed mode ROPE.
    Returns:
         Tensor (2, depth, num_heads, head_dim//2)
    """
    # base magnitudes, shape: (head_dim//4,)
    mag = 1.0 / (temperature ** (torch.arange(0, head_dim, 4, device=device, dtype=dtype) / head_dim))

    # (1,1,L) so it broadcasts over both depth and heads
    mag = mag.unsqueeze(0).unsqueeze(0)  # (1,1,L)

    # random (or zero) rotation per head *and* per block
    if rotate:
        angles = torch.rand(depth, num_heads, 1, device=device, dtype=dtype) * 2 * torch.pi
    else:
        angles = torch.zeros(depth, num_heads, 1, device=device, dtype=dtype)

    # build (depth, num_heads, 2·L) == head_dim//2 on the last axis
    fx = torch.cat([mag * torch.cos(angles), mag * torch.cos(angles + torch.pi / 2)], dim=-1)
    fy = torch.cat([mag * torch.sin(angles), mag * torch.sin(angles + torch.pi / 2)], dim=-1)

    # (2, depth, num_heads, head_dim//2)
    return torch.stack([fx, fy], dim=0)


@torch.fx.wrap
@register_notrace_function
def get_mixed_grid(
        shape: List[int],
        grid_indexing: str = 'ij',
        device: Optional[torch.device] = None,
        dtype: torch.dtype = torch.float32,
) -> Tuple[torch.Tensor, torch.Tensor]:
    if grid_indexing == 'xy':
        shape = swap_shape_xy(shape)
    x_pos, y_pos = torch.meshgrid(
        torch.arange(shape[0], device=device, dtype=torch.float32),
        torch.arange(shape[1], device=device, dtype=torch.float32),
        indexing=grid_indexing,
    )
    t_x = x_pos.to(dtype).flatten()
    t_y = y_pos.to(dtype).flatten()
    return t_x, t_y


def get_mixed_freqs(
        freqs: torch.Tensor,
        t_x: torch.Tensor,
        t_y: torch.Tensor,
) -> torch.Tensor:
    """Compute mixed (learnable) frequencies."""
    # Create position indices
    dtype = freqs.dtype
    freqs = freqs.float()
    freqs_x = (t_x.unsqueeze(-1) @ freqs[0].unsqueeze(-2))
    freqs_y = (t_y.unsqueeze(-1) @ freqs[1].unsqueeze(-2))
    combined = freqs_x + freqs_y  # shape: (num_heads, N, dim//4)
    sin_emb = torch.sin(combined).repeat_interleave(2, -1)  # (N, dim//2)
    cos_emb = torch.cos(combined).repeat_interleave(2, -1)  # (N, dim//2)
    rope_embeds = torch.cat([sin_emb, cos_emb], dim=-1)  # (num_heads, H*W, head_dim)
    return rope_embeds.to(dtype)


class RotaryEmbeddingMixed(nn.Module):
    """Rotary position embedding with depth-dependent learnable frequencies.

    This implementation supports mixed (learnable) ROPE. In mixed mode,
    each transformer block has its own set of learnable frequency parameters.

    Based on 'Rotary Position Embedding for Vision: https://arxiv.org/abs/2403.13298)'
    Compatible with original at https://github.com/naver-ai/rope-vit
    """
    def __init__(
            self,
            dim: int,
            depth: int,
            num_heads: int,
            temperature: float = 10.0,
            feat_shape: Optional[List[int]] = None,
            grid_indexing: str = 'xy',
            device=None,
            dtype=None,
    ):
        """Initialize rotary embeddings.

        Args:
            dim: Embedding dimension (should be divisible by 4)
            depth: Number of transformer blocks
            num_heads: Number of attention heads
            temperature: Base for frequency computation
            feat_shape: Spatial dimensions [H, W] if known in advance
            grid_indexing: How to index grid positions ('xy' or 'ij')
        """
        super().__init__()
        self.dim = dim
        self.depth = depth
        self.num_heads = num_heads
        self.temperature = temperature
        self.feat_shape = feat_shape
        self.grid_indexing = grid_indexing

        head_dim = dim // num_heads
        assert head_dim % 4 == 0, f"head_dim must be divisible by 4, got {head_dim}"

        freqs = init_random_2d_freqs(
            head_dim,
            depth,
            num_heads,
            temperature=temperature,
            rotate=True,
            device=device,
            dtype=dtype,
        )  # (2, depth, num_heads, head_dim//2)
        self.freqs = nn.Parameter(freqs)

        if feat_shape is not None:
            # cache pre-computed grid
            t_x, t_y = self._get_grid_values(feat_shape)
            self.register_buffer('t_x', t_x, persistent=False)
            self.register_buffer('t_y', t_y, persistent=False)
        else:
            self.t_x = self.t_y = None

    def _get_grid_values(self, feat_shape: Optional[List[int]]):
        t_x, t_y = get_mixed_grid(
            feat_shape,
            grid_indexing=self.grid_indexing,
            device=self.freqs.device,
        )
        return t_x, t_y

    def update_feat_shape(self, feat_shape: Optional[List[int]]):
        if self.feat_shape is not None and feat_shape != self.feat_shape:
            assert self.t_x is not None
            assert self.t_y is not None
            t_x, t_y = self._get_grid_values(feat_shape)
            self.t_x = t_x.to(self.t_x.device, self.t_x.dtype)
            self.t_y = t_y.to(self.t_y.device, self.t_y.dtype)
            self.feat_shape = feat_shape

    def get_embed(self, shape: Optional[List[int]] = None) -> torch.Tensor:
        """Generate rotary embeddings for the given spatial shape.

        Args:
            shape: Spatial dimensions [H, W]

        Returns:
            Tensor of shape (depth, H*W, dim) containing concatenated sin/cos embeddings
        """
        if shape is not None:
            t_x, t_y = get_mixed_grid(
                shape,
                grid_indexing=self.grid_indexing,
                device=self.freqs.device
            )
        elif self.t_x is not None and self.t_y is not None:
            t_x, t_y = self.t_x, self.t_y
        else:
            assert False, "get_embed() requires pre-computed t_x/t_y or valid shape"

        return get_mixed_freqs(self.freqs, t_x, t_y)

    def get_batch_embeds(
            self,
            shapes: List[Tuple[int, int]],
            seq_len: Optional[int] = None,
    ) -> Union[torch.Tensor, List[torch.Tensor]]:
        """Generate ROPE embeddings for multiple grid shapes efficiently.

        Computes embeddings for the maximum grid size once, then extracts
        and flattens the relevant portions for each requested shape.

        Args:
            shapes: List of (H, W) tuples representing different grid sizes
            seq_len: If provided, return padded tensor of this length. Otherwise return list.

        Returns:
            If seq_len is provided: Padded tensor of shape (len(shapes), depth, num_heads, seq_len, dim)
            Otherwise: List of tensors with shape (depth, num_heads, H*W, dim) for each shape
        """
        if not shapes:
            return []

        # Find max dimensions
        max_h = max(h for h, w in shapes)
        max_w = max(w for h, w in shapes)

        # Generate embeddings for max size ONCE
        t_x, t_y = get_mixed_grid(
            [max_h, max_w],
            grid_indexing=self.grid_indexing,
            device=self.freqs.device
        )
        max_embed = get_mixed_freqs(self.freqs, t_x, t_y)  # (depth, num_heads, max_h*max_w, dim)

        # Reshape to 2D grid for easy slicing
        depth, num_heads, _, dim = max_embed.shape
        max_embed_2d = max_embed.view(depth, num_heads, max_h, max_w, dim)

        if seq_len is not None:
            # Return padded tensor
            B = len(shapes)
            padded = torch.zeros(B, depth, num_heads, seq_len, dim, device=self.freqs.device, dtype=self.freqs.dtype)
            for i, (h, w) in enumerate(shapes):
                # Slice and flatten
                embed_slice = max_embed_2d[:, :, :h, :w].reshape(depth, num_heads, h * w, dim)
                actual_len = h * w
                padded[i, :, :, :actual_len] = embed_slice
            return padded
        else:
            # Return list
            results = []
            for h, w in shapes:
                # Slice and flatten
                embed_slice = max_embed_2d[:, :, :h, :w].reshape(depth, num_heads, h * w, dim)
                results.append(embed_slice)
            return results

    def forward(self, x):
        # assuming channel-first tensor where spatial dim are >= 2
        pos_embed = self.get_embed(x.shape[2:])
        return apply_rot_embed_cat(x, pos_embed)

    def no_weight_decay(self):
        """Exclude frequency parameters from weight decay."""
        return {'freqs'}


@torch.fx.wrap
@register_notrace_function
def make_coords_dinov3(
        height: int,
        width: int,
        normalize_coords: str = 'separate',
        grid_indexing: str = 'ij',
        grid_offset: float = 0.,
        device: torch.device = 'cpu',
        dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
    """Make coordinate grid matching offset and normalization of original.
    Returns: coords with shape (HW, 2) in [-1, 1].
    """
    # 0.5-centered indices with optional offset
    coords_h = torch.arange(0.5, height, device=device, dtype=torch.float32) + grid_offset
    coords_w = torch.arange(0.5, width, device=device, dtype=torch.float32) + grid_offset

    # Normalization denominators
    if normalize_coords == "max":
        denom = float(max(height, width))
        h_denom = denom
        w_denom = denom
    elif normalize_coords == "min":
        denom = float(min(height, width))
        h_denom = denom
        w_denom = denom
    elif normalize_coords == "separate":
        h_denom = float(height)
        w_denom = float(width)
    else:
        raise ValueError(f"Unknown normalize_coords: {normalize_coords}")

    # Normalize to [0, 1]
    coords_h = coords_h / h_denom
    coords_w = coords_w / w_denom
    coords_h = coords_h.to(dtype)
    coords_w = coords_w.to(dtype)

    # Create grid then map to [-1, 1]
    if grid_indexing == "xy":
        grid_w, grid_h = torch.meshgrid(coords_w, coords_h, indexing="xy")
        coords = torch.stack([grid_h, grid_w], dim=-1)  # (H, W, 2) -> (h, w order)
    else:
        coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1)  # (H, W, 2)
    coords = coords.flatten(0, 1)  # (HW, 2)
    coords = 2.0 * coords - 1.0  # (H, W, 2) in [-1, 1]
    return coords


class RotaryEmbeddingDinoV3(nn.Module):
    """RoPE for timm DinoV3 port, numerically matching original.

    Math is aligned to original DinoV3 RopePositionEmbedding at https://github.com/facebookresearch/dinov3:
      - 0.5-centered coords normalized by H/W (or min/max), mapped to [-1,1]
      - training-time augmentations (shift/jitter/rescale)
      - periods schedule equals Rope's temperature (base) or min/max period
    """

    def __init__(
            self,
            dim: int,
            temperature: Optional[float] = 100.0,
            min_period: Optional[float] = None,
            max_period: Optional[float] = None,
            feat_shape: Optional[List[int]] = None,
            normalize_coords: str = "separate",  # 'min', 'max', 'separate'
            grid_offset: float = 0.0,
            grid_indexing: str = "ij",
            rotate_half: bool = True,
            shift_coords: Optional[float] = None,
            jitter_coords: Optional[float] = None,  # interpreted as factor J >= 1
            rescale_coords: Optional[float] = None,  # interpreted as factor R >= 1
            device=None,
            dtype=None,
    ):
        super().__init__()

        # Dimensions / output format
        self.dim = dim  # equal to head_dim for most vit applications
        self.rotate_half = rotate_half

        # Period schedule parameters
        self.temperature = float(temperature)
        self.min_period = min_period
        self.max_period = max_period

        # Coord processing + augs
        self.normalize_coords = normalize_coords
        self.shift_coords = shift_coords
        self.jitter_coords = jitter_coords
        self.rescale_coords = rescale_coords
        self.aug_active = any([a is not None for a in [self.shift_coords, self.jitter_coords, self.rescale_coords]])

        # Grid config
        self.feat_shape = feat_shape
        self.grid_offset = grid_offset
        self.grid_indexing = grid_indexing

        # Precompute periods
        periods = self._compute_periods(device=device, dtype=dtype)
        self.register_buffer("periods", periods, persistent=False)

        if feat_shape is not None:
            self._cache_embed(feat_shape)
        else:
            self.register_buffer("pos_embed_cached", None, persistent=False)
            self.feat_shape = None

    def _compute_periods(self, device: torch.device = 'cpu', dtype: torch.dtype = torch.float32) -> torch.Tensor:
        """Construct periods from either min/max or temperature."""
        dim = self.dim // 4

        if self.min_period is not None and self.max_period is not None:
            exponents = torch.linspace(0, 1, dim, device='cpu', dtype=torch.float32)
            periods = self.min_period * ((self.max_period / self.min_period) ** exponents)
        else:
            if self.temperature is None:
                raise ValueError("Provide either min/max periods or `temperature`.")
            exponents = 2.0 * torch.arange(dim, device='cpu', dtype=torch.float32) / (self.dim // 2)
            periods = self.temperature ** exponents

        # NOTE: The original dinv3 model weights have periods downcast to bfloat16 in persistent buffers,
        # loaded models will differ a bit vs timm as periods is not persistent and generated in float32 by default
        return periods.to(device=device, dtype=dtype)

    def _apply_coord_augs(self, coords: torch.Tensor) -> torch.Tensor:
        """Apply shift/jitter/rescale train time augmentations."""
        if not self.training or not self.aug_active:
            return coords

        device = coords.device
        dtype = coords.dtype

        # Shift per-axis in [-s, +s]
        if self.shift_coords is not None:
            shift = float(self.shift_coords)
            shift_hw = torch.empty(2, device=device, dtype=dtype).uniform_(-shift, shift)
            coords = coords + shift_hw[None, :]

        # Jitter: per-axis log-uniform factor in [1/J, J]
        if self.jitter_coords is not None:
            jitter_factor = float(self.jitter_coords)
            if jitter_factor <= 0:
                raise ValueError("jitter_coords must be > 0 (interpreted as multiplicative factor).")
            jitter_max = math.log(jitter_factor)
            jitter_hw = torch.empty(2, device=device, dtype=dtype).uniform_(-jitter_max, jitter_max).exp()
            coords = coords * jitter_hw[None, :]

        # Rescale: shared scalar log-uniform factor in [1/R, R]
        if self.rescale_coords is not None:
            rescale_factor = float(self.rescale_coords)
            if rescale_factor <= 0:
                raise ValueError("rescale_coords must be > 0 (interpreted as multiplicative factor).")
            rescale_max = math.log(rescale_factor)
            rescale = torch.empty(1, device=device, dtype=dtype).uniform_(-rescale_max, rescale_max).exp()
            coords = coords * rescale

        return coords

    def _get_pos_embed_from_coords(self, coords: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Return sin/cos embeddings with either 'half' or 'interleaved' layout."""
        # coords: (HW, 2); periods: (dim)
        dim = self.dim // 4
        device = self.periods.device
        dtype = self.periods.dtype
        assert self.periods.numel() == dim

        # NOTE this is a slightly later device/dtype switch than original
        coords = coords[:, :, None].to(device=device, dtype=dtype)
        angles = 2 * math.pi * coords / self.periods[None, None, :]
        angles = angles.flatten(1)  # (HW, dim // 2)

        if self.rotate_half:
            # Tile (half layout) (HW, dim // 2) -> (HW, dim)
            angles = angles.tile(2)
        else:
            # Interleaved layout (HW, dim // 2) -> (HW, dim)
            angles = angles.repeat_interleave(2, dim=-1)

        sin = torch.sin(angles)
        cos = torch.cos(angles)
        return sin, cos

    def _create_embed(
            self,
            feat_shape: List[int],
            no_aug: bool = False,
    ) -> torch.Tensor:
        H, W = feat_shape
        coords = make_coords_dinov3(
            H, W,
            normalize_coords=self.normalize_coords,
            grid_indexing=self.grid_indexing,
            grid_offset=self.grid_offset,
        )  # (HW, 2)
        if not no_aug:
            coords = self._apply_coord_augs(coords)
        sin, cos = self._get_pos_embed_from_coords(coords)  # 2 * (HW, dim)
        rope_embed = torch.cat([sin, cos], dim=-1)  # (HW, 2*dim)
        return rope_embed

    def _cache_embed(self, feat_shape: List[int]):
        # create non-augmented embeds for cache
        rope_embed = self._create_embed(feat_shape, no_aug=True)
        self.register_buffer("pos_embed_cached", rope_embed, persistent=False)
        self.feat_shape = feat_shape

    def update_feat_shape(self, feat_shape: List[int]):
        if self.feat_shape is not None and feat_shape != self.feat_shape:
            # only update if feat_shape was set (valid cache) and different from previous value
            self._cache_embed(feat_shape)

    def get_embed(self, shape: Optional[List[int]] = None) -> torch.Tensor:
        """Generate rope_embed matching DINOv3 RopePositionEmbedding numerics.

        Returns: (HW, num_heads, 2 * head_dim) with last dim = [sin, cos] cat.
        """
        if shape is not None:
            rope_embed = self._create_embed(shape)
        else:
            need_create = self.pos_embed_cached is None or (self.training and self.aug_active)
            if need_create:
                assert self.feat_shape is not None, 'feature shape must be cached on create'
                rope_embed = self._create_embed(self.feat_shape)
            else:
                assert self.pos_embed_cached is not None
                rope_embed = self.pos_embed_cached

        return rope_embed

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Get and apply rotary embeddings to x"""
        # assuming channel-first tensor where spatial dim are >= 2
        pos_embed = self.get_embed(x.shape[2:])
        return apply_rot_embed_cat(x, pos_embed, half=self.rotate_half)


def create_rope_embed(
        rope_type: str = 'cat',
        dim: int = 768,
        num_heads: int = 12,
        **kwargs
) -> nn.Module:
    """Factory function for creating rotary position embeddings.

    Args:
        rope_type: Type of RoPE to create. Options:
            - 'base': Basic RotaryEmbedding
            - 'cat': RotaryEmbeddingCat (concatenated sin/cos)
            - 'mixed': RotaryEmbeddingMixed (learnable per-depth frequencies)
            - 'dinov3': RotaryEmbeddingDinoV3 (with coordinate transforms)
        dim: Total embedding dimension
        num_heads: Number of attention heads
        **kwargs: Additional arguments passed to the specific RoPE class

    Returns:
        Rotary embedding module
    """
    if rope_type == 'base':
        return RotaryEmbedding(dim=dim // num_heads, **kwargs)
    elif rope_type == 'cat':
        return RotaryEmbeddingCat(dim=dim // num_heads, **kwargs)
    elif rope_type == 'mixed':
        # Mixed requires depth parameter, generates differing embeddings per layer and head
        return RotaryEmbeddingMixed(dim=dim, num_heads=num_heads, **kwargs)
    elif rope_type == 'dinov3':
        return RotaryEmbeddingDinoV3(dim=dim // num_heads, **kwargs)
    else:
        raise ValueError(f"Unknown RoPE type: {rope_type}")
