A2D2 / a2d2_pep /evaluate_peptide_table.py
Sophia
initial commit
8019be0
Raw
History Blame Contribute Delete
18 kB
"""
Evaluate a finetuned peptide model checkpoint by sampling sequences
and computing metrics for the De Novo Peptide Generation table:
Validity (%), Affinity (↑), Solubility (↑), Hemolysis (↑),
Nonfouling (↑), Permeability (↑), Sampling Time (↓)
"""
import os
import sys
import argparse
import time
import torch
import numpy as np
import pandas as pd
# add repo root (A2D2/) to sys.path so top-level packages like lightning_modules resolve
REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, REPO_ROOT)
from lightning_modules.any_length_remask import AnyOrderInsertionFlowModuleFT
from lightning_modules import AnyOrderInsertionFlowModule
from inference_quality import sample_peptides_eval
from pep_scoring.scoring_functions import ScoringFunctions
from pep_utils.analyzer import PeptideAnalyzer
from pep_scoring.tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
from finetune_quality import PeptideFinetuner
from pep_utils.utils import str2bool, set_seed
from tdc import Evaluator
# Protein sequences
PROTEINS = {
'amhr': 'MLGSLGLWALLPTAVEAPPNRRTCVFFEAPGVRGSTKTLGELLDTGTELPRAIRCLYSRCCFGIWNLTQDRAQVEMQGCRDSDEPGCESLHCDPSPRAHPSPGSTLFTCSCGTDFCNANYSHLPPPGSPGTPGSQGPQAAPGESIWMALVLLGLFLLLLLLLGSIILALLQRKNYRVRGEPVPEPRPDSGRDWSVELQELPELCFSQVIREGGHAVVWAGQLQGKLVAIKAFPPRSVAQFQAERALYELPGLQHDHIVRFITASRGGPGRLLSGPLLVLELHPKGSLCHYLTQYTSDWGSSLRMALSLAQGLAFLHEERWQNGQYKPGIAHRDLSSQNVLIREDGSCAIGDLGLALVLPGLTQPPAWTPTQPQGPAAIMEAGTQRYMAPELLDKTLDLQDWGMALRRADIYSLALLLWEILSRCPDLRPDSSPPPFQLAYEAELGNTPTSDELWALAVQERRRPYIPSTWRCFATDPDGLRELLEDCWDADPEARLTAECVQQRLAALAHPQESHPFPESCPRGCPPLCPEDCTSIPAPTILPCRPQRSACHFSVQQGPCSRNPQPACTLSPV',
'tfr': 'MMDQARSAFSNLFGGEPLSYTRFSLARQVDGDNSHVEMKLAVDEEENADNNTKANVTKPKRCSGSICYGTIAVIVFFLIGFMIGYLGYCKGVEPKTECERLAGTESPVREEPGEDFPAARRLYWDDLKRKLSEKLDSTDFTGTIKLLNENSYVPREAGSQKDENLALYVENQFREFKLSKVWRDQHFVKIQVKDSAQNSVIIVDKNGRLVYLVENPGGYVAYSKAATVTGKLVHANFGTKKDFEDLYTPVNGSIVIVRAGKITFAEKVANAESLNAIGVLIYMDQTKFPIVNAELSFFGHAHLGTGDPYTPGFPSFNHTQFPPSRSSGLPNIPVQTISRAAAEKLFGNMEGDCPSDWKTDSTCRMVTSESKNVKLTVSNVLKEIKILNIFGVIKGFVEPDHYVVVGAQRDAWGPGAAKSGVGTALLLKLAQMFSDMVLKDGFQPSRSIIFASWSAGDFGSVGATEWLEGYLSSLHLKAFTYINLDKAVLGTSNFKVSASPLLYTLIEKTMQNVKHPVTGQFLYQDSNWASKVEKLTLDNAAFPFLAYSGIPAVSFCFCEDTDYPYLGTTMDTYKELIERIPELNKVARAAAEVAGQFVIKLTHDVELNLDYERYNSQLLSFVRDLNQYRADIKEMGLSLQWLYSARGDFFRATSRLTTDFGNAEKTDRFVMKKLNDRVMRVEYHFLSPYVSPKESPFRHVFWGSGSHTLPALLENLKLRKQNNGAFNETLFRNQLALATWTIQGAANALSGDVWDIDNEF',
'gfap': 'MERRRITSAARRSYVSSGEMMVGGLAPGRRLGPGTRLSLARMPPPLPTRVDFSLAGALNAGFKETRASERAEMMELNDRFASYIEKVRFLEQQNKALAAELNQLRAKEPTKLADVYQAELRELRLRLDQLTANSARLEVERDNLAQDLATVRQKLQDETNLRLEAENNLAAYRQEADEATLARLDLERKIESLEEEIRFLRKIHEEEVRELQEQLARQQVHVELDVAKPDLTAALKEIRTQYEAMASSNMHEAEEWYRSKFADLTDAAARNAELLRQAKHEANDYRRQLQSLTCDLESLRGTNESLERQMREQEERHVREAASYQEALARLEEEGQSLKDEMARHLQEYQDLLNVKLALDIEIATYRKLLEGEENRITIPVQTFSNLQIRETSLDTKSVSEGHLKRNIVVKTVEMRDGEVIKESKQEHKDVM',
'glp1': 'MAGAPGPLRLALLLLGMVGRAGPRPQGATVSLWETVQKWREYRRQCQRSLTEDPPPATDLFCNRTFDEYACWPDGEPGSFVNVSCPWYLPWASSVPQGHVYRFCTAEGLWLQKDNSSLPWRDLSECEESKRGERSSPEEQLLFLYIIYTVGYALSFSALVIASAILLGFRHLHCTRNYIHLNLFASFILRALSVFIKDAALKWMYSTAAQQHQWDGLLSYQDSLSCRLVFLLMQYCVAANYYWLLVEGVYLYTLLAFSVLSEQWIFRLYVSIGWGVPLLFVVPWGIVKYLYEDEGCWTRNSNMNYWLIIRLPILFAIGVNFLIFVRVICIVVSKLKANLMCKTDIKCRLAKSTLTLIPLLGTHEVIFAFVMDEHARGTLRFIKLFTELSFTSFQGLMVAILYCFVNNEVQLEFRKSWERWRLEHLHIQRDSSMKPLKCPTSSLSSGATAGSSMYTATCQASCS',
'glast': 'MTKSNGEEPKMGGRMERFQQGVRKRTLLAKKKVQNITKEDVKSYLFRNAFVLLTVTAVIVGTILGFTLRPYRMSYREVKYFSFPGELLMRMLQMLVLPLIISSLVTGMAALDSKASGKMGMRAVVYYMTTTIIAVVIGIIIVIIIHPGKGTKENMHREGKIVRVTAADAFLDLIRNMFPPNLVEACFKQFKTNYEKRSFKVPIQANETLVGAVINNVSEAMETLTRITEELVPVPGSVNGVNALGLVVFSMCFGFVIGNMKEQGQALREFFDSLNEAIMRLVAVIMWYAPVGILFLIAGKIVEMEDMGVIGGQLAMYTVTVIVGLLIHAVIVLPLLYFLVTRKNPWVFIGGLLQALITALGTSSSSATLPITFKCLEENNGVDKRVTRFVLPVGATINMDGTALYEALAAIFIAQVNNFELNFGQIITISITATAASIGAAGIPQAGLVTMVIVLTSVGLPTDDITLIIAVDWFLDRLRTTTNVLGDSLGAGIVEHLSRHELKNRDVEMGNSVIEENEMKKPYQLIAQDNETEKPIDSETKM',
'ncam': 'LQTKDLIWTLFFLGTAVSLQVDIVPSQGEISVGESKFFLCQVAGDAKDKDISWFSPNGEKLTPNQQRISVVWNDDSSSTLTIYNANIDDAGIYKCVVTGEDGSESEATVNVKIFQKLMFKNAPTPQEFREGEDAVIVCDVVSSLPPTIIWKHKGRDVILKKDVRFIVLSNNYLQIRGIKKTDEGTYRCEGRILARGEINFKDIQVIVNVPPTIQARQNIVNATANLGQSVTLVCDAEGFPEPTMSWTKDGEQIEQEEDDEKYIFSDDSSQLTIKKVDKNDEAEYICIAENKAGEQDATIHLKVFAKPKITYVENQTAMELEEQVTLTCEASGDPIPSITWRTSTRNISSEEKASWTRPEKQETLDGHMVVRSHARVSSLTLKSIQYTDAGEYICTASNTIGQDSQSMYLEVQYAPKLQGPVAVYTWEGNQVNITCEVFAYPSATISWFRDGQLLPSSNYSNIKIYNTPSASYLEVTPDSENDFGNYNCTAVNRIGQESLEFILVQADTPSSPSIDQVEPYSSTAQVQFDEPEATGGVPILKYKAEWRAVGEEVWHSKWYDAKEASMEGIVTIVGLKPETTYAVRLAALNGKGLGEISAASEF',
'cereblon': 'MAGEGDQQDAAHNMGNHLPLLPAESEEEDEMEVEDQDSKEAKKPNIINFDTSLPTSHTYLGADMEEFHGRTLHDDDSCQVIPVLPQVMMILIPGQTLPLQLFHPQEVSMVRNLIQKDRTFAVLAYSNVQEREAQFGTTAEIYAYREEQDFGIEIVKVKAIGRQRFKVLELRTQSDGIQQAKVQILPECVLPSTMSAVQLESLNKCQIFPSKPVSREDQCSYKWWQKYQKRKFHCANLTSWPRWLYSLYDAETLMDRIKKQLREWDENLKDDSLPSNPIDFSYRVAACLPIDDVLRIQLLKIGSAIQRLRCELDIMNKCTSLCCKQCQETEITTKNEIFSLSLCGPMAAYVNPHGYVHETLTVYKACNLNLIGRPSTEHSWFPGYAWTVAQCKICASHIGWKFTATKKDMSPQKFWGLTRSALLPTIPDTEDEISPDKVILCL',
'ligase': 'MASQPPEDTAESQASDELECKICYNRYNLKQRKPKVLECCHRVCAKCLYKIIDFGDSPQGVIVCPFCRFETCLPDDEVSSLPDDNNILVNLTCGGKGKKCLPENPTELLLTPKRLASLVSPSHTSSNCLVITIMEVQRESSPSLSSTPVVEFYRPASFDSVTTVSHNWTVWNCTSLLFQTSIRVLVWLLGLLYFSSLPLGIYLLVSKKVTLGVVFVSLVPSSLVILMVYGFCQCVCHEFLDCMAPPS',
'skp2': 'MHRKHLQEIPDLSSNVATSFTWGWDSSKTSELLSGMGVSALEKEEPDSENIPQELLSNLGHPESPPRKRLKSKGSDKDFVIVRRPKLNRENFPGVSWDSLPDELLLGIFSCLCLPELLKVSGVCKRWYRLASDESLWQTLDLTGKNLHPDVTGRLLSQGVIAFRCPRSFMDQPLAEHFSPFRVQHMDLSNSVIEVSTLHGILSQCSKLQNLSLEGLRLSDPIVNTLAKNSNLVRLNLSGCSGFSEFALQTLLSSCSRLDELNLSWCFDFTEKHVQVAVAHVSETITQLNLSGYRKNLQKSDLSTLVRRCPNLVHLDLSDSVMLKNDCFQEFFQLNYLQHLSLSRCYDIIPETLLELGEIPTLKTLQVFGIVPDGTLQLLKEALPHLQINCSHFTTIARPTIGNKKNQEIWGIKCRLTLQKPSCL',
}
def load_finetuned_model(checkpoint_path, pretrained_ckpt_path, device='cuda'):
"""Load a finetuned PeptideFinetuner from a Lightning checkpoint."""
ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
hparams = ckpt.get('hyper_parameters', {})
args = hparams.get('args', None)
# Load pretrained base checkpoint to get config
base_ckpt = torch.load(pretrained_ckpt_path, map_location='cpu', weights_only=False)
if 'hyper_parameters' in base_ckpt:
config = base_ckpt['hyper_parameters']['config']
elif 'config' in base_ckpt:
config = base_ckpt['config']
else:
raise ValueError("Cannot find config in base checkpoint")
from omegaconf import OmegaConf, DictConfig
if not OmegaConf.is_config(config):
config = DictConfig(config)
OmegaConf.set_struct(config, False)
config.training.use_adaptive_schedule = getattr(args, 'use_adaptive_schedule', True)
config.training.schedule_hidden_dim = getattr(args, 'schedule_hidden_dim', 256)
config.training.schedule_num_layers = getattr(args, 'schedule_num_layers', 2)
config.training.schedule_loss_weight = getattr(args, 'schedule_loss_weight', 0.1)
config.training.freeze_base_model = getattr(args, 'freeze_base_model', False)
config.training.schedule_warmup_epochs = getattr(args, 'schedule_warmup_epochs', 0)
OmegaConf.set_struct(config, True)
disable_planner = getattr(args, 'disable_planner', False)
policy_model = AnyOrderInsertionFlowModuleFT(
config=config,
args=args,
pretrained_checkpoint=pretrained_ckpt_path,
insertion_planner=not disable_planner,
)
# Load finetuned weights
state_dict = ckpt['state_dict']
policy_state = {}
for k, v in state_dict.items():
if k.startswith('policy_model.'):
policy_state[k[len('policy_model.'):]] = v
policy_model.load_state_dict(policy_state, strict=False)
policy_model = policy_model.to(device)
policy_model.eval()
return policy_model, args, config
@torch.no_grad()
def evaluate_checkpoint(policy_model, tokenizer, reward_model, analyzer,
num_samples=1000, batch_size=50, max_length=512,
total_num_steps=256, quality_mode="both", num_remasking=3,
quality_threshold=0.5, unmask_quality_threshold=None, device='cuda'):
"""
Sample `num_samples` peptides and compute all table metrics.
Returns a dict with: validity, affinity, sol, hemo, nf, permeability, sampling_time
"""
all_affinity = []
all_sol = []
all_hemo = []
all_nf = []
all_permeability = []
all_valid_seqs = []
total_valid = 0
total_generated = 0
total_time = 0.0
num_batches = (num_samples + batch_size - 1) // batch_size
remaining = num_samples
for b in range(num_batches):
bs = min(batch_size, remaining)
remaining -= bs
t_start = time.time()
result = sample_peptides_eval(
model=policy_model,
reward_model=reward_model,
analyzer=analyzer,
tokenizer=tokenizer,
steps=total_num_steps,
mask=policy_model.interpolant.mask_token,
pad=policy_model.interpolant.pad_token,
batch_size=bs,
max_length=max_length,
quality_mode=quality_mode,
num_remasking=num_remasking,
quality_threshold=quality_threshold,
unmask_quality_threshold=unmask_quality_threshold,
return_valid=True,
)
t_end = time.time()
# Unpack: validSequences, affinity, sol, hemo, nf, permeability, valid_fraction
valid_seqs, affinity, sol, hemo, nf, permeability, valid_fraction = result
batch_valid = len(valid_seqs)
total_valid += batch_valid
total_generated += bs
total_time += (t_end - t_start)
all_valid_seqs.extend(valid_seqs)
if isinstance(affinity, (list, np.ndarray)) and len(affinity) > 0:
all_affinity.extend(affinity if isinstance(affinity, list) else affinity.tolist())
all_sol.extend(sol if isinstance(sol, list) else sol.tolist())
all_hemo.extend(hemo if isinstance(hemo, list) else hemo.tolist())
all_nf.extend(nf if isinstance(nf, list) else nf.tolist())
all_permeability.extend(permeability if isinstance(permeability, list) else permeability.tolist())
print(f" Batch {b+1}/{num_batches}: {batch_valid}/{bs} valid, "
f"time={t_end - t_start:.1f}s")
validity = total_valid / total_generated * 100.0 if total_generated > 0 else 0.0
# Uniqueness (% of valid sequences that are unique) and
# Diversity (1 - mean pairwise Tanimoto on Morgan FPs of unique sequences).
# Matches the convention used in evaluate_mol_table.py.
all_unique = list(set(all_valid_seqs))
num_unique = len(all_unique)
uniqueness = num_unique / total_valid * 100.0 if total_valid > 0 else 0.0
if num_unique > 1:
diversity = Evaluator('diversity')(all_unique)
else:
diversity = 0.0
metrics = {
'Validity (%)': validity,
'Uniqueness (%)': uniqueness,
'Diversity': diversity,
'Affinity': np.mean(all_affinity) if all_affinity else 0.0,
'Affinity Std': np.std(all_affinity) if all_affinity else 0.0,
'Solubility': np.mean(all_sol) if all_sol else 0.0,
'Solubility Std': np.std(all_sol) if all_sol else 0.0,
'Hemolysis': np.mean(all_hemo) if all_hemo else 0.0,
'Hemolysis Std': np.std(all_hemo) if all_hemo else 0.0,
'Nonfouling': np.mean(all_nf) if all_nf else 0.0,
'Nonfouling Std': np.std(all_nf) if all_nf else 0.0,
'Permeability': np.mean(all_permeability) if all_permeability else 0.0,
'Permeability Std': np.std(all_permeability) if all_permeability else 0.0,
'Sampling Time (s)': total_time,
'Num Generated': total_generated,
'Num Valid': total_valid,
'Num Unique': num_unique,
}
return metrics
def main():
parser = argparse.ArgumentParser(description="Evaluate a finetuned peptide checkpoint")
parser.add_argument('--checkpoint_path', type=str, required=True,
help='Path to the finetuned Lightning checkpoint (e.g., last.ckpt)')
parser.add_argument('--pretrained_ckpt', type=str,
default=os.path.join(REPO_ROOT, 'pretrained', 'anylength_pep.ckpt'),
help='Path to the pretrained base model checkpoint')
parser.add_argument('--num_samples', type=int, default=500,
help='Number of peptides to sample')
parser.add_argument('--batch_size', type=int, default=50,
help='Batch size for sampling')
parser.add_argument('--max_length', type=int, default=512)
parser.add_argument('--total_num_steps', type=int, default=256)
parser.add_argument('--num_remasking', type=int, default=3)
parser.add_argument('--quality_threshold', type=float, default=0.5,
help='Threshold for insertion quality filtering during sampling')
parser.add_argument('--unmask_quality_threshold', type=float, default=None,
help='If set, gate unmasking/remasking by confidence: remask '
'ALL clean tokens whose unmasking confidence is below this '
'threshold, regardless of the schedule budget. If unset '
'(default), remasking is purely schedule-driven (count-based).')
parser.add_argument('--prot_name', type=str, default='glast',
help='Target protein name (must be one of: ' + ', '.join(PROTEINS.keys()) + ')')
parser.add_argument('--prot_seq', type=str, default=None,
help='Custom protein sequence (overrides --prot_name)')
parser.add_argument('--disable_planner', action='store_true',
help='If set, disable remasking during evaluation')
parser.add_argument('--disable_insertion_planner', action='store_true',
help='If set, disable insertion quality filtering during evaluation')
parser.add_argument('--disable_unmasking_planner', action='store_true',
help='If set, disable unmasking confidence planner during evaluation')
parser.add_argument('--output_dir', type=str, default=None,
help='Directory to save results CSV. Defaults to checkpoint directory.')
parser.add_argument('--device', type=str, default='cuda:0')
parser.add_argument('--seed', type=int, default=42)
args = parser.parse_args()
set_seed(args.seed, use_cuda=True)
device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
# Map flags to quality_mode
if args.disable_planner:
quality_mode = "none"
elif args.disable_insertion_planner and args.disable_unmasking_planner:
quality_mode = "none"
elif args.disable_insertion_planner:
quality_mode = "unmasking_only"
elif args.disable_unmasking_planner:
quality_mode = "insertion_only"
else:
quality_mode = "both"
print(f"Loading checkpoint: {args.checkpoint_path}")
print(f"Pretrained base: {args.pretrained_ckpt}")
print(f"Quality mode: {quality_mode}")
policy_model, train_args, config = load_finetuned_model(
args.checkpoint_path, args.pretrained_ckpt, device=device
)
# Setup tokenizer, reward model, analyzer
tokenizer = SMILES_SPE_Tokenizer(
os.path.join(REPO_ROOT, 'a2d2_pep', 'pep_scoring', 'tokenizer', 'new_vocab.txt'),
os.path.join(REPO_ROOT, 'a2d2_pep', 'pep_scoring', 'tokenizer', 'new_splits.txt')
)
if args.prot_seq is not None:
prot = args.prot_seq
prot_name = args.prot_name
else:
prot_name = args.prot_name
if prot_name not in PROTEINS:
raise ValueError(f"Unknown protein: {prot_name}. Choose from: {list(PROTEINS.keys())}")
prot = PROTEINS[prot_name]
score_func_names = ['binding_affinity1', 'solubility', 'hemolysis', 'nonfouling', 'permeability']
reward_model = ScoringFunctions(score_func_names, prot_seqs=[prot], device=device)
analyzer = PeptideAnalyzer()
print(f"\nSampling {args.num_samples} peptides (quality_mode={quality_mode}, target={prot_name})...")
metrics = evaluate_checkpoint(
policy_model=policy_model,
tokenizer=tokenizer,
reward_model=reward_model,
analyzer=analyzer,
num_samples=args.num_samples,
batch_size=args.batch_size,
max_length=args.max_length,
total_num_steps=args.total_num_steps,
quality_mode=quality_mode,
num_remasking=args.num_remasking,
quality_threshold=args.quality_threshold,
unmask_quality_threshold=args.unmask_quality_threshold,
device=device,
)
# Print summary table
print("\n" + "=" * 60)
print(" De Novo Peptide Generation Results")
print("=" * 60)
for k, v in metrics.items():
if isinstance(v, float):
print(f" {k:<30s}: {v:.4f}")
else:
print(f" {k:<30s}: {v}")
print("=" * 60)
# Save results
output_dir = args.output_dir or os.path.dirname(args.checkpoint_path)
os.makedirs(output_dir, exist_ok=True)
if args.disable_planner:
tag = "no_planner"
elif args.disable_insertion_planner:
tag = "no_insertion_planner"
elif args.disable_unmasking_planner:
tag = "no_unmasking_planner"
else:
tag = "with_planner"
if args.unmask_quality_threshold is not None:
tag += f"_ut{args.unmask_quality_threshold:g}"
# Record the sweep parameter in the saved row for traceability.
metrics['unmask_quality_threshold'] = args.unmask_quality_threshold
metrics['quality_threshold'] = args.quality_threshold
metrics_path = os.path.join(output_dir, f'eval_metrics_{tag}_{prot_name}.csv')
pd.DataFrame([metrics]).to_csv(metrics_path, index=False)
print(f"Metrics saved to: {metrics_path}")
if __name__ == '__main__':
main()