import argparse import json import time from copy import deepcopy from pathlib import Path import matplotlib.pyplot as plt import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torchvision as tv import torchvision.models as models import yaml from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from torchmetrics.classification import MulticlassConfusionMatrix from torchvision import transforms from torchvision.datasets import ImageFolder # ----------------- argparse ----------------- def build_argparser(): p = argparse.ArgumentParser(description="Train a small CNN on MNIST/Fashion-MNIST") p.add_argument( "--dataset", choices=["fashion-mnist", "mnist", "cifar10"], default="fashion-mnist" ) p.add_argument("--data-dir", type=str, default="./data") p.add_argument("--batch-size", type=int, default=128) p.add_argument("--epochs", type=int, default=8) p.add_argument("--lr", type=float, default=1e-3) p.add_argument("--weight-decay", type=float, default=1e-4) p.add_argument("--num-workers", type=int, default=2) p.add_argument("--seed", type=int, default=41) p.add_argument("--device", choices=["auto", "cpu", "cuda"], default="auto") # legacy path args (we’ll map them into roots if provided) p.add_argument("--logdir", type=str, default=None) p.add_argument("--ckpt", type=str, default=None) p.add_argument("--metrics", type=str, default=None) p.add_argument("--reports-dir", type=str, default=None) # config p.add_argument( "--config", type=str, default="configs/baseline.yaml", help="Path to YAML config with defaults", ) p.add_argument( "--model-name", type=str, default=None, choices=["smallcnn", "resnet18_cifar", "resnet18_imagenet"], help="Choose model architecture", ) return p # ----------------- small utils ----------------- def get_device(choice: str) -> str: if choice == "cpu": return "cpu" if choice == "cuda": return "cuda" return "cuda" if torch.cuda.is_available() else "cpu" def seed_all(seed: int): import random import numpy as np random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def accuracy(logits, targets): preds = logits.argmax(dim=1) return (preds == targets).float().mean().item() def load_yaml(path: str) -> dict: with open(path, "r") as f: return yaml.safe_load(f) def merge_cli_over_config_with_defaults(cfg, args, parser): cfg = deepcopy(cfg) defaults = parser.parse_args([]) # argparse defaults only for arg_name, cfg_key in [ ("dataset", "dataset"), ("data_dir", "data_dir"), ("batch_size", "batch_size"), ("epochs", "epochs"), ("lr", "lr"), ("weight_decay", "weight_decay"), ("num_workers", "num_workers"), ("seed", "seed"), ("device", "device"), ("logdir", "log_root"), ("ckpt", "ckpt_root"), ("metrics", "reports_root"), ("reports_dir", "reports_root"), ("model_name", "model_name"), ]: val = getattr(args, arg_name) defval = getattr(defaults, arg_name) if val is not None and val != defval: if arg_name == "ckpt": cfg[cfg_key] = str(Path(val).parent) elif arg_name in ("metrics", "reports_dir"): cfg[cfg_key] = str(Path(val).parent) else: cfg[cfg_key] = val cfg["_config_path"] = args.config return cfg def is_improved(best_value, current, mode: str, min_delta: float) -> bool: if mode == "min": return current < (best_value - min_delta) return current > (best_value + min_delta) def save_checkpoint(payload: dict, path: Path): torch.save(payload, str(path)) # ----------------- model ----------------- class SmallCNN(nn.Module): def __init__(self, num_classes: int = 10): super().__init__() self.conv1 = nn.Conv2d(1, 32, 3, padding=1) self.pool1 = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(32, 64, 3, padding=1) self.pool2 = nn.MaxPool2d(2, 2) self.fc = nn.Linear(64 * 7 * 7, num_classes) def forward(self, x): x = F.relu(self.conv1(x)) x = self.pool1(x) x = F.relu(self.conv2(x)) x = self.pool2(x) x = torch.flatten(x, 1) return self.fc(x) # logits def build_model(model_name: str, num_classes: int, img_size: int): """ Returns (model, default_target_layer) """ if model_name == "smallcnn": m = SmallCNN(num_classes=num_classes) return m, "conv2" if model_name == "resnet18_cifar": # Start from vanilla resnet18 but adapt for CIFAR (32x32) m = models.resnet18(weights=None) # 3x3 conv, stride=1, padding=1 instead of 7x7/stride=2, and remove maxpool m.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) m.maxpool = nn.Identity() # replace classifier m.fc = nn.Linear(m.fc.in_features, num_classes) return m, "layer4" if model_name == "resnet18_imagenet": # Use ImageNet weights and resize input to 224 try: w = models.ResNet18_Weights.IMAGENET1K_V1 except Exception: w = None m = models.resnet18(weights=w) m.fc = nn.Linear(m.fc.in_features, num_classes) return m, "layer4" raise ValueError(f"Unknown model_name: {model_name}") # ----------------- data ----------------- def get_transforms_for(dataset_name: str, img_size: int, mean, std, train: bool): tfms = [] if dataset_name in {"cifar10"}: if train: # light augments for CIFAR if img_size == 32: tfms += [ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), ] else: tfms += [ transforms.Resize((img_size, img_size)), transforms.RandomHorizontalFlip(), ] else: tfms += [transforms.Resize((img_size, img_size))] tfms += [transforms.ToTensor(), transforms.Normalize(mean, std)] return transforms.Compose(tfms) # fashion-mnist / mnist (grayscale) # fashion-mnist / mnist (grayscale) m, s = float(mean[0]), float(std[0]) tfms = [transforms.ToTensor(), transforms.Normalize((m,), (s,))] return transforms.Compose(tfms) def get_dataloaders( dataset_name: str, data_dir: str, batch_size: int, num_workers: int, seed: int, img_size: int, mean, std, ): root = Path(data_dir) g = torch.Generator().manual_seed(seed) if dataset_name == "fashion-mnist": train_tf = get_transforms_for("fashion-mnist", img_size, mean, std, train=True) eval_tf = get_transforms_for("fashion-mnist", img_size, mean, std, train=False) train_ds = tv.datasets.FashionMNIST(root=root, train=True, download=True, transform=train_tf) test_ds = tv.datasets.FashionMNIST(root=root, train=False, download=True, transform=eval_tf) train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, generator=g) val_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers) test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers) classes = train_ds.classes return train_loader, val_loader, test_loader, classes elif dataset_name == "mnist": train_tf = get_transforms_for("mnist", img_size, mean, std, train=True) eval_tf = get_transforms_for("mnist", img_size, mean, std, train=False) train_ds = tv.datasets.MNIST(root=root, train=True, download=True, transform=train_tf) test_ds = tv.datasets.MNIST(root=root, train=False, download=True, transform=eval_tf) train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, generator=g) val_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers) test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers) classes = train_ds.classes return train_loader, val_loader, test_loader, classes elif dataset_name == "cifar10": train_tf = get_transforms_for("cifar10", img_size, mean, std, train=True) eval_tf = get_transforms_for("cifar10", img_size, mean, std, train=False) train_ds = tv.datasets.CIFAR10(root=root, train=True, download=True, transform=train_tf) test_ds = tv.datasets.CIFAR10(root=root, train=False, download=True, transform=eval_tf) train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, generator=g) val_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers) test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers) classes = train_ds.classes return train_loader, val_loader, test_loader, classes else: raise ValueError(f"Unsupported dataset: {dataset_name}") # ----------------- train/eval ----------------- def train_one_epoch(model, loader, device, optimizer, loss_fn): model.train() loss_sum = 0.0 acc_sum = 0.0 n = 0 for xb, yb in loader: xb, yb = xb.to(device), yb.to(device) optimizer.zero_grad() logits = model(xb) loss = loss_fn(logits, yb) loss.backward() optimizer.step() b = yb.size(0) loss_sum += loss.item() * b acc_sum += accuracy(logits, yb) * b n += b return loss_sum / n, acc_sum / n @torch.no_grad() def eval_one_epoch(model, loader, device, loss_fn): model.eval() loss_sum = 0.0 acc_sum = 0.0 n = 0 for xb, yb in loader: xb, yb = xb.to(device), yb.to(device) logits = model(xb) loss = loss_fn(logits, yb) b = yb.size(0) loss_sum += loss.item() * b acc_sum += accuracy(logits, yb) * b n += b return loss_sum / n, acc_sum / n @torch.no_grad() def confusion_matrix_report( model, test_loader, device, classes, reports_dir: Path, metrics_path: Path, title_prefix: str, ): model.eval() all_preds, all_targets = [], [] for xb, yb in test_loader: xb = xb.to(device) logits = model(xb) preds = logits.argmax(dim=1).cpu() all_preds.append(preds) all_targets.append(yb) all_preds = torch.cat(all_preds) all_targets = torch.cat(all_targets) num_classes = len(classes) cm_metric = MulticlassConfusionMatrix(num_classes=num_classes) cm = cm_metric(all_preds, all_targets).numpy() cm_norm = cm / cm.sum(axis=1, keepdims=True) reports_dir.mkdir(parents=True, exist_ok=True) fig, ax = plt.subplots(figsize=(7, 6)) im = ax.imshow(cm_norm, interpolation="nearest") ax.figure.colorbar(im, ax=ax) ax.set( xticks=np.arange(num_classes), yticks=np.arange(num_classes), xticklabels=classes, yticklabels=classes, ylabel="True label", xlabel="Predicted label", title=f"{title_prefix} Confusion Matrix (row-normalized)", ) plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") for i in range(num_classes): for j in range(num_classes): ax.text( j, i, f"{cm_norm[i, j]*100:.1f}%", ha="center", va="center", fontsize=8 ) fig.tight_layout() fig_path = reports_dir / "confusion_matrix.png" plt.savefig(fig_path, dpi=200) plt.close(fig) print("Saved figure to:", fig_path) np.save(reports_dir / "confusion_matrix_counts.npy", cm) np.save(reports_dir / "confusion_matrix_norm.npy", cm_norm) try: with open(metrics_path) as f: metrics = json.load(f) except FileNotFoundError: metrics = {} metrics.update( { "confusion_matrix_counts_path": str( reports_dir / "confusion_matrix_counts.npy" ), "confusion_matrix_norm_path": str( reports_dir / "confusion_matrix_norm.npy" ), "confusion_matrix_figure": str(fig_path), } ) with open(metrics_path, "w") as f: json.dump(metrics, f, indent=2) # ----------------- main ----------------- def main(): parser = build_argparser() args = parser.parse_args() seed_all(args.seed) base_cfg = load_yaml(args.config) cfg = merge_cli_over_config_with_defaults(base_cfg, args, parser) dataset = cfg["dataset"] model_name = cfg.get("model_name", "smallcnn") img_size = int( cfg.get("img_size", 28 if dataset in ["fashion-mnist", "mnist"] else 32) ) mean = cfg.get("mean", None) std = cfg.get("std", None) # defaults for grayscale datasets if dataset in ["fashion-mnist", "mnist"]: if mean is None or std is None: if dataset == "fashion-mnist": mean, std = [0.2860], [0.3530] else: mean, std = [0.1307], [0.3081] # defaults for cifar10 if dataset == "cifar10" and (mean is None or std is None): mean, std = [0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2616] device = get_device(cfg["device"]) print("device:", device) run_id = f'{cfg["dataset"]}_{int(time.time())}' LOG_DIR = Path(cfg["log_root"]) / run_id CKPTS_DIR = Path(cfg["ckpt_root"]) / run_id REPORTS_DIR = Path(cfg["reports_root"]) / run_id for d in (LOG_DIR, CKPTS_DIR, REPORTS_DIR): d.mkdir(parents=True, exist_ok=True) effective_cfg = deepcopy(cfg) effective_cfg["run_id"] = run_id with open(REPORTS_DIR / "config_effective.yaml", "w") as f: yaml.safe_dump(effective_cfg, f) train_loader, val_loader, test_loader, classes = get_dataloaders( dataset, cfg["data_dir"], cfg["batch_size"], cfg["num_workers"], cfg["seed"], img_size, mean, std, ) loss_fn = nn.CrossEntropyLoss() model, default_target_layer = build_model( model_name, num_classes=len(classes), img_size=img_size ) model = model.to(device) opt_name = str(cfg.get("optimizer", "adam")).lower() if opt_name == "sgd": optimizer = optim.SGD( model.parameters(), lr=cfg["lr"], momentum=float(cfg.get("momentum", 0.9)), weight_decay=cfg["weight_decay"], nesterov=True, ) else: optimizer = optim.Adam( model.parameters(), lr=cfg["lr"], weight_decay=cfg["weight_decay"] ) scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode="min", factor=0.5, patience=2 ) writer = SummaryWriter(log_dir=str(LOG_DIR)) monitor = cfg["early_stop"]["monitor"] mode = cfg["early_stop"]["mode"] patience = int(cfg["early_stop"]["patience"]) min_delta = float(cfg["early_stop"]["min_delta"]) best_val = float("inf") if mode == "min" else -float("inf") epochs_no_improve = 0 ckpt_last = CKPTS_DIR / "last.ckpt" ckpt_best = CKPTS_DIR / "best.ckpt" for epoch in range(1, cfg["epochs"] + 1): tr_loss, tr_acc = train_one_epoch( model, train_loader, device, optimizer, loss_fn ) va_loss, va_acc = eval_one_epoch(model, val_loader, device, loss_fn) scheduler.step(va_loss) writer.add_scalar("Loss/train", tr_loss, epoch) writer.add_scalar("Loss/val", va_loss, epoch) writer.add_scalar("Acc/train", tr_acc, epoch) writer.add_scalar("Acc/val", va_acc, epoch) writer.add_scalar("LR", optimizer.param_groups[0]["lr"], epoch) print( f"Epoch {epoch:02d} | train_loss={tr_loss:.4f} acc={tr_acc:.4f}" + f" | val_loss={va_loss:.4f} acc={va_acc:.4f}" ) mon_value = va_loss if monitor == "val_loss" else va_acc payload = { "epoch": epoch, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "val_acc": va_acc, "val_loss": va_loss, "dataset": cfg["dataset"], "classes": classes, "config_path": cfg.get("_config_path"), "meta": { "dataset": dataset, "model_name": model_name, "img_size": img_size, "mean": mean, "std": std, "default_target_layer": default_target_layer, }, } save_checkpoint(payload, ckpt_last) if is_improved(best_val, mon_value, mode, min_delta): best_val = mon_value epochs_no_improve = 0 save_checkpoint(payload, ckpt_best) best_json = { "epoch": epoch, "monitor": monitor, "mode": mode, "best_value": float(best_val), "val_acc": float(va_acc), "val_loss": float(va_loss), "ckpt_path": str(ckpt_best), "meta": { "dataset": dataset, "model_name": model_name, "img_size": img_size, "mean": mean, "std": std, "default_target_layer": default_target_layer, }, } with open(REPORTS_DIR / "best.json", "w") as f: json.dump(best_json, f, indent=2) else: epochs_no_improve += 1 if epochs_no_improve >= patience: print(f"Early stopping: no improvement in {patience} epochs.") break writer.close() print(f"Best {monitor}: {best_val:.4f}") # Use best checkpoint for reports best_ckpt = torch.load(str(ckpt_best), map_location=device) model.load_state_dict(best_ckpt["model_state"]) model.eval() metrics_path = REPORTS_DIR / "metrics.json" confusion_matrix_report( model, test_loader, device, classes, reports_dir=REPORTS_DIR / "figures", metrics_path=metrics_path, title_prefix=cfg["dataset"].replace("-", " ").title(), ) metrics = { "dataset": cfg["dataset"], "epochs_ran": epoch, "batch_size": cfg["batch_size"], "lr": cfg["lr"], "weight_decay": cfg["weight_decay"], "best_monitor": monitor, "best_mode": mode, "best_value": float(best_val), "logs_dir": str(LOG_DIR), "ckpts_dir": str(CKPTS_DIR), "reports_dir": str(REPORTS_DIR), } with open(metrics_path, "w") as f: json.dump(metrics, f, indent=2) print("Saved metrics to:", metrics_path) print("Best checkpoint:", ckpt_best) if __name__ == "__main__": main()