Irwiny123's picture
添加PepGLAD初始代码
52007f8
#!/usr/bin/python
# -*- coding:utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.nn_utils import graph_to_batch
from data.format import VOCAB
from .sidechain import SideChainBuilder, ChiAngles
from .constants import AA20
class SideChainModel(nn.Module):
def __init__(self):
super().__init__()
self.sidechain_builder = SideChainBuilder()
self.chi_angle_calc = ChiAngles()
aa_index_inverse_mapping = torch.tensor([VOCAB.symbol_to_idx(a) for a in AA20], dtype=torch.long)
aa_index_mapping = torch.ones(aa_index_inverse_mapping.max() + 1, dtype=torch.long) * 20 # set 20 to unk (0~19 are natural amino acids)
aa_index_mapping[aa_index_inverse_mapping] = torch.arange(20)
self.register_buffer('aa_index_mapping', aa_index_mapping)
def forward(self, X, S, batch_ids, optimize=True):
'''
X: [N, 14, 3], predicted all-atom coordinates (obviously with a lot of invalidities)
S: [N], predicted sequence
'''
# do sequence index mapping from our vocabulary to the sidechain builder native indexes
S = self.aa_index_mapping[S]
# to batch-form representations
X, mask = graph_to_batch(X, batch_ids, mask_is_pad=False)
S, _ = graph_to_batch(S, batch_ids)
C = mask.long()
# rectify sidechains
chi, _ = self.chi_angle_calc(X, C, S)
ori_X = X.clone()
if optimize: # optimize chi so that the resulted atoms have similar positions with the predicted ones
with torch.enable_grad():
chi = chi.clone()
chi.requires_grad = True
delta, lr, step, last_mse = 1e-4, 1, 0, 100
optimizer = torch.optim.Adam([chi], lr=lr)
while True:
X, mask_X = self.sidechain_builder(ori_X[:, :, :4], C, S, chi)
mask_X = mask_X.squeeze(-1) # [bs, L, 14]
X, mask_X = X[:, :, 4:], mask_X[:, :, 4:].bool()
mse = F.mse_loss(X[mask_X], ori_X[:, :, 4:][mask_X]) # only on sidechain
if torch.abs(mse - last_mse) < delta:
break
mse.backward()
# chi.data = chi.data - lr * chi.grad.data
# chi.grad.zero_()
optimizer.step()
optimizer.zero_grad()
last_mse = mse.detach()
step += 1
chi = chi.detach()
# print(f'optimized {step} steps, mse {last_mse}')
X, _ = self.sidechain_builder(ori_X[:, :, :4], C, S, chi)
# get back to our graph representations
return X[mask]