| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Optional, Tuple |
|
|
|
|
| class RotaryEmbedding2D(nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| max_res: Tuple[int, int] = (128, 256), |
| temperature: float = 10000.0, |
| ): |
| super().__init__() |
| self.dim = dim |
| self.max_h, self.max_w = max_res |
| self.temperature = temperature |
|
|
| |
| assert dim % 4 == 0, "Embedding dimension must be divisible by 4 for 2D RoPE" |
|
|
| dim_h = dim // 2 |
| dim_w = dim // 2 |
|
|
| |
| |
| inv_freq_h = 1.0 / (temperature ** (torch.arange(0, dim_h, 2).float() / dim_h)) |
| inv_freq_w = 1.0 / (temperature ** (torch.arange(0, dim_w, 2).float() / dim_w)) |
|
|
| self.register_buffer("inv_freq_h", inv_freq_h) |
| self.register_buffer("inv_freq_w", inv_freq_w) |
|
|
| |
| self.cached_cos_sin_h = None |
| self.cached_cos_sin_w = None |
|
|
| def _update_cache(self, h: int, w: int, device: torch.device, dtype: torch.dtype): |
| |
| |
| |
| |
|
|
| if self.cached_cos_sin_h is None or self.cached_cos_sin_h[0].shape[0] < h: |
| t_h = torch.arange(h, device=device, dtype=dtype) |
| freqs_h = torch.einsum("i,j->ij", t_h, self.inv_freq_h) |
| emb_h = torch.cat((freqs_h, freqs_h), dim=-1) |
| self.cached_cos_sin_h = (emb_h.cos(), emb_h.sin()) |
|
|
| if self.cached_cos_sin_w is None or self.cached_cos_sin_w[0].shape[0] < w: |
| t_w = torch.arange(w, device=device, dtype=dtype) |
| freqs_w = torch.einsum("i,j->ij", t_w, self.inv_freq_w) |
| emb_w = torch.cat((freqs_w, freqs_w), dim=-1) |
| self.cached_cos_sin_w = (emb_w.cos(), emb_w.sin()) |
|
|
| def forward( |
| self, |
| q: torch.Tensor, |
| k: torch.Tensor, |
| pos_ids: torch.Tensor, |
| grid_size: Tuple[int, int], |
| ): |
| |
| |
| |
|
|
| B, num_heads, N, D = q.shape |
| H_grid, W_grid = grid_size |
|
|
| |
| |
| |
| |
|
|
| h_idx = pos_ids.div(W_grid, rounding_mode="floor") |
| w_idx = pos_ids % W_grid |
|
|
| |
| self._update_cache(H_grid, W_grid, q.device, q.dtype) |
|
|
| |
| |
| |
|
|
| |
| if h_idx.ndim == 1: |
| h_idx = h_idx.unsqueeze(0).expand(B, -1) |
| w_idx = w_idx.unsqueeze(0).expand(B, -1) |
|
|
| cos_h = F.embedding(h_idx, self.cached_cos_sin_h[0]) |
| sin_h = F.embedding(h_idx, self.cached_cos_sin_h[1]) |
| cos_w = F.embedding(w_idx, self.cached_cos_sin_w[0]) |
| sin_w = F.embedding(w_idx, self.cached_cos_sin_w[1]) |
|
|
| |
| |
| |
| |
|
|
| dim_half = D // 2 |
| q_h, q_w = q.split(dim_half, dim=-1) |
| k_h, k_w = k.split(dim_half, dim=-1) |
|
|
| |
| |
| |
| cos_h = cos_h.unsqueeze(1) |
| sin_h = sin_h.unsqueeze(1) |
| cos_w = cos_w.unsqueeze(1) |
| sin_w = sin_w.unsqueeze(1) |
|
|
| q_h_rot = self._apply_rotary(q_h, cos_h, sin_h) |
| k_h_rot = self._apply_rotary(k_h, cos_h, sin_h) |
|
|
| q_w_rot = self._apply_rotary(q_w, cos_w, sin_w) |
| k_w_rot = self._apply_rotary(k_w, cos_w, sin_w) |
|
|
| q_rot = torch.cat((q_h_rot, q_w_rot), dim=-1) |
| k_rot = torch.cat((k_h_rot, k_w_rot), dim=-1) |
|
|
| return q_rot, k_rot |
|
|
| def _apply_rotary( |
| self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor |
| ) -> torch.Tensor: |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| return (x * cos) + (self._rotate_half(x) * sin) |
|
|
| def _rotate_half(self, x: torch.Tensor) -> torch.Tensor: |
| x1, x2 = x.chunk(2, dim=-1) |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| class RoPEAttention(nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| num_heads: int = 8, |
| qkv_bias: bool = False, |
| attn_drop: float = 0.0, |
| proj_drop: float = 0.0, |
| rope: Optional[RotaryEmbedding2D] = None, |
| ): |
| super().__init__() |
| self.num_heads = num_heads |
| head_dim = dim // num_heads |
| self.scale = head_dim**-0.5 |
| self.rope = rope |
|
|
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
| self.attn_drop = nn.Dropout(attn_drop) |
| self.proj = nn.Linear(dim, dim) |
| self.proj_drop = nn.Dropout(proj_drop) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| pos_ids: torch.Tensor = None, |
| grid_size: Tuple[int, int] = None, |
| ) -> torch.Tensor: |
| B, N, C = x.shape |
| qkv = ( |
| self.qkv(x) |
| .reshape(B, N, 3, self.num_heads, C // self.num_heads) |
| .permute(2, 0, 3, 1, 4) |
| ) |
| q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
| if self.rope is not None and pos_ids is not None and grid_size is not None: |
| q, k = self.rope(q, k, pos_ids, grid_size) |
|
|
| attn = (q @ k.transpose(-2, -1)) * self.scale |
| attn = attn.softmax(dim=-1) |
| attn = self.attn_drop(attn) |
|
|
| x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| x = self.proj(x) |
| x = self.proj_drop(x) |
| return x |
|
|