|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from utils.nn_utils import graph_to_batch |
|
|
|
|
|
from .backbone import FrameBuilder |
|
|
|
|
|
|
|
|
class BackboneModel(nn.Module): |
|
|
def __init__(self) -> None: |
|
|
super().__init__() |
|
|
self.backbone_builder = FrameBuilder() |
|
|
|
|
|
def forward(self, X, batch_ids): |
|
|
''' |
|
|
X: [N, 14, 3], predicted all-atom coordinates (obviously with a lot of invalidities) |
|
|
assume the first 4 are N, CA, C, O |
|
|
S: [N], predicted sequence |
|
|
''' |
|
|
|
|
|
|
|
|
X, mask = graph_to_batch(X, batch_ids, mask_is_pad=False) |
|
|
C = mask.long() |
|
|
|
|
|
|
|
|
R, t, q = self.backbone_builder.inverse(X, C) |
|
|
X_bb = self.backbone_builder(R, t, C) |
|
|
X = torch.cat([X_bb, X[:, :, 4:]], dim=-2) |
|
|
|
|
|
|
|
|
return X[mask] |