| | |
| | |
| |
|
| | import copy |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| |
|
| | def MultiwayWrapper(args, module, dim=1): |
| | if args.multiway: |
| | return MultiwayNetwork(module, dim=dim) |
| | return module |
| |
|
| |
|
| | def set_split_position(position): |
| | def apply_fn(module): |
| | if hasattr(module, "split_position"): |
| | module.split_position = position |
| |
|
| | return apply_fn |
| |
|
| |
|
| | class MultiwayNetwork(nn.Module): |
| | def __init__(self, module, dim=1): |
| | super().__init__() |
| | self.dim = dim |
| | self.A = module |
| | self.B = copy.deepcopy(module) |
| | self.B.reset_parameters() |
| | self.split_position = -1 |
| |
|
| | def forward(self, x, **kwargs): |
| | if self.split_position == -1: |
| | return self.A(x, **kwargs) |
| | if self.split_position == 0: |
| | return self.B(x, **kwargs) |
| | x1, x2 = torch.split( |
| | x, |
| | [self.split_position, x.size(self.dim) - self.split_position], |
| | dim=self.dim, |
| | ) |
| | |
| | y1, y2 = self.A(x1, **kwargs), self.B(x2, **kwargs) |
| | return torch.cat([y1, y2], dim=self.dim) |
| |
|
| |
|
| | class MutliwayEmbedding(MultiwayNetwork): |
| | def __init__(self, modules, dim=1): |
| | super(MultiwayNetwork, self).__init__() |
| | self.dim = dim |
| | assert len(modules) == 2 |
| | self.A = modules[0] |
| | self.B = modules[1] |
| | self.split_position = -1 |
| |
|