Newer
Older
casic_iris_recognize / iris_recognize.py
zhangyingjie on 30 Sep 2021 6 KB socket
import os
import time

import cv2
import numpy as np
import torch
from torchvision.models.vgg import vgg16_bn, vgg16
from model.unet_vgg16_multitask_attention import UnetWithVGG16Attention

from iris_location import location
from iris_encode import add_borders, encode_image,get_gabor_filters
from post_processing import post_processing_image, normalize_image

scale_factor = 0.5

def get_application_points(point_path, width, height):
    point_file = open(point_path,'r')
    point_lines = point_file.readlines()

    application_points = np.zeros((height,width),dtype=np.uint8)
    for lines in point_lines[1:]:
        index = lines.split("	")
        application_points[int(index[0])][int(index[1])] = 255

    return application_points


def process_eye(image, mask, iris, pupil,scale_factor = 0.5, out = False, out_path = None):

    h, w = image.shape[:2]
    new_w, new_h = int(scale_factor * w), int(scale_factor * h)
    assert new_w > 0 and new_h > 0, 'Scale is too small'
    image_resized = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_NEAREST)

    mask, iris_circle, pupil_circle = post_processing_image(mask, iris, pupil, debug=False)

    iris_x, iris_y, iris_r = iris_circle[0], iris_circle[1], iris_circle[2]
    pupil_x, pupil_y, pupil_r = pupil_circle[0], pupil_circle[1], pupil_circle[2]

    if out:

        mask_color = cv2.applyColorMap(mask, cv2.COLORMAP_JET)
        show_image = cv2.addWeighted(image_resized, 0.7, mask_color, 0.3, 0)
        cv2.circle(show_image, (iris_x, iris_y), iris_r, (0, 0, 255), 1)
        cv2.circle(show_image, (pupil_x, pupil_y), pupil_r, (0, 0, 255), 1)
        # cv2.imshow("show_image",show_image)
        # cv2.waitKey(0)
        cv2.imwrite(out_path, show_image)

    if iris_r == 0 or pupil_r == 0:
        return None, None, None

    if pupil_x < iris_x - iris_r or pupil_x > iris_x + iris_r or pupil_y < iris_y - iris_r or pupil_y > iris_y + iris_r:
        return None, None, None

    width = 512
    height = 64
    single_image = image_resized[:,:,0]
    mask = mask
    norm_image = normalize_image(single_image,iris_x,iris_y,iris_r,pupil_x,pupil_y,pupil_r,width,height)
    norm_mask = normalize_image(mask,iris_x,iris_y,iris_r,pupil_x,pupil_y,pupil_r,width,height)

    gabor_filters = get_gabor_filters()
    encode = encode_image(norm_image, gabor_filters)

    return norm_image, norm_mask ,encode

def match(code1,code2,norm_mask1, norm_mask2, application_points=None):
    score = 1

    temp = cv2.bitwise_and(norm_mask1,norm_mask2,mask=application_points)
    total_mask = np.vstack([temp]*6)

    shift = 10
    shifted = add_borders(code1,shift)

    width = code1.shape[1]
    for i in range(-shift,shift+1,1):
        roi = shifted[:,shift+i:shift+i+width]
        result = cv2.bitwise_xor(roi,code2,mask=total_mask)

        mean = cv2.sumElems(result)[0] / cv2.sumElems(total_mask)[0]
        score = min(score, mean)

    return score

if __name__ == '__main__':
    # load application points
    point_path = '/home/ubuntu/points.txt'
    points = get_application_points(point_path,512,64)
    # load model
    model_path = ('checkpoints/unet_vgg16_multitask_attention_epoch600.pth')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    vgg16 = vgg16()
    net = UnetWithVGG16Attention(encoder=vgg16, n_classes=1, bilinear=True)
    net.to(device=device)
    net.load_state_dict(torch.load(model_path, map_location=device))

    inclass = 0
    outclass = 0

    true_accept = 0
    false_accept = 0
    true_reject = 0
    false_reject = 0

    score_threshold = 0.32

    image_dir = '/home/ubuntu/iris_image/casic_iris_origin/image/'
    image_list = os.listdir(image_dir)

    norm_img_dir = '/home/ubuntu/iris_image/casic_iris_origin/norm_image/'
    if not os.path.exists(norm_img_dir):
        os.makedirs(norm_img_dir)

    norm_mask_dir = '/home/ubuntu/iris_image/casic_iris_origin/norm_mask/'
    if not os.path.exists(norm_mask_dir):
        os.makedirs(norm_mask_dir)

    code_dir = '/home/ubuntu/iris_image/casic_iris_origin/code/'
    if not os.path.exists(code_dir):
        os.makedirs(code_dir)

    out = True
    out_dir = '/home/ubuntu/iris_image/casic_iris_origin/output/'
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    for image_name in image_list:
        print("processing", image_name)
        image1 = cv2.imread(image_dir + image_name)
        mask1, iris1, pupil1 = location(net, image1, device, scale_factor=scale_factor)
        out_path = out_dir + image_name
        norm_image1, norm_mask1, encode1 = process_eye(image1, mask1, iris1, pupil1, out=out, out_path=out_path)
        if encode1 is not None:
            cv2.imwrite(norm_img_dir + image_name, norm_image1)
            cv2.imwrite(norm_mask_dir + image_name, norm_mask1)
            cv2.imwrite(code_dir + image_name, encode1)

    for image_name1 in image_list:
        print("compare with " + image_name1 + "===========================================")


        for image_name2 in image_list:
            if image_name1 == image_name2:
                continue

            encode1 = cv2.imread(code_dir + image_name1, 0)
            encode2 = cv2.imread(code_dir + image_name2, 0)

            norm_mask1 = cv2.imread(norm_mask_dir + image_name1, 0)
            norm_mask2 = cv2.imread(norm_mask_dir + image_name2, 0)

            if encode1 is None or encode2 is None:
                continue

            score = match(encode1,encode2,norm_mask1,norm_mask2,points)
            # e_time = time.time()

            out_info = image_name2 + "  " + str(score) + "  "

            # print(image_name2, score)

            file_match = (image_name1[0:image_name1.find("eye")] == image_name2[0:image_name2.find("eye")])
            # file_match = (image_name1[0:6] == image_name2[0:6])
            # file_match = (image_name1[0:7]==image_name2[0:7]) and (image_name1[16:17]==image_name2[16:17])

            if file_match:
                inclass += 1
                if score <= score_threshold:
                    true_accept += 1
                else:
                    false_reject += 1
                    out_info += "   " + "false_reject"
                    print(out_info)
            else:
                outclass += 1
                if score <= score_threshold:
                    false_accept += 1
                    out_info += "   " + "false_accept"
                    print(out_info)
                else:
                    true_reject += 1



    print('inclass',inclass)
    print('outclass', outclass)
    print('false_reject',false_reject)
    print('false_accept',false_accept)
    print('true_accept',true_accept)
    print('true_reject',true_reject)
    print('frr', false_reject/ inclass)
    print('far', false_accept/outclass)
    print('crr',(true_reject+true_accept)/(inclass+outclass))