| |
| """ |
| β
OPTIMIZED Food101 + ResNet50 with major speed improvements |
| β
Mixed precision training (2x faster) |
| β
Better data loading (persistent workers) |
| β
Progress bars and better logging |
| β
Robust error handling and checkpointing |
| """ |
|
|
| import os |
| import time |
| import copy |
| import numpy as np |
| import matplotlib.pyplot as plt |
| from tqdm import tqdm |
| import logging |
|
|
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| import torchvision |
| import torchvision.transforms as transforms |
| from torch.utils.data import DataLoader |
| from torch.cuda.amp import autocast, GradScaler |
|
|
| |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
| logger = logging.getLogger(__name__) |
|
|
| |
| |
| |
| def get_food101_loaders(batch_size=64, num_workers=8): |
| """Returns optimized train/val/test loaders + class names""" |
| |
| |
| transform_train = transforms.Compose([ |
| transforms.Resize((256, 256)), |
| transforms.RandomCrop((224, 224)), |
| transforms.RandomHorizontalFlip(p=0.5), |
| transforms.RandomRotation(15), |
| transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1), |
| transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| ]) |
| |
| transform_test = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| ]) |
| |
| try: |
| |
| full_train = torchvision.datasets.Food101( |
| root='./data', split='train', download=True, transform=transform_train |
| ) |
| |
| |
| torch.manual_seed(42) |
| train_size = int(0.9 * len(full_train)) |
| val_size = len(full_train) - train_size |
| train_dataset, val_dataset = torch.utils.data.random_split( |
| full_train, [train_size, val_size] |
| ) |
| |
| |
| test_dataset = torchvision.datasets.Food101( |
| root='./data', split='test', download=True, transform=transform_test |
| ) |
| |
| logger.info(f"Dataset sizes - Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}") |
| |
| |
| train_loader = DataLoader( |
| train_dataset, batch_size, shuffle=True, num_workers=num_workers, |
| pin_memory=True, persistent_workers=True, drop_last=True |
| ) |
| val_loader = DataLoader( |
| val_dataset, batch_size, shuffle=False, num_workers=num_workers, |
| pin_memory=True, persistent_workers=True |
| ) |
| test_loader = DataLoader( |
| test_dataset, batch_size, shuffle=False, num_workers=num_workers, |
| pin_memory=True, persistent_workers=True |
| ) |
| |
| return train_loader, val_loader, test_loader, full_train.classes |
| |
| except Exception as e: |
| logger.error(f"Error loading data: {e}") |
| raise |
|
|
|
|
| |
| |
| |
| class BasicBlock(nn.Module): |
| expansion = 1 |
| def __init__(self, inplanes, planes, stride=1, downsample=None): |
| super().__init__() |
| self.conv1 = nn.Conv2d(inplanes, planes, 3, stride, 1, bias=False) |
| self.bn1 = nn.BatchNorm2d(planes) |
| self.conv2 = nn.Conv2d(planes, planes, 3, 1, 1, bias=False) |
| self.bn2 = nn.BatchNorm2d(planes) |
| self.relu = nn.ReLU(inplace=True) |
| self.downsample = downsample |
|
|
| def forward(self, x): |
| identity = x |
| out = self.conv1(x) |
| out = self.bn1(out) |
| out = self.relu(out) |
| out = self.conv2(out) |
| out = self.bn2(out) |
| if self.downsample: identity = self.downsample(x) |
| out += identity |
| out = self.relu(out) |
| return out |
|
|
|
|
| class Bottleneck(nn.Module): |
| expansion = 4 |
| def __init__(self, inplanes, planes, stride=1, downsample=None): |
| super().__init__() |
| self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) |
| self.bn1 = nn.BatchNorm2d(planes) |
| self.conv2 = nn.Conv2d(planes, planes, 3, stride, 1, bias=False) |
| self.bn2 = nn.BatchNorm2d(planes) |
| self.conv3 = nn.Conv2d(planes, planes*self.expansion, 1, bias=False) |
| self.bn3 = nn.BatchNorm2d(planes*self.expansion) |
| self.relu = nn.ReLU(inplace=True) |
| self.downsample = downsample |
|
|
| def forward(self, x): |
| identity = x |
| out = self.conv1(x) |
| out = self.bn1(out) |
| out = self.relu(out) |
| out = self.conv2(out) |
| out = self.bn2(out) |
| out = self.relu(out) |
| out = self.conv3(out) |
| out = self.bn3(out) |
| if self.downsample: identity = self.downsample(x) |
| out += identity |
| out = self.relu(out) |
| return out |
|
|
|
|
| class ResNet50(nn.Module): |
| def __init__(self, num_classes=101): |
| super().__init__() |
| self.inplanes = 64 |
| |
| self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False) |
| self.bn1 = nn.BatchNorm2d(64) |
| self.relu = nn.ReLU(inplace=True) |
| self.maxpool = nn.MaxPool2d(3, 2, 1) |
| |
| self.layer1 = self._make_layer(Bottleneck, 64, 3) |
| self.layer2 = self._make_layer(Bottleneck, 128, 4, 2) |
| self.layer3 = self._make_layer(Bottleneck, 256, 6, 2) |
| self.layer4 = self._make_layer(Bottleneck, 512, 3, 2) |
| |
| self.avgpool = nn.AdaptiveAvgPool2d(1) |
| self.fc = nn.Linear(512*Bottleneck.expansion, num_classes) |
| |
| |
| self._initialize_weights() |
| |
| def _make_layer(self, block, planes, blocks, stride=1): |
| downsample = None |
| if stride != 1 or self.inplanes != planes*block.expansion: |
| downsample = nn.Sequential( |
| nn.Conv2d(self.inplanes, planes*block.expansion, 1, stride, bias=False), |
| nn.BatchNorm2d(planes*block.expansion) |
| ) |
| |
| layers = [block(self.inplanes, planes, stride, downsample)] |
| self.inplanes = planes * block.expansion |
| for _ in range(1, blocks): |
| layers.append(block(self.inplanes, planes)) |
| return nn.Sequential(*layers) |
| |
| def _initialize_weights(self): |
| for m in self.modules(): |
| if isinstance(m, nn.Conv2d): |
| nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') |
| elif isinstance(m, nn.BatchNorm2d): |
| nn.init.constant_(m.weight, 1) |
| nn.init.constant_(m.bias, 0) |
| |
| def forward(self, x): |
| x = self.conv1(x) |
| x = self.bn1(x) |
| x = self.relu(x) |
| x = self.maxpool(x) |
| |
| x = self.layer1(x) |
| x = self.layer2(x) |
| x = self.layer3(x) |
| x = self.layer4(x) |
| |
| x = self.avgpool(x) |
| x = torch.flatten(x, 1) |
| x = self.fc(x) |
| return x |
|
|
|
|
| |
| |
| |
| def train_model(model, train_loader, val_loader, test_loader, device, num_epochs=100, resume_from=None): |
| """Optimized training loop with mixed precision and better checkpointing""" |
| |
| os.makedirs('./outputs', exist_ok=True) |
| |
| criterion = nn.CrossEntropyLoss(label_smoothing=0.1) |
| optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4, nesterov=True) |
| scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2) |
| |
| |
| scaler = GradScaler() |
| |
| best_val_acc = 0.0 |
| train_losses, val_accuracies, learning_rates = [], [], [] |
| start_epoch = 0 |
| |
| |
| if resume_from and os.path.exists(resume_from): |
| logger.info(f"Resuming from {resume_from}") |
| checkpoint = torch.load(resume_from, map_location=device) |
| model.load_state_dict(checkpoint['model_state_dict']) |
| optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
| start_epoch = checkpoint['epoch'] |
| best_val_acc = checkpoint.get('best_val_accuracy', 0.0) |
| train_losses = checkpoint.get('train_losses', []) |
| val_accuracies = checkpoint.get('val_accuracies', []) |
| learning_rates = checkpoint.get('learning_rates', []) |
| |
| logger.info(f"π Starting training from epoch {start_epoch+1} for {num_epochs} total epochs...") |
| |
| |
| total_train_time = 0 |
| |
| for epoch in range(start_epoch, num_epochs): |
| epoch_start = time.time() |
| |
| |
| model.train() |
| running_loss = 0.0 |
| train_pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]', leave=False) |
| |
| for batch_idx, (images, labels) in enumerate(train_pbar): |
| images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True) |
| |
| optimizer.zero_grad() |
| |
| |
| with autocast(): |
| outputs = model(images) |
| loss = criterion(outputs, labels) |
| |
| |
| scaler.scale(loss).backward() |
| scaler.step(optimizer) |
| scaler.update() |
| |
| running_loss += loss.item() |
| train_pbar.set_postfix({'loss': f'{loss.item():.4f}', 'lr': f'{optimizer.param_groups[0]["lr"]:.6f}'}) |
| |
| avg_train_loss = running_loss / len(train_loader) |
| train_losses.append(avg_train_loss) |
| learning_rates.append(optimizer.param_groups[0]['lr']) |
| |
| |
| model.eval() |
| val_loss = 0.0 |
| correct = 0 |
| total = 0 |
| val_pbar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Val]', leave=False) |
| |
| with torch.no_grad(): |
| for images, labels in val_pbar: |
| images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True) |
| |
| with autocast(): |
| outputs = model(images) |
| loss = criterion(outputs, labels) |
| |
| val_loss += loss.item() |
| _, predicted = torch.max(outputs, 1) |
| total += labels.size(0) |
| correct += (predicted == labels).sum().item() |
| |
| val_pbar.set_postfix({'acc': f'{100.*correct/total:.2f}%'}) |
| |
| val_acc = 100. * correct / total |
| val_accuracies.append(val_acc) |
| avg_val_loss = val_loss / len(val_loader) |
| |
| |
| is_best = val_acc > best_val_acc |
| if is_best: |
| best_val_acc = val_acc |
| |
| |
| if (epoch + 1) % 10 == 0 or is_best or epoch == num_epochs - 1: |
| checkpoint = { |
| 'epoch': epoch + 1, |
| 'model_state_dict': model.state_dict(), |
| 'optimizer_state_dict': optimizer.state_dict(), |
| 'scaler_state_dict': scaler.state_dict(), |
| 'best_val_accuracy': best_val_acc, |
| 'current_val_accuracy': val_acc, |
| 'train_losses': train_losses, |
| 'val_accuracies': val_accuracies, |
| 'learning_rates': learning_rates, |
| } |
| |
| if is_best: |
| torch.save(checkpoint, './outputs/food101_resnet50_best.pth') |
| |
| torch.save(model.state_dict(), './outputs/food101_resnet50_best_weights.pth') |
| |
| if (epoch + 1) % 10 == 0: |
| torch.save(checkpoint, f'./outputs/food101_resnet50_epoch_{epoch+1}.pth') |
| |
| scheduler.step() |
| epoch_time = time.time() - epoch_start |
| total_train_time += epoch_time |
| |
| logger.info(f"Epoch {epoch+1:3d}/{num_epochs} | " |
| f"Train Loss: {avg_train_loss:.4f} | " |
| f"Val Loss: {avg_val_loss:.4f} | " |
| f"Val Acc: {val_acc:.2f}% | " |
| f"Best: {best_val_acc:.2f}% | " |
| f"LR: {optimizer.param_groups[0]['lr']:.6f} | " |
| f"Time: {epoch_time:.1f}s") |
| |
| |
| final_checkpoint = { |
| 'epoch': num_epochs, |
| 'model_state_dict': model.state_dict(), |
| 'optimizer_state_dict': optimizer.state_dict(), |
| 'scaler_state_dict': scaler.state_dict(), |
| 'final_val_accuracy': val_accuracies[-1], |
| 'best_val_accuracy': best_val_acc, |
| 'train_losses': train_losses, |
| 'val_accuracies': val_accuracies, |
| 'learning_rates': learning_rates, |
| 'total_train_time': total_train_time, |
| } |
| torch.save(final_checkpoint, './outputs/food101_resnet50_final.pth') |
| torch.save(model.state_dict(), './outputs/food101_resnet50_final_weights.pth') |
| |
| logger.info(f"π Total training time: {total_train_time/3600:.2f} hours") |
| |
| |
| test_acc = evaluate_model(model, test_loader, device, "Test") |
| logger.info(f"π― Final Test Accuracy: {test_acc:.2f}%") |
| |
| |
| plot_training_curves(train_losses, val_accuracies, learning_rates) |
| |
| return best_val_acc, train_losses, val_accuracies |
|
|
|
|
| def evaluate_model(model, test_loader, device, dataset_name="Test"): |
| """Evaluate model with progress bar""" |
| model.eval() |
| correct = 0 |
| total = 0 |
| test_pbar = tqdm(test_loader, desc=f'{dataset_name} Evaluation', leave=False) |
| |
| with torch.no_grad(): |
| for images, labels in test_pbar: |
| images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True) |
| |
| with autocast(): |
| outputs = model(images) |
| |
| _, predicted = torch.max(outputs, 1) |
| total += labels.size(0) |
| correct += (predicted == labels).sum().item() |
| |
| test_pbar.set_postfix({'acc': f'{100.*correct/total:.2f}%'}) |
| |
| return 100. * correct / total |
|
|
|
|
| def plot_training_curves(train_losses, val_accuracies, learning_rates): |
| """Enhanced plotting with more visualizations""" |
| epochs = np.arange(1, len(train_losses) + 1) |
| |
| plt.style.use('default') |
| fig, axes = plt.subplots(2, 2, figsize=(16, 12)) |
| fig.suptitle('Food101 ResNet50 Training Analysis', fontsize=16, fontweight='bold') |
| |
| |
| axes[0, 0].plot(epochs, train_losses, 'b-', linewidth=2, alpha=0.8) |
| axes[0, 0].set_title('Training Loss Over Time', fontweight='bold') |
| axes[0, 0].set_xlabel('Epoch') |
| axes[0, 0].set_ylabel('Loss') |
| axes[0, 0].grid(True, alpha=0.3) |
| axes[0, 0].set_yscale('log') |
| |
| |
| axes[0, 1].plot(epochs, val_accuracies, 'r-', linewidth=2, alpha=0.8) |
| axes[0, 1].set_title('Validation Accuracy Over Time', fontweight='bold') |
| axes[0, 1].set_xlabel('Epoch') |
| axes[0, 1].set_ylabel('Accuracy (%)') |
| axes[0, 1].grid(True, alpha=0.3) |
| axes[0, 1].axhline(y=max(val_accuracies), color='r', linestyle='--', alpha=0.7, |
| label=f'Best: {max(val_accuracies):.2f}%') |
| axes[0, 1].legend() |
| |
| |
| axes[1, 0].plot(epochs, learning_rates, 'g-', linewidth=2, alpha=0.8) |
| axes[1, 0].set_title('Learning Rate Schedule', fontweight='bold') |
| axes[1, 0].set_xlabel('Epoch') |
| axes[1, 0].set_ylabel('Learning Rate') |
| axes[1, 0].grid(True, alpha=0.3) |
| axes[1, 0].set_yscale('log') |
| |
| |
| ax_combined = axes[1, 1] |
| ax_combined.plot(epochs, train_losses, 'b-', label='Train Loss', linewidth=2, alpha=0.8) |
| ax_combined.set_xlabel('Epoch') |
| ax_combined.set_ylabel('Loss', color='b') |
| ax_combined.tick_params(axis='y', labelcolor='b') |
| ax_combined.set_yscale('log') |
| |
| ax2 = ax_combined.twinx() |
| ax2.plot(epochs, val_accuracies, 'r-', label='Val Accuracy', linewidth=2, alpha=0.8) |
| ax2.set_ylabel('Accuracy (%)', color='r') |
| ax2.tick_params(axis='y', labelcolor='r') |
| |
| ax_combined.set_title('Loss vs Accuracy', fontweight='bold') |
| ax_combined.grid(True, alpha=0.3) |
| |
| plt.tight_layout() |
| plt.savefig('./outputs/training_analysis.png', dpi=300, bbox_inches='tight') |
| plt.close() |
| |
| |
| plt.figure(figsize=(12, 6)) |
| plt.plot(epochs, val_accuracies, 'r-', linewidth=2, alpha=0.8) |
| plt.fill_between(epochs, val_accuracies, alpha=0.3) |
| plt.title('Validation Accuracy Progress', fontsize=14, fontweight='bold') |
| plt.xlabel('Epoch') |
| plt.ylabel('Accuracy (%)') |
| plt.grid(True, alpha=0.3) |
| plt.axhline(y=max(val_accuracies), color='r', linestyle='--', alpha=0.7, |
| label=f'Peak Accuracy: {max(val_accuracies):.2f}%') |
| plt.legend() |
| plt.tight_layout() |
| plt.savefig('./outputs/accuracy_detail.png', dpi=300, bbox_inches='tight') |
| plt.close() |
| |
| logger.info("π Saved enhanced training visualizations") |
|
|
|
|
| def save_classes(classes): |
| """Save Food101 class names with better formatting""" |
| os.makedirs('./outputs', exist_ok=True) |
| |
| with open('./outputs/food101_classes.txt', 'w') as f: |
| f.write("Food101 Classes (101 total)\n") |
| f.write("=" * 30 + "\n\n") |
| for i, cls in enumerate(sorted(classes), 1): |
| f.write(f"{i:3d}. {cls.replace('_', ' ').title()}\n") |
| |
| |
| with open('./outputs/food101_classes_simple.txt', 'w') as f: |
| for cls in sorted(classes): |
| f.write(f"{cls}\n") |
| |
| logger.info("π Saved class names to ./outputs/") |
|
|
|
|
| def print_system_info(): |
| """Print system information for debugging""" |
| logger.info("π₯οΈ System Information:") |
| logger.info(f"PyTorch version: {torch.__version__}") |
| logger.info(f"CUDA available: {torch.cuda.is_available()}") |
| if torch.cuda.is_available(): |
| logger.info(f"CUDA version: {torch.version.cuda}") |
| logger.info(f"GPU: {torch.cuda.get_device_name()}") |
| logger.info(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") |
| logger.info(f"Number of CPU cores: {os.cpu_count()}") |
|
|
|
|
| |
| |
| |
| def main(): |
| print_system_info() |
| |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| logger.info(f"Using device: {device}") |
| |
| try: |
| |
| logger.info("π₯ Loading Food101 dataset...") |
| train_loader, val_loader, test_loader, classes = get_food101_loaders(batch_size=64, num_workers=8) |
| save_classes(classes) |
| |
| |
| logger.info("ποΈ Building ResNet50...") |
| model = ResNet50(num_classes=101).to(device) |
| total_params = sum(p.numel() for p in model.parameters()) |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| logger.info(f"Total parameters: {total_params/1e6:.1f}M") |
| logger.info(f"Trainable parameters: {trainable_params/1e6:.1f}M") |
| |
| |
| if hasattr(torch, 'compile'): |
| logger.info("π Compiling model for faster training...") |
| model = torch.compile(model) |
| |
| |
| best_val_acc, losses, accuracies = train_model( |
| model, train_loader, val_loader, test_loader, device, |
| num_epochs=100, resume_from='./outputs/food101_resnet50_best.pth' if os.path.exists('./outputs/food101_resnet50_best.pth') else None |
| ) |
| |
| logger.info(f"\nπ TRAINING COMPLETE!") |
| logger.info(f"π Best Validation Accuracy: {best_val_acc:.2f}%") |
| logger.info(f"\nπ SAVED FILES:") |
| logger.info(f" β’ ./outputs/food101_resnet50_best.pth (best checkpoint)") |
| logger.info(f" β’ ./outputs/food101_resnet50_best_weights.pth (best weights only)") |
| logger.info(f" β’ ./outputs/food101_resnet50_final.pth (final checkpoint)") |
| logger.info(f" β’ ./outputs/food101_resnet50_final_weights.pth (final weights only)") |
| logger.info(f" β’ ./outputs/training_analysis.png (comprehensive plots)") |
| logger.info(f" β’ ./outputs/accuracy_detail.png (detailed accuracy)") |
| logger.info(f" β’ ./outputs/food101_classes.txt (formatted class list)") |
| logger.info(f" β’ ./outputs/food101_classes_simple.txt (simple class list)") |
| |
| except Exception as e: |
| logger.error(f"β Training failed with error: {e}") |
| raise |
|
|
|
|
| if __name__ == "__main__": |
| main() |