File size: 1,107 Bytes
ca7299e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch

def load_model_checkpoint(model, ckpt_path):
    """Load state dict from checkpoint file.

    :param model: The model to load the state dict into.
    :param ckpt_path: The path to the checkpoint file.
    """
    if ckpt_path is None:
        return model, None
    
    # The ckpt_path ending with .ckpt is a checkpoint file saved by pytorch-lightning.
    # If the ckpt_path is a .pth file, it is viewed as a checkpoint file saved by pytorch
    # such that only net parameters are loaded. 
    # (This may avoid the ambiguity of loading #epochs/lr for finetuning)
    if ckpt_path.endswith(".pth"):  
        net_params = torch.load(ckpt_path, map_location=torch.device('cpu'))['state_dict']
        net_params = {k.replace('net.', ''): v for k, v in net_params.items()}
        model.net.load_state_dict(net_params)
        ckpt_path = None
    elif ckpt_path.endswith(".ckpt"):
        # will be handled later by the trainer
        pass
    else:
        # suffix check
        raise ValueError(f"ckpt_path {ckpt_path} is not a valid checkpoint file.")
    
    return model, ckpt_path