| | from utils import CustomDataset, transform, preproc, Convert_ONNX |
| | from torch.utils.data import Dataset, DataLoader |
| | import torch |
| | import numpy as np |
| | from resnet_model import ResidualBlock, ResNet |
| | import torch |
| | import torch.nn as nn |
| | import torch.optim as optim |
| | import tqdm |
| | import torch.nn.functional as F |
| | from torch.optim.lr_scheduler import ReduceLROnPlateau |
| | import pickle |
| | import sys |
| |
|
| | ind = int(sys.argv[1]) |
| | seeds = [1,42,7109,2002,32] |
| | seed = seeds[ind] |
| | print("using seed: ",seed) |
| | torch.manual_seed(seed) |
| |
|
| |
|
| | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
| | num_gpus = torch.cuda.device_count() |
| | print(num_gpus) |
| |
|
| | |
| | data_dir = '/mnt/buf1/pma/frbnn/train_ready' |
| | dataset = CustomDataset(data_dir, transform=transform) |
| | valid_data_dir = '/mnt/buf1/pma/frbnn/valid_ready' |
| | valid_dataset = CustomDataset(valid_data_dir, transform=transform) |
| |
|
| |
|
| | num_classes = 2 |
| | trainloader = DataLoader(dataset, batch_size=420, shuffle=True, num_workers=32) |
| | validloader = DataLoader(valid_dataset, batch_size=512, shuffle=True, num_workers=32) |
| |
|
| | model = ResNet(24, ResidualBlock, [3, 4, 6, 3], num_classes=num_classes).to(device) |
| | model = nn.DataParallel(model) |
| | model = model.to(device) |
| | params = sum(p.numel() for p in model.parameters()) |
| | print("num params ",params) |
| | torch.save(model.state_dict(), 'models/test.pt') |
| | model.load_state_dict(torch.load('models/test.pt')) |
| |
|
| | preproc_model = preproc() |
| | Convert_ONNX(model.module,'models/test.onnx', input_data_mock=torch.randn(1, 24, 192, 256).to(device)) |
| | Convert_ONNX(preproc_model,'models/preproc.onnx', input_data_mock=torch.randn(32, 192, 2048).to(device)) |
| |
|
| | |
| |
|
| | criterion = nn.CrossEntropyLoss(weight = torch.tensor([1,1]).to(device)) |
| | optimizer = optim.Adam(model.parameters(), lr=0.0001) |
| | scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10) |
| |
|
| |
|
| | from tqdm import tqdm |
| | |
| | epochs = 10000 |
| | for epoch in range(epochs): |
| | running_loss = 0.0 |
| | correct_train = 0 |
| | total_train = 0 |
| | with tqdm(trainloader, unit="batch") as tepoch: |
| | model.train() |
| | for i, (images, labels) in enumerate(tepoch): |
| | inputs, labels = images.to(device), labels.to(device).float() |
| | optimizer.zero_grad() |
| | outputs = model(inputs, return_mask=False).to(device) |
| | new_label = F.one_hot(labels.type(torch.int64),num_classes=2).type(torch.float32).to(device) |
| | loss = criterion(outputs, new_label) |
| | loss.backward() |
| | optimizer.step() |
| | running_loss += loss.item() |
| | |
| | _, predicted = torch.max(outputs.data, 1) |
| | total_train += labels.size(0) |
| | correct_train += (predicted == labels).sum().item() |
| | val_loss = 0.0 |
| | correct_valid = 0 |
| | total = 0 |
| | model.eval() |
| | with torch.no_grad(): |
| | for images, labels in validloader: |
| | inputs, labels = images.to(device), labels.to(device).float() |
| | optimizer.zero_grad() |
| | outputs = model(inputs, return_mask=False) |
| | new_label = F.one_hot(labels.type(torch.int64),num_classes=2).type(torch.float32) |
| | loss = criterion(outputs, new_label) |
| | val_loss += loss.item() |
| | _, predicted = torch.max(outputs, 1) |
| | total += labels.size(0) |
| | correct_valid += (predicted == labels).sum().item() |
| | scheduler.step(val_loss) |
| | |
| | train_accuracy = 100 * correct_train / total_train |
| | val_accuracy = correct_valid / total * 100.0 |
| | torch.save(model.state_dict(), 'models/model-'+str(epoch)+'-'+str(val_accuracy)+'.pt') |
| | Convert_ONNX(model.module,'models/model-'+str(epoch)+'-'+str(val_accuracy)+'.onnx', input_data_mock=inputs) |
| |
|
| | print("===========================") |
| | print('accuracy: ', epoch, train_accuracy, val_accuracy) |
| | print('learning rate: ', scheduler.get_last_lr()) |
| | print("===========================") |
| | if scheduler.get_last_lr()[0] <1e-6: |
| | break |
| |
|