File size: 7,018 Bytes
eca55dc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 | import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
class RotaryEmbedding2D(nn.Module):
def __init__(
self,
dim: int,
max_res: Tuple[int, int] = (128, 256),
temperature: float = 10000.0,
):
super().__init__()
self.dim = dim
self.max_h, self.max_w = max_res
self.temperature = temperature
# Check if dim is divisible by 4 (since we split into 2 for H/W, and each needs 2 for complex)
assert dim % 4 == 0, "Embedding dimension must be divisible by 4 for 2D RoPE"
dim_h = dim // 2
dim_w = dim // 2
# Generate frequencies for H and W
# inv_freq_h: [dim_h // 2]
inv_freq_h = 1.0 / (temperature ** (torch.arange(0, dim_h, 2).float() / dim_h))
inv_freq_w = 1.0 / (temperature ** (torch.arange(0, dim_w, 2).float() / dim_w))
self.register_buffer("inv_freq_h", inv_freq_h)
self.register_buffer("inv_freq_w", inv_freq_w)
# Cache
self.cached_cos_sin_h = None
self.cached_cos_sin_w = None
def _update_cache(self, h: int, w: int, device: torch.device, dtype: torch.dtype):
# Generate grid
# We need to support arbitrary positions, but usually we just precompute for max_res
# or compute on the fly for the given indices.
# Let's compute for max_res and index into it.
if self.cached_cos_sin_h is None or self.cached_cos_sin_h[0].shape[0] < h:
t_h = torch.arange(h, device=device, dtype=dtype)
freqs_h = torch.einsum("i,j->ij", t_h, self.inv_freq_h) # [H, dim_h/2]
emb_h = torch.cat((freqs_h, freqs_h), dim=-1) # [H, dim_h]
self.cached_cos_sin_h = (emb_h.cos(), emb_h.sin())
if self.cached_cos_sin_w is None or self.cached_cos_sin_w[0].shape[0] < w:
t_w = torch.arange(w, device=device, dtype=dtype)
freqs_w = torch.einsum("i,j->ij", t_w, self.inv_freq_w) # [W, dim_w/2]
emb_w = torch.cat((freqs_w, freqs_w), dim=-1) # [W, dim_w]
self.cached_cos_sin_w = (emb_w.cos(), emb_w.sin())
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
pos_ids: torch.Tensor,
grid_size: Tuple[int, int],
):
# q, k: [B, num_heads, N, head_dim]
# pos_ids: [B, N] or [N] (indices of patches)
# grid_size: (H, W) - original grid size to decode pos_ids
B, num_heads, N, D = q.shape
H_grid, W_grid = grid_size
# Decode pos_ids to (h, w)
# pos_ids are indices in flattened grid [0, H*W-1]
# h = pos_ids // W_grid
# w = pos_ids % W_grid
h_idx = pos_ids.div(W_grid, rounding_mode="floor") # [B, N]
w_idx = pos_ids % W_grid # [B, N]
# Ensure cache is large enough
self._update_cache(H_grid, W_grid, q.device, q.dtype)
# Fetch cos/sin for H and W
# cos_h: [B, N, dim_h]
# We need to gather from cached [max_h, dim_h] using h_idx
# Handle shared pos_ids (if [N])
if h_idx.ndim == 1:
h_idx = h_idx.unsqueeze(0).expand(B, -1)
w_idx = w_idx.unsqueeze(0).expand(B, -1)
cos_h = F.embedding(h_idx, self.cached_cos_sin_h[0]) # [B, N, dim_h]
sin_h = F.embedding(h_idx, self.cached_cos_sin_h[1])
cos_w = F.embedding(w_idx, self.cached_cos_sin_w[0]) # [B, N, dim_w]
sin_w = F.embedding(w_idx, self.cached_cos_sin_w[1])
# Split q, k into halves
# q: [B, num_heads, N, D] -> [B, N, num_heads, D] for easier manipulation?
# Usually RoPE is applied on [B, num_heads, N, D] or [N, B, num_heads, D]
# Let's keep [B, num_heads, N, D]
dim_half = D // 2
q_h, q_w = q.split(dim_half, dim=-1)
k_h, k_w = k.split(dim_half, dim=-1)
# Apply RoPE
# We need to reshape cos/sin to broadcast over num_heads
# cos_h: [B, N, dim_h] -> [B, 1, N, dim_h]
cos_h = cos_h.unsqueeze(1)
sin_h = sin_h.unsqueeze(1)
cos_w = cos_w.unsqueeze(1)
sin_w = sin_w.unsqueeze(1)
q_h_rot = self._apply_rotary(q_h, cos_h, sin_h)
k_h_rot = self._apply_rotary(k_h, cos_h, sin_h)
q_w_rot = self._apply_rotary(q_w, cos_w, sin_w)
k_w_rot = self._apply_rotary(k_w, cos_w, sin_w)
q_rot = torch.cat((q_h_rot, q_w_rot), dim=-1)
k_rot = torch.cat((k_h_rot, k_w_rot), dim=-1)
return q_rot, k_rot
def _apply_rotary(
self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> torch.Tensor:
# x: [B, num_heads, N, dim_half]
# cos, sin: [B, 1, N, dim_half]
# Standard RoPE rotation:
# x = [x1, x2]
# out = [x1*cos - x2*sin, x1*sin + x2*cos]
# This assumes pairs are adjacent.
# My inv_freq generation: cat(freqs, freqs).
# This corresponds to x = [x_first_half, x_second_half] pairing?
# Usually RoPE pairs even/odd or first/second half.
# "The standard implementation ... pairs feature i with i + d/2"
# My emb generation: cat(freqs, freqs) -> [f0, f1, ..., f0, f1, ...] ? No.
# freqs is [0, 2, ...]
# cat(freqs, freqs) -> [f0, f2, ..., f0, f2, ...]
# So it expects x to be split into two halves and rotated.
# rotate_half(x) = [-x2, x1]
return (x * cos) + (self._rotate_half(x) * sin)
def _rotate_half(self, x: torch.Tensor) -> torch.Tensor:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
class RoPEAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
rope: Optional[RotaryEmbedding2D] = None,
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.rope = rope
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(
self,
x: torch.Tensor,
pos_ids: torch.Tensor = None,
grid_size: Tuple[int, int] = None,
) -> torch.Tensor:
B, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
q, k, v = qkv[0], qkv[1], qkv[2] # [B, num_heads, N, head_dim]
if self.rope is not None and pos_ids is not None and grid_size is not None:
q, k = self.rope(q, k, pos_ids, grid_size)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
|