Spaces:
Sleeping
Sleeping
| import os | |
| import numpy as np | |
| import cv2 | |
| import torch | |
| from torch.utils.data.dataset import Dataset | |
| from torchvision import transforms | |
| from PIL import Image, ImageFilter | |
| class Gaze360(Dataset): | |
| def __init__(self, path, root, transform, angle, binwidth, train=True): | |
| self.transform = transform | |
| self.root = root | |
| self.orig_list_len = 0 | |
| self.angle = angle | |
| if train==False: | |
| angle=90 | |
| self.binwidth=binwidth | |
| self.lines = [] | |
| if isinstance(path, list): | |
| for i in path: | |
| with open(i) as f: | |
| line = f.readlines() | |
| line.pop(0) | |
| self.lines.extend(line) | |
| else: | |
| with open(path) as f: | |
| lines = f.readlines() | |
| lines.pop(0) | |
| self.orig_list_len = len(lines) | |
| for line in lines: | |
| gaze2d = line.strip().split(" ")[5] | |
| label = np.array(gaze2d.split(",")).astype("float") | |
| if abs((label[0]*180/np.pi)) <= angle and abs((label[1]*180/np.pi)) <= angle: | |
| self.lines.append(line) | |
| print("{} items removed from dataset that have an angle > {}".format(self.orig_list_len-len(self.lines), angle)) | |
| def __len__(self): | |
| return len(self.lines) | |
| def __getitem__(self, idx): | |
| line = self.lines[idx] | |
| line = line.strip().split(" ") | |
| face = line[0] | |
| lefteye = line[1] | |
| righteye = line[2] | |
| name = line[3] | |
| gaze2d = line[5] | |
| label = np.array(gaze2d.split(",")).astype("float") | |
| label = torch.from_numpy(label).type(torch.FloatTensor) | |
| pitch = label[0]* 180 / np.pi | |
| yaw = label[1]* 180 / np.pi | |
| img = Image.open(os.path.join(self.root, face)) | |
| # fimg = cv2.imread(os.path.join(self.root, face)) | |
| # fimg = cv2.resize(fimg, (448, 448))/255.0 | |
| # fimg = fimg.transpose(2, 0, 1) | |
| # img=torch.from_numpy(fimg).type(torch.FloatTensor) | |
| if self.transform: | |
| img = self.transform(img) | |
| # Bin values | |
| bins = np.array(range(-1*self.angle, self.angle, self.binwidth)) | |
| binned_pose = np.digitize([pitch, yaw], bins) - 1 | |
| labels = binned_pose | |
| cont_labels = torch.FloatTensor([pitch, yaw]) | |
| return img, labels, cont_labels, name | |
| class Mpiigaze(Dataset): | |
| def __init__(self, pathorg, root, transform, train, angle,fold=0): | |
| self.transform = transform | |
| self.root = root | |
| self.orig_list_len = 0 | |
| self.lines = [] | |
| path=pathorg.copy() | |
| if train==True: | |
| path.pop(fold) | |
| else: | |
| path=path[fold] | |
| if isinstance(path, list): | |
| for i in path: | |
| with open(i) as f: | |
| lines = f.readlines() | |
| lines.pop(0) | |
| self.orig_list_len += len(lines) | |
| for line in lines: | |
| gaze2d = line.strip().split(" ")[7] | |
| label = np.array(gaze2d.split(",")).astype("float") | |
| if abs((label[0]*180/np.pi)) <= angle and abs((label[1]*180/np.pi)) <= angle: | |
| self.lines.append(line) | |
| else: | |
| with open(path) as f: | |
| lines = f.readlines() | |
| lines.pop(0) | |
| self.orig_list_len += len(lines) | |
| for line in lines: | |
| gaze2d = line.strip().split(" ")[7] | |
| label = np.array(gaze2d.split(",")).astype("float") | |
| if abs((label[0]*180/np.pi)) <= 42 and abs((label[1]*180/np.pi)) <= 42: | |
| self.lines.append(line) | |
| print("{} items removed from dataset that have an angle > {}".format(self.orig_list_len-len(self.lines),angle)) | |
| def __len__(self): | |
| return len(self.lines) | |
| def __getitem__(self, idx): | |
| line = self.lines[idx] | |
| line = line.strip().split(" ") | |
| name = line[3] | |
| gaze2d = line[7] | |
| head2d = line[8] | |
| lefteye = line[1] | |
| righteye = line[2] | |
| face = line[0] | |
| label = np.array(gaze2d.split(",")).astype("float") | |
| label = torch.from_numpy(label).type(torch.FloatTensor) | |
| pitch = label[0]* 180 / np.pi | |
| yaw = label[1]* 180 / np.pi | |
| img = Image.open(os.path.join(self.root, face)) | |
| # fimg = cv2.imread(os.path.join(self.root, face)) | |
| # fimg = cv2.resize(fimg, (448, 448))/255.0 | |
| # fimg = fimg.transpose(2, 0, 1) | |
| # img=torch.from_numpy(fimg).type(torch.FloatTensor) | |
| if self.transform: | |
| img = self.transform(img) | |
| # Bin values | |
| bins = np.array(range(-42, 42,3)) | |
| binned_pose = np.digitize([pitch, yaw], bins) - 1 | |
| labels = binned_pose | |
| cont_labels = torch.FloatTensor([pitch, yaw]) | |
| return img, labels, cont_labels, name | |