| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | from dataclasses import dataclass |
| |
|
| | from transformers import PretrainedConfig, PreTrainedModel |
| | from transformers.utils import ModelOutput |
| |
|
| | from .configuration_mapper import MapperConfig |
| |
|
| | class FeedForward(nn.Module): |
| | def __init__(self, d_in, d_out): |
| | super().__init__() |
| | self.fc1 = nn.Linear(d_in, d_out*2) |
| | self.fc2 = nn.Linear(d_out, d_out) |
| | |
| | def forward(self, x): |
| | x = self.fc1(x) |
| | x1, x2 = x.chunk(2, dim=-1) |
| | x = self.fc2(F.silu(x1) * x2) |
| | return x |
| |
|
| | class FeedForwardLayer(nn.Module): |
| | def __init__(self, d_in, d_out, dropout=0.1, layer_norm_eps=None): |
| | super().__init__() |
| | self.ff = FeedForward(d_in, d_out) |
| | self.skip = nn.Linear(d_in, d_out) if d_in != d_out else nn.Identity() |
| | self.dropout = nn.Dropout(dropout) |
| | self.LayerNorm = nn.LayerNorm(d_out, eps=layer_norm_eps) if layer_norm_eps else None |
| | |
| | def forward(self, x): |
| | x = self.dropout(x) |
| | x = self.ff(x) + self.skip(x) |
| | if self.LayerNorm: |
| | x = self.LayerNorm(x) |
| | return x |
| | |
| | class Mapper(nn.Module): |
| | def __init__(self, d_in, d_hidden, d_out, n_out, n_layers, dropout=0.1, layer_norm_eps=None): |
| | super().__init__() |
| | self.n_out = n_out |
| | layers = [FeedForwardLayer(d_in, d_hidden, 0.0, layer_norm_eps)] |
| | layers += [FeedForwardLayer(d_hidden, d_hidden, dropout, layer_norm_eps) |
| | for i in range(n_layers)] |
| | self.layers = nn.Sequential(*layers) |
| | |
| | self.output_layer = FeedForwardLayer(d_hidden, d_out*n_out, 0.0, None) |
| | |
| | def forward(self, x): |
| | x = self.layers(x) |
| | x = self.output_layer(x) |
| | x = torch.stack(torch.chunk(x, self.n_out, -1), 1) |
| | return x |
| |
|
| | @dataclass |
| | class MapperModelOutput(ModelOutput): |
| | mapper_out: torch.FloatTensor = None |
| |
|
| | class MapperModel(PreTrainedModel): |
| | config_class = MapperConfig |
| | def __init__(self, config): |
| | super().__init__(config) |
| | |
| | self.mapper = Mapper(config.d_in, config.d_hidden, config.d_out, config.n_out, |
| | config.n_layers, config.dropout, config.layer_norm_eps) |
| | |
| | def forward(self, embedding, return_dict=True): |
| | |
| | mapper_out = self.mapper(embedding) |
| | |
| | |
| | if not return_dict: |
| | return (mapper_out, ) |
| | |
| | return MapperModelOutput(mapper_out=mapper_out) |
| |
|