Tilelli-llm / src /tilelli /utils /checkpoint.py
TilelliLab's picture
Mirror small files (code, paper, results)
f86dc09 verified
Raw
History Blame Contribute Delete
1.82 kB
"""Safer checkpoint loading.
Centralizes ``torch.load`` calls so the kit has one place to enforce loading
policy. By default this uses ``weights_only=True`` (PyTorch 2.6+ default),
which refuses pickled Python objects and only accepts plain tensors and
basic containers — neutralizing the standard pickle-based code-execution
vector against malicious .pt files.
Older Tilelli checkpoints carry richer Python objects (config dicts,
metadata blobs) and need the legacy unpickling path. Callers that load
such checkpoints from trusted sources pass ``trusted=True``, which is an
explicit, greppable opt-in instead of a silent ``weights_only=False``
scattered across the codebase.
"""
from __future__ import annotations
import os
from pathlib import Path
from typing import Any, Union
import torch
PathLike = Union[str, os.PathLike]
def safe_load_checkpoint(
path: PathLike,
*,
map_location: str = "cpu",
trusted: bool = False,
) -> Any:
"""Load a .pt file with safety defaults.
Args:
path: checkpoint file path.
map_location: torch.load map_location (default 'cpu').
trusted: when True, allows the legacy pickled-object path
(``weights_only=False``). Use only for checkpoints whose
provenance the caller has verified. Required for the kit's
own training checkpoints, which serialize a config dict
alongside state_dict.
Returns:
Whatever torch.load returns: a state_dict, a wrapper dict, or
a richer object for legacy checkpoints.
Raises:
FileNotFoundError: if the path does not exist.
"""
p = Path(path)
if not p.exists():
raise FileNotFoundError(f"checkpoint not found: {p}")
return torch.load(str(p), map_location=map_location, weights_only=not trusted)