class AI(nn.Module): def __init__(self): super().__init__() self.embai = nn.Embedding(230, 256) self.lsai = nn.LSTM(256, 512, batch_first=True, dropout=0.3, num_layers=1) self.linai = nn.Linear(512, 230) def forward(self, x): x = self.embai(x) if x.dim() == 2: x = x.unsqueeze(0) x, _ = self.lsai(x) x = x[:, -1, :] x = self.linai(x) return x