| """ |
| Evaluate a finetuned peptide model checkpoint by sampling sequences |
| and computing metrics for the De Novo Peptide Generation table: |
| Validity (%), Affinity (↑), Solubility (↑), Hemolysis (↑), |
| Nonfouling (↑), Permeability (↑), Sampling Time (↓) |
| """ |
|
|
| import os |
| import sys |
| import argparse |
| import time |
| import torch |
| import numpy as np |
| import pandas as pd |
|
|
| |
| REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
| sys.path.insert(0, REPO_ROOT) |
|
|
| from lightning_modules.any_length_remask import AnyOrderInsertionFlowModuleFT |
| from lightning_modules import AnyOrderInsertionFlowModule |
| from inference_quality import sample_peptides_eval |
| from pep_scoring.scoring_functions import ScoringFunctions |
| from pep_utils.analyzer import PeptideAnalyzer |
| from pep_scoring.tokenizer.my_tokenizers import SMILES_SPE_Tokenizer |
| from finetune_quality import PeptideFinetuner |
| from pep_utils.utils import str2bool, set_seed |
| from tdc import Evaluator |
|
|
|
|
| |
| PROTEINS = { |
| 'amhr': 'MLGSLGLWALLPTAVEAPPNRRTCVFFEAPGVRGSTKTLGELLDTGTELPRAIRCLYSRCCFGIWNLTQDRAQVEMQGCRDSDEPGCESLHCDPSPRAHPSPGSTLFTCSCGTDFCNANYSHLPPPGSPGTPGSQGPQAAPGESIWMALVLLGLFLLLLLLLGSIILALLQRKNYRVRGEPVPEPRPDSGRDWSVELQELPELCFSQVIREGGHAVVWAGQLQGKLVAIKAFPPRSVAQFQAERALYELPGLQHDHIVRFITASRGGPGRLLSGPLLVLELHPKGSLCHYLTQYTSDWGSSLRMALSLAQGLAFLHEERWQNGQYKPGIAHRDLSSQNVLIREDGSCAIGDLGLALVLPGLTQPPAWTPTQPQGPAAIMEAGTQRYMAPELLDKTLDLQDWGMALRRADIYSLALLLWEILSRCPDLRPDSSPPPFQLAYEAELGNTPTSDELWALAVQERRRPYIPSTWRCFATDPDGLRELLEDCWDADPEARLTAECVQQRLAALAHPQESHPFPESCPRGCPPLCPEDCTSIPAPTILPCRPQRSACHFSVQQGPCSRNPQPACTLSPV', |
| 'tfr': 'MMDQARSAFSNLFGGEPLSYTRFSLARQVDGDNSHVEMKLAVDEEENADNNTKANVTKPKRCSGSICYGTIAVIVFFLIGFMIGYLGYCKGVEPKTECERLAGTESPVREEPGEDFPAARRLYWDDLKRKLSEKLDSTDFTGTIKLLNENSYVPREAGSQKDENLALYVENQFREFKLSKVWRDQHFVKIQVKDSAQNSVIIVDKNGRLVYLVENPGGYVAYSKAATVTGKLVHANFGTKKDFEDLYTPVNGSIVIVRAGKITFAEKVANAESLNAIGVLIYMDQTKFPIVNAELSFFGHAHLGTGDPYTPGFPSFNHTQFPPSRSSGLPNIPVQTISRAAAEKLFGNMEGDCPSDWKTDSTCRMVTSESKNVKLTVSNVLKEIKILNIFGVIKGFVEPDHYVVVGAQRDAWGPGAAKSGVGTALLLKLAQMFSDMVLKDGFQPSRSIIFASWSAGDFGSVGATEWLEGYLSSLHLKAFTYINLDKAVLGTSNFKVSASPLLYTLIEKTMQNVKHPVTGQFLYQDSNWASKVEKLTLDNAAFPFLAYSGIPAVSFCFCEDTDYPYLGTTMDTYKELIERIPELNKVARAAAEVAGQFVIKLTHDVELNLDYERYNSQLLSFVRDLNQYRADIKEMGLSLQWLYSARGDFFRATSRLTTDFGNAEKTDRFVMKKLNDRVMRVEYHFLSPYVSPKESPFRHVFWGSGSHTLPALLENLKLRKQNNGAFNETLFRNQLALATWTIQGAANALSGDVWDIDNEF', |
| 'gfap': 'MERRRITSAARRSYVSSGEMMVGGLAPGRRLGPGTRLSLARMPPPLPTRVDFSLAGALNAGFKETRASERAEMMELNDRFASYIEKVRFLEQQNKALAAELNQLRAKEPTKLADVYQAELRELRLRLDQLTANSARLEVERDNLAQDLATVRQKLQDETNLRLEAENNLAAYRQEADEATLARLDLERKIESLEEEIRFLRKIHEEEVRELQEQLARQQVHVELDVAKPDLTAALKEIRTQYEAMASSNMHEAEEWYRSKFADLTDAAARNAELLRQAKHEANDYRRQLQSLTCDLESLRGTNESLERQMREQEERHVREAASYQEALARLEEEGQSLKDEMARHLQEYQDLLNVKLALDIEIATYRKLLEGEENRITIPVQTFSNLQIRETSLDTKSVSEGHLKRNIVVKTVEMRDGEVIKESKQEHKDVM', |
| 'glp1': 'MAGAPGPLRLALLLLGMVGRAGPRPQGATVSLWETVQKWREYRRQCQRSLTEDPPPATDLFCNRTFDEYACWPDGEPGSFVNVSCPWYLPWASSVPQGHVYRFCTAEGLWLQKDNSSLPWRDLSECEESKRGERSSPEEQLLFLYIIYTVGYALSFSALVIASAILLGFRHLHCTRNYIHLNLFASFILRALSVFIKDAALKWMYSTAAQQHQWDGLLSYQDSLSCRLVFLLMQYCVAANYYWLLVEGVYLYTLLAFSVLSEQWIFRLYVSIGWGVPLLFVVPWGIVKYLYEDEGCWTRNSNMNYWLIIRLPILFAIGVNFLIFVRVICIVVSKLKANLMCKTDIKCRLAKSTLTLIPLLGTHEVIFAFVMDEHARGTLRFIKLFTELSFTSFQGLMVAILYCFVNNEVQLEFRKSWERWRLEHLHIQRDSSMKPLKCPTSSLSSGATAGSSMYTATCQASCS', |
| 'glast': 'MTKSNGEEPKMGGRMERFQQGVRKRTLLAKKKVQNITKEDVKSYLFRNAFVLLTVTAVIVGTILGFTLRPYRMSYREVKYFSFPGELLMRMLQMLVLPLIISSLVTGMAALDSKASGKMGMRAVVYYMTTTIIAVVIGIIIVIIIHPGKGTKENMHREGKIVRVTAADAFLDLIRNMFPPNLVEACFKQFKTNYEKRSFKVPIQANETLVGAVINNVSEAMETLTRITEELVPVPGSVNGVNALGLVVFSMCFGFVIGNMKEQGQALREFFDSLNEAIMRLVAVIMWYAPVGILFLIAGKIVEMEDMGVIGGQLAMYTVTVIVGLLIHAVIVLPLLYFLVTRKNPWVFIGGLLQALITALGTSSSSATLPITFKCLEENNGVDKRVTRFVLPVGATINMDGTALYEALAAIFIAQVNNFELNFGQIITISITATAASIGAAGIPQAGLVTMVIVLTSVGLPTDDITLIIAVDWFLDRLRTTTNVLGDSLGAGIVEHLSRHELKNRDVEMGNSVIEENEMKKPYQLIAQDNETEKPIDSETKM', |
| 'ncam': 'LQTKDLIWTLFFLGTAVSLQVDIVPSQGEISVGESKFFLCQVAGDAKDKDISWFSPNGEKLTPNQQRISVVWNDDSSSTLTIYNANIDDAGIYKCVVTGEDGSESEATVNVKIFQKLMFKNAPTPQEFREGEDAVIVCDVVSSLPPTIIWKHKGRDVILKKDVRFIVLSNNYLQIRGIKKTDEGTYRCEGRILARGEINFKDIQVIVNVPPTIQARQNIVNATANLGQSVTLVCDAEGFPEPTMSWTKDGEQIEQEEDDEKYIFSDDSSQLTIKKVDKNDEAEYICIAENKAGEQDATIHLKVFAKPKITYVENQTAMELEEQVTLTCEASGDPIPSITWRTSTRNISSEEKASWTRPEKQETLDGHMVVRSHARVSSLTLKSIQYTDAGEYICTASNTIGQDSQSMYLEVQYAPKLQGPVAVYTWEGNQVNITCEVFAYPSATISWFRDGQLLPSSNYSNIKIYNTPSASYLEVTPDSENDFGNYNCTAVNRIGQESLEFILVQADTPSSPSIDQVEPYSSTAQVQFDEPEATGGVPILKYKAEWRAVGEEVWHSKWYDAKEASMEGIVTIVGLKPETTYAVRLAALNGKGLGEISAASEF', |
| 'cereblon': 'MAGEGDQQDAAHNMGNHLPLLPAESEEEDEMEVEDQDSKEAKKPNIINFDTSLPTSHTYLGADMEEFHGRTLHDDDSCQVIPVLPQVMMILIPGQTLPLQLFHPQEVSMVRNLIQKDRTFAVLAYSNVQEREAQFGTTAEIYAYREEQDFGIEIVKVKAIGRQRFKVLELRTQSDGIQQAKVQILPECVLPSTMSAVQLESLNKCQIFPSKPVSREDQCSYKWWQKYQKRKFHCANLTSWPRWLYSLYDAETLMDRIKKQLREWDENLKDDSLPSNPIDFSYRVAACLPIDDVLRIQLLKIGSAIQRLRCELDIMNKCTSLCCKQCQETEITTKNEIFSLSLCGPMAAYVNPHGYVHETLTVYKACNLNLIGRPSTEHSWFPGYAWTVAQCKICASHIGWKFTATKKDMSPQKFWGLTRSALLPTIPDTEDEISPDKVILCL', |
| 'ligase': 'MASQPPEDTAESQASDELECKICYNRYNLKQRKPKVLECCHRVCAKCLYKIIDFGDSPQGVIVCPFCRFETCLPDDEVSSLPDDNNILVNLTCGGKGKKCLPENPTELLLTPKRLASLVSPSHTSSNCLVITIMEVQRESSPSLSSTPVVEFYRPASFDSVTTVSHNWTVWNCTSLLFQTSIRVLVWLLGLLYFSSLPLGIYLLVSKKVTLGVVFVSLVPSSLVILMVYGFCQCVCHEFLDCMAPPS', |
| 'skp2': 'MHRKHLQEIPDLSSNVATSFTWGWDSSKTSELLSGMGVSALEKEEPDSENIPQELLSNLGHPESPPRKRLKSKGSDKDFVIVRRPKLNRENFPGVSWDSLPDELLLGIFSCLCLPELLKVSGVCKRWYRLASDESLWQTLDLTGKNLHPDVTGRLLSQGVIAFRCPRSFMDQPLAEHFSPFRVQHMDLSNSVIEVSTLHGILSQCSKLQNLSLEGLRLSDPIVNTLAKNSNLVRLNLSGCSGFSEFALQTLLSSCSRLDELNLSWCFDFTEKHVQVAVAHVSETITQLNLSGYRKNLQKSDLSTLVRRCPNLVHLDLSDSVMLKNDCFQEFFQLNYLQHLSLSRCYDIIPETLLELGEIPTLKTLQVFGIVPDGTLQLLKEALPHLQINCSHFTTIARPTIGNKKNQEIWGIKCRLTLQKPSCL', |
| } |
|
|
|
|
| def load_finetuned_model(checkpoint_path, pretrained_ckpt_path, device='cuda'): |
| """Load a finetuned PeptideFinetuner from a Lightning checkpoint.""" |
| ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=False) |
| hparams = ckpt.get('hyper_parameters', {}) |
| args = hparams.get('args', None) |
|
|
| |
| base_ckpt = torch.load(pretrained_ckpt_path, map_location='cpu', weights_only=False) |
| if 'hyper_parameters' in base_ckpt: |
| config = base_ckpt['hyper_parameters']['config'] |
| elif 'config' in base_ckpt: |
| config = base_ckpt['config'] |
| else: |
| raise ValueError("Cannot find config in base checkpoint") |
|
|
| from omegaconf import OmegaConf, DictConfig |
| if not OmegaConf.is_config(config): |
| config = DictConfig(config) |
| OmegaConf.set_struct(config, False) |
|
|
| config.training.use_adaptive_schedule = getattr(args, 'use_adaptive_schedule', True) |
| config.training.schedule_hidden_dim = getattr(args, 'schedule_hidden_dim', 256) |
| config.training.schedule_num_layers = getattr(args, 'schedule_num_layers', 2) |
| config.training.schedule_loss_weight = getattr(args, 'schedule_loss_weight', 0.1) |
| config.training.freeze_base_model = getattr(args, 'freeze_base_model', False) |
| config.training.schedule_warmup_epochs = getattr(args, 'schedule_warmup_epochs', 0) |
| OmegaConf.set_struct(config, True) |
|
|
| disable_planner = getattr(args, 'disable_planner', False) |
|
|
| policy_model = AnyOrderInsertionFlowModuleFT( |
| config=config, |
| args=args, |
| pretrained_checkpoint=pretrained_ckpt_path, |
| insertion_planner=not disable_planner, |
| ) |
|
|
| |
| state_dict = ckpt['state_dict'] |
| policy_state = {} |
| for k, v in state_dict.items(): |
| if k.startswith('policy_model.'): |
| policy_state[k[len('policy_model.'):]] = v |
| policy_model.load_state_dict(policy_state, strict=False) |
| policy_model = policy_model.to(device) |
| policy_model.eval() |
|
|
| return policy_model, args, config |
|
|
|
|
| @torch.no_grad() |
| def evaluate_checkpoint(policy_model, tokenizer, reward_model, analyzer, |
| num_samples=1000, batch_size=50, max_length=512, |
| total_num_steps=256, quality_mode="both", num_remasking=3, |
| quality_threshold=0.5, unmask_quality_threshold=None, device='cuda'): |
| """ |
| Sample `num_samples` peptides and compute all table metrics. |
| Returns a dict with: validity, affinity, sol, hemo, nf, permeability, sampling_time |
| """ |
| all_affinity = [] |
| all_sol = [] |
| all_hemo = [] |
| all_nf = [] |
| all_permeability = [] |
| all_valid_seqs = [] |
| total_valid = 0 |
| total_generated = 0 |
| total_time = 0.0 |
|
|
| num_batches = (num_samples + batch_size - 1) // batch_size |
| remaining = num_samples |
|
|
| for b in range(num_batches): |
| bs = min(batch_size, remaining) |
| remaining -= bs |
|
|
| t_start = time.time() |
| result = sample_peptides_eval( |
| model=policy_model, |
| reward_model=reward_model, |
| analyzer=analyzer, |
| tokenizer=tokenizer, |
| steps=total_num_steps, |
| mask=policy_model.interpolant.mask_token, |
| pad=policy_model.interpolant.pad_token, |
| batch_size=bs, |
| max_length=max_length, |
| quality_mode=quality_mode, |
| num_remasking=num_remasking, |
| quality_threshold=quality_threshold, |
| unmask_quality_threshold=unmask_quality_threshold, |
| return_valid=True, |
| ) |
| t_end = time.time() |
|
|
| |
| valid_seqs, affinity, sol, hemo, nf, permeability, valid_fraction = result |
|
|
| batch_valid = len(valid_seqs) |
| total_valid += batch_valid |
| total_generated += bs |
| total_time += (t_end - t_start) |
| all_valid_seqs.extend(valid_seqs) |
|
|
| if isinstance(affinity, (list, np.ndarray)) and len(affinity) > 0: |
| all_affinity.extend(affinity if isinstance(affinity, list) else affinity.tolist()) |
| all_sol.extend(sol if isinstance(sol, list) else sol.tolist()) |
| all_hemo.extend(hemo if isinstance(hemo, list) else hemo.tolist()) |
| all_nf.extend(nf if isinstance(nf, list) else nf.tolist()) |
| all_permeability.extend(permeability if isinstance(permeability, list) else permeability.tolist()) |
|
|
| print(f" Batch {b+1}/{num_batches}: {batch_valid}/{bs} valid, " |
| f"time={t_end - t_start:.1f}s") |
|
|
| validity = total_valid / total_generated * 100.0 if total_generated > 0 else 0.0 |
|
|
| |
| |
| |
| all_unique = list(set(all_valid_seqs)) |
| num_unique = len(all_unique) |
| uniqueness = num_unique / total_valid * 100.0 if total_valid > 0 else 0.0 |
| if num_unique > 1: |
| diversity = Evaluator('diversity')(all_unique) |
| else: |
| diversity = 0.0 |
|
|
| metrics = { |
| 'Validity (%)': validity, |
| 'Uniqueness (%)': uniqueness, |
| 'Diversity': diversity, |
| 'Affinity': np.mean(all_affinity) if all_affinity else 0.0, |
| 'Affinity Std': np.std(all_affinity) if all_affinity else 0.0, |
| 'Solubility': np.mean(all_sol) if all_sol else 0.0, |
| 'Solubility Std': np.std(all_sol) if all_sol else 0.0, |
| 'Hemolysis': np.mean(all_hemo) if all_hemo else 0.0, |
| 'Hemolysis Std': np.std(all_hemo) if all_hemo else 0.0, |
| 'Nonfouling': np.mean(all_nf) if all_nf else 0.0, |
| 'Nonfouling Std': np.std(all_nf) if all_nf else 0.0, |
| 'Permeability': np.mean(all_permeability) if all_permeability else 0.0, |
| 'Permeability Std': np.std(all_permeability) if all_permeability else 0.0, |
| 'Sampling Time (s)': total_time, |
| 'Num Generated': total_generated, |
| 'Num Valid': total_valid, |
| 'Num Unique': num_unique, |
| } |
|
|
| return metrics |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Evaluate a finetuned peptide checkpoint") |
| parser.add_argument('--checkpoint_path', type=str, required=True, |
| help='Path to the finetuned Lightning checkpoint (e.g., last.ckpt)') |
| parser.add_argument('--pretrained_ckpt', type=str, |
| default=os.path.join(REPO_ROOT, 'pretrained', 'anylength_pep.ckpt'), |
| help='Path to the pretrained base model checkpoint') |
| parser.add_argument('--num_samples', type=int, default=500, |
| help='Number of peptides to sample') |
| parser.add_argument('--batch_size', type=int, default=50, |
| help='Batch size for sampling') |
| parser.add_argument('--max_length', type=int, default=512) |
| parser.add_argument('--total_num_steps', type=int, default=256) |
| parser.add_argument('--num_remasking', type=int, default=3) |
| parser.add_argument('--quality_threshold', type=float, default=0.5, |
| help='Threshold for insertion quality filtering during sampling') |
| parser.add_argument('--unmask_quality_threshold', type=float, default=None, |
| help='If set, gate unmasking/remasking by confidence: remask ' |
| 'ALL clean tokens whose unmasking confidence is below this ' |
| 'threshold, regardless of the schedule budget. If unset ' |
| '(default), remasking is purely schedule-driven (count-based).') |
| parser.add_argument('--prot_name', type=str, default='glast', |
| help='Target protein name (must be one of: ' + ', '.join(PROTEINS.keys()) + ')') |
| parser.add_argument('--prot_seq', type=str, default=None, |
| help='Custom protein sequence (overrides --prot_name)') |
| parser.add_argument('--disable_planner', action='store_true', |
| help='If set, disable remasking during evaluation') |
| parser.add_argument('--disable_insertion_planner', action='store_true', |
| help='If set, disable insertion quality filtering during evaluation') |
| parser.add_argument('--disable_unmasking_planner', action='store_true', |
| help='If set, disable unmasking confidence planner during evaluation') |
| parser.add_argument('--output_dir', type=str, default=None, |
| help='Directory to save results CSV. Defaults to checkpoint directory.') |
| parser.add_argument('--device', type=str, default='cuda:0') |
| parser.add_argument('--seed', type=int, default=42) |
| args = parser.parse_args() |
|
|
| set_seed(args.seed, use_cuda=True) |
| device = torch.device(args.device if torch.cuda.is_available() else 'cpu') |
|
|
| |
| if args.disable_planner: |
| quality_mode = "none" |
| elif args.disable_insertion_planner and args.disable_unmasking_planner: |
| quality_mode = "none" |
| elif args.disable_insertion_planner: |
| quality_mode = "unmasking_only" |
| elif args.disable_unmasking_planner: |
| quality_mode = "insertion_only" |
| else: |
| quality_mode = "both" |
|
|
| print(f"Loading checkpoint: {args.checkpoint_path}") |
| print(f"Pretrained base: {args.pretrained_ckpt}") |
| print(f"Quality mode: {quality_mode}") |
|
|
| policy_model, train_args, config = load_finetuned_model( |
| args.checkpoint_path, args.pretrained_ckpt, device=device |
| ) |
|
|
| |
| tokenizer = SMILES_SPE_Tokenizer( |
| os.path.join(REPO_ROOT, 'a2d2_pep', 'pep_scoring', 'tokenizer', 'new_vocab.txt'), |
| os.path.join(REPO_ROOT, 'a2d2_pep', 'pep_scoring', 'tokenizer', 'new_splits.txt') |
| ) |
|
|
| if args.prot_seq is not None: |
| prot = args.prot_seq |
| prot_name = args.prot_name |
| else: |
| prot_name = args.prot_name |
| if prot_name not in PROTEINS: |
| raise ValueError(f"Unknown protein: {prot_name}. Choose from: {list(PROTEINS.keys())}") |
| prot = PROTEINS[prot_name] |
|
|
| score_func_names = ['binding_affinity1', 'solubility', 'hemolysis', 'nonfouling', 'permeability'] |
| reward_model = ScoringFunctions(score_func_names, prot_seqs=[prot], device=device) |
| analyzer = PeptideAnalyzer() |
|
|
| print(f"\nSampling {args.num_samples} peptides (quality_mode={quality_mode}, target={prot_name})...") |
|
|
| metrics = evaluate_checkpoint( |
| policy_model=policy_model, |
| tokenizer=tokenizer, |
| reward_model=reward_model, |
| analyzer=analyzer, |
| num_samples=args.num_samples, |
| batch_size=args.batch_size, |
| max_length=args.max_length, |
| total_num_steps=args.total_num_steps, |
| quality_mode=quality_mode, |
| num_remasking=args.num_remasking, |
| quality_threshold=args.quality_threshold, |
| unmask_quality_threshold=args.unmask_quality_threshold, |
| device=device, |
| ) |
|
|
| |
| print("\n" + "=" * 60) |
| print(" De Novo Peptide Generation Results") |
| print("=" * 60) |
| for k, v in metrics.items(): |
| if isinstance(v, float): |
| print(f" {k:<30s}: {v:.4f}") |
| else: |
| print(f" {k:<30s}: {v}") |
| print("=" * 60) |
|
|
| |
| output_dir = args.output_dir or os.path.dirname(args.checkpoint_path) |
| os.makedirs(output_dir, exist_ok=True) |
|
|
| if args.disable_planner: |
| tag = "no_planner" |
| elif args.disable_insertion_planner: |
| tag = "no_insertion_planner" |
| elif args.disable_unmasking_planner: |
| tag = "no_unmasking_planner" |
| else: |
| tag = "with_planner" |
| if args.unmask_quality_threshold is not None: |
| tag += f"_ut{args.unmask_quality_threshold:g}" |
| |
| metrics['unmask_quality_threshold'] = args.unmask_quality_threshold |
| metrics['quality_threshold'] = args.quality_threshold |
| metrics_path = os.path.join(output_dir, f'eval_metrics_{tag}_{prot_name}.csv') |
| pd.DataFrame([metrics]).to_csv(metrics_path, index=False) |
| print(f"Metrics saved to: {metrics_path}") |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|