syntax-model / inference /rnn_apply.py
MesserMMP's picture
add files for inference
e4194f4
import os
import json
import tqdm
import torch
import numpy as np
import click
from datetime import datetime
import lightning.pytorch as pl
import sklearn.metrics as skm
from torch.utils.data import DataLoader
from torchvision.transforms import transforms as T
from torchvision.transforms._transforms_video import ToTensorVideo
from pytorchvideo.transforms import Normalize
# Импорты из соседних папок (относительные пути)
from full_model.rnn_dataset import SyntaxDataset
from full_model.rnn_model import SyntaxLightningModule
from metrics_visualization import visualize_final_syntax_plotly_multi
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"DEVICE: {DEVICE}")
def safe_sample_std(values):
"""Sample std (ddof=1). Если значение одно/пусто — 0.0."""
arr = np.array(values, dtype=float)
if arr.size <= 1:
return 0.0
return float(arr.std(ddof=1))
def compute_metrics(y_true, y_pred, thr=22.0):
"""R2, MAE, Pearson, MAPE, Mean_Recall."""
y_true_arr = np.array(y_true, dtype=float)
y_pred_arr = np.array(y_pred, dtype=float)
r2 = float(skm.r2_score(y_true_arr, y_pred_arr))
mae = float(skm.mean_absolute_error(y_true_arr, y_pred_arr))
pearson = float(np.corrcoef(y_true_arr, y_pred_arr)[0, 1]) if len(y_true_arr) > 1 else 0.0
mape = float(skm.mean_absolute_percentage_error(y_true_arr, y_pred_arr))
y_true_bin = (y_true_arr >= thr).astype(int)
y_pred_bin = (y_pred_arr >= thr).astype(int)
unique_classes = np.unique(np.concatenate([y_true_bin, y_pred_bin]))
mean_recall = float(np.mean(skm.recall_score(y_true_bin, y_pred_bin, average=None, labels=[0, 1]))) \
if len(unique_classes) > 1 else 0.0
return r2, mae, pearson, mape, mean_recall
@click.command()
@click.option("-d", "--dataset-paths", multiple=True,
help="JSON с метаданными датасетов (относительно dataset_root).")
@click.option("-n", "--dataset-names", multiple=True,
help="Имена датасетов для метрик/графиков.")
@click.option("-p", "--postfixes", multiple=True,
help="Суффиксы для файлов предсказаний.")
@click.option("-r", "--dataset-root", type=click.Path(exists=True),
help="Корень датасета (где лежат JSON и DICOM).")
@click.option("-v", "--video-size", type=click.Tuple([int, int]),
help="Размер видео (H, W).")
@click.option("--frames-per-clip",
help="Количество кадров в клипе.")
@click.option("--num-workers",
help="Число DataLoader workers.")
@click.option("--seed",
help="Random seed.")
@click.option("--pt-weights-format", is_flag=True,
help="True → модели в .pt (torch.save), False → .ckpt (Lightning).")
@click.option("--use-scaling", is_flag=True,
help="Применить a*x+b scaling из JSON.")
@click.option("--scaling-file",
help="JSON с коэффициентами scaling (относительно dataset_root).")
@click.option("-e", "--ensemble-name",
help="Имя ансамбля в metrics.json.")
@click.option("-m", "--metrics-file",
help="JSON с метриками экспериментов.")
def main(dataset_paths, dataset_names, postfixes, dataset_root, video_size,
frames_per_clip, num_workers, seed, pt_weights_format, use_scaling,
scaling_file, ensemble_name, metrics_file):
pl.seed_everything(seed)
postfix_plotly = "Ensemble"
# Пути к моделям (backbone + RNN-head)
model_paths = {
"left": [
"full_model/checkpoints/leftBinSyntax_R3D_fold00_lstm_mean_post_best.pt",
"full_model/checkpoints/leftBinSyntax_R3D_fold01_lstm_mean_post_best.pt",
"full_model/checkpoints/leftBinSyntax_R3D_fold02_lstm_mean_post_best.pt",
"full_model/checkpoints/leftBinSyntax_R3D_fold03_lstm_mean_post_best.pt",
"full_model/checkpoints/leftBinSyntax_R3D_fold04_lstm_mean_post_best.pt",
],
"right": [
"full_model/checkpoints/rightBinSyntax_R3D_fold00_lstm_mean_post_best.pt",
"full_model/checkpoints/rightBinSyntax_R3D_fold01_lstm_mean_post_best.pt",
"full_model/checkpoints/rightBinSyntax_R3D_fold02_lstm_mean_post_best.pt",
"full_model/checkpoints/rightBinSyntax_R3D_fold03_lstm_mean_post_best.pt",
"full_model/checkpoints/rightBinSyntax_R3D_fold04_lstm_mean_post_best.pt",
]
}
# Scaling параметры
scaling_params_dict = {}
if use_scaling:
postfix_plotly += "_scaled"
ensemble_name += "_scaled"
scaling_path = os.path.join(dataset_root, scaling_file)
if os.path.exists(scaling_path):
with open(scaling_path, "r") as f:
scaling_params_dict = json.load(f)
print(f"Loaded scaling from {scaling_path}")
else:
print(f"⚠️ Scaling file not found: {scaling_path}")
# Результаты ансамбля
ensemble_results = {
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"use_scaling": use_scaling,
"pt_weights_format": pt_weights_format,
"datasets": {}
}
all_datasets, all_r2, all_recalls = {}, {}, {}
for dataset_path, dataset_name, postfix in zip(dataset_paths, dataset_names, postfixes):
# Относительные пути
abs_dataset_path = os.path.join(dataset_root, dataset_path)
results_file = os.path.join(dataset_root, "coeffs", f"{postfix}.json")
# Загрузка/вычисление предсказаний
if os.path.exists(results_file):
print(f"[{postfix}] Loading from {results_file}")
with open(results_file, "r") as f:
data = json.load(f)
syntax_true = data["syntax_true"]
left_preds_all = data["left_preds"]
right_preds_all = data["right_preds"]
else:
print(f"[{postfix}] Computing predictions...")
left_preds_all, left_sids = run_artery(
abs_dataset_path, "left", model_paths["left"],
video_size, frames_per_clip, num_workers, pt_weights_format
)
right_preds_all, right_sids = run_artery(
abs_dataset_path, "right", model_paths["right"],
video_size, frames_per_clip, num_workers, pt_weights_format
)
assert left_sids == right_sids
with open(abs_dataset_path, "r") as f:
dataset = json.load(f)
syntax_true = [rec.get("mean_syntax", rec.get("syntax")) for rec in dataset]
os.makedirs(os.path.dirname(results_file), exist_ok=True)
save_data = {
"syntax_true": syntax_true,
"left_preds": left_preds_all,
"right_preds": right_preds_all
}
with open(results_file, "w") as f:
json.dump(save_data, f)
print(f"[{postfix}] Saved to {results_file}")
# Scaling (fold-wise для left/right)
if use_scaling:
left_scaled_all, right_scaled_all = [], []
for pred_list in left_preds_all:
scaled = [scaling_params_dict.get(f"fold{i}", (1.0, 0.0))[0] * val +
scaling_params_dict.get(f"fold{i}", (1.0, 0.0))[1]
for i, val in enumerate(pred_list)]
left_scaled_all.append(scaled)
for pred_list in right_preds_all:
scaled = [scaling_params_dict.get(f"fold{i}", (1.0, 0.0))[0] * val +
scaling_params_dict.get(f"fold{i}", (1.0, 0.0))[1]
for i, val in enumerate(pred_list)]
right_scaled_all.append(scaled)
else:
left_scaled_all, right_scaled_all = left_preds_all, right_preds_all
# Ансамбль: mean по фолдам + left+right
syntax_pred = [max(0.0, float(np.mean([l + r for l, r in zip(l_list, r_list)])))
for l_list, r_list in zip(left_scaled_all, right_scaled_all)]
# Метрики ансамбля
r2, mae, pearson, mape, mean_recall = compute_metrics(syntax_true, syntax_pred)
print(f"[{postfix}] ENSEMBLE: R2={r2:.4f}, Pearson={pearson:.4f}, "
f"MAE={mae:.4f}, MAPE={mape:.4f}, Recall={mean_recall:.4f}")
# STD по фолдам
n_folds = len(left_scaled_all[0]) if left_scaled_all else 0
fold_metrics = {metric: [] for metric in ["R2", "MAE", "Pearson", "MAPE", "Mean_Recall"]}
for k in range(n_folds):
pred_k = [max(0.0, l_list[k] + r_list[k])
for l_list, r_list in zip(left_scaled_all, right_scaled_all)]
fold_r2, fold_mae, fold_pearson, fold_mape, fold_recall = compute_metrics(syntax_true, pred_k)
for metric, value in zip(fold_metrics.keys(),
[fold_r2, fold_mae, fold_pearson, fold_mape, fold_recall]):
fold_metrics[metric].append(value)
fold_summary = {k: {"mean": float(np.mean(v)), "std": safe_sample_std(v), "values": v}
for k, v in fold_metrics.items()}
# Визуализация и сохранение
all_datasets[dataset_name] = (syntax_true, syntax_pred)
all_r2[dataset_name] = r2
all_recalls[dataset_name] = mean_recall
ensemble_results["datasets"][dataset_name] = {
# Ансамбль
"R2": round(r2, 4), "MAE": round(mae, 4),
"Pearson": round(pearson, 4), "MAPE": round(mape, 4),
"Mean_Recall": round(mean_recall, 4), "N_samples": len(syntax_true),
# По фолдам (mean±std)
**{f"{k}_mean": round(v["mean"], 4) for k, v in fold_summary.items()},
**{f"{k}_std": round(v["std"], 4) for k, v in fold_summary.items()},
**{f"{k}_folds": [round(x, 4) for x in v["values"]] for k, v in fold_summary.items()}
}
# Сохранение метрик
metrics_path = os.path.join(dataset_root, metrics_file)
full_history = {}
if os.path.exists(metrics_path):
try:
with open(metrics_path, "r") as f:
full_history = json.load(f)
except json.JSONDecodeError:
print("⚠️ Metrics file corrupted. Creating new.")
full_history[ensemble_name] = ensemble_results
with open(metrics_path, "w") as f:
json.dump(full_history, f, indent=4)
print(f"✅ Metrics saved: {metrics_path}")
# Визуализация
visualize_final_syntax_plotly_multi(
datasets=all_datasets, r2_values=all_r2, recall_values=all_recalls,
gt_row="ENSEMBLE", postfix=postfix_plotly
)
def run_artery(dataset_path, artery, model_paths, video_size, frames_per_clip,
num_workers, pt_weights_format=False):
"""Инференс для одной артерии (5 фолдов)."""
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]
test_transform = T.Compose([
ToTensorVideo(),
T.Resize(size=video_size, antialias=True),
Normalize(mean=imagenet_mean, std=imagenet_std),
])
val_set = SyntaxDataset(
root=os.path.dirname(dataset_path),
meta=dataset_path,
train=False,
length=frames_per_clip,
label="", # inference mode
artery=artery,
inference=True,
transform=test_transform
)
val_loader = DataLoader(val_set, batch_size=1, num_workers=num_workers,
shuffle=False, pin_memory=True)
print(f"{artery} artery: {len(val_loader)} samples")
models = []
for path in model_paths:
if not os.path.exists(path):
print(f"⚠️ Model not found: {path}")
continue
model = SyntaxLightningModule(
num_classes=2, lr=1e-5, variant="lstm_mean",
weight_decay=0.001, max_epochs=1,
pl_weight_path=path, pt_weights_format=pt_weights_format
)
model.to(DEVICE)
model.eval()
models.append(model)
if not models:
raise RuntimeError(f"No models loaded for {artery}")
preds_all, sids = [], []
with torch.no_grad():
for x, [y], [t], [sid] in tqdm.tqdm(val_loader, desc=f"{artery} infer"):
if len(x.shape) == 1: # пустое видео
val_syntax_list = [0.0] * len(models)
else:
x = x.to(DEVICE)
val_syntax_list = []
for model in models:
pred = model(x)
_, val_log = pred # регрессионный logit
val = float(torch.exp(val_log).cpu()) - 1
val_syntax_list.append(val)
preds_all.append(val_syntax_list)
sids.append(sid[0]) # study_uid
return preds_all, sids
if __name__ == "__main__":
main()