| | import logging |
| | import random |
| |
|
| | import numpy as np |
| | import torch |
| | from omegaconf import DictConfig |
| | from torch.utils.data import DataLoader, Dataset |
| | from torch.utils.data.dataloader import default_collate |
| | from torch.utils.data.distributed import DistributedSampler |
| |
|
| | from .eval.audiocaps import AudioCapsData |
| | from .eval.video_dataset import MovieGen, VGGSound |
| | from .extracted_audio import ExtractedAudio |
| | from .extracted_vgg import ExtractedVGG |
| | from .mm_dataset import MultiModalDataset |
| | from ..utils.dist_utils import local_rank |
| |
|
| | log = logging.getLogger() |
| |
|
| |
|
| | |
| | def worker_init_fn(worker_id: int): |
| | worker_seed = torch.initial_seed() % (2**31) + worker_id + local_rank * 1000 |
| | np.random.seed(worker_seed) |
| | random.seed(worker_seed) |
| | log.debug(f'Worker {worker_id} re-seeded with seed {worker_seed} in rank {local_rank}') |
| |
|
| |
|
| | def load_vgg_data(cfg: DictConfig, data_cfg: DictConfig) -> Dataset: |
| | dataset = ExtractedVGG(tsv_path=data_cfg.tsv, |
| | data_dim=cfg.data_dim, |
| | premade_mmap_dir=data_cfg.memmap_dir) |
| |
|
| | return dataset |
| |
|
| |
|
| | def load_audio_data(cfg: DictConfig, data_cfg: DictConfig) -> Dataset: |
| | dataset = ExtractedAudio(tsv_path=data_cfg.tsv, |
| | data_dim=cfg.data_dim, |
| | premade_mmap_dir=data_cfg.memmap_dir) |
| |
|
| | return dataset |
| |
|
| |
|
| | def setup_training_datasets(cfg: DictConfig) -> tuple[Dataset, DistributedSampler, DataLoader]: |
| | if cfg.mini_train: |
| | vgg = load_vgg_data(cfg, cfg.data.ExtractedVGG_val) |
| | audiocaps = load_audio_data(cfg, cfg.data.AudioCaps) |
| | dataset = MultiModalDataset([vgg], [audiocaps]) |
| | if cfg.example_train: |
| | video = load_vgg_data(cfg, cfg.data.Example_video) |
| | audio = load_audio_data(cfg, cfg.data.Example_audio) |
| | dataset = MultiModalDataset([video], [audio]) |
| | else: |
| | |
| | freesound = load_audio_data(cfg, cfg.data.FreeSound) |
| | vgg = load_vgg_data(cfg, cfg.data.ExtractedVGG) |
| | audiocaps = load_audio_data(cfg, cfg.data.AudioCaps) |
| | audioset_sl = load_audio_data(cfg, cfg.data.AudioSetSL) |
| | bbcsound = load_audio_data(cfg, cfg.data.BBCSound) |
| | clotho = load_audio_data(cfg, cfg.data.Clotho) |
| | dataset = MultiModalDataset([vgg] * cfg.vgg_oversample_rate, |
| | [audiocaps, audioset_sl, bbcsound, freesound, clotho]) |
| |
|
| | batch_size = cfg.batch_size |
| | num_workers = cfg.num_workers |
| | pin_memory = cfg.pin_memory |
| | sampler, loader = construct_loader(dataset, |
| | batch_size, |
| | num_workers, |
| | shuffle=True, |
| | drop_last=True, |
| | pin_memory=pin_memory) |
| |
|
| | return dataset, sampler, loader |
| |
|
| |
|
| | def setup_test_datasets(cfg): |
| | dataset = load_vgg_data(cfg, cfg.data.ExtractedVGG_test) |
| |
|
| | batch_size = cfg.batch_size |
| | num_workers = cfg.num_workers |
| | pin_memory = cfg.pin_memory |
| | sampler, loader = construct_loader(dataset, |
| | batch_size, |
| | num_workers, |
| | shuffle=False, |
| | drop_last=False, |
| | pin_memory=pin_memory) |
| |
|
| | return dataset, sampler, loader |
| |
|
| |
|
| | def setup_val_datasets(cfg: DictConfig) -> tuple[Dataset, DataLoader, DataLoader]: |
| | if cfg.example_train: |
| | dataset = load_vgg_data(cfg, cfg.data.Example_video) |
| | else: |
| | dataset = load_vgg_data(cfg, cfg.data.ExtractedVGG_val) |
| |
|
| | val_batch_size = cfg.batch_size |
| | val_eval_batch_size = cfg.eval_batch_size |
| | num_workers = cfg.num_workers |
| | pin_memory = cfg.pin_memory |
| | _, val_loader = construct_loader(dataset, |
| | val_batch_size, |
| | num_workers, |
| | shuffle=False, |
| | drop_last=False, |
| | pin_memory=pin_memory) |
| | _, eval_loader = construct_loader(dataset, |
| | val_eval_batch_size, |
| | num_workers, |
| | shuffle=False, |
| | drop_last=False, |
| | pin_memory=pin_memory) |
| |
|
| | return dataset, val_loader, eval_loader |
| |
|
| |
|
| | def setup_eval_dataset(dataset_name: str, cfg: DictConfig) -> tuple[Dataset, DataLoader]: |
| | if dataset_name.startswith('audiocaps_full'): |
| | dataset = AudioCapsData(cfg.eval_data.AudioCaps_full.audio_path, |
| | cfg.eval_data.AudioCaps_full.csv_path) |
| | elif dataset_name.startswith('audiocaps'): |
| | dataset = AudioCapsData(cfg.eval_data.AudioCaps.audio_path, |
| | cfg.eval_data.AudioCaps.csv_path) |
| | elif dataset_name.startswith('moviegen'): |
| | dataset = MovieGen(cfg.eval_data.MovieGen.video_path, |
| | cfg.eval_data.MovieGen.jsonl_path, |
| | duration_sec=cfg.duration_s) |
| | elif dataset_name.startswith('vggsound'): |
| | dataset = VGGSound(cfg.eval_data.VGGSound.video_path, |
| | cfg.eval_data.VGGSound.csv_path, |
| | duration_sec=cfg.duration_s) |
| | else: |
| | raise ValueError(f'Invalid dataset name: {dataset_name}') |
| |
|
| | batch_size = cfg.batch_size |
| | num_workers = cfg.num_workers |
| | pin_memory = cfg.pin_memory |
| | _, loader = construct_loader(dataset, |
| | batch_size, |
| | num_workers, |
| | shuffle=False, |
| | drop_last=False, |
| | pin_memory=pin_memory, |
| | error_avoidance=True) |
| | return dataset, loader |
| |
|
| |
|
| | def error_avoidance_collate(batch): |
| | batch = list(filter(lambda x: x is not None, batch)) |
| | return default_collate(batch) |
| |
|
| |
|
| | def construct_loader(dataset: Dataset, |
| | batch_size: int, |
| | num_workers: int, |
| | *, |
| | shuffle: bool = True, |
| | drop_last: bool = True, |
| | pin_memory: bool = False, |
| | error_avoidance: bool = False) -> tuple[DistributedSampler, DataLoader]: |
| | train_sampler = DistributedSampler(dataset, rank=local_rank, shuffle=shuffle) |
| | train_loader = DataLoader(dataset, |
| | batch_size, |
| | sampler=train_sampler, |
| | num_workers=num_workers, |
| | worker_init_fn=worker_init_fn, |
| | drop_last=drop_last, |
| | persistent_workers=num_workers > 0, |
| | pin_memory=pin_memory, |
| | collate_fn=error_avoidance_collate if error_avoidance else None) |
| | return train_sampler, train_loader |
| |
|