| | import logging |
| | from typing import Any |
| |
|
| | import numpy as np |
| | import pandas as pd |
| | import scipy.fft as fft |
| | import torch |
| | from gluonts.time_feature import time_features_from_frequency_str |
| | from gluonts.time_feature._base import ( |
| | day_of_month, |
| | day_of_month_index, |
| | day_of_week, |
| | day_of_week_index, |
| | day_of_year, |
| | hour_of_day, |
| | hour_of_day_index, |
| | minute_of_hour, |
| | minute_of_hour_index, |
| | month_of_year, |
| | month_of_year_index, |
| | second_of_minute, |
| | second_of_minute_index, |
| | week_of_year, |
| | week_of_year_index, |
| | ) |
| | from gluonts.time_feature.holiday import ( |
| | BLACK_FRIDAY, |
| | CHRISTMAS_DAY, |
| | CHRISTMAS_EVE, |
| | CYBER_MONDAY, |
| | EASTER_MONDAY, |
| | EASTER_SUNDAY, |
| | GOOD_FRIDAY, |
| | INDEPENDENCE_DAY, |
| | LABOR_DAY, |
| | MEMORIAL_DAY, |
| | NEW_YEARS_DAY, |
| | NEW_YEARS_EVE, |
| | THANKSGIVING, |
| | SpecialDateFeatureSet, |
| | exponential_kernel, |
| | squared_exponential_kernel, |
| | ) |
| | from gluonts.time_feature.seasonality import get_seasonality |
| | from scipy.signal import find_peaks |
| |
|
| | from src.data.constants import BASE_END_DATE, BASE_START_DATE |
| | from src.data.frequency import ( |
| | Frequency, |
| | validate_frequency_safety, |
| | ) |
| | from src.utils.utils import device |
| |
|
| | |
| | logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s") |
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | |
| | ENHANCED_TIME_FEATURES = { |
| | |
| | "high_freq": { |
| | "normalized": [ |
| | second_of_minute, |
| | minute_of_hour, |
| | hour_of_day, |
| | day_of_week, |
| | day_of_month, |
| | ], |
| | "index": [ |
| | second_of_minute_index, |
| | minute_of_hour_index, |
| | hour_of_day_index, |
| | day_of_week_index, |
| | ], |
| | }, |
| | |
| | "medium_freq": { |
| | "normalized": [ |
| | hour_of_day, |
| | day_of_week, |
| | day_of_month, |
| | day_of_year, |
| | month_of_year, |
| | ], |
| | "index": [ |
| | hour_of_day_index, |
| | day_of_week_index, |
| | day_of_month_index, |
| | week_of_year_index, |
| | ], |
| | }, |
| | |
| | "low_freq": { |
| | "normalized": [day_of_week, day_of_month, month_of_year, week_of_year], |
| | "index": [day_of_week_index, month_of_year_index, week_of_year_index], |
| | }, |
| | } |
| |
|
| | |
| | HOLIDAY_FEATURE_SETS = { |
| | "us_business": [ |
| | NEW_YEARS_DAY, |
| | MEMORIAL_DAY, |
| | INDEPENDENCE_DAY, |
| | LABOR_DAY, |
| | THANKSGIVING, |
| | CHRISTMAS_EVE, |
| | CHRISTMAS_DAY, |
| | NEW_YEARS_EVE, |
| | ], |
| | "us_retail": [ |
| | NEW_YEARS_DAY, |
| | EASTER_SUNDAY, |
| | MEMORIAL_DAY, |
| | INDEPENDENCE_DAY, |
| | LABOR_DAY, |
| | THANKSGIVING, |
| | BLACK_FRIDAY, |
| | CYBER_MONDAY, |
| | CHRISTMAS_EVE, |
| | CHRISTMAS_DAY, |
| | NEW_YEARS_EVE, |
| | ], |
| | "christian": [ |
| | NEW_YEARS_DAY, |
| | GOOD_FRIDAY, |
| | EASTER_SUNDAY, |
| | EASTER_MONDAY, |
| | CHRISTMAS_EVE, |
| | CHRISTMAS_DAY, |
| | NEW_YEARS_EVE, |
| | ], |
| | } |
| |
|
| |
|
| | class TimeFeatureGenerator: |
| | """ |
| | Enhanced time feature generator that leverages full GluonTS capabilities. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | use_enhanced_features: bool = True, |
| | use_holiday_features: bool = True, |
| | holiday_set: str = "us_business", |
| | holiday_kernel: str = "exponential", |
| | holiday_kernel_alpha: float = 1.0, |
| | use_index_features: bool = True, |
| | k_max: int = 15, |
| | include_seasonality_info: bool = True, |
| | use_auto_seasonality: bool = False, |
| | max_seasonal_periods: int = 3, |
| | ): |
| | """ |
| | Initialize enhanced time feature generator. |
| | |
| | Parameters |
| | ---------- |
| | use_enhanced_features : bool |
| | Whether to use frequency-specific enhanced features |
| | use_holiday_features : bool |
| | Whether to include holiday features |
| | holiday_set : str |
| | Which holiday set to use ('us_business', 'us_retail', 'christian') |
| | holiday_kernel : str |
| | Holiday kernel type ('indicator', 'exponential', 'squared_exponential') |
| | holiday_kernel_alpha : float |
| | Kernel parameter for exponential kernels |
| | use_index_features : bool |
| | Whether to include index-based features alongside normalized ones |
| | k_max : int |
| | Maximum number of time features to pad to |
| | include_seasonality_info : bool |
| | Whether to include seasonality information as features |
| | use_auto_seasonality : bool |
| | Whether to use automatic FFT-based seasonality detection |
| | max_seasonal_periods : int |
| | Maximum number of seasonal periods to detect automatically |
| | """ |
| | self.use_enhanced_features = use_enhanced_features |
| | self.use_holiday_features = use_holiday_features |
| | self.holiday_set = holiday_set |
| | self.use_index_features = use_index_features |
| | self.k_max = k_max |
| | self.include_seasonality_info = include_seasonality_info |
| | self.use_auto_seasonality = use_auto_seasonality |
| | self.max_seasonal_periods = max_seasonal_periods |
| |
|
| | |
| | self.holiday_feature_set = None |
| | if use_holiday_features and holiday_set in HOLIDAY_FEATURE_SETS: |
| | kernel_func = self._get_holiday_kernel(holiday_kernel, holiday_kernel_alpha) |
| | self.holiday_feature_set = SpecialDateFeatureSet(HOLIDAY_FEATURE_SETS[holiday_set], kernel_func) |
| |
|
| | def _get_holiday_kernel(self, kernel_type: str, alpha: float): |
| | """Get holiday kernel function.""" |
| | if kernel_type == "exponential": |
| | return exponential_kernel(alpha) |
| | elif kernel_type == "squared_exponential": |
| | return squared_exponential_kernel(alpha) |
| | else: |
| | |
| | return lambda x: float(x == 0) |
| |
|
| | def _get_feature_category(self, freq_str: str) -> str: |
| | """Determine feature category based on frequency.""" |
| | if freq_str in ["s", "1min", "5min", "10min", "15min"]: |
| | return "high_freq" |
| | elif freq_str in ["h", "D"]: |
| | return "medium_freq" |
| | else: |
| | return "low_freq" |
| |
|
| | def _compute_enhanced_features(self, period_index: pd.PeriodIndex, freq_str: str) -> np.ndarray: |
| | """Compute enhanced time features based on frequency.""" |
| | if not self.use_enhanced_features: |
| | return np.array([]).reshape(len(period_index), 0) |
| |
|
| | category = self._get_feature_category(freq_str) |
| | feature_config = ENHANCED_TIME_FEATURES[category] |
| |
|
| | features = [] |
| |
|
| | |
| | for feat_func in feature_config["normalized"]: |
| | try: |
| | feat_values = feat_func(period_index) |
| | features.append(feat_values) |
| | except Exception: |
| | continue |
| |
|
| | |
| | if self.use_index_features: |
| | for feat_func in feature_config["index"]: |
| | try: |
| | feat_values = feat_func(period_index) |
| | |
| | if feat_values.max() > 0: |
| | feat_values = feat_values / feat_values.max() |
| | features.append(feat_values) |
| | except Exception: |
| | continue |
| |
|
| | if features: |
| | return np.stack(features, axis=-1) |
| | else: |
| | return np.array([]).reshape(len(period_index), 0) |
| |
|
| | def _compute_holiday_features(self, date_range: pd.DatetimeIndex) -> np.ndarray: |
| | """Compute holiday features.""" |
| | if not self.use_holiday_features or self.holiday_feature_set is None: |
| | return np.array([]).reshape(len(date_range), 0) |
| |
|
| | try: |
| | holiday_features = self.holiday_feature_set(date_range) |
| | return holiday_features.T |
| | except Exception: |
| | return np.array([]).reshape(len(date_range), 0) |
| |
|
| | def _detect_auto_seasonality(self, time_series_values: np.ndarray) -> list: |
| | """ |
| | Detect seasonal periods automatically using FFT analysis. |
| | |
| | Parameters |
| | ---------- |
| | time_series_values : np.ndarray |
| | Time series values for seasonality detection |
| | |
| | Returns |
| | ------- |
| | list |
| | List of detected seasonal periods |
| | """ |
| | if not self.use_auto_seasonality or len(time_series_values) < 10: |
| | return [] |
| |
|
| | try: |
| | |
| | values = time_series_values[~np.isnan(time_series_values)] |
| | if len(values) < 10: |
| | return [] |
| |
|
| | |
| | x = np.arange(len(values)) |
| | coeffs = np.polyfit(x, values, 1) |
| | trend = np.polyval(coeffs, x) |
| | detrended = values - trend |
| |
|
| | |
| | window = np.hanning(len(detrended)) |
| | windowed = detrended * window |
| |
|
| | |
| | padded_length = len(windowed) * 2 |
| | padded_values = np.zeros(padded_length) |
| | padded_values[: len(windowed)] = windowed |
| |
|
| | |
| | fft_values = fft.rfft(padded_values) |
| | fft_magnitudes = np.abs(fft_values) |
| | freqs = np.fft.rfftfreq(padded_length) |
| |
|
| | |
| | fft_magnitudes[0] = 0.0 |
| |
|
| | |
| | threshold = 0.05 * np.max(fft_magnitudes) |
| | peak_indices, _ = find_peaks(fft_magnitudes, height=threshold) |
| |
|
| | if len(peak_indices) == 0: |
| | return [] |
| |
|
| | |
| | sorted_indices = peak_indices[np.argsort(fft_magnitudes[peak_indices])[::-1]] |
| | top_indices = sorted_indices[: self.max_seasonal_periods] |
| |
|
| | |
| | periods = [] |
| | for idx in top_indices: |
| | if freqs[idx] > 0: |
| | period = 1.0 / freqs[idx] |
| | |
| | period = round(period / 2) |
| | if 2 <= period <= len(values) // 2: |
| | periods.append(period) |
| |
|
| | return list(set(periods)) |
| |
|
| | except Exception: |
| | return [] |
| |
|
| | def _compute_seasonality_features( |
| | self, |
| | period_index: pd.PeriodIndex, |
| | freq_str: str, |
| | time_series_values: np.ndarray = None, |
| | ) -> np.ndarray: |
| | """Compute seasonality-aware features.""" |
| | if not self.include_seasonality_info: |
| | return np.array([]).reshape(len(period_index), 0) |
| |
|
| | all_seasonal_features = [] |
| |
|
| | |
| | try: |
| | seasonality = get_seasonality(freq_str) |
| | if seasonality > 1: |
| | positions = np.arange(len(period_index)) |
| | sin_feat = np.sin(2 * np.pi * positions / seasonality) |
| | cos_feat = np.cos(2 * np.pi * positions / seasonality) |
| | all_seasonal_features.extend([sin_feat, cos_feat]) |
| | except Exception: |
| | pass |
| |
|
| | |
| | if self.use_auto_seasonality and time_series_values is not None: |
| | auto_periods = self._detect_auto_seasonality(time_series_values) |
| | for period in auto_periods: |
| | try: |
| | positions = np.arange(len(period_index)) |
| | sin_feat = np.sin(2 * np.pi * positions / period) |
| | cos_feat = np.cos(2 * np.pi * positions / period) |
| | all_seasonal_features.extend([sin_feat, cos_feat]) |
| | except Exception: |
| | continue |
| |
|
| | if all_seasonal_features: |
| | return np.stack(all_seasonal_features, axis=-1) |
| | else: |
| | return np.array([]).reshape(len(period_index), 0) |
| |
|
| | def compute_features( |
| | self, |
| | period_index: pd.PeriodIndex, |
| | date_range: pd.DatetimeIndex, |
| | freq_str: str, |
| | time_series_values: np.ndarray = None, |
| | ) -> np.ndarray: |
| | """ |
| | Compute all time features for given period index. |
| | |
| | Parameters |
| | ---------- |
| | period_index : pd.PeriodIndex |
| | Period index for computing features |
| | date_range : pd.DatetimeIndex |
| | Corresponding datetime index for holiday features |
| | freq_str : str |
| | Frequency string |
| | time_series_values : np.ndarray, optional |
| | Time series values for automatic seasonality detection |
| | |
| | Returns |
| | ------- |
| | np.ndarray |
| | Time features array of shape [time_steps, num_features] |
| | """ |
| | all_features = [] |
| |
|
| | |
| | try: |
| | standard_features = time_features_from_frequency_str(freq_str) |
| | if standard_features: |
| | std_feat = np.stack([feat(period_index) for feat in standard_features], axis=-1) |
| | all_features.append(std_feat) |
| | except Exception: |
| | pass |
| |
|
| | |
| | enhanced_feat = self._compute_enhanced_features(period_index, freq_str) |
| | if enhanced_feat.shape[1] > 0: |
| | all_features.append(enhanced_feat) |
| |
|
| | |
| | holiday_feat = self._compute_holiday_features(date_range) |
| | if holiday_feat.shape[1] > 0: |
| | all_features.append(holiday_feat) |
| |
|
| | |
| | seasonality_feat = self._compute_seasonality_features(period_index, freq_str, time_series_values) |
| | if seasonality_feat.shape[1] > 0: |
| | all_features.append(seasonality_feat) |
| |
|
| | if all_features: |
| | combined_features = np.concatenate(all_features, axis=-1) |
| | else: |
| | combined_features = np.zeros((len(period_index), 1)) |
| |
|
| | return combined_features |
| |
|
| |
|
| | def compute_batch_time_features( |
| | start: list[np.datetime64], |
| | history_length: int, |
| | future_length: int, |
| | batch_size: int, |
| | frequency: list[Frequency], |
| | K_max: int = 6, |
| | time_feature_config: dict[str, Any] | None = None, |
| | ): |
| | """ |
| | Compute time features from start timestamps and frequency. |
| | |
| | Parameters |
| | ---------- |
| | start : array-like, shape (batch_size,) |
| | Start timestamps for each batch item. |
| | history_length : int |
| | Length of history sequence. |
| | future_length : int |
| | Length of target sequence. |
| | batch_size : int |
| | Batch size. |
| | frequency : array-like, shape (batch_size,) |
| | Frequency of the time series. |
| | K_max : int, optional |
| | Maximum number of time features to pad to (default: 6). |
| | time_feature_config : dict, optional |
| | Configuration for enhanced time features. |
| | |
| | Returns |
| | ------- |
| | tuple |
| | (history_time_features, target_time_features) where each is a torch.Tensor |
| | of shape (batch_size, length, K_max). |
| | """ |
| | |
| | feature_config = time_feature_config or {} |
| | feature_generator = TimeFeatureGenerator(**feature_config) |
| |
|
| | |
| | history_features_list = [] |
| | future_features_list = [] |
| | total_length = history_length + future_length |
| | for i in range(batch_size): |
| | frequency_i = frequency[i] |
| | freq_str = frequency_i.to_pandas_freq(for_date_range=True) |
| | period_freq_str = frequency_i.to_pandas_freq(for_date_range=False) |
| |
|
| | |
| | start_ts = pd.Timestamp(start[i]) |
| | if not validate_frequency_safety(start_ts, total_length, frequency_i): |
| | logger.debug( |
| | f"Start date {start_ts} not safe for total_length={total_length}, frequency={frequency_i}. " |
| | f"Using BASE_START_DATE instead." |
| | ) |
| | start_ts = BASE_START_DATE |
| |
|
| | |
| | history_range = pd.date_range(start=start_ts, periods=history_length, freq=freq_str) |
| |
|
| | |
| | if history_range[-1] > BASE_END_DATE: |
| | safe_start = BASE_END_DATE - pd.tseries.frequencies.to_offset(freq_str) * (history_length + future_length) |
| | if safe_start < BASE_START_DATE: |
| | safe_start = BASE_START_DATE |
| | history_range = pd.date_range(start=safe_start, periods=history_length, freq=freq_str) |
| |
|
| | future_start = history_range[-1] + pd.tseries.frequencies.to_offset(freq_str) |
| | future_range = pd.date_range(start=future_start, periods=future_length, freq=freq_str) |
| |
|
| | |
| | history_period_idx = history_range.to_period(period_freq_str) |
| | future_period_idx = future_range.to_period(period_freq_str) |
| |
|
| | |
| | history_features = feature_generator.compute_features(history_period_idx, history_range, freq_str) |
| | future_features = feature_generator.compute_features(future_period_idx, future_range, freq_str) |
| |
|
| | |
| | history_features = _pad_or_truncate_features(history_features, K_max) |
| | future_features = _pad_or_truncate_features(future_features, K_max) |
| |
|
| | history_features_list.append(history_features) |
| | future_features_list.append(future_features) |
| |
|
| | |
| | history_time_features = np.stack(history_features_list, axis=0) |
| | future_time_features = np.stack(future_features_list, axis=0) |
| |
|
| | return ( |
| | torch.from_numpy(history_time_features).float().to(device), |
| | torch.from_numpy(future_time_features).float().to(device), |
| | ) |
| |
|
| |
|
| | def _pad_or_truncate_features(features: np.ndarray, K_max: int) -> np.ndarray: |
| | """Pad with zeros or truncate features to K_max dimensions.""" |
| | seq_len, num_features = features.shape |
| |
|
| | if num_features < K_max: |
| | |
| | padding = np.zeros((seq_len, K_max - num_features)) |
| | features = np.concatenate([features, padding], axis=-1) |
| | elif num_features > K_max: |
| | |
| | features = features[:, :K_max] |
| |
|
| | return features |
| |
|