| | from dataclasses import dataclass |
| |
|
| | import numpy as np |
| | import torch |
| |
|
| | from src.data.frequency import Frequency |
| |
|
| |
|
| | @dataclass |
| | class BatchTimeSeriesContainer: |
| | """ |
| | Container for a batch of multivariate time series data and their associated features. |
| | |
| | Attributes: |
| | history_values: Tensor of historical observations. |
| | Shape: [batch_size, seq_len, num_channels] |
| | future_values: Tensor of future observations to predict. |
| | Shape: [batch_size, pred_len, num_channels] |
| | start: Timestamp of the first history value. |
| | Type: List[np.datetime64] |
| | frequency: Frequency of the time series. |
| | Type: List[Frequency] |
| | history_mask: Optional boolean/float tensor indicating missing entries in history_values across channels. |
| | Shape: [batch_size, seq_len] |
| | future_mask: Optional boolean/float tensor indicating missing entries in future_values across channels. |
| | Shape: [batch_size, pred_len] |
| | """ |
| |
|
| | history_values: torch.Tensor |
| | future_values: torch.Tensor |
| | start: list[np.datetime64] |
| | frequency: list[Frequency] |
| |
|
| | history_mask: torch.Tensor | None = None |
| | future_mask: torch.Tensor | None = None |
| |
|
| | def __post_init__(self): |
| | """Validate all tensor shapes and consistency.""" |
| | |
| | if not isinstance(self.history_values, torch.Tensor): |
| | raise TypeError("history_values must be a torch.Tensor") |
| | if not isinstance(self.future_values, torch.Tensor): |
| | raise TypeError("future_values must be a torch.Tensor") |
| | if not isinstance(self.start, list) or not all(isinstance(x, np.datetime64) for x in self.start): |
| | raise TypeError("start must be a List[np.datetime64]") |
| | if not isinstance(self.frequency, list) or not all(isinstance(x, Frequency) for x in self.frequency): |
| | raise TypeError("frequency must be a List[Frequency]") |
| |
|
| | batch_size, seq_len, num_channels = self.history_values.shape |
| | pred_len = self.future_values.shape[1] |
| |
|
| | |
| | if self.future_values.shape[0] != batch_size: |
| | raise ValueError("Batch size mismatch between history and future_values") |
| | if self.future_values.shape[2] != num_channels: |
| | raise ValueError("Channel size mismatch between history and future_values") |
| |
|
| | |
| | if self.history_mask is not None: |
| | if not isinstance(self.history_mask, torch.Tensor): |
| | raise TypeError("history_mask must be a Tensor or None") |
| | if self.history_mask.shape[:2] != (batch_size, seq_len): |
| | raise ValueError( |
| | f"Shape mismatch in history_mask: {self.history_mask.shape[:2]} vs {(batch_size, seq_len)}" |
| | ) |
| |
|
| | if self.future_mask is not None: |
| | if not isinstance(self.future_mask, torch.Tensor): |
| | raise TypeError("future_mask must be a Tensor or None") |
| | if not ( |
| | self.future_mask.shape == (batch_size, pred_len) or self.future_mask.shape == self.future_values.shape |
| | ): |
| | raise ValueError( |
| | "Shape mismatch in future_mask: " |
| | f"expected {(batch_size, pred_len)} or {self.future_values.shape}, got {self.future_mask.shape}" |
| | ) |
| |
|
| | def to_device(self, device: torch.device, attributes: list[str] | None = None) -> None: |
| | """ |
| | Move specified tensors to the target device in place. |
| | |
| | Args: |
| | device: Target device (e.g., 'cpu', 'cuda'). |
| | attributes: Optional list of attribute names to move. If None, move all tensors. |
| | |
| | Raises: |
| | ValueError: If an invalid attribute is specified or device transfer fails. |
| | """ |
| | all_tensors = { |
| | "history_values": self.history_values, |
| | "future_values": self.future_values, |
| | "history_mask": self.history_mask, |
| | "future_mask": self.future_mask, |
| | } |
| |
|
| | if attributes is None: |
| | attributes = [k for k, v in all_tensors.items() if v is not None] |
| |
|
| | for attr in attributes: |
| | if attr not in all_tensors: |
| | raise ValueError(f"Invalid attribute: {attr}") |
| | if all_tensors[attr] is not None: |
| | setattr(self, attr, all_tensors[attr].to(device)) |
| |
|
| | def to(self, device: torch.device, attributes: list[str] | None = None): |
| | """ |
| | Alias for to_device method for consistency with PyTorch conventions. |
| | |
| | Args: |
| | device: Target device (e.g., 'cpu', 'cuda'). |
| | attributes: Optional list of attribute names to move. If None, move all tensors. |
| | """ |
| | self.to_device(device, attributes) |
| | return self |
| |
|
| | @property |
| | def batch_size(self) -> int: |
| | return self.history_values.shape[0] |
| |
|
| | @property |
| | def history_length(self) -> int: |
| | return self.history_values.shape[1] |
| |
|
| | @property |
| | def future_length(self) -> int: |
| | return self.future_values.shape[1] |
| |
|
| | @property |
| | def num_channels(self) -> int: |
| | return self.history_values.shape[2] |
| |
|
| |
|
| | @dataclass |
| | class TimeSeriesContainer: |
| | """ |
| | Container for batch of time series data without explicit history/future split. |
| | |
| | This container is used for storing generated synthetic time series data where |
| | the entire series is treated as a single entity, typically for further processing |
| | or splitting into history/future components later. |
| | |
| | Attributes: |
| | values: np.ndarray of time series values. |
| | Shape: [batch_size, seq_len, num_channels] for multivariate series |
| | [batch_size, seq_len] for univariate series |
| | start: List of start timestamps for each series in the batch. |
| | Type: List[np.datetime64], length should match batch_size |
| | frequency: List of frequency for each series in the batch. |
| | Type: List[Frequency], length should match batch_size |
| | """ |
| |
|
| | values: np.ndarray |
| | start: list[np.datetime64] |
| | frequency: list[Frequency] |
| |
|
| | def __post_init__(self): |
| | """Validate all shapes and consistency.""" |
| | |
| | if not isinstance(self.values, np.ndarray): |
| | raise TypeError("values must be a np.ndarray") |
| | if not isinstance(self.start, list) or not all(isinstance(x, np.datetime64) for x in self.start): |
| | raise TypeError("start must be a List[np.datetime64]") |
| | if not isinstance(self.frequency, list) or not all(isinstance(x, Frequency) for x in self.frequency): |
| | raise TypeError("frequency must be a List[Frequency]") |
| |
|
| | |
| | if len(self.values.shape) < 2 or len(self.values.shape) > 3: |
| | raise ValueError( |
| | "values must have 2 or 3 dimensions " |
| | "[batch_size, seq_len] or [batch_size, seq_len, num_channels], " |
| | f"got shape {self.values.shape}" |
| | ) |
| |
|
| | batch_size = self.values.shape[0] |
| |
|
| | if len(self.start) != batch_size: |
| | raise ValueError(f"Length of start ({len(self.start)}) must match batch_size ({batch_size})") |
| | if len(self.frequency) != batch_size: |
| | raise ValueError(f"Length of frequency ({len(self.frequency)}) must match batch_size ({batch_size})") |
| |
|
| | @property |
| | def batch_size(self) -> int: |
| | return self.values.shape[0] |
| |
|
| | @property |
| | def seq_length(self) -> int: |
| | return self.values.shape[1] |
| |
|
| | @property |
| | def num_channels(self) -> int: |
| | return self.values.shape[2] if len(self.values.shape) == 3 else 1 |
| |
|