| import torch |
| from allennlp.modules.feedforward import FeedForward |
| from allennlp.modules.seq2vec_encoders import PytorchSeq2VecWrapper |
| from higher.patch import monkeypatch as make_functional |
|
|
|
|
| class ConditionedParameter(torch.nn.Module): |
| def __init__(self, parameter, condition_dim=1024, hidden_dim=128, max_scale=1): |
| super().__init__() |
| self.parameter_shape = parameter.shape |
|
|
| if len(self.parameter_shape) == 2: |
| self.conditioners = torch.nn.Sequential( |
| torch.nn.utils.weight_norm(torch.nn.Linear(condition_dim, hidden_dim)), |
| torch.nn.Tanh(), |
| torch.nn.utils.weight_norm( |
| torch.nn.Linear( |
| hidden_dim, 2 * (parameter.shape[0] + parameter.shape[1]) + 1 |
| ) |
| ), |
| ) |
| elif len(self.parameter_shape) == 1: |
| self.conditioners = torch.nn.Sequential( |
| torch.nn.utils.weight_norm(torch.nn.Linear(condition_dim, hidden_dim)), |
| torch.nn.Tanh(), |
| torch.nn.utils.weight_norm( |
| torch.nn.Linear(hidden_dim, 2 * parameter.shape[0] + 1) |
| ), |
| ) |
| else: |
| raise RuntimeError() |
|
|
| self.max_scale = max_scale |
|
|
| def forward(self, inputs, grad): |
|
|
| if len(self.parameter_shape) == 2: |
| ( |
| conditioner_cola, |
| conditioner_rowa, |
| conditioner_colb, |
| conditioner_rowb, |
| conditioner_norm, |
| ) = self.conditioners(inputs).split( |
| [ |
| self.parameter_shape[1], |
| self.parameter_shape[0], |
| self.parameter_shape[1], |
| self.parameter_shape[0], |
| 1, |
| ], |
| dim=-1, |
| ) |
|
|
| a = conditioner_rowa.softmax(-1).T @ conditioner_cola |
| b = conditioner_rowb.softmax(-1).T @ conditioner_colb |
|
|
| elif len(self.parameter_shape) == 1: |
| a, b, conditioner_norm = self.conditioners(inputs).split( |
| [self.parameter_shape[0], self.parameter_shape[0], 1], dim=-1 |
| ) |
| else: |
| raise RuntimeError() |
|
|
| return ( |
| self.max_scale |
| * torch.mean(conditioner_norm.sigmoid(), dim=0).squeeze() |
| * (grad * a.squeeze() + b.squeeze()) |
| ) |
|
|
|
|
| class LSTMConditioner(torch.nn.Module): |
| def __init__( |
| self, |
| vocab_dim=30522, |
| embedding_dim=768, |
| hidden_dim=256, |
| output_dim=1024, |
| embedding_init=None, |
| ): |
| super().__init__() |
| self.embedding = torch.nn.Embedding( |
| num_embeddings=vocab_dim, |
| embedding_dim=embedding_dim, |
| padding_idx=0, |
| _weight=embedding_init, |
| ) |
| self.lstm = PytorchSeq2VecWrapper( |
| torch.nn.LSTM( |
| input_size=embedding_dim, |
| hidden_size=hidden_dim, |
| num_layers=1, |
| bidirectional=True, |
| batch_first=True, |
| ) |
| ) |
| self.linear = FeedForward( |
| input_dim=hidden_dim * 2, |
| num_layers=1, |
| hidden_dims=[output_dim], |
| activations=[torch.nn.Tanh()], |
| ) |
|
|
| def forward(self, inputs, masks): |
| return self.linear(self.lstm(self.embedding(inputs), masks)) |
|
|
|
|
| class OneShotLearner(torch.nn.Module): |
| def __init__( |
| self, |
| model, |
| vocab_dim=30522, |
| embedding_dim=768, |
| hidden_dim=128, |
| condition_dim=1024, |
| include_set={}, |
| max_scale=1e-3, |
| embedding_init=None, |
| ): |
| super().__init__() |
|
|
| self.param2conditioner_map = { |
| n: "{}_conditioner".format(n).replace(".", "_") |
| for n, p in model.named_parameters() |
| if n in include_set |
| } |
|
|
| self.conditioners = torch.nn.ModuleDict( |
| { |
| self.param2conditioner_map[n]: ConditionedParameter( |
| p, |
| condition_dim, |
| hidden_dim, |
| max_scale=max_scale, |
| ) |
| for n, p in model.named_parameters() |
| if n in include_set |
| } |
| ) |
|
|
| self.condition = LSTMConditioner( |
| vocab_dim, |
| embedding_dim, |
| hidden_dim, |
| condition_dim, |
| embedding_init=embedding_init, |
| ) |
|
|
| def forward(self, inputs, masks, grads=None): |
| condition = self.condition(inputs, masks) |
| return { |
| p: self.conditioners[self.param2conditioner_map[p]]( |
| condition, |
| grad=grads[p] if grads else None, |
| ) |
| for p, c in self.param2conditioner_map.items() |
| } |
|
|