import argparse import pathlib import numpy as np import cv2 import time import torch import torch.nn as nn from torch.autograd import Variable from torchvision import transforms import torch.backends.cudnn as cudnn import torchvision from PIL import Image from PIL import Image, ImageOps from face_detection import RetinaFace from l2cs import select_device, draw_gaze, getArch, Pipeline, render CWD = pathlib.Path.cwd() def parse_args(): """Parse input arguments.""" parser = argparse.ArgumentParser( description='Gaze evalution using model pretrained with L2CS-Net on Gaze360.') parser.add_argument( '--device',dest='device', help='Device to run model: cpu or gpu:0', default="cpu", type=str) parser.add_argument( '--snapshot',dest='snapshot', help='Path of model snapshot.', default='output/snapshots/L2CS-gaze360-_loader-180-4/_epoch_55.pkl', type=str) parser.add_argument( '--cam',dest='cam_id', help='Camera device id to use [0]', default=0, type=int) parser.add_argument( '--arch',dest='arch',help='Network architecture, can be: ResNet18, ResNet34, ResNet50, ResNet101, ResNet152', default='ResNet50', type=str) args = parser.parse_args() return args if __name__ == '__main__': args = parse_args() cudnn.enabled = True arch=args.arch cam = args.cam_id # snapshot_path = args.snapshot gaze_pipeline = Pipeline( weights=CWD / 'models' / 'L2CSNet_gaze360.pkl', arch='ResNet50', device = select_device(args.device, batch_size=1) ) cap = cv2.VideoCapture(cam) # Check if the webcam is opened correctly if not cap.isOpened(): raise IOError("Cannot open webcam") with torch.no_grad(): while True: # Get frame success, frame = cap.read() start_fps = time.time() if not success: print("Failed to obtain frame") time.sleep(0.1) # Process frame results = gaze_pipeline.step(frame) # Visualize output frame = render(frame, results) myFPS = 1.0 / (time.time() - start_fps) cv2.putText(frame, 'FPS: {:.1f}'.format(myFPS), (10, 20),cv2.FONT_HERSHEY_COMPLEX_SMALL, 1, (0, 255, 0), 1, cv2.LINE_AA) cv2.imshow("Demo",frame) if cv2.waitKey(1) & 0xFF == ord('q'): break success,frame = cap.read()