PepFlow / models_con /sample.py
Irwiny123's picture
添加PepFlow模型初始代码
ef423c5
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import copy
import math
from tqdm.auto import tqdm
import functools
import os
import argparse
import pandas as pd
from copy import deepcopy
from models_con.pep_dataloader import PepDataset
from pepflow.utils.train import recursive_to
from pepflow.modules.common.geometry import reconstruct_backbone, reconstruct_backbone_partially, align, batch_align
from pepflow.modules.protein.writers import save_pdb
from pepflow.utils.data import PaddingCollate
from models_con.utils import process_dic
from models_con.flow_model import FlowModel
from models_con.torsion import full_atom_reconstruction, get_heavyatom_mask
collate_fn = PaddingCollate(eight=False)
import argparse
def item_to_batch(item, nums=32):
data_list = [deepcopy(item) for i in range(nums)]
return collate_fn(data_list)
def sample_for_data_bb(data, model, device, save_root, num_steps=200, sample_structure=True, sample_sequence=True, nums=8):
if not os.path.exists(os.path.join(save_root,data["id"])):
os.makedirs(os.path.join(save_root,data["id"]))
batch = recursive_to(item_to_batch(data, nums=nums),device=device)
traj = model.sample(batch, num_steps=num_steps, sample_structure=sample_structure, sample_sequence=sample_sequence)
final = recursive_to(traj[-1], device=device)
pos_bb = reconstruct_backbone(R=final['rotmats'],t=final['trans'],aa=final['seqs'],chain_nb=batch['chain_nb'],res_nb=batch['res_nb'],mask=batch['res_mask']) # (32,L,4,3)
pos_ha = F.pad(pos_bb, pad=(0,0,0,15-4), value=0.) # (32,L,A,3) pos14 A=14
pos_new = torch.where(batch['generate_mask'][:,:,None,None],pos_ha,batch['pos_heavyatom'])
mask_bb_atoms = torch.zeros_like(batch['mask_heavyatom'])
mask_bb_atoms[:,:,:4] = True
mask_new = torch.where(batch['generate_mask'][:,:,None],mask_bb_atoms,batch['mask_heavyatom'])
aa_new = final['seqs']
chain_nb = torch.LongTensor([0 if gen_mask else 1 for gen_mask in data['generate_mask']])
chain_id = ['A' if gen_mask else 'B' for gen_mask in data['generate_mask']]
icode = [' ' for _ in range(len(data['icode']))]
for i in range(nums):
ref_bb_pos = data['pos_heavyatom'][i][:,:4].cpu()
pred_bb_pos = pos_new[i][:,:4].cpu()
data_saved = {
'chain_nb':data['chain_nb'],'chain_id':data['chain_id'],'resseq':data['resseq'],'icode':data['icode'],
'aa':aa_new[i].cpu(), 'mask_heavyatom':mask_new[i].cpu(), 'pos_heavyatom':pos_new[i].cpu(),
}
save_pdb(data_saved,path=os.path.join(save_root,data["id"],f'{data["id"]}_{i}.pdb'))
save_pdb(data,path=os.path.join(save_root,data["id"],f'{data["id"]}_gt.pdb'))
def save_samples_bb(samples,save_dir):
# meta data
batch = recursive_to(samples['batch'],'cpu')
chain_id = [list(item) for item in zip(*batch['chain_id'])][0] # fix chain id in collate func
icode = [' ' for _ in range(len(chain_id))] # batch icode have same problem
nums = len(batch['id'])
id = batch['id'][0]
# batch convert
# aa=batch['aa] if only bb level
pos_bb = reconstruct_backbone(R=samples['rotmats'],t=samples['trans'],aa=samples['seqs'],chain_nb=batch['chain_nb'],res_nb=batch['res_nb'],mask=batch['res_mask']) # (32,L,4,3)
pos_ha = F.pad(pos_bb, pad=(0,0,0,15-4), value=0.) # (32,L,A,3) pos14 A=14
pos_new = torch.where(batch['generate_mask'][:,:,None,None],pos_ha,batch['pos_heavyatom'])
mask_bb_atoms = torch.zeros_like(batch['mask_heavyatom'])
mask_bb_atoms[:,:,:4] = True
mask_new = torch.where(batch['generate_mask'][:,:,None],mask_bb_atoms,batch['mask_heavyatom'])
aa_new = samples['seqs']
for i in range(nums):
data_saved = {
'chain_nb':batch['chain_nb'][0],'chain_id':chain_id,'resseq':batch['resseq'][0],'icode':icode,
'aa':aa_new[i], 'mask_heavyatom':mask_new[i], 'pos_heavyatom':pos_new[i],
}
save_pdb(data_saved,path=os.path.join(save_dir,f'sample_{i}.pdb'))
data_saved = {
'chain_nb':batch['chain_nb'][0],'chain_id':chain_id,'resseq':batch['resseq'][0],'icode':icode,
'aa':batch['aa'][0], 'mask_heavyatom':batch['mask_heavyatom'][0], 'pos_heavyatom':batch['pos_heavyatom'][0],
}
save_pdb(data_saved,path=os.path.join(save_dir,f'gt.pdb'))
def save_samples_sc(samples,save_dir):
# meta data
batch = recursive_to(samples['batch'],'cpu')
chain_id = [list(item) for item in zip(*batch['chain_id'])][0] # fix chain id in collate func
icode = [' ' for _ in range(len(chain_id))] # batch icode have same problem
nums = len(batch['id'])
id = batch['id'][0]
# batch convert
# aa=batch['aa] if only bb level
pos_ha,_,_ = full_atom_reconstruction(R_bb=samples['rotmats'],t_bb=samples['trans'],angles=samples['angles'],aa=samples['seqs']) # (32,L,14,3), instead of 15, ignore OXT masked
pos_ha = F.pad(pos_ha, pad=(0,0,0,15-14), value=0.) # (32,L,A,3) pos14 A=14
pos_new = torch.where(batch['generate_mask'][:,:,None,None],pos_ha,batch['pos_heavyatom'])
mask_new = get_heavyatom_mask(samples['seqs'])
aa_new = samples['seqs']
for i in range(nums):
data_saved = {
'chain_nb':batch['chain_nb'][0],'chain_id':chain_id,'resseq':batch['resseq'][0],'icode':icode,
'aa':aa_new[i], 'mask_heavyatom':mask_new[i], 'pos_heavyatom':pos_new[i],
}
save_pdb(data_saved,path=os.path.join(save_dir,f'sample_{i}.pdb'))
data_saved = {
'chain_nb':batch['chain_nb'][0],'chain_id':chain_id,'resseq':batch['resseq'][0],'icode':icode,
'aa':batch['aa'][0], 'mask_heavyatom':batch['mask_heavyatom'][0], 'pos_heavyatom':batch['pos_heavyatom'][0],
}
save_pdb(data_saved,path=os.path.join(save_dir,f'gt.pdb'))
if __name__ == '__main__':
# sample = torch.load('./Codesign/outputs/1aze_B.pt')
# save_samples_sc(sample,'./misc/test')
# save_samples_bb(sample,'./misc/test')
# for k,v in sample.items():
# if isinstance(v,torch.Tensor):
# print(f'{k},{v.shape}')
# # subdir = 'bb_seq_angle' # bb,bb_seq,bb_seq_angle
# names = [n.split('.')[0] for n in os.listdir(os.path.join(SAMPLE_DIR,subdir,'outputs'))]
# for name in tqdm(names):
# sample = torch.load(os.path.join(SAMPLE_DIR,subdir,'outputs',f'{name}.pt'))
# os.makedirs(os.path.join(SAMPLE_DIR,subdir,'pdbs',name),exist_ok=True)
# save_samples_sc(sample,os.path.join(SAMPLE_DIR,subdir,'pdbs',name))
args = argparse.ArgumentParser()
args.add_argument('--SAMPLEDIR', type=str)
parser = args.parse_args()
SAMPLE_DIR = parser.SAMPLEDIR
names = [n.split('.')[0] for n in os.listdir(os.path.join(SAMPLE_DIR,'outputs'))]
for name in tqdm(names):
sample = torch.load(os.path.join(SAMPLE_DIR,'outputs',f'{name}.pt'))
os.makedirs(os.path.join(SAMPLE_DIR,'pdbs',name),exist_ok=True)
save_samples_sc(sample,os.path.join(SAMPLE_DIR,'pdbs',name))