#!/usr/bin/python # -*- coding:utf-8 -*- import math import torch import torch.nn as nn import torch.nn.functional as F from torch_scatter import scatter_mean from data.format import VOCAB from utils import register as R from utils.oom_decorator import oom_decorator from utils.const import aas from utils.nn_utils import variadic_meshgrid from .sidechain.api import SideChainModel from .backbone.api import BackboneModel from ..dyMEAN.modules.am_egnn import AMEGNN # adaptive-multichannel egnn from ..dyMEAN.nn_utils import SeparatedAminoAcidFeature, ProteinFeature def create_encoder( name, atom_embed_size, embed_size, hidden_size, n_channel, n_layers, dropout, n_rbf, cutoff ): if name == 'dyMEAN': encoder = AMEGNN( embed_size, hidden_size, hidden_size, n_channel, channel_nf=atom_embed_size, radial_nf=hidden_size, in_edge_nf=0, n_layers=n_layers, residual=True, dropout=dropout, dense=False, n_rbf=n_rbf, cutoff=cutoff) else: raise NotImplementedError(f'Encoder {encoder} not implemented') return encoder @R.register('AutoEncoder') class AutoEncoder(nn.Module): def __init__( self, embed_size, hidden_size, latent_size, n_channel, latent_n_channel=1, mask_id=VOCAB.get_mask_idx(), latent_id=VOCAB.symbol_to_idx(VOCAB.LAT), max_position=2048, relative_position=False, CA_channel_idx=VOCAB.backbone_atoms.index('CA'), n_layers=3, dropout=0.1, mask_ratio=0.0, fix_alpha_carbon=False, h_kl_weight=0.1, z_kl_weight=0.5, coord_loss_weights={ 'Xloss': 1.0, 'ca_Xloss': 0.0, 'bb_bond_lengths_loss': 1.0, 'sc_bond_lengths_loss': 1.0, 'bb_dihedral_angles_loss': 0.0, # this significantly poison the training 'sc_chi_angles_loss': 0.5 }, coord_loss_ratio=0.5, # (1 - r)*seq + r * coord coord_prior_var=1.0, # sigma^2 anchor_at_ca=False, share_decoder=False, n_rbf=0, cutoff=0, encoder='dyMEAN', mode='codesign', # codesign, fixbb (inverse folding), fixseq (structure prediction) additional_noise_scale=0.0 # whether to add additional noise on coordinates to enhance robustness ) -> None: super().__init__() self.mask_id = mask_id self.latent_id = latent_id self.ca_channel_idx = CA_channel_idx self.n_channel = n_channel self.mask_ratio = mask_ratio self.fix_alpha_carbon = fix_alpha_carbon self.h_kl_weight = h_kl_weight self.z_kl_weight = z_kl_weight self.coord_loss_weights = coord_loss_weights self.coord_loss_ratio = coord_loss_ratio self.mode = mode self.latent_size = 0 if self.mode == 'fixseq' else latent_size self.latent_n_channel = 0 if self.mode == 'fixbb' else latent_n_channel self.anchor_at_ca = anchor_at_ca self.coord_prior_var = coord_prior_var self.additional_noise_scale = additional_noise_scale if self.fix_alpha_carbon: assert self.latent_n_channel == 1, f'Specifying fix alpha carbon (use Ca as the latent coordinate) but number of latent channels is not 1' if self.anchor_at_ca: assert self.latent_n_channel == 1, f'Specifying anchor_at_ca as True but number of latent channels is not 1' if self.mode == 'fixseq': assert self.coord_loss_ratio == 1.0, f'Specifying fixseq mode but coordination loss ratio is not 1.0: {self.coord_loss_ratio}' if self.mode == 'fixbb': assert self.coord_loss_ratio == 0.0, f'Specifying fixbb mode but coordination loss ratio is not 0.0: {self.coord_loss_ratio}' atom_embed_size = embed_size // 4 self.aa_feature = SeparatedAminoAcidFeature( embed_size, atom_embed_size, max_position=max_position, relative_position=relative_position, fix_atom_weights=True ) self.protein_feature = ProteinFeature() self.encoder = create_encoder( name = encoder, atom_embed_size = atom_embed_size, embed_size = embed_size, hidden_size = hidden_size, n_channel = n_channel, n_layers = n_layers, dropout = dropout, n_rbf = n_rbf, cutoff = cutoff ) if self.mode != 'fixbb': self.sidechain_decoder = create_encoder( name = encoder, atom_embed_size = atom_embed_size, embed_size = embed_size, hidden_size = hidden_size, n_channel = n_channel, n_layers = n_layers, dropout = dropout, n_rbf = n_rbf, cutoff = cutoff ) self.backbone_model = BackboneModel() self.sidechain_model = SideChainModel() self.W_Z_log_var = nn.Linear(hidden_size, latent_n_channel * 3) if self.mode != 'fixseq': self.W_mean = nn.Linear(hidden_size, latent_size) self.W_log_var = nn.Linear(hidden_size, latent_size) # self.hidden2latent = nn.Linear(hidden_size, latent_size) self.latent2hidden = nn.Linear(latent_size, hidden_size) self.merge_S_H = nn.Linear(hidden_size * 2, hidden_size) if share_decoder: self.seq_decoder = self.sidechain_decoder else: self.seq_decoder = create_encoder( name = encoder, atom_embed_size = atom_embed_size, embed_size = embed_size, hidden_size = hidden_size, n_channel = n_channel, n_layers = n_layers, dropout = dropout, n_rbf = n_rbf, cutoff = cutoff ) # residue type index mapping, from original index to 0~20, 0 is unk self.unk_idx = 0 self.s_map = [0 for _ in range(len(VOCAB))] self.s_remap = [0 for _ in range(len(aas) + 1)] self.s_remap[0] = VOCAB.symbol_to_idx(VOCAB.UNK) for i, (a, _) in enumerate(aas): original_idx = VOCAB.symbol_to_idx(a) self.s_map[original_idx] = i + 1 # start from 1 self.s_remap[i + 1] = original_idx self.s_map = nn.Parameter(torch.tensor(self.s_map, dtype=torch.long), requires_grad=False) self.s_remap = nn.Parameter(torch.tensor(self.s_remap, dtype=torch.long), requires_grad=False) if self.mode != 'fixseq': self.seq_linear = nn.Linear(hidden_size, len(self.s_remap)) @torch.no_grad() def prepare_inputs(self, X, S, mask, atom_mask, lengths): # batch ids batch_ids = self.get_batch_ids(S, lengths) # edges row, col = variadic_meshgrid( input1=torch.arange(batch_ids.shape[0], device=batch_ids.device), size1=lengths, input2=torch.arange(batch_ids.shape[0], device=batch_ids.device), size2=lengths, ) # (row, col) is_ctx = mask[row] == mask[col] is_inter = ~is_ctx ctx_edges = torch.stack([row[is_ctx], col[is_ctx]], dim=0) # [2, Ec] inter_edges = torch.stack([row[is_inter], col[is_inter]], dim=0) # [2, Ei] return ctx_edges, inter_edges, batch_ids @torch.no_grad() def get_batch_ids(self, S, lengths): batch_ids = torch.zeros_like(S) batch_ids[torch.cumsum(lengths, dim=0)[:-1]] = 1 batch_ids.cumsum_(dim=0) return batch_ids def rsample(self, H, Z, Z_centers, no_randomness=False): ''' H: [N, latent_size] Z: [N, latent_channel, 3] Z_centers: [N, latent_channel, 3] ''' if self.mode != 'fixseq': data_size = H.shape[0] H_mean = self.W_mean(H) H_log_var = -torch.abs(self.W_log_var(H)) #Following Mueller et al., z_log_var is log(\sigma^2) H_kl_loss = -0.5 * torch.sum(1.0 + H_log_var - H_mean * H_mean - torch.exp(H_log_var)) / data_size H_vecs = H_mean if no_randomness else H_mean + torch.exp(H_log_var / 2) * torch.randn_like(H_mean) else: H_vecs, H_kl_loss = None, 0 if self.mode != 'fixbb': data_size = Z.shape[0] Z_mean_delta = Z - Z_centers Z_log_var = -torch.abs(self.W_Z_log_var(H)).view(-1, self.latent_n_channel, 3) Z_kl_loss = -0.5 * torch.sum(1.0 + Z_log_var - math.log(self.coord_prior_var) - Z_mean_delta * Z_mean_delta / self.coord_prior_var - torch.exp(Z_log_var) / self.coord_prior_var) / data_size Z_vecs = Z if no_randomness else Z + torch.exp(Z_log_var / 2) * torch.randn_like(Z) else: Z_vecs, Z_kl_loss = None, 0 return H_vecs, Z_vecs, H_kl_loss, Z_kl_loss def _get_latent_channels(self, X, atom_mask): atom_weights = atom_mask.float() # 1 for atom, 0 for padding/missing, [N, 14] if hasattr(self, 'fix_alpha_carbon') and self.fix_alpha_carbon: return X[:, self.ca_channel_idx].unsqueeze(1) # use alpha carbon as latent channel elif self.latent_n_channel == 1: X = (X * atom_weights.unsqueeze(-1)).sum(1) # [N, 3] X = X / atom_weights.sum(-1).unsqueeze(-1) # [N, 3] return X.unsqueeze(1) elif self.latent_n_channel == 5: bb_X = X[:, :4] X = (X * atom_weights.unsqueeze(-1)).sum(1) # [N, 3] X = X / atom_weights.sum(-1).unsqueeze(-1) # [N, 3] X = torch.cat([bb_X, X.unsqueeze(1)], dim=1) # [N, 5, 3] return X else: raise NotImplementedError(f'Latent number of channels: {self.latent_n_channel} not implemented') def _get_latent_channel_anchors(self, X, atom_mask): if self.anchor_at_ca: return X[:, self.ca_channel_idx].unsqueeze(1) else: return self._get_latent_channels(X, atom_mask) def _fill_latent_channels(self, latent_X): if self.latent_n_channel == 1: return latent_X.repeat(1, self.n_channel, 1) elif self.latent_n_channel == 5: bb_X = latent_X[:, :4] sc_X = latent_X[:, 4].unsqueeze(1).repeat(1, self.n_channel - 4, 1) return torch.cat([bb_X, sc_X], dim=1) else: raise NotImplementedError(f'Latent number of channels: {self.latent_n_channel} not implemented') def _remove_sidechain_atom_mask(self, atom_mask, mask_generate): atom_mask = atom_mask.clone() bb_mask = atom_mask[mask_generate] bb_mask[:, 4:] = 0 # only backbone atoms are visible atom_mask[mask_generate] = bb_mask return atom_mask @torch.no_grad() def _mask_pep(self, S, atom_mask, mask_generate): assert self.mask_ratio > 0 S, atom_mask = S.clone(), atom_mask.clone() pep_S = S[mask_generate] do_mask = torch.rand_like(pep_S, dtype=torch.float) < self.mask_ratio pep_S[do_mask] = self.mask_id S[mask_generate] = pep_S atom_mask[mask_generate ]= self._remove_sidechain_atom_mask(atom_mask[mask_generate], do_mask) return S, atom_mask def encode(self, X, S, mask, position_ids, lengths, atom_mask, no_randomness=False): true_X = X.clone() ctx_edges, inter_edges, batch_ids = self.prepare_inputs(X, S, mask, atom_mask, lengths) H_0, (atom_embeddings, _) = self.aa_feature(S, position_ids) edges = torch.cat([ctx_edges, inter_edges], dim=1) atom_weights = atom_mask.float() # 1 for atom, 0 for padding/missing, [N, 14] H, pred_X = self.encoder(H_0, X, edges, channel_attr=atom_embeddings, channel_weights=atom_weights) H = H[mask] if self.mode != 'fixbb': if hasattr(self, 'fix_alpha_carbon') and self.fix_alpha_carbon: Z = self._get_latent_channels(true_X, atom_mask) else: Z = self._get_latent_channels(pred_X, atom_mask) Z_centers = self._get_latent_channel_anchors(true_X, atom_mask) Z, Z_centers = Z[mask], Z_centers[mask] else: Z, Z_centers = None, None # resample latent_H, latent_X, H_kl_loss, X_kl_loss = self.rsample(H, Z, Z_centers, no_randomness) return latent_H, latent_X, H_kl_loss, X_kl_loss def decode(self, X, S, H, Z, mask, position_ids, lengths, atom_mask, teacher_forcing): X, S, atom_mask = X.clone(), S.clone(), atom_mask.clone() true_S = S[mask].clone() if self.mode != 'fixbb': # fill coordinates with latent points X[mask] = self._fill_latent_channels(Z) if self.mode != 'fixseq': # fill sequences with mask token S[mask] = self.latent_id H_from_latent = self.latent2hidden(H) if self.mode == 'fixbb': # only backbone atoms are visible atom_mask = self._remove_sidechain_atom_mask(atom_mask, mask) elif self.mode == 'codesign': # all channels are visible when deciding the sequence (all dummy atoms) atom_mask[mask] = 1 else: # fixseq mode does not need to change atom mask pass ctx_edges, inter_edges, batch_ids = self.prepare_inputs(X, S, mask, atom_mask, lengths) edges = torch.cat([ctx_edges, inter_edges], dim=1) # decode sequence if self.mode != 'fixseq': H_0, (atom_embeddings, _) = self.aa_feature(S, position_ids) H_0 = H_0.clone() H_0[mask] = H_from_latent # TODO: how about the position encoding H, _ = self.seq_decoder(H_0, X, edges, channel_attr=atom_embeddings, channel_weights=atom_mask.float()) pred_S_logits = self.seq_linear(H[mask]) # [Ntgt, 21] S = S.clone() if teacher_forcing: # teacher forcing S[mask] = true_S else: # inference S[mask] = self.s_remap[torch.argmax(pred_S_logits, dim=-1)] else: pred_S_logits = None # decode sidechain if self.mode != 'fixbb': H_0, (atom_embeddings, atom_weights) = self.aa_feature(S, position_ids) H_0 = H_0.clone() if self.mode != 'fixseq': H_0[mask] = self.merge_S_H(torch.cat([H_from_latent, H_0[mask]], dim=-1)) # H_0[mask] = H_from_latent atom_mask = atom_mask.clone() atom_mask[mask] = atom_weights.bool()[mask] & atom_mask[mask] # reset atomic visibility of the reconstruction part with the decoded sequence _, pred_X = self.sidechain_decoder(H_0, X, edges, channel_attr=atom_embeddings, channel_weights=atom_mask.float()) pred_X = pred_X[mask] else: pred_X = None return pred_S_logits, pred_X @oom_decorator def forward(self, X, S, mask, position_ids, lengths, atom_mask, teacher_forcing=True): true_X, true_S = X[mask].clone(), S[mask].clone() # encode: H (N*d), Z (N*3) if self.mask_ratio > 0: input_S, input_atom_mask = self._mask_pep(S, atom_mask, mask) else: input_S, input_atom_mask = S, atom_mask H, Z, H_kl_loss, Z_kl_loss = self.encode(X, input_S, mask, position_ids, lengths, input_atom_mask) if self.mode != 'fixbb': coord_reg_loss = F.mse_loss(Z, self._get_latent_channel_anchors(true_X, atom_mask[mask])) else: coord_reg_loss = 0 # add noise to improve robustness if self.training: noise = torch.randn_like(Z) * getattr(self, 'additional_noise_scale', 0.0) Z = Z + noise # decode: S (N), Z (N * 14 * 3) with atom mask recon_S_logits, recon_X = self.decode(X, S, H, Z, mask, position_ids, lengths, atom_mask, teacher_forcing) # sequence reconstruction loss if self.mode != 'fixseq': seq_recon_loss = F.cross_entropy(recon_S_logits, self.s_map[true_S]) # aar with torch.no_grad(): aar = (torch.argmax(recon_S_logits, dim=-1) == self.s_map[true_S]).sum() / len(recon_S_logits) else: seq_recon_loss, aar = 0, 1.0 # coordinates reconstruction loss if self.mode != 'fixbb': xloss_mask = atom_mask[mask] batch_ids = self.get_batch_ids(S, lengths)[mask] segment_ids = torch.ones_like(true_S, device=true_S.device, dtype=torch.long) if self.n_channel == 4: # backbone only loss_profile = {} else: true_struct_profile = self.protein_feature.get_struct_profile(true_X, true_S, batch_ids, self.aa_feature, segment_ids, xloss_mask) recon_struct_profile = self.protein_feature.get_struct_profile(recon_X, true_S, batch_ids, self.aa_feature, segment_ids, xloss_mask) loss_profile = { key + '_loss': F.l1_loss(recon_struct_profile[key], true_struct_profile[key]) for key in recon_struct_profile } # mse xloss = F.mse_loss(recon_X[xloss_mask], true_X[xloss_mask]) loss_profile['Xloss'] = xloss # CA mse ca_xloss_mask = xloss_mask[:, self.ca_channel_idx] ca_xloss = F.mse_loss(recon_X[:, self.ca_channel_idx][ca_xloss_mask], true_X[:, self.ca_channel_idx][ca_xloss_mask]) loss_profile['ca_Xloss'] = ca_xloss struct_recon_loss = 0 for name in loss_profile: struct_recon_loss = struct_recon_loss + self.coord_loss_weights[name] * loss_profile[name] else: struct_recon_loss, loss_profile = 0, {} recon_loss = (1 - self.coord_loss_ratio) * (seq_recon_loss + self.h_kl_weight * H_kl_loss) + \ self.coord_loss_ratio * (struct_recon_loss + self.z_kl_weight * Z_kl_loss) return recon_loss, (seq_recon_loss, aar), (struct_recon_loss, loss_profile), (H_kl_loss, Z_kl_loss, coord_reg_loss) def _reconstruct(self, X, S, mask, position_ids, lengths, atom_mask, given_laten_H=None, given_latent_X=None, allow_unk=False, optimize_sidechain=True, idealize=False, no_randomness=False): if given_laten_H is None and given_latent_X is None: # encode: H (N*d), Z (N*3) H, Z, _, _ = self.encode(X, S, mask, position_ids, lengths, atom_mask, no_randomness=no_randomness) else: H, Z = given_laten_H, given_latent_X # decode: S (N), Z (N * 14 * 3) with atom mask recon_S_logits, recon_X = self.decode(X, S, H, Z, mask, position_ids, lengths, atom_mask, teacher_forcing=False) batch_ids = self.get_batch_ids(S, lengths)[mask] if self.mode != 'fixseq': if not allow_unk: recon_S_logits[:, 0] = float('-inf') # map aa index back recon_S = self.s_remap[torch.argmax(recon_S_logits, dim=-1)] # ppls snll_all = F.cross_entropy(recon_S_logits, torch.argmax(recon_S_logits, dim=-1), reduction='none') batch_ppls = scatter_mean(snll_all, batch_ids, dim=0) else: recon_S = S[mask] batch_ppls = torch.zeros(batch_ids.max() + 1, device=recon_X.device).float() if self.mode == 'fixseq' or (self.mode != 'fixbb' and idealize): # rectify backbone recon_X = self.backbone_model(recon_X, batch_ids) # rectify sidechain recon_X = self.sidechain_model(recon_X, recon_S, batch_ids, optimize_sidechain) return recon_X, recon_S, batch_ppls, batch_ids @torch.no_grad() def test(self, X, S, mask, position_ids, lengths, atom_mask, given_laten_H=None, given_latent_X=None, return_tensor=False, allow_unk=False, optimize_sidechain=True, idealize=False, n_iter=1): no_randomness = given_laten_H is not None # in reconstruction mode, with latent variable derived from diffusion model for i in range(n_iter): recon_X, recon_S, batch_ppls, batch_ids = self._reconstruct(X, S, mask, position_ids, lengths, atom_mask, given_laten_H, given_latent_X, allow_unk, optimize_sidechain, idealize, no_randomness) X, S = X.clone(), S.clone() if self.mode != 'fixbb': X[mask] = recon_X if self.mode != 'fixseq': S[mask] = recon_S given_laten_H, given_latent_X = None, None # let the model encode and decode for later iterations if return_tensor: return recon_X, recon_S, batch_ppls batch_X, batch_S = [], [] batch_ppls = batch_ppls.tolist() for i, l in enumerate(lengths): cur_mask = batch_ids == i if self.mode == 'fixbb': batch_X.append(None) else: batch_X.append(recon_X[cur_mask].tolist()) if self.mode == 'fixseq': batch_S.append(None) else: batch_S.append(''.join([VOCAB.idx_to_symbol(s) for s in recon_S[cur_mask]])) return batch_X, batch_S, batch_ppls