| from typing import Dict, Any |
| import torch.nn as nn |
|
|
| def extract_model_weights(reference_model, n_layers): |
| params = {} |
| current_layer = 0 |
|
|
| |
| for name, module in reference_model.named_modules(): |
|
|
| |
| if hasattr(module, 'weight') and module.weight is not None: |
| params[name + '.weight'] = module.weight.data.clone() |
| if hasattr(module, 'bias') and module.bias is not None: |
| params[name + '.bias'] = module.bias.data.clone() |
|
|
| if 'model.layers.' in name: |
| |
| layer_index = int(name.split('.')[2]) |
| if layer_index > current_layer: |
| current_layer = layer_index |
| if current_layer > n_layers-1: |
| break |
|
|
| norm_layer = reference_model.model.norm |
| if hasattr(norm_layer, 'weight') and norm_layer.weight is not None: |
| params['model.norm.weight'] = norm_layer.weight.data.clone() |
| if hasattr(norm_layer, 'bias') and norm_layer.bias is not None: |
| params['model.norm.bias'] = norm_layer.bias.data.clone() |
|
|
| lm_head = reference_model.lm_head |
| if hasattr(lm_head, 'weight') and lm_head.weight is not None: |
| params["lm_head.weight"] = lm_head.weight.data |
| if hasattr(lm_head, 'bias') and lm_head.bias is not None: |
| params["lm_head.bias"] = lm_head.bias.data |
|
|
| return params |
|
|
|
|