| | import os |
| | import json |
| | import torch |
| | import numpy as np |
| | import click |
| | import lightning.pytorch as pl |
| | from lightning.pytorch.loggers import TensorBoardLogger |
| | from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor |
| |
|
| | from pytorchvideo.transforms import Normalize, Permute, RandAugment |
| | from torch.utils.data import DataLoader |
| | from torchvision.transforms import transforms as T |
| | from torchvision.transforms._transforms_video import ToTensorVideo |
| | from torchvision.transforms import InterpolationMode |
| |
|
| | from rnn_dataset import SyntaxDataset |
| | from rnn_model import SyntaxLightningModule |
| |
|
| | torch.set_float32_matmul_precision("medium") |
| |
|
| |
|
| | """ |
| | Обучение RNN-head поверх предобученного backbone для SYNTAX score. |
| | |
| | Этапы: |
| | 1) pretrain — обучается только head (backbone заморожен); |
| | 2) full — fine-tuning всей модели (backbone + head). |
| | """ |
| |
|
| |
|
| | def get_transforms(video_size, imagenet_mean, imagenet_std, train=True): |
| | """Трансформации для видео (train с аугментациями, test без).""" |
| | interpolation_choices = [InterpolationMode.BILINEAR, InterpolationMode.BICUBIC] |
| | if train: |
| | return T.Compose([ |
| | ToTensorVideo(), |
| | Permute(dims=[1, 0, 2, 3]), |
| | RandAugment(magnitude=10, num_layers=2), |
| | T.RandomHorizontalFlip(), |
| | Permute(dims=[1, 0, 2, 3]), |
| | T.RandomChoice([ |
| | T.Resize(size=video_size, interpolation=interp, antialias=True) |
| | for interp in interpolation_choices |
| | ]), |
| | Normalize(mean=imagenet_mean, std=imagenet_std), |
| | ]) |
| | return T.Compose([ |
| | ToTensorVideo(), |
| | T.Resize(size=video_size, interpolation=InterpolationMode.BICUBIC, antialias=True), |
| | Normalize(mean=imagenet_mean, std=imagenet_std), |
| | ]) |
| |
|
| |
|
| | def make_dataloader(dataset, batch_size, num_workers): |
| | """DataLoader с shuffle (sampler закомментирован).""" |
| | |
| | return DataLoader( |
| | dataset, |
| | batch_size=batch_size, |
| | num_workers=num_workers, |
| | shuffle=True if not dataset.inference else False, |
| | drop_last=True, |
| | pin_memory=True, |
| | ) |
| |
|
| |
|
| | def make_model(num_classes, video_shape, lr, variant, weight_decay, max_epochs, |
| | weight_path=None, pl_weight_path=None, pt_weights_format=False): |
| | """Создание SyntaxLightningModule.""" |
| | return SyntaxLightningModule( |
| | num_classes=num_classes, |
| | lr=lr, |
| | variant=variant, |
| | weight_decay=weight_decay, |
| | max_epochs=max_epochs, |
| | weight_path=weight_path, |
| | pl_weight_path=pl_weight_path, |
| | pt_weights_format=pt_weights_format, |
| | ) |
| |
|
| |
|
| | def make_callbacks(artery: str, fold: int, phase: str): |
| | """Callbacks: LR monitor + checkpoint по val_mae.""" |
| | lr_monitor = LearningRateMonitor(logging_interval="epoch") |
| | if phase == "pre": |
| | checkpoint = ModelCheckpoint( |
| | monitor="val_mae", |
| | save_top_k=1, |
| | mode="min", |
| | filename="model-{epoch:02d}-{val_rmse:.3f}", |
| | save_last=True, |
| | ) |
| | elif phase == "full": |
| | checkpoint = ModelCheckpoint( |
| | monitor="val_mae", |
| | save_top_k=3, |
| | mode="min", |
| | filename="model-{epoch:02d}-{val_rmse:.3f}", |
| | save_last=True, |
| | ) |
| | else: |
| | raise ValueError(f"phase must be 'pre' or 'full'") |
| | return [lr_monitor, checkpoint] |
| |
|
| |
|
| | def make_trainer(max_epochs, logger_name, callbacks): |
| | """Lightning Trainer с TensorBoard.""" |
| | logger = TensorBoardLogger(save_dir="rnn_logs", name=logger_name) |
| | trainer = pl.Trainer( |
| | max_epochs=max_epochs, |
| | accelerator="gpu", |
| | devices=1, |
| | strategy="ddp_find_unused_parameters_true", |
| | precision="bf16-mixed", |
| | callbacks=callbacks, |
| | log_every_n_steps=10, |
| | logger=logger, |
| | ) |
| | return trainer |
| |
|
| |
|
| | @click.command() |
| | @click.option( |
| | "-r", "--dataset-root", type=click.Path(exists=True), required=True, |
| | help="Корень датасета (где лежат folds/*.json и DICOM).", |
| | ) |
| | @click.option("--fold", type=int, default=0, help="Номер фолда (0-4).") |
| | @click.option("-a", "--artery", type=str, default="right", help="'left' или 'right'.") |
| | @click.option("--variant", type=str, default="lstm_mean", help="Тип head (lstm_mean и др.).") |
| | @click.option("-nc", "--num-classes", type=int, default=2) |
| | @click.option("-b", "--batch-size", type=int, default=8) |
| | @click.option("-f", "--frames-per-clip", type=int, default=32) |
| | @click.option("-v", "--video-size", type=click.Tuple([int, int]), default=(256, 256)) |
| | @click.option("--max-epochs", type=int, default=10) |
| | @click.option("--num-workers", type=int, default=8) |
| | @click.option("--fast-dev-run", is_flag=True) |
| | @click.option("--seed", type=int, default=42) |
| | @click.option("--backbone-ckpt", type=str, help="Путь к backbone-чекпоинту для pretrain.") |
| | def main( |
| | dataset_root, fold, artery, variant, num_classes, batch_size, frames_per_clip, |
| | video_size, max_epochs, num_workers, fast_dev_run, seed, backbone_ckpt, |
| | ): |
| | pl.seed_everything(seed) |
| | artery = artery.lower() |
| | artery_bin = {"left": 0, "right": 1}[artery] |
| |
|
| | print(f"Training {variant} head for {artery} artery, fold {fold}") |
| |
|
| | imagenet_mean = [0.485, 0.456, 0.406] |
| | imagenet_std = [0.229, 0.224, 0.225] |
| |
|
| | |
| | train_meta = os.path.join("rnn_folds", f"step2_rnn_fold{fold:02d}_train.json") |
| | val_meta = os.path.join("rnn_folds", f"step2_rnn_fold{fold:02d}_eval.json") |
| |
|
| | train_set = SyntaxDataset( |
| | root=dataset_root, |
| | meta=train_meta, |
| | train=True, |
| | length=frames_per_clip, |
| | label=f"syntax_{artery}", |
| | artery=artery, |
| | transform=get_transforms(video_size, imagenet_mean, imagenet_std, train=True), |
| | ) |
| | val_set = SyntaxDataset( |
| | root=dataset_root, |
| | meta=val_meta, |
| | train=False, |
| | length=frames_per_clip, |
| | label=f"syntax_{artery}", |
| | artery=artery, |
| | validation=True, |
| | transform=get_transforms(video_size, imagenet_mean, imagenet_std, train=False), |
| | ) |
| |
|
| | |
| | train_loader_pre = make_dataloader(train_set, batch_size * 2, num_workers) |
| | train_loader_post = make_dataloader(train_set, batch_size, num_workers) |
| | val_loader = make_dataloader(val_set, 1, num_workers) |
| |
|
| | |
| | x, *_ = next(iter(train_loader_pre)) |
| | video_shape = x.shape[1:] |
| |
|
| | |
| | callbacks_pre = make_callbacks(artery, fold, "pre") |
| | model_pre = make_model( |
| | num_classes, video_shape, lr=1e-4, variant=variant, |
| | weight_decay=0.01, max_epochs=max_epochs, weight_path=backbone_ckpt, |
| | ) |
| | trainer_pre = make_trainer(max_epochs, f"{artery}_{variant}_pre_fold{fold:02d}", callbacks_pre) |
| | trainer_pre.fit(model_pre, train_loader_pre, val_loader) |
| |
|
| | |
| | callbacks_full = make_callbacks(artery, fold, "full") |
| | model_full = make_model( |
| | num_classes, video_shape, lr=2e-5, variant=variant, |
| | weight_decay=0.01, max_epochs=max_epochs, |
| | pl_weight_path=trainer_pre.checkpoint_callback.best_model_path, |
| | ) |
| | trainer_full = make_trainer(max_epochs, f"{artery}_{variant}_full_fold{fold:02d}", callbacks_full) |
| | trainer_full.fit(model_full, train_loader_post, val_loader) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|