import os import time import cv2 import numpy as np import torch from torchvision.models.vgg import vgg16_bn, vgg16 from model.unet_vgg16_multitask_attention import UnetWithVGG16Attention from iris_location import location from iris_encode import add_borders, encode_image,get_gabor_filters from post_processing import post_processing_image, normalize_image if __name__ == '__main__': model_path = ('checkpoints/unet_vgg16_multitask_attention_epoch600.pth') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') vgg16 = vgg16() net = UnetWithVGG16Attention(encoder=vgg16, n_classes=1, bilinear=True) net.to(device=device) net.load_state_dict(torch.load(model_path, map_location=device)) image1 = cv2.imread('/home/ubuntu/iris_image/debug_image/1-2-19-righteye2.bmp') mask1, iris1, pupil1 = location(net, image1, device, scale_factor=0.5) cv2.imshow('pupil',pupil1) cv2.waitKey(0) ret3, pupil1 = cv2.threshold(pupil1, 200, 255, cv2.THRESH_BINARY) cv2.imshow('pupil',pupil1) cv2.waitKey(0) norm_image1, norm_mask1, encode1 = post_processing_image(mask1, iris1, pupil1,debug=False)