| | |
| |
|
| | import math |
| | from functools import partial |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch.nn import init |
| |
|
| |
|
| | class StructuredLinear(nn.Module): |
| |
|
| | def __init__(self, in_features, out_features, bias=True, device=None, dtype=None): |
| | """Subclasses should call reset_parameters |
| | """ |
| | factory_kwargs = {'device': device, 'dtype': dtype} |
| | super().__init__() |
| | self.in_features = in_features |
| | self.out_features = out_features |
| | |
| | if not hasattr(self, 'in_features_extended'): |
| | self.in_features_extended = in_features |
| | if not hasattr(self, 'out_features_extended'): |
| | self.out_features_extended = out_features |
| | if bias: |
| | self.bias = nn.Parameter(torch.zeros(out_features, **factory_kwargs)) |
| | else: |
| | self.register_parameter('bias', None) |
| |
|
| | def reset_parameters(self) -> None: |
| | self.set_weights_from_dense_init(dense_init_fn_=partial(init.kaiming_uniform_, a=math.sqrt(5))) |
| | self.reset_parameters_bias() |
| |
|
| | def set_weights_from_dense_init(self, dense_init_fn_): |
| | raise NotImplementedError |
| |
|
| | def reset_parameters_bias(self): |
| | if self.bias is not None: |
| | fan_in = self.bias.shape[-1] |
| | bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 |
| | init.uniform_(self.bias, -bound, bound) |
| |
|
| | @property |
| | def saving(self): |
| | raise NotImplementedError |
| |
|
| | def convert_to_dense_weight(self): |
| | factory_kwargs = {'device': self.weight.device, 'dtype': self.weight.dtype} |
| | dense_weight = self.forward_matmul(torch.eye(self.in_features, **factory_kwargs)).T |
| | return dense_weight |
| |
|
| | def preprocess(self, x): |
| | in_features = x.shape[-1] |
| | if in_features < self.in_features_extended: |
| | x = F.pad(x, (0, self.in_features_extended - in_features)) |
| | return x |
| |
|
| | def postprocess(self, output): |
| | out_features_extended = output.shape[-1] |
| | if out_features_extended > self.out_features: |
| | output = output[..., :self.out_features] |
| | return output |
| |
|
| | def forward_matmul(self, x): |
| | raise NotImplementedError |
| |
|
| | def forward(self, x): |
| | output = self.forward_matmul(x) |
| | |
| | return (output + self.bias.to(dtype=output.dtype)) if self.bias is not None else output |