Irwiny123's picture
添加PepGLAD初始代码
52007f8
#!/usr/bin/python
# -*- coding:utf-8 -*-
import torch
import torch.nn as nn
from utils.nn_utils import graph_to_batch
from .backbone import FrameBuilder
class BackboneModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.backbone_builder = FrameBuilder()
def forward(self, X, batch_ids):
'''
X: [N, 14, 3], predicted all-atom coordinates (obviously with a lot of invalidities)
assume the first 4 are N, CA, C, O
S: [N], predicted sequence
'''
# to batch-form representations
X, mask = graph_to_batch(X, batch_ids, mask_is_pad=False)
C = mask.long()
# rectify backbones
R, t, q = self.backbone_builder.inverse(X, C)
X_bb = self.backbone_builder(R, t, C) # [bs, L, 4, 3]
X = torch.cat([X_bb, X[:, :, 4:]], dim=-2) # [bs, L, 14, 3]
# get back to our graph representations
return X[mask]