File size: 8,098 Bytes
eca55dc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 | import torch
import torch.nn as nn
from timm.layers import Mlp, build_sincos2d_pos_embed, DropPath
from src.models.components.rope import RoPEAttention, RotaryEmbedding2D
from typing import Optional, Tuple
class RoPEBlock(nn.Module):
"""
Transformer Block with RoPE support.
"""
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = False,
proj_drop: float = 0.0,
attn_drop: float = 0.0,
drop_path: float = 0.0,
act_layer: nn.Module = nn.GELU,
norm_layer: nn.Module = nn.LayerNorm,
rope: Optional[RotaryEmbedding2D] = None,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = RoPEAttention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=proj_drop,
rope=rope,
)
self.norm2 = norm_layer(dim)
self.mlp = Mlp(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=proj_drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
def forward(
self,
x: torch.Tensor,
pos_ids: Optional[torch.Tensor] = None,
grid_size: Optional[Tuple[int, int]] = None,
) -> torch.Tensor:
x = x + self.drop_path(
self.attn(self.norm1(x), pos_ids=pos_ids, grid_size=grid_size)
)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class ViT(nn.Module):
"""
Vision Transformer with support for RoPE and 2D positional embeddings.
Args:
embed_dim (int): Embedding dimension.
depth (int): Number of transformer blocks.
num_heads (int): Number of attention heads.
mlp_ratio (float): Ratio of MLP hidden dim to embedding dim.
qkv_bias (bool): Enable bias for QKV projections.
drop_rate (float): Dropout rate.
attn_drop_rate (float): Attention dropout rate.
drop_path_rate (float): Stochastic depth rate.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
num_patches (int): Total number of patches (used for learnable/sincos pos embed).
img_size (tuple[int, int]): Input image size (H, W).
patch_size (tuple[int, int]): Patch size (H, W).
pos_embed_type (str): Type of positional embedding ("rope", "sincos", "learnable").
"""
def __init__(
self,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
norm_layer: nn.Module = nn.LayerNorm,
act_layer: nn.Module = nn.GELU,
num_patches: int = 128,
img_size: tuple[int, int] = (128, 256),
patch_size: tuple[int, int] = (16, 16),
pos_embed_type: str = "rope",
):
super().__init__()
self.embed_dim = embed_dim
self.num_patches = num_patches
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.pos_embed_type = pos_embed_type
# Positional Embeddings
if pos_embed_type == "rope":
head_dim = embed_dim // num_heads
self.rope = RotaryEmbedding2D(dim=head_dim, max_res=self.grid_size)
self.pos_embed = None
elif pos_embed_type == "sincos":
self.rope = None
# build_sincos2d_pos_embed(feat_shape, dim, ...)
# We assume grid_size matches num_patches
pos_embed = build_sincos2d_pos_embed(self.grid_size, embed_dim)
self.register_buffer("pos_embed", pos_embed.unsqueeze(0)) # [1, N, D]
elif pos_embed_type == "learnable":
self.rope = None
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
nn.init.trunc_normal_(self.pos_embed, std=0.02)
else:
raise ValueError(f"Unknown pos_embed_type: {pos_embed_type}")
# Stochastic Depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
self.blocks = nn.ModuleList(
[
RoPEBlock(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
proj_drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer,
rope=self.rope,
)
for i in range(depth)
]
)
self.norm = norm_layer(embed_dim)
self.apply(self._init_weights)
def _init_weights(self, m: nn.Module) -> None:
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(
self,
x: torch.Tensor,
pos_ids: Optional[torch.Tensor] = None,
add_pos_embed: bool = True,
grid_size: Optional[Tuple[int, int]] = None,
) -> torch.Tensor:
"""
Forward pass.
Args:
x (torch.Tensor): Input tensor [B, N, D].
pos_ids (Optional[torch.Tensor]): Positional indices [B, N] or [N].
add_pos_embed (bool): Whether to add positional embeddings (for non-RoPE).
grid_size (Optional[Tuple[int, int]]): Grid size for RoPE/PosEmbed.
Returns:
torch.Tensor: Output tensor [B, N, D].
"""
# Determine grid size
if grid_size is None:
if pos_ids is None:
# Infer from x assuming full sequence
B, N, D = x.shape
H_grid = self.grid_size[0]
W_grid = N // H_grid
current_grid_size = (H_grid, W_grid)
else:
# Cannot infer, use default (might be wrong if variable length)
current_grid_size = self.grid_size
else:
current_grid_size = grid_size
if self.pos_embed_type != "rope" and add_pos_embed:
if pos_ids is not None:
# Select positional embeddings
if pos_ids.ndim == 1:
# Shared pos_ids across batch
pos_embed = self.pos_embed[:, pos_ids, :] # [1, N_subset, D]
else:
# Different pos_ids per sample
pos_embed = self.pos_embed.expand(x.shape[0], -1, -1)
pos_embed = torch.gather(
pos_embed,
1,
pos_ids.unsqueeze(-1).expand(-1, -1, self.embed_dim),
)
x = x + pos_embed
else:
# Assume full sequence
if x.shape[1] == self.num_patches:
x = x + self.pos_embed
elif (
self.pos_embed is not None and x.shape[1] <= self.pos_embed.shape[1]
):
x = x + self.pos_embed[:, : x.shape[1], :]
# For RoPE, we need pos_ids. If not provided, generate them.
if self.pos_embed_type == "rope" and pos_ids is None:
device = x.device
# We need to generate pos_ids for the current grid
# If we inferred current_grid_size, we should use it.
# pos_ids should be 0..N-1
B, N, D = x.shape
pos_ids = torch.arange(N, device=device)
for block in self.blocks:
x = block(x, pos_ids=pos_ids, grid_size=current_grid_size)
x = self.norm(x)
return x
|