| | import argparse |
| | import logging |
| | import warnings |
| | from pathlib import Path |
| |
|
| | import matplotlib |
| | from gluonts.model.evaluation import evaluate_model |
| | from gluonts.time_feature import get_seasonality |
| | from linear_operator.utils.cholesky import NumericalWarning |
| |
|
| | from src.gift_eval.constants import ( |
| | DATASET_PROPERTIES, |
| | MED_LONG_DATASETS, |
| | METRICS, |
| | PRETTY_NAMES, |
| | ) |
| | from src.gift_eval.core import DatasetMetadata, EvaluationItem, expand_datasets_arg |
| | from src.gift_eval.data import Dataset |
| | from src.gift_eval.predictor import TimeSeriesPredictor |
| | from src.gift_eval.results import write_results_to_disk |
| | from src.plotting.gift_eval_utils import create_plots_for_dataset |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | warnings.filterwarnings("ignore", category=NumericalWarning) |
| | warnings.filterwarnings("ignore", category=FutureWarning) |
| | warnings.filterwarnings("ignore", category=DeprecationWarning) |
| | matplotlib.set_loglevel("WARNING") |
| | logging.getLogger("matplotlib").setLevel(logging.WARNING) |
| | logging.getLogger("matplotlib.font_manager").setLevel(logging.WARNING) |
| | logging.getLogger("PIL").setLevel(logging.WARNING) |
| |
|
| |
|
| | class WarningFilter(logging.Filter): |
| | def __init__(self, text_to_filter: str) -> None: |
| | super().__init__() |
| | self.text_to_filter = text_to_filter |
| |
|
| | def filter(self, record: logging.LogRecord) -> bool: |
| | return self.text_to_filter not in record.getMessage() |
| |
|
| |
|
| | |
| | gts_logger = logging.getLogger("gluonts.model.forecast") |
| | gts_logger.addFilter(WarningFilter("The mean prediction is not stored in the forecast data")) |
| |
|
| |
|
| | def construct_evaluation_data( |
| | dataset_name: str, |
| | dataset_storage_path: str, |
| | terms: list[str] | None = None, |
| | max_windows: int | None = None, |
| | ) -> list[tuple[Dataset, DatasetMetadata]]: |
| | """Build datasets and rich metadata per term for a dataset name.""" |
| | if terms is None: |
| | terms = ["short", "medium", "long"] |
| |
|
| | sub_datasets: list[tuple[Dataset, DatasetMetadata]] = [] |
| |
|
| | if "/" in dataset_name: |
| | ds_key, ds_freq = dataset_name.split("/") |
| | ds_key = ds_key.lower() |
| | ds_key = PRETTY_NAMES.get(ds_key, ds_key) |
| | else: |
| | ds_key = dataset_name.lower() |
| | ds_key = PRETTY_NAMES.get(ds_key, ds_key) |
| | ds_freq = DATASET_PROPERTIES.get(ds_key, {}).get("frequency") |
| |
|
| | for term in terms: |
| | |
| | if (term == "medium" or term == "long") and dataset_name not in MED_LONG_DATASETS: |
| | continue |
| |
|
| | |
| | probe_dataset = Dataset( |
| | name=dataset_name, |
| | term=term, |
| | to_univariate=False, |
| | storage_path=dataset_storage_path, |
| | max_windows=max_windows, |
| | ) |
| |
|
| | to_univariate = probe_dataset.target_dim > 1 |
| |
|
| | dataset = Dataset( |
| | name=dataset_name, |
| | term=term, |
| | to_univariate=to_univariate, |
| | storage_path=dataset_storage_path, |
| | max_windows=max_windows, |
| | ) |
| |
|
| | |
| | season_length = get_seasonality(dataset.freq) |
| | actual_freq = ds_freq if ds_freq else dataset.freq |
| |
|
| | metadata = DatasetMetadata( |
| | full_name=f"{ds_key}/{actual_freq}/{term}", |
| | key=ds_key, |
| | freq=actual_freq, |
| | term=term, |
| | season_length=season_length, |
| | target_dim=probe_dataset.target_dim, |
| | to_univariate=to_univariate, |
| | prediction_length=dataset.prediction_length, |
| | windows=dataset.windows, |
| | ) |
| |
|
| | sub_datasets.append((dataset, metadata)) |
| |
|
| | return sub_datasets |
| |
|
| |
|
| | def evaluate_datasets( |
| | predictor: TimeSeriesPredictor, |
| | dataset: str, |
| | dataset_storage_path: str, |
| | terms: list[str] | None = None, |
| | max_windows: int | None = None, |
| | batch_size: int = 48, |
| | max_context_length: int | None = 1024, |
| | create_plots: bool = False, |
| | max_plots_per_dataset: int = 10, |
| | ) -> list[EvaluationItem]: |
| | """Evaluate predictor on one dataset across the requested terms.""" |
| | if terms is None: |
| | terms = ["short", "medium", "long"] |
| |
|
| | sub_datasets = construct_evaluation_data( |
| | dataset_name=dataset, |
| | dataset_storage_path=dataset_storage_path, |
| | terms=terms, |
| | max_windows=max_windows, |
| | ) |
| |
|
| | results: list[EvaluationItem] = [] |
| | for i, (sub_dataset, metadata) in enumerate(sub_datasets): |
| | logger.info(f"Evaluating {i + 1}/{len(sub_datasets)}: {metadata.full_name}") |
| | logger.info(f" Dataset size: {len(sub_dataset.test_data)}") |
| | logger.info(f" Frequency: {sub_dataset.freq}") |
| | logger.info(f" Term: {metadata.term}") |
| | logger.info(f" Prediction length: {sub_dataset.prediction_length}") |
| | logger.info(f" Target dimensions: {sub_dataset.target_dim}") |
| | logger.info(f" Windows: {sub_dataset.windows}") |
| |
|
| | |
| | predictor.set_dataset_context( |
| | prediction_length=sub_dataset.prediction_length, |
| | freq=sub_dataset.freq, |
| | batch_size=batch_size, |
| | max_context_length=max_context_length, |
| | ) |
| |
|
| | res = evaluate_model( |
| | model=predictor, |
| | test_data=sub_dataset.test_data, |
| | metrics=METRICS, |
| | axis=None, |
| | mask_invalid_label=True, |
| | allow_nan_forecast=False, |
| | seasonality=metadata.season_length, |
| | ) |
| |
|
| | figs: list[tuple[object, str]] = [] |
| | if create_plots: |
| | forecasts = predictor.predict(sub_dataset.test_data.input) |
| | figs = create_plots_for_dataset( |
| | forecasts=forecasts, |
| | test_data=sub_dataset.test_data, |
| | dataset_metadata=metadata, |
| | max_plots=max_plots_per_dataset, |
| | max_context_length=max_context_length, |
| | ) |
| |
|
| | results.append(EvaluationItem(dataset_metadata=metadata, metrics=res, figures=figs)) |
| |
|
| | return results |
| |
|
| |
|
| | def _run_evaluation( |
| | predictor: TimeSeriesPredictor, |
| | datasets: list[str] | str, |
| | terms: list[str], |
| | dataset_storage_path: str, |
| | max_windows: int | None = None, |
| | batch_size: int = 48, |
| | max_context_length: int | None = 1024, |
| | output_dir: str = "gift_eval_results", |
| | model_name: str = "TimeSeriesModel", |
| | create_plots: bool = False, |
| | max_plots: int = 10, |
| | ) -> None: |
| | """Shared evaluation workflow used by both entry points.""" |
| | datasets_to_run = expand_datasets_arg(datasets) |
| | results_root = Path(output_dir) |
| |
|
| | for ds_name in datasets_to_run: |
| | items = evaluate_datasets( |
| | predictor=predictor, |
| | dataset=ds_name, |
| | dataset_storage_path=dataset_storage_path, |
| | terms=terms, |
| | max_windows=max_windows, |
| | batch_size=batch_size, |
| | max_context_length=max_context_length, |
| | create_plots=create_plots, |
| | max_plots_per_dataset=max_plots, |
| | ) |
| | write_results_to_disk( |
| | items=items, |
| | dataset_name=ds_name, |
| | output_dir=results_root, |
| | model_name=model_name, |
| | create_plots=create_plots, |
| | ) |
| |
|
| |
|
| | def evaluate_from_paths( |
| | model_path: str, |
| | config_path: str, |
| | datasets: list[str] | str, |
| | terms: list[str], |
| | dataset_storage_path: str, |
| | max_windows: int | None = None, |
| | batch_size: int = 48, |
| | max_context_length: int | None = 1024, |
| | output_dir: str = "gift_eval_results", |
| | model_name: str = "TimeSeriesModel", |
| | create_plots: bool = False, |
| | max_plots: int = 10, |
| | ) -> None: |
| | """Entry point: load model from disk and save metrics/plots to disk.""" |
| | |
| | if not Path(model_path).exists(): |
| | raise FileNotFoundError(f"Model path does not exist: {model_path}") |
| | if not Path(config_path).exists(): |
| | raise FileNotFoundError(f"Config path does not exist: {config_path}") |
| |
|
| | predictor = TimeSeriesPredictor.from_paths( |
| | model_path=model_path, |
| | config_path=config_path, |
| | ds_prediction_length=1, |
| | ds_freq="D", |
| | batch_size=batch_size, |
| | max_context_length=max_context_length, |
| | ) |
| |
|
| | _run_evaluation( |
| | predictor=predictor, |
| | datasets=datasets, |
| | terms=terms, |
| | dataset_storage_path=dataset_storage_path, |
| | max_windows=max_windows, |
| | batch_size=batch_size, |
| | max_context_length=max_context_length, |
| | output_dir=output_dir, |
| | model_name=model_name, |
| | create_plots=create_plots, |
| | max_plots=max_plots, |
| | ) |
| |
|
| |
|
| | def evaluate_in_memory( |
| | model, |
| | config: dict, |
| | datasets: list[str] | str, |
| | terms: list[str], |
| | dataset_storage_path: str, |
| | max_windows: int | None = None, |
| | batch_size: int = 48, |
| | max_context_length: int | None = 1024, |
| | output_dir: str = "gift_eval_results", |
| | model_name: str = "TimeSeriesModel", |
| | create_plots: bool = False, |
| | max_plots: int = 10, |
| | ) -> None: |
| | """Entry point: evaluate in-memory model and return results per dataset.""" |
| | predictor = TimeSeriesPredictor.from_model( |
| | model=model, |
| | config=config, |
| | ds_prediction_length=1, |
| | ds_freq="D", |
| | batch_size=batch_size, |
| | max_context_length=max_context_length, |
| | ) |
| |
|
| | _run_evaluation( |
| | predictor=predictor, |
| | datasets=datasets, |
| | terms=terms, |
| | dataset_storage_path=dataset_storage_path, |
| | max_windows=max_windows, |
| | batch_size=batch_size, |
| | max_context_length=max_context_length, |
| | output_dir=output_dir, |
| | model_name=model_name, |
| | create_plots=create_plots, |
| | max_plots=max_plots, |
| | ) |
| |
|
| |
|
| | def _parse_args() -> argparse.Namespace: |
| | parser = argparse.ArgumentParser(description="Evaluate TimeSeriesModel on GIFT-Eval datasets") |
| |
|
| | |
| | parser.add_argument( |
| | "--model_path", |
| | type=str, |
| | required=True, |
| | help="Path to the trained model checkpoint", |
| | ) |
| | parser.add_argument( |
| | "--config_path", |
| | type=str, |
| | required=True, |
| | help="Path to the model configuration YAML file", |
| | ) |
| | parser.add_argument( |
| | "--model_name", |
| | type=str, |
| | default="TimeSeriesModel", |
| | help="Name identifier for the model", |
| | ) |
| |
|
| | |
| | parser.add_argument( |
| | "--datasets", |
| | type=str, |
| | default="all", |
| | help="Comma-separated list of dataset names to evaluate (or 'all')", |
| | ) |
| | parser.add_argument( |
| | "--dataset_storage_path", |
| | type=str, |
| | default="/work/dlclarge2/moroshav-GiftEvalPretrain/gift_eval", |
| | help="Path to the dataset storage directory (default: GIFT_EVAL)", |
| | ) |
| | parser.add_argument( |
| | "--terms", |
| | type=str, |
| | default="short,medium,long", |
| | help="Comma-separated list of prediction terms to evaluate", |
| | ) |
| | parser.add_argument( |
| | "--max_windows", |
| | type=int, |
| | default=None, |
| | help="Maximum number of windows to use for evaluation", |
| | ) |
| |
|
| | |
| | parser.add_argument("--batch_size", type=int, default=48, help="Batch size for model inference") |
| | parser.add_argument( |
| | "--max_context_length", |
| | type=int, |
| | default=1024, |
| | help="Maximum context length to use (None for no limit)", |
| | ) |
| |
|
| | |
| | parser.add_argument( |
| | "--output_dir", |
| | type=str, |
| | default="gift_eval_results", |
| | help="Directory to save evaluation results", |
| | ) |
| |
|
| | |
| | parser.add_argument( |
| | "--create_plots", |
| | action="store_true", |
| | help="Create and save plots for each evaluation window", |
| | ) |
| | parser.add_argument( |
| | "--max_plots_per_dataset", |
| | type=int, |
| | default=10, |
| | help="Maximum number of plots to create per dataset term", |
| | ) |
| |
|
| | args = parser.parse_args() |
| | args.terms = args.terms.split(",") |
| | args.datasets = args.datasets.split(",") |
| | return args |
| |
|
| |
|
| | def _configure_logging() -> None: |
| | logging.basicConfig( |
| | level=logging.INFO, |
| | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | _configure_logging() |
| | args = _parse_args() |
| | logger.info(f"Command Line Arguments: {vars(args)}") |
| | try: |
| | evaluate_from_paths( |
| | model_path=args.model_path, |
| | config_path=args.config_path, |
| | datasets=args.datasets, |
| | terms=args.terms, |
| | dataset_storage_path=args.dataset_storage_path, |
| | max_windows=args.max_windows, |
| | batch_size=args.batch_size, |
| | max_context_length=args.max_context_length, |
| | output_dir=args.output_dir, |
| | model_name=args.model_name, |
| | create_plots=args.create_plots, |
| | max_plots=args.max_plots_per_dataset, |
| | ) |
| | except Exception as e: |
| | logger.error(f"Evaluation failed: {str(e)}") |
| | raise |
| |
|