English
VideoCLIP-XL / data_dataloaders.py
zarus03's picture
Upload folder using huggingface_hub
a3c8a6a verified
import torch
from torch.utils.data import DataLoader
from dataloaders.dataloader_msvd_retrieval import MSVD_DataLoader
def dataloader_msvd_train(args, tokenizer):
msvd_dataset = MSVD_DataLoader(
subset="train",
data_path=args.data_path,
features_path=args.features_path,
csv_path=args.train_csv,
max_words=args.max_words,
feature_framerate=args.feature_framerate,
tokenizer=tokenizer,
max_frames=args.max_frames,
frame_order=args.train_frame_order,
slice_framepos=args.slice_framepos,
)
train_sampler = torch.utils.data.distributed.DistributedSampler(msvd_dataset)
dataloader = DataLoader(
msvd_dataset,
batch_size=args.batch_size // args.n_gpu,
num_workers=args.num_thread_reader,
pin_memory=False,
shuffle=(train_sampler is None),
sampler=train_sampler,
drop_last=True,
)
return dataloader, len(msvd_dataset), train_sampler
def dataloader_msvd_test(args, tokenizer, subset="test"):
msvd_testset = MSVD_DataLoader(
subset=subset,
data_path=args.data_path,
features_path=args.features_path,
csv_path=args.val_csv,
max_words=args.max_words,
feature_framerate=args.feature_framerate,
tokenizer=tokenizer,
max_frames=args.max_frames,
frame_order=args.eval_frame_order,
slice_framepos=args.slice_framepos,
)
dataloader_msrvtt = DataLoader(
msvd_testset,
batch_size=args.batch_size_val,
num_workers=args.num_thread_reader,
shuffle=False,
drop_last=False,
)
return dataloader_msrvtt, len(msvd_testset)
DATALOADER_DICT = {}
DATALOADER_DICT["msvd"] = {"train":dataloader_msvd_train, "val":dataloader_msvd_test, "test":dataloader_msvd_test}