Spaces:
Sleeping
Sleeping
| """ | |
| MLP Hyperparameter Sweep (Optuna). | |
| Run: python -m models.mlp.sweep | |
| """ | |
| import os | |
| import sys | |
| import optuna | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) | |
| if REPO_ROOT not in sys.path: | |
| sys.path.insert(0, REPO_ROOT) | |
| from data_preparation.prepare_dataset import get_dataloaders | |
| from models.mlp.train import BaseModel, set_seed | |
| SEED = 42 | |
| N_TRIALS = 20 | |
| EPOCHS_PER_TRIAL = 15 | |
| def objective(trial): | |
| set_seed(SEED) | |
| lr = trial.suggest_float("lr", 1e-4, 1e-2, log=True) | |
| batch_size = trial.suggest_categorical("batch_size", [16, 32, 64]) | |
| train_loader, val_loader, _, num_features, num_classes, _ = get_dataloaders( | |
| model_name="face_orientation", | |
| batch_size=batch_size, | |
| split_ratios=(0.7, 0.15, 0.15), | |
| seed=SEED, | |
| ) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = BaseModel(num_features, num_classes).to(device) | |
| criterion = nn.CrossEntropyLoss() | |
| optimizer = optim.Adam(model.parameters(), lr=lr) | |
| best_val_acc = 0.0 | |
| for epoch in range(1, EPOCHS_PER_TRIAL + 1): | |
| model.training_step(train_loader, optimizer, criterion, device) | |
| val_loss, val_acc = model.validation_step(val_loader, criterion, device) | |
| if val_acc > best_val_acc: | |
| best_val_acc = val_acc | |
| return 1.0 - best_val_acc # minimize (1 - accuracy) | |
| def main(): | |
| study = optuna.create_study(direction="minimize", study_name="mlp_sweep") | |
| print(f"[SWEEP] MLP Optuna sweep: {N_TRIALS} trials, {EPOCHS_PER_TRIAL} epochs each") | |
| study.optimize(objective, n_trials=N_TRIALS) | |
| print("\n[SWEEP] Top-5 trials by validation accuracy") | |
| best = sorted(study.trials, key=lambda t: t.value if t.value is not None else float("inf"))[:5] | |
| for i, t in enumerate(best, 1): | |
| acc = (1.0 - t.value) * 100 | |
| print(f" #{i} Val Acc: {acc:.2f}% params={t.params}") | |
| if __name__ == "__main__": | |
| main() | |