Spaces:
Sleeping
Sleeping
| import os | |
| import argparse | |
| import time | |
| import torch.utils.model_zoo as model_zoo | |
| import torch | |
| import torch.nn as nn | |
| from torch.autograd import Variable | |
| from torch.utils.data import DataLoader | |
| from torchvision import transforms | |
| import torch.backends.cudnn as cudnn | |
| import torchvision | |
| from l2cs import L2CS, select_device, Gaze360, Mpiigaze | |
| def parse_args(): | |
| """Parse input arguments.""" | |
| parser = argparse.ArgumentParser(description='Gaze estimation using L2CSNet.') | |
| # Gaze360 | |
| parser.add_argument( | |
| '--gaze360image_dir', dest='gaze360image_dir', help='Directory path for gaze images.', | |
| default='datasets/Gaze360/Image', type=str) | |
| parser.add_argument( | |
| '--gaze360label_dir', dest='gaze360label_dir', help='Directory path for gaze labels.', | |
| default='datasets/Gaze360/Label/train.label', type=str) | |
| # mpiigaze | |
| parser.add_argument( | |
| '--gazeMpiimage_dir', dest='gazeMpiimage_dir', help='Directory path for gaze images.', | |
| default='datasets/MPIIFaceGaze/Image', type=str) | |
| parser.add_argument( | |
| '--gazeMpiilabel_dir', dest='gazeMpiilabel_dir', help='Directory path for gaze labels.', | |
| default='datasets/MPIIFaceGaze/Label', type=str) | |
| # Important args ------------------------------------------------------------------------------------------------------- | |
| # ---------------------------------------------------------------------------------------------------------------------- | |
| parser.add_argument( | |
| '--dataset', dest='dataset', help='mpiigaze, rtgene, gaze360, ethgaze', | |
| default= "gaze360", type=str) | |
| parser.add_argument( | |
| '--output', dest='output', help='Path of output models.', | |
| default='output/snapshots/', type=str) | |
| parser.add_argument( | |
| '--snapshot', dest='snapshot', help='Path of model snapshot.', | |
| default='', type=str) | |
| parser.add_argument( | |
| '--gpu', dest='gpu_id', help='GPU device id to use [0] or multiple 0,1,2,3', | |
| default='0', type=str) | |
| parser.add_argument( | |
| '--num_epochs', dest='num_epochs', help='Maximum number of training epochs.', | |
| default=60, type=int) | |
| parser.add_argument( | |
| '--batch_size', dest='batch_size', help='Batch size.', | |
| default=1, type=int) | |
| parser.add_argument( | |
| '--arch', dest='arch', help='Network architecture, can be: ResNet18, ResNet34, [ResNet50], ''ResNet101, ResNet152, Squeezenet_1_0, Squeezenet_1_1, MobileNetV2', | |
| default='ResNet50', type=str) | |
| parser.add_argument( | |
| '--alpha', dest='alpha', help='Regression loss coefficient.', | |
| default=1, type=float) | |
| parser.add_argument( | |
| '--lr', dest='lr', help='Base learning rate.', | |
| default=0.00001, type=float) | |
| # --------------------------------------------------------------------------------------------------------------------- | |
| # Important args ------------------------------------------------------------------------------------------------------ | |
| args = parser.parse_args() | |
| return args | |
| def get_ignored_params(model): | |
| # Generator function that yields ignored params. | |
| b = [model.conv1, model.bn1, model.fc_finetune] | |
| for i in range(len(b)): | |
| for module_name, module in b[i].named_modules(): | |
| if 'bn' in module_name: | |
| module.eval() | |
| for name, param in module.named_parameters(): | |
| yield param | |
| def get_non_ignored_params(model): | |
| # Generator function that yields params that will be optimized. | |
| b = [model.layer1, model.layer2, model.layer3, model.layer4] | |
| for i in range(len(b)): | |
| for module_name, module in b[i].named_modules(): | |
| if 'bn' in module_name: | |
| module.eval() | |
| for name, param in module.named_parameters(): | |
| yield param | |
| def get_fc_params(model): | |
| # Generator function that yields fc layer params. | |
| b = [model.fc_yaw_gaze, model.fc_pitch_gaze] | |
| for i in range(len(b)): | |
| for module_name, module in b[i].named_modules(): | |
| for name, param in module.named_parameters(): | |
| yield param | |
| def load_filtered_state_dict(model, snapshot): | |
| # By user apaszke from discuss.pytorch.org | |
| model_dict = model.state_dict() | |
| snapshot = {k: v for k, v in snapshot.items() if k in model_dict} | |
| model_dict.update(snapshot) | |
| model.load_state_dict(model_dict) | |
| def getArch_weights(arch, bins): | |
| if arch == 'ResNet18': | |
| model = L2CS(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], bins) | |
| pre_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' | |
| elif arch == 'ResNet34': | |
| model = L2CS(torchvision.models.resnet.BasicBlock, [3, 4, 6, 3], bins) | |
| pre_url = 'https://download.pytorch.org/models/resnet34-333f7ec4.pth' | |
| elif arch == 'ResNet101': | |
| model = L2CS(torchvision.models.resnet.Bottleneck, [3, 4, 23, 3], bins) | |
| pre_url = 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth' | |
| elif arch == 'ResNet152': | |
| model = L2CS(torchvision.models.resnet.Bottleneck,[3, 8, 36, 3], bins) | |
| pre_url = 'https://download.pytorch.org/models/resnet152-b121ed2d.pth' | |
| else: | |
| if arch != 'ResNet50': | |
| print('Invalid value for architecture is passed! ' | |
| 'The default value of ResNet50 will be used instead!') | |
| model = L2CS(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], bins) | |
| pre_url = 'https://download.pytorch.org/models/resnet50-19c8e357.pth' | |
| return model, pre_url | |
| if __name__ == '__main__': | |
| args = parse_args() | |
| cudnn.enabled = True | |
| num_epochs = args.num_epochs | |
| batch_size = args.batch_size | |
| gpu = select_device(args.gpu_id, batch_size=args.batch_size) | |
| data_set=args.dataset | |
| alpha = args.alpha | |
| output=args.output | |
| transformations = transforms.Compose([ | |
| transforms.Resize(448), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225] | |
| ) | |
| ]) | |
| if data_set=="gaze360": | |
| model, pre_url = getArch_weights(args.arch, 90) | |
| if args.snapshot == '': | |
| load_filtered_state_dict(model, model_zoo.load_url(pre_url)) | |
| else: | |
| saved_state_dict = torch.load(args.snapshot) | |
| model.load_state_dict(saved_state_dict) | |
| model.cuda(gpu) | |
| dataset=Gaze360(args.gaze360label_dir, args.gaze360image_dir, transformations, 180, 4) | |
| print('Loading data.') | |
| train_loader_gaze = DataLoader( | |
| dataset=dataset, | |
| batch_size=int(batch_size), | |
| shuffle=True, | |
| num_workers=0, | |
| pin_memory=True) | |
| torch.backends.cudnn.benchmark = True | |
| summary_name = '{}_{}'.format('L2CS-gaze360-', int(time.time())) | |
| output=os.path.join(output, summary_name) | |
| if not os.path.exists(output): | |
| os.makedirs(output) | |
| criterion = nn.CrossEntropyLoss().cuda(gpu) | |
| reg_criterion = nn.MSELoss().cuda(gpu) | |
| softmax = nn.Softmax(dim=1).cuda(gpu) | |
| idx_tensor = [idx for idx in range(90)] | |
| idx_tensor = Variable(torch.FloatTensor(idx_tensor)).cuda(gpu) | |
| # Optimizer gaze | |
| optimizer_gaze = torch.optim.Adam([ | |
| {'params': get_ignored_params(model), 'lr': 0}, | |
| {'params': get_non_ignored_params(model), 'lr': args.lr}, | |
| {'params': get_fc_params(model), 'lr': args.lr} | |
| ], args.lr) | |
| configuration = f"\ntrain configuration, gpu_id={args.gpu_id}, batch_size={batch_size}, model_arch={args.arch}\nStart testing dataset={data_set}, loader={len(train_loader_gaze)}------------------------- \n" | |
| print(configuration) | |
| for epoch in range(num_epochs): | |
| sum_loss_pitch_gaze = sum_loss_yaw_gaze = iter_gaze = 0 | |
| for i, (images_gaze, labels_gaze, cont_labels_gaze,name) in enumerate(train_loader_gaze): | |
| images_gaze = Variable(images_gaze).cuda(gpu) | |
| # Binned labels | |
| label_pitch_gaze = Variable(labels_gaze[:, 0]).cuda(gpu) | |
| label_yaw_gaze = Variable(labels_gaze[:, 1]).cuda(gpu) | |
| # Continuous labels | |
| label_pitch_cont_gaze = Variable(cont_labels_gaze[:, 0]).cuda(gpu) | |
| label_yaw_cont_gaze = Variable(cont_labels_gaze[:, 1]).cuda(gpu) | |
| pitch, yaw = model(images_gaze) | |
| # Cross entropy loss | |
| loss_pitch_gaze = criterion(pitch, label_pitch_gaze) | |
| loss_yaw_gaze = criterion(yaw, label_yaw_gaze) | |
| # MSE loss | |
| pitch_predicted = softmax(pitch) | |
| yaw_predicted = softmax(yaw) | |
| pitch_predicted = \ | |
| torch.sum(pitch_predicted * idx_tensor, 1) * 4 - 180 | |
| yaw_predicted = \ | |
| torch.sum(yaw_predicted * idx_tensor, 1) * 4 - 180 | |
| loss_reg_pitch = reg_criterion( | |
| pitch_predicted, label_pitch_cont_gaze) | |
| loss_reg_yaw = reg_criterion( | |
| yaw_predicted, label_yaw_cont_gaze) | |
| # Total loss | |
| loss_pitch_gaze += alpha * loss_reg_pitch | |
| loss_yaw_gaze += alpha * loss_reg_yaw | |
| sum_loss_pitch_gaze += loss_pitch_gaze | |
| sum_loss_yaw_gaze += loss_yaw_gaze | |
| loss_seq = [loss_pitch_gaze, loss_yaw_gaze] | |
| grad_seq = [torch.tensor(1.0).cuda(gpu) for _ in range(len(loss_seq))] | |
| optimizer_gaze.zero_grad(set_to_none=True) | |
| torch.autograd.backward(loss_seq, grad_seq) | |
| optimizer_gaze.step() | |
| # scheduler.step() | |
| iter_gaze += 1 | |
| if (i+1) % 100 == 0: | |
| print('Epoch [%d/%d], Iter [%d/%d] Losses: ' | |
| 'Gaze Yaw %.4f,Gaze Pitch %.4f' % ( | |
| epoch+1, | |
| num_epochs, | |
| i+1, | |
| len(dataset)//batch_size, | |
| sum_loss_pitch_gaze/iter_gaze, | |
| sum_loss_yaw_gaze/iter_gaze | |
| ) | |
| ) | |
| if epoch % 1 == 0 and epoch < num_epochs: | |
| print('Taking snapshot...', | |
| torch.save(model.state_dict(), | |
| output +'/'+ | |
| '_epoch_' + str(epoch+1) + '.pkl') | |
| ) | |
| elif data_set=="mpiigaze": | |
| folder = os.listdir(args.gazeMpiilabel_dir) | |
| folder.sort() | |
| testlabelpathombined = [os.path.join(args.gazeMpiilabel_dir, j) for j in folder] | |
| for fold in range(15): | |
| model, pre_url = getArch_weights(args.arch, 28) | |
| load_filtered_state_dict(model, model_zoo.load_url(pre_url)) | |
| model = nn.DataParallel(model) | |
| model.to(gpu) | |
| print('Loading data.') | |
| dataset=Mpiigaze(testlabelpathombined,args.gazeMpiimage_dir, transformations, True, fold) | |
| train_loader_gaze = DataLoader( | |
| dataset=dataset, | |
| batch_size=int(batch_size), | |
| shuffle=True, | |
| num_workers=4, | |
| pin_memory=True) | |
| torch.backends.cudnn.benchmark = True | |
| summary_name = '{}_{}'.format('L2CS-mpiigaze', int(time.time())) | |
| if not os.path.exists(os.path.join(output+'/{}'.format(summary_name),'fold' + str(fold))): | |
| os.makedirs(os.path.join(output+'/{}'.format(summary_name),'fold' + str(fold))) | |
| criterion = nn.CrossEntropyLoss().cuda(gpu) | |
| reg_criterion = nn.MSELoss().cuda(gpu) | |
| softmax = nn.Softmax(dim=1).cuda(gpu) | |
| idx_tensor = [idx for idx in range(28)] | |
| idx_tensor = Variable(torch.FloatTensor(idx_tensor)).cuda(gpu) | |
| # Optimizer gaze | |
| optimizer_gaze = torch.optim.Adam([ | |
| {'params': get_ignored_params(model, args.arch), 'lr': 0}, | |
| {'params': get_non_ignored_params(model, args.arch), 'lr': args.lr}, | |
| {'params': get_fc_params(model, args.arch), 'lr': args.lr} | |
| ], args.lr) | |
| configuration = f"\ntrain configuration, gpu_id={args.gpu_id}, batch_size={batch_size}, model_arch={args.arch}\n Start training dataset={data_set}, loader={len(train_loader_gaze)}, fold={fold}--------------\n" | |
| print(configuration) | |
| for epoch in range(num_epochs): | |
| sum_loss_pitch_gaze = sum_loss_yaw_gaze = iter_gaze = 0 | |
| for i, (images_gaze, labels_gaze, cont_labels_gaze,name) in enumerate(train_loader_gaze): | |
| images_gaze = Variable(images_gaze).cuda(gpu) | |
| # Binned labels | |
| label_pitch_gaze = Variable(labels_gaze[:, 0]).cuda(gpu) | |
| label_yaw_gaze = Variable(labels_gaze[:, 1]).cuda(gpu) | |
| # Continuous labels | |
| label_pitch_cont_gaze = Variable(cont_labels_gaze[:, 0]).cuda(gpu) | |
| label_yaw_cont_gaze = Variable(cont_labels_gaze[:, 1]).cuda(gpu) | |
| pitch, yaw = model(images_gaze) | |
| # Cross entropy loss | |
| loss_pitch_gaze = criterion(pitch, label_pitch_gaze) | |
| loss_yaw_gaze = criterion(yaw, label_yaw_gaze) | |
| # MSE loss | |
| pitch_predicted = softmax(pitch) | |
| yaw_predicted = softmax(yaw) | |
| pitch_predicted = \ | |
| torch.sum(pitch_predicted * idx_tensor, 1) * 3 - 42 | |
| yaw_predicted = \ | |
| torch.sum(yaw_predicted * idx_tensor, 1) * 3 - 42 | |
| loss_reg_pitch = reg_criterion( | |
| pitch_predicted, label_pitch_cont_gaze) | |
| loss_reg_yaw = reg_criterion( | |
| yaw_predicted, label_yaw_cont_gaze) | |
| # Total loss | |
| loss_pitch_gaze += alpha * loss_reg_pitch | |
| loss_yaw_gaze += alpha * loss_reg_yaw | |
| sum_loss_pitch_gaze += loss_pitch_gaze | |
| sum_loss_yaw_gaze += loss_yaw_gaze | |
| loss_seq = [loss_pitch_gaze, loss_yaw_gaze] | |
| grad_seq = \ | |
| [torch.tensor(1.0).cuda(gpu) for _ in range(len(loss_seq))] | |
| optimizer_gaze.zero_grad(set_to_none=True) | |
| torch.autograd.backward(loss_seq, grad_seq) | |
| optimizer_gaze.step() | |
| iter_gaze += 1 | |
| if (i+1) % 100 == 0: | |
| print('Epoch [%d/%d], Iter [%d/%d] Losses: ' | |
| 'Gaze Yaw %.4f,Gaze Pitch %.4f' % ( | |
| epoch+1, | |
| num_epochs, | |
| i+1, | |
| len(dataset)//batch_size, | |
| sum_loss_pitch_gaze/iter_gaze, | |
| sum_loss_yaw_gaze/iter_gaze | |
| ) | |
| ) | |
| # Save models at numbered epochs. | |
| if epoch % 1 == 0 and epoch < num_epochs: | |
| print('Taking snapshot...', | |
| torch.save(model.state_dict(), | |
| output+'/fold' + str(fold) +'/'+ | |
| '_epoch_' + str(epoch+1) + '.pkl') | |
| ) | |