| | import os |
| | import json |
| | import tqdm |
| | import torch |
| | import numpy as np |
| | import click |
| | from datetime import datetime |
| | import lightning.pytorch as pl |
| | import sklearn.metrics as skm |
| |
|
| | from torch.utils.data import DataLoader |
| | from torchvision.transforms import transforms as T |
| | from torchvision.transforms._transforms_video import ToTensorVideo |
| | from pytorchvideo.transforms import Normalize |
| |
|
| | |
| | from full_model.rnn_dataset import SyntaxDataset |
| | from full_model.rnn_model import SyntaxLightningModule |
| | from metrics_visualization import visualize_final_syntax_plotly_multi |
| |
|
| | DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
| | print(f"DEVICE: {DEVICE}") |
| |
|
| |
|
| | def safe_sample_std(values): |
| | """Sample std (ddof=1). Если значение одно/пусто — 0.0.""" |
| | arr = np.array(values, dtype=float) |
| | if arr.size <= 1: |
| | return 0.0 |
| | return float(arr.std(ddof=1)) |
| |
|
| |
|
| | def compute_metrics(y_true, y_pred, thr=22.0): |
| | """R2, MAE, Pearson, MAPE, Mean_Recall.""" |
| | y_true_arr = np.array(y_true, dtype=float) |
| | y_pred_arr = np.array(y_pred, dtype=float) |
| |
|
| | r2 = float(skm.r2_score(y_true_arr, y_pred_arr)) |
| | mae = float(skm.mean_absolute_error(y_true_arr, y_pred_arr)) |
| |
|
| | pearson = float(np.corrcoef(y_true_arr, y_pred_arr)[0, 1]) if len(y_true_arr) > 1 else 0.0 |
| | mape = float(skm.mean_absolute_percentage_error(y_true_arr, y_pred_arr)) |
| |
|
| | y_true_bin = (y_true_arr >= thr).astype(int) |
| | y_pred_bin = (y_pred_arr >= thr).astype(int) |
| | unique_classes = np.unique(np.concatenate([y_true_bin, y_pred_bin])) |
| | mean_recall = float(np.mean(skm.recall_score(y_true_bin, y_pred_bin, average=None, labels=[0, 1]))) \ |
| | if len(unique_classes) > 1 else 0.0 |
| |
|
| | return r2, mae, pearson, mape, mean_recall |
| |
|
| |
|
| | @click.command() |
| | @click.option("-d", "--dataset-paths", multiple=True, |
| | help="JSON с метаданными датасетов (относительно dataset_root).") |
| | @click.option("-n", "--dataset-names", multiple=True, |
| | help="Имена датасетов для метрик/графиков.") |
| | @click.option("-p", "--postfixes", multiple=True, |
| | help="Суффиксы для файлов предсказаний.") |
| | @click.option("-r", "--dataset-root", type=click.Path(exists=True), |
| | help="Корень датасета (где лежат JSON и DICOM).") |
| | @click.option("-v", "--video-size", type=click.Tuple([int, int]), |
| | help="Размер видео (H, W).") |
| | @click.option("--frames-per-clip", |
| | help="Количество кадров в клипе.") |
| | @click.option("--num-workers", |
| | help="Число DataLoader workers.") |
| | @click.option("--seed", |
| | help="Random seed.") |
| | @click.option("--pt-weights-format", is_flag=True, |
| | help="True → модели в .pt (torch.save), False → .ckpt (Lightning).") |
| | @click.option("--use-scaling", is_flag=True, |
| | help="Применить a*x+b scaling из JSON.") |
| | @click.option("--scaling-file", |
| | help="JSON с коэффициентами scaling (относительно dataset_root).") |
| | @click.option("-e", "--ensemble-name", |
| | help="Имя ансамбля в metrics.json.") |
| | @click.option("-m", "--metrics-file", |
| | help="JSON с метриками экспериментов.") |
| | def main(dataset_paths, dataset_names, postfixes, dataset_root, video_size, |
| | frames_per_clip, num_workers, seed, pt_weights_format, use_scaling, |
| | scaling_file, ensemble_name, metrics_file): |
| |
|
| | pl.seed_everything(seed) |
| | postfix_plotly = "Ensemble" |
| |
|
| | |
| | model_paths = { |
| | "left": [ |
| | "full_model/checkpoints/leftBinSyntax_R3D_fold00_lstm_mean_post_best.pt", |
| | "full_model/checkpoints/leftBinSyntax_R3D_fold01_lstm_mean_post_best.pt", |
| | "full_model/checkpoints/leftBinSyntax_R3D_fold02_lstm_mean_post_best.pt", |
| | "full_model/checkpoints/leftBinSyntax_R3D_fold03_lstm_mean_post_best.pt", |
| | "full_model/checkpoints/leftBinSyntax_R3D_fold04_lstm_mean_post_best.pt", |
| | ], |
| | "right": [ |
| | "full_model/checkpoints/rightBinSyntax_R3D_fold00_lstm_mean_post_best.pt", |
| | "full_model/checkpoints/rightBinSyntax_R3D_fold01_lstm_mean_post_best.pt", |
| | "full_model/checkpoints/rightBinSyntax_R3D_fold02_lstm_mean_post_best.pt", |
| | "full_model/checkpoints/rightBinSyntax_R3D_fold03_lstm_mean_post_best.pt", |
| | "full_model/checkpoints/rightBinSyntax_R3D_fold04_lstm_mean_post_best.pt", |
| | ] |
| | } |
| |
|
| | |
| | scaling_params_dict = {} |
| | if use_scaling: |
| | postfix_plotly += "_scaled" |
| | ensemble_name += "_scaled" |
| | scaling_path = os.path.join(dataset_root, scaling_file) |
| | if os.path.exists(scaling_path): |
| | with open(scaling_path, "r") as f: |
| | scaling_params_dict = json.load(f) |
| | print(f"Loaded scaling from {scaling_path}") |
| | else: |
| | print(f"⚠️ Scaling file not found: {scaling_path}") |
| |
|
| | |
| | ensemble_results = { |
| | "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
| | "use_scaling": use_scaling, |
| | "pt_weights_format": pt_weights_format, |
| | "datasets": {} |
| | } |
| |
|
| | all_datasets, all_r2, all_recalls = {}, {}, {} |
| |
|
| | for dataset_path, dataset_name, postfix in zip(dataset_paths, dataset_names, postfixes): |
| | |
| | abs_dataset_path = os.path.join(dataset_root, dataset_path) |
| | results_file = os.path.join(dataset_root, "coeffs", f"{postfix}.json") |
| |
|
| | |
| | if os.path.exists(results_file): |
| | print(f"[{postfix}] Loading from {results_file}") |
| | with open(results_file, "r") as f: |
| | data = json.load(f) |
| | syntax_true = data["syntax_true"] |
| | left_preds_all = data["left_preds"] |
| | right_preds_all = data["right_preds"] |
| | else: |
| | print(f"[{postfix}] Computing predictions...") |
| | left_preds_all, left_sids = run_artery( |
| | abs_dataset_path, "left", model_paths["left"], |
| | video_size, frames_per_clip, num_workers, pt_weights_format |
| | ) |
| | right_preds_all, right_sids = run_artery( |
| | abs_dataset_path, "right", model_paths["right"], |
| | video_size, frames_per_clip, num_workers, pt_weights_format |
| | ) |
| | assert left_sids == right_sids |
| |
|
| | with open(abs_dataset_path, "r") as f: |
| | dataset = json.load(f) |
| | syntax_true = [rec.get("mean_syntax", rec.get("syntax")) for rec in dataset] |
| |
|
| | os.makedirs(os.path.dirname(results_file), exist_ok=True) |
| | save_data = { |
| | "syntax_true": syntax_true, |
| | "left_preds": left_preds_all, |
| | "right_preds": right_preds_all |
| | } |
| | with open(results_file, "w") as f: |
| | json.dump(save_data, f) |
| | print(f"[{postfix}] Saved to {results_file}") |
| |
|
| | |
| | if use_scaling: |
| | left_scaled_all, right_scaled_all = [], [] |
| | for pred_list in left_preds_all: |
| | scaled = [scaling_params_dict.get(f"fold{i}", (1.0, 0.0))[0] * val + |
| | scaling_params_dict.get(f"fold{i}", (1.0, 0.0))[1] |
| | for i, val in enumerate(pred_list)] |
| | left_scaled_all.append(scaled) |
| | for pred_list in right_preds_all: |
| | scaled = [scaling_params_dict.get(f"fold{i}", (1.0, 0.0))[0] * val + |
| | scaling_params_dict.get(f"fold{i}", (1.0, 0.0))[1] |
| | for i, val in enumerate(pred_list)] |
| | right_scaled_all.append(scaled) |
| | else: |
| | left_scaled_all, right_scaled_all = left_preds_all, right_preds_all |
| |
|
| | |
| | syntax_pred = [max(0.0, float(np.mean([l + r for l, r in zip(l_list, r_list)]))) |
| | for l_list, r_list in zip(left_scaled_all, right_scaled_all)] |
| |
|
| | |
| | r2, mae, pearson, mape, mean_recall = compute_metrics(syntax_true, syntax_pred) |
| | print(f"[{postfix}] ENSEMBLE: R2={r2:.4f}, Pearson={pearson:.4f}, " |
| | f"MAE={mae:.4f}, MAPE={mape:.4f}, Recall={mean_recall:.4f}") |
| |
|
| | |
| | n_folds = len(left_scaled_all[0]) if left_scaled_all else 0 |
| | fold_metrics = {metric: [] for metric in ["R2", "MAE", "Pearson", "MAPE", "Mean_Recall"]} |
| | for k in range(n_folds): |
| | pred_k = [max(0.0, l_list[k] + r_list[k]) |
| | for l_list, r_list in zip(left_scaled_all, right_scaled_all)] |
| | fold_r2, fold_mae, fold_pearson, fold_mape, fold_recall = compute_metrics(syntax_true, pred_k) |
| | for metric, value in zip(fold_metrics.keys(), |
| | [fold_r2, fold_mae, fold_pearson, fold_mape, fold_recall]): |
| | fold_metrics[metric].append(value) |
| |
|
| | fold_summary = {k: {"mean": float(np.mean(v)), "std": safe_sample_std(v), "values": v} |
| | for k, v in fold_metrics.items()} |
| |
|
| | |
| | all_datasets[dataset_name] = (syntax_true, syntax_pred) |
| | all_r2[dataset_name] = r2 |
| | all_recalls[dataset_name] = mean_recall |
| |
|
| | ensemble_results["datasets"][dataset_name] = { |
| | |
| | "R2": round(r2, 4), "MAE": round(mae, 4), |
| | "Pearson": round(pearson, 4), "MAPE": round(mape, 4), |
| | "Mean_Recall": round(mean_recall, 4), "N_samples": len(syntax_true), |
| | |
| | **{f"{k}_mean": round(v["mean"], 4) for k, v in fold_summary.items()}, |
| | **{f"{k}_std": round(v["std"], 4) for k, v in fold_summary.items()}, |
| | **{f"{k}_folds": [round(x, 4) for x in v["values"]] for k, v in fold_summary.items()} |
| | } |
| |
|
| | |
| | metrics_path = os.path.join(dataset_root, metrics_file) |
| | full_history = {} |
| | if os.path.exists(metrics_path): |
| | try: |
| | with open(metrics_path, "r") as f: |
| | full_history = json.load(f) |
| | except json.JSONDecodeError: |
| | print("⚠️ Metrics file corrupted. Creating new.") |
| | |
| | full_history[ensemble_name] = ensemble_results |
| | with open(metrics_path, "w") as f: |
| | json.dump(full_history, f, indent=4) |
| | print(f"✅ Metrics saved: {metrics_path}") |
| |
|
| | |
| | visualize_final_syntax_plotly_multi( |
| | datasets=all_datasets, r2_values=all_r2, recall_values=all_recalls, |
| | gt_row="ENSEMBLE", postfix=postfix_plotly |
| | ) |
| |
|
| |
|
| | def run_artery(dataset_path, artery, model_paths, video_size, frames_per_clip, |
| | num_workers, pt_weights_format=False): |
| | """Инференс для одной артерии (5 фолдов).""" |
| | imagenet_mean = [0.485, 0.456, 0.406] |
| | imagenet_std = [0.229, 0.224, 0.225] |
| | test_transform = T.Compose([ |
| | ToTensorVideo(), |
| | T.Resize(size=video_size, antialias=True), |
| | Normalize(mean=imagenet_mean, std=imagenet_std), |
| | ]) |
| |
|
| | val_set = SyntaxDataset( |
| | root=os.path.dirname(dataset_path), |
| | meta=dataset_path, |
| | train=False, |
| | length=frames_per_clip, |
| | label="", |
| | artery=artery, |
| | inference=True, |
| | transform=test_transform |
| | ) |
| | val_loader = DataLoader(val_set, batch_size=1, num_workers=num_workers, |
| | shuffle=False, pin_memory=True) |
| | print(f"{artery} artery: {len(val_loader)} samples") |
| |
|
| | models = [] |
| | for path in model_paths: |
| | if not os.path.exists(path): |
| | print(f"⚠️ Model not found: {path}") |
| | continue |
| | model = SyntaxLightningModule( |
| | num_classes=2, lr=1e-5, variant="lstm_mean", |
| | weight_decay=0.001, max_epochs=1, |
| | pl_weight_path=path, pt_weights_format=pt_weights_format |
| | ) |
| | model.to(DEVICE) |
| | model.eval() |
| | models.append(model) |
| | if not models: |
| | raise RuntimeError(f"No models loaded for {artery}") |
| |
|
| | preds_all, sids = [], [] |
| | with torch.no_grad(): |
| | for x, [y], [t], [sid] in tqdm.tqdm(val_loader, desc=f"{artery} infer"): |
| | if len(x.shape) == 1: |
| | val_syntax_list = [0.0] * len(models) |
| | else: |
| | x = x.to(DEVICE) |
| | val_syntax_list = [] |
| | for model in models: |
| | pred = model(x) |
| | _, val_log = pred |
| | val = float(torch.exp(val_log).cpu()) - 1 |
| | val_syntax_list.append(val) |
| | preds_all.append(val_syntax_list) |
| | sids.append(sid[0]) |
| |
|
| | return preds_all, sids |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|