import os import logging import time import cv2 import torch import numpy as np import torch.nn.functional as F from PIL import Image from torchvision import transforms def location(net, img, device, scale_factor = 1): mask, iris, pupil = predict_img(net=net, image=img, device=device, scale_factor = scale_factor) mask = np.array(mask*255,dtype=np.uint8) iris = np.array(iris*255,dtype=np.uint8) pupil = np.array(pupil*255,dtype=np.uint8) return mask, iris, pupil def predict_img(net, image, device, scale_factor=1, out_threshold=0.5): net.eval() full_img = image h, w = image.shape[:2] new_w, new_h = int(scale_factor * w), int(scale_factor * h) assert new_w > 0 and new_h > 0, 'Scale is too small' image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_NEAREST) if image.ndim == 2: image = image[:, :, None] mean = np.array([103.939, 116.779, 123.68]) image = image - mean image = image.transpose((2, 0, 1)) image = torch.from_numpy(image) # if isinstance(image, torch.ByteTensor): # image = image.float().div(255) image = image.float() image = image.unsqueeze(0) image = image.to(device=device, dtype=torch.float32) with torch.no_grad(): mask_output, iris_output, pupil_output = net(image) if net.n_classes > 1: mask_probs = F.softmax(mask_output, dim=1) iris_probs = F.softmax(iris_output, dim=1) pupil_probs = F.softmax(pupil_output, dim=1) else: mask_probs = torch.sigmoid(mask_output) iris_probs = torch.sigmoid(iris_output) pupil_probs = torch.sigmoid(pupil_output) mask_probs = mask_probs.squeeze(0) iris_probs = iris_probs.squeeze(0) pupil_probs = pupil_probs.squeeze(0) # tf = transforms.Compose( # [ # transforms.ToPILImage(), # transforms.Resize(full_img.shape[:2]), # transforms.ToTensor() # ] # ) # mask_probs = tf(mask_probs.cpu()) # iris_probs = tf(iris_probs.cpu()) # pupil_probs = tf(pupil_probs.cpu()) mask_probs = mask_probs.cpu() iris_probs = iris_probs.cpu() pupil_probs = pupil_probs.cpu() full_mask = mask_probs.squeeze().cpu().numpy() full_iris = iris_probs.squeeze().cpu().numpy() full_pupil = pupil_probs.squeeze().cpu().numpy() return full_mask > out_threshold, full_iris > out_threshold, full_pupil def mask_to_image(mask): return Image.fromarray((mask * 255).astype(np.uint8)) # if __name__ == '__main__': # # # torch.set_num_threads(1) # print(torch.get_num_threads()) # print(torch.get_num_interop_threads()) # # logging.getLogger().setLevel(logging.INFO) # # if not os.path.exists(output_mask_dir): # os.makedirs(output_mask_dir) # # if not os.path.exists(output_pupil_dir): # os.makedirs(output_pupil_dir) # # if not os.path.exists(output_iris_dir): # os.makedirs(output_iris_dir) # # vgg16 = vgg16() # net = UnetWithVGG16Attention(encoder=vgg16, n_classes=1, bilinear=True) # # logging.info("Loading model {}".format(model_path)) # # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # logging.info('Using device {}'.format(device)) # net.to(device=device) # net.load_state_dict(torch.load(model_path, map_location=device)) # # logging.info("Model loaded !") # # in_files = os.listdir(input_dir) # all_start_time = time.time() # # for i, fn in enumerate(in_files): # # filename = fn # fn = input_dir + filename # logging.info("\nPredicting image {} ...".format(fn)) # # img = cv2.imread(fn) # # start_time = time.time() # mask, iris, pupil = predict_img(net=net, image=img, device=device, scale_factor=scale) # # out_mask_path = output_mask_dir + os.path.splitext(filename)[0] + '.bmp' # mask = mask_to_image(mask) # mask.save(out_mask_path) # # out_iris_path = output_iris_dir + os.path.splitext(filename)[0] + '.bmp' # iris = mask_to_image(iris) # iris.save(out_iris_path) # # out_pupil_path = output_pupil_dir + os.path.splitext(filename)[0] + '.bmp' # pupil = mask_to_image(pupil) # pupil.save(out_pupil_path) # # end_time = time.time() # print(" time:" + str(end_time - start_time) + " sec") # # logging.info("Mask saved to {}".format(out_mask_path)) # # all_end_time = time.time() # avg_time = (all_end_time - all_start_time) / len(in_files) # print("average time is %f seconds" % avg_time)