File size: 2,016 Bytes
8bbb872
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""
MLP Hyperparameter Sweep (Optuna).
Run: python -m models.mlp.sweep
"""
import os
import sys

import optuna
import torch
import torch.nn as nn
import torch.optim as optim

REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
if REPO_ROOT not in sys.path:
    sys.path.insert(0, REPO_ROOT)

from data_preparation.prepare_dataset import get_dataloaders
from models.mlp.train import BaseModel, set_seed

SEED = 42
N_TRIALS = 20
EPOCHS_PER_TRIAL = 15


def objective(trial):
    set_seed(SEED)

    lr = trial.suggest_float("lr", 1e-4, 1e-2, log=True)
    batch_size = trial.suggest_categorical("batch_size", [16, 32, 64])

    train_loader, val_loader, _, num_features, num_classes, _ = get_dataloaders(
        model_name="face_orientation",
        batch_size=batch_size,
        split_ratios=(0.7, 0.15, 0.15),
        seed=SEED,
    )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = BaseModel(num_features, num_classes).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    best_val_acc = 0.0
    for epoch in range(1, EPOCHS_PER_TRIAL + 1):
        model.training_step(train_loader, optimizer, criterion, device)
        val_loss, val_acc = model.validation_step(val_loader, criterion, device)
        if val_acc > best_val_acc:
            best_val_acc = val_acc

    return 1.0 - best_val_acc  # minimize (1 - accuracy)


def main():
    study = optuna.create_study(direction="minimize", study_name="mlp_sweep")
    print(f"[SWEEP] MLP Optuna sweep: {N_TRIALS} trials, {EPOCHS_PER_TRIAL} epochs each")
    study.optimize(objective, n_trials=N_TRIALS)

    print("\n[SWEEP] Top-5 trials by validation accuracy")
    best = sorted(study.trials, key=lambda t: t.value if t.value is not None else float("inf"))[:5]
    for i, t in enumerate(best, 1):
        acc = (1.0 - t.value) * 100
        print(f"  #{i}  Val Acc: {acc:.2f}%  params={t.params}")


if __name__ == "__main__":
    main()