import time import cv2 import math import numpy as np import os import traceback def single_channel_to_multi_channel(arr): arr_c3 = arr if len(arr.shape) < 3: arr_c1 = np.expand_dims(arr, axis=2) arr_c3 = np.concatenate((arr_c1, arr_c1, arr_c1), axis=-1) return arr_c3 # 获取图像8连通域 def get_connections(img): cons = [] if len(img.shape) > 2: img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) _, labels = cv2.connectedComponents(img, connectivity=8) max_value = labels.max() + 1 for i in range(max_value)[1:]: con = (labels == i) con = con.astype(np.uint8) cons.append(con) return cons # 展示各个连通域 def show_connections(cons): for con in cons: con_c3 = single_channel_to_multi_channel(con) _, con_img = cv2.threshold(con_c3, 0, 255, cv2.THRESH_BINARY) cv2.imshow('con', con_img) cv2.waitKey(0) cv2.destroyAllWindows() # 获取相邻三元组 def get_triplet_set(mask_cons, iris_cons, pupil_cons, debug = False,chessboard_distance_threshold = 20): triplet_set = [] for mask_con in mask_cons: for iris_con in iris_cons: try: d2 = get_chessboard_distance(mask_con, iris_con,debug=debug, chessboard_distance_threshold=chessboard_distance_threshold) for pupil_con in pupil_cons: d1 = get_chessboard_distance(mask_con, pupil_con,debug=debug, chessboard_distance_threshold=chessboard_distance_threshold) # print('mask_pupil:',d1,'mask_iris:',d2) if d1 <= chessboard_distance_threshold and d2 <=chessboard_distance_threshold: # if get_chessboard_distance(mask_con, pupil_con) <= 10 \ # and get_chessboard_distance(mask_con, iris_con) <= 15: triplet_set.append((mask_con, iris_con, pupil_con)) except: error = traceback.format_exc() print(error) continue return triplet_set # 计算两个区域的棋盘距离 def get_chessboard_distance(con1, con2,debug = False, chessboard_distance_threshold = 20): _, con1_th = cv2.threshold(con1,0,255,cv2.THRESH_BINARY) c1, _ = cv2.findContours(con1_th,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE) if debug: con1_th = single_channel_to_multi_channel(con1_th) cv2.drawContours(con1_th,c1,-1,(0,0,255),3) cv2.imshow('con1',con1_th) cv2.waitKey(0) c1 = np.concatenate(c1) c1 = np.squeeze(c1, axis=1) _, con2_th = cv2.threshold(con2, 0, 255, cv2.THRESH_BINARY) c2, _ = cv2.findContours(con2_th, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) if debug: con2_th = single_channel_to_multi_channel(con2_th) cv2.drawContours(con2_th,c2,-1,(0,0,255),3) cv2.imshow('con2',con2_th) cv2.waitKey(0) c2 = np.concatenate(c2) c2 = np.squeeze(c2, axis=1) distance = np.abs(c1[:,None,:] - c2[None,:,:]).max(axis=2).min(axis=1).min() # distance = 1000 # con1_idx = np.where(c1 > 0) # con2_idx = np.where(c2 > 0) # print(len(con1_idx[0]), len(con2_idx[0])) # for i in range(len(con1_idx[0])): # for j in range(len(con2_idx[0])): # chessboard_distance = max(abs(con1_idx[0][i] - con2_idx[0][j]), abs(con1_idx[1][i] - con2_idx[1][j])) # print(len(c1), len(c2)) # for i in range(len(c1)): # for j in range(len(c2)): # chessboard_distance = max(abs(c1[i][0] - c2[j][0]), abs(c1[i][1] - c2[j][1])) # if chessboard_distance < distance: # distance = chessboard_distance # if distance <= chessboard_distance_threshold: # break if debug: print('distance',distance) return distance # 展示三元组 def show_triplet(triplet): all = triplet[0] + triplet[1] + triplet[2] all_c3 = single_channel_to_multi_channel(all) _, triplet_img = cv2.threshold(all_c3, 0, 255, cv2.THRESH_BINARY) cv2.imshow('triplet', triplet_img) cv2.waitKey(0) # 获取最大三元组 def get_max_triplet(triplet_set): max_count = 0 max_index = -1 for i, triplet in enumerate(triplet_set): # show_triplet(triplet) #all = triplet[0] + triplet[1] + triplet[2] #count = np.count_nonzero(all) count = np.count_nonzero(triplet[0]) + np.count_nonzero(triplet[1]) + np.count_nonzero(triplet[2]) # print('count', count) if count > max_count: max_count = count max_index = i return triplet_set[max_index] # 最小二乘法拟合圆 def least_square_circle_fitting(counter): center_x = 0.0 center_y = 0.0 radius = 0.0 if(len(counter)<=3): return center_x,center_y,radius sum_x, sum_y = 0.0, 0.0 sum_x2, sum_y2 = 0.0, 0.0 sum_x3, sum_y3 = 0.0, 0.0 sum_xy, sum_x1y2, sum_x2y1 = 0.0, 0.0, 0.0 for point in counter: x = point[0] y = point[1] x2 = x * x y2 = y * y sum_x += x sum_y += y sum_x2 += x2 sum_y2 += y2 sum_x3 += x2 * x sum_y3 += y2 * y sum_xy += x*y sum_x1y2 += x * y2 sum_x2y1 += x2 * y N = len(counter) C = N * sum_x2 - sum_x * sum_x D = N * sum_xy - sum_x * sum_y E = N * sum_x3 + N * sum_x1y2 - (sum_x2 + sum_y2) * sum_x G = N * sum_y2 - sum_y * sum_y H = N * sum_x2y1 + N * sum_y3 - (sum_x2 + sum_y2) * sum_y a = (H * D - E * G) / (C * G - D * D) b = (H * C - E * D) / (D * D - G * C) c = -(a * sum_x + b * sum_y + sum_x2 + sum_y2) / N center_x = a / (-2) center_y = b / (-2) radius = math.sqrt(a * a + b * b - 4 * c) / 2 return center_x, center_y, radius def post_processing_image(mask,iris,pupil,debug = False): iris_x, iris_y, iris_r = 0, 0, 0 pupil_x, pupil_y, pupil_r = 0, 0, 0 # time1 = time.time() # 1. 二值化 ret1, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY) ret2, iris = cv2.threshold(iris, 90, 255, cv2.THRESH_BINARY) ret3, pupil = cv2.threshold(pupil, 200, 255, cv2.THRESH_BINARY) # ret2, iris = cv2.threshold(iris, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) # ret3, pupil = cv2.threshold(pupil, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) # 对虹膜外圆做闭运算,连通小断点 size = 5 kernel = np.ones((size, size), dtype=np.uint8) iris = cv2.erode(cv2.dilate(iris, kernel), kernel) # if debug: # cv2.imshow('mask',mask) # cv2.waitKey(0) # cv2.imshow('iris',iris) # cv2.waitKey(0) # cv2.imshow('pupil',pupil) # cv2.waitKey(0) # cv2.destroyAllWindows() # time2 = time.time() # 2. 8邻域 mask_cons = get_connections(mask) iris_cons = get_connections(iris) pupil_cons = get_connections(pupil) if debug: show_connections(mask_cons) show_connections(iris_cons) show_connections(pupil_cons) if len(mask_cons) < 1 or len(iris_cons) < 1 or len(pupil_cons) < 1: return mask,(iris_y,iris_x,iris_r),(pupil_x,pupil_y,pupil_r) # time3 = time.time() # 3. 形成三元组 triplet_set = get_triplet_set(mask_cons, iris_cons, pupil_cons,debug=debug) if(len(triplet_set) < 1): return mask,(iris_y,iris_x,iris_r),(pupil_x,pupil_y,pupil_r) if debug: for triplet in triplet_set: show_triplet(triplet) # 4. 选择最大三元组 # max_triplet = get_max_triplet(triplet_set) # if debug: # show_triplet(max_triplet) # time4 = time.time() sort_triplet = sorted(triplet_set, key=lambda triplet:(np.count_nonzero(triplet[0]) + np.count_nonzero(triplet[1]) + np.count_nonzero(triplet[2]))) max_triplet = sort_triplet[-1] # show_triplet(max_triplet) if len(triplet_set) > 1: sec_triplet = sort_triplet[-2] # show_triplet(sec_triplet) # 5. 轮廓提取 + 最小二乘圆拟合 # 外圆 # time5 = time.time() best_iris = max_triplet[1] if len(triplet_set) > 1: best_iris = cv2.bitwise_or( best_iris, sec_triplet[1]) _, best_iris = cv2.threshold(best_iris, 0, 255, cv2.THRESH_BINARY) # if debug: # cv2.imshow('best_iris',best_iris) # cv2.waitKey(0) out_counters = np.nonzero(best_iris) out_points = [] for i in range(len(out_counters[0])): out_points.append((out_counters[1][i], out_counters[0][i])) iris_x, iris_y, iris_r = least_square_circle_fitting(out_points) iris_x = int(iris_x) iris_y = int(iris_y) iris_r = int(iris_r) # print(iris_x, iris_y, iris_r) # cv2.circle(mask, (iris_x, iris_y),iris_r,(0,0,255),1) # cv2.imshow('outer',mask) # cv2.waitKey(0) # 内圆 # time6 = time.time() best_pupil = max_triplet[2] _, best_pupil = cv2.threshold(best_pupil, 0, 255, cv2.THRESH_BINARY) counters, _ = cv2.findContours(best_pupil, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) inner_points = np.squeeze(counters[0], axis=1) pupil_x, pupil_y, pupil_r = least_square_circle_fitting(inner_points) pupil_x = int(pupil_x) pupil_y = int(pupil_y) pupil_r = int(pupil_r) # print(pupil_x, pupil_y, pupil_r) # 绘制轮廓 # best_pupil_c = single_channel_to_multi_channel(best_pupil) # cv2.drawContours(best_pupil_c, counters, -1, (0, 0, 255), 3) # cv2.imshow('inner', best_pupil_c) # cv2.waitKey(0) # 6. 优化mask:必须在内外圆内 # time7 = time.time() circle_mask = np.zeros(mask.shape, dtype=np.uint8) cv2.circle(circle_mask, (iris_x, iris_y), iris_r, (255, 255, 255), -1) cv2.circle(circle_mask, (pupil_x, pupil_y), pupil_r, (0, 0, 0), -1) mask = cv2.bitwise_and(mask, circle_mask) # time8 = time.time() # print(time2-time1, time3-time2, time4-time3, time5-time4, time6-time5, time7-time6, time8-time7) # cv2.circle(mask, (iris_y, iris_x), iris_r, (0, 0, 255), 1) # cv2.circle(mask, (pupil_x, pupil_y), pupil_r, (0, 0, 255), 1) # cv2.imshow('circle_mask', circle_mask) # cv2.waitKey(0) # cv2.imshow('mask', mask) # cv2.waitKey(0) return mask,(iris_x,iris_y,iris_r),(pupil_x,pupil_y,pupil_r) def normalize_image(image, irisCenterX, irisCenterY, irisR, pupilCenterX, pupilCenterY, pupilR, width, height): realHeight = height + 2 angledivisions = width - 1 rows, cols = image.shape[0:2] r = np.arange(0, realHeight, 1) thetas = np.arange(0, 2*np.pi + 2*np.pi / angledivisions, 2*np.pi / angledivisions) ox = pupilCenterX - irisCenterX oy = pupilCenterY - irisCenterY if ox <= 0: sgn = -1 elif ox > 0: sgn = 1 if ox == 0 and oy > 0: sgn = 1 a = np.ones((1,width)) * (ox*ox + oy*oy) if ox == 0: phi = np.pi/2 else: phi = math.atan(oy/ox) b = sgn * np.cos(np.pi - phi - thetas) r = np.multiply(np.sqrt(a), b) + np.sqrt(np.multiply(a, np.power(b, 2)) - (a - np.power(irisR, 2))) r = r - pupilR rMatrix = np.transpose(np.ones((1, realHeight))) * r rMatrix = np.multiply(rMatrix, np.transpose(np.ones((angledivisions+1, 1))*np.arange(0, 1, 1/realHeight))) rMatrix = rMatrix + pupilR rMatrix = rMatrix[1:(realHeight - 1), :] xcosMat = np.ones((height, 1)) * np.cos(thetas) xsinMat = np.ones((height, 1)) * np.sin(thetas) xo = np.multiply(rMatrix, xcosMat) yo = np.multiply(rMatrix, xsinMat) xo = pupilCenterX + xo yo = pupilCenterY - yo xo = xo.astype(int) yo = yo.astype(int) normImage = np.empty((0, width), np.uint8) for i, j in zip(xo, yo): normImage = np.vstack((normImage, image[j, i])) return normImage if __name__ == '__main__': chessboard_distance_threshold = 20 debug = False image_dir = '/home/ubuntu/iris_image/iris-test/image/' mask_dir = '/home/ubuntu/iris_image/iris-test/vgg16_unet_multitask_attention/epoch400_mask/' iris_dir = '/home/ubuntu/iris_image/iris-test/vgg16_unet_multitask_attention/epoch400_iris/' pupil_dir = '/home/ubuntu/iris_image/iris-test/vgg16_unet_multitask_attention/epoch400_pupil/' mask_extension = 'bmp' file_list = os.listdir(image_dir) for filename in file_list: print(filename) image_path = image_dir + filename mask_path = mask_dir + os.path.splitext(filename)[0] + '.' + mask_extension iris_path = iris_dir + os.path.splitext(filename)[0] + '.' + mask_extension pupil_path = pupil_dir + os.path.splitext(filename)[0] + '.' + mask_extension image = cv2.imread(image_path,0) # image = cv2.resize(image,(int(image.shape[1]/2),int(image.shape[0]/2))) mask = cv2.imread(mask_path) iris = cv2.imread(iris_path, 0) pupil = cv2.imread(pupil_path, 0) mask, iris_circle, pupil_circle = post_processing_image(mask,iris,pupil,debug=debug) iris_x, iris_y, iris_r = iris_circle[0], iris_circle[1], iris_circle[2] pupil_x, pupil_y, pupil_r = pupil_circle[0], pupil_circle[1], pupil_circle[2] cv2.circle(image, (iris_x, iris_y), iris_r, (0, 0, 255), 3) cv2.circle(image, (pupil_x, pupil_y), pupil_r, (0, 0, 255), 3) cv2.imshow('111', image) cv2.waitKey(0) if pupil_x < iris_x - iris_r or pupil_x > iris_x + iris_r or pupil_y < iris_y - iris_r or pupil_y > iris_y + iris_r: continue width = 512 height = 64 mask = mask[:,:,0] a = normalize_image(image,iris_x,iris_y,iris_r,pupil_x,pupil_y,pupil_r,width,height) b = normalize_image(mask,iris_x,iris_y,iris_r,pupil_x,pupil_y,pupil_r,width,height) # cv2.imshow('normalize_mask', b) # cv2.waitKey(0) # cv2.imshow('normalize_image', a) # cv2.waitKey(0) # ret2, iris = cv2.threshold(iris, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) # cv2.imshow('iris',iris) # cv2.waitKey(0) # # size = 5 # kernel = np.ones((size, size), dtype=np.uint8) # iris = cv2.erode(cv2.dilate(iris,kernel),kernel) # cv2.imshow('iris-close', iris) # cv2.waitKey(0) # # iris_con = get_connections(iris) # show_connections(iris_con)