| --- |
| license: apache-2.0 |
| tags: |
| - pytorch |
| - lightning |
| - cifar10 |
| --- |
| |
| # CIFAR10 Classifier trained with PyTorch Lightning |
|
|
| ## Introduction |
|
|
| A ResNet18 model that achieves 94% prediction accuracy. Key features include: |
|
|
| 1. Data normalization and randomization (10% improvement) |
| 2. Dropout before FC classifier (1% improvement) |
| 3. Batch normalization in ResNetBlock (2% improvement) |
| 4. Cos learning rate schedule (1% improvement) |
| 5. ResNet18 is deeper than a simple CNN network. |
|
|
|
|
|
|
| ## Usage |
|
|
| ### Approach 1: use pytorch to predict |
| ```python |
| |
| ## Approach 1: use pytorch to predict |
| import torch |
| from model import CIFARCNN |
| |
| # Evaluate model checkpoints |
| model = CIFARCNN.load_from_checkpoint("model.ckpt") |
| model.eval() |
| x = torch.randn(4, 3, 32, 32).to(model.device) |
| |
| with torch.no_grad(): |
| predictions = model(x) # the lightning module should implement forward func |
| print(predictions.shape) # should be [4, 10] |
| ``` |
|
|
|
|
| ### Approach 2: use Lightning to predict |
| ```py |
| import torch |
| from model import CIFARCNN |
| from lightning import Trainer |
| |
| test_dataloader = DataLoader(...) |
| model = CIFARCNN.load_from_checkpoint("model.ckpt") # lightning will move model to default device |
| trainer = Trainer() |
| |
| trainer.test(model, test_dataloader) |
| ``` |
|
|
| ### Visualize results |
| ```py |
| import matplotlib.pyplot as plt |
| |
| cifar10_labels = { |
| 0: "airplane", |
| 1: "automobile", |
| 2: "bird", |
| 3: "cat", |
| 4: "deer", |
| 5: "dog", |
| 6: "frog", |
| 7: "horse", |
| 8: "ship", |
| 9: "truck", |
| } |
| |
| samples, labels = next(iter(train_loader)) |
| predicts = trainer.predict(model, samples) |
| labels = predicts.argmax(dim=1) |
| |
| fig, axes = plt.subplots(2, 5, figsize=(10, 4)) |
| for i, ax in enumerate(axes.flatten()): |
| ax.imshow(samples[i].permute(1, 2, 0)) |
| ax.set_title(f"{cifar10_labels[labels[i].item()]}") |
| ax.axis("off") |
| plt.show() |
| ``` |