import os import time import cv2 import numpy as np import torch from torchvision.models.vgg import vgg16_bn, vgg16 from model.unet_vgg16_multitask_attention import UnetWithVGG16Attention from iris_location import location from iris_encode import add_borders, encode_image,get_gabor_filters from post_processing import post_processing_image, normalize_image scale_factor = 0.5 def get_application_points(point_path, width, height): point_file = open(point_path,'r') point_lines = point_file.readlines() application_points = np.zeros((height,width),dtype=np.uint8) for lines in point_lines[1:]: index = lines.split(" ") application_points[int(index[0])][int(index[1])] = 255 return application_points def process_eye(image, mask, iris, pupil,scale_factor = 0.5, out = False, out_path = None): h, w = image.shape[:2] new_w, new_h = int(scale_factor * w), int(scale_factor * h) assert new_w > 0 and new_h > 0, 'Scale is too small' image_resized = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_NEAREST) mask, iris_circle, pupil_circle = post_processing_image(mask, iris, pupil, debug=False) iris_x, iris_y, iris_r = iris_circle[0], iris_circle[1], iris_circle[2] pupil_x, pupil_y, pupil_r = pupil_circle[0], pupil_circle[1], pupil_circle[2] if out: mask_color = cv2.applyColorMap(mask, cv2.COLORMAP_JET) show_image = cv2.addWeighted(image_resized, 0.7, mask_color, 0.3, 0) cv2.circle(show_image, (iris_x, iris_y), iris_r, (0, 0, 255), 1) cv2.circle(show_image, (pupil_x, pupil_y), pupil_r, (0, 0, 255), 1) # cv2.imshow("show_image",show_image) # cv2.waitKey(0) cv2.imwrite(out_path, show_image) if iris_r == 0 or pupil_r == 0: return None, None, None if pupil_x < iris_x - iris_r or pupil_x > iris_x + iris_r or pupil_y < iris_y - iris_r or pupil_y > iris_y + iris_r: return None, None, None width = 512 height = 64 single_image = image_resized[:,:,0] mask = mask norm_image = normalize_image(single_image,iris_x,iris_y,iris_r,pupil_x,pupil_y,pupil_r,width,height) norm_mask = normalize_image(mask,iris_x,iris_y,iris_r,pupil_x,pupil_y,pupil_r,width,height) gabor_filters = get_gabor_filters() encode = encode_image(norm_image, gabor_filters) return norm_image, norm_mask ,encode def match(code1,code2,norm_mask1, norm_mask2, application_points=None): score = 1 temp = cv2.bitwise_and(norm_mask1,norm_mask2,mask=application_points) total_mask = np.vstack([temp]*6) shift = 10 shifted = add_borders(code1,shift) width = code1.shape[1] for i in range(-shift,shift+1,1): roi = shifted[:,shift+i:shift+i+width] result = cv2.bitwise_xor(roi,code2,mask=total_mask) mean = cv2.sumElems(result)[0] / cv2.sumElems(total_mask)[0] score = min(score, mean) return score if __name__ == '__main__': # load application points point_path = '/home/ubuntu/points.txt' points = get_application_points(point_path,512,64) # load model model_path = ('checkpoints/unet_vgg16_multitask_attention_epoch600.pth') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') vgg16 = vgg16() net = UnetWithVGG16Attention(encoder=vgg16, n_classes=1, bilinear=True) net.to(device=device) net.load_state_dict(torch.load(model_path, map_location=device)) inclass = 0 outclass = 0 true_accept = 0 false_accept = 0 true_reject = 0 false_reject = 0 score_threshold = 0.32 image_dir = '/home/ubuntu/iris_image/casic_iris_origin/image/' image_list = os.listdir(image_dir) norm_img_dir = '/home/ubuntu/iris_image/casic_iris_origin/norm_image/' if not os.path.exists(norm_img_dir): os.makedirs(norm_img_dir) norm_mask_dir = '/home/ubuntu/iris_image/casic_iris_origin/norm_mask/' if not os.path.exists(norm_mask_dir): os.makedirs(norm_mask_dir) code_dir = '/home/ubuntu/iris_image/casic_iris_origin/code/' if not os.path.exists(code_dir): os.makedirs(code_dir) out = True out_dir = '/home/ubuntu/iris_image/casic_iris_origin/output/' if not os.path.exists(out_dir): os.makedirs(out_dir) for image_name in image_list: print("processing", image_name) image1 = cv2.imread(image_dir + image_name) mask1, iris1, pupil1 = location(net, image1, device, scale_factor=scale_factor) out_path = out_dir + image_name norm_image1, norm_mask1, encode1 = process_eye(image1, mask1, iris1, pupil1, out=out, out_path=out_path) if encode1 is not None: cv2.imwrite(norm_img_dir + image_name, norm_image1) cv2.imwrite(norm_mask_dir + image_name, norm_mask1) cv2.imwrite(code_dir + image_name, encode1) for image_name1 in image_list: print("compare with " + image_name1 + "===========================================") for image_name2 in image_list: if image_name1 == image_name2: continue encode1 = cv2.imread(code_dir + image_name1, 0) encode2 = cv2.imread(code_dir + image_name2, 0) norm_mask1 = cv2.imread(norm_mask_dir + image_name1, 0) norm_mask2 = cv2.imread(norm_mask_dir + image_name2, 0) if encode1 is None or encode2 is None: continue score = match(encode1,encode2,norm_mask1,norm_mask2,points) # e_time = time.time() out_info = image_name2 + " " + str(score) + " " # print(image_name2, score) file_match = (image_name1[0:image_name1.find("eye")] == image_name2[0:image_name2.find("eye")]) # file_match = (image_name1[0:6] == image_name2[0:6]) # file_match = (image_name1[0:7]==image_name2[0:7]) and (image_name1[16:17]==image_name2[16:17]) if file_match: inclass += 1 if score <= score_threshold: true_accept += 1 else: false_reject += 1 out_info += " " + "false_reject" print(out_info) else: outclass += 1 if score <= score_threshold: false_accept += 1 out_info += " " + "false_accept" print(out_info) else: true_reject += 1 print('inclass',inclass) print('outclass', outclass) print('false_reject',false_reject) print('false_accept',false_accept) print('true_accept',true_accept) print('true_reject',true_reject) print('frr', false_reject/ inclass) print('far', false_accept/outclass) print('crr',(true_reject+true_accept)/(inclass+outclass))