| | from typing import Dict, Any |
| | import torch.nn as nn |
| |
|
| | def extract_model_weights(reference_model, n_layers): |
| | params = {} |
| | current_layer = 0 |
| |
|
| | |
| | for name, module in reference_model.named_modules(): |
| |
|
| | |
| | if hasattr(module, 'weight') and module.weight is not None: |
| | params[name + '.weight'] = module.weight.data.clone() |
| | if hasattr(module, 'bias') and module.bias is not None: |
| | params[name + '.bias'] = module.bias.data.clone() |
| |
|
| | if 'model.layers.' in name: |
| | |
| | layer_index = int(name.split('.')[2]) |
| | if layer_index > current_layer: |
| | current_layer = layer_index |
| | if current_layer > n_layers-1: |
| | break |
| |
|
| | norm_layer = reference_model.model.norm |
| | if hasattr(norm_layer, 'weight') and norm_layer.weight is not None: |
| | params['model.norm.weight'] = norm_layer.weight.data.clone() |
| | if hasattr(norm_layer, 'bias') and norm_layer.bias is not None: |
| | params['model.norm.bias'] = norm_layer.bias.data.clone() |
| |
|
| | lm_head = reference_model.lm_head |
| | if hasattr(lm_head, 'weight') and lm_head.weight is not None: |
| | params["lm_head.weight"] = lm_head.weight.data |
| | if hasattr(lm_head, 'bias') and lm_head.bias is not None: |
| | params["lm_head.bias"] = lm_head.bias.data |
| |
|
| | return params |
| |
|
| |
|