| | import os |
| | import json |
| | import time |
| | import argparse |
| | import pathlib |
| | from tqdm import tqdm |
| | import matplotlib.pyplot as plt |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torchvision import datasets |
| | from torch.utils.data import DataLoader |
| | import torchvision.transforms as transforms |
| | from torch.optim.lr_scheduler import _LRScheduler |
| | import traceback |
| |
|
| | CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343) |
| | CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404) |
| | MILESTONES = [60, 120, 160] |
| |
|
| |
|
| | class WideBasicBlock(nn.Module): |
| | def __init__(self, in_planes, out_planes, dropout_rate, stride=1): |
| | super(WideBasicBlock, self).__init__() |
| | self.bn1 = nn.BatchNorm2d(in_planes) |
| | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) |
| | self.dropout = nn.Dropout(p=dropout_rate) |
| | self.bn2 = nn.BatchNorm2d(out_planes) |
| | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False) |
| | self.relu = nn.ReLU(inplace=True) |
| |
|
| | if in_planes != out_planes: |
| | self.shortcut = nn.Conv2d( |
| | in_planes, |
| | out_planes, |
| | kernel_size=1, |
| | stride=stride, |
| | padding=0, |
| | bias=False, |
| | ) |
| | else: |
| | self.shortcut = nn.Identity() |
| |
|
| | def forward(self, x): |
| | out = self.relu(self.bn1(x)) |
| | skip_x = x if isinstance(self.shortcut, nn.Identity) else out |
| |
|
| | out = self.conv1(out) |
| | out = self.relu(self.bn2(out)) |
| | out = self.dropout(out) |
| | out = self.conv2(out) |
| | out += self.shortcut(skip_x) |
| |
|
| | return out |
| |
|
| |
|
| | class WideResNet(nn.Module): |
| | def __init__(self, depth, widen_factor, num_classes, dropout_rate): |
| | super(WideResNet, self).__init__() |
| |
|
| | assert (depth - 4) % 6 == 0, "Wide-resnet depth should be 6n+4" |
| | n = (depth - 4) / 6 |
| |
|
| | n_stages = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] |
| |
|
| | self.conv1 = nn.Conv2d(3, n_stages[0], kernel_size=3, stride=1, padding=1, bias=False) |
| | self.stage1 = self._make_wide_stage(WideBasicBlock, n_stages[0], n_stages[1], n, dropout_rate, stride=1) |
| | self.stage2 = self._make_wide_stage(WideBasicBlock, n_stages[1], n_stages[2], n, dropout_rate, stride=2) |
| | self.stage3 = self._make_wide_stage(WideBasicBlock, n_stages[2], n_stages[3], n, dropout_rate, stride=2) |
| | self.bn1 = nn.BatchNorm2d(n_stages[3]) |
| | self.relu = nn.ReLU(inplace=True) |
| | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) |
| | self.linear = nn.Linear(n_stages[3], num_classes) |
| |
|
| | self._init_params() |
| |
|
| | @staticmethod |
| | def _make_wide_stage(block, in_planes, out_planes, num_blocks, dropout_rate, stride): |
| | stride_list = [stride] + [1] * (int(num_blocks) - 1) |
| | in_planes_list = [in_planes] + [out_planes] * (int(num_blocks) - 1) |
| | blocks = [] |
| |
|
| | for _in_planes, _stride in zip(in_planes_list, stride_list): |
| | blocks.append(block(_in_planes, out_planes, dropout_rate, _stride)) |
| |
|
| | return nn.Sequential(*blocks) |
| |
|
| | def _init_params(self): |
| | for m in self.modules(): |
| | if isinstance(m, nn.Conv2d): |
| | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") |
| | elif isinstance(m, nn.BatchNorm2d): |
| | if m.affine: |
| | m.weight.data.fill_(1) |
| | m.bias.data.zero_() |
| | elif isinstance(m, nn.Linear): |
| | if m.bias is not None: |
| | m.bias.data.zero_() |
| |
|
| | def forward(self, x): |
| | out = self.conv1(x) |
| | out = self.stage1(out) |
| | out = self.stage2(out) |
| | out = self.stage3(out) |
| | out = self.relu(self.bn1(out)) |
| | out = self.avg_pool(out) |
| | out = out.view(out.size(0), -1) |
| | out = self.linear(out) |
| |
|
| | return out |
| |
|
| |
|
| | def wide_resnet_28_10_old(): |
| | return WideResNet( |
| | depth=28, |
| | widen_factor=10, |
| | num_classes=100, |
| | dropout_rate=0.0, |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--batch_size", type=int, default=128) |
| | parser.add_argument("--num_workers", type=int, default=4) |
| | parser.add_argument("--out_dir", type=str, default="run_1") |
| | parser.add_argument("--in_channels", type=int, default=3) |
| | parser.add_argument("--data_root", type=str, default='./datasets/cifar100/') |
| | parser.add_argument("--learning_rate", type=float, default=0.1) |
| | parser.add_argument("", type=int, default=200) |
| | parser.add_argument("--val_per_epoch", type=int, default=5) |
| | config = parser.parse_args() |
| |
|
| |
|
| | try: |
| | final_infos = {} |
| | all_results = {} |
| |
|
| | pathlib.Path(config.out_dir).mkdir(parents=True, exist_ok=True) |
| |
|
| | model = wide_resnet_28_10_old().cuda() |
| | transform_train = transforms.Compose([ |
| | transforms.ToTensor(), |
| | transforms.Lambda(lambda x: F.pad(x.unsqueeze(0), |
| | (4, 4, 4, 4), mode='reflect').squeeze()), |
| | transforms.ToPILImage(), |
| | transforms.RandomCrop(32), |
| | transforms.RandomHorizontalFlip(), |
| | transforms.ToTensor(), |
| | transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD), |
| | ]) |
| |
|
| | transform_test = transforms.Compose([ |
| | transforms.ToTensor(), |
| | transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD) |
| | ]) |
| | train_dataset = datasets.CIFAR100(root=config.data_root, train=True, |
| | download=True, transform=transform_train) |
| | test_dataset = datasets.CIFAR100(root=config.data_root, train=False, |
| | download=True, transform=transform_test) |
| | train_loader = DataLoader(train_dataset, shuffle=True, num_workers=config.num_workers, batch_size=config.batch_size) |
| | test_loader = DataLoader(test_dataset, shuffle=True, num_workers=config.num_workers, batch_size=config.batch_size) |
| |
|
| | criterion = nn.CrossEntropyLoss().cuda() |
| | optimizer = torch.optim.SGD(model.parameters(), lr=config.learning_rate, momentum=0.9, weight_decay=5e-4, |
| | nesterov=True) |
| | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(train_loader) * config.max_epoch) |
| |
|
| | best_acc = 0.0 |
| | start_time = time.time() |
| | for cur_epoch in tqdm(range(1, config.max_epoch + 1)): |
| | model.train() |
| | for batch_idx, (images, labels) in enumerate(tqdm(train_loader)): |
| | images, labels = images.cuda(), labels.cuda() |
| | optimizer.zero_grad() |
| | outputs = model(images) |
| | loss = criterion(outputs, labels) |
| | loss.backward() |
| | optimizer.step() |
| | scheduler.step() |
| |
|
| | print(f'Finished epoch {cur_epoch} training.') |
| |
|
| | if (cur_epoch % config.val_per_epoch == 0 and cur_epoch != 0) or cur_epoch == (config.max_epoch - 1): |
| | model.eval() |
| | correct = 0.0 |
| | for images, labels in tqdm(test_loader): |
| | images, labels = images.cuda(), labels.cuda() |
| | with torch.no_grad(): |
| | outputs = model(images) |
| |
|
| | _, preds = outputs.max(1) |
| | correct += preds.eq(labels).sum() |
| | cur_acc = correct.float() / len(test_loader.dataset) |
| | print(f"Epoch: {cur_epoch}, Accuracy: {correct.float() / len(test_loader.dataset)}") |
| |
|
| | if cur_acc > best_acc: |
| | best_acc = cur_acc |
| | best_epoch = cur_epoch |
| | torch.save(model.state_dict(), os.path.join(config.out_dir, 'best.pth')) |
| |
|
| | final_infos = { |
| | "cifar100": { |
| | "means": { |
| | "best_acc": best_acc.item(), |
| | "epoch": best_epoch |
| | } |
| | } |
| | } |
| |
|
| | with open(os.path.join(config.out_dir, "final_info.json"), "w") as f: |
| | json.dump(final_infos, f) |
| |
|
| | except Exception as e: |
| | print("Original error in subprocess:", flush=True) |
| | traceback.print_exc(file=open(os.path.join(config.out_dir, "traceback.log"), "w")) |
| | raise |