| import torch |
| import numpy as np |
| from graph_decoder.diffusion_model import GraphDiT |
|
|
| def count_parameters(model): |
| r""" |
| Returns the number of trainable parameters and number of all parameters in the model. |
| """ |
| trainable_params, all_param = 0, 0 |
| for param in model.parameters(): |
| num_params = param.numel() |
| all_param += num_params |
| if param.requires_grad: |
| trainable_params += num_params |
|
|
| return trainable_params, all_param |
|
|
| def load_graph_decoder(path='model_labeled'): |
| model_config_path = f"{path}/config.yaml" |
| data_info_path = f"{path}/data.meta.json" |
|
|
| model = GraphDiT( |
| model_config_path=model_config_path, |
| data_info_path=data_info_path, |
| |
| model_dtype=torch.float32, |
| ) |
| model.init_model(path) |
| model.disable_grads() |
|
|
| trainable_params, all_param = count_parameters(model) |
| param_stats = "Loaded Graph DiT from {} trainable params: {:,} || all params: {:,} || trainable%: {:.4f}".format( |
| path, trainable_params, all_param, 100 * trainable_params / all_param |
| ) |
| print(param_stats) |
| return model |
|
|