MagicText-1.2_fine-tuning / model_class.py
BIGAI-models's picture
Create model_class.py
6bb1778 verified
raw
history blame contribute delete
443 Bytes
class AI(nn.Module):
def __init__(self):
super().__init__()
self.embai = nn.Embedding(230, 256)
self.lsai = nn.LSTM(256, 512, batch_first=True, dropout=0.3, num_layers=1)
self.linai = nn.Linear(512, 230)
def forward(self, x):
x = self.embai(x)
if x.dim() == 2:
x = x.unsqueeze(0)
x, _ = self.lsai(x)
x = x[:, -1, :]
x = self.linai(x)
return x