Newer
Older
casic_iris_recognize / post_processing.py
zhangyingjie on 30 Sep 2021 13 KB socket
import time

import cv2
import math
import numpy as np
import os
import traceback

def single_channel_to_multi_channel(arr):
    arr_c3 = arr
    if len(arr.shape) < 3:
        arr_c1 = np.expand_dims(arr, axis=2)
        arr_c3 = np.concatenate((arr_c1, arr_c1, arr_c1), axis=-1)
    return arr_c3


# 获取图像8连通域
def get_connections(img):
    cons = []
    if len(img.shape) > 2:
        img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
    _, labels = cv2.connectedComponents(img, connectivity=8)
    max_value = labels.max() + 1
    for i in range(max_value)[1:]:
        con = (labels == i)
        con = con.astype(np.uint8)
        cons.append(con)
    return cons


# 展示各个连通域
def show_connections(cons):
    for con in cons:
        con_c3 = single_channel_to_multi_channel(con)
        _, con_img = cv2.threshold(con_c3, 0, 255, cv2.THRESH_BINARY)
        cv2.imshow('con', con_img)
        cv2.waitKey(0)
    cv2.destroyAllWindows()


# 获取相邻三元组
def get_triplet_set(mask_cons, iris_cons, pupil_cons, debug = False,chessboard_distance_threshold = 20):
    triplet_set = []
    for mask_con in mask_cons:
        for iris_con in iris_cons:
            try:
                d2 = get_chessboard_distance(mask_con, iris_con,debug=debug, chessboard_distance_threshold=chessboard_distance_threshold)
                for pupil_con in pupil_cons:
                    d1 = get_chessboard_distance(mask_con, pupil_con,debug=debug, chessboard_distance_threshold=chessboard_distance_threshold)
                    # print('mask_pupil:',d1,'mask_iris:',d2)
                    if d1 <= chessboard_distance_threshold and d2 <=chessboard_distance_threshold:
                    # if get_chessboard_distance(mask_con, pupil_con) <= 10 \
                    #         and get_chessboard_distance(mask_con, iris_con) <= 15:
                        triplet_set.append((mask_con, iris_con, pupil_con))
            except:
                error = traceback.format_exc()
                print(error)
                continue
    return triplet_set

