| | import esm |
| | import torch |
| |
|
| | from Bio import SeqIO |
| |
|
| | class ESMFold_Pred(): |
| | def __init__(self, device): |
| | self._folding_model = esm.pretrained.esmfold_v1().eval() |
| | self._folding_model.requires_grad_(False) |
| | self._folding_model.to(device) |
| |
|
| | def predict_str(self, pdbfile, save_path, max_seq_len = 1500): |
| | seq_record = SeqIO.parse(pdbfile, "pdb-atom") |
| | count = 0 |
| | seq_list = [] |
| | for record in seq_record: |
| | seq = str(record.seq) |
| | |
| |
|
| | if len(seq) > max_seq_len: |
| | continue |
| |
|
| | print(f'seq {count}:',seq) |
| | seq_list.append(seq) |
| | count += 1 |
| | |
| | for idx, seq in enumerate(seq_list): |
| | with torch.no_grad(): |
| | output = self._folding_model.infer_pdb(seq) |
| | with open(save_path, "w+") as f: |
| | f.write(output) |
| | break |
| |
|
| |
|
| |
|
| |
|