| import torch |
| import torch.nn as nn |
|
|
| class RecoveredBaselineModel(nn.Module): |
| def __init__(self, input_dim=320, hidden_dim=1024, output_dim=1, dropout=0.2): |
| super().__init__() |
| self.fc1 = nn.Linear(input_dim, hidden_dim) |
| self.proj = nn.Linear(input_dim, hidden_dim) |
| self.fc2 = nn.Linear(hidden_dim, hidden_dim) |
| self.out = nn.Linear(hidden_dim, output_dim) |
| self.relu = nn.ReLU() |
| self.drop = nn.Dropout(dropout) |
|
|
| def forward(self, x): |
| h = self.relu(self.fc1(x)) |
| p = self.proj(x) |
| h = h + p |
| h = self.relu(self.fc2(h)) |
| h = self.drop(h) |
| return self.out(h) |
|
|
| |
| model = torch.load("baseline_public_v1.pth", map_location="cpu", weights_only=False) |
|
|
| |
| torch.save(model.state_dict(), "baseline_state_dict.pth") |
|
|
| print("Successfully extracted state_dict -> baseline_state_dict.pth") |
|
|