| | |
| | |
| | |
| | import os |
| | from scipy.interpolate import CubicSpline, PchipInterpolator, Akima1DInterpolator |
| | import numpy as np |
| | import math |
| | import matplotlib.pyplot as plt |
| |
|
| | from Bio.PDB import PDBParser |
| | from Bio.PDB.DSSP import DSSP |
| | from Bio.PDB import PDBList |
| |
|
| | import torch |
| | from einops import rearrange |
| | import esm |
| | |
| | |
| | def create_path(this_path): |
| | if not os.path.exists(this_path): |
| | print('Creating the given path...') |
| | os.mkdir (this_path) |
| | path_stat = 1 |
| | print('Done.') |
| | else: |
| | print('The given path already exists!') |
| | path_stat = 2 |
| | return path_stat |
| | |
| | |
| |
|
| | |
| | def params (model): |
| | pytorch_total_params = sum(p.numel() for p in model.parameters()) |
| | pytorch_total_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| |
|
| | print ("Total parameters: ", pytorch_total_params," trainable parameters: ", pytorch_total_params_trainable) |
| | |
| | |
| | |
| | def prepare_UNet_keys(write_dict): |
| | |
| | Full_Keys=['dim', 'text_embed_dim', 'num_resnet_blocks', 'cond_dim', 'num_image_tokens', 'num_time_tokens', 'learned_sinu_pos_emb_dim', 'out_dim', 'dim_mults', 'cond_images_channels', 'channels', 'channels_out', 'attn_dim_head', 'attn_heads', 'ff_mult', 'lowres_cond', 'layer_attns', 'layer_attns_depth', 'layer_attns_add_text_cond', 'attend_at_middle', 'layer_cross_attns', 'use_linear_attn', 'use_linear_cross_attn', 'cond_on_text', 'max_text_len', 'init_dim', 'resnet_groups', 'init_conv_kernel_size', 'init_cross_embed', 'init_cross_embed_kernel_sizes', 'cross_embed_downsample', 'cross_embed_downsample_kernel_sizes', 'attn_pool_text', 'attn_pool_num_latents', 'dropout', 'memory_efficient', 'init_conv_to_final_conv_residual', 'use_global_context_attn', 'scale_skip_connection', 'final_resnet_block', 'final_conv_kernel_size', 'cosine_sim_attn', 'self_cond', 'combine_upsample_fmaps', 'pixel_shuffle_upsample', 'beginning_and_final_conv_present'] |
| | |
| | PKeys={} |
| | for key in Full_Keys: |
| | PKeys[key]=None |
| | |
| | for write_key in write_dict.keys(): |
| | if write_key in PKeys.keys(): |
| | PKeys[write_key]=write_dict[write_key] |
| | else: |
| | print("Wrong key found: ", write_key) |
| | |
| | return PKeys |
| |
|
| | def prepare_ModelB_keys(write_dict): |
| | Full_Keys=['timesteps', 'dim', 'pred_dim', 'loss_type', 'elucidated', 'padding_idx', 'cond_dim', 'text_embed_dim', 'input_tokens', 'sequence_embed', 'embed_dim_position', 'max_text_len', 'cond_images_channels', 'max_length', 'device'] |
| | |
| | PKeys={} |
| | for key in Full_Keys: |
| | PKeys[key]=None |
| | |
| | for write_key in write_dict.keys(): |
| | if write_key in PKeys.keys(): |
| | PKeys[write_key]=write_dict[write_key] |
| | else: |
| | print("Wrong key found: ", write_key) |
| | |
| | return PKeys |
| |
|
| | def modify_keys(old_dict,write_dict): |
| | new_dict = old_dict.copy() |
| | for w_key in write_dict.keys(): |
| | if w_key in old_dict.keys(): |
| | new_dict[w_key]=write_dict[w_key] |
| | else: |
| | print("Alien key found: ", w_key) |
| | return new_dict |
| |
|
| | |
| | |
| | |
| | def mixing_two_FORCE_for_AA_Len(NGap1,Force1,NGap2,Force2,LenAA,mix_fac): |
| | N = np.amax([len(NGap1), len(NGap2)]) |
| | N_Base = math.ceil(N*2) |
| | fun_PI_0 = PchipInterpolator(NGap1,Force1) |
| | fun_PI_1 = PchipInterpolator(NGap2,Force2) |
| | xx=np.linspace(0,1,N_Base) |
| | yy=fun_PI_0(xx)*mix_fac+fun_PI_1(xx)*(1-mix_fac) |
| | fun_PI = PchipInterpolator(xx,yy) |
| | |
| | x1=np.linspace(0,1,LenAA+1) |
| | y1=fun_PI(x1) |
| | return fun_PI, x1, y1 |
| |
|
| | |
| | |
| | |
| | def get_Model_A_error (fname, cond, plotit=True, ploterror=False): |
| | |
| | sec_structure,sec_structure_3state, sequence=get_DSSP_result (fname) |
| | sscount=[] |
| | length = len (sec_structure) |
| | sscount.append (sec_structure.count('H')/length) |
| | sscount.append (sec_structure.count('E')/length) |
| | sscount.append (sec_structure.count('T')/length) |
| | sscount.append (sec_structure.count('~')/length) |
| | sscount.append (sec_structure.count('B')/length) |
| | sscount.append (sec_structure.count('G')/length) |
| | sscount.append (sec_structure.count('I')/length) |
| | sscount.append (sec_structure.count('S')/length) |
| | sscount=np.asarray (sscount) |
| | |
| | error=np.abs(sscount-cond) |
| | print ("Abs error per SS structure type (H, E, T, ~, B, G, I S): ", error) |
| |
|
| | if ploterror: |
| | fig, ax = plt.subplots(1, 1, figsize=(6,3)) |
| | plt.plot (error, 'o-', label='Error over SS type') |
| | plt.legend() |
| | plt.ylabel ('SS content') |
| | plt.show() |
| |
|
| | x=np.linspace (0, 7, 8) |
| | |
| | sslabels=['H','E','T','~','B','G','I','S'] |
| | |
| | fig, ax = plt.subplots(1, 1, figsize=(6,3)) |
| | |
| | ax.bar(x-0.15, cond, width=0.3, color='b', align='center') |
| | ax.bar(x+0.15, sscount, width=0.3, color='r', align='center') |
| | |
| | ax.set_ylim([0, 1]) |
| | |
| | plt.xticks(range(len(sslabels)), sslabels, size='medium') |
| | plt.legend (['GT','Prediction']) |
| | |
| | plt.ylabel ('SS content') |
| | plt.show() |
| | |
| | |
| |
|
| | sscount=[] |
| | length = len (sec_structure) |
| | sscount.append (sec_structure_3state.count('H')/length) |
| | sscount.append (sec_structure_3state.count('E')/length) |
| | sscount.append (sec_structure_3state.count('~')/length) |
| | cond_p=[np.sum([cond[0],cond[5], cond[6]]), np.sum ([cond[1], cond[4]]), np.sum([cond[2],cond[3],cond[7]]) ] |
| | |
| | print ("cond 3type: ",cond_p) |
| | sscount=np.asarray (sscount) |
| | |
| | error3=np.abs(sscount-cond_p) |
| | print ("Abs error per 3-type SS structure type (C, H, E): ", error) |
| | |
| | if ploterror: |
| | fig, ax = plt.subplots(1, 1, figsize=(6,3)) |
| |
|
| | plt.plot (error3, 'o-', label='Error over SS type') |
| | plt.legend() |
| | plt.ylabel ('SS content') |
| | plt.show() |
| | |
| | |
| | x=np.linspace (0,2, 3) |
| | |
| | sslabels=['H','E', '~' ] |
| | |
| | |
| | fig, ax = plt.subplots(1, 1, figsize=(6,3)) |
| | |
| | |
| | ax.bar(x-0.15, cond_p, width=0.3, color='b', align='center') |
| | ax.bar(x+0.15, sscount, width=0.3, color='r', align='center') |
| | |
| | ax.set_ylim([0, 1]) |
| | |
| | plt.xticks(range(len(sslabels)), sslabels, size='medium') |
| | plt.legend (['GT','Prediction']) |
| | |
| | plt.ylabel ('SS content') |
| | plt.show() |
| | |
| | return error |
| |
|
| | def get_DSSP_result (fname): |
| | pdb_list = [fname] |
| |
|
| | |
| | p = PDBParser() |
| | for i in pdb_list: |
| | structure = p.get_structure(i, fname) |
| | |
| | model = structure[0] |
| | |
| | dssp = DSSP(model, fname, file_type='PDB' ) |
| | |
| | sequence = '' |
| | sec_structure = '' |
| | for z in range(len(dssp)): |
| | a_key = list(dssp.keys())[z] |
| | sequence += dssp[a_key][1] |
| | sec_structure += dssp[a_key][2] |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | sec_structure = sec_structure.replace('-', '~') |
| | sec_structure_3state=sec_structure |
| |
|
| | |
| | |
| | sec_structure_3state = sec_structure_3state.replace('H', 'H') |
| | sec_structure_3state = sec_structure_3state.replace('E', 'E') |
| | sec_structure_3state = sec_structure_3state.replace('T', '~') |
| | sec_structure_3state = sec_structure_3state.replace('~', '~') |
| | sec_structure_3state = sec_structure_3state.replace('B', 'E') |
| | sec_structure_3state = sec_structure_3state.replace('G', 'H') |
| | sec_structure_3state = sec_structure_3state.replace('I', 'H') |
| | sec_structure_3state = sec_structure_3state.replace('S', '~') |
| | return sec_structure,sec_structure_3state, sequence |
| | |
| | |
| | def string_diff (seq1, seq2): |
| | return sum(1 for a, b in zip(seq1, seq2) if a != b) + abs(len(seq1) - len(seq2)) |
| |
|
| |
|
| | |
| | |
| | |
| | import esm |
| |
|
| | def decode_one_ems_token_rec(this_token, esm_alphabet): |
| | |
| | |
| | |
| |
|
| | id_b=(this_token==esm_alphabet.cls_idx).nonzero(as_tuple=True)[0] |
| | id_e=(this_token==esm_alphabet.eos_idx).nonzero(as_tuple=True)[0] |
| | |
| | |
| | if len(id_e)==0: |
| | |
| | id_e=len(this_token) |
| | else: |
| | id_e=id_e[0] |
| | if len(id_b)==0: |
| | id_b=0 |
| | else: |
| | id_b=id_b[-1] |
| |
|
| | this_seq = [] |
| | |
| | for ii in range(id_b+1,id_e,1): |
| | |
| | this_seq.append( |
| | esm_alphabet.get_tok(this_token[ii]) |
| | ) |
| | |
| | this_seq = "".join(this_seq) |
| |
|
| | |
| | |
| | |
| | return this_seq |
| |
|
| |
|
| | def decode_many_ems_token_rec(batch_tokens, esm_alphabet): |
| | rev_y_seq = [] |
| | for jj in range(len(batch_tokens)): |
| | |
| | this_seq = decode_one_ems_token_rec( |
| | batch_tokens[jj], esm_alphabet |
| | ) |
| | rev_y_seq.append(this_seq) |
| | return rev_y_seq |
| |
|
| | |
| | uncomm_idx_list = [0, 1, 2, 3, 24, 25, 26, 27, 28, 29, 30, 31, 32] |
| |
|
| | |
| | def decode_one_ems_token_rec_for_folding( |
| | this_token, |
| | this_logits, |
| | esm_alphabet, |
| | esm_model): |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | id_b_0=(this_token==esm_alphabet.cls_idx).nonzero(as_tuple=True)[0] |
| | id_e_0=(this_token==esm_alphabet.eos_idx).nonzero(as_tuple=True)[0] |
| | |
| | |
| | |
| | |
| | |
| | id_b = 0 |
| | |
| | if len(id_e_0)==0: |
| | id_e=len(this_token) |
| | else: |
| | id_e=id_e_0[0] |
| | |
| | if id_e<=id_b+1: |
| | if len(id_e_0)>1: |
| | id_e=id_e_0[1] |
| | else: |
| | id_e=len(this_token) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | print("start at: ", id_b) |
| | print("end at: ", id_e) |
| | |
| | |
| | use_logits = this_logits[id_b+1:id_e] |
| | use_logits[:,uncomm_idx_list]=-float('inf') |
| | use_token = use_logits.max(1).indices |
| | |
| | |
| | |
| | this_seq = [] |
| | |
| | |
| | for ii in range(len(use_token)): |
| | |
| | |
| | |
| | this_seq.append( |
| | esm_alphabet.get_tok(use_token[ii]) |
| | ) |
| | |
| | this_seq = "".join(this_seq) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | return this_seq |
| |
|
| |
|
| | def decode_many_ems_token_rec_for_folding( |
| | batch_tokens, |
| | batch_logits, |
| | esm_alphabet, |
| | esm_model): |
| | |
| | rev_y_seq = [] |
| | for jj in range(len(batch_tokens)): |
| | |
| | this_seq = decode_one_ems_token_rec_for_folding( |
| | batch_tokens[jj], |
| | batch_logits[jj], |
| | esm_alphabet, |
| | esm_model, |
| | ) |
| | rev_y_seq.append(this_seq) |
| | return rev_y_seq |
| |
|
| |
|
| | def convert_into_logits(esm_model, result): |
| | repre=rearrange( |
| | result, |
| | 'b l c -> b c l' |
| | ) |
| | with torch.no_grad(): |
| | logits=esm_model.lm_head(repre) |
| | |
| | return logits |
| |
|
| | |
| | def convert_into_tokens(model, result, pLM_Model_Name): |
| | if pLM_Model_Name=='esm2_t33_650M_UR50D' \ |
| | or pLM_Model_Name=='esm2_t36_3B_UR50D' \ |
| | or pLM_Model_Name=='esm2_t30_150M_UR50D' \ |
| | or pLM_Model_Name=='esm2_t12_35M_UR50D' : |
| | |
| | repre=rearrange( |
| | result, |
| | 'b c l -> b l c' |
| | ) |
| | with torch.no_grad(): |
| | logits=model.lm_head(repre) |
| | |
| | tokens=logits.max(2).indices |
| | |
| | else: |
| | print("pLM_Model is not defined...") |
| | return tokens,logits |
| | |
| | def convert_into_tokens_using_prob(prob_result, pLM_Model_Name): |
| | if pLM_Model_Name=='esm2_t33_650M_UR50D' \ |
| | or pLM_Model_Name=='esm2_t36_3B_UR50D' \ |
| | or pLM_Model_Name=='esm2_t30_150M_UR50D' \ |
| | or pLM_Model_Name=='esm2_t12_35M_UR50D' : |
| | |
| | repre=rearrange( |
| | prob_result, |
| | 'b c l -> b l c' |
| | ) |
| | |
| | |
| | logits = repre |
| | |
| | tokens=logits.max(2).indices |
| | |
| | else: |
| | print("pLM_Model is not defined...") |
| | return tokens,logits |
| |
|
| |
|
| | |
| | def read_mask_from_input( |
| | |
| | |
| | |
| | tokenized_data=None, |
| | mask_value=None, |
| | seq_data=None, |
| | max_seq_length=None, |
| | ): |
| | |
| | |
| | |
| | if seq_data!=None: |
| | |
| | n_seq = len(seq_data) |
| | mask = torch.zeros(n_seq, max_seq_length) |
| | for ii in range(n_seq): |
| | this_len = len(seq_data[ii]) |
| | mask[ii,1:1+this_len]=1 |
| | mask = mask==1 |
| | |
| | elif tokenized_data!=None: |
| | n_seq = len(tokenized_data) |
| | mask = tokenized_data!=mask_value |
| | |
| | for ii in range(n_seq): |
| | |
| | id_1 = (mask[ii]==True).nonzero(as_tuple=True)[0] |
| | |
| | |
| | mask[ii,1:id_1[0]]=True |
| | |
| | return mask |
| |
|
| | |
| | def read_one_len_from_padding_vec( |
| | in_np_array, |
| | padding_val=0.0, |
| | ): |
| | mask = in_np_array!=padding_val |
| | id_list_all_1 = mask.nonzero()[0] |
| | vec_len = id_list_all_1[-1]+1 |
| | |
| | return vec_len |
| |
|
| |
|
| | |
| | def decode_one_ems_token_rec_for_folding_with_mask( |
| | this_token, |
| | this_logits, |
| | esm_alphabet, |
| | esm_model, |
| | this_mask, |
| | ): |
| | |
| | |
| | |
| | |
| | use_logits = this_logits |
| | use_logits[:,uncomm_idx_list]=-float('inf') |
| | use_token = use_logits.max(1).indices |
| | |
| | print(use_token) |
| | use_token = use_token[this_mask==True] |
| | |
| | |
| | this_seq = [] |
| | |
| | |
| | for ii in range(len(use_token)): |
| | |
| | |
| | |
| | this_seq.append( |
| | esm_alphabet.get_tok(use_token[ii]) |
| | ) |
| | |
| | this_seq = "".join(this_seq) |
| | |
| | return this_seq |
| |
|
| | def decode_many_ems_token_rec_for_folding_with_mask( |
| | batch_tokens, |
| | batch_logits, |
| | esm_alphabet, |
| | esm_model, |
| | mask): |
| | |
| | rev_y_seq = [] |
| | for jj in range(len(batch_tokens)): |
| | |
| | this_seq = decode_one_ems_token_rec_for_folding_with_mask( |
| | batch_tokens[jj], |
| | batch_logits[jj], |
| | esm_alphabet, |
| | esm_model, |
| | mask[jj] |
| | ) |
| | rev_y_seq.append(this_seq) |
| | return rev_y_seq |
| |
|
| | |
| | |
| | |
| | from scipy import interpolate |
| |
|
| | def interpolate_and_resample_ForcPath(y0,seq_len1): |
| | seq_len0=len(y0)-1 |
| | x0=np.arange(0., 1.+1./seq_len0, 1./seq_len0) |
| | f=interpolate.interp1d(x0,y0) |
| | |
| | x1=np.arange(0., 1.+1./seq_len1, 1./seq_len1) |
| | y1=f(x1) |
| | |
| | resu = {} |
| | resu['y1']=y1 |
| | resu['x1']=x1 |
| | resu['x0']=x0 |
| | return resu |
| | |
| | def mix_two_ForcPath(y0,y1,seq_len2): |
| | seq_len0=len(y0)-1 |
| | x0=np.arange(0., 1.+1./seq_len0, 1./seq_len0) |
| | seq_len1=len(y1)-1 |
| | x1=np.arange(0., 1.+1./seq_len1, 1./seq_len1) |
| | f0=interpolate.interp1d(x0,y0) |
| | f1=interpolate.interp1d(x1,y1) |
| | |
| | x2=np.arange(0., 1.+1./seq_len2, 1./seq_len2) |
| | y2=(f0(x2)+f1(x2))/1. |
| | |
| | resu={} |
| | resu['y2']=y2 |
| | resu['x2']=x2 |
| | resu['x1']=x1 |
| | resu['x0']=x0 |
| | return resu |
| | |
| | |
| | |
| | |
| | import esm |
| |
|
| | def load_in_pLM(pLM_Model_Name,device): |
| | |
| | |
| | if pLM_Model_Name=='trivial': |
| | pLM_Model=None |
| | esm_alphabet=None |
| | len_toks=0 |
| | esm_layer=0 |
| |
|
| | elif pLM_Model_Name=='esm2_t33_650M_UR50D': |
| | |
| | esm_layer=33 |
| | pLM_Model, esm_alphabet = esm.pretrained.esm2_t33_650M_UR50D() |
| | len_toks=len(esm_alphabet.all_toks) |
| | pLM_Model.eval() |
| | pLM_Model. to(device) |
| |
|
| | elif pLM_Model_Name=='esm2_t36_3B_UR50D': |
| | |
| | esm_layer=36 |
| | pLM_Model, esm_alphabet = esm.pretrained.esm2_t36_3B_UR50D() |
| | len_toks=len(esm_alphabet.all_toks) |
| | pLM_Model.eval() |
| | pLM_Model. to(device) |
| |
|
| | elif pLM_Model_Name=='esm2_t30_150M_UR50D': |
| | |
| | esm_layer=30 |
| | pLM_Model, esm_alphabet = esm.pretrained.esm2_t30_150M_UR50D() |
| | len_toks=len(esm_alphabet.all_toks) |
| | pLM_Model.eval() |
| | pLM_Model. to(device) |
| |
|
| | elif pLM_Model_Name=='esm2_t12_35M_UR50D': |
| | |
| | esm_layer=12 |
| | pLM_Model, esm_alphabet = esm.pretrained.esm2_t12_35M_UR50D() |
| | len_toks=len(esm_alphabet.all_toks) |
| | pLM_Model.eval() |
| | pLM_Model. to(device) |
| |
|
| | else: |
| | print("pLM model is missing...") |
| | |
| | return pLM_Model, esm_alphabet, esm_layer, len_toks |