| |
| from typing import Any, Callable, Dict, List, Optional, Union, Iterable |
| import lightning.pytorch as pl |
| import torch |
| from pathlib import Path |
| import os |
| import re |
| from loguru import logger |
| from lightning.pytorch.utilities.consolidate_checkpoint import ( |
| _format_checkpoint, |
| _load_distributed_checkpoint, |
| ) |
| from glob import glob |
|
|
| from sam3d_objects.data.utils import get_child, set_child |
|
|
|
|
| def rename_checkpoint_weights_using_suffix_matching( |
| checkpoint_path_in, |
| checkpoint_path_out, |
| model: torch.nn.Module, |
| strict: bool = True, |
| keys: Optional[List[Any]] = (), |
| ): |
| |
| param_names = [n for n, _ in model.named_parameters()] |
| buffer_names = [n for n, _ in model.named_buffers()] |
| model_names = param_names + buffer_names |
|
|
| |
| state = torch.load(checkpoint_path_in, weights_only=False) |
|
|
| model_state = get_child(state, *keys) |
| model_state_names = list(model_state.keys()) |
|
|
| |
| model_names_rev = sorted([n[::-1] for n in model_names]) |
| model_state_names_rev = sorted([n[::-1] for n in model_state_names]) |
|
|
| if strict and len(model_names) != len(model_state_names): |
| raise RuntimeError( |
| f"model and state don't have the same number of parameters ({len(model_names)} != {len(model_state_names)}), cannot match them (set strict = False to relax constraint)" |
| ) |
|
|
| def common_prefix_length(str_0: str, str_1: str): |
| for count in range(min(len(str_0), len(str_1))): |
| if str_0[count] != str_1[count]: |
| break |
| return count |
|
|
| |
| name_mapping = {} |
| i, j = 0, 0 |
| last_n = 0 |
| while i < len(model_names_rev): |
| if j < len(model_state_names_rev): |
| n = common_prefix_length(model_names_rev[i], model_state_names_rev[j]) |
| else: |
| n = 0 |
|
|
| if n >= last_n: |
| last_n = n |
| j += 1 |
| else: |
| last_n = 0 |
| name_mapping[model_names_rev[i][::-1]] = model_state_names_rev[j - 1][::-1] |
| i += 1 |
|
|
| if not j < len(model_state_names_rev) + 1: |
| break |
|
|
| |
| if i < len(model_names): |
| raise RuntimeError("could not suffix match parameter names") |
|
|
| for k, v in name_mapping.items(): |
| logger.debug(f"{k} <- {v}") |
|
|
| |
| model_state_out = {k: model_state[v] for k, v in name_mapping.items()} |
| set_child(state, model_state_out, *keys) |
| torch.save(state, checkpoint_path_out) |
|
|
|
|
| def remove_prefix_state_dict_fn(prefix: str): |
| n = len(prefix) |
|
|
| def state_dict_fn(state_dict): |
| return { |
| (key[n:] if key.startswith(prefix) else key): value |
| for key, value in state_dict.items() |
| } |
|
|
| return state_dict_fn |
|
|
|
|
| def add_prefix_state_dict_fn(prefix: str): |
| def state_dict_fn(state_dict): |
| return {prefix + key: value for key, value in state_dict.items()} |
|
|
| return state_dict_fn |
|
|
|
|
| def filter_and_remove_prefix_state_dict_fn(prefix: str): |
| n = len(prefix) |
|
|
| def state_dict_fn(state_dict): |
| return { |
| key[n:]: value |
| for key, value in state_dict.items() |
| if key.startswith(prefix) |
| } |
|
|
| return state_dict_fn |
|
|
|
|
| def get_last_checkpoint(path: str): |
| checkpoints = glob(os.path.join(path, "epoch=*-step=*.ckpt")) |
| prog = re.compile(r"epoch=(\d+)-step=(\d+).ckpt") |
|
|
| checkpoints_to_sort = [] |
| for checkpoint in checkpoints: |
| checkpoint_name = os.path.basename(checkpoint) |
| match = prog.match(checkpoint_name) |
| if match is not None: |
| n_epoch, n_step = prog.match(checkpoint_name).groups() |
| n_epoch, n_step = int(n_epoch), int(n_step) |
| checkpoints_to_sort.append((n_epoch, n_step, checkpoint)) |
|
|
| sorted_checkpoints = sorted(checkpoints_to_sort) |
| if not len(sorted_checkpoints) > 0: |
| raise RuntimeError(f"no checkpoint has been found at path : {path}") |
| return sorted_checkpoints[-1][2] |
|
|
|
|
| def load_sharded_checkpoint(path: str, device: Optional[str]): |
| if device != "cpu": |
| raise RuntimeError( |
| f'loading sharded weights on device "{device}" is not available, please use the "cpu" device instead' |
| ) |
| checkpoint = _load_distributed_checkpoint(Path(path)) |
| checkpoint = _format_checkpoint(checkpoint) |
| return checkpoint |
|
|
|
|
| def load_model_from_checkpoint( |
| model: Union[pl.LightningModule, torch.nn.Module], |
| checkpoint_path: str, |
| strict: bool = True, |
| device: Optional[str] = None, |
| freeze: bool = False, |
| eval: bool = False, |
| map_name: Union[Dict[str, str], None] = None, |
| remove_name: Union[List[str], None] = None, |
| state_dict_key: Union[None, str, Iterable[str]] = "state_dict", |
| state_dict_fn: Optional[Callable[[Any], Any]] = None, |
| ): |
| logger.info(f"Loading checkpoint from {checkpoint_path}") |
| if os.path.isfile(checkpoint_path): |
| checkpoint = torch.load( |
| checkpoint_path, |
| map_location=device, |
| weights_only=False, |
| ) |
| elif os.path.isdir(checkpoint_path): |
| checkpoint = load_sharded_checkpoint(checkpoint_path, device=device) |
| else: |
| raise FileNotFoundError(checkpoint_path) |
|
|
| if isinstance(model, pl.LightningModule): |
| model.on_load_checkpoint(checkpoint) |
|
|
| |
| state_dict = checkpoint |
| if state_dict_key is not None: |
| if isinstance(state_dict_key, str): |
| state_dict_key = (state_dict_key,) |
| state_dict = get_child(state_dict, *state_dict_key) |
|
|
| |
| if remove_name is not None: |
| for name in remove_name: |
| del state_dict[name] |
|
|
| |
| if map_name is not None: |
| for src, dst in map_name.items(): |
| if src not in state_dict: |
| continue |
| state_dict[dst] = state_dict[src] |
| del state_dict[src] |
|
|
| |
| if state_dict_fn is not None: |
| state_dict = state_dict_fn(state_dict) |
|
|
| model.load_state_dict(state_dict, strict=strict) |
|
|
| if device is not None: |
| model = model.to(device) |
|
|
| if freeze: |
| for param in model.parameters(): |
| param.requires_grad = False |
| eval = True |
|
|
| if eval: |
| model.eval() |
|
|
| return model |
|
|