diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 434459f..4133388 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -106,7 +106,7 @@ logger.info(f"{self.url} disconnect, try to reconnect...") self.cap.release() # 释放当前的捕获对象 self.cap = self.get_connect() # 尝试重新连接 - self.frame = np.zeros_like(self.frame) + self.frame = None continue # 跳过当前循环的剩余部分 else: vid_n += 1 @@ -118,6 +118,7 @@ logger.error(f"{self.url} update fail", exc_info=e) if self.cap is not None: self.cap.release() + self.frame = None self.cap = self.get_connect() # 尝试重新连接 def __iter__(self): diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 434459f..4133388 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -106,7 +106,7 @@ logger.info(f"{self.url} disconnect, try to reconnect...") self.cap.release() # 释放当前的捕获对象 self.cap = self.get_connect() # 尝试重新连接 - self.frame = np.zeros_like(self.frame) + self.frame = None continue # 跳过当前循环的剩余部分 else: vid_n += 1 @@ -118,6 +118,7 @@ logger.error(f"{self.url} update fail", exc_info=e) if self.cap is not None: self.cap.release() + self.frame = None self.cap = self.get_connect() # 尝试重新连接 def __iter__(self): diff --git a/app_instance.py b/app_instance.py index 33f5e80..4e24675 100644 --- a/app_instance.py +++ b/app_instance.py @@ -15,6 +15,7 @@ from services.global_config import GlobalConfig from services.model_service import ModelService from services.scene_service import SceneService +from services.schedule_job import start_scheduler from tcp.tcp_manager import TcpManager _app = None # 创建一个私有变量来存储 app 实例 @@ -48,7 +49,7 @@ tcp_manager = TcpManager(device_service=device_service) app.state.tcp_manager = tcp_manager - await tcp_manager.start() + # await tcp_manager.start() algo_runner = AlgoRunner( device_service=device_service, @@ -56,7 +57,7 @@ relation_service=model_relation_service, ) app.state.algo_runner = algo_runner - await algo_runner.start() + # await algo_runner.start() scene_runner = SceneRunner( device_service=device_service, @@ -68,6 +69,8 @@ app.state.scene_runner = scene_runner await scene_runner.start() + main_loop.create_task(start_scheduler()) + yield # 允许请求处理 logger.info("Shutting down application...") diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 434459f..4133388 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -106,7 +106,7 @@ logger.info(f"{self.url} disconnect, try to reconnect...") self.cap.release() # 释放当前的捕获对象 self.cap = self.get_connect() # 尝试重新连接 - self.frame = np.zeros_like(self.frame) + self.frame = None continue # 跳过当前循环的剩余部分 else: vid_n += 1 @@ -118,6 +118,7 @@ logger.error(f"{self.url} update fail", exc_info=e) if self.cap is not None: self.cap.release() + self.frame = None self.cap = self.get_connect() # 尝试重新连接 def __iter__(self): diff --git a/app_instance.py b/app_instance.py index 33f5e80..4e24675 100644 --- a/app_instance.py +++ b/app_instance.py @@ -15,6 +15,7 @@ from services.global_config import GlobalConfig from services.model_service import ModelService from services.scene_service import SceneService +from services.schedule_job import start_scheduler from tcp.tcp_manager import TcpManager _app = None # 创建一个私有变量来存储 app 实例 @@ -48,7 +49,7 @@ tcp_manager = TcpManager(device_service=device_service) app.state.tcp_manager = tcp_manager - await tcp_manager.start() + # await tcp_manager.start() algo_runner = AlgoRunner( device_service=device_service, @@ -56,7 +57,7 @@ relation_service=model_relation_service, ) app.state.algo_runner = algo_runner - await algo_runner.start() + # await algo_runner.start() scene_runner = SceneRunner( device_service=device_service, @@ -68,6 +69,8 @@ app.state.scene_runner = scene_runner await scene_runner.start() + main_loop.create_task(start_scheduler()) + yield # 允许请求处理 logger.info("Shutting down application...") diff --git a/common/biz_exception.py b/common/biz_exception.py index 4659571..8e5089f 100644 --- a/common/biz_exception.py +++ b/common/biz_exception.py @@ -1,4 +1,5 @@ from fastapi import Request, HTTPException +from fastapi.responses import JSONResponse from apis.base import standard_error_response @@ -11,7 +12,8 @@ class BizExceptionHandlers: @staticmethod async def biz_exception_handler(request: Request, exc: HTTPException): - return standard_error_response(code=exc.status_code, message=exc.detail) - # 使用 JSONResponse 返回响应 - # return JSONResponse(status_code=exc.status_code, content=response_data) + return JSONResponse( + content=standard_error_response(code=exc.status_code, message=exc.detail).dict(), + status_code=exc.status_code + ) diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 434459f..4133388 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -106,7 +106,7 @@ logger.info(f"{self.url} disconnect, try to reconnect...") self.cap.release() # 释放当前的捕获对象 self.cap = self.get_connect() # 尝试重新连接 - self.frame = np.zeros_like(self.frame) + self.frame = None continue # 跳过当前循环的剩余部分 else: vid_n += 1 @@ -118,6 +118,7 @@ logger.error(f"{self.url} update fail", exc_info=e) if self.cap is not None: self.cap.release() + self.frame = None self.cap = self.get_connect() # 尝试重新连接 def __iter__(self): diff --git a/app_instance.py b/app_instance.py index 33f5e80..4e24675 100644 --- a/app_instance.py +++ b/app_instance.py @@ -15,6 +15,7 @@ from services.global_config import GlobalConfig from services.model_service import ModelService from services.scene_service import SceneService +from services.schedule_job import start_scheduler from tcp.tcp_manager import TcpManager _app = None # 创建一个私有变量来存储 app 实例 @@ -48,7 +49,7 @@ tcp_manager = TcpManager(device_service=device_service) app.state.tcp_manager = tcp_manager - await tcp_manager.start() + # await tcp_manager.start() algo_runner = AlgoRunner( device_service=device_service, @@ -56,7 +57,7 @@ relation_service=model_relation_service, ) app.state.algo_runner = algo_runner - await algo_runner.start() + # await algo_runner.start() scene_runner = SceneRunner( device_service=device_service, @@ -68,6 +69,8 @@ app.state.scene_runner = scene_runner await scene_runner.start() + main_loop.create_task(start_scheduler()) + yield # 允许请求处理 logger.info("Shutting down application...") diff --git a/common/biz_exception.py b/common/biz_exception.py index 4659571..8e5089f 100644 --- a/common/biz_exception.py +++ b/common/biz_exception.py @@ -1,4 +1,5 @@ from fastapi import Request, HTTPException +from fastapi.responses import JSONResponse from apis.base import standard_error_response @@ -11,7 +12,8 @@ class BizExceptionHandlers: @staticmethod async def biz_exception_handler(request: Request, exc: HTTPException): - return standard_error_response(code=exc.status_code, message=exc.detail) - # 使用 JSONResponse 返回响应 - # return JSONResponse(status_code=exc.status_code, content=response_data) + return JSONResponse( + content=standard_error_response(code=exc.status_code, message=exc.detail).dict(), + status_code=exc.status_code + ) diff --git a/common/global_logger.py b/common/global_logger.py index 3e672ca..c975ea8 100644 --- a/common/global_logger.py +++ b/common/global_logger.py @@ -2,6 +2,7 @@ import logging.handlers import os import sys +from logging.handlers import TimedRotatingFileHandler # 确保日志目录存在 log_dir = 'logs' @@ -15,7 +16,7 @@ logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING) # 创建一个TimedRotatingFileHandler -handler = logging.handlers.TimedRotatingFileHandler( +handler = TimedRotatingFileHandler( os.path.join(log_dir, 'app.log'), # 日志文件名 when='midnight', # 每天午夜滚动 interval=1 # 滚动间隔为1天 diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 434459f..4133388 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -106,7 +106,7 @@ logger.info(f"{self.url} disconnect, try to reconnect...") self.cap.release() # 释放当前的捕获对象 self.cap = self.get_connect() # 尝试重新连接 - self.frame = np.zeros_like(self.frame) + self.frame = None continue # 跳过当前循环的剩余部分 else: vid_n += 1 @@ -118,6 +118,7 @@ logger.error(f"{self.url} update fail", exc_info=e) if self.cap is not None: self.cap.release() + self.frame = None self.cap = self.get_connect() # 尝试重新连接 def __iter__(self): diff --git a/app_instance.py b/app_instance.py index 33f5e80..4e24675 100644 --- a/app_instance.py +++ b/app_instance.py @@ -15,6 +15,7 @@ from services.global_config import GlobalConfig from services.model_service import ModelService from services.scene_service import SceneService +from services.schedule_job import start_scheduler from tcp.tcp_manager import TcpManager _app = None # 创建一个私有变量来存储 app 实例 @@ -48,7 +49,7 @@ tcp_manager = TcpManager(device_service=device_service) app.state.tcp_manager = tcp_manager - await tcp_manager.start() + # await tcp_manager.start() algo_runner = AlgoRunner( device_service=device_service, @@ -56,7 +57,7 @@ relation_service=model_relation_service, ) app.state.algo_runner = algo_runner - await algo_runner.start() + # await algo_runner.start() scene_runner = SceneRunner( device_service=device_service, @@ -68,6 +69,8 @@ app.state.scene_runner = scene_runner await scene_runner.start() + main_loop.create_task(start_scheduler()) + yield # 允许请求处理 logger.info("Shutting down application...") diff --git a/common/biz_exception.py b/common/biz_exception.py index 4659571..8e5089f 100644 --- a/common/biz_exception.py +++ b/common/biz_exception.py @@ -1,4 +1,5 @@ from fastapi import Request, HTTPException +from fastapi.responses import JSONResponse from apis.base import standard_error_response @@ -11,7 +12,8 @@ class BizExceptionHandlers: @staticmethod async def biz_exception_handler(request: Request, exc: HTTPException): - return standard_error_response(code=exc.status_code, message=exc.detail) - # 使用 JSONResponse 返回响应 - # return JSONResponse(status_code=exc.status_code, content=response_data) + return JSONResponse( + content=standard_error_response(code=exc.status_code, message=exc.detail).dict(), + status_code=exc.status_code + ) diff --git a/common/global_logger.py b/common/global_logger.py index 3e672ca..c975ea8 100644 --- a/common/global_logger.py +++ b/common/global_logger.py @@ -2,6 +2,7 @@ import logging.handlers import os import sys +from logging.handlers import TimedRotatingFileHandler # 确保日志目录存在 log_dir = 'logs' @@ -15,7 +16,7 @@ logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING) # 创建一个TimedRotatingFileHandler -handler = logging.handlers.TimedRotatingFileHandler( +handler = TimedRotatingFileHandler( os.path.join(log_dir, 'app.log'), # 日志文件名 when='midnight', # 每天午夜滚动 interval=1 # 滚动间隔为1天 diff --git a/scene_handler/limit_space_scene_handler.py b/scene_handler/limit_space_scene_handler.py index 68df7df..d1c5b9a 100644 --- a/scene_handler/limit_space_scene_handler.py +++ b/scene_handler/limit_space_scene_handler.py @@ -1,17 +1,147 @@ import asyncio +import base64 +import traceback from asyncio import Event +from copy import deepcopy, copy +from datetime import datetime + +import cv2 from algo.model_manager import AlgoModelExec from algo.stream_loader import OpenCVStreamLoad from common.device_status_manager import DeviceStatusManager from common.global_logger import logger +from common.global_thread_pool import GlobalThreadPool +from common.http_utils import send_request +from common.image_plotting import Annotator from entity.device import Device from ultralytics import YOLO from scene_handler.base_scene_handler import BaseSceneHandler +from services.global_config import GlobalConfig from tcp.tcp_manager import TcpManager +COLOR_RED = (0, 0, 255) +COLOR_GREEN = (255, 0, 0) + +ALARM_DICT = { + 'hat_and_mask': { + 'alarmType': '11', + 'alarmContent': '未佩戴呼吸防护设备', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '未佩戴呼吸防护设备' + }, + 'no_jiandu': { + 'alarmType': '12', + 'alarmContent': '没有监护人员', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x13\x00\xA7', + 'label': '没有监护人员' + }, + 'break': { + 'alarmType': '3', + 'alarmContent': '非法闯入', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x00\x00\x94', + 'label': '非法闯入' + }, + 'smoke': { + 'alarmType': '6', + 'alarmContent': '吸烟', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x03\x00\x97', + 'label': '吸烟' + }, + 'no_blower': { + 'alarmType': '13', + 'alarmContent': '没有检测到通风设备', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x1A\x00\xAE', + 'label': '没有检测到通风设备' + }, + 'no_extinguisher': { + 'alarmType': '14', + 'alarmContent': '没有检测到灭火器', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x1B\x00\xAF', + 'label': '没有检测到灭火器' + } +} + + +def intersection_area(bbox1, bbox2): + # 计算两个坐标框的重叠面积 + x1, y1, x2, y2 = bbox1 + x3, y3, x4, y4 = bbox2 + + xi1 = max(x1, x3) + yi1 = max(y1, y3) + xi2 = min(x2, x4) + yi2 = min(y2, y4) + + width = max(0, xi2 - xi1) + height = max(0, yi2 - yi1) + + return width * height + + +def bbox_area(bbox): + # 计算坐标框的面积 + x1, y1, x2, y2 = bbox + return (x2 - x1) * (y2 - y1) + + +def get_person_head(person_bbox, heads): + best_head = None + max_overlap = 0 + + for head in heads: + head_bbox = head.xyxy.cpu().squeeze() + overlap_area = intersection_area(person_bbox, head_bbox) + head_area = bbox_area(head_bbox) + overlap_ratio = overlap_area / head_area + + if overlap_ratio >= 0.8 and (best_head is None or overlap_area > max_overlap or ( + overlap_area == max_overlap and float(head.conf) > float(best_head.conf))): + best_head = head + max_overlap = overlap_area + return best_head + + +def is_overlapping(bbox1, bbox2): + # 检查两个坐标框是否重叠 + x1, y1, x2, y2 = bbox1 + x3, y3, x4, y4 = bbox2 + + return not (x2 < x3 or x4 < x1 or y2 < y3 or y4 < y1) + + +def image_to_base64(numpy_image, format='jpg'): + # 将NumPy数组转换为图片格式的字节流 + # format指定图片格式,'png'或'jpeg' + success, encoded_image = cv2.imencode(f'.{format}', numpy_image) + if not success: + raise ValueError("Could not encode image") + + # 将字节流转换为Base64编码 + base64_encoded_image = base64.b64encode(encoded_image) + + # 将bytes类型转换为UTF-8字符串 + base64_message = base64_encoded_image.decode('utf-8') + + return base64_message + + +def handle_alarm_info(type, frame, frame_alarm, person_box, conf=None, person_id=None): + if type not in frame_alarm or frame_alarm[type] is None: + annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + frame_alarm[type] = {'count': 0, 'annotator': annotator} + frame_alarm[type]['count'] = frame_alarm[type]['count'] + 1 + alarm_annotator = frame_alarm[type]['annotator'] + if person_box is not None: + alarm_annotator.box_label(person_box, + ALARM_DICT[type]['label'] + (f'{conf:.2f}' if conf is not None else '') + ( + f' id={person_id}' if person_id is not None else ''), + color=COLOR_RED, + rotated=False) + return alarm_annotator + class LimitSpaceSceneHandler(BaseSceneHandler): @@ -22,8 +152,33 @@ self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, device_thread_id=thread_id) self.device_status_manager = DeviceStatusManager() + self.thread_pool = GlobalThreadPool() - self.person_model = YOLO('weights/yolov8s.pt') + self.model = YOLO('weights/labor-v8-20241114.pt') + self.model_classes = { + 0: '三脚架', + 3: '人', + 4: '作业信息公示牌', + 6: '危险告知牌', + 9: '反光衣', + 11: '呼吸面罩', + 13: '四合一', + 15: '头', + 16: '安全告知牌', + 18: '安全帽', + 20: '安全标识牌', + 24: '工服', + 34: '灭火器', + 43: '警戒线', + 58: '鼓风机', + } + + self.alarm_interval_dict = {} + self.alarm_interval = device.alarm_interval + + self.socket_interval_dict = {} + self.socket_interval = device.alarm_interval + self.socket_retry = 3 self.__stop_event = Event(loop=main_loop) # 使用 Event 控制线程的运行状态 @@ -39,27 +194,119 @@ have_response=have_response), self.main_loop) + def send_alarm_message(self, type): + if self.tcp_manager: + if self.socket_interval_dict.get(type) is None \ + or (datetime.now() - self.socket_interval_dict.get(type)).total_seconds() > int(self.socket_interval): + logger.debug("send alarm message %s %s", ALARM_DICT[type]['alarmContent'], + ALARM_DICT[type]['alarmSoundMessage']) + self.send_tcp_message(ALARM_DICT[type]['alarmSoundMessage'], have_response=True) + self.socket_interval_dict[type] = datetime.now() + + def send_alarm_record(self, type, frame_alarm): + if self.alarm_interval < 0: + return + + global_config = GlobalConfig() + push_config = global_config.get_alarm_push_config() + if push_config and push_config.push_url: + if self.alarm_interval_dict.get(type) is None \ + or (datetime.now() - self.alarm_interval_dict.get(type)).total_seconds() > int(self.alarm_interval): + logger.debug("send alarm record") + + annotator_result = frame_alarm[type]['annotator'].result() + alarm_image = deepcopy(annotator_result) + + data = {} + data["device_id"] = self.device.id + data["alarm_type"] = ALARM_DICT[type]['alarmType'] + data["alarm_content"] = ALARM_DICT[type]['alarmContent'] + data["alarm_value"] = frame_alarm[type]['count'] + data["alarm_image"] = image_to_base64(alarm_image) + data["alarm_time"] = datetime.now().strftime('%Y-%m-%d %H:%M:%S') + + url = push_config.push_url + send_request(url, data) + + data_copy = copy(data) + # 从拷贝的字典中移除"alarm_image"键 + data_copy.pop("alarm_image", None) + logger.debug(f"send to {url}: {data_copy}") + + self.alarm_interval_dict[type] = datetime.now() + + def model_predict(self, frame): + results_gen = self.model(frame, save_txt=False, save=False, verbose=False, conf=0.5, + classes=list(self.model_classes.keys()), + imgsz=640, + stream=True) + results = list(results_gen) # 确保生成器转换为列表 + result = results[0] + result_boxes = [box for box in result.boxes] + pred_ids = [int(box.cls) for box in result_boxes] + pred_names = [self.model_classes[int(box.cls)] for box in result_boxes] + return result_boxes, pred_ids, pred_names + + def process_alarm(self, frame, result_boxes, pred_ids, pred_names, frame_alarm): + persons = [box for box in result_boxes if int(box.cls) == 3] + helmets = [box for box in result_boxes if int(box.cls) == 18] + heads = [box for box in result_boxes if int(box.cls) == 15] + + has_jianduyuan = False + has_others = False + + for person in persons: + person_bbox = person.xyxy.cpu().squeeze() + + # 检查这个人是否佩戴了安全帽 + has_helmet = True + person_head = get_person_head(person_bbox, heads) + if person_head is not None: + has_helmet = any( + is_overlapping(person_head.xyxy.cpu().squeeze(), helmet.xyxy.cpu().squeeze()) for helmet in + helmets) + + if not has_helmet: + has_others = True + alarm_annotator = handle_alarm_info('break', frame, frame_alarm, person_bbox, + float(person.conf)) + else: + has_jianduyuan = True + + if not has_jianduyuan: + alarm_annotator = handle_alarm_info('no_jiandu', frame, frame_alarm, None) + self.send_alarm_message('no_jiandu') + if has_others: + self.send_alarm_message('break') + + def process_labor(self, frame, result_boxes, pred_ids, pred_names): + pass + def run(self): while not self.stream_loader.init: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 self.stream_loader.init_cap() for frame in self.stream_loader: - if self.__stop_event.is_set(): - break # 如果触发了停止事件,则退出循环 - # print('frame') - if frame is None: - continue + try: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + # print('frame') + if frame is None: + continue - self.device_status_manager.set_status(device_id=self.device.id) - results = self.person_model.predict(source=frame, imgsz=640, - save_txt=False, - save=False, - verbose=False, stream=True) - result = (list(results)) - if len(result[0]) > 0: - asyncio.run_coroutine_threadsafe( - self.tcp_manager.send_message_to_device(device_id=self.device.id, - message=b'\xaa\x01\x00\x93\x07\x00\x9B', - have_response=False), - self.main_loop) + self.device_status_manager.set_status(device_id=self.device.id) + result_boxes, pred_ids, pred_names = self.model_predict(frame) + + frame_alarm = {} + self.process_alarm(frame, result_boxes, pred_ids, pred_names, frame_alarm) + self.process_labor(frame, result_boxes, pred_ids, pred_names) + + if len(frame_alarm.keys()) > 0: + for key in frame_alarm.keys(): + if frame_alarm[key]['count'] > 0: + self.thread_pool.submit_task(self.send_alarm_record, key, frame_alarm, ) + + except Exception as ex: + traceback.print_exc() + logger.error(ex) diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 434459f..4133388 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -106,7 +106,7 @@ logger.info(f"{self.url} disconnect, try to reconnect...") self.cap.release() # 释放当前的捕获对象 self.cap = self.get_connect() # 尝试重新连接 - self.frame = np.zeros_like(self.frame) + self.frame = None continue # 跳过当前循环的剩余部分 else: vid_n += 1 @@ -118,6 +118,7 @@ logger.error(f"{self.url} update fail", exc_info=e) if self.cap is not None: self.cap.release() + self.frame = None self.cap = self.get_connect() # 尝试重新连接 def __iter__(self): diff --git a/app_instance.py b/app_instance.py index 33f5e80..4e24675 100644 --- a/app_instance.py +++ b/app_instance.py @@ -15,6 +15,7 @@ from services.global_config import GlobalConfig from services.model_service import ModelService from services.scene_service import SceneService +from services.schedule_job import start_scheduler from tcp.tcp_manager import TcpManager _app = None # 创建一个私有变量来存储 app 实例 @@ -48,7 +49,7 @@ tcp_manager = TcpManager(device_service=device_service) app.state.tcp_manager = tcp_manager - await tcp_manager.start() + # await tcp_manager.start() algo_runner = AlgoRunner( device_service=device_service, @@ -56,7 +57,7 @@ relation_service=model_relation_service, ) app.state.algo_runner = algo_runner - await algo_runner.start() + # await algo_runner.start() scene_runner = SceneRunner( device_service=device_service, @@ -68,6 +69,8 @@ app.state.scene_runner = scene_runner await scene_runner.start() + main_loop.create_task(start_scheduler()) + yield # 允许请求处理 logger.info("Shutting down application...") diff --git a/common/biz_exception.py b/common/biz_exception.py index 4659571..8e5089f 100644 --- a/common/biz_exception.py +++ b/common/biz_exception.py @@ -1,4 +1,5 @@ from fastapi import Request, HTTPException +from fastapi.responses import JSONResponse from apis.base import standard_error_response @@ -11,7 +12,8 @@ class BizExceptionHandlers: @staticmethod async def biz_exception_handler(request: Request, exc: HTTPException): - return standard_error_response(code=exc.status_code, message=exc.detail) - # 使用 JSONResponse 返回响应 - # return JSONResponse(status_code=exc.status_code, content=response_data) + return JSONResponse( + content=standard_error_response(code=exc.status_code, message=exc.detail).dict(), + status_code=exc.status_code + ) diff --git a/common/global_logger.py b/common/global_logger.py index 3e672ca..c975ea8 100644 --- a/common/global_logger.py +++ b/common/global_logger.py @@ -2,6 +2,7 @@ import logging.handlers import os import sys +from logging.handlers import TimedRotatingFileHandler # 确保日志目录存在 log_dir = 'logs' @@ -15,7 +16,7 @@ logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING) # 创建一个TimedRotatingFileHandler -handler = logging.handlers.TimedRotatingFileHandler( +handler = TimedRotatingFileHandler( os.path.join(log_dir, 'app.log'), # 日志文件名 when='midnight', # 每天午夜滚动 interval=1 # 滚动间隔为1天 diff --git a/scene_handler/limit_space_scene_handler.py b/scene_handler/limit_space_scene_handler.py index 68df7df..d1c5b9a 100644 --- a/scene_handler/limit_space_scene_handler.py +++ b/scene_handler/limit_space_scene_handler.py @@ -1,17 +1,147 @@ import asyncio +import base64 +import traceback from asyncio import Event +from copy import deepcopy, copy +from datetime import datetime + +import cv2 from algo.model_manager import AlgoModelExec from algo.stream_loader import OpenCVStreamLoad from common.device_status_manager import DeviceStatusManager from common.global_logger import logger +from common.global_thread_pool import GlobalThreadPool +from common.http_utils import send_request +from common.image_plotting import Annotator from entity.device import Device from ultralytics import YOLO from scene_handler.base_scene_handler import BaseSceneHandler +from services.global_config import GlobalConfig from tcp.tcp_manager import TcpManager +COLOR_RED = (0, 0, 255) +COLOR_GREEN = (255, 0, 0) + +ALARM_DICT = { + 'hat_and_mask': { + 'alarmType': '11', + 'alarmContent': '未佩戴呼吸防护设备', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '未佩戴呼吸防护设备' + }, + 'no_jiandu': { + 'alarmType': '12', + 'alarmContent': '没有监护人员', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x13\x00\xA7', + 'label': '没有监护人员' + }, + 'break': { + 'alarmType': '3', + 'alarmContent': '非法闯入', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x00\x00\x94', + 'label': '非法闯入' + }, + 'smoke': { + 'alarmType': '6', + 'alarmContent': '吸烟', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x03\x00\x97', + 'label': '吸烟' + }, + 'no_blower': { + 'alarmType': '13', + 'alarmContent': '没有检测到通风设备', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x1A\x00\xAE', + 'label': '没有检测到通风设备' + }, + 'no_extinguisher': { + 'alarmType': '14', + 'alarmContent': '没有检测到灭火器', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x1B\x00\xAF', + 'label': '没有检测到灭火器' + } +} + + +def intersection_area(bbox1, bbox2): + # 计算两个坐标框的重叠面积 + x1, y1, x2, y2 = bbox1 + x3, y3, x4, y4 = bbox2 + + xi1 = max(x1, x3) + yi1 = max(y1, y3) + xi2 = min(x2, x4) + yi2 = min(y2, y4) + + width = max(0, xi2 - xi1) + height = max(0, yi2 - yi1) + + return width * height + + +def bbox_area(bbox): + # 计算坐标框的面积 + x1, y1, x2, y2 = bbox + return (x2 - x1) * (y2 - y1) + + +def get_person_head(person_bbox, heads): + best_head = None + max_overlap = 0 + + for head in heads: + head_bbox = head.xyxy.cpu().squeeze() + overlap_area = intersection_area(person_bbox, head_bbox) + head_area = bbox_area(head_bbox) + overlap_ratio = overlap_area / head_area + + if overlap_ratio >= 0.8 and (best_head is None or overlap_area > max_overlap or ( + overlap_area == max_overlap and float(head.conf) > float(best_head.conf))): + best_head = head + max_overlap = overlap_area + return best_head + + +def is_overlapping(bbox1, bbox2): + # 检查两个坐标框是否重叠 + x1, y1, x2, y2 = bbox1 + x3, y3, x4, y4 = bbox2 + + return not (x2 < x3 or x4 < x1 or y2 < y3 or y4 < y1) + + +def image_to_base64(numpy_image, format='jpg'): + # 将NumPy数组转换为图片格式的字节流 + # format指定图片格式,'png'或'jpeg' + success, encoded_image = cv2.imencode(f'.{format}', numpy_image) + if not success: + raise ValueError("Could not encode image") + + # 将字节流转换为Base64编码 + base64_encoded_image = base64.b64encode(encoded_image) + + # 将bytes类型转换为UTF-8字符串 + base64_message = base64_encoded_image.decode('utf-8') + + return base64_message + + +def handle_alarm_info(type, frame, frame_alarm, person_box, conf=None, person_id=None): + if type not in frame_alarm or frame_alarm[type] is None: + annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + frame_alarm[type] = {'count': 0, 'annotator': annotator} + frame_alarm[type]['count'] = frame_alarm[type]['count'] + 1 + alarm_annotator = frame_alarm[type]['annotator'] + if person_box is not None: + alarm_annotator.box_label(person_box, + ALARM_DICT[type]['label'] + (f'{conf:.2f}' if conf is not None else '') + ( + f' id={person_id}' if person_id is not None else ''), + color=COLOR_RED, + rotated=False) + return alarm_annotator + class LimitSpaceSceneHandler(BaseSceneHandler): @@ -22,8 +152,33 @@ self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, device_thread_id=thread_id) self.device_status_manager = DeviceStatusManager() + self.thread_pool = GlobalThreadPool() - self.person_model = YOLO('weights/yolov8s.pt') + self.model = YOLO('weights/labor-v8-20241114.pt') + self.model_classes = { + 0: '三脚架', + 3: '人', + 4: '作业信息公示牌', + 6: '危险告知牌', + 9: '反光衣', + 11: '呼吸面罩', + 13: '四合一', + 15: '头', + 16: '安全告知牌', + 18: '安全帽', + 20: '安全标识牌', + 24: '工服', + 34: '灭火器', + 43: '警戒线', + 58: '鼓风机', + } + + self.alarm_interval_dict = {} + self.alarm_interval = device.alarm_interval + + self.socket_interval_dict = {} + self.socket_interval = device.alarm_interval + self.socket_retry = 3 self.__stop_event = Event(loop=main_loop) # 使用 Event 控制线程的运行状态 @@ -39,27 +194,119 @@ have_response=have_response), self.main_loop) + def send_alarm_message(self, type): + if self.tcp_manager: + if self.socket_interval_dict.get(type) is None \ + or (datetime.now() - self.socket_interval_dict.get(type)).total_seconds() > int(self.socket_interval): + logger.debug("send alarm message %s %s", ALARM_DICT[type]['alarmContent'], + ALARM_DICT[type]['alarmSoundMessage']) + self.send_tcp_message(ALARM_DICT[type]['alarmSoundMessage'], have_response=True) + self.socket_interval_dict[type] = datetime.now() + + def send_alarm_record(self, type, frame_alarm): + if self.alarm_interval < 0: + return + + global_config = GlobalConfig() + push_config = global_config.get_alarm_push_config() + if push_config and push_config.push_url: + if self.alarm_interval_dict.get(type) is None \ + or (datetime.now() - self.alarm_interval_dict.get(type)).total_seconds() > int(self.alarm_interval): + logger.debug("send alarm record") + + annotator_result = frame_alarm[type]['annotator'].result() + alarm_image = deepcopy(annotator_result) + + data = {} + data["device_id"] = self.device.id + data["alarm_type"] = ALARM_DICT[type]['alarmType'] + data["alarm_content"] = ALARM_DICT[type]['alarmContent'] + data["alarm_value"] = frame_alarm[type]['count'] + data["alarm_image"] = image_to_base64(alarm_image) + data["alarm_time"] = datetime.now().strftime('%Y-%m-%d %H:%M:%S') + + url = push_config.push_url + send_request(url, data) + + data_copy = copy(data) + # 从拷贝的字典中移除"alarm_image"键 + data_copy.pop("alarm_image", None) + logger.debug(f"send to {url}: {data_copy}") + + self.alarm_interval_dict[type] = datetime.now() + + def model_predict(self, frame): + results_gen = self.model(frame, save_txt=False, save=False, verbose=False, conf=0.5, + classes=list(self.model_classes.keys()), + imgsz=640, + stream=True) + results = list(results_gen) # 确保生成器转换为列表 + result = results[0] + result_boxes = [box for box in result.boxes] + pred_ids = [int(box.cls) for box in result_boxes] + pred_names = [self.model_classes[int(box.cls)] for box in result_boxes] + return result_boxes, pred_ids, pred_names + + def process_alarm(self, frame, result_boxes, pred_ids, pred_names, frame_alarm): + persons = [box for box in result_boxes if int(box.cls) == 3] + helmets = [box for box in result_boxes if int(box.cls) == 18] + heads = [box for box in result_boxes if int(box.cls) == 15] + + has_jianduyuan = False + has_others = False + + for person in persons: + person_bbox = person.xyxy.cpu().squeeze() + + # 检查这个人是否佩戴了安全帽 + has_helmet = True + person_head = get_person_head(person_bbox, heads) + if person_head is not None: + has_helmet = any( + is_overlapping(person_head.xyxy.cpu().squeeze(), helmet.xyxy.cpu().squeeze()) for helmet in + helmets) + + if not has_helmet: + has_others = True + alarm_annotator = handle_alarm_info('break', frame, frame_alarm, person_bbox, + float(person.conf)) + else: + has_jianduyuan = True + + if not has_jianduyuan: + alarm_annotator = handle_alarm_info('no_jiandu', frame, frame_alarm, None) + self.send_alarm_message('no_jiandu') + if has_others: + self.send_alarm_message('break') + + def process_labor(self, frame, result_boxes, pred_ids, pred_names): + pass + def run(self): while not self.stream_loader.init: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 self.stream_loader.init_cap() for frame in self.stream_loader: - if self.__stop_event.is_set(): - break # 如果触发了停止事件,则退出循环 - # print('frame') - if frame is None: - continue + try: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + # print('frame') + if frame is None: + continue - self.device_status_manager.set_status(device_id=self.device.id) - results = self.person_model.predict(source=frame, imgsz=640, - save_txt=False, - save=False, - verbose=False, stream=True) - result = (list(results)) - if len(result[0]) > 0: - asyncio.run_coroutine_threadsafe( - self.tcp_manager.send_message_to_device(device_id=self.device.id, - message=b'\xaa\x01\x00\x93\x07\x00\x9B', - have_response=False), - self.main_loop) + self.device_status_manager.set_status(device_id=self.device.id) + result_boxes, pred_ids, pred_names = self.model_predict(frame) + + frame_alarm = {} + self.process_alarm(frame, result_boxes, pred_ids, pred_names, frame_alarm) + self.process_labor(frame, result_boxes, pred_ids, pred_names) + + if len(frame_alarm.keys()) > 0: + for key in frame_alarm.keys(): + if frame_alarm[key]['count'] > 0: + self.thread_pool.submit_task(self.send_alarm_record, key, frame_alarm, ) + + except Exception as ex: + traceback.print_exc() + logger.error(ex) diff --git a/services/frame_analysis_result_service.py b/services/frame_analysis_result_service.py index 4e6c145..ca925cf 100644 --- a/services/frame_analysis_result_service.py +++ b/services/frame_analysis_result_service.py @@ -1,6 +1,6 @@ from typing import List -from sqlmodel import select +from sqlmodel import select,delete from sqlalchemy.ext.asyncio import AsyncSession from entity.frame_analysis_result import FrameAnalysisResultCreate, FrameAnalysisResult @@ -23,3 +23,9 @@ statement = select(FrameAnalysisResult).where(FrameAnalysisResult.frame_id == frame_id) results = await self.db.execute(statement) return results.scalars().all() + + async def delete_by_frame_id(self, max_frame_id): + statement = delete(FrameAnalysisResult).where(FrameAnalysisResult.frame_id <= max_frame_id) + await self.db.execute(statement) + await self.db.commit() + return max_frame_id diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 434459f..4133388 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -106,7 +106,7 @@ logger.info(f"{self.url} disconnect, try to reconnect...") self.cap.release() # 释放当前的捕获对象 self.cap = self.get_connect() # 尝试重新连接 - self.frame = np.zeros_like(self.frame) + self.frame = None continue # 跳过当前循环的剩余部分 else: vid_n += 1 @@ -118,6 +118,7 @@ logger.error(f"{self.url} update fail", exc_info=e) if self.cap is not None: self.cap.release() + self.frame = None self.cap = self.get_connect() # 尝试重新连接 def __iter__(self): diff --git a/app_instance.py b/app_instance.py index 33f5e80..4e24675 100644 --- a/app_instance.py +++ b/app_instance.py @@ -15,6 +15,7 @@ from services.global_config import GlobalConfig from services.model_service import ModelService from services.scene_service import SceneService +from services.schedule_job import start_scheduler from tcp.tcp_manager import TcpManager _app = None # 创建一个私有变量来存储 app 实例 @@ -48,7 +49,7 @@ tcp_manager = TcpManager(device_service=device_service) app.state.tcp_manager = tcp_manager - await tcp_manager.start() + # await tcp_manager.start() algo_runner = AlgoRunner( device_service=device_service, @@ -56,7 +57,7 @@ relation_service=model_relation_service, ) app.state.algo_runner = algo_runner - await algo_runner.start() + # await algo_runner.start() scene_runner = SceneRunner( device_service=device_service, @@ -68,6 +69,8 @@ app.state.scene_runner = scene_runner await scene_runner.start() + main_loop.create_task(start_scheduler()) + yield # 允许请求处理 logger.info("Shutting down application...") diff --git a/common/biz_exception.py b/common/biz_exception.py index 4659571..8e5089f 100644 --- a/common/biz_exception.py +++ b/common/biz_exception.py @@ -1,4 +1,5 @@ from fastapi import Request, HTTPException +from fastapi.responses import JSONResponse from apis.base import standard_error_response @@ -11,7 +12,8 @@ class BizExceptionHandlers: @staticmethod async def biz_exception_handler(request: Request, exc: HTTPException): - return standard_error_response(code=exc.status_code, message=exc.detail) - # 使用 JSONResponse 返回响应 - # return JSONResponse(status_code=exc.status_code, content=response_data) + return JSONResponse( + content=standard_error_response(code=exc.status_code, message=exc.detail).dict(), + status_code=exc.status_code + ) diff --git a/common/global_logger.py b/common/global_logger.py index 3e672ca..c975ea8 100644 --- a/common/global_logger.py +++ b/common/global_logger.py @@ -2,6 +2,7 @@ import logging.handlers import os import sys +from logging.handlers import TimedRotatingFileHandler # 确保日志目录存在 log_dir = 'logs' @@ -15,7 +16,7 @@ logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING) # 创建一个TimedRotatingFileHandler -handler = logging.handlers.TimedRotatingFileHandler( +handler = TimedRotatingFileHandler( os.path.join(log_dir, 'app.log'), # 日志文件名 when='midnight', # 每天午夜滚动 interval=1 # 滚动间隔为1天 diff --git a/scene_handler/limit_space_scene_handler.py b/scene_handler/limit_space_scene_handler.py index 68df7df..d1c5b9a 100644 --- a/scene_handler/limit_space_scene_handler.py +++ b/scene_handler/limit_space_scene_handler.py @@ -1,17 +1,147 @@ import asyncio +import base64 +import traceback from asyncio import Event +from copy import deepcopy, copy +from datetime import datetime + +import cv2 from algo.model_manager import AlgoModelExec from algo.stream_loader import OpenCVStreamLoad from common.device_status_manager import DeviceStatusManager from common.global_logger import logger +from common.global_thread_pool import GlobalThreadPool +from common.http_utils import send_request +from common.image_plotting import Annotator from entity.device import Device from ultralytics import YOLO from scene_handler.base_scene_handler import BaseSceneHandler +from services.global_config import GlobalConfig from tcp.tcp_manager import TcpManager +COLOR_RED = (0, 0, 255) +COLOR_GREEN = (255, 0, 0) + +ALARM_DICT = { + 'hat_and_mask': { + 'alarmType': '11', + 'alarmContent': '未佩戴呼吸防护设备', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '未佩戴呼吸防护设备' + }, + 'no_jiandu': { + 'alarmType': '12', + 'alarmContent': '没有监护人员', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x13\x00\xA7', + 'label': '没有监护人员' + }, + 'break': { + 'alarmType': '3', + 'alarmContent': '非法闯入', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x00\x00\x94', + 'label': '非法闯入' + }, + 'smoke': { + 'alarmType': '6', + 'alarmContent': '吸烟', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x03\x00\x97', + 'label': '吸烟' + }, + 'no_blower': { + 'alarmType': '13', + 'alarmContent': '没有检测到通风设备', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x1A\x00\xAE', + 'label': '没有检测到通风设备' + }, + 'no_extinguisher': { + 'alarmType': '14', + 'alarmContent': '没有检测到灭火器', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x1B\x00\xAF', + 'label': '没有检测到灭火器' + } +} + + +def intersection_area(bbox1, bbox2): + # 计算两个坐标框的重叠面积 + x1, y1, x2, y2 = bbox1 + x3, y3, x4, y4 = bbox2 + + xi1 = max(x1, x3) + yi1 = max(y1, y3) + xi2 = min(x2, x4) + yi2 = min(y2, y4) + + width = max(0, xi2 - xi1) + height = max(0, yi2 - yi1) + + return width * height + + +def bbox_area(bbox): + # 计算坐标框的面积 + x1, y1, x2, y2 = bbox + return (x2 - x1) * (y2 - y1) + + +def get_person_head(person_bbox, heads): + best_head = None + max_overlap = 0 + + for head in heads: + head_bbox = head.xyxy.cpu().squeeze() + overlap_area = intersection_area(person_bbox, head_bbox) + head_area = bbox_area(head_bbox) + overlap_ratio = overlap_area / head_area + + if overlap_ratio >= 0.8 and (best_head is None or overlap_area > max_overlap or ( + overlap_area == max_overlap and float(head.conf) > float(best_head.conf))): + best_head = head + max_overlap = overlap_area + return best_head + + +def is_overlapping(bbox1, bbox2): + # 检查两个坐标框是否重叠 + x1, y1, x2, y2 = bbox1 + x3, y3, x4, y4 = bbox2 + + return not (x2 < x3 or x4 < x1 or y2 < y3 or y4 < y1) + + +def image_to_base64(numpy_image, format='jpg'): + # 将NumPy数组转换为图片格式的字节流 + # format指定图片格式,'png'或'jpeg' + success, encoded_image = cv2.imencode(f'.{format}', numpy_image) + if not success: + raise ValueError("Could not encode image") + + # 将字节流转换为Base64编码 + base64_encoded_image = base64.b64encode(encoded_image) + + # 将bytes类型转换为UTF-8字符串 + base64_message = base64_encoded_image.decode('utf-8') + + return base64_message + + +def handle_alarm_info(type, frame, frame_alarm, person_box, conf=None, person_id=None): + if type not in frame_alarm or frame_alarm[type] is None: + annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + frame_alarm[type] = {'count': 0, 'annotator': annotator} + frame_alarm[type]['count'] = frame_alarm[type]['count'] + 1 + alarm_annotator = frame_alarm[type]['annotator'] + if person_box is not None: + alarm_annotator.box_label(person_box, + ALARM_DICT[type]['label'] + (f'{conf:.2f}' if conf is not None else '') + ( + f' id={person_id}' if person_id is not None else ''), + color=COLOR_RED, + rotated=False) + return alarm_annotator + class LimitSpaceSceneHandler(BaseSceneHandler): @@ -22,8 +152,33 @@ self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, device_thread_id=thread_id) self.device_status_manager = DeviceStatusManager() + self.thread_pool = GlobalThreadPool() - self.person_model = YOLO('weights/yolov8s.pt') + self.model = YOLO('weights/labor-v8-20241114.pt') + self.model_classes = { + 0: '三脚架', + 3: '人', + 4: '作业信息公示牌', + 6: '危险告知牌', + 9: '反光衣', + 11: '呼吸面罩', + 13: '四合一', + 15: '头', + 16: '安全告知牌', + 18: '安全帽', + 20: '安全标识牌', + 24: '工服', + 34: '灭火器', + 43: '警戒线', + 58: '鼓风机', + } + + self.alarm_interval_dict = {} + self.alarm_interval = device.alarm_interval + + self.socket_interval_dict = {} + self.socket_interval = device.alarm_interval + self.socket_retry = 3 self.__stop_event = Event(loop=main_loop) # 使用 Event 控制线程的运行状态 @@ -39,27 +194,119 @@ have_response=have_response), self.main_loop) + def send_alarm_message(self, type): + if self.tcp_manager: + if self.socket_interval_dict.get(type) is None \ + or (datetime.now() - self.socket_interval_dict.get(type)).total_seconds() > int(self.socket_interval): + logger.debug("send alarm message %s %s", ALARM_DICT[type]['alarmContent'], + ALARM_DICT[type]['alarmSoundMessage']) + self.send_tcp_message(ALARM_DICT[type]['alarmSoundMessage'], have_response=True) + self.socket_interval_dict[type] = datetime.now() + + def send_alarm_record(self, type, frame_alarm): + if self.alarm_interval < 0: + return + + global_config = GlobalConfig() + push_config = global_config.get_alarm_push_config() + if push_config and push_config.push_url: + if self.alarm_interval_dict.get(type) is None \ + or (datetime.now() - self.alarm_interval_dict.get(type)).total_seconds() > int(self.alarm_interval): + logger.debug("send alarm record") + + annotator_result = frame_alarm[type]['annotator'].result() + alarm_image = deepcopy(annotator_result) + + data = {} + data["device_id"] = self.device.id + data["alarm_type"] = ALARM_DICT[type]['alarmType'] + data["alarm_content"] = ALARM_DICT[type]['alarmContent'] + data["alarm_value"] = frame_alarm[type]['count'] + data["alarm_image"] = image_to_base64(alarm_image) + data["alarm_time"] = datetime.now().strftime('%Y-%m-%d %H:%M:%S') + + url = push_config.push_url + send_request(url, data) + + data_copy = copy(data) + # 从拷贝的字典中移除"alarm_image"键 + data_copy.pop("alarm_image", None) + logger.debug(f"send to {url}: {data_copy}") + + self.alarm_interval_dict[type] = datetime.now() + + def model_predict(self, frame): + results_gen = self.model(frame, save_txt=False, save=False, verbose=False, conf=0.5, + classes=list(self.model_classes.keys()), + imgsz=640, + stream=True) + results = list(results_gen) # 确保生成器转换为列表 + result = results[0] + result_boxes = [box for box in result.boxes] + pred_ids = [int(box.cls) for box in result_boxes] + pred_names = [self.model_classes[int(box.cls)] for box in result_boxes] + return result_boxes, pred_ids, pred_names + + def process_alarm(self, frame, result_boxes, pred_ids, pred_names, frame_alarm): + persons = [box for box in result_boxes if int(box.cls) == 3] + helmets = [box for box in result_boxes if int(box.cls) == 18] + heads = [box for box in result_boxes if int(box.cls) == 15] + + has_jianduyuan = False + has_others = False + + for person in persons: + person_bbox = person.xyxy.cpu().squeeze() + + # 检查这个人是否佩戴了安全帽 + has_helmet = True + person_head = get_person_head(person_bbox, heads) + if person_head is not None: + has_helmet = any( + is_overlapping(person_head.xyxy.cpu().squeeze(), helmet.xyxy.cpu().squeeze()) for helmet in + helmets) + + if not has_helmet: + has_others = True + alarm_annotator = handle_alarm_info('break', frame, frame_alarm, person_bbox, + float(person.conf)) + else: + has_jianduyuan = True + + if not has_jianduyuan: + alarm_annotator = handle_alarm_info('no_jiandu', frame, frame_alarm, None) + self.send_alarm_message('no_jiandu') + if has_others: + self.send_alarm_message('break') + + def process_labor(self, frame, result_boxes, pred_ids, pred_names): + pass + def run(self): while not self.stream_loader.init: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 self.stream_loader.init_cap() for frame in self.stream_loader: - if self.__stop_event.is_set(): - break # 如果触发了停止事件,则退出循环 - # print('frame') - if frame is None: - continue + try: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + # print('frame') + if frame is None: + continue - self.device_status_manager.set_status(device_id=self.device.id) - results = self.person_model.predict(source=frame, imgsz=640, - save_txt=False, - save=False, - verbose=False, stream=True) - result = (list(results)) - if len(result[0]) > 0: - asyncio.run_coroutine_threadsafe( - self.tcp_manager.send_message_to_device(device_id=self.device.id, - message=b'\xaa\x01\x00\x93\x07\x00\x9B', - have_response=False), - self.main_loop) + self.device_status_manager.set_status(device_id=self.device.id) + result_boxes, pred_ids, pred_names = self.model_predict(frame) + + frame_alarm = {} + self.process_alarm(frame, result_boxes, pred_ids, pred_names, frame_alarm) + self.process_labor(frame, result_boxes, pred_ids, pred_names) + + if len(frame_alarm.keys()) > 0: + for key in frame_alarm.keys(): + if frame_alarm[key]['count'] > 0: + self.thread_pool.submit_task(self.send_alarm_record, key, frame_alarm, ) + + except Exception as ex: + traceback.print_exc() + logger.error(ex) diff --git a/services/frame_analysis_result_service.py b/services/frame_analysis_result_service.py index 4e6c145..ca925cf 100644 --- a/services/frame_analysis_result_service.py +++ b/services/frame_analysis_result_service.py @@ -1,6 +1,6 @@ from typing import List -from sqlmodel import select +from sqlmodel import select,delete from sqlalchemy.ext.asyncio import AsyncSession from entity.frame_analysis_result import FrameAnalysisResultCreate, FrameAnalysisResult @@ -23,3 +23,9 @@ statement = select(FrameAnalysisResult).where(FrameAnalysisResult.frame_id == frame_id) results = await self.db.execute(statement) return results.scalars().all() + + async def delete_by_frame_id(self, max_frame_id): + statement = delete(FrameAnalysisResult).where(FrameAnalysisResult.frame_id <= max_frame_id) + await self.db.execute(statement) + await self.db.commit() + return max_frame_id diff --git a/services/schedule_job.py b/services/schedule_job.py new file mode 100644 index 0000000..def2d1d --- /dev/null +++ b/services/schedule_job.py @@ -0,0 +1,142 @@ +import asyncio +import os +import shutil +from datetime import datetime, timedelta + +from dateutil.relativedelta import relativedelta + +from common.global_logger import logger +from db.database import get_db +from services.device_frame_service import DeviceFrameService + +from apscheduler.schedulers.asyncio import AsyncIOScheduler +from pytz import timezone + +import re +import gzip + + +async def start_scheduler(): + await delete_frames() + await compress_day_logs() + await organize_and_compress_month_logs() + + +async def delete_frames(save_month=3): + ''' + 删除3个月前的帧数据 + :param save_month: 默认保留3个月的帧数据 + :return: + ''' + base_folder = './storage/frames' + + async for db in get_db(): + frame_service = DeviceFrameService(db) + + now = datetime.now() + latest_time = now - relativedelta(months=save_month) + latest_time = latest_time.replace(hour=0, minute=0, second=0, microsecond=0) + + max_frame_id = await frame_service.select_max_frame_id(latest_time) + if max_frame_id > 0: + await frame_service.delete_frame(max_frame_id) + logger.info(f"Delete frames before {latest_time}, max_frame_id = {max_frame_id}") + + # 遍历子文件夹 + for folder_name in os.listdir(base_folder): + folder_path = os.path.join(base_folder, folder_name) + if os.path.isdir(folder_path): + try: + # 尝试将文件夹名称解析为日期 + folder_date = datetime.strptime(folder_name, "%Y-%m-%d") + folder_date = folder_date.replace(hour=0, minute=0, second=0, microsecond=0) + # 比较日期并删除早于3个月前的文件夹 + if folder_date < latest_time: + logger.info(f"Deleting folder: {folder_path}") + shutil.rmtree(folder_path) # 递归删除文件夹及其内容 + except ValueError: + # 忽略无法解析为日期的文件夹名称 + logger.warning(f"Skipping non-date folder: {folder_name}") + + +async def compress_day_logs(): + """ + 异步压缩日志目录中符合特定日期格式 (年-月-日) 的日志文件,并删除原始文件。 + + :param log_dir: 日志目录路径 + :param base_filename: 日志的基础文件名,例如 'app.log' + """ + log_dir = 'logs' + base_filename = 'app.log' + loop = asyncio.get_running_loop() + base_filename_pattern = re.escape(base_filename) + r"\.\d{4}-\d{2}-\d{2}" # 匹配 app.log.年-月-日 格式 + + for filename in os.listdir(log_dir): + file_path = os.path.join(log_dir, filename) + if os.path.isfile(file_path) and re.fullmatch(base_filename_pattern, filename): + try: + compressed_file_path = f"{file_path}.gz" + # 使用线程池在后台执行文件操作以避免阻塞事件循环 + await loop.run_in_executor(None, compress_and_remove, file_path, compressed_file_path) + print(f"Compressed and removed: {file_path}") + except Exception as e: + print(f"Error compressing {file_path}: {e}") + + +def compress_and_remove(file_path, compressed_file_path): + """同步函数,用于压缩并删除文件""" + with open(file_path, 'rb') as f_in: + with gzip.open(compressed_file_path, 'wb') as f_out: + f_out.writelines(f_in) + os.remove(file_path) # 删除未压缩的旧文件 + + +async def organize_and_compress_month_logs(): + loop = asyncio.get_running_loop() + + base_log_dir = 'logs' # 日志目录路径 + archive_dir = 'logs/archive' # 日志归档目录路径 + os.makedirs(archive_dir, exist_ok=True) + + # 获取当前日期以确定要处理的月份 + today = datetime.now() + year_month_pattern = re.compile(r'\d{4}-\d{2}') # 匹配 YYYY-MM 格式 + + # 遍历日志目录并识别所有符合 YYYY-MM 格式的日志文件 + processed_months = set() + for filename in os.listdir(base_log_dir): + file_path = os.path.join(base_log_dir, filename) + if os.path.isfile(file_path): + match = year_month_pattern.search(filename) + if match: + month_str = match.group(0) + # 忽略当前月份 + if month_str < today.strftime('%Y-%m'): + processed_months.add(month_str) + + # 逐月处理日志文件 + for month_str in sorted(processed_months): + archive_path = os.path.join(archive_dir, month_str) + os.makedirs(archive_path, exist_ok=True) + + # 移动符合该月份的日志文件 + for filename in os.listdir(base_log_dir): + file_path = os.path.join(base_log_dir, filename) + if os.path.isfile(file_path) and month_str in filename: + try: + # 使用线程池执行同步的移动操作 + await loop.run_in_executor(None, shutil.move, file_path, os.path.join(archive_path, filename)) + except Exception as e: + logger.error(f"Error moving file {filename}: {e}") + + # 压缩归档文件夹 + archive_tar_path = os.path.join(archive_dir, f"{month_str}.tar.gz") + try: + # 使用线程池执行同步的压缩操作 + await loop.run_in_executor(None, shutil.make_archive, archive_tar_path.replace('.tar.gz', ''), 'gztar', + archive_dir, month_str) + # 使用线程池执行同步的删除操作 + await loop.run_in_executor(None, shutil.rmtree, archive_path) + logger.info(f"Archived and compressed logs for {month_str} to {archive_tar_path}") + except Exception as e: + logger.error(f"Error compressing logs for {month_str}: {e}") diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 434459f..4133388 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -106,7 +106,7 @@ logger.info(f"{self.url} disconnect, try to reconnect...") self.cap.release() # 释放当前的捕获对象 self.cap = self.get_connect() # 尝试重新连接 - self.frame = np.zeros_like(self.frame) + self.frame = None continue # 跳过当前循环的剩余部分 else: vid_n += 1 @@ -118,6 +118,7 @@ logger.error(f"{self.url} update fail", exc_info=e) if self.cap is not None: self.cap.release() + self.frame = None self.cap = self.get_connect() # 尝试重新连接 def __iter__(self): diff --git a/app_instance.py b/app_instance.py index 33f5e80..4e24675 100644 --- a/app_instance.py +++ b/app_instance.py @@ -15,6 +15,7 @@ from services.global_config import GlobalConfig from services.model_service import ModelService from services.scene_service import SceneService +from services.schedule_job import start_scheduler from tcp.tcp_manager import TcpManager _app = None # 创建一个私有变量来存储 app 实例 @@ -48,7 +49,7 @@ tcp_manager = TcpManager(device_service=device_service) app.state.tcp_manager = tcp_manager - await tcp_manager.start() + # await tcp_manager.start() algo_runner = AlgoRunner( device_service=device_service, @@ -56,7 +57,7 @@ relation_service=model_relation_service, ) app.state.algo_runner = algo_runner - await algo_runner.start() + # await algo_runner.start() scene_runner = SceneRunner( device_service=device_service, @@ -68,6 +69,8 @@ app.state.scene_runner = scene_runner await scene_runner.start() + main_loop.create_task(start_scheduler()) + yield # 允许请求处理 logger.info("Shutting down application...") diff --git a/common/biz_exception.py b/common/biz_exception.py index 4659571..8e5089f 100644 --- a/common/biz_exception.py +++ b/common/biz_exception.py @@ -1,4 +1,5 @@ from fastapi import Request, HTTPException +from fastapi.responses import JSONResponse from apis.base import standard_error_response @@ -11,7 +12,8 @@ class BizExceptionHandlers: @staticmethod async def biz_exception_handler(request: Request, exc: HTTPException): - return standard_error_response(code=exc.status_code, message=exc.detail) - # 使用 JSONResponse 返回响应 - # return JSONResponse(status_code=exc.status_code, content=response_data) + return JSONResponse( + content=standard_error_response(code=exc.status_code, message=exc.detail).dict(), + status_code=exc.status_code + ) diff --git a/common/global_logger.py b/common/global_logger.py index 3e672ca..c975ea8 100644 --- a/common/global_logger.py +++ b/common/global_logger.py @@ -2,6 +2,7 @@ import logging.handlers import os import sys +from logging.handlers import TimedRotatingFileHandler # 确保日志目录存在 log_dir = 'logs' @@ -15,7 +16,7 @@ logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING) # 创建一个TimedRotatingFileHandler -handler = logging.handlers.TimedRotatingFileHandler( +handler = TimedRotatingFileHandler( os.path.join(log_dir, 'app.log'), # 日志文件名 when='midnight', # 每天午夜滚动 interval=1 # 滚动间隔为1天 diff --git a/scene_handler/limit_space_scene_handler.py b/scene_handler/limit_space_scene_handler.py index 68df7df..d1c5b9a 100644 --- a/scene_handler/limit_space_scene_handler.py +++ b/scene_handler/limit_space_scene_handler.py @@ -1,17 +1,147 @@ import asyncio +import base64 +import traceback from asyncio import Event +from copy import deepcopy, copy +from datetime import datetime + +import cv2 from algo.model_manager import AlgoModelExec from algo.stream_loader import OpenCVStreamLoad from common.device_status_manager import DeviceStatusManager from common.global_logger import logger +from common.global_thread_pool import GlobalThreadPool +from common.http_utils import send_request +from common.image_plotting import Annotator from entity.device import Device from ultralytics import YOLO from scene_handler.base_scene_handler import BaseSceneHandler +from services.global_config import GlobalConfig from tcp.tcp_manager import TcpManager +COLOR_RED = (0, 0, 255) +COLOR_GREEN = (255, 0, 0) + +ALARM_DICT = { + 'hat_and_mask': { + 'alarmType': '11', + 'alarmContent': '未佩戴呼吸防护设备', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '未佩戴呼吸防护设备' + }, + 'no_jiandu': { + 'alarmType': '12', + 'alarmContent': '没有监护人员', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x13\x00\xA7', + 'label': '没有监护人员' + }, + 'break': { + 'alarmType': '3', + 'alarmContent': '非法闯入', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x00\x00\x94', + 'label': '非法闯入' + }, + 'smoke': { + 'alarmType': '6', + 'alarmContent': '吸烟', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x03\x00\x97', + 'label': '吸烟' + }, + 'no_blower': { + 'alarmType': '13', + 'alarmContent': '没有检测到通风设备', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x1A\x00\xAE', + 'label': '没有检测到通风设备' + }, + 'no_extinguisher': { + 'alarmType': '14', + 'alarmContent': '没有检测到灭火器', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x1B\x00\xAF', + 'label': '没有检测到灭火器' + } +} + + +def intersection_area(bbox1, bbox2): + # 计算两个坐标框的重叠面积 + x1, y1, x2, y2 = bbox1 + x3, y3, x4, y4 = bbox2 + + xi1 = max(x1, x3) + yi1 = max(y1, y3) + xi2 = min(x2, x4) + yi2 = min(y2, y4) + + width = max(0, xi2 - xi1) + height = max(0, yi2 - yi1) + + return width * height + + +def bbox_area(bbox): + # 计算坐标框的面积 + x1, y1, x2, y2 = bbox + return (x2 - x1) * (y2 - y1) + + +def get_person_head(person_bbox, heads): + best_head = None + max_overlap = 0 + + for head in heads: + head_bbox = head.xyxy.cpu().squeeze() + overlap_area = intersection_area(person_bbox, head_bbox) + head_area = bbox_area(head_bbox) + overlap_ratio = overlap_area / head_area + + if overlap_ratio >= 0.8 and (best_head is None or overlap_area > max_overlap or ( + overlap_area == max_overlap and float(head.conf) > float(best_head.conf))): + best_head = head + max_overlap = overlap_area + return best_head + + +def is_overlapping(bbox1, bbox2): + # 检查两个坐标框是否重叠 + x1, y1, x2, y2 = bbox1 + x3, y3, x4, y4 = bbox2 + + return not (x2 < x3 or x4 < x1 or y2 < y3 or y4 < y1) + + +def image_to_base64(numpy_image, format='jpg'): + # 将NumPy数组转换为图片格式的字节流 + # format指定图片格式,'png'或'jpeg' + success, encoded_image = cv2.imencode(f'.{format}', numpy_image) + if not success: + raise ValueError("Could not encode image") + + # 将字节流转换为Base64编码 + base64_encoded_image = base64.b64encode(encoded_image) + + # 将bytes类型转换为UTF-8字符串 + base64_message = base64_encoded_image.decode('utf-8') + + return base64_message + + +def handle_alarm_info(type, frame, frame_alarm, person_box, conf=None, person_id=None): + if type not in frame_alarm or frame_alarm[type] is None: + annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + frame_alarm[type] = {'count': 0, 'annotator': annotator} + frame_alarm[type]['count'] = frame_alarm[type]['count'] + 1 + alarm_annotator = frame_alarm[type]['annotator'] + if person_box is not None: + alarm_annotator.box_label(person_box, + ALARM_DICT[type]['label'] + (f'{conf:.2f}' if conf is not None else '') + ( + f' id={person_id}' if person_id is not None else ''), + color=COLOR_RED, + rotated=False) + return alarm_annotator + class LimitSpaceSceneHandler(BaseSceneHandler): @@ -22,8 +152,33 @@ self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, device_thread_id=thread_id) self.device_status_manager = DeviceStatusManager() + self.thread_pool = GlobalThreadPool() - self.person_model = YOLO('weights/yolov8s.pt') + self.model = YOLO('weights/labor-v8-20241114.pt') + self.model_classes = { + 0: '三脚架', + 3: '人', + 4: '作业信息公示牌', + 6: '危险告知牌', + 9: '反光衣', + 11: '呼吸面罩', + 13: '四合一', + 15: '头', + 16: '安全告知牌', + 18: '安全帽', + 20: '安全标识牌', + 24: '工服', + 34: '灭火器', + 43: '警戒线', + 58: '鼓风机', + } + + self.alarm_interval_dict = {} + self.alarm_interval = device.alarm_interval + + self.socket_interval_dict = {} + self.socket_interval = device.alarm_interval + self.socket_retry = 3 self.__stop_event = Event(loop=main_loop) # 使用 Event 控制线程的运行状态 @@ -39,27 +194,119 @@ have_response=have_response), self.main_loop) + def send_alarm_message(self, type): + if self.tcp_manager: + if self.socket_interval_dict.get(type) is None \ + or (datetime.now() - self.socket_interval_dict.get(type)).total_seconds() > int(self.socket_interval): + logger.debug("send alarm message %s %s", ALARM_DICT[type]['alarmContent'], + ALARM_DICT[type]['alarmSoundMessage']) + self.send_tcp_message(ALARM_DICT[type]['alarmSoundMessage'], have_response=True) + self.socket_interval_dict[type] = datetime.now() + + def send_alarm_record(self, type, frame_alarm): + if self.alarm_interval < 0: + return + + global_config = GlobalConfig() + push_config = global_config.get_alarm_push_config() + if push_config and push_config.push_url: + if self.alarm_interval_dict.get(type) is None \ + or (datetime.now() - self.alarm_interval_dict.get(type)).total_seconds() > int(self.alarm_interval): + logger.debug("send alarm record") + + annotator_result = frame_alarm[type]['annotator'].result() + alarm_image = deepcopy(annotator_result) + + data = {} + data["device_id"] = self.device.id + data["alarm_type"] = ALARM_DICT[type]['alarmType'] + data["alarm_content"] = ALARM_DICT[type]['alarmContent'] + data["alarm_value"] = frame_alarm[type]['count'] + data["alarm_image"] = image_to_base64(alarm_image) + data["alarm_time"] = datetime.now().strftime('%Y-%m-%d %H:%M:%S') + + url = push_config.push_url + send_request(url, data) + + data_copy = copy(data) + # 从拷贝的字典中移除"alarm_image"键 + data_copy.pop("alarm_image", None) + logger.debug(f"send to {url}: {data_copy}") + + self.alarm_interval_dict[type] = datetime.now() + + def model_predict(self, frame): + results_gen = self.model(frame, save_txt=False, save=False, verbose=False, conf=0.5, + classes=list(self.model_classes.keys()), + imgsz=640, + stream=True) + results = list(results_gen) # 确保生成器转换为列表 + result = results[0] + result_boxes = [box for box in result.boxes] + pred_ids = [int(box.cls) for box in result_boxes] + pred_names = [self.model_classes[int(box.cls)] for box in result_boxes] + return result_boxes, pred_ids, pred_names + + def process_alarm(self, frame, result_boxes, pred_ids, pred_names, frame_alarm): + persons = [box for box in result_boxes if int(box.cls) == 3] + helmets = [box for box in result_boxes if int(box.cls) == 18] + heads = [box for box in result_boxes if int(box.cls) == 15] + + has_jianduyuan = False + has_others = False + + for person in persons: + person_bbox = person.xyxy.cpu().squeeze() + + # 检查这个人是否佩戴了安全帽 + has_helmet = True + person_head = get_person_head(person_bbox, heads) + if person_head is not None: + has_helmet = any( + is_overlapping(person_head.xyxy.cpu().squeeze(), helmet.xyxy.cpu().squeeze()) for helmet in + helmets) + + if not has_helmet: + has_others = True + alarm_annotator = handle_alarm_info('break', frame, frame_alarm, person_bbox, + float(person.conf)) + else: + has_jianduyuan = True + + if not has_jianduyuan: + alarm_annotator = handle_alarm_info('no_jiandu', frame, frame_alarm, None) + self.send_alarm_message('no_jiandu') + if has_others: + self.send_alarm_message('break') + + def process_labor(self, frame, result_boxes, pred_ids, pred_names): + pass + def run(self): while not self.stream_loader.init: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 self.stream_loader.init_cap() for frame in self.stream_loader: - if self.__stop_event.is_set(): - break # 如果触发了停止事件,则退出循环 - # print('frame') - if frame is None: - continue + try: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + # print('frame') + if frame is None: + continue - self.device_status_manager.set_status(device_id=self.device.id) - results = self.person_model.predict(source=frame, imgsz=640, - save_txt=False, - save=False, - verbose=False, stream=True) - result = (list(results)) - if len(result[0]) > 0: - asyncio.run_coroutine_threadsafe( - self.tcp_manager.send_message_to_device(device_id=self.device.id, - message=b'\xaa\x01\x00\x93\x07\x00\x9B', - have_response=False), - self.main_loop) + self.device_status_manager.set_status(device_id=self.device.id) + result_boxes, pred_ids, pred_names = self.model_predict(frame) + + frame_alarm = {} + self.process_alarm(frame, result_boxes, pred_ids, pred_names, frame_alarm) + self.process_labor(frame, result_boxes, pred_ids, pred_names) + + if len(frame_alarm.keys()) > 0: + for key in frame_alarm.keys(): + if frame_alarm[key]['count'] > 0: + self.thread_pool.submit_task(self.send_alarm_record, key, frame_alarm, ) + + except Exception as ex: + traceback.print_exc() + logger.error(ex) diff --git a/services/frame_analysis_result_service.py b/services/frame_analysis_result_service.py index 4e6c145..ca925cf 100644 --- a/services/frame_analysis_result_service.py +++ b/services/frame_analysis_result_service.py @@ -1,6 +1,6 @@ from typing import List -from sqlmodel import select +from sqlmodel import select,delete from sqlalchemy.ext.asyncio import AsyncSession from entity.frame_analysis_result import FrameAnalysisResultCreate, FrameAnalysisResult @@ -23,3 +23,9 @@ statement = select(FrameAnalysisResult).where(FrameAnalysisResult.frame_id == frame_id) results = await self.db.execute(statement) return results.scalars().all() + + async def delete_by_frame_id(self, max_frame_id): + statement = delete(FrameAnalysisResult).where(FrameAnalysisResult.frame_id <= max_frame_id) + await self.db.execute(statement) + await self.db.commit() + return max_frame_id diff --git a/services/schedule_job.py b/services/schedule_job.py new file mode 100644 index 0000000..def2d1d --- /dev/null +++ b/services/schedule_job.py @@ -0,0 +1,142 @@ +import asyncio +import os +import shutil +from datetime import datetime, timedelta + +from dateutil.relativedelta import relativedelta + +from common.global_logger import logger +from db.database import get_db +from services.device_frame_service import DeviceFrameService + +from apscheduler.schedulers.asyncio import AsyncIOScheduler +from pytz import timezone + +import re +import gzip + + +async def start_scheduler(): + await delete_frames() + await compress_day_logs() + await organize_and_compress_month_logs() + + +async def delete_frames(save_month=3): + ''' + 删除3个月前的帧数据 + :param save_month: 默认保留3个月的帧数据 + :return: + ''' + base_folder = './storage/frames' + + async for db in get_db(): + frame_service = DeviceFrameService(db) + + now = datetime.now() + latest_time = now - relativedelta(months=save_month) + latest_time = latest_time.replace(hour=0, minute=0, second=0, microsecond=0) + + max_frame_id = await frame_service.select_max_frame_id(latest_time) + if max_frame_id > 0: + await frame_service.delete_frame(max_frame_id) + logger.info(f"Delete frames before {latest_time}, max_frame_id = {max_frame_id}") + + # 遍历子文件夹 + for folder_name in os.listdir(base_folder): + folder_path = os.path.join(base_folder, folder_name) + if os.path.isdir(folder_path): + try: + # 尝试将文件夹名称解析为日期 + folder_date = datetime.strptime(folder_name, "%Y-%m-%d") + folder_date = folder_date.replace(hour=0, minute=0, second=0, microsecond=0) + # 比较日期并删除早于3个月前的文件夹 + if folder_date < latest_time: + logger.info(f"Deleting folder: {folder_path}") + shutil.rmtree(folder_path) # 递归删除文件夹及其内容 + except ValueError: + # 忽略无法解析为日期的文件夹名称 + logger.warning(f"Skipping non-date folder: {folder_name}") + + +async def compress_day_logs(): + """ + 异步压缩日志目录中符合特定日期格式 (年-月-日) 的日志文件,并删除原始文件。 + + :param log_dir: 日志目录路径 + :param base_filename: 日志的基础文件名,例如 'app.log' + """ + log_dir = 'logs' + base_filename = 'app.log' + loop = asyncio.get_running_loop() + base_filename_pattern = re.escape(base_filename) + r"\.\d{4}-\d{2}-\d{2}" # 匹配 app.log.年-月-日 格式 + + for filename in os.listdir(log_dir): + file_path = os.path.join(log_dir, filename) + if os.path.isfile(file_path) and re.fullmatch(base_filename_pattern, filename): + try: + compressed_file_path = f"{file_path}.gz" + # 使用线程池在后台执行文件操作以避免阻塞事件循环 + await loop.run_in_executor(None, compress_and_remove, file_path, compressed_file_path) + print(f"Compressed and removed: {file_path}") + except Exception as e: + print(f"Error compressing {file_path}: {e}") + + +def compress_and_remove(file_path, compressed_file_path): + """同步函数,用于压缩并删除文件""" + with open(file_path, 'rb') as f_in: + with gzip.open(compressed_file_path, 'wb') as f_out: + f_out.writelines(f_in) + os.remove(file_path) # 删除未压缩的旧文件 + + +async def organize_and_compress_month_logs(): + loop = asyncio.get_running_loop() + + base_log_dir = 'logs' # 日志目录路径 + archive_dir = 'logs/archive' # 日志归档目录路径 + os.makedirs(archive_dir, exist_ok=True) + + # 获取当前日期以确定要处理的月份 + today = datetime.now() + year_month_pattern = re.compile(r'\d{4}-\d{2}') # 匹配 YYYY-MM 格式 + + # 遍历日志目录并识别所有符合 YYYY-MM 格式的日志文件 + processed_months = set() + for filename in os.listdir(base_log_dir): + file_path = os.path.join(base_log_dir, filename) + if os.path.isfile(file_path): + match = year_month_pattern.search(filename) + if match: + month_str = match.group(0) + # 忽略当前月份 + if month_str < today.strftime('%Y-%m'): + processed_months.add(month_str) + + # 逐月处理日志文件 + for month_str in sorted(processed_months): + archive_path = os.path.join(archive_dir, month_str) + os.makedirs(archive_path, exist_ok=True) + + # 移动符合该月份的日志文件 + for filename in os.listdir(base_log_dir): + file_path = os.path.join(base_log_dir, filename) + if os.path.isfile(file_path) and month_str in filename: + try: + # 使用线程池执行同步的移动操作 + await loop.run_in_executor(None, shutil.move, file_path, os.path.join(archive_path, filename)) + except Exception as e: + logger.error(f"Error moving file {filename}: {e}") + + # 压缩归档文件夹 + archive_tar_path = os.path.join(archive_dir, f"{month_str}.tar.gz") + try: + # 使用线程池执行同步的压缩操作 + await loop.run_in_executor(None, shutil.make_archive, archive_tar_path.replace('.tar.gz', ''), 'gztar', + archive_dir, month_str) + # 使用线程池执行同步的删除操作 + await loop.run_in_executor(None, shutil.rmtree, archive_path) + logger.info(f"Archived and compressed logs for {month_str} to {archive_tar_path}") + except Exception as e: + logger.error(f"Error compressing logs for {month_str}: {e}") diff --git a/tcp/tcp_client_connector.py b/tcp/tcp_client_connector.py index d645340..399cd66 100644 --- a/tcp/tcp_client_connector.py +++ b/tcp/tcp_client_connector.py @@ -1,4 +1,5 @@ import asyncio +from collections import deque from datetime import datetime from common.byte_utils import format_bytes @@ -14,14 +15,19 @@ def parse_gas_data(data): # 数据长度检查,确保最小长度符合协议要求 if len(data) < 13: - raise ValueError("数据长度不足,无法解析") + return None + # raise ValueError("数据长度不足,无法解析") + + # 检查帧头(AA 01) + if data[6:8] != b'\xAA\x01': + return None + # raise ValueError("帧头不匹配") + # 解析设备编号(UU VV WW XX YY ZZ) device_id = ''.join(f'{byte:02X}' for byte in data[:6]) - # 检查帧头(AA 01) - if data[6:8] != b'\xAA\x01': - raise ValueError("帧头不匹配") + # 解析 GG, HH, II 字节,计算燃气浓度值 (ppm.m) GG = data[8] @@ -61,24 +67,29 @@ self.reconnect_interval = reconnect_interval # 重连间隔 self.timeout = timeout # 连接/发送超时时间 self.is_connected = False # 连接状态标志 - + self.message_queue = deque() + self.read_lock = asyncio.Lock() # 添加锁 self.push_ts_dict = {} async def connect(self): """连接到设备""" - try: - logger.info(f"Connecting to {self.ip}:{self.port}...") - # 使用 asyncio.wait_for() 为连接设置超时时间 - self.reader, self.writer = await asyncio.wait_for( - asyncio.open_connection(self.ip, self.port), timeout=self.timeout - ) - self.is_connected = True - logger.info(f"Connected to {self.ip}:{self.port}") - # 一旦连接成功,开始发送查询指令 - await self.start_gas_query() - except (asyncio.TimeoutError, ConnectionRefusedError, OSError) as e: - logger.error(f"Failed to connect to {self.ip}:{self.port}, error: {e}") - await self.reconnect() + while not self.is_connected: + try: + logger.info(f"正在连接到 {self.ip}:{self.port}...") + # 设置连接超时 + self.reader, self.writer = await asyncio.wait_for( + asyncio.open_connection(self.ip, self.port), timeout=self.timeout + ) + self.is_connected = True + logger.info(f"已连接到 {self.ip}:{self.port}") + asyncio.create_task(self.process_message_queue()) # Start processing message queue + + # 一旦连接成功,开始发送查询指令 + await self.start_gas_query() + except (asyncio.TimeoutError, ConnectionRefusedError, OSError) as e: + logger.error(f"连接到 {self.ip}:{self.port} 失败,错误: {e}") + logger.info(f"{self.reconnect_interval} 秒后将重连到 {self.ip}:{self.port}") + await asyncio.sleep(self.reconnect_interval) async def reconnect(self): """处理断线重连""" @@ -118,16 +129,17 @@ logger.info(f"Received data from {self.ip}:{self.port}: {format_bytes(data)}") try: res = parse_gas_data(data) - logger.info(res) - async for db in get_db(): - data_gas_service = DataGasService(db) - data_gas = DataGas( - device_code=res['device_code'], - gas_value=res['gas_value'] - ) + if res: + logger.info(res) + async for db in get_db(): + data_gas_service = DataGasService(db) + data_gas = DataGas( + device_code=res['device_code'], + gas_value=res['gas_value'] + ) - await data_gas_service.add_data_gas(data_gas) - await self.gas_push(data_gas) + await data_gas_service.add_data_gas(data_gas) + await self.gas_push(data_gas) except Exception as e: logger.error(f"Parse and save gas data failed: {e}") @@ -145,32 +157,67 @@ self.push_ts_dict[data_gas.device_code] = current_time # 更新推送时间戳 async def send_message(self, message: bytes, have_response=True): - """发送自定义消息的接口,供其他类调用""" - try: - # 检查连接状态 - if self.writer is None: - raise ConnectionResetError("No active connection") + """Add a message to the queue for sending""" + self.message_queue.append((message, have_response)) + logger.info(f"Message enqueued for {self.ip}:{self.port} {format_bytes(message)}") - # 发送自定义消息 - self.writer.write(message) - await self.writer.drain() # 确保数据已发送 - logger.info(f"Sent message to {self.ip}:{self.port}: {message}") - - # 可以根据需求选择是否接收响应 - if have_response: - data = await asyncio.wait_for(self.reader.read(1024), timeout=self.timeout) - # if not data: - # raise ConnectionResetError("Connection lost or no data received") - await self.parse_response(data) - return data # 返回响应数据 + async def process_message_queue(self): + """Process messages in the queue, retrying on failures""" + while self.is_connected: + if self.message_queue: + message, have_response = self.message_queue.popleft() + await self._send_message_with_retry(message, have_response) else: - return None - except asyncio.TimeoutError: - logger.error(f"TimeoutError: No response from {self.ip}:{self.port} after {self.timeout} seconds") - await self.reconnect() # 如果超时则重新连接 - except (ConnectionResetError, asyncio.IncompleteReadError) as e: - logger.error(f"Failed to send message: {e}") - await self.reconnect() # 重新连接设备 + await asyncio.sleep(1) # Small delay to prevent busy-waiting + + async def _send_message_with_retry(self, message: bytes, have_response): + """Send a message with retries on failure""" + while self.is_connected: + try: + if self.writer is None: + raise ConnectionResetError("No active connection") + + self.writer.write(message) + await self.writer.drain() + logger.info(f"Sent message to {self.ip}:{self.port}: {message}") + + if have_response: + async with self.read_lock: # 使用锁确保只有一个协程读取 + data = await asyncio.wait_for(self.reader.read(1024), timeout=self.timeout) + await self.parse_response(data) + return # Exit loop on success + except (asyncio.TimeoutError, ConnectionResetError, asyncio.IncompleteReadError) as e: + logger.error(f"Failed to send message: {e}, retrying...") + await self.reconnect() + break + + # async def send_message(self, message: bytes, have_response=True): + # """发送自定义消息的接口,供其他类调用""" + # try: + # # 检查连接状态 + # if self.writer is None: + # raise ConnectionResetError("No active connection") + # + # # 发送自定义消息 + # self.writer.write(message) + # await self.writer.drain() # 确保数据已发送 + # logger.info(f"Sent message to {self.ip}:{self.port}: {message}") + # + # # 可以根据需求选择是否接收响应 + # if have_response: + # data = await asyncio.wait_for(self.reader.read(1024), timeout=self.timeout) + # # if not data: + # # raise ConnectionResetError("Connection lost or no data received") + # await self.parse_response(data) + # return data # 返回响应数据 + # else: + # return None + # except asyncio.TimeoutError: + # logger.error(f"TimeoutError: No response from {self.ip}:{self.port} after {self.timeout} seconds") + # await self.reconnect() # 如果超时则重新连接 + # except (ConnectionResetError, asyncio.IncompleteReadError) as e: + # logger.error(f"Failed to send message: {e}") + # await self.reconnect() # 重新连接设备 if __name__ == '__main__':