Newer
Older
casic_iris_recognize / iris_test.py
zhangyingjie on 23 Sep 2021 1 KB irst commit
import os
import time

import cv2
import numpy as np
import torch
from torchvision.models.vgg import vgg16_bn, vgg16
from model.unet_vgg16_multitask_attention import UnetWithVGG16Attention

from iris_location import location
from iris_encode import add_borders, encode_image,get_gabor_filters
from post_processing import post_processing_image, normalize_image

if __name__ == '__main__':
    model_path = ('checkpoints/unet_vgg16_multitask_attention_epoch600.pth')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    vgg16 = vgg16()
    net = UnetWithVGG16Attention(encoder=vgg16, n_classes=1, bilinear=True)
    net.to(device=device)
    net.load_state_dict(torch.load(model_path, map_location=device))

    image1 = cv2.imread('/home/ubuntu/iris_image/debug_image/1-2-19-righteye2.bmp')
    mask1, iris1, pupil1 = location(net, image1, device, scale_factor=0.5)

    cv2.imshow('pupil',pupil1)
    cv2.waitKey(0)

    ret3, pupil1 = cv2.threshold(pupil1, 200, 255, cv2.THRESH_BINARY)

    cv2.imshow('pupil',pupil1)
    cv2.waitKey(0)

    norm_image1, norm_mask1, encode1 = post_processing_image(mask1, iris1, pupil1,debug=False)