File size: 1,816 Bytes
f86dc09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
"""Safer checkpoint loading.

Centralizes ``torch.load`` calls so the kit has one place to enforce loading
policy. By default this uses ``weights_only=True`` (PyTorch 2.6+ default),
which refuses pickled Python objects and only accepts plain tensors and
basic containers — neutralizing the standard pickle-based code-execution
vector against malicious .pt files.

Older Tilelli checkpoints carry richer Python objects (config dicts,
metadata blobs) and need the legacy unpickling path. Callers that load
such checkpoints from trusted sources pass ``trusted=True``, which is an
explicit, greppable opt-in instead of a silent ``weights_only=False``
scattered across the codebase.
"""
from __future__ import annotations

import os
from pathlib import Path
from typing import Any, Union

import torch

PathLike = Union[str, os.PathLike]


def safe_load_checkpoint(
    path: PathLike,
    *,
    map_location: str = "cpu",
    trusted: bool = False,
) -> Any:
    """Load a .pt file with safety defaults.

    Args:
        path: checkpoint file path.
        map_location: torch.load map_location (default 'cpu').
        trusted: when True, allows the legacy pickled-object path
            (``weights_only=False``). Use only for checkpoints whose
            provenance the caller has verified. Required for the kit's
            own training checkpoints, which serialize a config dict
            alongside state_dict.

    Returns:
        Whatever torch.load returns: a state_dict, a wrapper dict, or
        a richer object for legacy checkpoints.

    Raises:
        FileNotFoundError: if the path does not exist.
    """
    p = Path(path)
    if not p.exists():
        raise FileNotFoundError(f"checkpoint not found: {p}")
    return torch.load(str(p), map_location=map_location, weights_only=not trusted)