| import torch |
| import torch.nn as nn |
| from typing import Union, List, Tuple |
|
|
| |
|
|
| class LSTM(nn.Module): |
| |
| def __init__( |
| self, |
| input_size: int = 8, |
| hidden_size: int = 40, |
| num_layers: int = 2, |
| dropout: float = 0.1, |
| lookback: int = 8, |
| ): |
| |
| super(LSTM,self).__init__() |
| |
| |
| self.hidden_size, self.num_layers = hidden_size, num_layers |
| |
| |
| self.lstm = nn.LSTM( |
| input_size = input_size, |
| hidden_size = hidden_size, |
| num_layers = num_layers, |
| bias = True, |
| batch_first = True, |
| dropout = dropout, |
| bidirectional = False, |
| proj_size = 0, |
| device = None |
| ) |
|
|
| |
| self.proj = nn.Linear(in_features=hidden_size, out_features=1, bias=False) |
|
|
| |
| self.dropout = nn.Dropout(p=dropout) |
| |
| |
| def init_h_c_(self, B, device, dtype): |
| |
| h = torch.zeros((self.num_layers,B,self.hidden_size),dtype=dtype,device=device) |
| c = torch.zeros((self.num_layers,B,self.hidden_size),dtype=dtype,device=device) |
| |
| return h,c |
| |
| def forward(self, x, fut_time): |
| |
| B, dev, dt = x.shape[0], x.device, x.dtype |
| |
| |
| h,c = self.init_h_c_(B, dev, dt) |
| |
| |
| out,(_,_) = self.lstm(x,(h,c)) |
| return self.proj(self.dropout(out[:,-1,:])) |