"""Safer checkpoint loading. Centralizes ``torch.load`` calls so the kit has one place to enforce loading policy. By default this uses ``weights_only=True`` (PyTorch 2.6+ default), which refuses pickled Python objects and only accepts plain tensors and basic containers — neutralizing the standard pickle-based code-execution vector against malicious .pt files. Older Tilelli checkpoints carry richer Python objects (config dicts, metadata blobs) and need the legacy unpickling path. Callers that load such checkpoints from trusted sources pass ``trusted=True``, which is an explicit, greppable opt-in instead of a silent ``weights_only=False`` scattered across the codebase. """ from __future__ import annotations import os from pathlib import Path from typing import Any, Union import torch PathLike = Union[str, os.PathLike] def safe_load_checkpoint( path: PathLike, *, map_location: str = "cpu", trusted: bool = False, ) -> Any: """Load a .pt file with safety defaults. Args: path: checkpoint file path. map_location: torch.load map_location (default 'cpu'). trusted: when True, allows the legacy pickled-object path (``weights_only=False``). Use only for checkpoints whose provenance the caller has verified. Required for the kit's own training checkpoints, which serialize a config dict alongside state_dict. Returns: Whatever torch.load returns: a state_dict, a wrapper dict, or a richer object for legacy checkpoints. Raises: FileNotFoundError: if the path does not exist. """ p = Path(path) if not p.exists(): raise FileNotFoundError(f"checkpoint not found: {p}") return torch.load(str(p), map_location=map_location, weights_only=not trusted)