import torch from torch import nn import math class CustomTransformerEncoderLayer(nn.TransformerEncoderLayer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def forward(self, src, src_mask=None, src_key_padding_mask=None): # Obtain the output and attention weights directly from self.self_attn src2, attn_weights = self.self_attn( src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask, average_attn_weights=False, need_weights=True ) src = src + self.dropout1(src2) src = self.norm1(src) src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) src = src + self.dropout2(src2) src = self.norm2(src) return src, attn_weights class SingleTransformer(nn.Module): """ Transformer-based model for each modality. Args: vocab_size (int): Vocabulary size. (set 1 if projection is used.) seq_len (int): Sequence length. n_encoder_layers (int): Number of transformer encoder layers. n_heads (int): Number of attention heads. n_batches (int): Number of batches. d_tokens (int): Dimension of the token embeddings. d_ff (int): Dimension of the feedforward layer. d_batch (int): Dimension of the batch embeddings. dropout_rate (float, optional): Dropout rate. Defaults to 0.1. Attributes: count_embedding (torch.Tensor): Count embeddings. id_embeddings (torch.Tensor): ID embeddings. batch_embedding (nn.Embedding): Batch embeddings. layer_norm (nn.LayerNorm): Layer normalization. cls_token (torch.Tensor): CLS token. encoder (nn.TransformerEncoder): Transformer encoder. mask_output_layer (nn.Linear): Mask output layer. cls_attention (nn.MultiheadAttention): Multihead attention for CLS token. cls_norm1 (nn.LayerNorm): Layer normalization for CLS token. cls_norm2 (nn.LayerNorm): Layer normalization for CLS token. cls_ffn (nn.Sequential): Feedforward network for CLS token. cls_output_layer (nn.Linear): Output layer for CLS token. pretrained (bool): Flag indicating if pretrained weights are frozen. Methods: forward(x, batch_indices, masked_lm=False, return_attention=False, return_embeddings=False): Forward pass of the module. freeze_pretrained_weights(): Freeze the pretrained weights. unfreeze_pretrained_weights(): Unfreeze the pretrained weights. create_count_embeddings(max_count, embed_size): Create count embeddings. get_latent_space(inputs, batch_indices, batch_size=32): Get the latent space representation and predictions. """ def __init__(self, model_type, vocab_size, seq_len, n_encoder_layers, n_heads, n_batches, d_model, d_ff, dropout_rate=0.0): super(SingleTransformer, self).__init__() if model_type not in ['RNA', 'ATAC', 'Flux']: raise ValueError("model_type must be one of 'RNA', 'ATAC', 'Flux'") self.model_type = model_type if self.model_type == 'RNA': self.count_embedding_fix = self.create_count_embeddings(vocab_size, d_model) else: self.count_embedding_proj = nn.Linear(1, d_model) self.id_embeddings = nn.Parameter(torch.zeros(1, seq_len, d_model)) nn.init.normal_(self.id_embeddings, mean=0.0, std=0.02) self.batch_embedding = nn.Embedding(n_batches, d_model) self.layer_norm = nn.LayerNorm(d_model) self.token_layer_norm = nn.LayerNorm(d_model) self.batch_layer_norm = nn.LayerNorm(d_model) # self.alpha = nn.Parameter(torch.tensor(1.0)) # self.beta = nn.Parameter(torch.tensor(1.0)) self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model)) nn.init.normal_(self.cls_token, mean=0.0, std=0.02) # encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_heads, dim_feedforward=d_ff, dropout=dropout_rate, batch_first=True) encoder_layer = CustomTransformerEncoderLayer( d_model=d_model, nhead=n_heads, dim_feedforward=d_ff, dropout=dropout_rate, batch_first=True ) self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_encoder_layers) self.mask_output_layer = nn.Linear(d_model, vocab_size) self.cls_attention = nn.MultiheadAttention(embed_dim=d_model, num_heads=n_heads, batch_first=True) self.cls_norm1 = nn.LayerNorm(d_model) self.cls_norm2 = nn.LayerNorm(d_model) self.cls_ffn = nn.Sequential( nn.Linear(d_model, d_ff), nn.ReLU(), nn.Dropout(dropout_rate), nn.Linear(d_ff, d_model) ) self.dropout = nn.Dropout(dropout_rate) self.cls_output_layer = nn.Linear(d_model, 1) def forward(self, x, batch_indices, masked_lm=False, return_attention=False, return_embeddings=False, return_flow_attention=False): # [batch_dim, seq_dim, embed_dim] if self.model_type == 'RNA': self.count_embedding_fix = self.count_embedding_fix.to(x.device) x = x.long() x = self.count_embedding_fix[x] else: x = x.unsqueeze(-1).float() x = self.count_embedding_proj(x) x = x + self.id_embeddings[:, :x.size(1), :] batch_embeddings = self.batch_embedding(batch_indices).unsqueeze(1)#.expand(-1, x.size(1), -1) # repeat for the token dim # token_embeddings = self.token_layer_norm(x) # batch_embeddings = self.batch_layer_norm(batch_embeddings) # x = token_embeddings + batch_embeddings # print(batch_embeddings.shape, x.shape) # print(torch.max(batch_embeddings.flatten()), torch.max(token_embeddings.flatten())) # print(torch.min(batch_embeddings.flatten()), torch.min(token_embeddings.flatten())) # print("===") x = torch.cat((x, batch_embeddings), dim=1) #x + batch_embeddings # x = self.layer_norm(x) attention_flow = [] for layer in self.encoder.layers: x, attn_weights = layer(x) if return_flow_attention: attention_flow.append(attn_weights) other_tokens = x #self.encoder(x) if return_embeddings: return other_tokens, attention_flow if masked_lm: # exclude the batch embeddings other_tokens = other_tokens[:, :-1, :] return self.mask_output_layer(other_tokens) cls_token = self.cls_token.expand(x.size(0), -1, -1) # repeat for the batch dim attended_cls, attention_weights = self.cls_attention(cls_token, other_tokens, other_tokens, need_weights=True, average_attn_weights=False) attended_cls = attended_cls.squeeze(1) cls_output = self.cls_norm1(cls_token.squeeze(1) + self.dropout(attended_cls)) cls_output = self.cls_norm2(cls_output + self.dropout(self.cls_ffn(cls_output))) preds = self.cls_output_layer(cls_output) preds = torch.sigmoid(preds) if return_flow_attention: return preds, cls_output, attention_weights, attention_flow elif return_attention: return preds, cls_output, attention_weights else: return preds, cls_output def freeze_pretrained_weights(self): for name, param in self.named_parameters(): if not any(x in name for x in ['cls_attention', 'cls_norm', 'cls_ffn', 'cls_token', 'cls_ff_dim', 'cls_output_layer']): param.requires_grad = False self.pretrained = True def unfreeze_pretrained_weights(self): for param in self.parameters(): param.requires_grad = True self.pretrained = False def create_count_embeddings(self, max_count, embed_size): embeddings = torch.zeros(max_count + 1, embed_size) for i in range(max_count + 1): embeddings[i] = torch.tensor([math.sin(i / (10000 ** (2 * (j // 2) / embed_size))) if j % 2 == 0 else math.cos(i / (10000 ** (2 * (j // 2) / embed_size))) for j in range(embed_size)]) return embeddings def get_latent_space(self, inputs, batch_indices, batch_size=32): """ Get the latent space representation and predictions. Args: inputs (torch.Tensor): Input tensor. batch_indices (torch.Tensor): Batch indices tensor. batch_size (int, optional): Batch size. Defaults to 32. Returns: torch.Tensor: Latent space representation. torch.Tensor: Predictions. """ self.eval() latent_space_list, preds_list = [], [] with torch.no_grad(): for i in range(0, inputs.shape[0], batch_size): inputs_batch = inputs[i:i + batch_size].float() batch_indices_batch = batch_indices[i:i + batch_size].int() preds, reduced_dim = self(inputs_batch, batch_indices_batch) latent_space_list.append(reduced_dim) preds_list.append(preds) latent_space = torch.cat(latent_space_list, dim=0) preds = torch.cat(preds_list, dim=0) return latent_space, preds class MultiModalTransformer(nn.Module): def __init__(self, rna_model, atac_model, flux_model, d_model, n_heads_cls, d_ff_cls, dropout_rate=0.0): super(MultiModalTransformer, self).__init__() self.rna_model = rna_model self.atac_model = atac_model self.flux_model = flux_model self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model)) nn.init.normal_(self.cls_token, mean=0.0, std=0.02) # self.modality_embeddings = nn.Embedding(3, d_model) self.layer_norm = nn.LayerNorm(d_model) self.cls_attention = nn.MultiheadAttention(embed_dim=d_model, num_heads=n_heads_cls, dropout=dropout_rate, batch_first=True) self.cls_norm1 = nn.LayerNorm(d_model) self.cls_norm2 = nn.LayerNorm(d_model) self.cls_ffn = nn.Sequential( nn.Linear(d_model, d_ff_cls), nn.ReLU(), nn.Dropout(dropout_rate), nn.Linear(d_ff_cls, d_model)) self.cls_output_layer = nn.Linear(d_model, 1) self.dropout = nn.Dropout(dropout_rate) def forward(self, x, batch_indices, return_attention=False, return_embeddings=False, return_flow_attention=False): rna_input, atac_input, flux_input = x[0], x[1], x[2] rna_tokens, rna_attention = self.rna_model(rna_input, batch_indices, return_embeddings=True, return_flow_attention=return_flow_attention) # [32, 944, 128] atac_tokens, atac_attention = self.atac_model(atac_input, batch_indices, return_embeddings=True, return_flow_attention=return_flow_attention) # [32, 883, 128] flux_tokens, flux_attention = self.flux_model(flux_input, batch_indices, return_embeddings=True, return_flow_attention=return_flow_attention) # [32, 168, 128] # rna_tokens += self.modality_embeddings(torch.tensor([0]).to(rna_tokens.device)) # atac_tokens += self.modality_embeddings(torch.tensor([1]).to(atac_tokens.device)) # flux_tokens += self.modality_embeddings(torch.tensor([2]).to(flux_tokens.device)) other_tokens = torch.cat((rna_tokens, atac_tokens, flux_tokens), dim=-2) # [32, 1995, 128] if return_embeddings: return other_tokens # create mask rna_mask = (rna_input.sum(dim=1) != 0).float() # [32] # b1 = rna_mask.sum() atac_mask = (atac_input.sum(dim=1) != 0).float() # [32] # b2 = atac_mask.sum() flux_mask = (flux_input.sum(dim=1) != 0).float() # [32] rna_mask = rna_mask.unsqueeze(-1).expand(-1, rna_tokens.size(1)) # [32, 944] atac_mask = atac_mask.unsqueeze(-1).expand(-1, atac_tokens.size(1)) # [32, 883] flux_mask = flux_mask.unsqueeze(-1).expand(-1, flux_tokens.size(1)) # [32, 168] other_tokens_mask = torch.cat((rna_mask, atac_mask, flux_mask), dim=1) # [32, 1995] other_tokens = self.layer_norm(other_tokens) cls_token = self.cls_token.expand(other_tokens.size(0), -1, -1) # [32, 1, 128] attended_cls, attention_weights = self.cls_attention(cls_token, other_tokens, other_tokens, key_padding_mask=(1 - other_tokens_mask).bool(), need_weights=True, average_attn_weights=False) attended_cls = attended_cls.squeeze(1) cls_output = self.cls_norm1(cls_token.squeeze(1) + self.dropout(attended_cls)) cls_output = self.cls_norm2(cls_output + self.dropout(self.cls_ffn(cls_output))) preds = self.cls_output_layer(cls_output) preds = torch.sigmoid(preds) if return_flow_attention: return preds, cls_output, { 'rna': rna_attention, 'atac': atac_attention, 'flux': flux_attention, 'cls': attention_weights } elif return_attention: return preds, cls_output, attention_weights else: return preds, cls_output def freeze_pretrained_weights(self): self.rna_model.freeze_pretrained_weights() self.atac_model.freeze_pretrained_weights() self.flux_model.freeze_pretrained_weights() for name, param in self.named_parameters(): if not any(x in name for x in ['cls_attention', 'cls_norm', 'cls_ffn', 'cls_token', 'cls_output_layer']): param.requires_grad = False def unfreeze_pretrained_weights(self): self.rna_model.unfreeze_pretrained_weights() self.atac_model.unfreeze_pretrained_weights() self.flux_model.unfreeze_pretrained_weights() for param in self.parameters(): param.requires_grad = True def get_latent_space(self, X, batch_indices, batch_size=32): self.eval() latent_space_list, preds_list = [], [] rna_input, atac_input, flux_input = X[0], X[1], X[2] with torch.no_grad(): for i in range(0, rna_input.shape[0], batch_size): rna_input_batch = rna_input[i:i + batch_size].float() atac_input_batch = atac_input[i:i + batch_size].float() flux_input_batch = flux_input[i:i + batch_size].float() batch_indices_batch = batch_indices[i:i + batch_size].int() preds, reduced_dim = self((rna_input_batch, atac_input_batch, flux_input_batch), batch_indices_batch) latent_space_list.append(reduced_dim) preds_list.append(preds) latent_space = torch.cat(latent_space_list, dim=0) preds = torch.cat(preds_list, dim=0) return latent_space, preds if __name__=='__main__': model = SingleTransformer(model_type='ATAC', vocab_size=1, seq_len=883, n_encoder_layers=2, n_heads=2, n_batches=3, d_tokens=508, d_ff=128, d_batch=4) x = torch.rand(32, 883) batch_indices = torch.randint(1, 3, (32,)) print(model(x, batch_indices, masked_lm=True).shape) print(model(x, batch_indices, return_attention=True)[0].shape) print(model(x, batch_indices, return_embeddings=True).shape) print(model(x, batch_indices).shape)