ltuncay's picture
Submission to the Interspeech 2026 Audio Encoder Capability Challenge
eca55dc verified
import torch
import torch.nn as nn
from timm.layers import Mlp, build_sincos2d_pos_embed, DropPath
from src.models.components.rope import RoPEAttention, RotaryEmbedding2D
from typing import Optional, Tuple
class RoPEBlock(nn.Module):
"""
Transformer Block with RoPE support.
"""
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = False,
proj_drop: float = 0.0,
attn_drop: float = 0.0,
drop_path: float = 0.0,
act_layer: nn.Module = nn.GELU,
norm_layer: nn.Module = nn.LayerNorm,
rope: Optional[RotaryEmbedding2D] = None,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = RoPEAttention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=proj_drop,
rope=rope,
)
self.norm2 = norm_layer(dim)
self.mlp = Mlp(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=proj_drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
def forward(
self,
x: torch.Tensor,
pos_ids: Optional[torch.Tensor] = None,
grid_size: Optional[Tuple[int, int]] = None,
) -> torch.Tensor:
x = x + self.drop_path(
self.attn(self.norm1(x), pos_ids=pos_ids, grid_size=grid_size)
)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class ViT(nn.Module):
"""
Vision Transformer with support for RoPE and 2D positional embeddings.
Args:
embed_dim (int): Embedding dimension.
depth (int): Number of transformer blocks.
num_heads (int): Number of attention heads.
mlp_ratio (float): Ratio of MLP hidden dim to embedding dim.
qkv_bias (bool): Enable bias for QKV projections.
drop_rate (float): Dropout rate.
attn_drop_rate (float): Attention dropout rate.
drop_path_rate (float): Stochastic depth rate.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
num_patches (int): Total number of patches (used for learnable/sincos pos embed).
img_size (tuple[int, int]): Input image size (H, W).
patch_size (tuple[int, int]): Patch size (H, W).
pos_embed_type (str): Type of positional embedding ("rope", "sincos", "learnable").
"""
def __init__(
self,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
norm_layer: nn.Module = nn.LayerNorm,
act_layer: nn.Module = nn.GELU,
num_patches: int = 128,
img_size: tuple[int, int] = (128, 256),
patch_size: tuple[int, int] = (16, 16),
pos_embed_type: str = "rope",
):
super().__init__()
self.embed_dim = embed_dim
self.num_patches = num_patches
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.pos_embed_type = pos_embed_type
# Positional Embeddings
if pos_embed_type == "rope":
head_dim = embed_dim // num_heads
self.rope = RotaryEmbedding2D(dim=head_dim, max_res=self.grid_size)
self.pos_embed = None
elif pos_embed_type == "sincos":
self.rope = None
# build_sincos2d_pos_embed(feat_shape, dim, ...)
# We assume grid_size matches num_patches
pos_embed = build_sincos2d_pos_embed(self.grid_size, embed_dim)
self.register_buffer("pos_embed", pos_embed.unsqueeze(0)) # [1, N, D]
elif pos_embed_type == "learnable":
self.rope = None
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
nn.init.trunc_normal_(self.pos_embed, std=0.02)
else:
raise ValueError(f"Unknown pos_embed_type: {pos_embed_type}")
# Stochastic Depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
self.blocks = nn.ModuleList(
[
RoPEBlock(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
proj_drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer,
rope=self.rope,
)
for i in range(depth)
]
)
self.norm = norm_layer(embed_dim)
self.apply(self._init_weights)
def _init_weights(self, m: nn.Module) -> None:
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(
self,
x: torch.Tensor,
pos_ids: Optional[torch.Tensor] = None,
add_pos_embed: bool = True,
grid_size: Optional[Tuple[int, int]] = None,
) -> torch.Tensor:
"""
Forward pass.
Args:
x (torch.Tensor): Input tensor [B, N, D].
pos_ids (Optional[torch.Tensor]): Positional indices [B, N] or [N].
add_pos_embed (bool): Whether to add positional embeddings (for non-RoPE).
grid_size (Optional[Tuple[int, int]]): Grid size for RoPE/PosEmbed.
Returns:
torch.Tensor: Output tensor [B, N, D].
"""
# Determine grid size
if grid_size is None:
if pos_ids is None:
# Infer from x assuming full sequence
B, N, D = x.shape
H_grid = self.grid_size[0]
W_grid = N // H_grid
current_grid_size = (H_grid, W_grid)
else:
# Cannot infer, use default (might be wrong if variable length)
current_grid_size = self.grid_size
else:
current_grid_size = grid_size
if self.pos_embed_type != "rope" and add_pos_embed:
if pos_ids is not None:
# Select positional embeddings
if pos_ids.ndim == 1:
# Shared pos_ids across batch
pos_embed = self.pos_embed[:, pos_ids, :] # [1, N_subset, D]
else:
# Different pos_ids per sample
pos_embed = self.pos_embed.expand(x.shape[0], -1, -1)
pos_embed = torch.gather(
pos_embed,
1,
pos_ids.unsqueeze(-1).expand(-1, -1, self.embed_dim),
)
x = x + pos_embed
else:
# Assume full sequence
if x.shape[1] == self.num_patches:
x = x + self.pos_embed
elif (
self.pos_embed is not None and x.shape[1] <= self.pos_embed.shape[1]
):
x = x + self.pos_embed[:, : x.shape[1], :]
# For RoPE, we need pos_ids. If not provided, generate them.
if self.pos_embed_type == "rope" and pos_ids is None:
device = x.device
# We need to generate pos_ids for the current grid
# If we inferred current_grid_size, we should use it.
# pos_ids should be 0..N-1
B, N, D = x.shape
pos_ids = torch.arange(N, device=device)
for block in self.blocks:
x = block(x, pos_ids=pos_ids, grid_size=current_grid_size)
x = self.norm(x)
return x