Irwiny123's picture
添加PepGLAD初始代码
52007f8
#!/usr/bin/python
# -*- coding:utf-8 -*-
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_scatter import scatter_mean
from data.format import VOCAB
from utils import register as R
from utils.oom_decorator import oom_decorator
from utils.const import aas
from utils.nn_utils import variadic_meshgrid
from .sidechain.api import SideChainModel
from .backbone.api import BackboneModel
from ..dyMEAN.modules.am_egnn import AMEGNN # adaptive-multichannel egnn
from ..dyMEAN.nn_utils import SeparatedAminoAcidFeature, ProteinFeature
def create_encoder(
name,
atom_embed_size,
embed_size,
hidden_size,
n_channel,
n_layers,
dropout,
n_rbf,
cutoff
):
if name == 'dyMEAN':
encoder = AMEGNN(
embed_size, hidden_size, hidden_size, n_channel,
channel_nf=atom_embed_size, radial_nf=hidden_size,
in_edge_nf=0, n_layers=n_layers, residual=True,
dropout=dropout, dense=False, n_rbf=n_rbf, cutoff=cutoff)
else:
raise NotImplementedError(f'Encoder {encoder} not implemented')
return encoder
@R.register('AutoEncoder')
class AutoEncoder(nn.Module):
def __init__(
self,
embed_size,
hidden_size,
latent_size,
n_channel,
latent_n_channel=1,
mask_id=VOCAB.get_mask_idx(),
latent_id=VOCAB.symbol_to_idx(VOCAB.LAT),
max_position=2048,
relative_position=False,
CA_channel_idx=VOCAB.backbone_atoms.index('CA'),
n_layers=3,
dropout=0.1,
mask_ratio=0.0,
fix_alpha_carbon=False,
h_kl_weight=0.1,
z_kl_weight=0.5,
coord_loss_weights={
'Xloss': 1.0,
'ca_Xloss': 0.0,
'bb_bond_lengths_loss': 1.0,
'sc_bond_lengths_loss': 1.0,
'bb_dihedral_angles_loss': 0.0, # this significantly poison the training
'sc_chi_angles_loss': 0.5
},
coord_loss_ratio=0.5, # (1 - r)*seq + r * coord
coord_prior_var=1.0, # sigma^2
anchor_at_ca=False,
share_decoder=False,
n_rbf=0,
cutoff=0,
encoder='dyMEAN',
mode='codesign', # codesign, fixbb (inverse folding), fixseq (structure prediction)
additional_noise_scale=0.0 # whether to add additional noise on coordinates to enhance robustness
) -> None:
super().__init__()
self.mask_id = mask_id
self.latent_id = latent_id
self.ca_channel_idx = CA_channel_idx
self.n_channel = n_channel
self.mask_ratio = mask_ratio
self.fix_alpha_carbon = fix_alpha_carbon
self.h_kl_weight = h_kl_weight
self.z_kl_weight = z_kl_weight
self.coord_loss_weights = coord_loss_weights
self.coord_loss_ratio = coord_loss_ratio
self.mode = mode
self.latent_size = 0 if self.mode == 'fixseq' else latent_size
self.latent_n_channel = 0 if self.mode == 'fixbb' else latent_n_channel
self.anchor_at_ca = anchor_at_ca
self.coord_prior_var = coord_prior_var
self.additional_noise_scale = additional_noise_scale
if self.fix_alpha_carbon: assert self.latent_n_channel == 1, f'Specifying fix alpha carbon (use Ca as the latent coordinate) but number of latent channels is not 1'
if self.anchor_at_ca: assert self.latent_n_channel == 1, f'Specifying anchor_at_ca as True but number of latent channels is not 1'
if self.mode == 'fixseq': assert self.coord_loss_ratio == 1.0, f'Specifying fixseq mode but coordination loss ratio is not 1.0: {self.coord_loss_ratio}'
if self.mode == 'fixbb': assert self.coord_loss_ratio == 0.0, f'Specifying fixbb mode but coordination loss ratio is not 0.0: {self.coord_loss_ratio}'
atom_embed_size = embed_size // 4
self.aa_feature = SeparatedAminoAcidFeature(
embed_size, atom_embed_size,
max_position=max_position,
relative_position=relative_position,
fix_atom_weights=True
)
self.protein_feature = ProteinFeature()
self.encoder = create_encoder(
name = encoder,
atom_embed_size = atom_embed_size,
embed_size = embed_size,
hidden_size = hidden_size,
n_channel = n_channel,
n_layers = n_layers,
dropout = dropout,
n_rbf = n_rbf,
cutoff = cutoff
)
if self.mode != 'fixbb':
self.sidechain_decoder = create_encoder(
name = encoder,
atom_embed_size = atom_embed_size,
embed_size = embed_size,
hidden_size = hidden_size,
n_channel = n_channel,
n_layers = n_layers,
dropout = dropout,
n_rbf = n_rbf,
cutoff = cutoff
)
self.backbone_model = BackboneModel()
self.sidechain_model = SideChainModel()
self.W_Z_log_var = nn.Linear(hidden_size, latent_n_channel * 3)
if self.mode != 'fixseq':
self.W_mean = nn.Linear(hidden_size, latent_size)
self.W_log_var = nn.Linear(hidden_size, latent_size)
# self.hidden2latent = nn.Linear(hidden_size, latent_size)
self.latent2hidden = nn.Linear(latent_size, hidden_size)
self.merge_S_H = nn.Linear(hidden_size * 2, hidden_size)
if share_decoder:
self.seq_decoder = self.sidechain_decoder
else:
self.seq_decoder = create_encoder(
name = encoder,
atom_embed_size = atom_embed_size,
embed_size = embed_size,
hidden_size = hidden_size,
n_channel = n_channel,
n_layers = n_layers,
dropout = dropout,
n_rbf = n_rbf,
cutoff = cutoff
)
# residue type index mapping, from original index to 0~20, 0 is unk
self.unk_idx = 0
self.s_map = [0 for _ in range(len(VOCAB))]
self.s_remap = [0 for _ in range(len(aas) + 1)]
self.s_remap[0] = VOCAB.symbol_to_idx(VOCAB.UNK)
for i, (a, _) in enumerate(aas):
original_idx = VOCAB.symbol_to_idx(a)
self.s_map[original_idx] = i + 1 # start from 1
self.s_remap[i + 1] = original_idx
self.s_map = nn.Parameter(torch.tensor(self.s_map, dtype=torch.long), requires_grad=False)
self.s_remap = nn.Parameter(torch.tensor(self.s_remap, dtype=torch.long), requires_grad=False)
if self.mode != 'fixseq':
self.seq_linear = nn.Linear(hidden_size, len(self.s_remap))
@torch.no_grad()
def prepare_inputs(self, X, S, mask, atom_mask, lengths):
# batch ids
batch_ids = self.get_batch_ids(S, lengths)
# edges
row, col = variadic_meshgrid(
input1=torch.arange(batch_ids.shape[0], device=batch_ids.device),
size1=lengths,
input2=torch.arange(batch_ids.shape[0], device=batch_ids.device),
size2=lengths,
) # (row, col)
is_ctx = mask[row] == mask[col]
is_inter = ~is_ctx
ctx_edges = torch.stack([row[is_ctx], col[is_ctx]], dim=0) # [2, Ec]
inter_edges = torch.stack([row[is_inter], col[is_inter]], dim=0) # [2, Ei]
return ctx_edges, inter_edges, batch_ids
@torch.no_grad()
def get_batch_ids(self, S, lengths):
batch_ids = torch.zeros_like(S)
batch_ids[torch.cumsum(lengths, dim=0)[:-1]] = 1
batch_ids.cumsum_(dim=0)
return batch_ids
def rsample(self, H, Z, Z_centers, no_randomness=False):
'''
H: [N, latent_size]
Z: [N, latent_channel, 3]
Z_centers: [N, latent_channel, 3]
'''
if self.mode != 'fixseq':
data_size = H.shape[0]
H_mean = self.W_mean(H)
H_log_var = -torch.abs(self.W_log_var(H)) #Following Mueller et al., z_log_var is log(\sigma^2)
H_kl_loss = -0.5 * torch.sum(1.0 + H_log_var - H_mean * H_mean - torch.exp(H_log_var)) / data_size
H_vecs = H_mean if no_randomness else H_mean + torch.exp(H_log_var / 2) * torch.randn_like(H_mean)
else:
H_vecs, H_kl_loss = None, 0
if self.mode != 'fixbb':
data_size = Z.shape[0]
Z_mean_delta = Z - Z_centers
Z_log_var = -torch.abs(self.W_Z_log_var(H)).view(-1, self.latent_n_channel, 3)
Z_kl_loss = -0.5 * torch.sum(1.0 + Z_log_var - math.log(self.coord_prior_var) - Z_mean_delta * Z_mean_delta / self.coord_prior_var - torch.exp(Z_log_var) / self.coord_prior_var) / data_size
Z_vecs = Z if no_randomness else Z + torch.exp(Z_log_var / 2) * torch.randn_like(Z)
else:
Z_vecs, Z_kl_loss = None, 0
return H_vecs, Z_vecs, H_kl_loss, Z_kl_loss
def _get_latent_channels(self, X, atom_mask):
atom_weights = atom_mask.float() # 1 for atom, 0 for padding/missing, [N, 14]
if hasattr(self, 'fix_alpha_carbon') and self.fix_alpha_carbon:
return X[:, self.ca_channel_idx].unsqueeze(1) # use alpha carbon as latent channel
elif self.latent_n_channel == 1:
X = (X * atom_weights.unsqueeze(-1)).sum(1) # [N, 3]
X = X / atom_weights.sum(-1).unsqueeze(-1) # [N, 3]
return X.unsqueeze(1)
elif self.latent_n_channel == 5:
bb_X = X[:, :4]
X = (X * atom_weights.unsqueeze(-1)).sum(1) # [N, 3]
X = X / atom_weights.sum(-1).unsqueeze(-1) # [N, 3]
X = torch.cat([bb_X, X.unsqueeze(1)], dim=1) # [N, 5, 3]
return X
else:
raise NotImplementedError(f'Latent number of channels: {self.latent_n_channel} not implemented')
def _get_latent_channel_anchors(self, X, atom_mask):
if self.anchor_at_ca:
return X[:, self.ca_channel_idx].unsqueeze(1)
else:
return self._get_latent_channels(X, atom_mask)
def _fill_latent_channels(self, latent_X):
if self.latent_n_channel == 1:
return latent_X.repeat(1, self.n_channel, 1)
elif self.latent_n_channel == 5:
bb_X = latent_X[:, :4]
sc_X = latent_X[:, 4].unsqueeze(1).repeat(1, self.n_channel - 4, 1)
return torch.cat([bb_X, sc_X], dim=1)
else:
raise NotImplementedError(f'Latent number of channels: {self.latent_n_channel} not implemented')
def _remove_sidechain_atom_mask(self, atom_mask, mask_generate):
atom_mask = atom_mask.clone()
bb_mask = atom_mask[mask_generate]
bb_mask[:, 4:] = 0 # only backbone atoms are visible
atom_mask[mask_generate] = bb_mask
return atom_mask
@torch.no_grad()
def _mask_pep(self, S, atom_mask, mask_generate):
assert self.mask_ratio > 0
S, atom_mask = S.clone(), atom_mask.clone()
pep_S = S[mask_generate]
do_mask = torch.rand_like(pep_S, dtype=torch.float) < self.mask_ratio
pep_S[do_mask] = self.mask_id
S[mask_generate] = pep_S
atom_mask[mask_generate ]= self._remove_sidechain_atom_mask(atom_mask[mask_generate], do_mask)
return S, atom_mask
def encode(self, X, S, mask, position_ids, lengths, atom_mask, no_randomness=False):
true_X = X.clone()
ctx_edges, inter_edges, batch_ids = self.prepare_inputs(X, S, mask, atom_mask, lengths)
H_0, (atom_embeddings, _) = self.aa_feature(S, position_ids)
edges = torch.cat([ctx_edges, inter_edges], dim=1)
atom_weights = atom_mask.float() # 1 for atom, 0 for padding/missing, [N, 14]
H, pred_X = self.encoder(H_0, X, edges, channel_attr=atom_embeddings, channel_weights=atom_weights)
H = H[mask]
if self.mode != 'fixbb':
if hasattr(self, 'fix_alpha_carbon') and self.fix_alpha_carbon:
Z = self._get_latent_channels(true_X, atom_mask)
else:
Z = self._get_latent_channels(pred_X, atom_mask)
Z_centers = self._get_latent_channel_anchors(true_X, atom_mask)
Z, Z_centers = Z[mask], Z_centers[mask]
else:
Z, Z_centers = None, None
# resample
latent_H, latent_X, H_kl_loss, X_kl_loss = self.rsample(H, Z, Z_centers, no_randomness)
return latent_H, latent_X, H_kl_loss, X_kl_loss
def decode(self, X, S, H, Z, mask, position_ids, lengths, atom_mask, teacher_forcing):
X, S, atom_mask = X.clone(), S.clone(), atom_mask.clone()
true_S = S[mask].clone()
if self.mode != 'fixbb': # fill coordinates with latent points
X[mask] = self._fill_latent_channels(Z)
if self.mode != 'fixseq': # fill sequences with mask token
S[mask] = self.latent_id
H_from_latent = self.latent2hidden(H)
if self.mode == 'fixbb': # only backbone atoms are visible
atom_mask = self._remove_sidechain_atom_mask(atom_mask, mask)
elif self.mode == 'codesign': # all channels are visible when deciding the sequence (all dummy atoms)
atom_mask[mask] = 1
else: # fixseq mode does not need to change atom mask
pass
ctx_edges, inter_edges, batch_ids = self.prepare_inputs(X, S, mask, atom_mask, lengths)
edges = torch.cat([ctx_edges, inter_edges], dim=1)
# decode sequence
if self.mode != 'fixseq':
H_0, (atom_embeddings, _) = self.aa_feature(S, position_ids)
H_0 = H_0.clone()
H_0[mask] = H_from_latent # TODO: how about the position encoding
H, _ = self.seq_decoder(H_0, X, edges, channel_attr=atom_embeddings, channel_weights=atom_mask.float())
pred_S_logits = self.seq_linear(H[mask]) # [Ntgt, 21]
S = S.clone()
if teacher_forcing: # teacher forcing
S[mask] = true_S
else: # inference
S[mask] = self.s_remap[torch.argmax(pred_S_logits, dim=-1)]
else:
pred_S_logits = None
# decode sidechain
if self.mode != 'fixbb':
H_0, (atom_embeddings, atom_weights) = self.aa_feature(S, position_ids)
H_0 = H_0.clone()
if self.mode != 'fixseq':
H_0[mask] = self.merge_S_H(torch.cat([H_from_latent, H_0[mask]], dim=-1))
# H_0[mask] = H_from_latent
atom_mask = atom_mask.clone()
atom_mask[mask] = atom_weights.bool()[mask] & atom_mask[mask] # reset atomic visibility of the reconstruction part with the decoded sequence
_, pred_X = self.sidechain_decoder(H_0, X, edges, channel_attr=atom_embeddings, channel_weights=atom_mask.float())
pred_X = pred_X[mask]
else:
pred_X = None
return pred_S_logits, pred_X
@oom_decorator
def forward(self, X, S, mask, position_ids, lengths, atom_mask, teacher_forcing=True):
true_X, true_S = X[mask].clone(), S[mask].clone()
# encode: H (N*d), Z (N*3)
if self.mask_ratio > 0:
input_S, input_atom_mask = self._mask_pep(S, atom_mask, mask)
else:
input_S, input_atom_mask = S, atom_mask
H, Z, H_kl_loss, Z_kl_loss = self.encode(X, input_S, mask, position_ids, lengths, input_atom_mask)
if self.mode != 'fixbb':
coord_reg_loss = F.mse_loss(Z, self._get_latent_channel_anchors(true_X, atom_mask[mask]))
else:
coord_reg_loss = 0
# add noise to improve robustness
if self.training:
noise = torch.randn_like(Z) * getattr(self, 'additional_noise_scale', 0.0)
Z = Z + noise
# decode: S (N), Z (N * 14 * 3) with atom mask
recon_S_logits, recon_X = self.decode(X, S, H, Z, mask, position_ids, lengths, atom_mask, teacher_forcing)
# sequence reconstruction loss
if self.mode != 'fixseq':
seq_recon_loss = F.cross_entropy(recon_S_logits, self.s_map[true_S])
# aar
with torch.no_grad():
aar = (torch.argmax(recon_S_logits, dim=-1) == self.s_map[true_S]).sum() / len(recon_S_logits)
else:
seq_recon_loss, aar = 0, 1.0
# coordinates reconstruction loss
if self.mode != 'fixbb':
xloss_mask = atom_mask[mask]
batch_ids = self.get_batch_ids(S, lengths)[mask]
segment_ids = torch.ones_like(true_S, device=true_S.device, dtype=torch.long)
if self.n_channel == 4: # backbone only
loss_profile = {}
else:
true_struct_profile = self.protein_feature.get_struct_profile(true_X, true_S, batch_ids, self.aa_feature, segment_ids, xloss_mask)
recon_struct_profile = self.protein_feature.get_struct_profile(recon_X, true_S, batch_ids, self.aa_feature, segment_ids, xloss_mask)
loss_profile = { key + '_loss': F.l1_loss(recon_struct_profile[key], true_struct_profile[key]) for key in recon_struct_profile }
# mse
xloss = F.mse_loss(recon_X[xloss_mask], true_X[xloss_mask])
loss_profile['Xloss'] = xloss
# CA mse
ca_xloss_mask = xloss_mask[:, self.ca_channel_idx]
ca_xloss = F.mse_loss(recon_X[:, self.ca_channel_idx][ca_xloss_mask], true_X[:, self.ca_channel_idx][ca_xloss_mask])
loss_profile['ca_Xloss'] = ca_xloss
struct_recon_loss = 0
for name in loss_profile:
struct_recon_loss = struct_recon_loss + self.coord_loss_weights[name] * loss_profile[name]
else:
struct_recon_loss, loss_profile = 0, {}
recon_loss = (1 - self.coord_loss_ratio) * (seq_recon_loss + self.h_kl_weight * H_kl_loss) + \
self.coord_loss_ratio * (struct_recon_loss + self.z_kl_weight * Z_kl_loss)
return recon_loss, (seq_recon_loss, aar), (struct_recon_loss, loss_profile), (H_kl_loss, Z_kl_loss, coord_reg_loss)
def _reconstruct(self, X, S, mask, position_ids, lengths, atom_mask, given_laten_H=None, given_latent_X=None, allow_unk=False, optimize_sidechain=True, idealize=False, no_randomness=False):
if given_laten_H is None and given_latent_X is None:
# encode: H (N*d), Z (N*3)
H, Z, _, _ = self.encode(X, S, mask, position_ids, lengths, atom_mask, no_randomness=no_randomness)
else:
H, Z = given_laten_H, given_latent_X
# decode: S (N), Z (N * 14 * 3) with atom mask
recon_S_logits, recon_X = self.decode(X, S, H, Z, mask, position_ids, lengths, atom_mask, teacher_forcing=False)
batch_ids = self.get_batch_ids(S, lengths)[mask]
if self.mode != 'fixseq':
if not allow_unk:
recon_S_logits[:, 0] = float('-inf')
# map aa index back
recon_S = self.s_remap[torch.argmax(recon_S_logits, dim=-1)]
# ppls
snll_all = F.cross_entropy(recon_S_logits, torch.argmax(recon_S_logits, dim=-1), reduction='none')
batch_ppls = scatter_mean(snll_all, batch_ids, dim=0)
else:
recon_S = S[mask]
batch_ppls = torch.zeros(batch_ids.max() + 1, device=recon_X.device).float()
if self.mode == 'fixseq' or (self.mode != 'fixbb' and idealize):
# rectify backbone
recon_X = self.backbone_model(recon_X, batch_ids)
# rectify sidechain
recon_X = self.sidechain_model(recon_X, recon_S, batch_ids, optimize_sidechain)
return recon_X, recon_S, batch_ppls, batch_ids
@torch.no_grad()
def test(self, X, S, mask, position_ids, lengths, atom_mask, given_laten_H=None, given_latent_X=None, return_tensor=False, allow_unk=False, optimize_sidechain=True, idealize=False, n_iter=1):
no_randomness = given_laten_H is not None # in reconstruction mode, with latent variable derived from diffusion model
for i in range(n_iter):
recon_X, recon_S, batch_ppls, batch_ids = self._reconstruct(X, S, mask, position_ids, lengths, atom_mask, given_laten_H, given_latent_X, allow_unk, optimize_sidechain, idealize, no_randomness)
X, S = X.clone(), S.clone()
if self.mode != 'fixbb':
X[mask] = recon_X
if self.mode != 'fixseq':
S[mask] = recon_S
given_laten_H, given_latent_X = None, None # let the model encode and decode for later iterations
if return_tensor:
return recon_X, recon_S, batch_ppls
batch_X, batch_S = [], []
batch_ppls = batch_ppls.tolist()
for i, l in enumerate(lengths):
cur_mask = batch_ids == i
if self.mode == 'fixbb':
batch_X.append(None)
else:
batch_X.append(recon_X[cur_mask].tolist())
if self.mode == 'fixseq':
batch_S.append(None)
else:
batch_S.append(''.join([VOCAB.idx_to_symbol(s) for s in recon_S[cur_mask]]))
return batch_X, batch_S, batch_ppls