| | from torch import nn |
| |
|
| | from .quantization import BitLinear |
| |
|
| |
|
| | def replace_linears_in_hf( |
| | model, name_skip = 'lm_head' |
| | ): |
| | """ |
| | Replaces all instances of nn.Linear in the given model with BitLinear15b. |
| | |
| | Args: |
| | model (nn.Module): The model to modify. |
| | |
| | Returns: |
| | None |
| | """ |
| | for name, module in model.named_children(): |
| | if isinstance(module, nn.Linear) and name != name_skip: |
| | |
| | setattr( |
| | model, |
| | name, |
| | BitLinear( |
| | in_features=module.in_features, |
| | out_features=module.out_features, |
| | bias=module.bias is not None, |
| | ), |
| | ) |
| | else: |
| | |
| | replace_linears_in_hf(module) |
| |
|
| |
|
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | def final_quantization(model): |
| | for name, module in model.named_children(): |
| | if isinstance(module, BitLinear): |
| | |
| | module.weight.data = weight_quant(module.weight.data) |
| | if module.bias is not None: |
| | module.bias.data = activation_quant(module.bias.data, module.input_bits) |
| | else: |
| | |
| | final_quantization(module) |
| |
|