| | import torch |
| |
|
| | def load_model_checkpoint(model, ckpt_path): |
| | """Load state dict from checkpoint file. |
| | |
| | :param model: The model to load the state dict into. |
| | :param ckpt_path: The path to the checkpoint file. |
| | """ |
| | if ckpt_path is None: |
| | return model, None |
| | |
| | |
| | |
| | |
| | |
| | if ckpt_path.endswith(".pth"): |
| | net_params = torch.load(ckpt_path, map_location=torch.device('cpu'))['state_dict'] |
| | net_params = {k.replace('net.', ''): v for k, v in net_params.items()} |
| | model.net.load_state_dict(net_params) |
| | ckpt_path = None |
| | elif ckpt_path.endswith(".ckpt"): |
| | |
| | pass |
| | else: |
| | |
| | raise ValueError(f"ckpt_path {ckpt_path} is not a valid checkpoint file.") |
| | |
| | return model, ckpt_path |