| from torch import nn |
|
|
| from .quantization import BitLinear |
|
|
|
|
| def replace_linears_in_hf( |
| model, name_skip = 'lm_head' |
| ): |
| """ |
| Replaces all instances of nn.Linear in the given model with BitLinear15b. |
| |
| Args: |
| model (nn.Module): The model to modify. |
| |
| Returns: |
| None |
| """ |
| for name, module in model.named_children(): |
| if isinstance(module, nn.Linear) and name != name_skip: |
| |
| setattr( |
| model, |
| name, |
| BitLinear( |
| in_features=module.in_features, |
| out_features=module.out_features, |
| bias=module.bias is not None, |
| ), |
| ) |
| else: |
| |
| replace_linears_in_hf(module) |
|
|
|
|
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def final_quantization(model): |
| for name, module in model.named_children(): |
| if isinstance(module, BitLinear): |
| |
| module.weight.data = weight_quant(module.weight.data) |
| if module.bias is not None: |
| module.bias.data = activation_quant(module.bias.data, module.input_bits) |
| else: |
| |
| final_quantization(module) |
|
|