| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| |
|
| | class MLP(nn.Module): |
| | def __init__( |
| | self, |
| | dim: int, |
| | hidden_dim: int, |
| | ): |
| | """ |
| | Initializes the multilayer perceptron (MLP) module. |
| | |
| | Args: |
| | dim: The input and output dimensionality. |
| | hidden_dim: The dimensionality of the hidden layer. |
| | """ |
| | super().__init__() |
| | self.w1 = nn.Linear(dim, hidden_dim, bias=False) |
| | self.w2 = nn.Linear(hidden_dim, dim, bias=False) |
| | self.w3 = nn.Linear(dim, hidden_dim, bias=False) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Performs the forward pass of the MLP module. |
| | |
| | Args: |
| | x: The input tensor of shape (batch_size, dim). |
| | |
| | Returns: |
| | The output tensor of shape (batch_size, dim). |
| | """ |
| | output = self.w2(F.silu(self.w1(x)) * self.w3(x)) |
| | return output |
| |
|