| | import json |
| | import logging |
| | import random |
| |
|
| | import numpy as np |
| | import pandas as pd |
| | import torch |
| |
|
| | from src.data.augmentations import ( |
| | NanAugmenter, |
| | ) |
| | from src.data.constants import DEFAULT_NAN_STATS_PATH, LENGTH_CHOICES, LENGTH_WEIGHTS |
| | from src.data.containers import BatchTimeSeriesContainer |
| | from src.data.datasets import CyclicalBatchDataset |
| | from src.data.frequency import Frequency |
| | from src.data.scalers import MeanScaler, MedianScaler, MinMaxScaler, RobustScaler |
| | from src.data.utils import sample_future_length |
| |
|
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class BatchComposer: |
| | """ |
| | Composes batches from saved generator data according to specified proportions. |
| | Manages multiple CyclicalBatchDataset instances and creates uniform or mixed batches. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | base_data_dir: str, |
| | generator_proportions: dict[str, float] | None = None, |
| | mixed_batches: bool = True, |
| | device: torch.device | None = None, |
| | augmentations: dict[str, bool] | None = None, |
| | augmentation_probabilities: dict[str, float] | None = None, |
| | nan_stats_path: str | None = None, |
| | nan_patterns_path: str | None = None, |
| | global_seed: int = 42, |
| | chosen_scaler_name: str | None = None, |
| | rank: int = 0, |
| | world_size: int = 1, |
| | ): |
| | """ |
| | Initialize the BatchComposer. |
| | |
| | Args: |
| | base_data_dir: Base directory containing generator subdirectories |
| | generator_proportions: Dict mapping generator names to proportions |
| | mixed_batches: If True, create mixed batches; if False, uniform batches |
| | device: Device to load tensors to |
| | augmentations: Dict mapping augmentation names to booleans |
| | augmentation_probabilities: Dict mapping augmentation names to probabilities |
| | global_seed: Global random seed |
| | chosen_scaler_name: Name of the scaler that used in training |
| | rank: Rank of current process for distributed data loading |
| | world_size: Total number of processes for distributed data loading |
| | """ |
| | self.base_data_dir = base_data_dir |
| | self.mixed_batches = mixed_batches |
| | self.device = device |
| | self.global_seed = global_seed |
| | self.nan_stats_path = nan_stats_path |
| | self.nan_patterns_path = nan_patterns_path |
| | self.rank = rank |
| | self.world_size = world_size |
| | self.augmentation_probabilities = augmentation_probabilities or { |
| | "noise_augmentation": 0.3, |
| | "scaler_augmentation": 0.5, |
| | } |
| | |
| | self.chosen_scaler_name = chosen_scaler_name.lower() if chosen_scaler_name is not None else None |
| |
|
| | |
| | self.rng = np.random.default_rng(global_seed) |
| | random.seed(global_seed) |
| | torch.manual_seed(global_seed) |
| |
|
| | |
| | self._setup_augmentations(augmentations) |
| |
|
| | |
| | self._setup_proportions(generator_proportions) |
| |
|
| | |
| | self.datasets = self._initialize_datasets() |
| |
|
| | logger.info( |
| | f"Initialized BatchComposer with {len(self.datasets)} generators, " |
| | f"mixed_batches={mixed_batches}, proportions={self.generator_proportions}, " |
| | f"augmentations={self.augmentations}, " |
| | f"augmentation_probabilities={self.augmentation_probabilities}" |
| | ) |
| |
|
| | def _setup_augmentations(self, augmentations: dict[str, bool] | None): |
| | """Setup only the augmentations that should remain online (NaN).""" |
| | default_augmentations = { |
| | "nan_augmentation": False, |
| | "scaler_augmentation": False, |
| | "length_shortening": False, |
| | } |
| |
|
| | self.augmentations = augmentations or default_augmentations |
| |
|
| | |
| | self.nan_augmenter = None |
| | if self.augmentations.get("nan_augmentation", False): |
| | stats_path_to_use = self.nan_stats_path or DEFAULT_NAN_STATS_PATH |
| | stats = json.load(open(stats_path_to_use)) |
| | self.nan_augmenter = NanAugmenter( |
| | p_series_has_nan=stats["p_series_has_nan"], |
| | nan_ratio_distribution=stats["nan_ratio_distribution"], |
| | nan_length_distribution=stats["nan_length_distribution"], |
| | nan_patterns_path=self.nan_patterns_path, |
| | ) |
| |
|
| | def _should_apply_scaler_augmentation(self) -> bool: |
| | """ |
| | Decide whether to apply scaler augmentation for a single series based on |
| | the boolean toggle and probability from the configuration. |
| | """ |
| | if not self.augmentations.get("scaler_augmentation", False): |
| | return False |
| | probability = float(self.augmentation_probabilities.get("scaler_augmentation", 0.0)) |
| | probability = max(0.0, min(1.0, probability)) |
| | return bool(self.rng.random() < probability) |
| |
|
| | def _choose_random_scaler(self) -> object | None: |
| | """ |
| | Choose a random scaler for augmentation, explicitly avoiding the one that |
| | is already selected in the training configuration (if any). |
| | |
| | Returns an instance of the selected scaler or None when no valid option exists. |
| | """ |
| | chosen: str | None = None |
| | if self.chosen_scaler_name is not None: |
| | chosen = self.chosen_scaler_name.strip().lower() |
| |
|
| | candidates = ["custom_robust", "minmax", "median", "mean"] |
| |
|
| | |
| | if chosen in candidates: |
| | candidates = [c for c in candidates if c != chosen] |
| | if not candidates: |
| | return None |
| |
|
| | pick = str(self.rng.choice(candidates)) |
| | if pick == "custom_robust": |
| | return RobustScaler() |
| | if pick == "minmax": |
| | return MinMaxScaler() |
| | if pick == "median": |
| | return MedianScaler() |
| | if pick == "mean": |
| | return MeanScaler() |
| | return None |
| |
|
| | def _setup_proportions(self, generator_proportions): |
| | """Setup default or custom generator proportions.""" |
| | default_proportions = { |
| | "forecast_pfn": 1.0, |
| | "gp": 1.0, |
| | "kernel": 1.0, |
| | "sinewave": 1.0, |
| | "sawtooth": 1.0, |
| | "step": 0.1, |
| | "anomaly": 1.0, |
| | "spike": 2.0, |
| | "cauker_univariate": 2.0, |
| | "cauker_multivariate": 0.00, |
| | "lmc": 0.00, |
| | "ou_process": 1.0, |
| | "audio_financial_volatility": 0.1, |
| | "audio_multi_scale_fractal": 0.1, |
| | "audio_network_topology": 0.5, |
| | "audio_stochastic_rhythm": 1.0, |
| | "augmented_per_sample_2048": 3.0, |
| | "augmented_temp_batch_2048": 3.0, |
| | } |
| | self.generator_proportions = generator_proportions or default_proportions |
| |
|
| | |
| | total = sum(self.generator_proportions.values()) |
| | if total <= 0: |
| | raise ValueError("Total generator proportions must be positive") |
| | self.generator_proportions = {k: v / total for k, v in self.generator_proportions.items()} |
| |
|
| | def _initialize_datasets(self) -> dict[str, CyclicalBatchDataset]: |
| | """Initialize CyclicalBatchDataset for each generator with proportion > 0.""" |
| | datasets = {} |
| |
|
| | for generator_name, proportion in self.generator_proportions.items(): |
| | |
| | if proportion <= 0: |
| | logger.info(f"Skipping {generator_name} (proportion = {proportion})") |
| | continue |
| |
|
| | batches_dir = f"{self.base_data_dir}/{generator_name}" |
| |
|
| | try: |
| | dataset = CyclicalBatchDataset( |
| | batches_dir=batches_dir, |
| | generator_type=generator_name, |
| | device=None, |
| | prefetch_next=True, |
| | prefetch_threshold=32, |
| | rank=self.rank, |
| | world_size=self.world_size, |
| | ) |
| | datasets[generator_name] = dataset |
| | logger.info(f"Loaded dataset for {generator_name} (proportion = {proportion})") |
| |
|
| | except Exception as e: |
| | logger.warning(f"Failed to load dataset for {generator_name}: {e}") |
| | continue |
| |
|
| | if not datasets: |
| | raise ValueError(f"No valid datasets found in {self.base_data_dir} or all generators have proportion <= 0") |
| |
|
| | return datasets |
| |
|
| | def _convert_sample_to_tensors( |
| | self, sample: dict, future_length: int | None = None |
| | ) -> tuple[torch.Tensor, np.datetime64, Frequency]: |
| | """ |
| | Convert a sample dict to tensors and metadata. |
| | |
| | Args: |
| | sample: Sample dict from CyclicalBatchDataset |
| | future_length: Desired future length (if None, use default split) |
| | |
| | Returns: |
| | Tuple of (history_values, future_values, start, frequency) |
| | """ |
| | |
| | num_channels = sample.get("num_channels", 1) |
| | values_data = sample["values"] |
| | generator_type = sample.get("generator_type", "unknown") |
| |
|
| | if num_channels == 1: |
| | |
| | if isinstance(values_data[0], list): |
| | |
| | values = torch.tensor(values_data[0], dtype=torch.float32) |
| | logger.debug(f"{generator_type}: Using new univariate format, shape: {values.shape}") |
| | else: |
| | |
| | values = torch.tensor(values_data, dtype=torch.float32) |
| | values = values.unsqueeze(0).unsqueeze(-1) |
| | else: |
| | |
| | channel_tensors = [] |
| | for channel_values in values_data: |
| | channel_tensor = torch.tensor(channel_values, dtype=torch.float32) |
| | channel_tensors.append(channel_tensor) |
| |
|
| | |
| | values = torch.stack(channel_tensors, dim=-1).unsqueeze(0) |
| | logger.debug(f"{generator_type}: Using multivariate format, {num_channels} channels, shape: {values.shape}") |
| |
|
| | |
| | freq_str = sample["frequency"] |
| | try: |
| | frequency = Frequency(freq_str) |
| | except ValueError: |
| | |
| | freq_mapping = { |
| | "h": Frequency.H, |
| | "D": Frequency.D, |
| | "W": Frequency.W, |
| | "M": Frequency.M, |
| | "Q": Frequency.Q, |
| | "A": Frequency.A, |
| | "Y": Frequency.A, |
| | "1min": Frequency.T1, |
| | "5min": Frequency.T5, |
| | "10min": Frequency.T10, |
| | "15min": Frequency.T15, |
| | "30min": Frequency.T30, |
| | "s": Frequency.S, |
| | } |
| | frequency = freq_mapping.get(freq_str, Frequency.H) |
| |
|
| | |
| | if isinstance(sample["start"], pd.Timestamp): |
| | start = sample["start"].to_numpy() |
| | else: |
| | start = np.datetime64(sample["start"]) |
| |
|
| | return values, start, frequency |
| |
|
| | def _effective_proportions_for_length(self, total_length_for_batch: int) -> dict[str, float]: |
| | """ |
| | Build a simple, length-aware proportion map for the current batch. |
| | |
| | Rules: |
| | - For generators named 'augmented{L}', keep only the one matching the |
| | chosen length L; zero out others. |
| | - Keep non-augmented generators as-is. |
| | - Drop generators that are unavailable (not loaded) or zero-weight. |
| | - If nothing remains, fall back to 'augmented{L}' if available, else any dataset. |
| | - Normalize the final map to sum to 1. |
| | """ |
| |
|
| | def augmented_length_from_name(name: str) -> int | None: |
| | if not name.startswith("augmented"): |
| | return None |
| | suffix = name[len("augmented") :] |
| | if not suffix: |
| | return None |
| | try: |
| | return int(suffix) |
| | except ValueError: |
| | return None |
| |
|
| | |
| | adjusted: dict[str, float] = {} |
| | for name, proportion in self.generator_proportions.items(): |
| | aug_len = augmented_length_from_name(name) |
| | if aug_len is None: |
| | adjusted[name] = proportion |
| | else: |
| | adjusted[name] = proportion if aug_len == total_length_for_batch else 0.0 |
| |
|
| | |
| | adjusted = {name: p for name, p in adjusted.items() if name in self.datasets and p > 0.0} |
| |
|
| | |
| | if not adjusted: |
| | preferred = f"augmented{total_length_for_batch}" |
| | if preferred in self.datasets: |
| | adjusted = {preferred: 1.0} |
| | elif self.datasets: |
| | |
| | first_key = next(iter(self.datasets.keys())) |
| | adjusted = {first_key: 1.0} |
| | else: |
| | raise ValueError("No datasets available to create batch") |
| |
|
| | |
| | total = sum(adjusted.values()) |
| | return {name: p / total for name, p in adjusted.items()} |
| |
|
| | def _compute_sample_counts_for_batch(self, proportions: dict[str, float], batch_size: int) -> dict[str, int]: |
| | """ |
| | Convert a proportion map into integer sample counts that sum to batch_size. |
| | |
| | Strategy: allocate floor(batch_size * p) to each generator in order, and let the |
| | last generator absorb any remainder to ensure the total matches exactly. |
| | """ |
| | counts: dict[str, int] = {} |
| | remaining = batch_size |
| | names = list(proportions.keys()) |
| | values = list(proportions.values()) |
| | for index, (name, p) in enumerate(zip(names, values, strict=True)): |
| | if index == len(names) - 1: |
| | counts[name] = remaining |
| | else: |
| | n = int(batch_size * p) |
| | counts[name] = n |
| | remaining -= n |
| | return counts |
| |
|
| | def _calculate_generator_samples(self, batch_size: int) -> dict[str, int]: |
| | """ |
| | Calculate the number of samples each generator should contribute. |
| | |
| | Args: |
| | batch_size: Total batch size |
| | |
| | Returns: |
| | Dict mapping generator names to sample counts |
| | """ |
| | generator_samples = {} |
| | remaining_samples = batch_size |
| |
|
| | generators = list(self.generator_proportions.keys()) |
| | proportions = list(self.generator_proportions.values()) |
| |
|
| | |
| | for i, (generator, proportion) in enumerate(zip(generators, proportions, strict=True)): |
| | if generator not in self.datasets: |
| | continue |
| |
|
| | if i == len(generators) - 1: |
| | samples = remaining_samples |
| | else: |
| | samples = int(batch_size * proportion) |
| | remaining_samples -= samples |
| | generator_samples[generator] = samples |
| |
|
| | return generator_samples |
| |
|
| | def create_batch( |
| | self, |
| | batch_size: int = 128, |
| | seed: int | None = None, |
| | future_length: int | None = None, |
| | ) -> tuple[BatchTimeSeriesContainer, str]: |
| | """ |
| | Create a batch of the specified size. |
| | |
| | Args: |
| | batch_size: Size of the batch to create |
| | seed: Random seed for this batch |
| | future_length: Fixed future length to use. If None, samples from gift_eval range |
| | |
| | Returns: |
| | Tuple of (batch_container, generator_info) |
| | """ |
| | if seed is not None: |
| | batch_rng = np.random.default_rng(seed) |
| | random.seed(seed) |
| | else: |
| | batch_rng = self.rng |
| |
|
| | if self.mixed_batches: |
| | return self._create_mixed_batch(batch_size, future_length) |
| | else: |
| | return self._create_uniform_batch(batch_size, batch_rng, future_length) |
| |
|
| | def _create_mixed_batch( |
| | self, batch_size: int, future_length: int | None = None |
| | ) -> tuple[BatchTimeSeriesContainer, str]: |
| | """Create a mixed batch with samples from multiple generators, rejecting NaNs.""" |
| |
|
| | |
| | |
| | if self.augmentations.get("length_shortening", False): |
| | lengths = list(LENGTH_WEIGHTS.keys()) |
| | probs = list(LENGTH_WEIGHTS.values()) |
| | total_length_for_batch = int(self.rng.choice(lengths, p=probs)) |
| | else: |
| | total_length_for_batch = int(max(LENGTH_CHOICES)) |
| |
|
| | if future_length is None: |
| | prediction_length = int(sample_future_length(range="gift_eval", total_length=total_length_for_batch)) |
| | else: |
| | prediction_length = future_length |
| |
|
| | history_length = total_length_for_batch - prediction_length |
| |
|
| | |
| | effective_props = self._effective_proportions_for_length(total_length_for_batch) |
| | generator_samples = self._compute_sample_counts_for_batch(effective_props, batch_size) |
| |
|
| | all_values = [] |
| | all_starts = [] |
| | all_frequencies = [] |
| | actual_proportions = {} |
| |
|
| | |
| | for generator_name, num_samples in generator_samples.items(): |
| | if num_samples == 0 or generator_name not in self.datasets: |
| | continue |
| |
|
| | dataset = self.datasets[generator_name] |
| |
|
| | |
| | generator_values = [] |
| | generator_starts = [] |
| | generator_frequencies = [] |
| |
|
| | |
| | max_attempts = 50 |
| | attempts = 0 |
| | while len(generator_values) < num_samples and attempts < max_attempts: |
| | attempts += 1 |
| | |
| | need = num_samples - len(generator_values) |
| | fetch_n = max(need * 2, 8) |
| | samples = dataset.get_samples(fetch_n) |
| |
|
| | for sample in samples: |
| | if len(generator_values) >= num_samples: |
| | break |
| |
|
| | values, sample_start, sample_freq = self._convert_sample_to_tensors(sample, future_length) |
| |
|
| | |
| | if torch.isnan(values).any(): |
| | continue |
| |
|
| | |
| | if total_length_for_batch < values.shape[1]: |
| | strategy = self.rng.choice(["cut", "subsample"]) |
| | if strategy == "cut": |
| | max_start_idx = values.shape[1] - total_length_for_batch |
| | start_idx = int(self.rng.integers(0, max_start_idx + 1)) |
| | values = values[:, start_idx : start_idx + total_length_for_batch, :] |
| | else: |
| | indices = np.linspace( |
| | 0, |
| | values.shape[1] - 1, |
| | total_length_for_batch, |
| | dtype=int, |
| | ) |
| | values = values[:, indices, :] |
| |
|
| | |
| | if self._should_apply_scaler_augmentation(): |
| | scaler = self._choose_random_scaler() |
| | if scaler is not None: |
| | values = scaler.scale(values, scaler.compute_statistics(values)) |
| |
|
| | generator_values.append(values) |
| | generator_starts.append(sample_start) |
| | generator_frequencies.append(sample_freq) |
| |
|
| | if len(generator_values) < num_samples: |
| | logger.warning( |
| | f"Generator {generator_name}: collected {len(generator_values)}/" |
| | f"{num_samples} after {attempts} attempts" |
| | ) |
| |
|
| | |
| | if generator_values: |
| | all_values.extend(generator_values) |
| | all_starts.extend(generator_starts) |
| | all_frequencies.extend(generator_frequencies) |
| | actual_proportions[generator_name] = len(generator_values) |
| |
|
| | if not all_values: |
| | raise RuntimeError("No valid samples could be collected from any generator.") |
| |
|
| | combined_values = torch.cat(all_values, dim=0) |
| | |
| | combined_history = combined_values[:, :history_length, :] |
| | combined_future = combined_values[:, history_length : history_length + prediction_length, :] |
| |
|
| | if self.nan_augmenter is not None: |
| | combined_history = self.nan_augmenter.transform(combined_history) |
| |
|
| | |
| | container = BatchTimeSeriesContainer( |
| | history_values=combined_history, |
| | future_values=combined_future, |
| | start=all_starts, |
| | frequency=all_frequencies, |
| | ) |
| |
|
| | return container, "MixedBatch" |
| |
|
| | def _create_uniform_batch( |
| | self, |
| | batch_size: int, |
| | batch_rng: np.random.Generator, |
| | future_length: int | None = None, |
| | ) -> tuple[BatchTimeSeriesContainer, str]: |
| | """Create a uniform batch with samples from a single generator.""" |
| |
|
| | |
| | generators = list(self.datasets.keys()) |
| | proportions = [self.generator_proportions[gen] for gen in generators] |
| | selected_generator = batch_rng.choice(generators, p=proportions) |
| |
|
| | |
| | if future_length is None: |
| | future_length = sample_future_length(range="gift_eval") |
| |
|
| | |
| | dataset = self.datasets[selected_generator] |
| | samples = dataset.get_samples(batch_size) |
| |
|
| | all_history_values = [] |
| | all_future_values = [] |
| | all_starts = [] |
| | all_frequencies = [] |
| |
|
| | for sample in samples: |
| | values, sample_start, sample_freq = self._convert_sample_to_tensors(sample, future_length) |
| |
|
| | total_length = values.shape[1] |
| | history_length = max(1, total_length - future_length) |
| |
|
| | |
| | if self._should_apply_scaler_augmentation(): |
| | scaler = self._choose_random_scaler() |
| | if scaler is not None: |
| | values = scaler.scale(values, scaler.compute_statistics(values)) |
| |
|
| | |
| | hist_vals = values[:, :history_length, :] |
| | fut_vals = values[:, history_length : history_length + future_length, :] |
| |
|
| | all_history_values.append(hist_vals) |
| | all_future_values.append(fut_vals) |
| | all_starts.append(sample_start) |
| | all_frequencies.append(sample_freq) |
| |
|
| | |
| | combined_history = torch.cat(all_history_values, dim=0) |
| | combined_future = torch.cat(all_future_values, dim=0) |
| |
|
| | |
| | container = BatchTimeSeriesContainer( |
| | history_values=combined_history, |
| | future_values=combined_future, |
| | start=all_starts, |
| | frequency=all_frequencies, |
| | ) |
| |
|
| | return container, selected_generator |
| |
|
| | def get_dataset_info(self) -> dict[str, dict]: |
| | """Get information about all datasets.""" |
| | info = {} |
| | for name, dataset in self.datasets.items(): |
| | info[name] = dataset.get_info() |
| | return info |
| |
|
| | def get_generator_info(self) -> dict[str, any]: |
| | """Get information about the composer configuration.""" |
| | return { |
| | "mixed_batches": self.mixed_batches, |
| | "generator_proportions": self.generator_proportions, |
| | "active_generators": list(self.datasets.keys()), |
| | "total_generators": len(self.datasets), |
| | "augmentations": self.augmentations, |
| | "augmentation_probabilities": self.augmentation_probabilities, |
| | "nan_augmenter_enabled": self.nan_augmenter is not None, |
| | } |
| |
|
| |
|
| | class ComposedDataset(torch.utils.data.Dataset): |
| | """ |
| | PyTorch Dataset wrapper around BatchComposer for training pipeline integration. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | batch_composer: BatchComposer, |
| | num_batches_per_epoch: int = 100, |
| | batch_size: int = 128, |
| | ): |
| | """ |
| | Initialize the dataset. |
| | |
| | Args: |
| | batch_composer: The BatchComposer instance |
| | num_batches_per_epoch: Number of batches to generate per epoch |
| | batch_size: Size of each batch |
| | """ |
| | self.batch_composer = batch_composer |
| | self.num_batches_per_epoch = num_batches_per_epoch |
| | self.batch_size = batch_size |
| |
|
| | def __len__(self) -> int: |
| | return self.num_batches_per_epoch |
| |
|
| | def __getitem__(self, idx: int) -> BatchTimeSeriesContainer: |
| | """ |
| | Get a batch by index. |
| | |
| | Args: |
| | idx: Batch index (used as seed for reproducibility) |
| | |
| | Returns: |
| | BatchTimeSeriesContainer |
| | """ |
| | |
| | batch, _ = self.batch_composer.create_batch( |
| | batch_size=self.batch_size, seed=self.batch_composer.global_seed + idx |
| | ) |
| | return batch |
| |
|