import os import json import tqdm import torch import numpy as np import click from datetime import datetime import lightning.pytorch as pl import sklearn.metrics as skm from torch.utils.data import DataLoader from torchvision.transforms import transforms as T from torchvision.transforms._transforms_video import ToTensorVideo from pytorchvideo.transforms import Normalize # Импорты из соседних папок (относительные пути) from full_model.rnn_dataset import SyntaxDataset from full_model.rnn_model import SyntaxLightningModule from metrics_visualization import visualize_final_syntax_plotly_multi DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(f"DEVICE: {DEVICE}") def safe_sample_std(values): """Sample std (ddof=1). Если значение одно/пусто — 0.0.""" arr = np.array(values, dtype=float) if arr.size <= 1: return 0.0 return float(arr.std(ddof=1)) def compute_metrics(y_true, y_pred, thr=22.0): """R2, MAE, Pearson, MAPE, Mean_Recall.""" y_true_arr = np.array(y_true, dtype=float) y_pred_arr = np.array(y_pred, dtype=float) r2 = float(skm.r2_score(y_true_arr, y_pred_arr)) mae = float(skm.mean_absolute_error(y_true_arr, y_pred_arr)) pearson = float(np.corrcoef(y_true_arr, y_pred_arr)[0, 1]) if len(y_true_arr) > 1 else 0.0 mape = float(skm.mean_absolute_percentage_error(y_true_arr, y_pred_arr)) y_true_bin = (y_true_arr >= thr).astype(int) y_pred_bin = (y_pred_arr >= thr).astype(int) unique_classes = np.unique(np.concatenate([y_true_bin, y_pred_bin])) mean_recall = float(np.mean(skm.recall_score(y_true_bin, y_pred_bin, average=None, labels=[0, 1]))) \ if len(unique_classes) > 1 else 0.0 return r2, mae, pearson, mape, mean_recall @click.command() @click.option("-d", "--dataset-paths", multiple=True, help="JSON с метаданными датасетов (относительно dataset_root).") @click.option("-n", "--dataset-names", multiple=True, help="Имена датасетов для метрик/графиков.") @click.option("-p", "--postfixes", multiple=True, help="Суффиксы для файлов предсказаний.") @click.option("-r", "--dataset-root", type=click.Path(exists=True), help="Корень датасета (где лежат JSON и DICOM).") @click.option("-v", "--video-size", type=click.Tuple([int, int]), help="Размер видео (H, W).") @click.option("--frames-per-clip", help="Количество кадров в клипе.") @click.option("--num-workers", help="Число DataLoader workers.") @click.option("--seed", help="Random seed.") @click.option("--pt-weights-format", is_flag=True, help="True → модели в .pt (torch.save), False → .ckpt (Lightning).") @click.option("--use-scaling", is_flag=True, help="Применить a*x+b scaling из JSON.") @click.option("--scaling-file", help="JSON с коэффициентами scaling (относительно dataset_root).") @click.option("-e", "--ensemble-name", help="Имя ансамбля в metrics.json.") @click.option("-m", "--metrics-file", help="JSON с метриками экспериментов.") def main(dataset_paths, dataset_names, postfixes, dataset_root, video_size, frames_per_clip, num_workers, seed, pt_weights_format, use_scaling, scaling_file, ensemble_name, metrics_file): pl.seed_everything(seed) postfix_plotly = "Ensemble" # Пути к моделям (backbone + RNN-head) model_paths = { "left": [ "full_model/checkpoints/leftBinSyntax_R3D_fold00_lstm_mean_post_best.pt", "full_model/checkpoints/leftBinSyntax_R3D_fold01_lstm_mean_post_best.pt", "full_model/checkpoints/leftBinSyntax_R3D_fold02_lstm_mean_post_best.pt", "full_model/checkpoints/leftBinSyntax_R3D_fold03_lstm_mean_post_best.pt", "full_model/checkpoints/leftBinSyntax_R3D_fold04_lstm_mean_post_best.pt", ], "right": [ "full_model/checkpoints/rightBinSyntax_R3D_fold00_lstm_mean_post_best.pt", "full_model/checkpoints/rightBinSyntax_R3D_fold01_lstm_mean_post_best.pt", "full_model/checkpoints/rightBinSyntax_R3D_fold02_lstm_mean_post_best.pt", "full_model/checkpoints/rightBinSyntax_R3D_fold03_lstm_mean_post_best.pt", "full_model/checkpoints/rightBinSyntax_R3D_fold04_lstm_mean_post_best.pt", ] } # Scaling параметры scaling_params_dict = {} if use_scaling: postfix_plotly += "_scaled" ensemble_name += "_scaled" scaling_path = os.path.join(dataset_root, scaling_file) if os.path.exists(scaling_path): with open(scaling_path, "r") as f: scaling_params_dict = json.load(f) print(f"Loaded scaling from {scaling_path}") else: print(f"⚠️ Scaling file not found: {scaling_path}") # Результаты ансамбля ensemble_results = { "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "use_scaling": use_scaling, "pt_weights_format": pt_weights_format, "datasets": {} } all_datasets, all_r2, all_recalls = {}, {}, {} for dataset_path, dataset_name, postfix in zip(dataset_paths, dataset_names, postfixes): # Относительные пути abs_dataset_path = os.path.join(dataset_root, dataset_path) results_file = os.path.join(dataset_root, "coeffs", f"{postfix}.json") # Загрузка/вычисление предсказаний if os.path.exists(results_file): print(f"[{postfix}] Loading from {results_file}") with open(results_file, "r") as f: data = json.load(f) syntax_true = data["syntax_true"] left_preds_all = data["left_preds"] right_preds_all = data["right_preds"] else: print(f"[{postfix}] Computing predictions...") left_preds_all, left_sids = run_artery( abs_dataset_path, "left", model_paths["left"], video_size, frames_per_clip, num_workers, pt_weights_format ) right_preds_all, right_sids = run_artery( abs_dataset_path, "right", model_paths["right"], video_size, frames_per_clip, num_workers, pt_weights_format ) assert left_sids == right_sids with open(abs_dataset_path, "r") as f: dataset = json.load(f) syntax_true = [rec.get("mean_syntax", rec.get("syntax")) for rec in dataset] os.makedirs(os.path.dirname(results_file), exist_ok=True) save_data = { "syntax_true": syntax_true, "left_preds": left_preds_all, "right_preds": right_preds_all } with open(results_file, "w") as f: json.dump(save_data, f) print(f"[{postfix}] Saved to {results_file}") # Scaling (fold-wise для left/right) if use_scaling: left_scaled_all, right_scaled_all = [], [] for pred_list in left_preds_all: scaled = [scaling_params_dict.get(f"fold{i}", (1.0, 0.0))[0] * val + scaling_params_dict.get(f"fold{i}", (1.0, 0.0))[1] for i, val in enumerate(pred_list)] left_scaled_all.append(scaled) for pred_list in right_preds_all: scaled = [scaling_params_dict.get(f"fold{i}", (1.0, 0.0))[0] * val + scaling_params_dict.get(f"fold{i}", (1.0, 0.0))[1] for i, val in enumerate(pred_list)] right_scaled_all.append(scaled) else: left_scaled_all, right_scaled_all = left_preds_all, right_preds_all # Ансамбль: mean по фолдам + left+right syntax_pred = [max(0.0, float(np.mean([l + r for l, r in zip(l_list, r_list)]))) for l_list, r_list in zip(left_scaled_all, right_scaled_all)] # Метрики ансамбля r2, mae, pearson, mape, mean_recall = compute_metrics(syntax_true, syntax_pred) print(f"[{postfix}] ENSEMBLE: R2={r2:.4f}, Pearson={pearson:.4f}, " f"MAE={mae:.4f}, MAPE={mape:.4f}, Recall={mean_recall:.4f}") # STD по фолдам n_folds = len(left_scaled_all[0]) if left_scaled_all else 0 fold_metrics = {metric: [] for metric in ["R2", "MAE", "Pearson", "MAPE", "Mean_Recall"]} for k in range(n_folds): pred_k = [max(0.0, l_list[k] + r_list[k]) for l_list, r_list in zip(left_scaled_all, right_scaled_all)] fold_r2, fold_mae, fold_pearson, fold_mape, fold_recall = compute_metrics(syntax_true, pred_k) for metric, value in zip(fold_metrics.keys(), [fold_r2, fold_mae, fold_pearson, fold_mape, fold_recall]): fold_metrics[metric].append(value) fold_summary = {k: {"mean": float(np.mean(v)), "std": safe_sample_std(v), "values": v} for k, v in fold_metrics.items()} # Визуализация и сохранение all_datasets[dataset_name] = (syntax_true, syntax_pred) all_r2[dataset_name] = r2 all_recalls[dataset_name] = mean_recall ensemble_results["datasets"][dataset_name] = { # Ансамбль "R2": round(r2, 4), "MAE": round(mae, 4), "Pearson": round(pearson, 4), "MAPE": round(mape, 4), "Mean_Recall": round(mean_recall, 4), "N_samples": len(syntax_true), # По фолдам (mean±std) **{f"{k}_mean": round(v["mean"], 4) for k, v in fold_summary.items()}, **{f"{k}_std": round(v["std"], 4) for k, v in fold_summary.items()}, **{f"{k}_folds": [round(x, 4) for x in v["values"]] for k, v in fold_summary.items()} } # Сохранение метрик metrics_path = os.path.join(dataset_root, metrics_file) full_history = {} if os.path.exists(metrics_path): try: with open(metrics_path, "r") as f: full_history = json.load(f) except json.JSONDecodeError: print("⚠️ Metrics file corrupted. Creating new.") full_history[ensemble_name] = ensemble_results with open(metrics_path, "w") as f: json.dump(full_history, f, indent=4) print(f"✅ Metrics saved: {metrics_path}") # Визуализация visualize_final_syntax_plotly_multi( datasets=all_datasets, r2_values=all_r2, recall_values=all_recalls, gt_row="ENSEMBLE", postfix=postfix_plotly ) def run_artery(dataset_path, artery, model_paths, video_size, frames_per_clip, num_workers, pt_weights_format=False): """Инференс для одной артерии (5 фолдов).""" imagenet_mean = [0.485, 0.456, 0.406] imagenet_std = [0.229, 0.224, 0.225] test_transform = T.Compose([ ToTensorVideo(), T.Resize(size=video_size, antialias=True), Normalize(mean=imagenet_mean, std=imagenet_std), ]) val_set = SyntaxDataset( root=os.path.dirname(dataset_path), meta=dataset_path, train=False, length=frames_per_clip, label="", # inference mode artery=artery, inference=True, transform=test_transform ) val_loader = DataLoader(val_set, batch_size=1, num_workers=num_workers, shuffle=False, pin_memory=True) print(f"{artery} artery: {len(val_loader)} samples") models = [] for path in model_paths: if not os.path.exists(path): print(f"⚠️ Model not found: {path}") continue model = SyntaxLightningModule( num_classes=2, lr=1e-5, variant="lstm_mean", weight_decay=0.001, max_epochs=1, pl_weight_path=path, pt_weights_format=pt_weights_format ) model.to(DEVICE) model.eval() models.append(model) if not models: raise RuntimeError(f"No models loaded for {artery}") preds_all, sids = [], [] with torch.no_grad(): for x, [y], [t], [sid] in tqdm.tqdm(val_loader, desc=f"{artery} infer"): if len(x.shape) == 1: # пустое видео val_syntax_list = [0.0] * len(models) else: x = x.to(DEVICE) val_syntax_list = [] for model in models: pred = model(x) _, val_log = pred # регрессионный logit val = float(torch.exp(val_log).cpu()) - 1 val_syntax_list.append(val) preds_all.append(val_syntax_list) sids.append(sid[0]) # study_uid return preds_all, sids if __name__ == "__main__": main()