| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| import copy |
| import math |
| from tqdm.auto import tqdm |
| import functools |
| from torch.utils.data import DataLoader |
| import os |
| import argparse |
|
|
| import pandas as pd |
|
|
| from models_con.edge import EdgeEmbedder |
| from models_con.node import NodeEmbedder |
| from pepflow.modules.common.layers import sample_from, clampped_one_hot |
| from models_con.ga import GAEncoder |
| from pepflow.modules.protein.constants import AA, BBHeavyAtom, max_num_heavyatoms |
| from pepflow.modules.common.geometry import construct_3d_basis |
| from pepflow.utils.data import mask_select_data, find_longest_true_segment, PaddingCollate |
| from pepflow.utils.misc import seed_all |
| from pepflow.utils.train import sum_weighted_losses |
| from torch.nn.utils import clip_grad_norm_ |
|
|
| from pepflow.modules.so3.dist import centered_gaussian,uniform_so3 |
| from pepflow.modules.common.geometry import batch_align, align |
|
|
| from tqdm import tqdm |
|
|
| import wandb |
|
|
| from data import so3_utils |
| from data import all_atom |
|
|
| from models_con.pep_dataloader import PepDataset |
|
|
| from pepflow.utils.misc import load_config |
| from pepflow.utils.train import recursive_to |
| from easydict import EasyDict |
|
|
| from models_con.utils import process_dic |
| from models_con.torsion import get_torsion_angle, torsions_mask |
| import models_con.torus as torus |
|
|
| import gc |
|
|
| from copy import deepcopy |
| from pepflow.utils.data import PaddingCollate |
| collate_fn = PaddingCollate(eight=False) |
| from pepflow.utils.train import recursive_to |
|
|
| resolution_to_num_atoms = { |
| 'backbone+CB': 5, |
| 'full': max_num_heavyatoms |
| } |
|
|
| class FlowModel(nn.Module): |
| def __init__(self,cfg): |
| super().__init__() |
| self._model_cfg = cfg.encoder |
| self._interpolant_cfg = cfg.interpolant |
|
|
| self.node_embedder = NodeEmbedder(cfg.encoder.node_embed_size,max_num_heavyatoms) |
| self.edge_embedder = EdgeEmbedder(cfg.encoder.edge_embed_size,max_num_heavyatoms) |
| self.ga_encoder = GAEncoder(cfg.encoder.ipa) |
|
|
| self.sample_structure = self._interpolant_cfg.sample_structure |
| self.sample_sequence = self._interpolant_cfg.sample_sequence |
|
|
| self.K = self._interpolant_cfg.seqs.num_classes |
| self.k = self._interpolant_cfg.seqs.simplex_value |
| |
| def encode(self, batch): |
| rotmats_1 = construct_3d_basis(batch['pos_heavyatom'][:, :, BBHeavyAtom.CA],batch['pos_heavyatom'][:, :, BBHeavyAtom.C],batch['pos_heavyatom'][:, :, BBHeavyAtom.N] ) |
| trans_1 = batch['pos_heavyatom'][:, :, BBHeavyAtom.CA] |
| seqs_1 = batch['aa'] |
|
|
| |
| |
| |
| angles_1 = batch['torsion_angle'] |
|
|
| context_mask = torch.logical_and(batch['mask_heavyatom'][:, :, BBHeavyAtom.CA], ~batch['generate_mask']) |
| structure_mask = context_mask if self.sample_structure else None |
| sequence_mask = context_mask if self.sample_sequence else None |
| node_embed = self.node_embedder(batch['aa'], batch['res_nb'], batch['chain_nb'], batch['pos_heavyatom'], |
| batch['mask_heavyatom'], structure_mask=structure_mask, sequence_mask=sequence_mask) |
| edge_embed = self.edge_embedder(batch['aa'], batch['res_nb'], batch['chain_nb'], batch['pos_heavyatom'], |
| batch['mask_heavyatom'], structure_mask=structure_mask, sequence_mask=sequence_mask) |
| |
| return rotmats_1, trans_1, angles_1, seqs_1, node_embed, edge_embed |
| |
| def zero_center_part(self,pos,gen_mask,res_mask): |
| """ |
| move pos by center of gen_mask |
| pos: (B,N,3) |
| gen_mask, res_mask: (B,N) |
| """ |
| center = torch.sum(pos * gen_mask[...,None], dim=1) / (torch.sum(gen_mask,dim=-1,keepdim=True) + 1e-8) |
| center = center.unsqueeze(1) |
| |
| pos = pos - center |
| pos = pos * res_mask[...,None] |
| return pos,center |
| |
| def seq_to_simplex(self,seqs): |
| return clampped_one_hot(seqs, self.K).float() * self.k * 2 - self.k |
| |
| def forward(self, batch): |
|
|
| num_batch, num_res = batch['aa'].shape |
| gen_mask,res_mask,angle_mask = batch['generate_mask'].long(),batch['res_mask'].long(),batch['torsion_angle_mask'].long() |
|
|
| |
| rotmats_1, trans_1, angles_1, seqs_1, node_embed, edge_embed = self.encode(batch) |
|
|
| |
| trans_1_c,_ = self.zero_center_part(trans_1,gen_mask,res_mask) |
| trans_1_c = trans_1 |
| seqs_1_simplex = self.seq_to_simplex(seqs_1) |
| seqs_1_prob = F.softmax(seqs_1_simplex,dim=-1) |
|
|
| with torch.no_grad(): |
| t = torch.rand((num_batch,1), device=batch['aa'].device) |
| t = t*(1-2 * self._interpolant_cfg.t_normalization_clip) + self._interpolant_cfg.t_normalization_clip |
| if self.sample_structure: |
| |
| trans_0 = torch.randn((num_batch,num_res,3), device=batch['aa'].device) * self._interpolant_cfg.trans.sigma |
| trans_0_c,_ = self.zero_center_part(trans_0,gen_mask,res_mask) |
| trans_t = (1-t[...,None])*trans_0_c + t[...,None]*trans_1_c |
| trans_t_c = torch.where(batch['generate_mask'][...,None],trans_t,trans_1_c) |
| |
| rotmats_0 = uniform_so3(num_batch,num_res,device=batch['aa'].device) |
| rotmats_t = so3_utils.geodesic_t(t[..., None], rotmats_1, rotmats_0) |
| rotmats_t = torch.where(batch['generate_mask'][...,None,None],rotmats_t,rotmats_1) |
| |
| angles_0 = torus.tor_random_uniform(angles_1.shape, device=batch['aa'].device, dtype=angles_1.dtype) |
| angles_t = torus.tor_geodesic_t(t[..., None], angles_1, angles_0) |
| angles_t = torch.where(batch['generate_mask'][...,None],angles_t,angles_1) |
| else: |
| trans_t_c = trans_1_c.detach().clone() |
| rotmats_t = rotmats_1.detach().clone() |
| angles_t = angles_1.detach().clone() |
| if self.sample_sequence: |
| |
| seqs_0_simplex = self.k * torch.randn_like(seqs_1_simplex) |
| seqs_0_prob = F.softmax(seqs_0_simplex,dim=-1) |
| seqs_t_simplex = ((1 - t[..., None]) * seqs_0_simplex) + (t[..., None] * seqs_1_simplex) |
| seqs_t_simplex = torch.where(batch['generate_mask'][...,None],seqs_t_simplex,seqs_1_simplex) |
| seqs_t_prob = F.softmax(seqs_t_simplex,dim=-1) |
| seqs_t = sample_from(seqs_t_prob) |
| seqs_t = torch.where(batch['generate_mask'],seqs_t,seqs_1) |
| else: |
| seqs_t = seqs_1.detach().clone() |
| seqs_t_simplex = seqs_1_simplex.detach().clone() |
| seqs_t_prob = seqs_1_prob.detach().clone() |
|
|
| |
| pred_rotmats_1, pred_trans_1, pred_angles_1, pred_seqs_1_prob = self.ga_encoder(t, rotmats_t, trans_t_c, angles_t, seqs_t, node_embed, edge_embed, gen_mask, res_mask) |
| pred_seqs_1 = sample_from(F.softmax(pred_seqs_1_prob,dim=-1)) |
| pred_seqs_1 = torch.where(batch['generate_mask'],pred_seqs_1,torch.clamp(seqs_1,0,19)) |
| pred_trans_1_c,_ = self.zero_center_part(pred_trans_1,gen_mask,res_mask) |
| pred_trans_1_c = pred_trans_1 |
|
|
| norm_scale = 1 / (1 - torch.min(t[...,None], torch.tensor(self._interpolant_cfg.t_normalization_clip))) |
|
|
| |
| trans_loss = torch.sum((pred_trans_1_c - trans_1_c)**2*gen_mask[...,None],dim=(-1,-2)) / (torch.sum(gen_mask,dim=-1) + 1e-8) |
| trans_loss = torch.mean(trans_loss) |
|
|
| |
| gt_rot_vf = so3_utils.calc_rot_vf(rotmats_t, rotmats_1) |
| pred_rot_vf = so3_utils.calc_rot_vf(rotmats_t, pred_rotmats_1) |
| rot_loss = torch.sum(((gt_rot_vf - pred_rot_vf) * norm_scale)**2*gen_mask[...,None],dim=(-1,-2)) / (torch.sum(gen_mask,dim=-1) + 1e-8) |
| rot_loss = torch.mean(rot_loss) |
|
|
| |
| gt_bb_atoms = all_atom.to_atom37(trans_1_c, rotmats_1)[:, :, :3] |
| pred_bb_atoms = all_atom.to_atom37(pred_trans_1_c, pred_rotmats_1)[:, :, :3] |
| |
| |
| |
| bb_atom_loss = torch.sum( |
| (gt_bb_atoms - pred_bb_atoms) ** 2 * gen_mask[..., None, None], |
| dim=(-1, -2, -3) |
| ) / (torch.sum(gen_mask,dim=-1) + 1e-8) |
| bb_atom_loss = torch.mean(bb_atom_loss) |
| |
|
|
| |
| seqs_loss = F.cross_entropy(pred_seqs_1_prob.view(-1,pred_seqs_1_prob.shape[-1]),torch.clamp(seqs_1,0,19).view(-1), reduction='none').view(pred_seqs_1_prob.shape[:-1]) |
| seqs_loss = torch.sum(seqs_loss * gen_mask, dim=-1) / (torch.sum(gen_mask,dim=-1) + 1e-8) |
| seqs_loss = torch.mean(seqs_loss) |
|
|
| |
| |
| |
| angle_mask_loss = torsions_mask.to(batch['aa'].device) |
| angle_mask_loss = angle_mask_loss[pred_seqs_1.reshape(-1)].reshape(num_batch,num_res,-1) |
| angle_mask_loss = torch.cat([angle_mask_loss,angle_mask_loss],dim=-1) |
| angle_mask_loss = torch.logical_and(batch['generate_mask'][...,None].bool(),angle_mask_loss) |
| gt_angle_vf = torus.tor_logmap(angles_t, angles_1) |
| gt_angle_vf_vec = torch.cat([torch.sin(gt_angle_vf),torch.cos(gt_angle_vf)],dim=-1) |
| pred_angle_vf = torus.tor_logmap(angles_t, pred_angles_1) |
| pred_angle_vf_vec = torch.cat([torch.sin(pred_angle_vf),torch.cos(pred_angle_vf)],dim=-1) |
| |
| angle_loss = torch.sum(((gt_angle_vf_vec - pred_angle_vf_vec) * norm_scale)**2*angle_mask_loss,dim=(-1,-2)) / (torch.sum(angle_mask_loss,dim=(-1,-2)) + 1e-8) |
| angle_loss = torch.mean(angle_loss) |
|
|
|
|
| |
| angles_1_vec = torch.cat([torch.sin(angles_1),torch.cos(angles_1)],dim=-1) |
| pred_angles_1_vec = torch.cat([torch.sin(pred_angles_1),torch.cos(pred_angles_1)],dim=-1) |
| |
| torsion_loss = torch.sum((pred_angles_1_vec - angles_1_vec)**2*angle_mask_loss,dim=(-1,-2)) / (torch.sum(angle_mask_loss,dim=(-1,-2)) + 1e-8) |
| torsion_loss = torch.mean(torsion_loss) |
|
|
| return { |
| "trans_loss": trans_loss, |
| 'rot_loss': rot_loss, |
| 'bb_atom_loss': bb_atom_loss, |
| 'seqs_loss': seqs_loss, |
| 'angle_loss': angle_loss, |
| 'torsion_loss': torsion_loss, |
| } |
| |
| @torch.no_grad() |
| def sample(self, batch, num_steps = 100, sample_bb=True, sample_ang=True, sample_seq=True): |
|
|
| num_batch, num_res = batch['aa'].shape |
| gen_mask,res_mask = batch['generate_mask'],batch['res_mask'] |
| K = self._interpolant_cfg.seqs.num_classes |
| k = self._interpolant_cfg.seqs.simplex_value |
| angle_mask_loss = torsions_mask.to(batch['aa'].device) |
|
|
| |
| rotmats_1, trans_1, angles_1, seqs_1, node_embed, edge_embed = self.encode(batch) |
| |
| trans_1_c = trans_1 |
| seqs_1_simplex = self.seq_to_simplex(seqs_1) |
| seqs_1_prob = F.softmax(seqs_1_simplex,dim=-1) |
|
|
| |
| |
| |
| |
| |
|
|
| |
| if sample_bb: |
| rotmats_0 = uniform_so3(num_batch,num_res,device=batch['aa'].device) |
| rotmats_0 = torch.where(batch['generate_mask'][...,None,None],rotmats_0,rotmats_1) |
| trans_0 = torch.randn((num_batch,num_res,3), device=batch['aa'].device) |
| |
| trans_0_c,center = self.zero_center_part(trans_0,gen_mask,res_mask) |
| trans_0_c = torch.where(batch['generate_mask'][...,None],trans_0_c,trans_1_c) |
| else: |
| rotmats_0 = rotmats_1.detach().clone() |
| trans_0_c = trans_1_c.detach().clone() |
| if sample_ang: |
| |
| angles_0 = torus.tor_random_uniform(angles_1.shape, device=batch['aa'].device, dtype=angles_1.dtype) |
| angles_0 = torch.where(batch['generate_mask'][...,None],angles_0,angles_1) |
| else: |
| angles_0 = angles_1.detach().clone() |
| if sample_seq: |
| seqs_0_simplex = k * torch.randn((num_batch,num_res,K), device=batch['aa'].device) |
| seqs_0_prob = F.softmax(seqs_0_simplex,dim=-1) |
| seqs_0 = sample_from(seqs_0_prob) |
| seqs_0 = torch.where(batch['generate_mask'],seqs_0,seqs_1) |
| seqs_0_simplex = torch.where(batch['generate_mask'][...,None],seqs_0_simplex,seqs_1_simplex) |
| else: |
| seqs_0 = seqs_1.detach().clone() |
| seqs_0_prob = seqs_1_prob.detach().clone() |
| seqs_0_simplex = seqs_1_simplex.detach().clone() |
|
|
| |
| ts = torch.linspace(1.e-2, 1.0, num_steps) |
| t_1 = ts[0] |
| |
| clean_traj = [] |
| rotmats_t_1, trans_t_1_c, angles_t_1, seqs_t_1, seqs_t_1_simplex = rotmats_0, trans_0_c, angles_0, seqs_0, seqs_0_simplex |
|
|
| |
| for t_2 in ts[1:]: |
| t = torch.ones((num_batch, 1), device=batch['aa'].device) * t_1 |
| |
| pred_rotmats_1, pred_trans_1, pred_angles_1, pred_seqs_1_prob = self.ga_encoder(t, rotmats_t_1, trans_t_1_c, angles_t_1, seqs_t_1, node_embed, edge_embed, batch['generate_mask'].long(), batch['res_mask'].long()) |
| pred_rotmats_1 = torch.where(batch['generate_mask'][...,None,None],pred_rotmats_1,rotmats_1) |
| |
| |
| pred_trans_1_c = torch.where(batch['generate_mask'][...,None],pred_trans_1,trans_1_c) |
| |
| pred_angles_1 = torch.where(batch['generate_mask'][...,None],pred_angles_1,angles_1) |
| |
| pred_seqs_1 = sample_from(F.softmax(pred_seqs_1_prob,dim=-1)) |
| pred_seqs_1 = torch.where(batch['generate_mask'],pred_seqs_1,seqs_1) |
| pred_seqs_1_simplex = self.seq_to_simplex(pred_seqs_1) |
| |
| torsion_mask = angle_mask_loss[pred_seqs_1.reshape(-1)].reshape(num_batch,num_res,-1) |
| pred_angles_1 = torch.where(torsion_mask.bool(),pred_angles_1,torch.zeros_like(pred_angles_1)) |
| if not sample_bb: |
| pred_trans_1_c = trans_1_c.detach().clone() |
| |
| pred_rotmats_1 = rotmats_1.detach().clone() |
| if not sample_ang: |
| pred_angles_1 = angles_1.detach().clone() |
| if not sample_seq: |
| pred_seqs_1 = seqs_1.detach().clone() |
| pred_seqs_1_simplex = seqs_1_simplex.detach().clone() |
| clean_traj.append({'rotmats':pred_rotmats_1.cpu(),'trans':pred_trans_1_c.cpu(),'angles':pred_angles_1.cpu(),'seqs':pred_seqs_1.cpu(),'seqs_simplex':pred_seqs_1_simplex.cpu(), |
| 'rotmats_1':rotmats_1.cpu(),'trans_1':trans_1_c.cpu(),'angles_1':angles_1.cpu(),'seqs_1':seqs_1.cpu()}) |
| |
| d_t = (t_2-t_1) * torch.ones((num_batch, 1), device=batch['aa'].device) |
| |
| trans_t_2 = trans_t_1_c + (pred_trans_1_c-trans_0_c)*d_t[...,None] |
| |
| trans_t_2_c = torch.where(batch['generate_mask'][...,None],trans_t_2,trans_1_c) |
| |
| rotmats_t_2 = so3_utils.geodesic_t(d_t[...,None] * 10, pred_rotmats_1, rotmats_t_1) |
| rotmats_t_2 = torch.where(batch['generate_mask'][...,None,None],rotmats_t_2,rotmats_1) |
| |
| angles_t_2 = torus.tor_geodesic_t(d_t[...,None],pred_angles_1, angles_t_1) |
| angles_t_2 = torch.where(batch['generate_mask'][...,None],angles_t_2,angles_1) |
| |
| seqs_t_2_simplex = seqs_t_1_simplex + (pred_seqs_1_simplex - seqs_0_simplex) * d_t[...,None] |
| seqs_t_2 = sample_from(F.softmax(seqs_t_2_simplex,dim=-1)) |
| seqs_t_2 = torch.where(batch['generate_mask'],seqs_t_2,seqs_1) |
| |
| torsion_mask = angle_mask_loss[seqs_t_2.reshape(-1)].reshape(num_batch,num_res,-1) |
| angles_t_2 = torch.where(torsion_mask.bool(),angles_t_2,torch.zeros_like(angles_t_2)) |
| |
| if not sample_bb: |
| trans_t_2_c = trans_1_c.detach().clone() |
| rotmats_t_2 = rotmats_1.detach().clone() |
| if not sample_ang: |
| angles_t_2 = angles_1.detach().clone() |
| if not sample_seq: |
| seqs_t_2 = seqs_1.detach().clone() |
| rotmats_t_1, trans_t_1_c, angles_t_1, seqs_t_1, seqs_t_1_simplex = rotmats_t_2, trans_t_2_c, angles_t_2, seqs_t_2, seqs_t_2_simplex |
| t_1 = t_2 |
|
|
| |
| t_1 = ts[-1] |
| t = torch.ones((num_batch, 1), device=batch['aa'].device) * t_1 |
| pred_rotmats_1, pred_trans_1, pred_angles_1, pred_seqs_1_prob = self.ga_encoder(t, rotmats_t_1, trans_t_1_c, angles_t_1, seqs_t_1, node_embed, edge_embed, batch['generate_mask'].long(), batch['res_mask'].long()) |
| pred_rotmats_1 = torch.where(batch['generate_mask'][...,None,None],pred_rotmats_1,rotmats_1) |
| |
| |
| pred_trans_1_c = torch.where(batch['generate_mask'][...,None],pred_trans_1,trans_1_c) |
| |
| pred_angles_1 = torch.where(batch['generate_mask'][...,None],pred_angles_1,angles_1) |
| |
| pred_seqs_1 = sample_from(F.softmax(pred_seqs_1_prob,dim=-1)) |
| pred_seqs_1 = torch.where(batch['generate_mask'],pred_seqs_1,seqs_1) |
| pred_seqs_1_simplex = self.seq_to_simplex(pred_seqs_1) |
| |
| torsion_mask = angle_mask_loss[pred_seqs_1.reshape(-1)].reshape(num_batch,num_res,-1) |
| pred_angles_1 = torch.where(torsion_mask.bool(),pred_angles_1,torch.zeros_like(pred_angles_1)) |
| if not sample_bb: |
| pred_trans_1_c = trans_1_c.detach().clone() |
| |
| pred_rotmats_1 = rotmats_1.detach().clone() |
| if not sample_ang: |
| pred_angles_1 = angles_1.detach().clone() |
| if not sample_seq: |
| pred_seqs_1 = seqs_1.detach().clone() |
| pred_seqs_1_simplex = seqs_1_simplex.detach().clone() |
| clean_traj.append({'rotmats':pred_rotmats_1.cpu(),'trans':pred_trans_1_c.cpu(),'angles':pred_angles_1.cpu(),'seqs':pred_seqs_1.cpu(),'seqs_simplex':pred_seqs_1_simplex.cpu(), |
| 'rotmats_1':rotmats_1.cpu(),'trans_1':trans_1_c.cpu(),'angles_1':angles_1.cpu(),'seqs_1':seqs_1.cpu()}) |
| |
| return clean_traj |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
|
|
| |
| |