File size: 956 Bytes
52007f8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 | #!/usr/bin/python
# -*- coding:utf-8 -*-
import torch
import torch.nn as nn
from utils.nn_utils import graph_to_batch
from .backbone import FrameBuilder
class BackboneModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.backbone_builder = FrameBuilder()
def forward(self, X, batch_ids):
'''
X: [N, 14, 3], predicted all-atom coordinates (obviously with a lot of invalidities)
assume the first 4 are N, CA, C, O
S: [N], predicted sequence
'''
# to batch-form representations
X, mask = graph_to_batch(X, batch_ids, mask_is_pad=False)
C = mask.long()
# rectify backbones
R, t, q = self.backbone_builder.inverse(X, C)
X_bb = self.backbone_builder(R, t, C) # [bs, L, 4, 3]
X = torch.cat([X_bb, X[:, :, 4:]], dim=-2) # [bs, L, 14, 3]
# get back to our graph representations
return X[mask] |