| | |
| | |
| |
|
| | import argparse |
| | import os |
| | import numpy as np |
| | import pandas as pd |
| | import torch |
| | import tabm |
| | from sklearn.metrics import precision_recall_curve, auc |
| |
|
| | def normalize_rt(s: pd.Series) -> pd.Series: |
| | return s.astype(str).str.strip().str.upper() |
| |
|
| | def compute_patient_metrics(df_p: pd.DataFrame, y_prob: np.ndarray) -> tuple: |
| | X_r = df_p.copy() |
| | X_r['ML_pred'] = y_prob |
| | X_r['response'] = (normalize_rt(X_r['response_type']) == 'CD8').astype(int) |
| |
|
| | X_r = X_r.sort_values(by=['ML_pred'], ascending=False).reset_index(drop=True) |
| |
|
| | idx_pos = np.where(X_r['response'].to_numpy() == 1)[0] |
| | idx_tested = np.where(normalize_rt(X_r['response_type']) == 'NEGATIVE')[0] |
| |
|
| | def topk_counts(k: int): |
| | k_eff = min(k, len(X_r)) |
| | nr_correct = int(np.sum(idx_pos < k_eff)) |
| | nr_tested = nr_correct + int(np.sum(idx_tested < k_eff)) |
| | return nr_correct, nr_tested |
| |
|
| | nr_correct20, nr_tested20 = topk_counts(20) |
| | nr_correct50, nr_tested50 = topk_counts(50) |
| | nr_correct100, nr_tested100 = topk_counts(100) |
| |
|
| | nr_immuno = int(np.sum(X_r['response'] == 1)) |
| | y_true = X_r['response'].to_numpy() |
| | y_pred = X_r['ML_pred'].to_numpy() |
| |
|
| | alpha = 0.005 |
| | score = float(np.sum(np.exp(-alpha * idx_pos))) |
| |
|
| | if nr_immuno > 0: |
| | sort_idx = np.argsort(idx_pos) |
| | ranks_str = ",".join([f"{int(r+1)}" for r in idx_pos[sort_idx]]) |
| | mut_seqs = X_r.loc[X_r['response'] == 1, 'mutant_seq'].to_numpy() |
| | mut_seqs_str = ",".join([str(s) for s in mut_seqs[sort_idx]]) |
| | genes = X_r.loc[X_r['response'] == 1, 'gene'].to_numpy() |
| | genes_str = ",".join([str(g) for g in genes[sort_idx]]) |
| | else: |
| | ranks_str = "" |
| | mut_seqs_str = "" |
| | genes_str = "" |
| |
|
| | return (X_r['ML_pred'].to_numpy(), X_r, |
| | nr_correct20, nr_tested20, |
| | nr_correct50, nr_tested50, |
| | nr_correct100, nr_tested100, |
| | nr_immuno, idx_pos, score, |
| | ranks_str, mut_seqs_str, genes_str) |
| |
|
| |
|
| | def predict_in_batches(model, X_all, device, batch_size=1024): |
| | model.eval() |
| | y_prob_all = [] |
| | |
| | with torch.inference_mode(): |
| | for i in range(0, len(X_all), batch_size): |
| | batch_end = min(i + batch_size, len(X_all)) |
| | batch_X = X_all[i:batch_end].to(device) |
| | |
| | batch_pred = model(batch_X).mean(1) |
| | batch_pred = torch.softmax(batch_pred, dim=1)[:, 1] |
| | |
| | y_prob_all.append(batch_pred.cpu()) |
| | |
| | del batch_X, batch_pred |
| | if torch.cuda.is_available(): |
| | torch.cuda.empty_cache() |
| | |
| | return torch.cat(y_prob_all, dim=0).numpy() |
| |
|
| | def main(): |
| |
|
| | ap = argparse.ArgumentParser(description="TabM model evaluation, output format consistent with TestVotingClassifier") |
| | ap.add_argument("--model_file", type=str, required=False, help="TabM model file, e.g. tabm_results/tabm_model.pth (mutually exclusive with --model_files/--model_glob, choose one of three)") |
| | ap.add_argument("--model_files", type=str, nargs='*', default=None, help="Multiple model files for equal-weighted average prediction") |
| | ap.add_argument("--model_glob", type=str, default=None, help="Use wildcards to match multiple model files (e.g. tabm_results/tabm_hyperopt_best_rep*.pth)") |
| | ap.add_argument("--data_file", type=str, required=True, help="Input TSV: TestVoting_selection_neopep.tsv") |
| | ap.add_argument("--output_file", type=str, required=True, help="Main result output file (header consistent with original)") |
| | ap.add_argument("--tesla_file", type=str, default=None, help="TESLA score output file (for neopep task)") |
| | ap.add_argument("--output_xlsx", type=str, default=None, help="Main result Excel output path (optional)") |
| | ap.add_argument("--tesla_xlsx", type=str, default=None, help="TESLA result Excel output path (optional)") |
| | ap.add_argument("--dataset_name", type=str, default=None, help="If no dataset column exists, use this value as Dataset column in TESLA") |
| | ap.add_argument("--skip_no_cd8", action="store_true", help="Skip patients without CD8") |
| | ap.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "cpu"], |
| | help="Device selection: auto/cuda/cpu") |
| | ap.add_argument("--batch_size", type=int, default=1024, |
| | help="Batch size to avoid GPU memory overflow (default 1024)") |
| | args = ap.parse_args() |
| |
|
| | |
| | if args.device == "auto": |
| | if torch.cuda.is_available(): |
| | device = torch.device('cuda:0') |
| | print(f"🚀 Auto-selected GPU: {torch.cuda.get_device_name(0)}") |
| | print(f" GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB") |
| | else: |
| | device = torch.device('cpu') |
| | print("⚠️ No GPU detected, using CPU") |
| | elif args.device == "cuda": |
| | if torch.cuda.is_available(): |
| | device = torch.device('cuda:0') |
| | print(f"🚀 Force using GPU: {torch.cuda.get_device_name(0)}") |
| | else: |
| | raise RuntimeError("CUDA specified but no GPU detected") |
| | else: |
| | device = torch.device('cpu') |
| | print("️ Using CPU") |
| |
|
| | print(f" Batch size: {args.batch_size}") |
| |
|
| | |
| | df = pd.read_csv(args.data_file, sep="\t", header=0, low_memory=False) |
| | print(f"📈 Data shape: {df.shape}") |
| |
|
| | |
| | required_cols = ["patient", "response_type", "gene", "mutant_seq"] |
| | for c in required_cols: |
| | if c not in df.columns: |
| | raise KeyError(f"Missing required column: {c}") |
| |
|
| | |
| | feature_cols = [c for c in df.columns if c not in required_cols] |
| | |
| | X_all = df[feature_cols].apply(pd.to_numeric, errors='coerce').fillna(0.0).to_numpy() |
| | print(f" Number of features: {X_all.shape[1]}") |
| |
|
| | |
| | import glob as _glob |
| | model_paths: list[str] = [] |
| | if args.model_files: |
| | model_paths.extend(list(args.model_files)) |
| | if args.model_glob: |
| | model_paths.extend(sorted(_glob.glob(args.model_glob))) |
| | if not model_paths and args.model_file: |
| | model_paths = [args.model_file] |
| | if not model_paths: |
| | raise FileNotFoundError("No model files found, please check!") |
| |
|
| | first_ckpt = torch.load(model_paths[0], map_location='cpu', weights_only=False) |
| | model_args = first_ckpt['args'] |
| |
|
| | def _predict_with_model(model_path: str, X_all_np: np.ndarray) -> np.ndarray: |
| | if not os.path.exists(model_path): |
| | raise FileNotFoundError(f"Model file not existed: {model_path}") |
| | ckpt = torch.load(model_path, map_location='cpu', weights_only=False) |
| | m_args = ckpt['args'] |
| | X_np = X_all_np |
| | if ckpt.get("used_feature_idx") is not None: |
| | try: |
| | ufi = ckpt["used_feature_idx"] |
| | import numpy as _np |
| | ufi_arr = _np.array(ufi, dtype=int) |
| | max_idx = X_np.shape[1] - 1 |
| | ufi_arr = ufi_arr[(ufi_arr >= 0) & (ufi_arr <= max_idx)] |
| | if len(ufi_arr) > 0: |
| | X_np = X_np[:, ufi_arr] |
| | except Exception: |
| | pass |
| | X_tensor_cpu = torch.as_tensor(X_np, dtype=torch.float32) |
| | num_embeddings = None |
| | if getattr(m_args, 'use_embeddings', False): |
| | if m_args.embedding_type == 'linear': |
| | import rtdl_num_embeddings |
| | num_embeddings = rtdl_num_embeddings.LinearReLUEmbeddings(X_tensor_cpu.shape[1]) |
| | elif m_args.embedding_type == 'periodic': |
| | import rtdl_num_embeddings |
| | num_embeddings = rtdl_num_embeddings.PeriodicEmbeddings(X_tensor_cpu.shape[1], lite=False) |
| | elif m_args.embedding_type == 'piecewise': |
| | import rtdl_num_embeddings |
| | num_embeddings = rtdl_num_embeddings.PiecewiseLinearEmbeddings( |
| | rtdl_num_embeddings.compute_bins(X_tensor_cpu, n_bins=48), |
| | d_embedding=16, |
| | activation=False, |
| | version='B', |
| | ) |
| | model = tabm.TabM.make( |
| | n_num_features=X_tensor_cpu.shape[1], |
| | cat_cardinalities=[], |
| | d_out=2, |
| | k=m_args.k, |
| | n_blocks=m_args.n_blocks, |
| | d_block=m_args.d_block, |
| | num_embeddings=num_embeddings, |
| | arch_type=getattr(m_args, 'arch_type', 'tabm'), |
| | ) |
| | model.load_state_dict(ckpt['model_state_dict']) |
| | model.to(device) |
| | model.eval() |
| | bs = max(256, args.batch_size) |
| | probs_list = [] |
| | n = len(X_tensor_cpu) |
| | with torch.inference_mode(): |
| | for i in range(0, n, bs): |
| | j = min(i + bs, n) |
| | xb = X_tensor_cpu[i:j].to(device) |
| | logits = model(xb).mean(1) |
| | probs = torch.softmax(logits, dim=1)[:, 1].detach().cpu().numpy() |
| | probs_list.append(probs) |
| | del xb, logits |
| | if torch.cuda.is_available() and device.type == 'cuda': |
| | torch.cuda.empty_cache() |
| | if (i // bs) % 50 == 0: |
| | print(f" batch {i//bs}/{(n+bs-1)//bs}") |
| | return np.concatenate(probs_list, axis=0) |
| |
|
| | def _stringify(v): |
| | try: |
| | return repr(v) |
| | except Exception: |
| | try: |
| | return str(v) |
| | except Exception: |
| | return "<unprintable>" |
| |
|
| | print("===== Saved Hyperparameters from checkpoint['args'] =====") |
| | if hasattr(model_args, "__dict__"): |
| | hp_items = sorted(vars(model_args).items()) |
| | elif isinstance(model_args, dict): |
| | hp_items = sorted(model_args.items()) |
| | else: |
| | try: |
| | hp_items = sorted(model_args.__dict__.items()) |
| | except Exception: |
| | hp_items = [] |
| | print("⚠️ Unable to enumerate contents of model_args") |
| | for key, val in hp_items: |
| | print(f"- {key}: {_stringify(val)}") |
| | print("=========================================================") |
| |
|
| | def _p_dict(title, d): |
| | try: |
| | print(title) |
| | for k in sorted(d.keys()): |
| | try: |
| | print(f"- {k}: {repr(d[k])}") |
| | except Exception: |
| | print(f"- {k}: <unprintable>") |
| | print("=" * len(title)) |
| | except Exception: |
| | pass |
| |
|
| | if isinstance(first_ckpt.get("training_args"), dict): |
| | _p_dict("===== checkpoint['training_args'] =====", first_ckpt["training_args"]) |
| |
|
| | if isinstance(first_ckpt.get("best_params"), dict): |
| | _p_dict("===== checkpoint['best_params'] =====", first_ckpt["best_params"]) |
| |
|
| | if isinstance(first_ckpt.get("full_args"), dict): |
| | _p_dict("===== checkpoint['full_args'] =====", first_ckpt["full_args"]) |
| |
|
| | if first_ckpt.get("used_feature_idx") is not None: |
| | try: |
| | ufi = first_ckpt["used_feature_idx"] |
| | print("===== used_feature_idx =====") |
| | print(f"- length: {len(ufi)}") |
| | print(f"- head: {list(ufi[:10])}") |
| | print("=" * 25) |
| | except Exception: |
| | print("===== used_feature_idx =====\n<unprintable>\n============================") |
| |
|
| | try: |
| | print("===== Environment =====") |
| | print(f"- torch: {torch.__version__}") |
| | print(f"- cuda available: {torch.cuda.is_available()}") |
| | if torch.cuda.is_available(): |
| | print(f"- device: {torch.cuda.get_device_name(0)}") |
| | print(f"- cuda version: {torch.version.cuda}") |
| | import tabm as _tabm_mod |
| | print(f"- tabm: {getattr(_tabm_mod, '__version__', 'unknown')}") |
| | print("========================") |
| | except Exception: |
| | pass |
| |
|
| | n_models = len(model_paths) |
| | print(f"🔗 Loading {n_models} models for equal-weighted average prediction...") |
| | y_prob_all = None |
| | for mp in model_paths: |
| | print(f" -> {mp}") |
| | probs = _predict_with_model(mp, X_all) |
| | if y_prob_all is None: |
| | y_prob_all = probs.astype(np.float64) |
| | else: |
| | y_prob_all += probs |
| | y_prob_all = (y_prob_all / float(n_models)).astype(np.float64) |
| |
|
| | print(f"✅ Prediction completed, total {len(y_prob_all)} samples; number of models={n_models}") |
| |
|
| | rows_main = [] |
| | rows_tesla = [] |
| |
|
| | need_header = (not os.path.exists(args.output_file)) or (os.path.getsize(args.output_file) == 0) |
| | with open(args.output_file, "a", encoding="utf-8") as f: |
| | if need_header: |
| | f.write("Patient\tNr_correct_top20\tNr_tested_top20\tNr_correct_top50\tNr_tested_top50\t" |
| | "Nr_correct_top100\tNr_tested_top100\tNr_immunogenic\tNr_peptides\tClf_score\t" |
| | "CD8_ranks\tCD8_mut_seqs\tCD8_genes\n") |
| |
|
| | for patient, df_p in df.groupby("patient", sort=False): |
| | has_cd8 = (normalize_rt(df_p["response_type"]) == "CD8").any() |
| | if args.skip_no_cd8 and not has_cd8: |
| | continue |
| |
|
| | idx = df_p.index.to_numpy() |
| | y_prob = y_prob_all[idx] |
| |
|
| | (y_pred_sorted, X_sorted, |
| | nr_correct20, nr_tested20, |
| | nr_correct50, nr_tested50, |
| | nr_correct100, nr_tested100, |
| | nr_immuno, r, score, |
| | ranks_str, mut_seqs_str, genes_str) = compute_patient_metrics(df_p, y_prob) |
| |
|
| | f.write(f"{patient}\t{nr_correct20}\t{nr_tested20}\t{nr_correct50}\t{nr_tested50}\t" |
| | f"{nr_correct100}\t{nr_tested100}\t{nr_immuno}\t{len(df_p)}\t{score:.6f}\t" |
| | f"{ranks_str}\t{mut_seqs_str}\t{genes_str}\n") |
| |
|
| | rows_main.append({ |
| | "Patient": patient, |
| | "Nr_correct_top20": nr_correct20, |
| | "Nr_tested_top20": nr_tested20, |
| | "Nr_correct_top50": nr_correct50, |
| | "Nr_tested_top50": nr_tested50, |
| | "Nr_correct_top100": nr_correct100, |
| | "Nr_tested_top100": nr_tested100, |
| | "Nr_immunogenic": nr_immuno, |
| | "Nr_peptides": len(df_p), |
| | "Clf_score": score, |
| | "CD8_ranks": ranks_str, |
| | "CD8_mut_seqs": mut_seqs_str, |
| | "CD8_genes": genes_str, |
| | }) |
| |
|
| | if args.tesla_file or args.tesla_xlsx: |
| | if "dataset" in df_p.columns: |
| | dataset_val = str(df_p["dataset"].iloc[0]) |
| | else: |
| | dataset_val = args.dataset_name if args.dataset_name is not None else "" |
| | idx_nt = X_sorted['response_type'].astype(str) != 'not_tested' |
| | y_pred_tesla = pd.Series(y_pred_sorted)[idx_nt].to_numpy() |
| | y_tesla = X_sorted.loc[idx_nt, 'response'].to_numpy() |
| | ttif = (nr_correct20 / nr_tested20) if nr_tested20 > 0 else 0.0 |
| | fr = (nr_correct100 / nr_immuno) if nr_immuno > 0 else 0.0 |
| | precision, recall, _ = precision_recall_curve(y_tesla, y_pred_tesla) |
| | auprc = auc(recall, precision) |
| |
|
| | if args.tesla_file: |
| | new_tesla = (not os.path.exists(args.tesla_file)) or (os.path.getsize(args.tesla_file) == 0) |
| | with open(args.tesla_file, "a", encoding="utf-8") as tf: |
| | if new_tesla: |
| | tf.write("Dataset\tPatient\tTTIF\tFR\tAUPRC\n") |
| | tf.write(f"{dataset_val}\t{patient}\t{ttif:.3f}\t{fr:.3f}\t{auprc:.3f}\n") |
| |
|
| | rows_tesla.append({ |
| | "Dataset": dataset_val, |
| | "Patient": patient, |
| | "TTIF": ttif, |
| | "FR": fr, |
| | "AUPRC": auprc, |
| | }) |
| |
|
| | if args.output_xlsx and rows_main: |
| | os.makedirs(os.path.dirname(args.output_xlsx) or '.', exist_ok=True) |
| | pd.DataFrame(rows_main).to_excel(args.output_xlsx, index=False) |
| | if args.tesla_xlsx and rows_tesla: |
| | os.makedirs(os.path.dirname(args.tesla_xlsx) or '.', exist_ok=True) |
| | pd.DataFrame(rows_tesla).to_excel(args.tesla_xlsx, index=False) |
| |
|
| | print(f" Evaluation completed! Processed {len(rows_main)} patients") |
| |
|
| | if __name__ == "__main__": |
| | main() |