add files for inference
Browse files- inference/metrics_visualization.py +393 -0
- inference/rnn_apply.py +305 -0
inference/metrics_visualization.py
ADDED
|
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Визуализация предсказаний SYNTAX:
|
| 3 |
+
- точки (SYNTAX GT vs предсказания модели) для нескольких датасетов;
|
| 4 |
+
- зоны риска (низкий / высокий риск);
|
| 5 |
+
- области ±σ и ±2σ вокруг диагонали;
|
| 6 |
+
- логистические тренды для каждого датасета.
|
| 7 |
+
|
| 8 |
+
Скрипт не зависит от PyTorch/Lightning и используется на этапе инференса.
|
| 9 |
+
Сохранение осуществляется в папку `visualizations/` внутри проекта.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
import numpy as np
|
| 14 |
+
import plotly.graph_objects as go
|
| 15 |
+
from scipy.optimize import curve_fit # type: ignore
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def visualize_final_syntax_plotly_multi(
|
| 19 |
+
datasets,
|
| 20 |
+
r2_values,
|
| 21 |
+
gt_row,
|
| 22 |
+
postfix=None,
|
| 23 |
+
threshold=22.0,
|
| 24 |
+
recall_values=None,
|
| 25 |
+
backbone=False,
|
| 26 |
+
):
|
| 27 |
+
"""
|
| 28 |
+
Единая визуализация SYNTAX: точки, зоны риска и логистические тренды.
|
| 29 |
+
|
| 30 |
+
Параметры
|
| 31 |
+
---------
|
| 32 |
+
datasets : dict[str, tuple[list[float], list[float]]]
|
| 33 |
+
Словарь {имя_датасета: (syntax_true_list, syntax_pred_list)}.
|
| 34 |
+
r2_values : dict[str, float]
|
| 35 |
+
Словарь R^2 по датасетам.
|
| 36 |
+
gt_row : str
|
| 37 |
+
Строка, попадающая в заголовок (например, "ENSEMBLE" или "BOTH").
|
| 38 |
+
postfix : str | None
|
| 39 |
+
Суффикс для имени сохраняемого файла.
|
| 40 |
+
threshold : float
|
| 41 |
+
Порог SYNTAX (обычно 22.0) для разделения зон риска.
|
| 42 |
+
recall_values : dict[str, float] | None
|
| 43 |
+
Словарь Recall по датасетам (может быть None).
|
| 44 |
+
backbone : bool
|
| 45 |
+
Если True, сохраняет в `visualizations/backbone`, иначе в `visualizations/`.
|
| 46 |
+
"""
|
| 47 |
+
# ========== КОНСТАНТЫ ДЛЯ НАСТРОЙКИ ==========
|
| 48 |
+
DATA_MIN = 0.0
|
| 49 |
+
DATA_MAX = 60.0
|
| 50 |
+
|
| 51 |
+
PADDING = 0.5
|
| 52 |
+
|
| 53 |
+
SIGMA_SLOPE = 0.15
|
| 54 |
+
SIGMA_BASE = 1.4
|
| 55 |
+
|
| 56 |
+
PLOT_WIDTH = 980
|
| 57 |
+
PLOT_HEIGHT = 980
|
| 58 |
+
|
| 59 |
+
BASE_FONT_SIZE = 16
|
| 60 |
+
TITLE_FONT_SIZE = 22
|
| 61 |
+
AXIS_LABEL_FONT_SIZE = BASE_FONT_SIZE
|
| 62 |
+
AXIS_TICK_FONT_SIZE = 15
|
| 63 |
+
LEGEND_FONT_SIZE = 14
|
| 64 |
+
|
| 65 |
+
MARKER_SIZE = 11
|
| 66 |
+
MARKER_LINE_WIDTH = 1.1
|
| 67 |
+
LINE_WIDTH = 2
|
| 68 |
+
TREND_LINE_WIDTH = 3
|
| 69 |
+
|
| 70 |
+
PLOT_BG_COLOR = "rgba(235,238,245,1)"
|
| 71 |
+
PAPER_BG_COLOR = "white"
|
| 72 |
+
LEGEND_BG_COLOR = "rgba(255,255,255,0.94)"
|
| 73 |
+
GRID_COLOR = "rgba(100,116,139,0.18)"
|
| 74 |
+
|
| 75 |
+
MARGIN_LEFT = 70
|
| 76 |
+
MARGIN_RIGHT = 24
|
| 77 |
+
MARGIN_TOP = 78
|
| 78 |
+
MARGIN_BOTTOM = 70
|
| 79 |
+
|
| 80 |
+
LEGEND_X = 0.04
|
| 81 |
+
LEGEND_Y = 0.99
|
| 82 |
+
|
| 83 |
+
COLORS = ["#1E88E5", "#8E24AA", "#A0D137", "#EA1D1D", "#06EE0D", "#FB8C00"]
|
| 84 |
+
SYMBOLS = ["circle", "x", "square", "diamond", "triangle-up", "star"]
|
| 85 |
+
|
| 86 |
+
SIGMA_POINTS = 400
|
| 87 |
+
TREND_POINTS = 500
|
| 88 |
+
|
| 89 |
+
# ========== ВСПОМОГАТЕЛЬНЫЕ ФУНКЦИИ ==========
|
| 90 |
+
|
| 91 |
+
def _logistic_time(t, R0, Rmax, t50, k):
|
| 92 |
+
"""Логистическая функция по времени/оценке SYNTAX."""
|
| 93 |
+
t = np.asarray(t, dtype=float)
|
| 94 |
+
t_safe = np.where(t <= 0, 1e-3, t)
|
| 95 |
+
return R0 + (Rmax - R0) / (1.0 + (t50 / t_safe) ** k)
|
| 96 |
+
|
| 97 |
+
def _fit_logistic(x, y, domain=(DATA_MIN, DATA_MAX), n=TREND_POINTS):
|
| 98 |
+
"""
|
| 99 |
+
Аппроксимация логистической кривой.
|
| 100 |
+
Возвращает X, Y или (None, None), если фит не удался.
|
| 101 |
+
"""
|
| 102 |
+
x = np.asarray(x, dtype=float)
|
| 103 |
+
y = np.asarray(y, dtype=float)
|
| 104 |
+
m = np.isfinite(x) & np.isfinite(y)
|
| 105 |
+
if m.sum() < 4:
|
| 106 |
+
return None, None
|
| 107 |
+
|
| 108 |
+
x_m, y_m = x[m], y[m]
|
| 109 |
+
x_min = max(float(np.min(x_m)), float(domain[0]))
|
| 110 |
+
x_max = min(float(np.max(x_m)), float(domain[1]))
|
| 111 |
+
if not np.isfinite(x_min) or not np.isfinite(x_max) or x_max <= x_min:
|
| 112 |
+
return None, None
|
| 113 |
+
|
| 114 |
+
x_pos = x_m[x_m > 0]
|
| 115 |
+
if x_pos.size == 0:
|
| 116 |
+
return None, None
|
| 117 |
+
|
| 118 |
+
R0_init = float(np.percentile(y_m, 10))
|
| 119 |
+
Rmax_init = float(np.percentile(y_m, 90))
|
| 120 |
+
t50_init = float(np.median(x_pos))
|
| 121 |
+
k_init = 1.0
|
| 122 |
+
|
| 123 |
+
lower = [-10.0, 0.0, 1e-3, 0.01]
|
| 124 |
+
upper = [60.0, 80.0, 60.0, 10.0]
|
| 125 |
+
|
| 126 |
+
try:
|
| 127 |
+
popt, _ = curve_fit(
|
| 128 |
+
_logistic_time,
|
| 129 |
+
x_m,
|
| 130 |
+
y_m,
|
| 131 |
+
p0=[R0_init, Rmax_init, t50_init, k_init],
|
| 132 |
+
bounds=(lower, upper),
|
| 133 |
+
maxfev=20000,
|
| 134 |
+
)
|
| 135 |
+
except Exception:
|
| 136 |
+
return None, None
|
| 137 |
+
|
| 138 |
+
X = np.linspace(x_min, x_max, n)
|
| 139 |
+
Y = _logistic_time(X, *popt)
|
| 140 |
+
return X, Y
|
| 141 |
+
|
| 142 |
+
# ========== ОСНОВНОЙ КОД ==========
|
| 143 |
+
fig = go.Figure()
|
| 144 |
+
|
| 145 |
+
line_min = DATA_MIN - PADDING
|
| 146 |
+
line_max = DATA_MAX + PADDING
|
| 147 |
+
domain = (line_min, line_max)
|
| 148 |
+
|
| 149 |
+
base_font = dict(
|
| 150 |
+
family="Inter, Roboto, Helvetica Neue, Arial, sans-serif",
|
| 151 |
+
size=BASE_FONT_SIZE,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# ---------- Пороги и линии (legendrank=0) ----------
|
| 155 |
+
fig.add_trace(
|
| 156 |
+
go.Scatter(
|
| 157 |
+
x=[line_min, threshold, threshold, line_min],
|
| 158 |
+
y=[line_min, line_min, threshold, threshold],
|
| 159 |
+
fill="toself",
|
| 160 |
+
fillcolor="rgba(255, 82, 82, 0.12)",
|
| 161 |
+
line=dict(color="rgba(0,0,0,0)"),
|
| 162 |
+
name="Low-risk zone",
|
| 163 |
+
legendgroup="zones",
|
| 164 |
+
legendgrouptitle_text="Пороги и линии",
|
| 165 |
+
showlegend=True,
|
| 166 |
+
hoverinfo="skip",
|
| 167 |
+
legendrank=0,
|
| 168 |
+
)
|
| 169 |
+
)
|
| 170 |
+
fig.add_trace(
|
| 171 |
+
go.Scatter(
|
| 172 |
+
x=[threshold, line_max, line_max, threshold],
|
| 173 |
+
y=[threshold, threshold, line_max, line_max],
|
| 174 |
+
fill="toself",
|
| 175 |
+
fillcolor="rgba(76, 175, 80, 0.14)",
|
| 176 |
+
line=dict(color="rgba(0,0,0,0)"),
|
| 177 |
+
name="High-risk zone",
|
| 178 |
+
legendgroup="zones",
|
| 179 |
+
showlegend=True,
|
| 180 |
+
hoverinfo="skip",
|
| 181 |
+
legendrank=0,
|
| 182 |
+
)
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
fig.add_trace(
|
| 186 |
+
go.Scatter(
|
| 187 |
+
x=[threshold, threshold, None, line_min, line_max],
|
| 188 |
+
y=[line_min, line_max, None, threshold, threshold],
|
| 189 |
+
mode="lines",
|
| 190 |
+
name=rf"$\mathrm{{SYNTAX}}={threshold}$",
|
| 191 |
+
legendgroup="zones",
|
| 192 |
+
showlegend=True,
|
| 193 |
+
line=dict(color="rgba(46,125,50,0.85)", width=LINE_WIDTH, dash="dash"),
|
| 194 |
+
legendrank=0,
|
| 195 |
+
hoverinfo="skip",
|
| 196 |
+
)
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
x_vals = np.linspace(line_min, line_max, SIGMA_POINTS)
|
| 200 |
+
sigma_upper = x_vals + SIGMA_BASE + SIGMA_SLOPE * x_vals
|
| 201 |
+
sigma_lower = x_vals - SIGMA_BASE - SIGMA_SLOPE * x_vals
|
| 202 |
+
two_sigma_upper = x_vals + 2 * SIGMA_BASE + 2 * SIGMA_SLOPE * x_vals
|
| 203 |
+
two_sigma_lower = x_vals - 2 * SIGMA_BASE - 2 * SIGMA_SLOPE * x_vals
|
| 204 |
+
|
| 205 |
+
fig.add_trace(
|
| 206 |
+
go.Scatter(
|
| 207 |
+
x=np.concatenate([x_vals, x_vals[::-1]]),
|
| 208 |
+
y=np.concatenate([two_sigma_lower, two_sigma_upper[::-1]]),
|
| 209 |
+
fill="toself",
|
| 210 |
+
fillcolor="rgba(255,193,7,0.18)",
|
| 211 |
+
line=dict(color="rgba(0,0,0,0)"),
|
| 212 |
+
name=r"$\pm 2\sigma$",
|
| 213 |
+
legendgroup="zones",
|
| 214 |
+
showlegend=True,
|
| 215 |
+
hoverinfo="skip",
|
| 216 |
+
legendrank=0,
|
| 217 |
+
)
|
| 218 |
+
)
|
| 219 |
+
fig.add_trace(
|
| 220 |
+
go.Scatter(
|
| 221 |
+
x=np.concatenate([x_vals, x_vals[::-1]]),
|
| 222 |
+
y=np.concatenate([sigma_lower, sigma_upper[::-1]]),
|
| 223 |
+
fill="toself",
|
| 224 |
+
fillcolor="rgba(255,152,0,0.30)",
|
| 225 |
+
line=dict(color="rgba(0,0,0,0)"),
|
| 226 |
+
name=r"$\pm \sigma$",
|
| 227 |
+
legendgroup="zones",
|
| 228 |
+
showlegend=True,
|
| 229 |
+
hoverinfo="skip",
|
| 230 |
+
legendrank=0,
|
| 231 |
+
)
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
fig.add_trace(
|
| 235 |
+
go.Scatter(
|
| 236 |
+
x=[line_min, line_max],
|
| 237 |
+
y=[line_min, line_max],
|
| 238 |
+
mode="lines",
|
| 239 |
+
name=r"$y=x$",
|
| 240 |
+
legendgroup="zones",
|
| 241 |
+
showlegend=True,
|
| 242 |
+
line=dict(color="rgba(30,30,30,0.85)", width=LINE_WIDTH),
|
| 243 |
+
legendrank=0,
|
| 244 |
+
)
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
# ---------- Датасеты (legendrank=20) ----------
|
| 248 |
+
first_dataset = True
|
| 249 |
+
for i, (label, (syntax_true, syntax_pred)) in enumerate(datasets.items()):
|
| 250 |
+
x = np.array(syntax_true, dtype=float)
|
| 251 |
+
y = np.array(syntax_pred, dtype=float)
|
| 252 |
+
if x.size == 0 or y.size == 0:
|
| 253 |
+
continue
|
| 254 |
+
|
| 255 |
+
r2 = r2_values.get(label, None)
|
| 256 |
+
recall = recall_values.get(label, None) if recall_values else None
|
| 257 |
+
hover_lines = [f"<b>{label}</b>"]
|
| 258 |
+
if r2 is not None:
|
| 259 |
+
hover_lines.append(f"R² = {r2:.3f}")
|
| 260 |
+
if recall is not None:
|
| 261 |
+
hover_lines.append(f"Recall = {recall:.3f}")
|
| 262 |
+
hovertemplate = (
|
| 263 |
+
"<br>".join(hover_lines)
|
| 264 |
+
+ "<br>GT: %{x:.3f}<br>Pred: %{y:.3f}<extra></extra>"
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
fig.add_trace(
|
| 268 |
+
go.Scatter(
|
| 269 |
+
x=x,
|
| 270 |
+
y=y,
|
| 271 |
+
mode="markers",
|
| 272 |
+
name=label,
|
| 273 |
+
legendgroup="datasets",
|
| 274 |
+
legendgrouptitle_text=("Датасеты" if first_dataset else None),
|
| 275 |
+
showlegend=True,
|
| 276 |
+
marker=dict(
|
| 277 |
+
color=COLORS[i % len(COLORS)],
|
| 278 |
+
size=MARKER_SIZE,
|
| 279 |
+
opacity=0.96,
|
| 280 |
+
symbol=SYMBOLS[i % len(SYMBOLS)],
|
| 281 |
+
line=dict(
|
| 282 |
+
width=MARKER_LINE_WIDTH, color="rgba(255,255,255,0.95)"
|
| 283 |
+
),
|
| 284 |
+
),
|
| 285 |
+
hovertemplate=hovertemplate,
|
| 286 |
+
legendrank=20,
|
| 287 |
+
)
|
| 288 |
+
)
|
| 289 |
+
first_dataset = False
|
| 290 |
+
|
| 291 |
+
# ---------- Тренды: логистические (legendrank=30) ----------
|
| 292 |
+
first_trend = True
|
| 293 |
+
for i, (label, (syntax_true, syntax_pred)) in enumerate(datasets.items()):
|
| 294 |
+
x = np.array(syntax_true, dtype=float)
|
| 295 |
+
y = np.array(syntax_pred, dtype=float)
|
| 296 |
+
if x.size == 0 or y.size == 0:
|
| 297 |
+
continue
|
| 298 |
+
|
| 299 |
+
Xc, Yc = _fit_logistic(x, y, domain=domain)
|
| 300 |
+
if Xc is not None:
|
| 301 |
+
fig.add_trace(
|
| 302 |
+
go.Scatter(
|
| 303 |
+
x=Xc,
|
| 304 |
+
y=Yc,
|
| 305 |
+
mode="lines",
|
| 306 |
+
name=label, # без коротких alias, полное имя датасета
|
| 307 |
+
legendgroup="trends",
|
| 308 |
+
legendgrouptitle_text=(
|
| 309 |
+
"Тренды (логистические)" if first_trend else None
|
| 310 |
+
),
|
| 311 |
+
showlegend=True,
|
| 312 |
+
line=dict(
|
| 313 |
+
color=COLORS[i % len(COLORS)], width=TREND_LINE_WIDTH
|
| 314 |
+
),
|
| 315 |
+
hoverinfo="skip",
|
| 316 |
+
legendrank=30,
|
| 317 |
+
)
|
| 318 |
+
)
|
| 319 |
+
first_trend = False
|
| 320 |
+
|
| 321 |
+
# ---------- оформление ----------
|
| 322 |
+
title_text = f"SYNTAX predictions ({gt_row})"
|
| 323 |
+
if postfix:
|
| 324 |
+
title_text += f" {postfix}"
|
| 325 |
+
|
| 326 |
+
fig.update_layout(
|
| 327 |
+
title=dict(
|
| 328 |
+
text=title_text,
|
| 329 |
+
x=0.5,
|
| 330 |
+
xanchor="center",
|
| 331 |
+
font=dict(
|
| 332 |
+
size=TITLE_FONT_SIZE,
|
| 333 |
+
family=base_font["family"],
|
| 334 |
+
color="rgba(15,23,42,1)",
|
| 335 |
+
),
|
| 336 |
+
),
|
| 337 |
+
font=base_font,
|
| 338 |
+
xaxis_title=r"$\mathrm{SYNTAX\ GT}$",
|
| 339 |
+
yaxis_title=r"$\mathrm{SYNTAX\ predictions}$",
|
| 340 |
+
width=PLOT_WIDTH,
|
| 341 |
+
height=PLOT_HEIGHT,
|
| 342 |
+
plot_bgcolor=PLOT_BG_COLOR,
|
| 343 |
+
paper_bgcolor=PAPER_BG_COLOR,
|
| 344 |
+
legend=dict(
|
| 345 |
+
x=LEGEND_X,
|
| 346 |
+
y=LEGEND_Y,
|
| 347 |
+
bgcolor=LEGEND_BG_COLOR,
|
| 348 |
+
bordercolor="#CBD5E1",
|
| 349 |
+
borderwidth=1,
|
| 350 |
+
font=dict(size=LEGEND_FONT_SIZE, family=base_font["family"]),
|
| 351 |
+
tracegroupgap=8,
|
| 352 |
+
itemclick="toggle",
|
| 353 |
+
itemdoubleclick="toggleothers",
|
| 354 |
+
groupclick="toggleitem",
|
| 355 |
+
),
|
| 356 |
+
xaxis=dict(
|
| 357 |
+
showgrid=True,
|
| 358 |
+
gridcolor=GRID_COLOR,
|
| 359 |
+
gridwidth=1,
|
| 360 |
+
zeroline=False,
|
| 361 |
+
tickfont=dict(size=AXIS_TICK_FONT_SIZE),
|
| 362 |
+
range=[line_min, line_max],
|
| 363 |
+
constrain="domain",
|
| 364 |
+
),
|
| 365 |
+
yaxis=dict(
|
| 366 |
+
showgrid=True,
|
| 367 |
+
gridcolor=GRID_COLOR,
|
| 368 |
+
gridwidth=1,
|
| 369 |
+
zeroline=False,
|
| 370 |
+
tickfont=dict(size=AXIS_TICK_FONT_SIZE),
|
| 371 |
+
range=[line_min, line_max],
|
| 372 |
+
scaleanchor="x",
|
| 373 |
+
scaleratio=1,
|
| 374 |
+
constrain="domain",
|
| 375 |
+
),
|
| 376 |
+
margin=dict(
|
| 377 |
+
l=MARGIN_LEFT,
|
| 378 |
+
r=MARGIN_RIGHT,
|
| 379 |
+
t=MARGIN_TOP,
|
| 380 |
+
b=MARGIN_BOTTOM,
|
| 381 |
+
),
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
# ---------- сохранение ----------
|
| 385 |
+
save_dir = "visualizations"
|
| 386 |
+
if backbone:
|
| 387 |
+
save_dir = os.path.join(save_dir, "backbone")
|
| 388 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 389 |
+
|
| 390 |
+
postfix_html = f"{postfix}" if postfix else "syntax"
|
| 391 |
+
save_path_html = os.path.join(save_dir, f"{postfix_html}.html")
|
| 392 |
+
fig.write_html(save_path_html, include_mathjax="cdn")
|
| 393 |
+
print(f"Saved visualization with logistic trends: {save_path_html}")
|
inference/rnn_apply.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import tqdm
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
import click
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
import lightning.pytorch as pl
|
| 9 |
+
import sklearn.metrics as skm
|
| 10 |
+
|
| 11 |
+
from torch.utils.data import DataLoader
|
| 12 |
+
from torchvision.transforms import transforms as T
|
| 13 |
+
from torchvision.transforms._transforms_video import ToTensorVideo
|
| 14 |
+
from pytorchvideo.transforms import Normalize
|
| 15 |
+
|
| 16 |
+
# Импорты из соседних папок (относительные пути)
|
| 17 |
+
from full_model.rnn_dataset import SyntaxDataset
|
| 18 |
+
from full_model.rnn_model import SyntaxLightningModule
|
| 19 |
+
from metrics_visualization import visualize_final_syntax_plotly_multi
|
| 20 |
+
|
| 21 |
+
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 22 |
+
print(f"DEVICE: {DEVICE}")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def safe_sample_std(values):
|
| 26 |
+
"""Sample std (ddof=1). Если значение одно/пусто — 0.0."""
|
| 27 |
+
arr = np.array(values, dtype=float)
|
| 28 |
+
if arr.size <= 1:
|
| 29 |
+
return 0.0
|
| 30 |
+
return float(arr.std(ddof=1))
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def compute_metrics(y_true, y_pred, thr=22.0):
|
| 34 |
+
"""R2, MAE, Pearson, MAPE, Mean_Recall."""
|
| 35 |
+
y_true_arr = np.array(y_true, dtype=float)
|
| 36 |
+
y_pred_arr = np.array(y_pred, dtype=float)
|
| 37 |
+
|
| 38 |
+
r2 = float(skm.r2_score(y_true_arr, y_pred_arr))
|
| 39 |
+
mae = float(skm.mean_absolute_error(y_true_arr, y_pred_arr))
|
| 40 |
+
|
| 41 |
+
pearson = float(np.corrcoef(y_true_arr, y_pred_arr)[0, 1]) if len(y_true_arr) > 1 else 0.0
|
| 42 |
+
mape = float(skm.mean_absolute_percentage_error(y_true_arr, y_pred_arr))
|
| 43 |
+
|
| 44 |
+
y_true_bin = (y_true_arr >= thr).astype(int)
|
| 45 |
+
y_pred_bin = (y_pred_arr >= thr).astype(int)
|
| 46 |
+
unique_classes = np.unique(np.concatenate([y_true_bin, y_pred_bin]))
|
| 47 |
+
mean_recall = float(np.mean(skm.recall_score(y_true_bin, y_pred_bin, average=None, labels=[0, 1]))) \
|
| 48 |
+
if len(unique_classes) > 1 else 0.0
|
| 49 |
+
|
| 50 |
+
return r2, mae, pearson, mape, mean_recall
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@click.command()
|
| 54 |
+
@click.option("-d", "--dataset-paths", multiple=True,
|
| 55 |
+
help="JSON с метаданными датасетов (относительно dataset_root).")
|
| 56 |
+
@click.option("-n", "--dataset-names", multiple=True,
|
| 57 |
+
help="Имена датасетов для метрик/графиков.")
|
| 58 |
+
@click.option("-p", "--postfixes", multiple=True,
|
| 59 |
+
help="Суффиксы для файлов предсказаний.")
|
| 60 |
+
@click.option("-r", "--dataset-root", type=click.Path(exists=True),
|
| 61 |
+
help="Корень датасета (где лежат JSON и DICOM).")
|
| 62 |
+
@click.option("-v", "--video-size", type=click.Tuple([int, int]),
|
| 63 |
+
help="Размер видео (H, W).")
|
| 64 |
+
@click.option("--frames-per-clip",
|
| 65 |
+
help="Количество кадров в клипе.")
|
| 66 |
+
@click.option("--num-workers",
|
| 67 |
+
help="Число DataLoader workers.")
|
| 68 |
+
@click.option("--seed",
|
| 69 |
+
help="Random seed.")
|
| 70 |
+
@click.option("--pt-weights-format", is_flag=True,
|
| 71 |
+
help="True → модели в .pt (torch.save), False → .ckpt (Lightning).")
|
| 72 |
+
@click.option("--use-scaling", is_flag=True,
|
| 73 |
+
help="Применить a*x+b scaling из JSON.")
|
| 74 |
+
@click.option("--scaling-file",
|
| 75 |
+
help="JSON с коэффициентами scaling (относительно dataset_root).")
|
| 76 |
+
@click.option("-e", "--ensemble-name",
|
| 77 |
+
help="Имя ансамбля в metrics.json.")
|
| 78 |
+
@click.option("-m", "--metrics-file",
|
| 79 |
+
help="JSON с метриками экспериментов.")
|
| 80 |
+
def main(dataset_paths, dataset_names, postfixes, dataset_root, video_size,
|
| 81 |
+
frames_per_clip, num_workers, seed, pt_weights_format, use_scaling,
|
| 82 |
+
scaling_file, ensemble_name, metrics_file):
|
| 83 |
+
|
| 84 |
+
pl.seed_everything(seed)
|
| 85 |
+
postfix_plotly = "Ensemble"
|
| 86 |
+
|
| 87 |
+
# Пути к моделям (backbone + RNN-head)
|
| 88 |
+
model_paths = {
|
| 89 |
+
"left": [
|
| 90 |
+
"full_model/checkpoints/leftBinSyntax_R3D_fold00_lstm_mean_post_best.pt",
|
| 91 |
+
"full_model/checkpoints/leftBinSyntax_R3D_fold01_lstm_mean_post_best.pt",
|
| 92 |
+
"full_model/checkpoints/leftBinSyntax_R3D_fold02_lstm_mean_post_best.pt",
|
| 93 |
+
"full_model/checkpoints/leftBinSyntax_R3D_fold03_lstm_mean_post_best.pt",
|
| 94 |
+
"full_model/checkpoints/leftBinSyntax_R3D_fold04_lstm_mean_post_best.pt",
|
| 95 |
+
],
|
| 96 |
+
"right": [
|
| 97 |
+
"full_model/checkpoints/rightBinSyntax_R3D_fold00_lstm_mean_post_best.pt",
|
| 98 |
+
"full_model/checkpoints/rightBinSyntax_R3D_fold01_lstm_mean_post_best.pt",
|
| 99 |
+
"full_model/checkpoints/rightBinSyntax_R3D_fold02_lstm_mean_post_best.pt",
|
| 100 |
+
"full_model/checkpoints/rightBinSyntax_R3D_fold03_lstm_mean_post_best.pt",
|
| 101 |
+
"full_model/checkpoints/rightBinSyntax_R3D_fold04_lstm_mean_post_best.pt",
|
| 102 |
+
]
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
# Scaling параметры
|
| 106 |
+
scaling_params_dict = {}
|
| 107 |
+
if use_scaling:
|
| 108 |
+
postfix_plotly += "_scaled"
|
| 109 |
+
ensemble_name += "_scaled"
|
| 110 |
+
scaling_path = os.path.join(dataset_root, scaling_file)
|
| 111 |
+
if os.path.exists(scaling_path):
|
| 112 |
+
with open(scaling_path, "r") as f:
|
| 113 |
+
scaling_params_dict = json.load(f)
|
| 114 |
+
print(f"Loaded scaling from {scaling_path}")
|
| 115 |
+
else:
|
| 116 |
+
print(f"⚠️ Scaling file not found: {scaling_path}")
|
| 117 |
+
|
| 118 |
+
# Результаты ансамбля
|
| 119 |
+
ensemble_results = {
|
| 120 |
+
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
| 121 |
+
"use_scaling": use_scaling,
|
| 122 |
+
"pt_weights_format": pt_weights_format,
|
| 123 |
+
"datasets": {}
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
all_datasets, all_r2, all_recalls = {}, {}, {}
|
| 127 |
+
|
| 128 |
+
for dataset_path, dataset_name, postfix in zip(dataset_paths, dataset_names, postfixes):
|
| 129 |
+
# Относительные пути
|
| 130 |
+
abs_dataset_path = os.path.join(dataset_root, dataset_path)
|
| 131 |
+
results_file = os.path.join(dataset_root, "coeffs", f"{postfix}.json")
|
| 132 |
+
|
| 133 |
+
# Загрузка/вычисление предсказаний
|
| 134 |
+
if os.path.exists(results_file):
|
| 135 |
+
print(f"[{postfix}] Loading from {results_file}")
|
| 136 |
+
with open(results_file, "r") as f:
|
| 137 |
+
data = json.load(f)
|
| 138 |
+
syntax_true = data["syntax_true"]
|
| 139 |
+
left_preds_all = data["left_preds"]
|
| 140 |
+
right_preds_all = data["right_preds"]
|
| 141 |
+
else:
|
| 142 |
+
print(f"[{postfix}] Computing predictions...")
|
| 143 |
+
left_preds_all, left_sids = run_artery(
|
| 144 |
+
abs_dataset_path, "left", model_paths["left"],
|
| 145 |
+
video_size, frames_per_clip, num_workers, pt_weights_format
|
| 146 |
+
)
|
| 147 |
+
right_preds_all, right_sids = run_artery(
|
| 148 |
+
abs_dataset_path, "right", model_paths["right"],
|
| 149 |
+
video_size, frames_per_clip, num_workers, pt_weights_format
|
| 150 |
+
)
|
| 151 |
+
assert left_sids == right_sids
|
| 152 |
+
|
| 153 |
+
with open(abs_dataset_path, "r") as f:
|
| 154 |
+
dataset = json.load(f)
|
| 155 |
+
syntax_true = [rec.get("mean_syntax", rec.get("syntax")) for rec in dataset]
|
| 156 |
+
|
| 157 |
+
os.makedirs(os.path.dirname(results_file), exist_ok=True)
|
| 158 |
+
save_data = {
|
| 159 |
+
"syntax_true": syntax_true,
|
| 160 |
+
"left_preds": left_preds_all,
|
| 161 |
+
"right_preds": right_preds_all
|
| 162 |
+
}
|
| 163 |
+
with open(results_file, "w") as f:
|
| 164 |
+
json.dump(save_data, f)
|
| 165 |
+
print(f"[{postfix}] Saved to {results_file}")
|
| 166 |
+
|
| 167 |
+
# Scaling (fold-wise для left/right)
|
| 168 |
+
if use_scaling:
|
| 169 |
+
left_scaled_all, right_scaled_all = [], []
|
| 170 |
+
for pred_list in left_preds_all:
|
| 171 |
+
scaled = [scaling_params_dict.get(f"fold{i}", (1.0, 0.0))[0] * val +
|
| 172 |
+
scaling_params_dict.get(f"fold{i}", (1.0, 0.0))[1]
|
| 173 |
+
for i, val in enumerate(pred_list)]
|
| 174 |
+
left_scaled_all.append(scaled)
|
| 175 |
+
for pred_list in right_preds_all:
|
| 176 |
+
scaled = [scaling_params_dict.get(f"fold{i}", (1.0, 0.0))[0] * val +
|
| 177 |
+
scaling_params_dict.get(f"fold{i}", (1.0, 0.0))[1]
|
| 178 |
+
for i, val in enumerate(pred_list)]
|
| 179 |
+
right_scaled_all.append(scaled)
|
| 180 |
+
else:
|
| 181 |
+
left_scaled_all, right_scaled_all = left_preds_all, right_preds_all
|
| 182 |
+
|
| 183 |
+
# Ансамбль: mean по фолдам + left+right
|
| 184 |
+
syntax_pred = [max(0.0, float(np.mean([l + r for l, r in zip(l_list, r_list)])))
|
| 185 |
+
for l_list, r_list in zip(left_scaled_all, right_scaled_all)]
|
| 186 |
+
|
| 187 |
+
# Метрики ансамбля
|
| 188 |
+
r2, mae, pearson, mape, mean_recall = compute_metrics(syntax_true, syntax_pred)
|
| 189 |
+
print(f"[{postfix}] ENSEMBLE: R2={r2:.4f}, Pearson={pearson:.4f}, "
|
| 190 |
+
f"MAE={mae:.4f}, MAPE={mape:.4f}, Recall={mean_recall:.4f}")
|
| 191 |
+
|
| 192 |
+
# STD по фолдам
|
| 193 |
+
n_folds = len(left_scaled_all[0]) if left_scaled_all else 0
|
| 194 |
+
fold_metrics = {metric: [] for metric in ["R2", "MAE", "Pearson", "MAPE", "Mean_Recall"]}
|
| 195 |
+
for k in range(n_folds):
|
| 196 |
+
pred_k = [max(0.0, l_list[k] + r_list[k])
|
| 197 |
+
for l_list, r_list in zip(left_scaled_all, right_scaled_all)]
|
| 198 |
+
fold_r2, fold_mae, fold_pearson, fold_mape, fold_recall = compute_metrics(syntax_true, pred_k)
|
| 199 |
+
for metric, value in zip(fold_metrics.keys(),
|
| 200 |
+
[fold_r2, fold_mae, fold_pearson, fold_mape, fold_recall]):
|
| 201 |
+
fold_metrics[metric].append(value)
|
| 202 |
+
|
| 203 |
+
fold_summary = {k: {"mean": float(np.mean(v)), "std": safe_sample_std(v), "values": v}
|
| 204 |
+
for k, v in fold_metrics.items()}
|
| 205 |
+
|
| 206 |
+
# Визуализация и сохранение
|
| 207 |
+
all_datasets[dataset_name] = (syntax_true, syntax_pred)
|
| 208 |
+
all_r2[dataset_name] = r2
|
| 209 |
+
all_recalls[dataset_name] = mean_recall
|
| 210 |
+
|
| 211 |
+
ensemble_results["datasets"][dataset_name] = {
|
| 212 |
+
# Ансамбль
|
| 213 |
+
"R2": round(r2, 4), "MAE": round(mae, 4),
|
| 214 |
+
"Pearson": round(pearson, 4), "MAPE": round(mape, 4),
|
| 215 |
+
"Mean_Recall": round(mean_recall, 4), "N_samples": len(syntax_true),
|
| 216 |
+
# По фолдам (mean±std)
|
| 217 |
+
**{f"{k}_mean": round(v["mean"], 4) for k, v in fold_summary.items()},
|
| 218 |
+
**{f"{k}_std": round(v["std"], 4) for k, v in fold_summary.items()},
|
| 219 |
+
**{f"{k}_folds": [round(x, 4) for x in v["values"]] for k, v in fold_summary.items()}
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
# Сохранение метрик
|
| 223 |
+
metrics_path = os.path.join(dataset_root, metrics_file)
|
| 224 |
+
full_history = {}
|
| 225 |
+
if os.path.exists(metrics_path):
|
| 226 |
+
try:
|
| 227 |
+
with open(metrics_path, "r") as f:
|
| 228 |
+
full_history = json.load(f)
|
| 229 |
+
except json.JSONDecodeError:
|
| 230 |
+
print("⚠️ Metrics file corrupted. Creating new.")
|
| 231 |
+
|
| 232 |
+
full_history[ensemble_name] = ensemble_results
|
| 233 |
+
with open(metrics_path, "w") as f:
|
| 234 |
+
json.dump(full_history, f, indent=4)
|
| 235 |
+
print(f"✅ Metrics saved: {metrics_path}")
|
| 236 |
+
|
| 237 |
+
# Визуализация
|
| 238 |
+
visualize_final_syntax_plotly_multi(
|
| 239 |
+
datasets=all_datasets, r2_values=all_r2, recall_values=all_recalls,
|
| 240 |
+
gt_row="ENSEMBLE", postfix=postfix_plotly
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def run_artery(dataset_path, artery, model_paths, video_size, frames_per_clip,
|
| 245 |
+
num_workers, pt_weights_format=False):
|
| 246 |
+
"""Инференс для одной артерии (5 фолдов)."""
|
| 247 |
+
imagenet_mean = [0.485, 0.456, 0.406]
|
| 248 |
+
imagenet_std = [0.229, 0.224, 0.225]
|
| 249 |
+
test_transform = T.Compose([
|
| 250 |
+
ToTensorVideo(),
|
| 251 |
+
T.Resize(size=video_size, antialias=True),
|
| 252 |
+
Normalize(mean=imagenet_mean, std=imagenet_std),
|
| 253 |
+
])
|
| 254 |
+
|
| 255 |
+
val_set = SyntaxDataset(
|
| 256 |
+
root=os.path.dirname(dataset_path),
|
| 257 |
+
meta=dataset_path,
|
| 258 |
+
train=False,
|
| 259 |
+
length=frames_per_clip,
|
| 260 |
+
label="", # inference mode
|
| 261 |
+
artery=artery,
|
| 262 |
+
inference=True,
|
| 263 |
+
transform=test_transform
|
| 264 |
+
)
|
| 265 |
+
val_loader = DataLoader(val_set, batch_size=1, num_workers=num_workers,
|
| 266 |
+
shuffle=False, pin_memory=True)
|
| 267 |
+
print(f"{artery} artery: {len(val_loader)} samples")
|
| 268 |
+
|
| 269 |
+
models = []
|
| 270 |
+
for path in model_paths:
|
| 271 |
+
if not os.path.exists(path):
|
| 272 |
+
print(f"⚠️ Model not found: {path}")
|
| 273 |
+
continue
|
| 274 |
+
model = SyntaxLightningModule(
|
| 275 |
+
num_classes=2, lr=1e-5, variant="lstm_mean",
|
| 276 |
+
weight_decay=0.001, max_epochs=1,
|
| 277 |
+
pl_weight_path=path, pt_weights_format=pt_weights_format
|
| 278 |
+
)
|
| 279 |
+
model.to(DEVICE)
|
| 280 |
+
model.eval()
|
| 281 |
+
models.append(model)
|
| 282 |
+
if not models:
|
| 283 |
+
raise RuntimeError(f"No models loaded for {artery}")
|
| 284 |
+
|
| 285 |
+
preds_all, sids = [], []
|
| 286 |
+
with torch.no_grad():
|
| 287 |
+
for x, [y], [t], [sid] in tqdm.tqdm(val_loader, desc=f"{artery} infer"):
|
| 288 |
+
if len(x.shape) == 1: # пустое видео
|
| 289 |
+
val_syntax_list = [0.0] * len(models)
|
| 290 |
+
else:
|
| 291 |
+
x = x.to(DEVICE)
|
| 292 |
+
val_syntax_list = []
|
| 293 |
+
for model in models:
|
| 294 |
+
pred = model(x)
|
| 295 |
+
_, val_log = pred # регрессионный logit
|
| 296 |
+
val = float(torch.exp(val_log).cpu()) - 1
|
| 297 |
+
val_syntax_list.append(val)
|
| 298 |
+
preds_all.append(val_syntax_list)
|
| 299 |
+
sids.append(sid[0]) # study_uid
|
| 300 |
+
|
| 301 |
+
return preds_all, sids
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
if __name__ == "__main__":
|
| 305 |
+
main()
|