| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch.nn.functional import binary_cross_entropy_with_logits |
| | import math |
| | from transformers import PreTrainedModel |
| | from .configuration_flowformer import FlowformerConfig |
| |
|
| |
|
| | class MAB(nn.Module): |
| | """ |
| | Multihead attention Block (MAB) from https://arxiv.org/abs/1810.00825. |
| | """ |
| | def __init__(self, dim_Q: int, dim_K: int, dim_V: int, num_heads: int, ln: int=False): |
| | super(MAB, self).__init__() |
| |
|
| | self.dim_V = dim_V |
| | self.num_heads = num_heads |
| | self.fc_q = nn.Linear(dim_Q, dim_V) |
| | self.fc_k = nn.Linear(dim_K, dim_V) |
| | self.fc_v = nn.Linear(dim_K, dim_V) |
| |
|
| | if ln: |
| | self.ln0 = nn.LayerNorm(dim_V) |
| | self.ln1 = nn.LayerNorm(dim_V) |
| | self.fc_o = nn.Linear(dim_V, dim_V) |
| |
|
| | def forward(self, Q, K): |
| | Q = self.fc_q(Q) |
| | K, V = self.fc_k(K), self.fc_v(K) |
| |
|
| | dim_split = self.dim_V // self.num_heads |
| | Q_ = torch.cat(Q.split(dim_split, 2), dim=0) |
| | K_ = torch.cat(K.split(dim_split, 2), dim=0) |
| | V_ = torch.cat(V.split(dim_split, 2), dim=0) |
| |
|
| | A = torch.softmax(Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V), 2) |
| | O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2) |
| | O = O if getattr(self, 'ln0', None) is None else self.ln0(O) |
| | O = O + F.relu(self.fc_o(O)) |
| | O = O if getattr(self, 'ln1', None) is None else self.ln1(O) |
| |
|
| | return O |
| |
|
| |
|
| | class ISAB(nn.Module): |
| | """ |
| | The Induced Set Attention Block (ISAB) from https://arxiv.org/abs/1810.00825. |
| | """ |
| | def __init__(self, dim_in: int, dim_out: int, num_heads: int, num_inds: int, ln: bool=False): |
| | super(ISAB, self).__init__() |
| |
|
| | self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out)) |
| | nn.init.xavier_uniform_(self.I) |
| | self.mab0 = MAB(dim_out, dim_in, dim_out, num_heads, ln=ln) |
| | self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, ln=ln) |
| |
|
| | def forward(self, X): |
| | H = self.mab0(self.I.repeat(X.size(0), 1, 1), X) |
| |
|
| | return self.mab1(X, H) |
| | |
| |
|
| | class Flowformer(PreTrainedModel): |
| | r""" |
| | This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it |
| | as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and |
| | behavior. |
| | Parameters: |
| | config (`FlowformerConfig`): Model configuration class with all the parameters of the model. |
| | Initializing with a config file does not load the weights associated with the model, only the |
| | configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. |
| | """ |
| | config_class = FlowformerConfig |
| | |
| | def __init__(self, config: FlowformerConfig): |
| | super().__init__(config) |
| |
|
| | |
| | dim_input = config.dim_input |
| | dim_hidden = config.dim_hidden |
| | num_heads = config.num_heads |
| | num_inds = config.num_inds |
| | hidden_layers = config.hidden_layers |
| | layer_norm = config.layer_norm |
| | dim_output = 1 |
| | self._markers = config.markers |
| |
|
| | |
| | enc_layers = [ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=layer_norm)] |
| | for _ in range(1, hidden_layers): |
| | enc_layers.append(ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=layer_norm)) |
| | enc_layers.append(ISAB(dim_hidden, dim_input, 1, num_inds, ln=layer_norm)) |
| | self.enc = nn.Sequential(*enc_layers) |
| | |
| | |
| | dec_layers = [nn.Linear(dim_input, dim_output)] |
| | self.dec = nn.Sequential(*dec_layers) |
| |
|
| | def markers(self): |
| | return self._markers |
| |
|
| | def forward(self, tensor: torch.Tensor, labels: torch.Tensor=None, markers: list=None): |
| | r""" |
| | Args: |
| | tensor (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_markers)`): |
| | The sample used as a basis for the prediction. |
| | labels (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| | Optional ground truth lables for computing the loss. |
| | markers (`list` of length `num_markers`): |
| | The list of markers in the same order as the last dimension of the input tensor. |
| | """ |
| | B, L, M = tensor.shape |
| | if markers is not None: |
| | assert len(markers) == M, "last dimension of input must be equal to number of markers" |
| |
|
| | zeros = torch.zeros((B, L, len(self.markers())), device=tensor.device) |
| | valid_markers = [m for m in markers if m in set(self.markers()).intersection(markers)] |
| | idx = [self.markers().index(m) for m in valid_markers] |
| | zeros[:, :, idx] = tensor |
| | tensor = zeros |
| |
|
| | enc_out = self.enc(tensor) |
| | output = self.dec(enc_out)[:,:,0] |
| |
|
| | if labels is not None: |
| | return { |
| | 'loss': binary_cross_entropy_with_logits(output, labels), |
| | 'logits': output, |
| | 'prediction': torch.where(output > 0, 1, 0) |
| | } |
| | else: |
| | return { |
| | 'logits': output, |
| | 'prediction': torch.where(output > 0, 1, 0) |
| | } |
| | |