| | from typing import Any |
| | import torch |
| | import torch.nn.functional as F |
| | from torch import nn, optim |
| | import lightning.pytorch as pl |
| | import torchvision.models.video as tvmv |
| | import sklearn.metrics as skm |
| | import numpy as np |
| |
|
| |
|
| | class SyntaxLightningModule(pl.LightningModule): |
| | """ |
| | Полная модель: 3D-ResNet backbone + RNN/Transformer head для SYNTAX score. |
| | |
| | Варианты head (variant): |
| | - mean_out: среднее по выходам backbone |
| | - mean: среднее эмбеддингов + FC |
| | - lstm_mean/lstm_last: LSTM (mean/last) |
| | - gru_mean/gru_last: GRU (mean/last) |
| | - bert_mean/bert_cls/bert_cls2: Transformer encoder |
| | """ |
| |
|
| | SUPPORTED_VARIANTS = [ |
| | "mean_out", "mean", "lstm_mean", "lstm_last", |
| | "gru_mean", "gru_last", "bert_mean", "bert_cls", "bert_cls2" |
| | ] |
| |
|
| | def __init__( |
| | self, |
| | num_classes: int, |
| | lr: float, |
| | variant: str, |
| | weight_decay: float = 0.0, |
| | max_epochs: int = None, |
| | weight_path: str = None, |
| | pl_weight_path: str = None, |
| | pt_weights_format: bool = False, |
| | sigma_a: float = 0.0, |
| | sigma_b: float = 1.0, |
| | **kwargs, |
| | ): |
| | super().__init__() |
| | self.save_hyperparameters() |
| |
|
| | |
| | if variant not in self.SUPPORTED_VARIANTS: |
| | raise ValueError(f"variant must be one of {self.SUPPORTED_VARIANTS}") |
| |
|
| | self.num_classes = num_classes |
| | self.variant = variant |
| | self.lr = lr |
| | self.weight_decay = weight_decay |
| | self.max_epochs = max_epochs |
| | self.sigma_a = sigma_a |
| | self.sigma_b = sigma_b |
| |
|
| | |
| | self.model = tvmv.r3d_18(weights=tvmv.R3D_18_Weights.DEFAULT) |
| | in_features = self.model.fc.in_features |
| |
|
| | |
| | if variant != "mean_out": |
| | self.model.fc = nn.Identity() |
| | else: |
| | |
| | self.model.fc = nn.Linear(in_features, 2, bias=True) |
| |
|
| | |
| | if weight_path is not None: |
| | print(f"Loading backbone weights from {weight_path}") |
| | self.load_weights_backbone(weight_path, self.model) |
| |
|
| | |
| | self._init_head(in_features, num_classes) |
| |
|
| | |
| | if pl_weight_path is not None: |
| | print(f"Loading full model weights from {pl_weight_path} (pt_format={pt_weights_format})") |
| | self.load_full_model(pl_weight_path, pt_weights_format) |
| |
|
| | |
| | self.loss_clf = nn.BCEWithLogitsLoss(reduction="none") |
| | self.loss_reg = nn.MSELoss(reduction="none") |
| |
|
| | |
| | self.y_val, self.p_val, self.r_val = [], [], [] |
| | self.ty_val, self.tp_val = [], [] |
| |
|
| | def _init_head(self, in_features: int, num_classes: int): |
| | """Инициализация head в зависимости от variant.""" |
| | if self.variant == "mean_out": |
| | return |
| |
|
| | elif self.variant in ("gru_mean", "gru_last"): |
| | self.rnn = nn.GRU(in_features, in_features // 4, batch_first=True) |
| | self.dropout = nn.Dropout(0.2) |
| | self.fc = nn.Linear(in_features // 4, num_classes, bias=True) |
| |
|
| | elif self.variant in ("lstm_mean", "lstm_last"): |
| | self.lstm = nn.LSTM( |
| | input_size=in_features, |
| | hidden_size=in_features // 4, |
| | proj_size=num_classes, |
| | batch_first=True, |
| | ) |
| |
|
| | elif self.variant == "mean": |
| | self.fc = nn.Linear(in_features, num_classes, bias=True) |
| |
|
| | elif self.variant in ("bert_mean", "bert_cls", "bert_cls2"): |
| | encoder_layer = nn.TransformerEncoderLayer( |
| | d_model=in_features, |
| | nhead=4, |
| | batch_first=True, |
| | dim_feedforward=in_features // 4, |
| | ) |
| | self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=1) |
| | self.dropout = nn.Dropout(0.2) |
| | self.fc = nn.Linear(in_features, num_classes, bias=True) |
| | if self.variant == "bert_cls2": |
| | self.cls = nn.Parameter(torch.randn(1, 1, in_features)) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | """ |
| | x: (batch, N_videos, C, T, H, W) |
| | → (batch, N_videos, embed_dim) → head → (batch, num_classes) |
| | """ |
| | batch_size, seq_len, *video_shape = x.shape |
| | x = torch.flatten(x, start_dim=0, end_dim=1) |
| | x = self.model(x) |
| | x = torch.unflatten(x, 0, (batch_size, seq_len)) |
| |
|
| | |
| | if self.variant == "mean_out": |
| | x = torch.mean(x, dim=1) |
| |
|
| | elif self.variant in ("gru_mean", "gru_last"): |
| | all_outs, last_out = self.rnn(x) |
| | x = torch.mean(all_outs, dim=1) if self.variant == "gru_mean" else last_out |
| | x = self.dropout(x) |
| | x = self.fc(x) |
| |
|
| | elif self.variant in ("lstm_mean", "lstm_last"): |
| | all_outs, (last_out, _) = self.lstm(x) |
| | x = torch.mean(all_outs, dim=1) if self.variant == "lstm_mean" else last_out |
| |
|
| | elif self.variant == "mean": |
| | x = torch.mean(x, dim=1) |
| | x = self.fc(x) |
| |
|
| | elif self.variant in ("bert_mean", "bert_cls", "bert_cls2"): |
| | if self.variant == "bert_cls": |
| | x = F.pad(x, (0, 0, 1, 0), "constant", 0) |
| | elif self.variant == "bert_cls2": |
| | bs = x.size(0) |
| | x = torch.cat([self.cls.expand(bs, -1, -1), x], dim=1) |
| | x = self.encoder(x) |
| | x = torch.mean(x, dim=1) if self.variant == "bert_mean" else x[:, 0, :] |
| | x = self.dropout(x) |
| | x = self.fc(x) |
| |
|
| | return x |
| |
|
| | def training_step(self, batch, batch_idx): |
| | x, y, target, path = batch |
| | y_hat = self(x) |
| | yp_clf, yp_reg = y_hat[:, 0:1], y_hat[:, 1:] |
| |
|
| | |
| | weights_clf = torch.where(y > 0, 1.0, 0.45) |
| | clf_loss = (self.loss_clf(yp_clf, y) * weights_clf).mean() |
| |
|
| | |
| | reg_loss_raw = self.loss_reg(yp_reg, target) |
| | sigma = self.sigma_a * target + self.sigma_b |
| | reg_loss = (reg_loss_raw / (sigma ** 2)).mean() |
| |
|
| | loss = clf_loss + 0.5 * reg_loss |
| |
|
| | |
| | y_pred = torch.sigmoid(yp_clf) |
| | y_bin = torch.round(y.detach().cpu()).int() |
| | y_pred_bin = torch.round(y_pred.detach().cpu()).int() |
| |
|
| | self.log("train_clf_loss", clf_loss, prog_bar=True, sync_dist=True) |
| | self.log("train_val_loss", reg_loss, prog_bar=True, sync_dist=True) |
| | self.log("train_full_loss", loss, prog_bar=True, sync_dist=True) |
| | self.log("train_f1", skm.f1_score(y_bin, y_pred_bin, zero_division=0), |
| | prog_bar=True, sync_dist=True) |
| | self.log("train_acc", skm.accuracy_score(y_bin, y_pred_bin), |
| | prog_bar=True, sync_dist=True) |
| |
|
| | return loss |
| |
|
| | def validation_step(self, batch, batch_idx): |
| | x, y, target, path = batch |
| | y_hat = self(x) |
| | yp_clf, yp_reg = y_hat[:, 0:1], y_hat[:, 1:] |
| |
|
| | |
| | y_pred = torch.sigmoid(yp_clf) |
| | self.y_val.append(int(y[..., 0].cpu())) |
| | self.p_val.append(float(y_pred[..., 0].cpu())) |
| | self.r_val.append(round(float(y_pred[..., 0].cpu()))) |
| | self.ty_val.append(float(target[..., 0].cpu())) |
| | self.tp_val.append(float(yp_reg[..., 0].cpu())) |
| |
|
| | |
| | clf_loss = self.loss_clf(yp_clf, y).mean() |
| | reg_loss_raw = self.loss_reg(yp_reg, target) |
| | sigma = self.sigma_a * target + self.sigma_b |
| | reg_loss = (reg_loss_raw / (sigma ** 2)).mean() |
| | loss = clf_loss + 0.5 * reg_loss |
| |
|
| | return loss |
| |
|
| | def on_validation_epoch_end(self): |
| | try: |
| | auc = skm.roc_auc_score(self.y_val, self.p_val) |
| | f1 = skm.f1_score(self.y_val, self.r_val, zero_division=0) |
| | acc = skm.accuracy_score(self.y_val, self.r_val) |
| | mae = skm.mean_absolute_error(self.y_val, self.r_val) |
| | rmse = skm.root_mean_squared_error(self.ty_val, self.tp_val) |
| |
|
| | self.log("val_auc", auc, prog_bar=True, sync_dist=True) |
| | self.log("val_f1", f1, prog_bar=True, sync_dist=True) |
| | self.log("val_acc", acc, prog_bar=True, sync_dist=True) |
| | self.log("val_mae", mae, prog_bar=True, sync_dist=True) |
| | self.log("val_rmse", rmse, prog_bar=True, sync_dist=True) |
| |
|
| | except ValueError as err: |
| | print(err) |
| | print("Y_VAL", self.y_val[:10], "...") |
| | print("P_VAL", self.p_val[:10], "...") |
| |
|
| | |
| | [buf.clear() for buf in [self.y_val, self.p_val, self.r_val, self.ty_val, self.tp_val]] |
| |
|
| | def on_train_epoch_end(self): |
| | lr = self.optimizers().optimizer.param_groups[0]["lr"] |
| | self.log("lr", lr, on_epoch=True, sync_dist=True) |
| |
|
| | def configure_optimizers(self): |
| | |
| | if self.weight_path: |
| | |
| | trainable_modules = self._get_trainable_head_modules() |
| | for param in self.parameters(): |
| | param.requires_grad = False |
| | for module in trainable_modules: |
| | for param in module.parameters(): |
| | param.requires_grad = True |
| | params = [p for module in trainable_modules for p in module.parameters()] |
| | else: |
| | |
| | for param in self.parameters(): |
| | param.requires_grad = True |
| | params = self.parameters() |
| |
|
| | optimizer = optim.Adam(params, lr=self.lr, weight_decay=self.weight_decay) |
| | if self.max_epochs: |
| | scheduler = optim.lr_scheduler.OneCycleLR( |
| | optimizer, max_lr=self.lr, total_steps=self.max_epochs |
| | ) |
| | return [optimizer], [scheduler] |
| | return optimizer |
| |
|
| | def _get_trainable_head_modules(self): |
| | """Возвращает список обучаемых модулей head.""" |
| | if self.variant == "mean_out": |
| | return [self.model.fc] |
| | elif self.variant in ("gru_mean", "gru_last"): |
| | return [self.rnn, self.fc] |
| | elif self.variant in ("lstm_mean", "lstm_last"): |
| | return [self.lstm] |
| | elif self.variant == "mean": |
| | return [self.fc] |
| | elif self.variant in ("bert_mean", "bert_cls", "bert_cls2"): |
| | modules = [self.encoder, self.fc] |
| | if self.variant == "bert_cls2": |
| | modules.append(self.cls) |
| | return modules |
| | return [] |
| |
|
| | def load_weights_backbone(self, weight_path: str, model): |
| | """Загрузка backbone из Lightning .ckpt.""" |
| | ckpt = torch.load(weight_path, map_location="cpu", weights_only=False) |
| | state_dict = ckpt["state_dict"] |
| | new_state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()} |
| | model.load_state_dict(new_state_dict, strict=False) |
| |
|
| | def load_full_model(self, pl_weight_path: str, pt_weights_format: bool): |
| | """Загрузка полной модели (.ckpt или .pt).""" |
| | if pt_weights_format: |
| | |
| | state_dict = torch.load(pl_weight_path, map_location="cpu", weights_only=False) |
| | else: |
| | |
| | ckpt = torch.load(pl_weight_path, map_location="cpu", weights_only=False) |
| | state_dict = ckpt["state_dict"] |
| |
|
| | |
| | self.load_weights(state_dict, self.model, "model") |
| |
|
| | |
| | trainable_modules = self._get_trainable_head_modules() |
| | for module in trainable_modules: |
| | prefix = module.__class__.__name__.lower() |
| | self.load_weights(state_dict, module, prefix) |
| |
|
| | if self.variant == "bert_cls2": |
| | if "cls" in state_dict: |
| | self.cls.data.copy_(state_dict["cls"]) |
| |
|
| | def load_weights(self, state_dict, module, prefix: str): |
| | """Загрузка весов модуля по префиксу.""" |
| | module_state = { |
| | k.replace(f"{prefix}.", ""): v |
| | for k, v in state_dict.items() |
| | if k.startswith(prefix) |
| | } |
| | missing, unexpected = module.load_state_dict(module_state, strict=False) |
| | if missing: |
| | print(f"Missing keys for {prefix}: {len(missing)}") |
| | if unexpected: |
| | print(f"Unexpected keys for {prefix}: {len(unexpected)}") |
| |
|
| | def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: |
| | x, y, target, path = batch |
| | y_hat = self(x) |
| | yp_clf, yp_reg = y_hat[:, 0:1], y_hat[:, 1:] |
| | y_prob = torch.sigmoid(yp_clf) |
| | return { |
| | "y": y, |
| | "y_pred": torch.round(y_prob), |
| | "y_prob": y_prob, |
| | "y_reg": yp_reg, |
| | "target": target, |
| | } |
| |
|