PepFlow / models_con /flow_model.py
Irwiny123's picture
添加PepFlow模型初始代码
ef423c5
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import math
from tqdm.auto import tqdm
import functools
from torch.utils.data import DataLoader
import os
import argparse
import pandas as pd
from models_con.edge import EdgeEmbedder
from models_con.node import NodeEmbedder
from pepflow.modules.common.layers import sample_from, clampped_one_hot
from models_con.ga import GAEncoder
from pepflow.modules.protein.constants import AA, BBHeavyAtom, max_num_heavyatoms
from pepflow.modules.common.geometry import construct_3d_basis
from pepflow.utils.data import mask_select_data, find_longest_true_segment, PaddingCollate
from pepflow.utils.misc import seed_all
from pepflow.utils.train import sum_weighted_losses
from torch.nn.utils import clip_grad_norm_
from pepflow.modules.so3.dist import centered_gaussian,uniform_so3
from pepflow.modules.common.geometry import batch_align, align
from tqdm import tqdm
import wandb
from data import so3_utils
from data import all_atom
from models_con.pep_dataloader import PepDataset
from pepflow.utils.misc import load_config
from pepflow.utils.train import recursive_to
from easydict import EasyDict
from models_con.utils import process_dic
from models_con.torsion import get_torsion_angle, torsions_mask
import models_con.torus as torus
import gc
from copy import deepcopy
from pepflow.utils.data import PaddingCollate
collate_fn = PaddingCollate(eight=False)
from pepflow.utils.train import recursive_to
resolution_to_num_atoms = {
'backbone+CB': 5,
'full': max_num_heavyatoms
}
class FlowModel(nn.Module):
def __init__(self,cfg):
super().__init__()
self._model_cfg = cfg.encoder
self._interpolant_cfg = cfg.interpolant
self.node_embedder = NodeEmbedder(cfg.encoder.node_embed_size,max_num_heavyatoms)
self.edge_embedder = EdgeEmbedder(cfg.encoder.edge_embed_size,max_num_heavyatoms)
self.ga_encoder = GAEncoder(cfg.encoder.ipa)
self.sample_structure = self._interpolant_cfg.sample_structure
self.sample_sequence = self._interpolant_cfg.sample_sequence
self.K = self._interpolant_cfg.seqs.num_classes
self.k = self._interpolant_cfg.seqs.simplex_value
def encode(self, batch):
rotmats_1 = construct_3d_basis(batch['pos_heavyatom'][:, :, BBHeavyAtom.CA],batch['pos_heavyatom'][:, :, BBHeavyAtom.C],batch['pos_heavyatom'][:, :, BBHeavyAtom.N] )
trans_1 = batch['pos_heavyatom'][:, :, BBHeavyAtom.CA]
seqs_1 = batch['aa']
# ignore psi
# batch['torsion_angle'] = batch['torsion_angle'][:,:,1:]
# batch['torsion_angle_mask'] = batch['torsion_angle_mask'][:,:,1:]
angles_1 = batch['torsion_angle']
context_mask = torch.logical_and(batch['mask_heavyatom'][:, :, BBHeavyAtom.CA], ~batch['generate_mask'])
structure_mask = context_mask if self.sample_structure else None
sequence_mask = context_mask if self.sample_sequence else None
node_embed = self.node_embedder(batch['aa'], batch['res_nb'], batch['chain_nb'], batch['pos_heavyatom'],
batch['mask_heavyatom'], structure_mask=structure_mask, sequence_mask=sequence_mask)
edge_embed = self.edge_embedder(batch['aa'], batch['res_nb'], batch['chain_nb'], batch['pos_heavyatom'],
batch['mask_heavyatom'], structure_mask=structure_mask, sequence_mask=sequence_mask)
return rotmats_1, trans_1, angles_1, seqs_1, node_embed, edge_embed
def zero_center_part(self,pos,gen_mask,res_mask):
"""
move pos by center of gen_mask
pos: (B,N,3)
gen_mask, res_mask: (B,N)
"""
center = torch.sum(pos * gen_mask[...,None], dim=1) / (torch.sum(gen_mask,dim=-1,keepdim=True) + 1e-8) # (B,N,3)*(B,N,1)->(B,3)/(B,1)->(B,3)
center = center.unsqueeze(1) # (B,1,3)
# center = 0. it seems not center didnt influence the result, but its good for training stabilty
pos = pos - center
pos = pos * res_mask[...,None]
return pos,center
def seq_to_simplex(self,seqs):
return clampped_one_hot(seqs, self.K).float() * self.k * 2 - self.k # (B,L,K)
def forward(self, batch):
num_batch, num_res = batch['aa'].shape
gen_mask,res_mask,angle_mask = batch['generate_mask'].long(),batch['res_mask'].long(),batch['torsion_angle_mask'].long()
#encode
rotmats_1, trans_1, angles_1, seqs_1, node_embed, edge_embed = self.encode(batch) # no generate mask
# prepare for denoise
trans_1_c,_ = self.zero_center_part(trans_1,gen_mask,res_mask)
trans_1_c = trans_1 # already centered when constructing dataset
seqs_1_simplex = self.seq_to_simplex(seqs_1)
seqs_1_prob = F.softmax(seqs_1_simplex,dim=-1)
with torch.no_grad():
t = torch.rand((num_batch,1), device=batch['aa'].device)
t = t*(1-2 * self._interpolant_cfg.t_normalization_clip) + self._interpolant_cfg.t_normalization_clip # avoid 0
if self.sample_structure:
# corrupt trans
trans_0 = torch.randn((num_batch,num_res,3), device=batch['aa'].device) * self._interpolant_cfg.trans.sigma # scale with sigma?
trans_0_c,_ = self.zero_center_part(trans_0,gen_mask,res_mask)
trans_t = (1-t[...,None])*trans_0_c + t[...,None]*trans_1_c
trans_t_c = torch.where(batch['generate_mask'][...,None],trans_t,trans_1_c)
# corrupt rotmats
rotmats_0 = uniform_so3(num_batch,num_res,device=batch['aa'].device)
rotmats_t = so3_utils.geodesic_t(t[..., None], rotmats_1, rotmats_0)
rotmats_t = torch.where(batch['generate_mask'][...,None,None],rotmats_t,rotmats_1)
# corrup angles
angles_0 = torus.tor_random_uniform(angles_1.shape, device=batch['aa'].device, dtype=angles_1.dtype) # (B,L,5)
angles_t = torus.tor_geodesic_t(t[..., None], angles_1, angles_0)
angles_t = torch.where(batch['generate_mask'][...,None],angles_t,angles_1)
else:
trans_t_c = trans_1_c.detach().clone()
rotmats_t = rotmats_1.detach().clone()
angles_t = angles_1.detach().clone()
if self.sample_sequence:
# corrupt seqs
seqs_0_simplex = self.k * torch.randn_like(seqs_1_simplex) # (B,L,K)
seqs_0_prob = F.softmax(seqs_0_simplex,dim=-1) # (B,L,K)
seqs_t_simplex = ((1 - t[..., None]) * seqs_0_simplex) + (t[..., None] * seqs_1_simplex) # (B,L,K)
seqs_t_simplex = torch.where(batch['generate_mask'][...,None],seqs_t_simplex,seqs_1_simplex)
seqs_t_prob = F.softmax(seqs_t_simplex,dim=-1) # (B,L,K)
seqs_t = sample_from(seqs_t_prob) # (B,L)
seqs_t = torch.where(batch['generate_mask'],seqs_t,seqs_1)
else:
seqs_t = seqs_1.detach().clone()
seqs_t_simplex = seqs_1_simplex.detach().clone()
seqs_t_prob = seqs_1_prob.detach().clone()
# denoise
pred_rotmats_1, pred_trans_1, pred_angles_1, pred_seqs_1_prob = self.ga_encoder(t, rotmats_t, trans_t_c, angles_t, seqs_t, node_embed, edge_embed, gen_mask, res_mask)
pred_seqs_1 = sample_from(F.softmax(pred_seqs_1_prob,dim=-1))
pred_seqs_1 = torch.where(batch['generate_mask'],pred_seqs_1,torch.clamp(seqs_1,0,19))
pred_trans_1_c,_ = self.zero_center_part(pred_trans_1,gen_mask,res_mask)
pred_trans_1_c = pred_trans_1 # implicitly enforce zero center in gen_mask, in this way, we dont need to move receptor when sampling
norm_scale = 1 / (1 - torch.min(t[...,None], torch.tensor(self._interpolant_cfg.t_normalization_clip))) # yim etal.trick, 1/1-t
# trans vf loss
trans_loss = torch.sum((pred_trans_1_c - trans_1_c)**2*gen_mask[...,None],dim=(-1,-2)) / (torch.sum(gen_mask,dim=-1) + 1e-8) # (B,)
trans_loss = torch.mean(trans_loss)
# rots vf loss
gt_rot_vf = so3_utils.calc_rot_vf(rotmats_t, rotmats_1)
pred_rot_vf = so3_utils.calc_rot_vf(rotmats_t, pred_rotmats_1)
rot_loss = torch.sum(((gt_rot_vf - pred_rot_vf) * norm_scale)**2*gen_mask[...,None],dim=(-1,-2)) / (torch.sum(gen_mask,dim=-1) + 1e-8) # (B,)
rot_loss = torch.mean(rot_loss)
# bb aux loss
gt_bb_atoms = all_atom.to_atom37(trans_1_c, rotmats_1)[:, :, :3]
pred_bb_atoms = all_atom.to_atom37(pred_trans_1_c, pred_rotmats_1)[:, :, :3]
# gt_bb_atoms = all_atom.to_bb_atoms(trans_1_c, rotmats_1, angles_1[:,:,0]) # N,CA,C,O,CB
# pred_bb_atoms = all_atom.to_bb_atoms(pred_trans_1_c, pred_rotmats_1, pred_angles_1[:,:,0])
# print(gt_bb_atoms.shape)
bb_atom_loss = torch.sum(
(gt_bb_atoms - pred_bb_atoms) ** 2 * gen_mask[..., None, None],
dim=(-1, -2, -3)
) / (torch.sum(gen_mask,dim=-1) + 1e-8) # (B,)
bb_atom_loss = torch.mean(bb_atom_loss)
# bb_atom_loss = torch.mean(torch.where(t[:,0]>=0.75,bb_atom_loss,torch.zeros_like(bb_atom_loss))) # penalty for near gt point
# seqs vf loss
seqs_loss = F.cross_entropy(pred_seqs_1_prob.view(-1,pred_seqs_1_prob.shape[-1]),torch.clamp(seqs_1,0,19).view(-1), reduction='none').view(pred_seqs_1_prob.shape[:-1]) # (N,L), not softmax
seqs_loss = torch.sum(seqs_loss * gen_mask, dim=-1) / (torch.sum(gen_mask,dim=-1) + 1e-8)
seqs_loss = torch.mean(seqs_loss)
# we should not use angle mask, as you dont know aa type when generating
# angle_mask_loss = torch.cat([angle_mask,angle_mask],dim=-1) # (B,L,10)
# angle vf loss
angle_mask_loss = torsions_mask.to(batch['aa'].device)
angle_mask_loss = angle_mask_loss[pred_seqs_1.reshape(-1)].reshape(num_batch,num_res,-1) # (B,L,5)
angle_mask_loss = torch.cat([angle_mask_loss,angle_mask_loss],dim=-1) # (B,L,10)
angle_mask_loss = torch.logical_and(batch['generate_mask'][...,None].bool(),angle_mask_loss)
gt_angle_vf = torus.tor_logmap(angles_t, angles_1)
gt_angle_vf_vec = torch.cat([torch.sin(gt_angle_vf),torch.cos(gt_angle_vf)],dim=-1)
pred_angle_vf = torus.tor_logmap(angles_t, pred_angles_1)
pred_angle_vf_vec = torch.cat([torch.sin(pred_angle_vf),torch.cos(pred_angle_vf)],dim=-1)
# angle_loss = torch.sum(((gt_angle_vf_vec - pred_angle_vf_vec) * norm_scale)**2*gen_mask[...,None],dim=(-1,-2)) / ((torch.sum(gen_mask,dim=-1)) + 1e-8) # (B,)
angle_loss = torch.sum(((gt_angle_vf_vec - pred_angle_vf_vec) * norm_scale)**2*angle_mask_loss,dim=(-1,-2)) / (torch.sum(angle_mask_loss,dim=(-1,-2)) + 1e-8) # (B,)
angle_loss = torch.mean(angle_loss)
# angle aux loss
angles_1_vec = torch.cat([torch.sin(angles_1),torch.cos(angles_1)],dim=-1)
pred_angles_1_vec = torch.cat([torch.sin(pred_angles_1),torch.cos(pred_angles_1)],dim=-1)
# torsion_loss = torch.sum((pred_angles_1_vec - angles_1_vec)**2*gen_mask[...,None],dim=(-1,-2)) / (torch.sum(gen_mask,dim=-1) + 1e-8) # (B,)
torsion_loss = torch.sum((pred_angles_1_vec - angles_1_vec)**2*angle_mask_loss,dim=(-1,-2)) / (torch.sum(angle_mask_loss,dim=(-1,-2)) + 1e-8) # (B,)
torsion_loss = torch.mean(torsion_loss)
return {
"trans_loss": trans_loss,
'rot_loss': rot_loss,
'bb_atom_loss': bb_atom_loss,
'seqs_loss': seqs_loss,
'angle_loss': angle_loss,
'torsion_loss': torsion_loss,
}
@torch.no_grad()
def sample(self, batch, num_steps = 100, sample_bb=True, sample_ang=True, sample_seq=True):
num_batch, num_res = batch['aa'].shape
gen_mask,res_mask = batch['generate_mask'],batch['res_mask']
K = self._interpolant_cfg.seqs.num_classes
k = self._interpolant_cfg.seqs.simplex_value
angle_mask_loss = torsions_mask.to(batch['aa'].device)
#encode
rotmats_1, trans_1, angles_1, seqs_1, node_embed, edge_embed = self.encode(batch)
# trans_1_c,center = self.zero_center_part(trans_1,gen_mask,res_mask)
trans_1_c = trans_1
seqs_1_simplex = self.seq_to_simplex(seqs_1)
seqs_1_prob = F.softmax(seqs_1_simplex,dim=-1)
# # # only sample bb, angle and seq with noise
# angles_1 = torch.where(batch['generate_mask'][...,None],angles_1,torus.tor_random_uniform(angles_1.shape, device=batch['aa'].device, dtype=angles_1.dtype))
# seqs_1 = torch.where(batch['generate_mask'],seqs_1,torch.randint_like(seqs_1,0,20))
# seqs_1_simplex = self.seq_to_simplex(seqs_1)
# seqs_1_prob = F.softmax(seqs_1_simplex,dim=-1)
#initial noise
if sample_bb:
rotmats_0 = uniform_so3(num_batch,num_res,device=batch['aa'].device)
rotmats_0 = torch.where(batch['generate_mask'][...,None,None],rotmats_0,rotmats_1)
trans_0 = torch.randn((num_batch,num_res,3), device=batch['aa'].device) # scale with sigma?
# move center and receptor
trans_0_c,center = self.zero_center_part(trans_0,gen_mask,res_mask)
trans_0_c = torch.where(batch['generate_mask'][...,None],trans_0_c,trans_1_c)
else:
rotmats_0 = rotmats_1.detach().clone()
trans_0_c = trans_1_c.detach().clone()
if sample_ang:
# angle noise
angles_0 = torus.tor_random_uniform(angles_1.shape, device=batch['aa'].device, dtype=angles_1.dtype) # (B,L,5)
angles_0 = torch.where(batch['generate_mask'][...,None],angles_0,angles_1)
else:
angles_0 = angles_1.detach().clone()
if sample_seq:
seqs_0_simplex = k * torch.randn((num_batch,num_res,K), device=batch['aa'].device)
seqs_0_prob = F.softmax(seqs_0_simplex,dim=-1)
seqs_0 = sample_from(seqs_0_prob)
seqs_0 = torch.where(batch['generate_mask'],seqs_0,seqs_1)
seqs_0_simplex = torch.where(batch['generate_mask'][...,None],seqs_0_simplex,seqs_1_simplex)
else:
seqs_0 = seqs_1.detach().clone()
seqs_0_prob = seqs_1_prob.detach().clone()
seqs_0_simplex = seqs_1_simplex.detach().clone()
# Set-up time
ts = torch.linspace(1.e-2, 1.0, num_steps)
t_1 = ts[0]
# prot_traj = [{'rotmats':rotmats_0,'trans':trans_0_c,'seqs':seqs_0,'seqs_simplex':seqs_0_simplex,'rotmats_1':rotmats_1,'trans_1':trans_1-center,'seqs_1':seqs_1}]
clean_traj = []
rotmats_t_1, trans_t_1_c, angles_t_1, seqs_t_1, seqs_t_1_simplex = rotmats_0, trans_0_c, angles_0, seqs_0, seqs_0_simplex
# denoise loop
for t_2 in ts[1:]:
t = torch.ones((num_batch, 1), device=batch['aa'].device) * t_1
# rots
pred_rotmats_1, pred_trans_1, pred_angles_1, pred_seqs_1_prob = self.ga_encoder(t, rotmats_t_1, trans_t_1_c, angles_t_1, seqs_t_1, node_embed, edge_embed, batch['generate_mask'].long(), batch['res_mask'].long())
pred_rotmats_1 = torch.where(batch['generate_mask'][...,None,None],pred_rotmats_1,rotmats_1)
# trans, move center
# pred_trans_1_c,center = self.zero_center_part(pred_trans_1,gen_mask,res_mask)
pred_trans_1_c = torch.where(batch['generate_mask'][...,None],pred_trans_1,trans_1_c) # move receptor also
# angles
pred_angles_1 = torch.where(batch['generate_mask'][...,None],pred_angles_1,angles_1)
# seqs
pred_seqs_1 = sample_from(F.softmax(pred_seqs_1_prob,dim=-1))
pred_seqs_1 = torch.where(batch['generate_mask'],pred_seqs_1,seqs_1)
pred_seqs_1_simplex = self.seq_to_simplex(pred_seqs_1)
# seq-angle
torsion_mask = angle_mask_loss[pred_seqs_1.reshape(-1)].reshape(num_batch,num_res,-1) # (B,L,5)
pred_angles_1 = torch.where(torsion_mask.bool(),pred_angles_1,torch.zeros_like(pred_angles_1))
if not sample_bb:
pred_trans_1_c = trans_1_c.detach().clone()
# _,center = self.zero_center_part(trans_1,gen_mask,res_mask)
pred_rotmats_1 = rotmats_1.detach().clone()
if not sample_ang:
pred_angles_1 = angles_1.detach().clone()
if not sample_seq:
pred_seqs_1 = seqs_1.detach().clone()
pred_seqs_1_simplex = seqs_1_simplex.detach().clone()
clean_traj.append({'rotmats':pred_rotmats_1.cpu(),'trans':pred_trans_1_c.cpu(),'angles':pred_angles_1.cpu(),'seqs':pred_seqs_1.cpu(),'seqs_simplex':pred_seqs_1_simplex.cpu(),
'rotmats_1':rotmats_1.cpu(),'trans_1':trans_1_c.cpu(),'angles_1':angles_1.cpu(),'seqs_1':seqs_1.cpu()})
# reverse step, also only for gen mask region
d_t = (t_2-t_1) * torch.ones((num_batch, 1), device=batch['aa'].device)
# Euler step
trans_t_2 = trans_t_1_c + (pred_trans_1_c-trans_0_c)*d_t[...,None]
# trans_t_2_c,center = self.zero_center_part(trans_t_2,gen_mask,res_mask)
trans_t_2_c = torch.where(batch['generate_mask'][...,None],trans_t_2,trans_1_c) # move receptor also
# rotmats_t_2 = so3_utils.geodesic_t(d_t[...,None] / (1-t[...,None]), pred_rotmats_1, rotmats_t_1)
rotmats_t_2 = so3_utils.geodesic_t(d_t[...,None] * 10, pred_rotmats_1, rotmats_t_1)
rotmats_t_2 = torch.where(batch['generate_mask'][...,None,None],rotmats_t_2,rotmats_1)
# angles
angles_t_2 = torus.tor_geodesic_t(d_t[...,None],pred_angles_1, angles_t_1)
angles_t_2 = torch.where(batch['generate_mask'][...,None],angles_t_2,angles_1)
# seqs
seqs_t_2_simplex = seqs_t_1_simplex + (pred_seqs_1_simplex - seqs_0_simplex) * d_t[...,None]
seqs_t_2 = sample_from(F.softmax(seqs_t_2_simplex,dim=-1))
seqs_t_2 = torch.where(batch['generate_mask'],seqs_t_2,seqs_1)
# seq-angle
torsion_mask = angle_mask_loss[seqs_t_2.reshape(-1)].reshape(num_batch,num_res,-1) # (B,L,5)
angles_t_2 = torch.where(torsion_mask.bool(),angles_t_2,torch.zeros_like(angles_t_2))
if not sample_bb:
trans_t_2_c = trans_1_c.detach().clone()
rotmats_t_2 = rotmats_1.detach().clone()
if not sample_ang:
angles_t_2 = angles_1.detach().clone()
if not sample_seq:
seqs_t_2 = seqs_1.detach().clone()
rotmats_t_1, trans_t_1_c, angles_t_1, seqs_t_1, seqs_t_1_simplex = rotmats_t_2, trans_t_2_c, angles_t_2, seqs_t_2, seqs_t_2_simplex
t_1 = t_2
# final step
t_1 = ts[-1]
t = torch.ones((num_batch, 1), device=batch['aa'].device) * t_1
pred_rotmats_1, pred_trans_1, pred_angles_1, pred_seqs_1_prob = self.ga_encoder(t, rotmats_t_1, trans_t_1_c, angles_t_1, seqs_t_1, node_embed, edge_embed, batch['generate_mask'].long(), batch['res_mask'].long())
pred_rotmats_1 = torch.where(batch['generate_mask'][...,None,None],pred_rotmats_1,rotmats_1)
# move center
# pred_trans_1_c,center = self.zero_center_part(pred_trans_1,gen_mask,res_mask)
pred_trans_1_c = torch.where(batch['generate_mask'][...,None],pred_trans_1,trans_1_c) # move receptor also
# angles
pred_angles_1 = torch.where(batch['generate_mask'][...,None],pred_angles_1,angles_1)
# seqs
pred_seqs_1 = sample_from(F.softmax(pred_seqs_1_prob,dim=-1))
pred_seqs_1 = torch.where(batch['generate_mask'],pred_seqs_1,seqs_1)
pred_seqs_1_simplex = self.seq_to_simplex(pred_seqs_1)
# seq-angle
torsion_mask = angle_mask_loss[pred_seqs_1.reshape(-1)].reshape(num_batch,num_res,-1) # (B,L,5)
pred_angles_1 = torch.where(torsion_mask.bool(),pred_angles_1,torch.zeros_like(pred_angles_1))
if not sample_bb:
pred_trans_1_c = trans_1_c.detach().clone()
# _,center = self.zero_center_part(trans_1,gen_mask,res_mask)
pred_rotmats_1 = rotmats_1.detach().clone()
if not sample_ang:
pred_angles_1 = angles_1.detach().clone()
if not sample_seq:
pred_seqs_1 = seqs_1.detach().clone()
pred_seqs_1_simplex = seqs_1_simplex.detach().clone()
clean_traj.append({'rotmats':pred_rotmats_1.cpu(),'trans':pred_trans_1_c.cpu(),'angles':pred_angles_1.cpu(),'seqs':pred_seqs_1.cpu(),'seqs_simplex':pred_seqs_1_simplex.cpu(),
'rotmats_1':rotmats_1.cpu(),'trans_1':trans_1_c.cpu(),'angles_1':angles_1.cpu(),'seqs_1':seqs_1.cpu()})
return clean_traj
# if __name__ == '__main__':
# prefix_dir = './pepflowww'
# # config,cfg_name = load_config("../configs/angle/learn_sc.yaml")
# config,cfg_name = load_config(os.path.join(prefix_dir,"configs/angle/learn_sc.yaml"))
# # print(config)
# device = 'cuda:0'
# dataset = PepDataset(structure_dir = config.dataset.val.structure_dir, dataset_dir = config.dataset.val.dataset_dir,
# name = config.dataset.val.name, transform=None, reset=config.dataset.val.reset)
# dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=PaddingCollate(eight=False), num_workers=4, pin_memory=True)
# ckpt = torch.load("./checkpoints/600000.pt", map_location=device)
# seed_all(114514)
# model = FlowModel(config.model).to(device)
# model.load_state_dict(process_dic(ckpt['model']))
# model.eval()
# # print(model)
# # print(dataset[0]['chain_id'])
# # print(dataset[0]['id'])
# # print(dataset[0]['resseq'])
# # print(dataset[0]['res_nb'])
# # print(dataset[0]['icode'])
# dic = {'id':[],'len':[],'tran':[],'aar':[],'rot':[],'trans_loss':[],'rot_loss':[]}
# # for batch in tqdm(dataloader):
# # batch = recursive_to(batch,device)
# for i in tqdm(range(len(dataset))):
# item = dataset[i]
# data_list = [deepcopy(item) for _ in range(16)]
# batch = recursive_to(collate_fn(data_list),device)
# loss_dic = model(batch)
# # traj_1 = model.sample(batch,num_steps=50,sample_bb=False,sample_ang=True,sample_seq=False)
# traj_1 = model.sample(batch,num_steps=50,sample_bb=True,sample_ang=True,sample_seq=True)
# ca_dist = torch.sqrt(torch.sum((traj_1[-1]['trans']-traj_1[-1]['trans_1'])**2*batch['generate_mask'][...,None].cpu().long()) / (torch.sum(batch['generate_mask']) + 1e-8).cpu()) # rmsd
# rot_dist = torch.sqrt(torch.sum((traj_1[-1]['rotmats']-traj_1[-1]['rotmats_1'])**2*batch['generate_mask'][...,None,None].long().cpu()) / (torch.sum(batch['generate_mask']) + 1e-8).cpu()) # rmsd
# aar = torch.sum((traj_1[-1]['seqs']==traj_1[-1]['seqs_1']) * batch['generate_mask'].long().cpu()) / (torch.sum(batch['generate_mask']).cpu() + 1e-8)
# print(loss_dic)
# print(f'tran:{ca_dist},rot:{rot_dist},aar:{aar},len:{batch["generate_mask"].sum().item()}')
# # free
# torch.cuda.empty_cache()
# gc.collect()
# # dic['tran'].append(ca_dist.item())
# # dic['rot'].append(rot_dist.item())
# dic['aar'].append(aar.item())
# dic['trans_loss'].append(loss_dic['trans_loss'].item())
# dic['rot_loss'].append(loss_dic['rot_loss'].item())
# dic['id'].append(batch['id'][0])
# dic['len'].append(batch['generate_mask'].sum().item())
# # # break
# # traj_1[-1]['batch'] = batch
# # torch.save(traj_1[-1],f'/datapool/data2/home/jiahan/ResProj/PepDiff/frame-flow/Data/Models_new/Pack_new/outputs/{batch["id"][0]}.pt')
# # print(dic)
# # dic = pd.DataFrame(dic)
# # dic.to_csv(f'/datapool/data2/home/jiahan/ResProj/PepDiff/frame-flow/Data/Models_new/Pack/outputs.csv',index=None)
# print(np.mean(dic['aar']))
# print(np.mean(dic['trans_loss']))
# if __name__ == '__main__':
# config,cfg_name = load_config("./configs/angle/learn_angle.yaml")
# seed_all(114514)
# device = 'cpu'
# dataset = PepDataset(structure_dir = config.dataset.train.structure_dir, dataset_dir = config.dataset.train.dataset_dir,
# name = config.dataset.train.name, transform=None, reset=config.dataset.train.reset)
# dataloader = DataLoader(dataset, batch_size=2, shuffle=False, collate_fn=PaddingCollate(), num_workers=4, pin_memory=True)
# model = FlowModel(config.model).to(device)
# optimizer = torch.optim.Adam(model.parameters(),lr=1.e-4)
# # ckpt = torch.load('./checkpoints/90000.pt',map_location=device)
# # model.load_state_dict(process_dic(ckpt['model']))
# # optimizer.load_state_dict(ckpt['optimizer'])
# # torch.autograd.set_detect_anomaly(True)
# for i,batch in tqdm(enumerate(dataloader)):
# batch = recursive_to(batch,device)
# loss_dict = model(batch)
# loss = sum_weighted_losses(loss_dict, config.train.loss_weights)
# # if torch.isnan(loss):
# # print(i)
# # print(batch['id'])
# loss.backward()
# orig_grad_norm = clip_grad_norm_(model.parameters(), config.train.max_grad_norm)
# print(f'{loss_dict},{loss},{orig_grad_norm}')
# optimizer.step()
# optimizer.zero_grad()