from typing import Any import torch import torch.nn.functional as F from torch import nn, optim import lightning.pytorch as pl import torchvision.models.video as tvmv import sklearn.metrics as skm import numpy as np class SyntaxLightningModule(pl.LightningModule): """ Полная модель: 3D-ResNet backbone + RNN/Transformer head для SYNTAX score. Варианты head (variant): - mean_out: среднее по выходам backbone - mean: среднее эмбеддингов + FC - lstm_mean/lstm_last: LSTM (mean/last) - gru_mean/gru_last: GRU (mean/last) - bert_mean/bert_cls/bert_cls2: Transformer encoder """ SUPPORTED_VARIANTS = [ "mean_out", "mean", "lstm_mean", "lstm_last", "gru_mean", "gru_last", "bert_mean", "bert_cls", "bert_cls2" ] def __init__( self, num_classes: int, lr: float, variant: str, weight_decay: float = 0.0, max_epochs: int = None, weight_path: str = None, # путь к backbone-чекпоинту (.ckpt) pl_weight_path: str = None, # путь к полной модели (.ckpt или .pt) pt_weights_format: bool = False, # True → .pt формат (torch.save), False → Lightning .ckpt sigma_a: float = 0.0, sigma_b: float = 1.0, **kwargs, ): super().__init__() self.save_hyperparameters() # Проверяем вариант head if variant not in self.SUPPORTED_VARIANTS: raise ValueError(f"variant must be one of {self.SUPPORTED_VARIANTS}") self.num_classes = num_classes self.variant = variant self.lr = lr self.weight_decay = weight_decay self.max_epochs = max_epochs self.sigma_a = sigma_a self.sigma_b = sigma_b # Backbone: 3D-ResNet self.model = tvmv.r3d_18(weights=tvmv.R3D_18_Weights.DEFAULT) in_features = self.model.fc.in_features # Для большинства head заменяем fc на Identity (эмбеддинги) if variant != "mean_out": self.model.fc = nn.Identity() else: # mean_out использует финальные logits backbone self.model.fc = nn.Linear(in_features, 2, bias=True) # Загрузка backbone (если передан weight_path) if weight_path is not None: print(f"Loading backbone weights from {weight_path}") self.load_weights_backbone(weight_path, self.model) # Инициализация head в зависимости от variant self._init_head(in_features, num_classes) # Загрузка полной модели (если передан pl_weight_path) if pl_weight_path is not None: print(f"Loading full model weights from {pl_weight_path} (pt_format={pt_weights_format})") self.load_full_model(pl_weight_path, pt_weights_format) # Лоссы self.loss_clf = nn.BCEWithLogitsLoss(reduction="none") self.loss_reg = nn.MSELoss(reduction="none") # Буферы метрик self.y_val, self.p_val, self.r_val = [], [], [] self.ty_val, self.tp_val = [], [] def _init_head(self, in_features: int, num_classes: int): """Инициализация head в зависимости от variant.""" if self.variant == "mean_out": return # используем self.model.fc elif self.variant in ("gru_mean", "gru_last"): self.rnn = nn.GRU(in_features, in_features // 4, batch_first=True) self.dropout = nn.Dropout(0.2) self.fc = nn.Linear(in_features // 4, num_classes, bias=True) elif self.variant in ("lstm_mean", "lstm_last"): self.lstm = nn.LSTM( input_size=in_features, hidden_size=in_features // 4, proj_size=num_classes, batch_first=True, ) elif self.variant == "mean": self.fc = nn.Linear(in_features, num_classes, bias=True) elif self.variant in ("bert_mean", "bert_cls", "bert_cls2"): encoder_layer = nn.TransformerEncoderLayer( d_model=in_features, nhead=4, batch_first=True, dim_feedforward=in_features // 4, ) self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=1) self.dropout = nn.Dropout(0.2) self.fc = nn.Linear(in_features, num_classes, bias=True) if self.variant == "bert_cls2": self.cls = nn.Parameter(torch.randn(1, 1, in_features)) def forward(self, x: torch.Tensor) -> torch.Tensor: """ x: (batch, N_videos, C, T, H, W) → (batch, N_videos, embed_dim) → head → (batch, num_classes) """ batch_size, seq_len, *video_shape = x.shape x = torch.flatten(x, start_dim=0, end_dim=1) # (batch*seq, C, T, H, W) x = self.model(x) # (batch*seq, embed_dim) x = torch.unflatten(x, 0, (batch_size, seq_len)) # (batch, seq, embed_dim) # Head if self.variant == "mean_out": x = torch.mean(x, dim=1) # mean по последовательности elif self.variant in ("gru_mean", "gru_last"): all_outs, last_out = self.rnn(x) x = torch.mean(all_outs, dim=1) if self.variant == "gru_mean" else last_out x = self.dropout(x) x = self.fc(x) elif self.variant in ("lstm_mean", "lstm_last"): all_outs, (last_out, _) = self.lstm(x) x = torch.mean(all_outs, dim=1) if self.variant == "lstm_mean" else last_out elif self.variant == "mean": x = torch.mean(x, dim=1) x = self.fc(x) elif self.variant in ("bert_mean", "bert_cls", "bert_cls2"): if self.variant == "bert_cls": x = F.pad(x, (0, 0, 1, 0), "constant", 0) # prepend CLS elif self.variant == "bert_cls2": bs = x.size(0) x = torch.cat([self.cls.expand(bs, -1, -1), x], dim=1) x = self.encoder(x) x = torch.mean(x, dim=1) if self.variant == "bert_mean" else x[:, 0, :] x = self.dropout(x) x = self.fc(x) return x def training_step(self, batch, batch_idx): x, y, target, path = batch y_hat = self(x) yp_clf, yp_reg = y_hat[:, 0:1], y_hat[:, 1:] # BCE с down-weight для отрицательных примеров weights_clf = torch.where(y > 0, 1.0, 0.45) clf_loss = (self.loss_clf(yp_clf, y) * weights_clf).mean() # Регрессия с вариабельностью reg_loss_raw = self.loss_reg(yp_reg, target) sigma = self.sigma_a * target + self.sigma_b reg_loss = (reg_loss_raw / (sigma ** 2)).mean() loss = clf_loss + 0.5 * reg_loss # Логирование y_pred = torch.sigmoid(yp_clf) y_bin = torch.round(y.detach().cpu()).int() y_pred_bin = torch.round(y_pred.detach().cpu()).int() self.log("train_clf_loss", clf_loss, prog_bar=True, sync_dist=True) self.log("train_val_loss", reg_loss, prog_bar=True, sync_dist=True) self.log("train_full_loss", loss, prog_bar=True, sync_dist=True) self.log("train_f1", skm.f1_score(y_bin, y_pred_bin, zero_division=0), prog_bar=True, sync_dist=True) self.log("train_acc", skm.accuracy_score(y_bin, y_pred_bin), prog_bar=True, sync_dist=True) return loss def validation_step(self, batch, batch_idx): x, y, target, path = batch y_hat = self(x) yp_clf, yp_reg = y_hat[:, 0:1], y_hat[:, 1:] # Аккумулируем для метрик y_pred = torch.sigmoid(yp_clf) self.y_val.append(int(y[..., 0].cpu())) self.p_val.append(float(y_pred[..., 0].cpu())) self.r_val.append(round(float(y_pred[..., 0].cpu()))) self.ty_val.append(float(target[..., 0].cpu())) self.tp_val.append(float(yp_reg[..., 0].cpu())) # Лосс (тот же, что и в train) clf_loss = self.loss_clf(yp_clf, y).mean() reg_loss_raw = self.loss_reg(yp_reg, target) sigma = self.sigma_a * target + self.sigma_b reg_loss = (reg_loss_raw / (sigma ** 2)).mean() loss = clf_loss + 0.5 * reg_loss return loss def on_validation_epoch_end(self): try: auc = skm.roc_auc_score(self.y_val, self.p_val) f1 = skm.f1_score(self.y_val, self.r_val, zero_division=0) acc = skm.accuracy_score(self.y_val, self.r_val) mae = skm.mean_absolute_error(self.y_val, self.r_val) rmse = skm.root_mean_squared_error(self.ty_val, self.tp_val) self.log("val_auc", auc, prog_bar=True, sync_dist=True) self.log("val_f1", f1, prog_bar=True, sync_dist=True) self.log("val_acc", acc, prog_bar=True, sync_dist=True) self.log("val_mae", mae, prog_bar=True, sync_dist=True) self.log("val_rmse", rmse, prog_bar=True, sync_dist=True) except ValueError as err: print(err) print("Y_VAL", self.y_val[:10], "...") print("P_VAL", self.p_val[:10], "...") # Очистка буферов [buf.clear() for buf in [self.y_val, self.p_val, self.r_val, self.ty_val, self.tp_val]] def on_train_epoch_end(self): lr = self.optimizers().optimizer.param_groups[0]["lr"] self.log("lr", lr, on_epoch=True, sync_dist=True) def configure_optimizers(self): # Pretrain (заморозка backbone) или full fine-tune if self.weight_path: # Pretrain: обучаем только head trainable_modules = self._get_trainable_head_modules() for param in self.parameters(): param.requires_grad = False for module in trainable_modules: for param in module.parameters(): param.requires_grad = True params = [p for module in trainable_modules for p in module.parameters()] else: # Full: всё for param in self.parameters(): param.requires_grad = True params = self.parameters() optimizer = optim.Adam(params, lr=self.lr, weight_decay=self.weight_decay) if self.max_epochs: scheduler = optim.lr_scheduler.OneCycleLR( optimizer, max_lr=self.lr, total_steps=self.max_epochs ) return [optimizer], [scheduler] return optimizer def _get_trainable_head_modules(self): """Возвращает список обучаемых модулей head.""" if self.variant == "mean_out": return [self.model.fc] elif self.variant in ("gru_mean", "gru_last"): return [self.rnn, self.fc] elif self.variant in ("lstm_mean", "lstm_last"): return [self.lstm] elif self.variant == "mean": return [self.fc] elif self.variant in ("bert_mean", "bert_cls", "bert_cls2"): modules = [self.encoder, self.fc] if self.variant == "bert_cls2": modules.append(self.cls) return modules return [] def load_weights_backbone(self, weight_path: str, model): """Загрузка backbone из Lightning .ckpt.""" ckpt = torch.load(weight_path, map_location="cpu", weights_only=False) state_dict = ckpt["state_dict"] new_state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()} model.load_state_dict(new_state_dict, strict=False) def load_full_model(self, pl_weight_path: str, pt_weights_format: bool): """Загрузка полной модели (.ckpt или .pt).""" if pt_weights_format: # .pt формат (torch.save) state_dict = torch.load(pl_weight_path, map_location="cpu", weights_only=False) else: # Lightning .ckpt ckpt = torch.load(pl_weight_path, map_location="cpu", weights_only=False) state_dict = ckpt["state_dict"] # Backbone self.load_weights(state_dict, self.model, "model") # Head trainable_modules = self._get_trainable_head_modules() for module in trainable_modules: prefix = module.__class__.__name__.lower() self.load_weights(state_dict, module, prefix) if self.variant == "bert_cls2": if "cls" in state_dict: self.cls.data.copy_(state_dict["cls"]) def load_weights(self, state_dict, module, prefix: str): """Загрузка весов модуля по префиксу.""" module_state = { k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k.startswith(prefix) } missing, unexpected = module.load_state_dict(module_state, strict=False) if missing: print(f"Missing keys for {prefix}: {len(missing)}") if unexpected: print(f"Unexpected keys for {prefix}: {len(unexpected)}") def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: x, y, target, path = batch y_hat = self(x) yp_clf, yp_reg = y_hat[:, 0:1], y_hat[:, 1:] y_prob = torch.sigmoid(yp_clf) return { "y": y, "y_pred": torch.round(y_prob), "y_prob": y_prob, "y_reg": yp_reg, "target": target, }