| | from .models.autoencoders import create_autoencoder_from_config
|
| | import os
|
| | import json
|
| | import torch
|
| | from torch.nn.utils import remove_weight_norm
|
| |
|
| |
|
| | def remove_all_weight_norm(model):
|
| | for name, module in model.named_modules():
|
| | if hasattr(module, 'weight_g'):
|
| | remove_weight_norm(module)
|
| |
|
| |
|
| | def load_vae(ckpt_path, remove_weight_norm=False):
|
| | config_file = os.path.join(os.path.dirname(ckpt_path), 'config.json')
|
| |
|
| |
|
| | with open(config_file) as f:
|
| | model_config = json.load(f)
|
| |
|
| |
|
| | model = create_autoencoder_from_config(model_config)
|
| |
|
| |
|
| | model_dict = torch.load(ckpt_path, map_location='cpu')['state_dict']
|
| |
|
| |
|
| | model_dict = {key[len("autoencoder."):]: value for key, value in model_dict.items() if key.startswith("autoencoder.")}
|
| |
|
| |
|
| | model.load_state_dict(model_dict)
|
| |
|
| |
|
| | if remove_weight_norm:
|
| | remove_all_weight_norm(model)
|
| |
|
| |
|
| | model.eval()
|
| |
|
| | return model
|
| |
|