# Reference: https://github.com/sooftware/conformer import contextlib import math from collections import defaultdict from typing import Dict, List, Optional, Tuple import torch import torch.nn.functional as F from torch import nn from torch.nn.attention import SDPBackend, sdpa_kernel def lengths_to_padding_mask( lengths: torch.Tensor, max_len: Optional[int] = None ) -> torch.Tensor: """Create padding mask from lengths. Args: lengths: A 1-D tensor of shape (B,). max_len: An integer. It will be automatically set to the max value of lengths if not given. Returns: A bool tensor of shape (B, max_len), where padded positions are indicated by True. """ batch_size = lengths.size(0) max_len = lengths.max().item() if max_len is None else max_len seq_range = torch.arange(0, max_len, dtype=torch.long, device=lengths.device) seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) lengths_expand = lengths.unsqueeze(1).expand_as(seq_range_expand) padding_mask = seq_range_expand >= lengths_expand return padding_mask class SamePad(nn.Module): def __init__(self, kernel_size, causal=False): super().__init__() if causal: self.remove = kernel_size - 1 else: self.remove = 1 if kernel_size % 2 == 0 else 0 def forward(self, x): if self.remove > 0: x = x[:, :, : -self.remove] return x class TransposeLast(nn.Module): def __init__(self, deconstruct_idx=None, tranpose_dim=-2): super().__init__() self.deconstruct_idx = deconstruct_idx self.tranpose_dim = tranpose_dim def forward(self, x): if self.deconstruct_idx is not None: x = x[self.deconstruct_idx] return x.transpose(self.tranpose_dim, -1) class Swish(nn.Module): def __init__(self): super(Swish, self).__init__() def forward(self, inputs: torch.Tensor) -> torch.Tensor: return inputs * inputs.sigmoid() class GLU(nn.Module): def __init__(self, dim: int) -> None: super(GLU, self).__init__() self.dim = dim def forward(self, inputs: torch.Tensor) -> torch.Tensor: outputs, gate = inputs.chunk(2, dim=self.dim) return outputs * gate.sigmoid() class RMSNorm(torch.nn.Module): def __init__(self, dim: int, eps: float = 1e-5): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): output = self._norm(x.float()).type_as(x) return output * self.weight class ResidualConnectionModule(nn.Module): def __init__( self, module: nn.Module, module_factor: float = 1.0, input_factor: float = 1.0 ): super(ResidualConnectionModule, self).__init__() self.module = module self.module_factor = module_factor self.input_factor = input_factor def forward(self, inputs: torch.Tensor) -> torch.Tensor: return (self.module(inputs) * self.module_factor) + (inputs * self.input_factor) class Linear(nn.Module): def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None: super(Linear, self).__init__() self.linear = nn.Linear(in_features, out_features, bias=bias) nn.init.xavier_uniform_(self.linear.weight) if bias: nn.init.zeros_(self.linear.bias) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.linear(x) class View(nn.Module): def __init__(self, shape: tuple, contiguous: bool = False): super(View, self).__init__() self.shape = shape self.contiguous = contiguous def forward(self, x: torch.Tensor) -> torch.Tensor: if self.contiguous: x = x.contiguous() return x.view(*self.shape) class Transpose(nn.Module): def __init__(self, shape: tuple): super(Transpose, self).__init__() self.shape = shape def forward(self, x: torch.Tensor) -> torch.Tensor: return x.transpose(*self.shape) class FeedForwardModule(nn.Module): def __init__( self, encoder_dim: int = 512, expansion_factor: int = 4, dropout_p: float = 0.1, rms_norm: bool = False, ) -> None: super(FeedForwardModule, self).__init__() self.sequential = nn.Sequential( nn.LayerNorm(encoder_dim) if not rms_norm else RMSNorm(encoder_dim), Linear(encoder_dim, encoder_dim * expansion_factor, bias=True), Swish(), nn.Dropout(p=dropout_p), Linear(encoder_dim * expansion_factor, encoder_dim, bias=True), nn.Dropout(p=dropout_p), ) def forward(self, inputs: torch.Tensor) -> torch.Tensor: return self.sequential(inputs) class DepthwiseConv1d(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, padding: int = 0, bias: bool = False, ) -> None: super(DepthwiseConv1d, self).__init__() assert ( out_channels % in_channels == 0 ), "out_channels should be constant multiple of in_channels" self.conv = nn.Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, groups=in_channels, stride=stride, padding=padding, bias=bias, ) def forward(self, inputs: torch.Tensor) -> torch.Tensor: return self.conv(inputs) class PointwiseConv1d(nn.Module): def __init__( self, in_channels: int, out_channels: int, stride: int = 1, padding: int = 0, bias: bool = True, ) -> None: super(PointwiseConv1d, self).__init__() self.conv = nn.Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, padding=padding, bias=bias, ) def forward(self, inputs: torch.Tensor) -> torch.Tensor: return self.conv(inputs) class ConformerConvModule(nn.Module): def __init__( self, in_channels: int, kernel_size: int = 31, expansion_factor: int = 2, dropout_p: float = 0.1, rms_norm: bool = False, ) -> None: super(ConformerConvModule, self).__init__() assert ( kernel_size - 1 ) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding" assert expansion_factor == 2, "Currently, Only Supports expansion_factor 2" self.sequential = nn.Sequential( nn.LayerNorm(in_channels) if not rms_norm else RMSNorm(in_channels), Transpose(shape=(1, 2)), PointwiseConv1d( in_channels, in_channels * expansion_factor, stride=1, padding=0, bias=True, ), GLU(dim=1), DepthwiseConv1d( in_channels, in_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2, ), nn.BatchNorm1d(in_channels), Swish(), PointwiseConv1d(in_channels, in_channels, stride=1, padding=0, bias=True), nn.Dropout(p=dropout_p), ) def forward(self, inputs: torch.Tensor) -> torch.Tensor: return self.sequential(inputs).transpose(1, 2) class FramewiseConv2dSubampling(nn.Module): def __init__(self, out_channels: int, subsample_rate: int = 2) -> None: super(FramewiseConv2dSubampling, self).__init__() assert subsample_rate in {2, 4}, "subsample_rate should be 2 or 4" self.subsample_rate = subsample_rate self.cnn = nn.Sequential( nn.Conv2d(1, out_channels, kernel_size=3, stride=2), nn.ReLU(), nn.Conv2d( out_channels, out_channels, kernel_size=3, stride=(2 if subsample_rate == 4 else 1, 2), padding=(0 if subsample_rate == 4 else 1, 0), ), nn.ReLU(), ) def forward( self, inputs: torch.Tensor, input_lengths: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: # inputs: (B, T, C) -> (B, 1, T, C) if self.subsample_rate == 2 and inputs.shape[1] % 2 == 0: inputs = F.pad(inputs, (0, 0, 0, 1), "constant", 0) if self.subsample_rate == 4 and inputs.shape[1] % 4 < 3: inputs = F.pad(inputs, (0, 0, 0, 3 - inputs.shape[1] % 4), "constant", 0) outputs = self.cnn(inputs.unsqueeze(1)) batch_size, channels, subsampled_lengths, sumsampled_dim = outputs.size() outputs = outputs.permute(0, 2, 1, 3) outputs = outputs.contiguous().view( batch_size, subsampled_lengths, channels * sumsampled_dim ) if self.subsample_rate == 4: output_lengths = input_lengths >> 2 else: output_lengths = input_lengths >> 1 return outputs, output_lengths def get_out_dim(self, input_dim: int) -> int: # dummy input to get the output dimension with torch.no_grad(): device = next(self.parameters()).device inputs = torch.zeros(1, 16, input_dim, device=device) input_lengths = torch.tensor([16], device=device) outputs, _ = self.forward(inputs, input_lengths) return outputs.size(-1) class PatchwiseConv2dSubampling(nn.Module): def __init__( self, mel_dim: int, out_channels: int, patch_size_time: int = 16, patch_size_freq: int = 16, ) -> None: super(PatchwiseConv2dSubampling, self).__init__() self.mel_dim = mel_dim self.patch_size_time = patch_size_time self.patch_size_freq = patch_size_freq self.proj = nn.Conv2d( 1, out_channels, kernel_size=(patch_size_time, patch_size_freq), stride=(patch_size_time, patch_size_freq), padding=0, ) self.cnn = nn.Sequential( nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), nn.ReLU(), ) @property def subsample_rate(self) -> int: return self.patch_size_time * self.patch_size_freq // self.mel_dim def forward( self, inputs: torch.Tensor, input_lengths: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: assert ( inputs.shape[2] == self.mel_dim ), "inputs.shape[2] should be equal to mel_dim" # inputs: (B, Time, Freq) -> (B, 1, Time, Freq) outputs = self.proj(inputs.unsqueeze(1)) outputs = self.cnn(outputs) # (B, channels, Time // patch_size_time, Freq // patch_size_freq) outputs = outputs.flatten(2, 3).transpose(1, 2) # (B, (Time // patch_size_time) * (Freq // patch_size_freq), channels) output_lengths = ( input_lengths // self.patch_size_time * (self.mel_dim // self.patch_size_freq) ) return outputs, output_lengths class RelPositionalEncoding(nn.Module): def __init__(self, d_model: int) -> None: super(RelPositionalEncoding, self).__init__() self.d_model = d_model self.pe = None def extend_pe(self, x: torch.Tensor) -> None: if self.pe is not None: if self.pe.size(1) >= x.size(1) * 2 - 1: if self.pe.dtype != x.dtype or self.pe.device != x.device: self.pe = self.pe.to(dtype=x.dtype, device=x.device) return length = x.size(1) pe_positive = torch.zeros(length, self.d_model, device="cpu") pe_negative = torch.zeros(length, self.d_model, device="cpu") position = torch.arange(0, length, dtype=torch.float32, device="cpu").unsqueeze( 1 ) div_term = torch.exp( torch.arange(0, self.d_model, 2, dtype=torch.float32, device="cpu") * -(math.log(10000.0) / self.d_model) ) pe_positive[:, 0::2] = torch.sin(position * div_term) pe_positive[:, 1::2] = torch.cos(position * div_term) pe_negative[:, 0::2] = torch.sin(-1 * position * div_term) pe_negative[:, 1::2] = torch.cos(-1 * position * div_term) pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0) pe_negative = pe_negative[1:].unsqueeze(0) pe = torch.cat([pe_positive, pe_negative], dim=1) self.pe = pe.to(device=x.device, dtype=x.dtype) def forward(self, x: torch.Tensor) -> torch.Tensor: # x: (B, T, C) self.extend_pe(x) assert self.pe is not None pos_emb = self.pe[ :, self.pe.size(1) // 2 - x.size(1) + 1 : self.pe.size(1) // 2 + x.size(1), ] return pos_emb class RelativeMultiHeadAttention(nn.Module): def __init__( self, d_model: int = 512, num_heads: int = 16, dropout_p: float = 0.1, ): super(RelativeMultiHeadAttention, self).__init__() assert d_model % num_heads == 0, "d_model % num_heads should be zero." self.d_model = d_model self.d_head = int(d_model / num_heads) self.num_heads = num_heads self.sqrt_dim = math.sqrt(self.d_head) self.query_proj = Linear(d_model, d_model) self.key_proj = Linear(d_model, d_model) self.value_proj = Linear(d_model, d_model) self.pos_proj = Linear(d_model, d_model, bias=False) self.dropout = nn.Dropout(p=dropout_p) self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head)) self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head)) torch.nn.init.xavier_uniform_(self.u_bias) torch.nn.init.xavier_uniform_(self.v_bias) self.out_proj = Linear(d_model, d_model) @staticmethod def _relative_shift(pos_score: torch.Tensor) -> torch.Tensor: # pos_score: (B, H, T, 2T-1) B, H, T, L = pos_score.size() # Pad on the left of the last dimension: (B, H, T, 2T) pos_score = F.pad(pos_score, (1, 0)) # Reshape to (B, H, 2T, T) pos_score = pos_score.view(B, H, L + 1, T) # Slice and reshape back to (B, H, T, 2T-1) pos_score = pos_score[:, :, 1:].view(B, H, T, L) # Keep only first T positions => (B, H, T, T) return pos_score[:, :, :, : (L // 2 + 1)] def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, pos_embedding: torch.Tensor, padding_mask: Optional[torch.Tensor] = None, *, need_weights: bool = False, use_sdpa: Optional[bool] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ - If need_weights=True: returns (output, attn) like your original code. - If need_weights=False: returns (output, None) and uses SDPA in eval for speed/memory. """ B, Tq, _ = query.size() _, Tk, _ = key.size() # Project q = self.query_proj(query) # (B, Tq, C) k = self.key_proj(key) # (B, Tk, C) v = self.value_proj(value) # (B, Tk, C) # Reshape to (B, H, T, Dh) q = q.view(B, Tq, self.num_heads, self.d_head).transpose(1, 2) # (B,H,Tq,Dh) k = k.view(B, Tk, self.num_heads, self.d_head).transpose(1, 2) # (B,H,Tk,Dh) v = v.view(B, Tk, self.num_heads, self.d_head).transpose(1, 2) # (B,H,Tk,Dh) # Positional projection. # IMPORTANT: allow pos_embedding to be (1, 2T-1, C) and broadcast across batch. # pos_embedding expected length: 2Tq - 1 for self-attn. pB = pos_embedding.size(0) p = self.pos_proj(pos_embedding) # (pB, L, C) p = p.view(pB, -1, self.num_heads, self.d_head).transpose(1, 2) # (pB,H,L,Dh) # Compute position-based bias (scaled) to feed SDPA or add to scores # q_pos: (B,H,Tq,Dh), p^T: (pB,H,Dh,L) -> broadcast on pB if pB==1 q_pos = q + self.v_bias.unsqueeze(0).unsqueeze(2) # (B,H,Tq,Dh) pos_score = torch.matmul(q_pos, p.transpose(-2, -1)) # (B,H,Tq,L) pos_bias = self._relative_shift(pos_score) # (B,H,Tq,Tq) for self-attn pos_bias = pos_bias.mul(1.0 / self.sqrt_dim) # scale matches SDPA scaling if padding_mask is not None: # padding_mask: (B, T) -> (B, 1, 1, T) to broadcast with pos_bias (B, H, Tq, Tk) # This masks out key positions that are padded across all heads and queries if padding_mask.dtype != torch.bool: padding_mask = padding_mask.to(torch.bool) pos_bias = pos_bias.masked_fill(padding_mask[:, None, None, :], -1e9) if use_sdpa is None: use_sdpa = (not self.training) and (not need_weights) # ---- Fast inference path: no attention matrix materialized ---- if use_sdpa: # Content term uses u_bias q_content = q + self.u_bias.unsqueeze(0).unsqueeze(2) # (B,H,Tq,Dh) with sdpa_kernel( [ SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH, ] ): out = F.scaled_dot_product_attention( q_content, # (B,H,Tq,Dh) k, # (B,H,Tk,Dh) v, # (B,H,Tk,Dh) attn_mask=pos_bias, # (B,H,Tq,Tk) additive bias dropout_p=0.0, # dropout disabled in inference is_causal=False, ) # (BH, Tq, Dh) out = out.transpose(1, 2).contiguous().view(B, Tq, self.d_model) return self.out_proj(out), None # ---- Reference path (training / if you need attn weights): matches your math ---- q_content = q + self.u_bias.unsqueeze(0).unsqueeze(2) # (B,H,Tq,Dh) content_score = torch.matmul(q_content, k.transpose(-2, -1)) # (B,H,Tq,Tk) content_score = content_score.mul(1.0 / self.sqrt_dim) score = content_score + pos_bias # already scaled attn = F.softmax(score, dim=-1) attn = self.dropout(attn) context = torch.matmul(attn, v) # (B,H,Tq,Dh) context = context.transpose(1, 2).contiguous().view(B, Tq, self.d_model) return self.out_proj(context), attn class MultiHeadedSelfAttentionModule(nn.Module): def __init__( self, d_model: int, num_heads: int, dropout_p: float = 0.1, rms_norm: bool = False, ): super(MultiHeadedSelfAttentionModule, self).__init__() self.positional_encoding = RelPositionalEncoding(d_model) self.layer_norm = nn.LayerNorm(d_model) if not rms_norm else RMSNorm(d_model) self.attention = RelativeMultiHeadAttention(d_model, num_heads, dropout_p) self.dropout = nn.Dropout(p=dropout_p) def forward( self, inputs: torch.Tensor, padding_mask: Optional[torch.Tensor] = None, pos_embedding: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: if pos_embedding is None: pos_embedding = self.positional_encoding(inputs) inputs = self.layer_norm(inputs) outputs, attn = self.attention( inputs, inputs, inputs, pos_embedding=pos_embedding, padding_mask=padding_mask, ) return self.dropout(outputs), attn, pos_embedding class ConformerBlock(nn.Module): def __init__( self, encoder_dim: int = 512, attention_type: str = "mhsa", num_attention_heads: int = 8, feed_forward_expansion_factor: int = 4, conv_expansion_factor: int = 2, feed_forward_dropout_p: float = 0.1, attention_dropout_p: float = 0.1, conv_dropout_p: float = 0.1, conv_kernel_size: int = 31, half_step_residual: bool = True, transformer_style: bool = False, usad_v2: bool = False, pre_norm: bool = False, rms_norm: bool = False, ): super(ConformerBlock, self).__init__() self.transformer_style = transformer_style self.attention_type = attention_type self.usad_v2 = usad_v2 self.pre_norm = pre_norm if half_step_residual and not transformer_style: self.feed_forward_residual_factor = 0.5 else: self.feed_forward_residual_factor = 1 assert ( attention_type == "mhsa" ), "Only 'mhsa' attention is supported in this implementation." attention = MultiHeadedSelfAttentionModule( d_model=encoder_dim, num_heads=num_attention_heads, dropout_p=attention_dropout_p, rms_norm=rms_norm, ) self.ffn_1 = FeedForwardModule( encoder_dim=encoder_dim, expansion_factor=feed_forward_expansion_factor, dropout_p=feed_forward_dropout_p, rms_norm=rms_norm, ) self.attention = attention if not transformer_style: self.conv = ConformerConvModule( in_channels=encoder_dim, kernel_size=conv_kernel_size, expansion_factor=conv_expansion_factor, dropout_p=conv_dropout_p, rms_norm=rms_norm, ) self.ffn_2 = FeedForwardModule( encoder_dim=encoder_dim, expansion_factor=feed_forward_expansion_factor, dropout_p=feed_forward_dropout_p, rms_norm=rms_norm, ) self.layernorm = ( (nn.LayerNorm(encoder_dim) if not rms_norm else RMSNorm(encoder_dim)) if not pre_norm else nn.Identity() ) def forward_attention( self, x: torch.Tensor, pos_embedding: Optional[torch.Tensor] = None, padding_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: attn_out, attn, pos_embedding = self.attention( x, pos_embedding=pos_embedding, padding_mask=padding_mask ) return attn_out, attn, pos_embedding def forward_legacy( self, x: torch.Tensor, pos_embedding: Optional[torch.Tensor] = None, padding_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Dict[str, Optional[torch.Tensor]]]: # FFN 1 ffn_1_out = self.ffn_1(x) x = ffn_1_out * self.feed_forward_residual_factor + x # Attention attn_out, attn, pos_embedding = self.forward_attention( x, pos_embedding, padding_mask ) x = attn_out + x if self.transformer_style: x = self.layernorm(x) return x, {"ffn_1": ffn_1_out, "attn": attn, "conv": None, "ffn_2": None} # Convolution conv_out = self.conv(x) x = conv_out + x # FFN 2 ffn_2_out = self.ffn_2(x) x = ffn_2_out * self.feed_forward_residual_factor + x x = self.layernorm(x) other = { "ffn_1": ffn_1_out, "attn": attn, "conv": conv_out, "ffn_2": ffn_2_out, "pos_embedding": pos_embedding, } return x, other def forward_transformer( self, x: torch.Tensor, pos_embedding: Optional[torch.Tensor] = None, padding_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Dict[str, Optional[torch.Tensor]]]: # Attention attn_out, attn, pos_embedding = self.forward_attention( x, pos_embedding, padding_mask ) x = attn_out + x # FFN ffn_out = self.ffn_1(x) x = ffn_out * self.feed_forward_residual_factor + x x = self.layernorm(x) return x, { "ffn_1": ffn_out, "attn": attn, "conv": None, "ffn_2": None, "pos_embedding": pos_embedding, } def forward_conformer( self, x: torch.Tensor, pos_embedding: Optional[torch.Tensor] = None, padding_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Dict[str, Optional[torch.Tensor]]]: # FFN 1 ffn_1_out = self.ffn_1(x) x = ffn_1_out * self.feed_forward_residual_factor + x # Attention attn_out, attn, pos_embedding = self.forward_attention( x, pos_embedding, padding_mask ) x = attn_out + x # Convolution conv_out = self.conv(x) x = conv_out + x # FFN 2 ffn_2_out = self.ffn_2(x) x = ffn_2_out * self.feed_forward_residual_factor + x x = self.layernorm(x) other = { "ffn_1": ffn_1_out, "attn": attn, "conv": conv_out, "ffn_2": ffn_2_out, "pos_embedding": pos_embedding, } return x, other def forward( self, x: torch.Tensor, pos_embedding: Optional[torch.Tensor] = None, padding_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Dict[str, Optional[torch.Tensor]]]: if not self.usad_v2: return self.forward_legacy(x, pos_embedding, padding_mask) if self.transformer_style: return self.forward_transformer(x, pos_embedding, padding_mask) return self.forward_conformer(x, pos_embedding, padding_mask) class ConformerEncoder(nn.Module): def __init__(self, cfg): super(ConformerEncoder, self).__init__() self.cfg = cfg self.framewise_subsample = None self.patchwise_subsample = None self.framewise_in_proj = None self.patchwise_in_proj = None assert ( cfg.use_framewise_subsample or cfg.use_patchwise_subsample ), "At least one subsampling method should be used" if cfg.use_framewise_subsample: self.framewise_subsample = FramewiseConv2dSubampling( out_channels=cfg.conv_subsample_channels, subsample_rate=cfg.conv_subsample_rate, ) self.framewise_in_proj = nn.Sequential( Linear( self.framewise_subsample.get_out_dim(cfg.input_dim), cfg.encoder_dim, ), nn.Dropout(p=cfg.input_dropout_p), ) if cfg.use_patchwise_subsample: self.patchwise_subsample = PatchwiseConv2dSubampling( mel_dim=cfg.input_dim, out_channels=cfg.conv_subsample_channels, patch_size_time=cfg.patch_size_time, patch_size_freq=cfg.patch_size_freq, ) self.patchwise_in_proj = nn.Sequential( Linear( cfg.conv_subsample_channels, cfg.encoder_dim, ), nn.Dropout(p=cfg.input_dropout_p), ) assert not cfg.use_framewise_subsample or ( cfg.conv_subsample_rate == self.patchwise_subsample.subsample_rate ), ( f"conv_subsample_rate ({cfg.conv_subsample_rate}) != patchwise_subsample.subsample_rate" f"({self.patchwise_subsample.subsample_rate})" ) self.framewise_norm, self.patchwise_norm = None, None if getattr(cfg, "subsample_normalization", False): if cfg.use_framewise_subsample: self.framewise_norm = ( nn.LayerNorm(cfg.encoder_dim) if not getattr(cfg, "rms_norm", False) else RMSNorm(cfg.encoder_dim) ) if cfg.use_patchwise_subsample: self.patchwise_norm = ( nn.LayerNorm(cfg.encoder_dim) if not getattr(cfg, "rms_norm", False) else RMSNorm(cfg.encoder_dim) ) self.conv_pos = None self.conv_pos_post_ln = None if cfg.conv_pos: num_pos_layers = cfg.conv_pos_depth k = max(3, cfg.conv_pos_width // num_pos_layers) self.conv_pos = nn.Sequential( TransposeLast(), *[ nn.Sequential( nn.Conv1d( cfg.encoder_dim, cfg.encoder_dim, kernel_size=k, padding=k // 2, groups=cfg.conv_pos_groups, ), SamePad(k), TransposeLast(), nn.LayerNorm(cfg.encoder_dim, elementwise_affine=False), TransposeLast(), nn.GELU(), ) for _ in range(num_pos_layers) ], TransposeLast(), ) self.conv_pos_post_ln = ( ( nn.LayerNorm(cfg.encoder_dim) if not getattr(cfg, "rms_norm", False) else RMSNorm(cfg.encoder_dim) ) if not getattr(cfg, "pre_norm", False) else nn.Identity() ) self.layers = nn.ModuleList( [ ConformerBlock( encoder_dim=cfg.encoder_dim, attention_type=cfg.attention_type, num_attention_heads=cfg.num_attention_heads, feed_forward_expansion_factor=cfg.feed_forward_expansion_factor, conv_expansion_factor=cfg.conv_expansion_factor, feed_forward_dropout_p=cfg.feed_forward_dropout_p, attention_dropout_p=cfg.attention_dropout_p, conv_dropout_p=cfg.conv_dropout_p, conv_kernel_size=cfg.conv_kernel_size, half_step_residual=cfg.half_step_residual, transformer_style=getattr(cfg, "transformer_style", False), usad_v2=getattr(cfg, "usad_v2", False), pre_norm=getattr(cfg, "pre_norm", False), rms_norm=getattr(cfg, "rms_norm", False), ) for _ in range(cfg.num_layers) ] ) self.layerdrop_p = getattr(cfg, "layerdrop_p", 0.0) if cfg.attention_type == "mhsa" and len(self.layers) > 0: # Share positional encoding across layers shared_pos = None for layer in self.layers: if isinstance(layer.attention, MultiHeadedSelfAttentionModule): if shared_pos is None: shared_pos = layer.attention.positional_encoding else: layer.attention.positional_encoding = shared_pos if shared_pos is not None: # precompute positional encodings # expecting most mel inputs to be fewer than 2000 frames (20 seconds) max_len = 2000 // cfg.conv_subsample_rate shared_pos.extend_pe(torch.tensor(0.0).expand(1, max_len)) def count_parameters(self) -> int: """Count parameters of encoder""" return sum([p.numel() for p in self.parameters() if p.requires_grad]) def update_dropout(self, dropout_p: float) -> None: """Update dropout probability of encoder""" for name, child in self.named_children(): if isinstance(child, nn.Dropout): child.p = dropout_p def forward( self, inputs: torch.Tensor, input_lengths: Optional[torch.Tensor] = None, padding_mask: Optional[torch.Tensor] = None, *, return_hidden: bool = False, freeze_input_layers: bool = False, target_layer: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, List[torch.Tensor]]]: if input_lengths is None: input_lengths = torch.full( (inputs.size(0),), inputs.size(1), dtype=torch.long, device=inputs.device, ) with torch.no_grad() if freeze_input_layers else contextlib.ExitStack(): frame_feat, patch_feat = None, None frame_lengths, patch_lengths = None, None if self.framewise_subsample is not None: assert self.framewise_in_proj is not None frame_feat, frame_lengths = self.framewise_subsample( inputs, input_lengths ) frame_feat = self.framewise_in_proj(frame_feat) if self.framewise_norm is not None: frame_feat = self.framewise_norm(frame_feat) if self.patchwise_subsample is not None: assert self.patchwise_in_proj is not None patch_feat, patch_lengths = self.patchwise_subsample( inputs, input_lengths ) patch_feat = self.patchwise_in_proj(patch_feat) if self.patchwise_norm is not None: patch_feat = self.patchwise_norm(patch_feat) assert frame_feat is not None or patch_feat is not None assert frame_lengths is not None or patch_lengths is not None if frame_feat is not None and patch_feat is not None: assert frame_lengths is not None and patch_lengths is not None min_len = min(frame_feat.size(1), patch_feat.size(1)) frame_feat = frame_feat[:, :min_len] patch_feat = patch_feat[:, :min_len] features = frame_feat + patch_feat output_lengths = ( frame_lengths if frame_lengths.max().item() < patch_lengths.max().item() else patch_lengths ) elif frame_feat is not None: features = frame_feat output_lengths = frame_lengths else: features = patch_feat output_lengths = patch_lengths assert features is not None assert output_lengths is not None # Positional encoding with convolutional layers if self.conv_pos is not None and self.conv_pos_post_ln is not None: pos = self.conv_pos(features) if not self.training: features = features.add_(pos) else: features = features + pos features = self.conv_pos_post_ln(features) # Create padding mask for attention if padding_mask is not None: # downsample to match features length input_len = padding_mask.size(1) feat_len = features.size(1) factor = input_len / feat_len indices = ( torch.arange(feat_len, device=padding_mask.device) * factor ).long() padding_mask = padding_mask.index_select(1, indices) else: # create from output_lengths padding_mask = lengths_to_padding_mask( output_lengths, max_len=features.size(1) ) layer_results = defaultdict(list) outputs = features other = {} for i, layer in enumerate(self.layers): if ( self.training and self.layerdrop_p > 0 and torch.rand(1).item() < self.layerdrop_p ): continue outputs, other = layer( outputs, pos_embedding=other.get("pos_embedding"), padding_mask=padding_mask, ) if return_hidden: layer_results["hidden_states"].append(outputs) for k, v in other.items(): layer_results[k].append(v) if target_layer is not None and i + 1 == target_layer: break return outputs, output_lengths, layer_results