sapbot's picture
Upload 3 files
016eee7 verified
"""
train_mnist_1k_tqdm.py
Trains a tiny MNIST model (<1000 params) until convergence,
using tqdm progress bars and early stopping.
"""
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
import numpy as np
import os
import sys
# -------------------------------
# 0. Automatic device fallback
# -------------------------------
def get_device():
if torch.cuda.is_available():
try:
test_tensor = torch.randn(1, 1, 28, 28).cuda()
_ = torch.nn.functional.avg_pool2d(test_tensor, 4)
return torch.device('cuda')
except Exception as e:
print(f"GPU error: {e}\nFalling back to CPU.")
return torch.device('cpu')
return torch.device('cpu')
device = get_device()
print(f"Using device: {device}")
# -------------------------------
# 1. Model (970 parameters)
# -------------------------------
class TinyMNISTModel(nn.Module):
def __init__(self):
super().__init__()
self.pool = nn.AvgPool2d(4, 4)
self.fc1 = nn.Linear(7*7, 16)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.2)
self.fc2 = nn.Linear(16, 10)
def forward(self, x):
x = self.pool(x)
x = x.view(x.size(0), -1)
x = self.fc1(x)
x = self.relu(x)
x = self.dropout(x)
x = self.fc2(x)
return x
# -------------------------------
# 2. Data
# -------------------------------
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
full_train = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# Split 90% train, 10% validation
val_size = int(0.1 * len(full_train))
train_size = len(full_train) - val_size
train_dataset, val_dataset = random_split(full_train, [train_size, val_size])
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# -------------------------------
# 3. Training with early stopping + tqdm
# -------------------------------
model = TinyMNISTModel().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
patience = 5
best_val_loss = float('inf')
epochs_no_improve = 0
best_model_state = None
print("\nπŸ‹οΈ Training until convergence (early stopping patience = 5)\n")
epoch = 0
while True:
# Training phase with tqdm
model.train()
train_loss = 0.0
train_bar = tqdm(train_loader, desc=f"Epoch {epoch+1} [Train]", leave=False)
for images, labels in train_bar:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item()
train_bar.set_postfix(loss=loss.item())
train_loss /= len(train_loader)
# Validation phase
model.eval()
val_loss = 0.0
correct = 0
total = 0
val_bar = tqdm(val_loader, desc=f"Epoch {epoch+1} [Val]", leave=False)
with torch.no_grad():
for images, labels in val_bar:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
val_loss += loss.item()
_, pred = torch.max(outputs, 1)
total += labels.size(0)
correct += (pred == labels).sum().item()
val_bar.set_postfix(loss=loss.item())
val_loss /= len(val_loader)
val_acc = 100.0 * correct / total
# Print progress line (outside tqdm to keep clean)
print(f"Epoch {epoch+1:3d} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
# Early stopping logic
if val_loss < best_val_loss:
best_val_loss = val_loss
epochs_no_improve = 0
best_model_state = model.state_dict().copy()
else:
epochs_no_improve += 1
if epochs_no_improve >= patience:
print(f"\nπŸ›‘ Early stopping after {epoch+1} epochs (no improvement for {patience} epochs).")
break
epoch += 1
# Restore best model
model.load_state_dict(best_model_state)
# -------------------------------
# 4. Final evaluation on full test set
# -------------------------------
def evaluate(loader, name="Test"):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in tqdm(loader, desc=f"Evaluating on {name}", leave=False):
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, pred = torch.max(outputs, 1)
total += labels.size(0)
correct += (pred == labels).sum().item()
acc = 100.0 * correct / total
print(f"{name} accuracy: {acc:.2f}%")
return acc
test_acc = evaluate(test_loader, "full test set")
total_params = sum(p.numel() for p in model.parameters())
# -------------------------------
# 5. TL;DR summary
# -------------------------------
tldr = f"""
╔══════════════════════════════════════════════════════════╗
β•‘ TL;DR – Tiny MNIST β•‘
╠══════════════════════════════════════════════════════════╣
β•‘ Parameters: {total_params:<48}β•‘
β•‘ Training epochs until convergence: {epoch+1:<31}β•‘
β•‘ Best validation loss: {best_val_loss:.4f}<40 spaces>β•‘ -- actually align manually
β•‘ Final test accuracy: {test_acc:.2f}%<39 spaces>β•‘
β•‘ Early stopping patience: {patience} epochs<36 spaces>β•‘
β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•
"""
print(tldr)
# Save model
torch.save(model.state_dict(), "mnist_1k_best.pth")
# -------------------------------
# 6. Generate README.md (HF style)
# -------------------------------
readme_content = f"""---
language: en
license: apache-2.0
tags:
- mnist
- tiny-model
- tqdm
- early-stopping
---
# Tiny MNIST Classifier – with tqdm progress bars
- **Parameters**: {total_params} (<1000)
- **Test accuracy**: {test_acc:.2f}%
- **Epochs trained**: {epoch+1} (early stopping after {patience} epochs without improvement)
This script trains until convergence and shows **tqdm** progress bars for each batch.
## TL;DR
```bash
python train_mnist_1k_tqdm.py
```
## Full results
| Metric | Value |
|---------------------------|-----------------|
| Total parameters | {total_params} |
| Best validation loss | {best_val_loss:.4f} |
| Final test accuracy | {test_acc:.2f}% |
| Early stopping patience | {patience} |
| Training epochs | {epoch+1} |
## Model architecture
AvgPool(4x4) β†’ Linear(49β†’16) β†’ ReLU β†’ Dropout(0.2) β†’ Linear(16β†’10)
## How to use
```python
import torch
from train_mnist_1k_tqdm import TinyMNISTModel
model = TinyMNISTModel()
model.load_state_dict(torch.load("mnist_1k_best.pth"))
model.eval()
```
"""
with open("README.md", "w") as f:
f.write(readme_content)
print("βœ… README.md generated. Model saved as mnist_1k_best.pth")