syntax-model / full_model /rnn_model.py
MesserMMP's picture
add full model files
c2d9714
from typing import Any
import torch
import torch.nn.functional as F
from torch import nn, optim
import lightning.pytorch as pl
import torchvision.models.video as tvmv
import sklearn.metrics as skm
import numpy as np
class SyntaxLightningModule(pl.LightningModule):
"""
Полная модель: 3D-ResNet backbone + RNN/Transformer head для SYNTAX score.
Варианты head (variant):
- mean_out: среднее по выходам backbone
- mean: среднее эмбеддингов + FC
- lstm_mean/lstm_last: LSTM (mean/last)
- gru_mean/gru_last: GRU (mean/last)
- bert_mean/bert_cls/bert_cls2: Transformer encoder
"""
SUPPORTED_VARIANTS = [
"mean_out", "mean", "lstm_mean", "lstm_last",
"gru_mean", "gru_last", "bert_mean", "bert_cls", "bert_cls2"
]
def __init__(
self,
num_classes: int,
lr: float,
variant: str,
weight_decay: float = 0.0,
max_epochs: int = None,
weight_path: str = None, # путь к backbone-чекпоинту (.ckpt)
pl_weight_path: str = None, # путь к полной модели (.ckpt или .pt)
pt_weights_format: bool = False, # True → .pt формат (torch.save), False → Lightning .ckpt
sigma_a: float = 0.0,
sigma_b: float = 1.0,
**kwargs,
):
super().__init__()
self.save_hyperparameters()
# Проверяем вариант head
if variant not in self.SUPPORTED_VARIANTS:
raise ValueError(f"variant must be one of {self.SUPPORTED_VARIANTS}")
self.num_classes = num_classes
self.variant = variant
self.lr = lr
self.weight_decay = weight_decay
self.max_epochs = max_epochs
self.sigma_a = sigma_a
self.sigma_b = sigma_b
# Backbone: 3D-ResNet
self.model = tvmv.r3d_18(weights=tvmv.R3D_18_Weights.DEFAULT)
in_features = self.model.fc.in_features
# Для большинства head заменяем fc на Identity (эмбеддинги)
if variant != "mean_out":
self.model.fc = nn.Identity()
else:
# mean_out использует финальные logits backbone
self.model.fc = nn.Linear(in_features, 2, bias=True)
# Загрузка backbone (если передан weight_path)
if weight_path is not None:
print(f"Loading backbone weights from {weight_path}")
self.load_weights_backbone(weight_path, self.model)
# Инициализация head в зависимости от variant
self._init_head(in_features, num_classes)
# Загрузка полной модели (если передан pl_weight_path)
if pl_weight_path is not None:
print(f"Loading full model weights from {pl_weight_path} (pt_format={pt_weights_format})")
self.load_full_model(pl_weight_path, pt_weights_format)
# Лоссы
self.loss_clf = nn.BCEWithLogitsLoss(reduction="none")
self.loss_reg = nn.MSELoss(reduction="none")
# Буферы метрик
self.y_val, self.p_val, self.r_val = [], [], []
self.ty_val, self.tp_val = [], []
def _init_head(self, in_features: int, num_classes: int):
"""Инициализация head в зависимости от variant."""
if self.variant == "mean_out":
return # используем self.model.fc
elif self.variant in ("gru_mean", "gru_last"):
self.rnn = nn.GRU(in_features, in_features // 4, batch_first=True)
self.dropout = nn.Dropout(0.2)
self.fc = nn.Linear(in_features // 4, num_classes, bias=True)
elif self.variant in ("lstm_mean", "lstm_last"):
self.lstm = nn.LSTM(
input_size=in_features,
hidden_size=in_features // 4,
proj_size=num_classes,
batch_first=True,
)
elif self.variant == "mean":
self.fc = nn.Linear(in_features, num_classes, bias=True)
elif self.variant in ("bert_mean", "bert_cls", "bert_cls2"):
encoder_layer = nn.TransformerEncoderLayer(
d_model=in_features,
nhead=4,
batch_first=True,
dim_feedforward=in_features // 4,
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=1)
self.dropout = nn.Dropout(0.2)
self.fc = nn.Linear(in_features, num_classes, bias=True)
if self.variant == "bert_cls2":
self.cls = nn.Parameter(torch.randn(1, 1, in_features))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (batch, N_videos, C, T, H, W)
→ (batch, N_videos, embed_dim) → head → (batch, num_classes)
"""
batch_size, seq_len, *video_shape = x.shape
x = torch.flatten(x, start_dim=0, end_dim=1) # (batch*seq, C, T, H, W)
x = self.model(x) # (batch*seq, embed_dim)
x = torch.unflatten(x, 0, (batch_size, seq_len)) # (batch, seq, embed_dim)
# Head
if self.variant == "mean_out":
x = torch.mean(x, dim=1) # mean по последовательности
elif self.variant in ("gru_mean", "gru_last"):
all_outs, last_out = self.rnn(x)
x = torch.mean(all_outs, dim=1) if self.variant == "gru_mean" else last_out
x = self.dropout(x)
x = self.fc(x)
elif self.variant in ("lstm_mean", "lstm_last"):
all_outs, (last_out, _) = self.lstm(x)
x = torch.mean(all_outs, dim=1) if self.variant == "lstm_mean" else last_out
elif self.variant == "mean":
x = torch.mean(x, dim=1)
x = self.fc(x)
elif self.variant in ("bert_mean", "bert_cls", "bert_cls2"):
if self.variant == "bert_cls":
x = F.pad(x, (0, 0, 1, 0), "constant", 0) # prepend CLS
elif self.variant == "bert_cls2":
bs = x.size(0)
x = torch.cat([self.cls.expand(bs, -1, -1), x], dim=1)
x = self.encoder(x)
x = torch.mean(x, dim=1) if self.variant == "bert_mean" else x[:, 0, :]
x = self.dropout(x)
x = self.fc(x)
return x
def training_step(self, batch, batch_idx):
x, y, target, path = batch
y_hat = self(x)
yp_clf, yp_reg = y_hat[:, 0:1], y_hat[:, 1:]
# BCE с down-weight для отрицательных примеров
weights_clf = torch.where(y > 0, 1.0, 0.45)
clf_loss = (self.loss_clf(yp_clf, y) * weights_clf).mean()
# Регрессия с вариабельностью
reg_loss_raw = self.loss_reg(yp_reg, target)
sigma = self.sigma_a * target + self.sigma_b
reg_loss = (reg_loss_raw / (sigma ** 2)).mean()
loss = clf_loss + 0.5 * reg_loss
# Логирование
y_pred = torch.sigmoid(yp_clf)
y_bin = torch.round(y.detach().cpu()).int()
y_pred_bin = torch.round(y_pred.detach().cpu()).int()
self.log("train_clf_loss", clf_loss, prog_bar=True, sync_dist=True)
self.log("train_val_loss", reg_loss, prog_bar=True, sync_dist=True)
self.log("train_full_loss", loss, prog_bar=True, sync_dist=True)
self.log("train_f1", skm.f1_score(y_bin, y_pred_bin, zero_division=0),
prog_bar=True, sync_dist=True)
self.log("train_acc", skm.accuracy_score(y_bin, y_pred_bin),
prog_bar=True, sync_dist=True)
return loss
def validation_step(self, batch, batch_idx):
x, y, target, path = batch
y_hat = self(x)
yp_clf, yp_reg = y_hat[:, 0:1], y_hat[:, 1:]
# Аккумулируем для метрик
y_pred = torch.sigmoid(yp_clf)
self.y_val.append(int(y[..., 0].cpu()))
self.p_val.append(float(y_pred[..., 0].cpu()))
self.r_val.append(round(float(y_pred[..., 0].cpu())))
self.ty_val.append(float(target[..., 0].cpu()))
self.tp_val.append(float(yp_reg[..., 0].cpu()))
# Лосс (тот же, что и в train)
clf_loss = self.loss_clf(yp_clf, y).mean()
reg_loss_raw = self.loss_reg(yp_reg, target)
sigma = self.sigma_a * target + self.sigma_b
reg_loss = (reg_loss_raw / (sigma ** 2)).mean()
loss = clf_loss + 0.5 * reg_loss
return loss
def on_validation_epoch_end(self):
try:
auc = skm.roc_auc_score(self.y_val, self.p_val)
f1 = skm.f1_score(self.y_val, self.r_val, zero_division=0)
acc = skm.accuracy_score(self.y_val, self.r_val)
mae = skm.mean_absolute_error(self.y_val, self.r_val)
rmse = skm.root_mean_squared_error(self.ty_val, self.tp_val)
self.log("val_auc", auc, prog_bar=True, sync_dist=True)
self.log("val_f1", f1, prog_bar=True, sync_dist=True)
self.log("val_acc", acc, prog_bar=True, sync_dist=True)
self.log("val_mae", mae, prog_bar=True, sync_dist=True)
self.log("val_rmse", rmse, prog_bar=True, sync_dist=True)
except ValueError as err:
print(err)
print("Y_VAL", self.y_val[:10], "...")
print("P_VAL", self.p_val[:10], "...")
# Очистка буферов
[buf.clear() for buf in [self.y_val, self.p_val, self.r_val, self.ty_val, self.tp_val]]
def on_train_epoch_end(self):
lr = self.optimizers().optimizer.param_groups[0]["lr"]
self.log("lr", lr, on_epoch=True, sync_dist=True)
def configure_optimizers(self):
# Pretrain (заморозка backbone) или full fine-tune
if self.weight_path:
# Pretrain: обучаем только head
trainable_modules = self._get_trainable_head_modules()
for param in self.parameters():
param.requires_grad = False
for module in trainable_modules:
for param in module.parameters():
param.requires_grad = True
params = [p for module in trainable_modules for p in module.parameters()]
else:
# Full: всё
for param in self.parameters():
param.requires_grad = True
params = self.parameters()
optimizer = optim.Adam(params, lr=self.lr, weight_decay=self.weight_decay)
if self.max_epochs:
scheduler = optim.lr_scheduler.OneCycleLR(
optimizer, max_lr=self.lr, total_steps=self.max_epochs
)
return [optimizer], [scheduler]
return optimizer
def _get_trainable_head_modules(self):
"""Возвращает список обучаемых модулей head."""
if self.variant == "mean_out":
return [self.model.fc]
elif self.variant in ("gru_mean", "gru_last"):
return [self.rnn, self.fc]
elif self.variant in ("lstm_mean", "lstm_last"):
return [self.lstm]
elif self.variant == "mean":
return [self.fc]
elif self.variant in ("bert_mean", "bert_cls", "bert_cls2"):
modules = [self.encoder, self.fc]
if self.variant == "bert_cls2":
modules.append(self.cls)
return modules
return []
def load_weights_backbone(self, weight_path: str, model):
"""Загрузка backbone из Lightning .ckpt."""
ckpt = torch.load(weight_path, map_location="cpu", weights_only=False)
state_dict = ckpt["state_dict"]
new_state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
model.load_state_dict(new_state_dict, strict=False)
def load_full_model(self, pl_weight_path: str, pt_weights_format: bool):
"""Загрузка полной модели (.ckpt или .pt)."""
if pt_weights_format:
# .pt формат (torch.save)
state_dict = torch.load(pl_weight_path, map_location="cpu", weights_only=False)
else:
# Lightning .ckpt
ckpt = torch.load(pl_weight_path, map_location="cpu", weights_only=False)
state_dict = ckpt["state_dict"]
# Backbone
self.load_weights(state_dict, self.model, "model")
# Head
trainable_modules = self._get_trainable_head_modules()
for module in trainable_modules:
prefix = module.__class__.__name__.lower()
self.load_weights(state_dict, module, prefix)
if self.variant == "bert_cls2":
if "cls" in state_dict:
self.cls.data.copy_(state_dict["cls"])
def load_weights(self, state_dict, module, prefix: str):
"""Загрузка весов модуля по префиксу."""
module_state = {
k.replace(f"{prefix}.", ""): v
for k, v in state_dict.items()
if k.startswith(prefix)
}
missing, unexpected = module.load_state_dict(module_state, strict=False)
if missing:
print(f"Missing keys for {prefix}: {len(missing)}")
if unexpected:
print(f"Unexpected keys for {prefix}: {len(unexpected)}")
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
x, y, target, path = batch
y_hat = self(x)
yp_clf, yp_reg = y_hat[:, 0:1], y_hat[:, 1:]
y_prob = torch.sigmoid(yp_clf)
return {
"y": y,
"y_pred": torch.round(y_prob),
"y_prob": y_prob,
"y_reg": yp_reg,
"target": target,
}