| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from dataclasses import dataclass |
| from typing import Optional, Tuple |
| import torch |
| from transformers.utils import ModelOutput |
| from transformers import PreTrainedModel |
| from transformers.modeling_outputs import BaseModelOutputWithPooling, SequenceClassifierOutput |
| from transformers import AutoModel |
|
|
| from .configuration_protenrich import ProtEnrichConfig |
|
|
| @dataclass |
| class ProtEnrichModelOutput(ModelOutput): |
| h_enrich: torch.FloatTensor = None |
| h_anchor: Optional[torch.FloatTensor] = None |
| h_algn: Optional[torch.FloatTensor] = None |
| struct: Optional[torch.FloatTensor] = None |
| dyn: Optional[torch.FloatTensor] = None |
|
|
| class MLPEncoder(nn.Module): |
| def __init__(self, in_dim, out_dim, hidden_dim=1024, n_layers=2, dropout=0.1): |
| super().__init__() |
| layers = [] |
| d = in_dim |
| for _ in range(n_layers - 1): |
| layers += [ |
| nn.Linear(d, hidden_dim), |
| nn.LayerNorm(hidden_dim), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| ] |
| d = hidden_dim |
| layers.append(nn.Linear(d, out_dim)) |
| self.net = nn.Sequential(*layers) |
|
|
| def forward(self, x): |
| return self.net(x) |
|
|
| class ProtEnrichModel(PreTrainedModel): |
| config_class = ProtEnrichConfig |
| base_model_prefix = "protenrich" |
|
|
| def __init__(self, config: ProtEnrichConfig): |
| super().__init__(config) |
|
|
| self.seq_anchor = MLPEncoder(config.seq_dim, config.embed_dim) |
| self.seq_algn = MLPEncoder(config.seq_dim, config.embed_dim) |
| self.struct_encoder = MLPEncoder(config.struct_dim, config.embed_dim) |
| self.dyn_encoder = MLPEncoder(config.dyn_dim, config.embed_dim) |
|
|
| for p in self.struct_encoder.parameters(): |
| p.requires_grad = False |
| for p in self.dyn_encoder.parameters(): |
| p.requires_grad = False |
|
|
| self.seq_projector = nn.Linear(config.embed_dim, config.project_dim) |
| self.struct_projector = nn.Linear(config.embed_dim, config.project_dim) |
| self.dyn_projector = nn.Linear(config.embed_dim, config.project_dim) |
|
|
| self.seq_decoder = MLPEncoder(config.embed_dim, config.seq_dim) |
| self.struct_decoder = MLPEncoder(config.embed_dim, config.struct_dim) |
| self.dyn_decoder = MLPEncoder(config.embed_dim, config.dyn_dim) |
|
|
| self.alpha_logit = nn.Parameter(torch.tensor(-2.0)) |
| self.alpha_max = config.alpha_max |
|
|
| self.norm_anchor = nn.LayerNorm(config.embed_dim) |
| self.norm_algn = nn.LayerNorm(config.embed_dim) |
|
|
| self.post_init() |
|
|
| def forward(self, seq: torch.Tensor, return_dict: Optional[bool] = None): |
|
|
| h_anchor = self.norm_anchor(self.seq_anchor(seq)) |
| h_algn = self.norm_algn(self.seq_algn(seq)) |
|
|
| struct = self.struct_decoder(h_algn) |
| dyn = self.dyn_decoder(h_algn) |
|
|
| alpha = torch.sigmoid(self.alpha_logit) * self.alpha_max |
| h_enrich = h_anchor + alpha * h_algn |
|
|
| return ProtEnrichModelOutput( |
| h_enrich=h_enrich, |
| h_anchor=h_anchor, |
| h_algn=h_algn, |
| struct=struct, |
| dyn=dyn, |
| ) |
|
|
| class ProtEnrichForSequenceClassification(PreTrainedModel): |
| config_class = ProtEnrichConfig |
|
|
| def __init__(self, config: ProtEnrichConfig): |
| super().__init__(config) |
|
|
| self.num_labels = config.num_labels |
|
|
| self.protenrich = ProtEnrichModel(config) |
| self.classifier = nn.Linear(config.embed_dim, config.num_labels) |
|
|
| self.post_init() |
|
|
| def forward(self, seq: torch.Tensor, labels: Optional[torch.Tensor] = None, return_dict: Optional[bool] = None): |
|
|
| outputs = self.protenrich(seq=seq, return_dict=return_dict) |
| pooled = outputs.h_enrich |
|
|
| logits = self.classifier(pooled) |
|
|
| loss = None |
| if labels is not None: |
| loss_fct = nn.CrossEntropyLoss() |
| loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
|
| return SequenceClassifierOutput( |
| loss=loss, |
| logits=logits, |
| hidden_states=pooled, |
| ) |