import socket import sys import time import cv2 import numpy as np import torch from torchvision.models.vgg import vgg16_bn, vgg16 from model.unet_vgg16_multitask_attention import UnetWithVGG16Attention from iris_location import location from iris_recognize import process_eye HOST = '0.0.0.0' PORT = 50007 wid = 640 hig = 480 channels = 3 scale_factor = 0.5 model_path = 'checkpoints/unet_vgg16_multitask_attention_epoch600.pth' def open_socket_server(host, port, listen = 10): try: server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) server.bind((HOST, PORT)) server.listen(listen) return server except socket.error as msg: print(msg) sys.exit(1) def load_model(model_dict_path, device): encode_vgg16 = vgg16() net = UnetWithVGG16Attention(encoder=encode_vgg16, n_classes=1, bilinear=True) net.to(device=device) net.load_state_dict(torch.load(model_path, map_location=device)) return net if __name__ == '__main__': device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') socket_server = open_socket_server(HOST,PORT) if socket_server is None: print('open socket server error, return!') print('socket server open ...') net = load_model(model_path,device=device) print('success load model ...') while 1: conn, addr = socket_server.accept() # 接受TCP连接,并返回新的套接字与IP地址 print('Connected by',addr) #输出客户端的IP地址 time1 = time.time() # recv eye image buf = b'' count = int(wid * hig * channels) while count: new_buf = conn.recv(count) # if not newbuf: # return None buf += new_buf count -= len(new_buf) data = np.frombuffer(buf, dtype='uint8') # print(len(data)) image = data.reshape(hig, wid, channels) mask, iris, pupil = location(net, image, device, scale_factor=scale_factor) norm_image, norm_mask, encode = process_eye(image, mask, iris, pupil, scale_factor=scale_factor) time2 = time.time() print(time2-time1) # cv2.imshow("encode",encode) # cv2.waitKey(0) if encode is None: print('encode None') conn.send(str(0).encode()) continue else: print('encode shape', encode.shape) print('mask shape', norm_mask.shape) conn.send(str(encode.shape[0]*encode.shape[1]).encode()) send_code_data = np.array(encode).flatten().astype(np.uint8) send_code_size = encode.shape[0]*encode.shape[1] send_code_total = 0 while send_code_total < send_code_size: send_code_len = conn.send(send_code_data,0) send_code_total = send_code_total + send_code_len print("encode has been send",send_code_total) send_mask_data = np.array(norm_mask).flatten().astype(np.uint8) send_mask_size = norm_mask.shape[0]*norm_mask.shape[1] send_mask_total = 0 while send_mask_total < send_mask_size: send_mask_len = conn.send(send_mask_data,0) send_mask_total = send_mask_total + send_mask_len print("mask has been send", send_mask_total) conn.close()