# 计算两个区域的棋盘距离
def get_chessboard_distance(con1, con2,debug = False, chessboard_distance_threshold = 20):

    _, con1_th = cv2.threshold(con1,0,255,cv2.THRESH_BINARY)
    c1, _ = cv2.findContours(con1_th,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)

    if debug:
        con1_th = single_channel_to_multi_channel(con1_th)
        cv2.drawContours(con1_th,c1,-1,(0,0,255),3)
        cv2.imshow('con1',con1_th)
        cv2.waitKey(0)

    c1 = np.concatenate(c1)
    c1 = np.squeeze(c1, axis=1)

    _, con2_th = cv2.threshold(con2, 0, 255, cv2.THRESH_BINARY)
    c2, _ = cv2.findContours(con2_th, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

    if debug:
        con2_th = single_channel_to_multi_channel(con2_th)
        cv2.drawContours(con2_th,c2,-1,(0,0,255),3)
        cv2.imshow('con2',con2_th)
        cv2.waitKey(0)

    c2 = np.concatenate(c2)
    c2 = np.squeeze(c2, axis=1)

    distance = np.abs(c1[:,None,:] - c2[None,:,:]).max(axis=2).min(axis=1).min()

    # distance = 1000
    # con1_idx = np.where(c1 > 0)
    # con2_idx = np.where(c2 > 0)
    # print(len(con1_idx[0]), len(con2_idx[0]))
    # for i in range(len(con1_idx[0])):
    #     for j in range(len(con2_idx[0])):
    #         chessboard_distance = max(abs(con1_idx[0][i] - con2_idx[0][j]), abs(con1_idx[1][i] - con2_idx[1][j]))
    # print(len(c1), len(c2))
    # for i in range(len(c1)):
    #     for j in range(len(c2)):
    #         chessboard_distance = max(abs(c1[i][0] - c2[j][0]), abs(c1[i][1] - c2[j][1]))
    #         if chessboard_distance < distance:
    #             distance = chessboard_distance
    #         if distance <= chessboard_distance_threshold:
    #             break

    if debug:
        print('distance',distance)
    return distance

# 展示三元组
def show_triplet(triplet):
    all = triplet[0] + triplet[1] + triplet[2]
    all_c3 = single_channel_to_multi_channel(all)
    _, triplet_img = cv2.threshold(all_c3, 0, 255, cv2.THRESH_BINARY)
    cv2.imshow('triplet', triplet_img)
    cv2.waitKey(0)

# 获取最大三元组
def get_max_triplet(triplet_set):
    max_count = 0
    max_index = -1
    for i, triplet in enumerate(triplet_set):
        # show_triplet(triplet)
        #all = triplet[0] + triplet[1] + triplet[2]
        #count = np.count_nonzero(all)
        count = np.count_nonzero(triplet[0]) + np.count_nonzero(triplet[1]) + np.count_nonzero(triplet[2])
        # print('count', count)
        if count > max_count:
            max_count = count
            max_index = i
    return triplet_set[max_index]

# 最小二乘法拟合圆
def least_square_circle_fitting(counter):

    center_x = 0.0
    center_y = 0.0
    radius = 0.0

    if(len(counter)<=3):
        return center_x,center_y,radius

    sum_x, sum_y = 0.0, 0.0
    sum_x2, sum_y2 = 0.0, 0.0
    sum_x3, sum_y3 = 0.0, 0.0
    sum_xy, sum_x1y2, sum_x2y1 = 0.0, 0.0, 0.0

    for point in counter:
        x = point[0]
        y = point[1]
        x2 = x * x
        y2 = y * y
        sum_x += x
        sum_y += y
        sum_x2 += x2
        sum_y2 += y2
        sum_x3 += x2 * x
        sum_y3 += y2 * y
        sum_xy += x*y
        sum_x1y2 += x * y2
        sum_x2y1 += x2 * y

    N = len(counter)
    C = N * sum_x2 - sum_x * sum_x
    D = N * sum_xy - sum_x * sum_y
    E = N * sum_x3 + N * sum_x1y2 - (sum_x2 + sum_y2) * sum_x
    G = N * sum_y2 - sum_y * sum_y
    H = N * sum_x2y1 + N * sum_y3 - (sum_x2 + sum_y2) * sum_y
    a = (H * D - E * G) / (C * G - D * D)
    b = (H * C - E * D) / (D * D - G * C)
    c = -(a * sum_x + b * sum_y + sum_x2 + sum_y2) / N

    center_x = a / (-2)
    center_y = b / (-2)
    radius = math.sqrt(a * a + b * b - 4 * c) / 2

    return center_x, center_y, radius

def post_processing_image(mask,iris,pupil,debug = False):
    iris_x, iris_y, iris_r = 0, 0, 0
    pupil_x, pupil_y, pupil_r = 0, 0, 0

    # time1 = time.time()
    # 1. 二值化
    ret1, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
    ret2, iris = cv2.threshold(iris, 90, 255, cv2.THRESH_BINARY)
    ret3, pupil = cv2.threshold(pupil, 200, 255, cv2.THRESH_BINARY)
    # ret2, iris = cv2.threshold(iris, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    # ret3, pupil = cv2.threshold(pupil, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)

    # 对虹膜外圆做闭运算,连通小断点
    size = 5
    kernel = np.ones((size, size), dtype=np.uint8)
    iris = cv2.erode(cv2.dilate(iris, kernel), kernel)

    # if debug:
    # cv2.imshow('mask',mask)
    # cv2.waitKey(0)
    # cv2.imshow('iris',iris)
    # cv2.waitKey(0)
    # cv2.imshow('pupil',pupil)
    # cv2.waitKey(0)
    # cv2.destroyAllWindows()

    # time2 = time.time()
    # 2. 8邻域
    mask_cons = get_connections(mask)
    iris_cons = get_connections(iris)
    pupil_cons = get_connections(pupil)

    if debug:
        show_connections(mask_cons)
        show_connections(iris_cons)
        show_connections(pupil_cons)

    if len(mask_cons) < 1 or len(iris_cons) < 1 or len(pupil_cons) < 1:
        return mask,(iris_y,iris_x,iris_r),(pupil_x,pupil_y,pupil_r)

    # time3 = time.time()
    # 3. 形成三元组
    triplet_set = get_triplet_set(mask_cons, iris_cons, pupil_cons,debug=debug)

    if(len(triplet_set) < 1):
        return mask,(iris_y,iris_x,iris_r),(pupil_x,pupil_y,pupil_r)

    if debug:
        for triplet in triplet_set:
            show_triplet(triplet)

    # 4. 选择最大三元组
    # max_triplet = get_max_triplet(triplet_set)
    # if debug:
    #     show_triplet(max_triplet)
    # time4 = time.time()
    sort_triplet = sorted(triplet_set, key=lambda triplet:(np.count_nonzero(triplet[0]) + np.count_nonzero(triplet[1]) + np.count_nonzero(triplet[2])))
    max_triplet = sort_triplet[-1]
    # show_triplet(max_triplet)

    if len(triplet_set) > 1:
        sec_triplet = sort_triplet[-2]
        # show_triplet(sec_triplet)

    # 5. 轮廓提取 + 最小二乘圆拟合
    # 外圆
    # time5 = time.time()
    best_iris = max_triplet[1]
    if len(triplet_set) > 1:
        best_iris = cv2.bitwise_or(
            best_iris, sec_triplet[1])
    _, best_iris = cv2.threshold(best_iris, 0, 255, cv2.THRESH_BINARY)

    # if debug:
    # cv2.imshow('best_iris',best_iris)
    # cv2.waitKey(0)

    out_counters = np.nonzero(best_iris)
    out_points = []
    for i in range(len(out_counters[0])):
        out_points.append((out_counters[1][i], out_counters[0][i]))
    iris_x, iris_y, iris_r = least_square_circle_fitting(out_points)
    iris_x = int(iris_x)
    iris_y = int(iris_y)
    iris_r = int(iris_r)
    # print(iris_x, iris_y, iris_r)

    # cv2.circle(mask, (iris_x, iris_y),iris_r,(0,0,255),1)
    # cv2.imshow('outer',mask)
    # cv2.waitKey(0)

    # 内圆
    # time6 = time.time()
    best_pupil = max_triplet[2]
    _, best_pupil = cv2.threshold(best_pupil, 0, 255, cv2.THRESH_BINARY)
    counters, _ = cv2.findContours(best_pupil, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    inner_points = np.squeeze(counters[0], axis=1)
    pupil_x, pupil_y, pupil_r = least_square_circle_fitting(inner_points)
    pupil_x = int(pupil_x)
    pupil_y = int(pupil_y)
    pupil_r = int(pupil_r)
    # print(pupil_x, pupil_y, pupil_r)

    # 绘制轮廓
    # best_pupil_c = single_channel_to_multi_channel(best_pupil)
    # cv2.drawContours(best_pupil_c, counters, -1, (0, 0, 255), 3)
    # cv2.imshow('inner', best_pupil_c)
    # cv2.waitKey(0)

    # 6. 优化mask:必须在内外圆内
    # time7 = time.time()
    circle_mask = np.zeros(mask.shape, dtype=np.uint8)
    cv2.circle(circle_mask, (iris_x, iris_y), iris_r, (255, 255, 255), -1)
    cv2.circle(circle_mask, (pupil_x, pupil_y), pupil_r, (0, 0, 0), -1)
    mask = cv2.bitwise_and(mask, circle_mask)

    # time8 = time.time()
    # print(time2-time1, time3-time2, time4-time3, time5-time4, time6-time5, time7-time6, time8-time7)

    # cv2.circle(mask, (iris_y, iris_x), iris_r, (0, 0, 255), 1)
    # cv2.circle(mask, (pupil_x, pupil_y), pupil_r, (0, 0, 255), 1)
    # cv2.imshow('circle_mask', circle_mask)
    # cv2.waitKey(0)
    # cv2.imshow('mask', mask)
    # cv2.waitKey(0)

    return mask,(iris_x,iris_y,iris_r),(pupil_x,pupil_y,pupil_r)

def normalize_image(image, irisCenterX, irisCenterY, irisR, pupilCenterX, pupilCenterY, pupilR, width, height):
    realHeight = height + 2
    angledivisions = width - 1
    rows, cols = image.shape[0:2]

    r = np.arange(0, realHeight, 1)
    thetas = np.arange(0, 2*np.pi + 2*np.pi / angledivisions, 2*np.pi / angledivisions)

    ox = pupilCenterX - irisCenterX
    oy = pupilCenterY - irisCenterY

    if ox <= 0:
        sgn = -1
    elif ox > 0:
        sgn = 1

    if ox == 0 and oy > 0:
        sgn = 1

    a = np.ones((1,width)) * (ox*ox + oy*oy)

    if ox == 0:
        phi = np.pi/2
    else:
        phi = math.atan(oy/ox)

    b = sgn * np.cos(np.pi - phi - thetas)

    r = np.multiply(np.sqrt(a), b) + np.sqrt(np.multiply(a, np.power(b, 2)) - (a - np.power(irisR, 2)))
    r = r - pupilR

    rMatrix = np.transpose(np.ones((1, realHeight))) * r
    rMatrix = np.multiply(rMatrix, np.transpose(np.ones((angledivisions+1, 1))*np.arange(0, 1, 1/realHeight)))
    rMatrix = rMatrix + pupilR

    rMatrix = rMatrix[1:(realHeight - 1), :]

    xcosMat = np.ones((height, 1)) * np.cos(thetas)
    xsinMat = np.ones((height, 1)) * np.sin(thetas)

    xo = np.multiply(rMatrix, xcosMat)
    yo = np.multiply(rMatrix, xsinMat)

    xo = pupilCenterX + xo
    yo = pupilCenterY - yo

    xo = xo.astype(int)
    yo = yo.astype(int)

    normImage = np.empty((0, width), np.uint8)
    for i, j in zip(xo, yo):
        normImage = np.vstack((normImage, image[j, i]))

    return normImage

if __name__ == '__main__':

    chessboard_distance_threshold = 20
    debug = False

    image_dir = '/home/ubuntu/iris_image/iris-test/image/'
    mask_dir = '/home/ubuntu/iris_image/iris-test/vgg16_unet_multitask_attention/epoch400_mask/'
    iris_dir = '/home/ubuntu/iris_image/iris-test/vgg16_unet_multitask_attention/epoch400_iris/'
    pupil_dir = '/home/ubuntu/iris_image/iris-test/vgg16_unet_multitask_attention/epoch400_pupil/'
    mask_extension = 'bmp'

    file_list = os.listdir(image_dir)

    for filename in file_list:

        print(filename)

        image_path = image_dir + filename
        mask_path = mask_dir + os.path.splitext(filename)[0]  + '.' + mask_extension
        iris_path = iris_dir + os.path.splitext(filename)[0]  + '.' + mask_extension
        pupil_path = pupil_dir + os.path.splitext(filename)[0]  + '.' + mask_extension

        image = cv2.imread(image_path,0)
        # image = cv2.resize(image,(int(image.shape[1]/2),int(image.shape[0]/2)))

        mask = cv2.imread(mask_path)
        iris = cv2.imread(iris_path, 0)
        pupil = cv2.imread(pupil_path, 0)

        mask, iris_circle, pupil_circle = post_processing_image(mask,iris,pupil,debug=debug)

        iris_x, iris_y, iris_r = iris_circle[0], iris_circle[1], iris_circle[2]
        pupil_x, pupil_y, pupil_r = pupil_circle[0], pupil_circle[1], pupil_circle[2]

        cv2.circle(image, (iris_x, iris_y), iris_r, (0, 0, 255), 3)
        cv2.circle(image, (pupil_x, pupil_y), pupil_r, (0, 0, 255), 3)
        cv2.imshow('111', image)
        cv2.waitKey(0)

        if pupil_x < iris_x - iris_r or pupil_x > iris_x + iris_r or pupil_y < iris_y - iris_r or pupil_y > iris_y + iris_r:
            continue

        width = 512
        height = 64
        mask = mask[:,:,0]
        a = normalize_image(image,iris_x,iris_y,iris_r,pupil_x,pupil_y,pupil_r,width,height)
        b = normalize_image(mask,iris_x,iris_y,iris_r,pupil_x,pupil_y,pupil_r,width,height)

        # cv2.imshow('normalize_mask', b)
        # cv2.waitKey(0)
        # cv2.imshow('normalize_image', a)
        # cv2.waitKey(0)


        # ret2, iris = cv2.threshold(iris, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        # cv2.imshow('iris',iris)
        # cv2.waitKey(0)
        #
        # size = 5
        # kernel = np.ones((size, size), dtype=np.uint8)
        # iris = cv2.erode(cv2.dilate(iris,kernel),kernel)
        # cv2.imshow('iris-close', iris)
        # cv2.waitKey(0)
        #
        # iris_con = get_connections(iris)
        # show_connections(iris_con)