Newer
Older
casic_iris_recognize / iris_server.py
zhangyingjie on 30 Sep 2021 3 KB socket
import socket
import sys
import time

import cv2
import numpy as np
import torch
from torchvision.models.vgg import vgg16_bn, vgg16
from model.unet_vgg16_multitask_attention import UnetWithVGG16Attention
from iris_location import location
from iris_recognize import process_eye

HOST = '0.0.0.0'
PORT = 50007

wid = 640
hig = 480
channels = 3
scale_factor = 0.5
model_path = 'checkpoints/unet_vgg16_multitask_attention_epoch600.pth'


def open_socket_server(host, port, listen = 10):
    try:
        server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        server.bind((HOST, PORT))
        server.listen(listen)
        return server
    except socket.error as msg:
        print(msg)
        sys.exit(1)


def load_model(model_dict_path, device):
    encode_vgg16 = vgg16()
    net = UnetWithVGG16Attention(encoder=encode_vgg16, n_classes=1, bilinear=True)
    net.to(device=device)
    net.load_state_dict(torch.load(model_path, map_location=device))
    return net


if __name__ == '__main__':

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    socket_server = open_socket_server(HOST,PORT)
    if socket_server is None:
        print('open socket server error, return!')
    print('socket server open ...')

    net = load_model(model_path,device=device)
    print('success load model ...')

    while 1:
        conn, addr = socket_server.accept()  # 接受TCP连接,并返回新的套接字与IP地址
        print('Connected by',addr)    #输出客户端的IP地址

        time1 = time.time()

        # recv eye image
        buf = b''
        count = int(wid * hig * channels)
        while count:
            new_buf = conn.recv(count)
            # if not newbuf:
            #     return None
            buf += new_buf
            count -= len(new_buf)

        data = np.frombuffer(buf, dtype='uint8')
        # print(len(data))
        image = data.reshape(hig, wid, channels)

        mask, iris, pupil = location(net, image, device, scale_factor=scale_factor)
        norm_image, norm_mask, encode = process_eye(image, mask, iris, pupil, scale_factor=scale_factor)

        time2 = time.time()
        print(time2-time1)

        # cv2.imshow("encode",encode)
        # cv2.waitKey(0)

        if encode is None:
            print('encode None')
            conn.send(str(0).encode())
            continue
        else:
            print('encode shape', encode.shape)
            print('mask shape', norm_mask.shape)
            conn.send(str(encode.shape[0]*encode.shape[1]).encode())

        send_code_data = np.array(encode).flatten().astype(np.uint8)
        send_code_size = encode.shape[0]*encode.shape[1]
        send_code_total = 0
        while send_code_total < send_code_size:
            send_code_len = conn.send(send_code_data,0)
            send_code_total = send_code_total + send_code_len
        print("encode has been send",send_code_total)

        send_mask_data = np.array(norm_mask).flatten().astype(np.uint8)
        send_mask_size = norm_mask.shape[0]*norm_mask.shape[1]
        send_mask_total = 0
        while send_mask_total < send_mask_size:
            send_mask_len = conn.send(send_mask_data,0)
            send_mask_total = send_mask_total + send_mask_len
        print("mask has been send", send_mask_total)

        conn.close()