| import os |
| from functools import partial |
| from typing import Any, Dict, List, Optional, Union |
|
|
| import lightning as L |
| import numpy as np |
| import pandas as pd |
| import torch |
| import torchaudio |
| from torch.utils.data import DataLoader, Dataset |
|
|
| from src.data.audio_utils import DatasetResamplerCropper, collate_audio_batch |
|
|
|
|
| class YT1BDataset(Dataset): |
| """ |
| Dataset for YT-Temporal-1B data using Parquet metadata files. |
| |
| Args: |
| parquet_path (str): Path to the parquet file containing metadata (must include 'file_path', 'video_id', 'duration_sec'). |
| If a 'sample_rate' column exists, it is used to avoid probing files for source sample rate. |
| min_duration (Optional[float]): Minimum duration in seconds to include a file. |
| max_duration (Optional[float]): Maximum duration in seconds to include a file. |
| transform (Optional[callable]): Optional transform to apply to the waveform. |
| max_length (Optional[int]): Maximum length of the waveform in samples (at target_sample_rate). |
| target_sample_rate (int): Target sample rate for the waveform. Defaults to 16000. |
| decode_window_sec (Optional[float]): Optional decode window length in seconds. If None, |
| defaults to max_length / target_sample_rate (when max_length is set). |
| """ |
|
|
| def __init__( |
| self, |
| parquet_path: str, |
| min_duration: Optional[float] = None, |
| max_duration: Optional[float] = 30.0, |
| transform: Optional[Any] = None, |
| max_length: Optional[int] = None, |
| target_sample_rate: int = 16000, |
| decode_window_sec: Optional[float] = None, |
| ): |
| print(f"Loading metadata from {parquet_path}...") |
| self.transform = transform |
| self.max_length = max_length |
| self.target_sample_rate = target_sample_rate |
| self.decode_window_sec = decode_window_sec |
|
|
| |
| if not os.path.exists(parquet_path): |
| raise FileNotFoundError(f"Parquet file not found at: {parquet_path}") |
|
|
| |
| try: |
| df = pd.read_parquet(parquet_path) |
| except ImportError: |
| raise ImportError( |
| "Please install pyarrow to read parquet files: `uv add pyarrow`" |
| ) |
|
|
| required_cols = {"file_path", "video_id", "duration_sec"} |
| if not required_cols.issubset(df.columns): |
| |
| |
| raise ValueError( |
| f"Parquet file must contain columns: {required_cols}. Found: {df.columns.tolist()}" |
| ) |
|
|
| if min_duration is not None and min_duration < 0: |
| raise ValueError(f"min_duration must be >= 0, got {min_duration}") |
| if max_duration is not None and max_duration < 0: |
| raise ValueError(f"max_duration must be >= 0, got {max_duration}") |
| if ( |
| min_duration is not None |
| and max_duration is not None |
| and min_duration > max_duration |
| ): |
| raise ValueError( |
| "min_duration must be <= max_duration; " |
| f"got min_duration={min_duration}, max_duration={max_duration}" |
| ) |
|
|
| if min_duration is not None: |
| df = df[df["duration_sec"] >= min_duration] |
| if max_duration is not None: |
| df = df[df["duration_sec"] <= max_duration] |
|
|
| self.ids = df["video_id"].tolist() |
| self.paths = df["file_path"].tolist() |
| self.durations_sec = df["duration_sec"].tolist() |
| if "sample_rate" in df.columns: |
| sample_rates = pd.to_numeric(df["sample_rate"], errors="coerce").to_numpy( |
| dtype=np.float64 |
| ) |
| self.source_sample_rates: Optional[list[Optional[int]]] = [ |
| int(sr) if np.isfinite(sr) and sr > 0 else None for sr in sample_rates |
| ] |
| else: |
| self.source_sample_rates = None |
| self.length = len(self.ids) |
|
|
| |
| |
| self.resampler = DatasetResamplerCropper( |
| target_sr=target_sample_rate, max_length=max_length |
| ) |
|
|
| print(f"Dataset loaded. Length: {self.length:,}") |
|
|
| def __len__(self) -> int: |
| return self.length |
|
|
| def __getitem__(self, idx: int) -> Dict[str, Union[torch.Tensor, str, int]]: |
| audio_path = self.paths[idx] |
| audio_id = self.ids[idx] |
|
|
| |
| try: |
| decode_window_sec = self.decode_window_sec |
| if decode_window_sec is None and self.max_length is not None: |
| decode_window_sec = self.max_length / self.target_sample_rate |
|
|
| if decode_window_sec is None: |
| waveform, sr = torchaudio.load(audio_path) |
| else: |
| duration_sec = float(self.durations_sec[idx]) |
| if duration_sec <= 0: |
| waveform, sr = torchaudio.load(audio_path) |
| else: |
| source_sr: Optional[int] |
| if self.source_sample_rates is not None: |
| source_sr = self.source_sample_rates[idx] |
| else: |
| source_sr = None |
|
|
| if source_sr is None: |
| _, source_sr = torchaudio.load( |
| audio_path, frame_offset=0, num_frames=1 |
| ) |
|
|
| total_frames = max(1, int(duration_sec * source_sr)) |
| max_decode_frames = max(1, int(decode_window_sec * source_sr)) |
| decode_frames = min(max_decode_frames, total_frames) |
|
|
| if total_frames > decode_frames: |
| max_start = total_frames - decode_frames |
| frame_offset = int(np.random.randint(0, max_start + 1)) |
| else: |
| frame_offset = 0 |
|
|
| waveform, sr = torchaudio.load( |
| audio_path, |
| frame_offset=frame_offset, |
| num_frames=decode_frames, |
| ) |
| except Exception as e: |
| print(f"Error loading {audio_path}: {e}") |
| |
| len_samples = ( |
| self.max_length if self.max_length else self.target_sample_rate |
| ) |
| return { |
| "waveform": torch.zeros(1, len_samples), |
| "audio_name": audio_id, |
| "index": idx, |
| "error": True, |
| } |
|
|
| |
| if waveform.shape[0] > 1: |
| waveform = torch.mean(waveform, dim=0, keepdim=True) |
|
|
| |
| waveform = self.resampler(waveform, source_sr=sr) |
|
|
| |
| if waveform.ndim == 1: |
| waveform = waveform.unsqueeze(0) |
|
|
| if self.transform: |
| waveform = self.transform(waveform) |
|
|
| return { |
| "waveform": waveform, |
| "audio_name": audio_id, |
| "index": idx, |
| } |
|
|
|
|
| class YT1BDataModule(L.LightningDataModule): |
| """ |
| LightningDataModule for YT-Temporal-1B. |
| |
| Args: |
| data_dir (str): Root directory for data. |
| train_parquet (str): Filename of training parquet file. |
| val_parquet (str): Filename of validation parquet file. |
| test_parquet (str): Filename of test parquet file. |
| batch_size (int): Batch size for dataloaders. |
| num_workers (int): Number of workers for dataloaders. |
| pin_memory (bool): Whether to pin memory in dataloaders. |
| max_audio_length_sec (Optional[float]): Maximum audio length in seconds. |
| min_duration_sec (Optional[float]): Minimum audio duration in seconds to filter. |
| max_duration_sec (Optional[float]): Maximum audio duration in seconds to filter. |
| target_sample_rate (int): Target sample rate. |
| collate_mode (str): 'pad' or 'truncate'. |
| decode_window_sec (Optional[float]): Optional decode window length in seconds. If None, |
| defaults to max_audio_length_sec. |
| """ |
|
|
| def __init__( |
| self, |
| data_dir: str = "data/YT-Temporal-1B", |
| train_parquet: str = "train_metadata.parquet", |
| val_parquet: str = "val_metadata.parquet", |
| test_parquet: str = "val_metadata.parquet", |
| batch_size: int = 64, |
| num_workers: int = 4, |
| pin_memory: bool = True, |
| max_audio_length_sec: Optional[float] = 10.0, |
| min_duration_sec: Optional[float] = None, |
| max_duration_sec: Optional[float] = 30.0, |
| target_sample_rate: int = 16000, |
| collate_mode: str = "pad", |
| decode_window_sec: Optional[float] = None, |
| ): |
| super().__init__() |
| self.save_hyperparameters() |
|
|
| self.data_dir = data_dir |
| self.train_parquet_path = os.path.join(data_dir, train_parquet) |
| self.val_parquet_path = os.path.join(data_dir, val_parquet) |
| self.test_parquet_path = os.path.join(data_dir, test_parquet) |
|
|
| if max_audio_length_sec is not None: |
| self.max_audio_length = int(max_audio_length_sec * target_sample_rate) |
| else: |
| self.max_audio_length = None |
|
|
| self.train_dataset: Optional[YT1BDataset] = None |
| self.val_dataset: Optional[YT1BDataset] = None |
| self.test_dataset: Optional[YT1BDataset] = None |
|
|
| def setup(self, stage: Optional[str] = None) -> None: |
| if stage == "fit" or stage is None: |
| if os.path.exists(self.train_parquet_path): |
| self.train_dataset = YT1BDataset( |
| self.train_parquet_path, |
| min_duration=self.hparams["min_duration_sec"], |
| max_duration=self.hparams["max_duration_sec"], |
| max_length=self.max_audio_length, |
| target_sample_rate=self.hparams["target_sample_rate"], |
| decode_window_sec=self.hparams["decode_window_sec"], |
| ) |
|
|
| if os.path.exists(self.val_parquet_path): |
| self.val_dataset = YT1BDataset( |
| self.val_parquet_path, |
| min_duration=self.hparams["min_duration_sec"], |
| max_duration=self.hparams["max_duration_sec"], |
| max_length=self.max_audio_length, |
| target_sample_rate=self.hparams["target_sample_rate"], |
| decode_window_sec=self.hparams["decode_window_sec"], |
| ) |
|
|
| if stage == "test": |
| if os.path.exists(self.test_parquet_path): |
| self.test_dataset = YT1BDataset( |
| self.test_parquet_path, |
| min_duration=self.hparams["min_duration_sec"], |
| max_duration=self.hparams["max_duration_sec"], |
| max_length=self.max_audio_length, |
| target_sample_rate=self.hparams["target_sample_rate"], |
| decode_window_sec=self.hparams["decode_window_sec"], |
| ) |
|
|
| def train_dataloader(self) -> DataLoader: |
| if not self.train_dataset: |
| raise RuntimeError( |
| f"Train dataset not initialized. File not found: {self.train_parquet_path}" |
| ) |
| return DataLoader( |
| self.train_dataset, |
| batch_size=self.hparams["batch_size"], |
| shuffle=True, |
| num_workers=self.hparams["num_workers"], |
| pin_memory=self.hparams["pin_memory"], |
| persistent_workers=self.hparams["num_workers"] > 0, |
| collate_fn=partial(self.collate_fn, mode=self.hparams["collate_mode"]), |
| ) |
|
|
| def val_dataloader(self) -> DataLoader: |
| if not self.val_dataset: |
| |
| |
| |
| raise RuntimeError( |
| f"Val dataset not initialized. File not found: {self.val_parquet_path}" |
| ) |
|
|
| return DataLoader( |
| self.val_dataset, |
| batch_size=self.hparams["batch_size"], |
| shuffle=False, |
| num_workers=self.hparams["num_workers"], |
| pin_memory=self.hparams["pin_memory"], |
| persistent_workers=self.hparams["num_workers"] > 0, |
| collate_fn=partial(self.collate_fn, mode=self.hparams["collate_mode"]), |
| ) |
|
|
| def test_dataloader(self) -> DataLoader: |
| if not self.test_dataset: |
| raise RuntimeError( |
| f"Test dataset not initialized. File not found: {self.test_parquet_path}" |
| ) |
|
|
| return DataLoader( |
| self.test_dataset, |
| batch_size=self.hparams["batch_size"], |
| shuffle=False, |
| num_workers=self.hparams["num_workers"], |
| pin_memory=self.hparams["pin_memory"], |
| collate_fn=partial(self.collate_fn, mode=self.hparams["collate_mode"]), |
| ) |
|
|
| @staticmethod |
| def collate_fn(batch: List[Dict[str, Any]], mode: str = "pad") -> Dict[str, Any]: |
| |
| batch = [x for x in batch if not x.get("error", False)] |
| if len(batch) == 0: |
| raise RuntimeError("All items in batch failed to load.") |
|
|
| return collate_audio_batch( |
| batch=batch, |
| waveform_key="waveform", |
| mode=mode, |
| ) |
|
|