IntegrationTest / models /mlp /train.py
Yingtao-Zheng's picture
Upload partially updated files
8bbb872
import json
import os, sys
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import f1_score, roc_auc_score
from data_preparation.prepare_dataset import get_dataloaders
USE_CLEARML = False
_PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
CFG = {
"model_name": "face_orientation",
"epochs": 30,
"batch_size": 32,
"lr": 1e-3,
"seed": 42,
"split_ratios": (0.7, 0.15, 0.15),
"checkpoints_dir": os.path.join(_PROJECT_ROOT, "checkpoints"),
"logs_dir": os.path.join(_PROJECT_ROOT, "evaluation", "logs"),
}
# ==== ClearML (opt-in) =============================================
task = None
if USE_CLEARML:
from clearml import Task
task = Task.init(
project_name="Focus Guard",
task_name="MLP Model Training",
tags=["training", "mlp_model"]
)
task.connect(CFG)
# ==== Model =============================================
def set_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
class BaseModel(nn.Module):
def __init__(self, num_features: int, num_classes: int):
super().__init__()
self.network = nn.Sequential(
nn.Linear(num_features, 64),
nn.ReLU(),
nn.Linear(64, 32),
nn.ReLU(),
nn.Linear(32, num_classes),
)
def forward(self, x):
return self.network(x)
def training_step(self, loader, optimizer, criterion, device):
self.train()
total_loss = 0.0
correct = 0
total = 0
for features, labels in loader:
features, labels = features.to(device), labels.to(device)
optimizer.zero_grad()
outputs = self(features)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item() * features.size(0)
correct += (outputs.argmax(dim=1) == labels).sum().item()
total += features.size(0)
return total_loss / total, correct / total
@torch.no_grad()
def validation_step(self, loader, criterion, device):
self.eval()
total_loss = 0.0
correct = 0
total = 0
for features, labels in loader:
features, labels = features.to(device), labels.to(device)
outputs = self(features)
loss = criterion(outputs, labels)
total_loss += loss.item() * features.size(0)
correct += (outputs.argmax(dim=1) == labels).sum().item()
total += features.size(0)
return total_loss / total, correct / total
@torch.no_grad()
def test_step(self, loader, criterion, device):
self.eval()
total_loss = 0.0
correct = 0
total = 0
all_preds = []
all_labels = []
all_probs = []
for features, labels in loader:
features, labels = features.to(device), labels.to(device)
outputs = self(features)
loss = criterion(outputs, labels)
total_loss += loss.item() * features.size(0)
preds = outputs.argmax(dim=1)
correct += (preds == labels).sum().item()
total += features.size(0)
probs = torch.softmax(outputs, dim=1)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
all_probs.extend(probs.cpu().numpy())
return total_loss / total, correct / total, np.array(all_probs), np.array(all_preds), np.array(all_labels)
def main():
set_seed(CFG["seed"])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[TRAIN] Device: {device}")
print(f"[TRAIN] Model: {CFG['model_name']}")
train_loader, val_loader, test_loader, num_features, num_classes, scaler = get_dataloaders(
model_name=CFG["model_name"],
batch_size=CFG["batch_size"],
split_ratios=CFG["split_ratios"],
seed=CFG["seed"],
)
model = BaseModel(num_features, num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=CFG["lr"])
param_count = sum(p.numel() for p in model.parameters())
print(f"[TRAIN] Parameters: {param_count:,}")
ckpt_dir = CFG["checkpoints_dir"]
os.makedirs(ckpt_dir, exist_ok=True)
best_ckpt_path = os.path.join(ckpt_dir, "mlp_best.pt")
history = {
"model_name": CFG["model_name"],
"param_count": param_count,
"epochs": [],
"train_loss": [],
"train_acc": [],
"val_loss": [],
"val_acc": [],
}
best_val_acc = 0.0
print(f"\n{'Epoch':>6} | {'Train Loss':>10} | {'Train Acc':>9} | {'Val Loss':>10} | {'Val Acc':>9}")
print("-" * 60)
for epoch in range(1, CFG["epochs"] + 1):
train_loss, train_acc = model.training_step(train_loader, optimizer, criterion, device)
val_loss, val_acc = model.validation_step(val_loader, criterion, device)
history["epochs"].append(epoch)
history["train_loss"].append(round(train_loss, 4))
history["train_acc"].append(round(train_acc, 4))
history["val_loss"].append(round(val_loss, 4))
history["val_acc"].append(round(val_acc, 4))
current_lr = optimizer.param_groups[0]['lr']
if task is not None:
task.logger.report_scalar("Loss", "Train", float(train_loss), iteration=epoch)
task.logger.report_scalar("Accuracy", "Train", float(train_acc), iteration=epoch)
task.logger.report_scalar("Loss", "Val", float(val_loss), iteration=epoch)
task.logger.report_scalar("Accuracy", "Val", float(val_acc), iteration=epoch)
task.logger.report_scalar("Learning Rate", "LR", float(current_lr), iteration=epoch)
task.logger.flush()
marker = ""
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save(model.state_dict(), best_ckpt_path)
marker = " *"
print(f"{epoch:>6} | {train_loss:>10.4f} | {train_acc:>8.2%} | {val_loss:>10.4f} | {val_acc:>8.2%}{marker}")
print(f"\nBest validation accuracy: {best_val_acc:.2%}")
print(f"Checkpoint saved to: {best_ckpt_path}")
model.load_state_dict(torch.load(best_ckpt_path, weights_only=True))
test_loss, test_acc, test_probs, test_preds, test_labels = model.test_step(test_loader, criterion, device)
test_f1 = f1_score(test_labels, test_preds, average='weighted')
# Handle potentially >2 classes for AUC
if num_classes > 2:
test_auc = roc_auc_score(test_labels, test_probs, multi_class='ovr', average='weighted')
else:
test_auc = roc_auc_score(test_labels, test_probs[:, 1])
print(f"\n[TEST] Loss: {test_loss:.4f} | Accuracy: {test_acc:.2%}")
print(f"[TEST] F1: {test_f1:.4f} | ROC-AUC: {test_auc:.4f}")
history["test_loss"] = round(test_loss, 4)
history["test_acc"] = round(test_acc, 4)
history["test_f1"] = round(test_f1, 4)
history["test_auc"] = round(test_auc, 4)
logs_dir = CFG["logs_dir"]
os.makedirs(logs_dir, exist_ok=True)
log_path = os.path.join(logs_dir, f"{CFG['model_name']}_training_log.json")
with open(log_path, "w") as f:
json.dump(history, f, indent=2)
print(f"[LOG] Training history saved to: {log_path}")
if __name__ == "__main__":
main()