diff --git a/iris_recognize.py b/iris_recognize.py index 83d9934..fe842a1 100644 --- a/iris_recognize.py +++ b/iris_recognize.py @@ -25,7 +25,7 @@ return application_points -def process_eye(image, mask, iris, pupil,out = False, out_path = None): +def process_eye(image, mask, iris, pupil,scale_factor = 0.5, out = False, out_path = None): h, w = image.shape[:2] new_w, new_h = int(scale_factor * w), int(scale_factor * h) diff --git a/iris_recognize.py b/iris_recognize.py index 83d9934..fe842a1 100644 --- a/iris_recognize.py +++ b/iris_recognize.py @@ -25,7 +25,7 @@ return application_points -def process_eye(image, mask, iris, pupil,out = False, out_path = None): +def process_eye(image, mask, iris, pupil,scale_factor = 0.5, out = False, out_path = None): h, w = image.shape[:2] new_w, new_h = int(scale_factor * w), int(scale_factor * h) diff --git a/iris_server.py b/iris_server.py new file mode 100644 index 0000000..92b1de3 --- /dev/null +++ b/iris_server.py @@ -0,0 +1,108 @@ +import socket +import sys +import time + +import cv2 +import numpy as np +import torch +from torchvision.models.vgg import vgg16_bn, vgg16 +from model.unet_vgg16_multitask_attention import UnetWithVGG16Attention +from iris_location import location +from iris_recognize import process_eye + +HOST = '0.0.0.0' +PORT = 50007 + +wid = 640 +hig = 480 +channels = 3 +scale_factor = 0.5 +model_path = 'checkpoints/unet_vgg16_multitask_attention_epoch600.pth' + + +def open_socket_server(host, port, listen = 10): + try: + server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + server.bind((HOST, PORT)) + server.listen(listen) + return server + except socket.error as msg: + print(msg) + sys.exit(1) + + +def load_model(model_dict_path, device): + encode_vgg16 = vgg16() + net = UnetWithVGG16Attention(encoder=encode_vgg16, n_classes=1, bilinear=True) + net.to(device=device) + net.load_state_dict(torch.load(model_path, map_location=device)) + return net + + +if __name__ == '__main__': + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + socket_server = open_socket_server(HOST,PORT) + if socket_server is None: + print('open socket server error, return!') + print('socket server open ...') + + net = load_model(model_path,device=device) + print('success load model ...') + + while 1: + conn, addr = socket_server.accept() # 接受TCP连接,并返回新的套接字与IP地址 + print('Connected by',addr) #输出客户端的IP地址 + + time1 = time.time() + + # recv eye image + buf = b'' + count = int(wid * hig * channels) + while count: + new_buf = conn.recv(count) + # if not newbuf: + # return None + buf += new_buf + count -= len(new_buf) + + data = np.frombuffer(buf, dtype='uint8') + # print(len(data)) + image = data.reshape(hig, wid, channels) + + mask, iris, pupil = location(net, image, device, scale_factor=scale_factor) + norm_image, norm_mask, encode = process_eye(image, mask, iris, pupil, scale_factor=scale_factor) + + time2 = time.time() + print(time2-time1) + + # cv2.imshow("encode",encode) + # cv2.waitKey(0) + + if encode is None: + print('encode None') + conn.send(str(0).encode()) + continue + else: + print('encode shape', encode.shape) + print('mask shape', norm_mask.shape) + conn.send(str(encode.shape[0]*encode.shape[1]).encode()) + + send_code_data = np.array(encode).flatten().astype(np.uint8) + send_code_size = encode.shape[0]*encode.shape[1] + send_code_total = 0 + while send_code_total < send_code_size: + send_code_len = conn.send(send_code_data,0) + send_code_total = send_code_total + send_code_len + print("encode has been send",send_code_total) + + send_mask_data = np.array(norm_mask).flatten().astype(np.uint8) + send_mask_size = norm_mask.shape[0]*norm_mask.shape[1] + send_mask_total = 0 + while send_mask_total < send_mask_size: + send_mask_len = conn.send(send_mask_data,0) + send_mask_total = send_mask_total + send_mask_len + print("mask has been send", send_mask_total) + + conn.close() \ No newline at end of file diff --git a/iris_recognize.py b/iris_recognize.py index 83d9934..fe842a1 100644 --- a/iris_recognize.py +++ b/iris_recognize.py @@ -25,7 +25,7 @@ return application_points -def process_eye(image, mask, iris, pupil,out = False, out_path = None): +def process_eye(image, mask, iris, pupil,scale_factor = 0.5, out = False, out_path = None): h, w = image.shape[:2] new_w, new_h = int(scale_factor * w), int(scale_factor * h) diff --git a/iris_server.py b/iris_server.py new file mode 100644 index 0000000..92b1de3 --- /dev/null +++ b/iris_server.py @@ -0,0 +1,108 @@ +import socket +import sys +import time + +import cv2 +import numpy as np +import torch +from torchvision.models.vgg import vgg16_bn, vgg16 +from model.unet_vgg16_multitask_attention import UnetWithVGG16Attention +from iris_location import location +from iris_recognize import process_eye + +HOST = '0.0.0.0' +PORT = 50007 + +wid = 640 +hig = 480 +channels = 3 +scale_factor = 0.5 +model_path = 'checkpoints/unet_vgg16_multitask_attention_epoch600.pth' + + +def open_socket_server(host, port, listen = 10): + try: + server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + server.bind((HOST, PORT)) + server.listen(listen) + return server + except socket.error as msg: + print(msg) + sys.exit(1) + + +def load_model(model_dict_path, device): + encode_vgg16 = vgg16() + net = UnetWithVGG16Attention(encoder=encode_vgg16, n_classes=1, bilinear=True) + net.to(device=device) + net.load_state_dict(torch.load(model_path, map_location=device)) + return net + + +if __name__ == '__main__': + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + socket_server = open_socket_server(HOST,PORT) + if socket_server is None: + print('open socket server error, return!') + print('socket server open ...') + + net = load_model(model_path,device=device) + print('success load model ...') + + while 1: + conn, addr = socket_server.accept() # 接受TCP连接,并返回新的套接字与IP地址 + print('Connected by',addr) #输出客户端的IP地址 + + time1 = time.time() + + # recv eye image + buf = b'' + count = int(wid * hig * channels) + while count: + new_buf = conn.recv(count) + # if not newbuf: + # return None + buf += new_buf + count -= len(new_buf) + + data = np.frombuffer(buf, dtype='uint8') + # print(len(data)) + image = data.reshape(hig, wid, channels) + + mask, iris, pupil = location(net, image, device, scale_factor=scale_factor) + norm_image, norm_mask, encode = process_eye(image, mask, iris, pupil, scale_factor=scale_factor) + + time2 = time.time() + print(time2-time1) + + # cv2.imshow("encode",encode) + # cv2.waitKey(0) + + if encode is None: + print('encode None') + conn.send(str(0).encode()) + continue + else: + print('encode shape', encode.shape) + print('mask shape', norm_mask.shape) + conn.send(str(encode.shape[0]*encode.shape[1]).encode()) + + send_code_data = np.array(encode).flatten().astype(np.uint8) + send_code_size = encode.shape[0]*encode.shape[1] + send_code_total = 0 + while send_code_total < send_code_size: + send_code_len = conn.send(send_code_data,0) + send_code_total = send_code_total + send_code_len + print("encode has been send",send_code_total) + + send_mask_data = np.array(norm_mask).flatten().astype(np.uint8) + send_mask_size = norm_mask.shape[0]*norm_mask.shape[1] + send_mask_total = 0 + while send_mask_total < send_mask_size: + send_mask_len = conn.send(send_mask_data,0) + send_mask_total = send_mask_total + send_mask_len + print("mask has been send", send_mask_total) + + conn.close() \ No newline at end of file diff --git a/model/attention.py b/model/attention.py index 9f2a0f5..f5ab9bc 100755 --- a/model/attention.py +++ b/model/attention.py @@ -19,7 +19,7 @@ self.conv_3x3_output1 = nn.Conv2d(in_channels= depth * 5, out_channels= 256, kernel_size=3, stride=1, padding=1) self.conv_3x3_output2 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1) - self.sigmoid = F.sigmoid + self.sigmoid = torch.sigmoid def forward(self, x): size = x.shape[2:] diff --git a/iris_recognize.py b/iris_recognize.py index 83d9934..fe842a1 100644 --- a/iris_recognize.py +++ b/iris_recognize.py @@ -25,7 +25,7 @@ return application_points -def process_eye(image, mask, iris, pupil,out = False, out_path = None): +def process_eye(image, mask, iris, pupil,scale_factor = 0.5, out = False, out_path = None): h, w = image.shape[:2] new_w, new_h = int(scale_factor * w), int(scale_factor * h) diff --git a/iris_server.py b/iris_server.py new file mode 100644 index 0000000..92b1de3 --- /dev/null +++ b/iris_server.py @@ -0,0 +1,108 @@ +import socket +import sys +import time + +import cv2 +import numpy as np +import torch +from torchvision.models.vgg import vgg16_bn, vgg16 +from model.unet_vgg16_multitask_attention import UnetWithVGG16Attention +from iris_location import location +from iris_recognize import process_eye + +HOST = '0.0.0.0' +PORT = 50007 + +wid = 640 +hig = 480 +channels = 3 +scale_factor = 0.5 +model_path = 'checkpoints/unet_vgg16_multitask_attention_epoch600.pth' + + +def open_socket_server(host, port, listen = 10): + try: + server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + server.bind((HOST, PORT)) + server.listen(listen) + return server + except socket.error as msg: + print(msg) + sys.exit(1) + + +def load_model(model_dict_path, device): + encode_vgg16 = vgg16() + net = UnetWithVGG16Attention(encoder=encode_vgg16, n_classes=1, bilinear=True) + net.to(device=device) + net.load_state_dict(torch.load(model_path, map_location=device)) + return net + + +if __name__ == '__main__': + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + socket_server = open_socket_server(HOST,PORT) + if socket_server is None: + print('open socket server error, return!') + print('socket server open ...') + + net = load_model(model_path,device=device) + print('success load model ...') + + while 1: + conn, addr = socket_server.accept() # 接受TCP连接,并返回新的套接字与IP地址 + print('Connected by',addr) #输出客户端的IP地址 + + time1 = time.time() + + # recv eye image + buf = b'' + count = int(wid * hig * channels) + while count: + new_buf = conn.recv(count) + # if not newbuf: + # return None + buf += new_buf + count -= len(new_buf) + + data = np.frombuffer(buf, dtype='uint8') + # print(len(data)) + image = data.reshape(hig, wid, channels) + + mask, iris, pupil = location(net, image, device, scale_factor=scale_factor) + norm_image, norm_mask, encode = process_eye(image, mask, iris, pupil, scale_factor=scale_factor) + + time2 = time.time() + print(time2-time1) + + # cv2.imshow("encode",encode) + # cv2.waitKey(0) + + if encode is None: + print('encode None') + conn.send(str(0).encode()) + continue + else: + print('encode shape', encode.shape) + print('mask shape', norm_mask.shape) + conn.send(str(encode.shape[0]*encode.shape[1]).encode()) + + send_code_data = np.array(encode).flatten().astype(np.uint8) + send_code_size = encode.shape[0]*encode.shape[1] + send_code_total = 0 + while send_code_total < send_code_size: + send_code_len = conn.send(send_code_data,0) + send_code_total = send_code_total + send_code_len + print("encode has been send",send_code_total) + + send_mask_data = np.array(norm_mask).flatten().astype(np.uint8) + send_mask_size = norm_mask.shape[0]*norm_mask.shape[1] + send_mask_total = 0 + while send_mask_total < send_mask_size: + send_mask_len = conn.send(send_mask_data,0) + send_mask_total = send_mask_total + send_mask_len + print("mask has been send", send_mask_total) + + conn.close() \ No newline at end of file diff --git a/model/attention.py b/model/attention.py index 9f2a0f5..f5ab9bc 100755 --- a/model/attention.py +++ b/model/attention.py @@ -19,7 +19,7 @@ self.conv_3x3_output1 = nn.Conv2d(in_channels= depth * 5, out_channels= 256, kernel_size=3, stride=1, padding=1) self.conv_3x3_output2 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1) - self.sigmoid = F.sigmoid + self.sigmoid = torch.sigmoid def forward(self, x): size = x.shape[2:] diff --git a/post_processing.py b/post_processing.py index 44225bc..16c66f7 100644 --- a/post_processing.py +++ b/post_processing.py @@ -401,7 +401,7 @@ if pupil_x < iris_x - iris_r or pupil_x > iris_x + iris_r or pupil_y < iris_y - iris_r or pupil_y > iris_y + iris_r: continue - width = 256 + width = 512 height = 64 mask = mask[:,:,0] a = normalize_image(image,iris_x,iris_y,iris_r,pupil_x,pupil_y,pupil_r,width,height)