Spaces:
Sleeping
Sleeping
File size: 5,079 Bytes
2eba0cc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 | import sys
import os
import math
from math import cos, sin
from pathlib import Path
import subprocess
import re
import numpy as np
import torch
import torch.nn as nn
import scipy.io as sio
import cv2
import torchvision
from torchvision import transforms
from .model import L2CS
transformations = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize(448),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
def atoi(text):
return int(text) if text.isdigit() else text
def natural_keys(text):
'''
alist.sort(key=natural_keys) sorts in human order
http://nedbatchelder.com/blog/200712/human_sorting.html
(See Toothy's implementation in the comments)
'''
return [ atoi(c) for c in re.split(r'(\d+)', text) ]
def prep_input_numpy(img:np.ndarray, device:str):
"""Preparing a Numpy Array as input to L2CS-Net."""
if len(img.shape) == 4:
imgs = []
for im in img:
imgs.append(transformations(im))
img = torch.stack(imgs)
else:
img = transformations(img)
img = img.to(device)
if len(img.shape) == 3:
img = img.unsqueeze(0)
return img
def gazeto3d(gaze):
gaze_gt = np.zeros([3])
gaze_gt[0] = -np.cos(gaze[1]) * np.sin(gaze[0])
gaze_gt[1] = -np.sin(gaze[1])
gaze_gt[2] = -np.cos(gaze[1]) * np.cos(gaze[0])
return gaze_gt
def angular(gaze, label):
total = np.sum(gaze * label)
return np.arccos(min(total/(np.linalg.norm(gaze)* np.linalg.norm(label)), 0.9999999))*180/np.pi
def select_device(device='', batch_size=None):
# device = 'cpu' or '0' or '0,1,2,3'
s = f'YOLOv3 🚀 {git_describe() or date_modified()} torch {torch.__version__} ' # string
cpu = device.lower() == 'cpu'
if cpu:
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
elif device: # non-cpu device requested
os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable
# assert torch.cuda.is_available(), f'CUDA unavailable, invalid device {device} requested' # check availability
cuda = not cpu and torch.cuda.is_available()
if cuda:
devices = device.split(',') if device else range(torch.cuda.device_count()) # i.e. 0,1,6,7
n = len(devices) # device count
if n > 1 and batch_size: # check batch_size is divisible by device_count
assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
space = ' ' * len(s)
for i, d in enumerate(devices):
p = torch.cuda.get_device_properties(i)
s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2}MB)\n" # bytes to MB
else:
s += 'CPU\n'
return torch.device('cuda:0' if cuda else 'cpu')
def spherical2cartesial(x):
output = torch.zeros(x.size(0),3)
output[:,2] = -torch.cos(x[:,1])*torch.cos(x[:,0])
output[:,0] = torch.cos(x[:,1])*torch.sin(x[:,0])
output[:,1] = torch.sin(x[:,1])
return output
def compute_angular_error(input,target):
input = spherical2cartesial(input)
target = spherical2cartesial(target)
input = input.view(-1,3,1)
target = target.view(-1,1,3)
output_dot = torch.bmm(target,input)
output_dot = output_dot.view(-1)
output_dot = torch.acos(output_dot)
output_dot = output_dot.data
output_dot = 180*torch.mean(output_dot)/math.pi
return output_dot
def softmax_temperature(tensor, temperature):
result = torch.exp(tensor / temperature)
result = torch.div(result, torch.sum(result, 1).unsqueeze(1).expand_as(result))
return result
def git_describe(path=Path(__file__).parent): # path must be a directory
# return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
s = f'git -C {path} describe --tags --long --always'
try:
return subprocess.check_output(s, shell=True, stderr=subprocess.STDOUT).decode()[:-1]
except subprocess.CalledProcessError as e:
return '' # not a git repository
def getArch(arch,bins):
# Base network structure
if arch == 'ResNet18':
model = L2CS( torchvision.models.resnet.BasicBlock,[2, 2, 2, 2], bins)
elif arch == 'ResNet34':
model = L2CS( torchvision.models.resnet.BasicBlock,[3, 4, 6, 3], bins)
elif arch == 'ResNet101':
model = L2CS( torchvision.models.resnet.Bottleneck,[3, 4, 23, 3], bins)
elif arch == 'ResNet152':
model = L2CS( torchvision.models.resnet.Bottleneck,[3, 8, 36, 3], bins)
else:
if arch != 'ResNet50':
print('Invalid value for architecture is passed! '
'The default value of ResNet50 will be used instead!')
model = L2CS( torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], bins)
return model
|