| | import logging |
| | import os |
| |
|
| | import numpy as np |
| | import torch |
| | import yaml |
| | from src.data.containers import BatchTimeSeriesContainer |
| | from src.models.model import TimeSeriesModel |
| | from src.plotting.plot_timeseries import plot_from_container |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | def load_model(config_path: str, model_path: str, device: torch.device) -> TimeSeriesModel: |
| | """Load the TimeSeriesModel from config and checkpoint.""" |
| | with open(config_path) as f: |
| | config = yaml.safe_load(f) |
| |
|
| | model = TimeSeriesModel(**config["TimeSeriesModel"]).to(device) |
| | checkpoint = torch.load(model_path, map_location=device) |
| | model.load_state_dict(checkpoint["model_state_dict"]) |
| | model.eval() |
| | logger.info(f"Successfully loaded TimeSeriesModel from {model_path} on {device}") |
| | return model |
| |
|
| |
|
| | def plot_with_library( |
| | container: BatchTimeSeriesContainer, |
| | predictions_np: np.ndarray, |
| | model_quantiles: list[float] | None, |
| | output_dir: str = "outputs", |
| | show_plots: bool = True, |
| | save_plots: bool = True, |
| | ): |
| | os.makedirs(output_dir, exist_ok=True) |
| | batch_size = container.batch_size |
| | for i in range(batch_size): |
| | output_file = os.path.join(output_dir, f"sine_wave_prediction_sample_{i + 1}.png") if save_plots else None |
| | plot_from_container( |
| | batch=container, |
| | sample_idx=i, |
| | predicted_values=predictions_np, |
| | model_quantiles=model_quantiles, |
| | title=f"Sine Wave Time Series Prediction - Sample {i + 1}", |
| | output_file=output_file, |
| | show=show_plots, |
| | ) |
| |
|
| |
|
| | def run_inference_and_plot( |
| | model: TimeSeriesModel, |
| | container: BatchTimeSeriesContainer, |
| | output_dir: str = "outputs", |
| | use_bfloat16: bool = True, |
| | ) -> None: |
| | """Run model inference with optional bfloat16 and plot using shared utilities.""" |
| | device_type = "cuda" if (container.history_values.device.type == "cuda") else "cpu" |
| | autocast_enabled = use_bfloat16 and device_type == "cuda" |
| | with ( |
| | torch.no_grad(), |
| | torch.autocast(device_type=device_type, dtype=torch.bfloat16, enabled=autocast_enabled), |
| | ): |
| | model_output = model(container) |
| |
|
| | preds_full = model_output["result"].to(torch.float32) |
| | if hasattr(model, "scaler") and "scale_statistics" in model_output: |
| | preds_full = model.scaler.inverse_scale(preds_full, model_output["scale_statistics"]) |
| |
|
| | preds_np = preds_full.detach().cpu().numpy() |
| | model_quantiles = model.quantiles if getattr(model, "loss_type", None) == "quantile" else None |
| | plot_with_library( |
| | container=container, |
| | predictions_np=preds_np, |
| | model_quantiles=model_quantiles, |
| | output_dir=output_dir, |
| | show_plots=True, |
| | save_plots=True, |
| | ) |
| |
|