| from torch import Tensor | |
| from torch import nn | |
| from typing import Dict | |
| import torch.nn.functional as F | |
| class Normalize(nn.Module): | |
| """ | |
| This layer normalizes embeddings to unit length | |
| """ | |
| def __init__(self): | |
| super(Normalize, self).__init__() | |
| def forward(self, features: Dict[str, Tensor]): | |
| features.update({'sentence_embedding': F.normalize(features['sentence_embedding'], p=2, dim=1)}) | |
| return features | |
| def save(self, output_path): | |
| pass | |
| def load(input_path): | |
| return Normalize() | |