File size: 965 Bytes
ca7299e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 | import esm
import torch
from Bio import SeqIO
class ESMFold_Pred():
def __init__(self, device):
self._folding_model = esm.pretrained.esmfold_v1().eval()
self._folding_model.requires_grad_(False)
self._folding_model.to(device)
def predict_str(self, pdbfile, save_path, max_seq_len = 1500):
seq_record = SeqIO.parse(pdbfile, "pdb-atom")
count = 0
seq_list = []
for record in seq_record:
seq = str(record.seq)
# seq = seq.replace("X","")
if len(seq) > max_seq_len:
continue
print(f'seq {count}:',seq)
seq_list.append(seq)
count += 1
for idx, seq in enumerate(seq_list):
with torch.no_grad():
output = self._folding_model.infer_pdb(seq)
with open(save_path, "w+") as f:
f.write(output)
break # only infer for the first seq
|