Newer
Older
casic_iris_recognize / iris_location.py
zhangyingjie on 23 Sep 2021 4 KB irst commit
import os
import logging
import time
import cv2
import torch
import numpy as np

import torch.nn.functional as F
from PIL import Image
from torchvision import transforms


def location(net, img, device, scale_factor = 1):
    mask, iris, pupil = predict_img(net=net, image=img, device=device, scale_factor = scale_factor)
    mask = np.array(mask*255,dtype=np.uint8)
    iris = np.array(iris*255,dtype=np.uint8)
    pupil = np.array(pupil*255,dtype=np.uint8)
    return mask, iris, pupil

def predict_img(net,
                image,
                device,
                scale_factor=1,
                out_threshold=0.5):
    net.eval()

    full_img = image
    h, w = image.shape[:2]
    new_w, new_h = int(scale_factor * w), int(scale_factor * h)
    assert new_w > 0 and new_h > 0, 'Scale is too small'
    image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_NEAREST)

    if image.ndim == 2:
        image = image[:, :, None]

    mean = np.array([103.939, 116.779, 123.68])
    image = image - mean

    image = image.transpose((2, 0, 1))
    image = torch.from_numpy(image)
    # if isinstance(image, torch.ByteTensor):
        # image = image.float().div(255)
    image = image.float()

    image = image.unsqueeze(0)
    image = image.to(device=device, dtype=torch.float32)

    with torch.no_grad():
        mask_output, iris_output, pupil_output = net(image)

        if net.n_classes > 1:
            mask_probs = F.softmax(mask_output, dim=1)
            iris_probs = F.softmax(iris_output, dim=1)
            pupil_probs = F.softmax(pupil_output, dim=1)
        else:
            mask_probs = torch.sigmoid(mask_output)
            iris_probs = torch.sigmoid(iris_output)
            pupil_probs = torch.sigmoid(pupil_output)


        mask_probs = mask_probs.squeeze(0)
        iris_probs = iris_probs.squeeze(0)
        pupil_probs = pupil_probs.squeeze(0)

        # tf = transforms.Compose(
        #     [
        #         transforms.ToPILImage(),
        #         transforms.Resize(full_img.shape[:2]),
        #         transforms.ToTensor()
        #     ]
        # )

        # mask_probs = tf(mask_probs.cpu())
        # iris_probs = tf(iris_probs.cpu())
        # pupil_probs = tf(pupil_probs.cpu())

        mask_probs = mask_probs.cpu()
        iris_probs = iris_probs.cpu()
        pupil_probs = pupil_probs.cpu()

        full_mask = mask_probs.squeeze().cpu().numpy()
        full_iris = iris_probs.squeeze().cpu().numpy()
        full_pupil = pupil_probs.squeeze().cpu().numpy()

    return full_mask > out_threshold, full_iris > out_threshold, full_pupil



def mask_to_image(mask):
    return Image.fromarray((mask * 255).astype(np.uint8))


# if __name__ == '__main__':
#
#     # torch.set_num_threads(1)
#     print(torch.get_num_threads())
#     print(torch.get_num_interop_threads())
#
#     logging.getLogger().setLevel(logging.INFO)
#
#     if not os.path.exists(output_mask_dir):
#         os.makedirs(output_mask_dir)
#
#     if not os.path.exists(output_pupil_dir):
#         os.makedirs(output_pupil_dir)
#
#     if not os.path.exists(output_iris_dir):
#         os.makedirs(output_iris_dir)
#
#     vgg16 = vgg16()
#     net = UnetWithVGG16Attention(encoder=vgg16, n_classes=1, bilinear=True)
#
#     logging.info("Loading model {}".format(model_path))
#
#     device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#     logging.info('Using device {}'.format(device))
#     net.to(device=device)
#     net.load_state_dict(torch.load(model_path, map_location=device))
#
#     logging.info("Model loaded !")
#
#     in_files = os.listdir(input_dir)
#     all_start_time = time.time()
#
#     for i, fn in enumerate(in_files):
#
#         filename = fn
#         fn = input_dir + filename
#         logging.info("\nPredicting image {} ...".format(fn))
#
#         img = cv2.imread(fn)
#
#         start_time = time.time()
#         mask, iris, pupil = predict_img(net=net, image=img, device=device, scale_factor=scale)
#
#         out_mask_path = output_mask_dir + os.path.splitext(filename)[0] + '.bmp'
#         mask = mask_to_image(mask)
#         mask.save(out_mask_path)
#
#         out_iris_path = output_iris_dir + os.path.splitext(filename)[0] + '.bmp'
#         iris = mask_to_image(iris)
#         iris.save(out_iris_path)
#
#         out_pupil_path = output_pupil_dir + os.path.splitext(filename)[0] + '.bmp'
#         pupil = mask_to_image(pupil)
#         pupil.save(out_pupil_path)
#
#         end_time = time.time()
#         print(" time:" + str(end_time - start_time) + " sec")
#
#         logging.info("Mask saved to {}".format(out_mask_path))
#
#     all_end_time = time.time()
#     avg_time = (all_end_time - all_start_time) / len(in_files)
#     print("average time is %f seconds" % avg_time)