| """Safer checkpoint loading. |
| |
| Centralizes ``torch.load`` calls so the kit has one place to enforce loading |
| policy. By default this uses ``weights_only=True`` (PyTorch 2.6+ default), |
| which refuses pickled Python objects and only accepts plain tensors and |
| basic containers — neutralizing the standard pickle-based code-execution |
| vector against malicious .pt files. |
| |
| Older Tilelli checkpoints carry richer Python objects (config dicts, |
| metadata blobs) and need the legacy unpickling path. Callers that load |
| such checkpoints from trusted sources pass ``trusted=True``, which is an |
| explicit, greppable opt-in instead of a silent ``weights_only=False`` |
| scattered across the codebase. |
| """ |
| from __future__ import annotations |
|
|
| import os |
| from pathlib import Path |
| from typing import Any, Union |
|
|
| import torch |
|
|
| PathLike = Union[str, os.PathLike] |
|
|
|
|
| def safe_load_checkpoint( |
| path: PathLike, |
| *, |
| map_location: str = "cpu", |
| trusted: bool = False, |
| ) -> Any: |
| """Load a .pt file with safety defaults. |
| |
| Args: |
| path: checkpoint file path. |
| map_location: torch.load map_location (default 'cpu'). |
| trusted: when True, allows the legacy pickled-object path |
| (``weights_only=False``). Use only for checkpoints whose |
| provenance the caller has verified. Required for the kit's |
| own training checkpoints, which serialize a config dict |
| alongside state_dict. |
| |
| Returns: |
| Whatever torch.load returns: a state_dict, a wrapper dict, or |
| a richer object for legacy checkpoints. |
| |
| Raises: |
| FileNotFoundError: if the path does not exist. |
| """ |
| p = Path(path) |
| if not p.exists(): |
| raise FileNotFoundError(f"checkpoint not found: {p}") |
| return torch.load(str(p), map_location=map_location, weights_only=not trusted) |
|
|