diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py deleted file mode 100644 index 9109d75..0000000 --- a/algo/algo_runner_manager.py +++ /dev/null @@ -1,8 +0,0 @@ -from algo.algo_runner import AlgoRunner - -# algo_runner = AlgoRunner() - - -def get_algo_runner(): - pass - # return algo_runner diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py deleted file mode 100644 index 9109d75..0000000 --- a/algo/algo_runner_manager.py +++ /dev/null @@ -1,8 +0,0 @@ -from algo.algo_runner import AlgoRunner - -# algo_runner = AlgoRunner() - - -def get_algo_runner(): - pass - # return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 6a86830..d2c647e 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -3,7 +3,7 @@ from copy import deepcopy from dataclasses import dataclass import importlib -from datetime import datetime +from datetime import datetime, timedelta from threading import Event from typing import List, Dict @@ -48,6 +48,8 @@ self.__stop_event = Event() # 使用 Event 控制线程的运行状态 self.frame_ts = None self.push_ts = None + self.frames_detected = 0 + self.fps_ts = None self.thread_id = thread_id self.device_frame_service = DeviceFrameService(db_session) @@ -77,37 +79,40 @@ return True return False - def save_frame_results(self, frame, results_map): - if not self.check_frame_interval(): - return - # device_frame = self.device_frame_service.add_frame(self.device.id, frame) - # frame_id = device_frame.id + def save_frame_results(self, frames, result_maps): + for idx,frame in enumerate(frames): + results_map = result_maps[idx] - future = asyncio.run_coroutine_threadsafe( - self.device_frame_service.add_frame(self.device.id, frame), self.main_loop - ) - device_frame = future.result() - frame_id = device_frame.id + if not self.check_frame_interval(): + return + # device_frame = self.device_frame_service.add_frame(self.device.id, frame) + # frame_id = device_frame.id - logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') - frame_results = [] - for model_exec_id, results in results_map.items(): - for r in results: - frame_result = FrameAnalysisResultCreate( - device_id=self.device.id, - frame_id=frame_id, - algo_model_id=model_exec_id, - object_class_id=r.object_class_id, - object_class_name=r.object_class_name, - confidence=round(r.confidence, 4), - location=r.location, - ) - frame_results.append(frame_result) - asyncio.run_coroutine_threadsafe( - self.frame_analysis_result_service.add_frame_analysis_results(frame_results), - self.main_loop - ) - self.thread_pool.submit_task(self.push_frame_results, frame_results) + future = asyncio.run_coroutine_threadsafe( + self.device_frame_service.add_frame(self.device.id, frame), self.main_loop + ) + device_frame = future.result() + frame_id = device_frame.id + + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') + frame_results = [] + for model_exec_id, results in results_map.items(): + for r in results: + frame_result = FrameAnalysisResultCreate( + device_id=self.device.id, + frame_id=frame_id, + algo_model_id=model_exec_id, + object_class_id=r.object_class_id, + object_class_name=r.object_class_name, + confidence=round(r.confidence, 4), + location=r.location, + ) + frame_results.append(frame_result) + asyncio.run_coroutine_threadsafe( + self.frame_analysis_result_service.add_frame_analysis_results(frame_results), + self.main_loop + ) + self.thread_pool.submit_task(self.push_frame_results, frame_results) def push_frame_results(self, frame_results): global_config = GlobalConfig() @@ -124,31 +129,63 @@ ) self.push_ts = current_time # 更新推送时间戳 + def log_fps(self, frame_count): + self.frames_detected += frame_count + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_detected / 10.0 + self.frames_detected = 0 + logger.info(f"FPS (detect) for device {self.device.code}: {fps}") + self.fps_ts = current_time + def run(self): while not self.stream_loader.init: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 self.stream_loader.init_cap() - for frame in self.stream_loader: + for frames in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 # print('frame') - if frame is None: + if frames is None or len(frames)<=0: continue self.device_status_manager.set_status(device_id=self.device.id) - results_map = {} - annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + + result_maps = [] # 保存每个frame的result map + annotators = [] + for frame in frames: + annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + annotators.append(annotator) + for model_exec in self.model_exec_list: handle_task_name = model_exec.algo_model_info.handle_task handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) handler_instance = handler_cls(model_exec) - frame, results = handler_instance.run(frame, annotator) - results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] + frames, results = handler_instance.run(frames, annotators) + + # 遍历检测结果,按帧存储 + for frame_idx, r in enumerate(results): # 遍历每帧结果 + if len(result_maps) <= frame_idx: # 初始化 frame_results_map + result_maps.append({}) + frame_results_map = result_maps[frame_idx] + # 为当前模型存储检测结果 + if model_exec.algo_model_id not in frame_results_map: + frame_results_map[model_exec.algo_model_id] = [] + # 添加检测结果 + frame_results_map[model_exec.algo_model_id].extend( + DetectionResult.from_dict(box) for box in r + ) # 结果处理 - self.thread_pool.submit_task(self.save_frame_results, frame, results_map) - self.display_frame_manager.add_frame(self.device.id, annotator.result()) + self.thread_pool.submit_task(self.save_frame_results, frames, result_maps) + for annotator in annotators: + self.display_frame_manager.add_frame(self.device.id, annotator.result()) + + self.log_fps(len(frames)) + + # future = asyncio.run_coroutine_threadsafe( # self.display_frame_manager.add_frame(self.device.id, annotator.result()), self.main_loop # ) diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py deleted file mode 100644 index 9109d75..0000000 --- a/algo/algo_runner_manager.py +++ /dev/null @@ -1,8 +0,0 @@ -from algo.algo_runner import AlgoRunner - -# algo_runner = AlgoRunner() - - -def get_algo_runner(): - pass - # return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 6a86830..d2c647e 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -3,7 +3,7 @@ from copy import deepcopy from dataclasses import dataclass import importlib -from datetime import datetime +from datetime import datetime, timedelta from threading import Event from typing import List, Dict @@ -48,6 +48,8 @@ self.__stop_event = Event() # 使用 Event 控制线程的运行状态 self.frame_ts = None self.push_ts = None + self.frames_detected = 0 + self.fps_ts = None self.thread_id = thread_id self.device_frame_service = DeviceFrameService(db_session) @@ -77,37 +79,40 @@ return True return False - def save_frame_results(self, frame, results_map): - if not self.check_frame_interval(): - return - # device_frame = self.device_frame_service.add_frame(self.device.id, frame) - # frame_id = device_frame.id + def save_frame_results(self, frames, result_maps): + for idx,frame in enumerate(frames): + results_map = result_maps[idx] - future = asyncio.run_coroutine_threadsafe( - self.device_frame_service.add_frame(self.device.id, frame), self.main_loop - ) - device_frame = future.result() - frame_id = device_frame.id + if not self.check_frame_interval(): + return + # device_frame = self.device_frame_service.add_frame(self.device.id, frame) + # frame_id = device_frame.id - logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') - frame_results = [] - for model_exec_id, results in results_map.items(): - for r in results: - frame_result = FrameAnalysisResultCreate( - device_id=self.device.id, - frame_id=frame_id, - algo_model_id=model_exec_id, - object_class_id=r.object_class_id, - object_class_name=r.object_class_name, - confidence=round(r.confidence, 4), - location=r.location, - ) - frame_results.append(frame_result) - asyncio.run_coroutine_threadsafe( - self.frame_analysis_result_service.add_frame_analysis_results(frame_results), - self.main_loop - ) - self.thread_pool.submit_task(self.push_frame_results, frame_results) + future = asyncio.run_coroutine_threadsafe( + self.device_frame_service.add_frame(self.device.id, frame), self.main_loop + ) + device_frame = future.result() + frame_id = device_frame.id + + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') + frame_results = [] + for model_exec_id, results in results_map.items(): + for r in results: + frame_result = FrameAnalysisResultCreate( + device_id=self.device.id, + frame_id=frame_id, + algo_model_id=model_exec_id, + object_class_id=r.object_class_id, + object_class_name=r.object_class_name, + confidence=round(r.confidence, 4), + location=r.location, + ) + frame_results.append(frame_result) + asyncio.run_coroutine_threadsafe( + self.frame_analysis_result_service.add_frame_analysis_results(frame_results), + self.main_loop + ) + self.thread_pool.submit_task(self.push_frame_results, frame_results) def push_frame_results(self, frame_results): global_config = GlobalConfig() @@ -124,31 +129,63 @@ ) self.push_ts = current_time # 更新推送时间戳 + def log_fps(self, frame_count): + self.frames_detected += frame_count + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_detected / 10.0 + self.frames_detected = 0 + logger.info(f"FPS (detect) for device {self.device.code}: {fps}") + self.fps_ts = current_time + def run(self): while not self.stream_loader.init: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 self.stream_loader.init_cap() - for frame in self.stream_loader: + for frames in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 # print('frame') - if frame is None: + if frames is None or len(frames)<=0: continue self.device_status_manager.set_status(device_id=self.device.id) - results_map = {} - annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + + result_maps = [] # 保存每个frame的result map + annotators = [] + for frame in frames: + annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + annotators.append(annotator) + for model_exec in self.model_exec_list: handle_task_name = model_exec.algo_model_info.handle_task handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) handler_instance = handler_cls(model_exec) - frame, results = handler_instance.run(frame, annotator) - results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] + frames, results = handler_instance.run(frames, annotators) + + # 遍历检测结果,按帧存储 + for frame_idx, r in enumerate(results): # 遍历每帧结果 + if len(result_maps) <= frame_idx: # 初始化 frame_results_map + result_maps.append({}) + frame_results_map = result_maps[frame_idx] + # 为当前模型存储检测结果 + if model_exec.algo_model_id not in frame_results_map: + frame_results_map[model_exec.algo_model_id] = [] + # 添加检测结果 + frame_results_map[model_exec.algo_model_id].extend( + DetectionResult.from_dict(box) for box in r + ) # 结果处理 - self.thread_pool.submit_task(self.save_frame_results, frame, results_map) - self.display_frame_manager.add_frame(self.device.id, annotator.result()) + self.thread_pool.submit_task(self.save_frame_results, frames, result_maps) + for annotator in annotators: + self.display_frame_manager.add_frame(self.device.id, annotator.result()) + + self.log_fps(len(frames)) + + # future = asyncio.run_coroutine_threadsafe( # self.display_frame_manager.add_frame(self.device.id, annotator.result()), self.main_loop # ) diff --git a/algo/model_manager.py b/algo/model_manager.py index 1c30867..e84676f 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -19,11 +19,12 @@ class ModelManager: - def __init__(self, model_service: ModelService, model_warm_up=5): + def __init__(self, model_service: ModelService, model_warm_up=5, batch_size=4): # self.db = db self.model_service = model_service self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up + self.batch_size=4 async def query_model_inuse(self): algo_model_list = list(await self.model_service.get_models_in_use()) @@ -58,8 +59,11 @@ if not isinstance(imgsz, list): imgsz = [imgsz, imgsz] dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + dummy_inputs = [] + for _ in range(self.batch_size): + dummy_inputs.append(dummy_input) for i in range(self.model_warm_up): - algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + algo_model_exec.algo_model_exec.predict(source=dummy_inputs, imgsz=imgsz, verbose=False,save=False,save_txt=False) logger.info(f'warm up model {model_name} success!') logger.info(f'load model {model_name} success!') diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py deleted file mode 100644 index 9109d75..0000000 --- a/algo/algo_runner_manager.py +++ /dev/null @@ -1,8 +0,0 @@ -from algo.algo_runner import AlgoRunner - -# algo_runner = AlgoRunner() - - -def get_algo_runner(): - pass - # return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 6a86830..d2c647e 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -3,7 +3,7 @@ from copy import deepcopy from dataclasses import dataclass import importlib -from datetime import datetime +from datetime import datetime, timedelta from threading import Event from typing import List, Dict @@ -48,6 +48,8 @@ self.__stop_event = Event() # 使用 Event 控制线程的运行状态 self.frame_ts = None self.push_ts = None + self.frames_detected = 0 + self.fps_ts = None self.thread_id = thread_id self.device_frame_service = DeviceFrameService(db_session) @@ -77,37 +79,40 @@ return True return False - def save_frame_results(self, frame, results_map): - if not self.check_frame_interval(): - return - # device_frame = self.device_frame_service.add_frame(self.device.id, frame) - # frame_id = device_frame.id + def save_frame_results(self, frames, result_maps): + for idx,frame in enumerate(frames): + results_map = result_maps[idx] - future = asyncio.run_coroutine_threadsafe( - self.device_frame_service.add_frame(self.device.id, frame), self.main_loop - ) - device_frame = future.result() - frame_id = device_frame.id + if not self.check_frame_interval(): + return + # device_frame = self.device_frame_service.add_frame(self.device.id, frame) + # frame_id = device_frame.id - logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') - frame_results = [] - for model_exec_id, results in results_map.items(): - for r in results: - frame_result = FrameAnalysisResultCreate( - device_id=self.device.id, - frame_id=frame_id, - algo_model_id=model_exec_id, - object_class_id=r.object_class_id, - object_class_name=r.object_class_name, - confidence=round(r.confidence, 4), - location=r.location, - ) - frame_results.append(frame_result) - asyncio.run_coroutine_threadsafe( - self.frame_analysis_result_service.add_frame_analysis_results(frame_results), - self.main_loop - ) - self.thread_pool.submit_task(self.push_frame_results, frame_results) + future = asyncio.run_coroutine_threadsafe( + self.device_frame_service.add_frame(self.device.id, frame), self.main_loop + ) + device_frame = future.result() + frame_id = device_frame.id + + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') + frame_results = [] + for model_exec_id, results in results_map.items(): + for r in results: + frame_result = FrameAnalysisResultCreate( + device_id=self.device.id, + frame_id=frame_id, + algo_model_id=model_exec_id, + object_class_id=r.object_class_id, + object_class_name=r.object_class_name, + confidence=round(r.confidence, 4), + location=r.location, + ) + frame_results.append(frame_result) + asyncio.run_coroutine_threadsafe( + self.frame_analysis_result_service.add_frame_analysis_results(frame_results), + self.main_loop + ) + self.thread_pool.submit_task(self.push_frame_results, frame_results) def push_frame_results(self, frame_results): global_config = GlobalConfig() @@ -124,31 +129,63 @@ ) self.push_ts = current_time # 更新推送时间戳 + def log_fps(self, frame_count): + self.frames_detected += frame_count + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_detected / 10.0 + self.frames_detected = 0 + logger.info(f"FPS (detect) for device {self.device.code}: {fps}") + self.fps_ts = current_time + def run(self): while not self.stream_loader.init: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 self.stream_loader.init_cap() - for frame in self.stream_loader: + for frames in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 # print('frame') - if frame is None: + if frames is None or len(frames)<=0: continue self.device_status_manager.set_status(device_id=self.device.id) - results_map = {} - annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + + result_maps = [] # 保存每个frame的result map + annotators = [] + for frame in frames: + annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + annotators.append(annotator) + for model_exec in self.model_exec_list: handle_task_name = model_exec.algo_model_info.handle_task handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) handler_instance = handler_cls(model_exec) - frame, results = handler_instance.run(frame, annotator) - results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] + frames, results = handler_instance.run(frames, annotators) + + # 遍历检测结果,按帧存储 + for frame_idx, r in enumerate(results): # 遍历每帧结果 + if len(result_maps) <= frame_idx: # 初始化 frame_results_map + result_maps.append({}) + frame_results_map = result_maps[frame_idx] + # 为当前模型存储检测结果 + if model_exec.algo_model_id not in frame_results_map: + frame_results_map[model_exec.algo_model_id] = [] + # 添加检测结果 + frame_results_map[model_exec.algo_model_id].extend( + DetectionResult.from_dict(box) for box in r + ) # 结果处理 - self.thread_pool.submit_task(self.save_frame_results, frame, results_map) - self.display_frame_manager.add_frame(self.device.id, annotator.result()) + self.thread_pool.submit_task(self.save_frame_results, frames, result_maps) + for annotator in annotators: + self.display_frame_manager.add_frame(self.device.id, annotator.result()) + + self.log_fps(len(frames)) + + # future = asyncio.run_coroutine_threadsafe( # self.display_frame_manager.add_frame(self.device.id, annotator.result()), self.main_loop # ) diff --git a/algo/model_manager.py b/algo/model_manager.py index 1c30867..e84676f 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -19,11 +19,12 @@ class ModelManager: - def __init__(self, model_service: ModelService, model_warm_up=5): + def __init__(self, model_service: ModelService, model_warm_up=5, batch_size=4): # self.db = db self.model_service = model_service self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up + self.batch_size=4 async def query_model_inuse(self): algo_model_list = list(await self.model_service.get_models_in_use()) @@ -58,8 +59,11 @@ if not isinstance(imgsz, list): imgsz = [imgsz, imgsz] dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + dummy_inputs = [] + for _ in range(self.batch_size): + dummy_inputs.append(dummy_input) for i in range(self.model_warm_up): - algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + algo_model_exec.algo_model_exec.predict(source=dummy_inputs, imgsz=imgsz, verbose=False,save=False,save_txt=False) logger.info(f'warm up model {model_name} success!') logger.info(f'load model {model_name} success!') diff --git a/algo/scene_runner.py b/algo/scene_runner.py index e5c067e..5120f96 100644 --- a/algo/scene_runner.py +++ b/algo/scene_runner.py @@ -69,7 +69,7 @@ try: handle_task_name = scene.scene_handle_task handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) - handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) + handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop,scene.range_points) future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) self.device_tasks[device.id] = handler_instance logger.info(f'start thread {thread_id}, device info: {device}') diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py deleted file mode 100644 index 9109d75..0000000 --- a/algo/algo_runner_manager.py +++ /dev/null @@ -1,8 +0,0 @@ -from algo.algo_runner import AlgoRunner - -# algo_runner = AlgoRunner() - - -def get_algo_runner(): - pass - # return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 6a86830..d2c647e 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -3,7 +3,7 @@ from copy import deepcopy from dataclasses import dataclass import importlib -from datetime import datetime +from datetime import datetime, timedelta from threading import Event from typing import List, Dict @@ -48,6 +48,8 @@ self.__stop_event = Event() # 使用 Event 控制线程的运行状态 self.frame_ts = None self.push_ts = None + self.frames_detected = 0 + self.fps_ts = None self.thread_id = thread_id self.device_frame_service = DeviceFrameService(db_session) @@ -77,37 +79,40 @@ return True return False - def save_frame_results(self, frame, results_map): - if not self.check_frame_interval(): - return - # device_frame = self.device_frame_service.add_frame(self.device.id, frame) - # frame_id = device_frame.id + def save_frame_results(self, frames, result_maps): + for idx,frame in enumerate(frames): + results_map = result_maps[idx] - future = asyncio.run_coroutine_threadsafe( - self.device_frame_service.add_frame(self.device.id, frame), self.main_loop - ) - device_frame = future.result() - frame_id = device_frame.id + if not self.check_frame_interval(): + return + # device_frame = self.device_frame_service.add_frame(self.device.id, frame) + # frame_id = device_frame.id - logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') - frame_results = [] - for model_exec_id, results in results_map.items(): - for r in results: - frame_result = FrameAnalysisResultCreate( - device_id=self.device.id, - frame_id=frame_id, - algo_model_id=model_exec_id, - object_class_id=r.object_class_id, - object_class_name=r.object_class_name, - confidence=round(r.confidence, 4), - location=r.location, - ) - frame_results.append(frame_result) - asyncio.run_coroutine_threadsafe( - self.frame_analysis_result_service.add_frame_analysis_results(frame_results), - self.main_loop - ) - self.thread_pool.submit_task(self.push_frame_results, frame_results) + future = asyncio.run_coroutine_threadsafe( + self.device_frame_service.add_frame(self.device.id, frame), self.main_loop + ) + device_frame = future.result() + frame_id = device_frame.id + + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') + frame_results = [] + for model_exec_id, results in results_map.items(): + for r in results: + frame_result = FrameAnalysisResultCreate( + device_id=self.device.id, + frame_id=frame_id, + algo_model_id=model_exec_id, + object_class_id=r.object_class_id, + object_class_name=r.object_class_name, + confidence=round(r.confidence, 4), + location=r.location, + ) + frame_results.append(frame_result) + asyncio.run_coroutine_threadsafe( + self.frame_analysis_result_service.add_frame_analysis_results(frame_results), + self.main_loop + ) + self.thread_pool.submit_task(self.push_frame_results, frame_results) def push_frame_results(self, frame_results): global_config = GlobalConfig() @@ -124,31 +129,63 @@ ) self.push_ts = current_time # 更新推送时间戳 + def log_fps(self, frame_count): + self.frames_detected += frame_count + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_detected / 10.0 + self.frames_detected = 0 + logger.info(f"FPS (detect) for device {self.device.code}: {fps}") + self.fps_ts = current_time + def run(self): while not self.stream_loader.init: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 self.stream_loader.init_cap() - for frame in self.stream_loader: + for frames in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 # print('frame') - if frame is None: + if frames is None or len(frames)<=0: continue self.device_status_manager.set_status(device_id=self.device.id) - results_map = {} - annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + + result_maps = [] # 保存每个frame的result map + annotators = [] + for frame in frames: + annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + annotators.append(annotator) + for model_exec in self.model_exec_list: handle_task_name = model_exec.algo_model_info.handle_task handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) handler_instance = handler_cls(model_exec) - frame, results = handler_instance.run(frame, annotator) - results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] + frames, results = handler_instance.run(frames, annotators) + + # 遍历检测结果,按帧存储 + for frame_idx, r in enumerate(results): # 遍历每帧结果 + if len(result_maps) <= frame_idx: # 初始化 frame_results_map + result_maps.append({}) + frame_results_map = result_maps[frame_idx] + # 为当前模型存储检测结果 + if model_exec.algo_model_id not in frame_results_map: + frame_results_map[model_exec.algo_model_id] = [] + # 添加检测结果 + frame_results_map[model_exec.algo_model_id].extend( + DetectionResult.from_dict(box) for box in r + ) # 结果处理 - self.thread_pool.submit_task(self.save_frame_results, frame, results_map) - self.display_frame_manager.add_frame(self.device.id, annotator.result()) + self.thread_pool.submit_task(self.save_frame_results, frames, result_maps) + for annotator in annotators: + self.display_frame_manager.add_frame(self.device.id, annotator.result()) + + self.log_fps(len(frames)) + + # future = asyncio.run_coroutine_threadsafe( # self.display_frame_manager.add_frame(self.device.id, annotator.result()), self.main_loop # ) diff --git a/algo/model_manager.py b/algo/model_manager.py index 1c30867..e84676f 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -19,11 +19,12 @@ class ModelManager: - def __init__(self, model_service: ModelService, model_warm_up=5): + def __init__(self, model_service: ModelService, model_warm_up=5, batch_size=4): # self.db = db self.model_service = model_service self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up + self.batch_size=4 async def query_model_inuse(self): algo_model_list = list(await self.model_service.get_models_in_use()) @@ -58,8 +59,11 @@ if not isinstance(imgsz, list): imgsz = [imgsz, imgsz] dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + dummy_inputs = [] + for _ in range(self.batch_size): + dummy_inputs.append(dummy_input) for i in range(self.model_warm_up): - algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + algo_model_exec.algo_model_exec.predict(source=dummy_inputs, imgsz=imgsz, verbose=False,save=False,save_txt=False) logger.info(f'warm up model {model_name} success!') logger.info(f'load model {model_name} success!') diff --git a/algo/scene_runner.py b/algo/scene_runner.py index e5c067e..5120f96 100644 --- a/algo/scene_runner.py +++ b/algo/scene_runner.py @@ -69,7 +69,7 @@ try: handle_task_name = scene.scene_handle_task handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) - handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) + handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop,scene.range_points) future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) self.device_tasks[device.id] = handler_instance logger.info(f'start thread {thread_id}, device info: {device}') diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 4133388..202213c 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -1,13 +1,20 @@ +from datetime import datetime, timedelta + import cv2 import time import numpy as np from threading import Thread, Event + +import queue + from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool class OpenCVStreamLoad: def __init__(self, camera_url, camera_code, device_thread_id = '', + batch_size=4, + queue_size=100, retry_interval=1, vid_stride=1): assert camera_url is not None and camera_url != '' @@ -19,10 +26,15 @@ self.init = False self.frame = None + self.frame_queue = queue.Queue(maxsize=queue_size) + self.fps_ts = None self.cap = None + self.batch_size = batch_size + self.frames_read = 0 self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 self.thread_pool = GlobalThreadPool() + self.init_cap() # self.thread = Thread(target=self.update, daemon=True) # self.thread.start() @@ -44,7 +56,14 @@ 尝试创建视频流捕获对象。 """ try: - cap = cv2.VideoCapture(self.url) + # cap = cv2.VideoCapture(self.url) + gst_pipeline = ( + f"rtspsrc location={self.url} ! " + f"rtph264depay ! h264parse ! " + f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + f"appsink" + ) + cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) # 可以在这里设置cap的一些属性,如果需要的话 return cap except Exception as e: @@ -74,9 +93,13 @@ cap = cv2.VideoCapture(self.url) # cap = cv2.VideoCapture(self.url, cv2.CAP_FFMPEG) - # 使用 GStreamer pipeline 作为 VideoCapture 的输入 - # gst_str = f"rtspsrc location={self.url} ! decodebin ! videoconvert ! appsink" - # cap = cv2.VideoCapture(gst_str, cv2.CAP_GSTREAMER) + # gst_pipeline = ( + # f"rtspsrc location={self.url} ! " + # f"rtph264depay ! h264parse ! " + # f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + # f"appsink" + # ) + # cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) if cap.isOpened(): logger.info(f"{self.url} connected successfully!") @@ -91,6 +114,15 @@ return cap + def log_fps(self): + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_read / 10.0 + self.frames_read = 0 + logger.info(f"FPS (read) for device {self.camera_code}: {fps}") + self.fps_ts = current_time + def update(self): vid_n = 0 log_n = 0 @@ -111,6 +143,9 @@ else: vid_n += 1 self.frame = frame + self.frames_read += 1 + if not self.frame_queue.full(): + self.frame_queue.put(frame) if log_n % 1000 == 0: logger.debug(f'{self.url} cap success') log_n = (log_n + 1) % 250 @@ -120,12 +155,22 @@ self.cap.release() self.frame = None self.cap = self.get_connect() # 尝试重新连接 + self.log_fps() def __iter__(self): return self def __next__(self): - return self.frame + batch_frames = [] + + queue_length = self.frame_queue.qsize() + if queue_length < self.batch_size: + return [] + + while not self.frame_queue.empty() and len(batch_frames) < self.batch_size: + frame = self.frame_queue.get() + batch_frames.append(frame) + return batch_frames def stop(self): logger.info(f'stop stream loader {self.url},device thread {self.device_thread_id} stopped') diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py deleted file mode 100644 index 9109d75..0000000 --- a/algo/algo_runner_manager.py +++ /dev/null @@ -1,8 +0,0 @@ -from algo.algo_runner import AlgoRunner - -# algo_runner = AlgoRunner() - - -def get_algo_runner(): - pass - # return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 6a86830..d2c647e 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -3,7 +3,7 @@ from copy import deepcopy from dataclasses import dataclass import importlib -from datetime import datetime +from datetime import datetime, timedelta from threading import Event from typing import List, Dict @@ -48,6 +48,8 @@ self.__stop_event = Event() # 使用 Event 控制线程的运行状态 self.frame_ts = None self.push_ts = None + self.frames_detected = 0 + self.fps_ts = None self.thread_id = thread_id self.device_frame_service = DeviceFrameService(db_session) @@ -77,37 +79,40 @@ return True return False - def save_frame_results(self, frame, results_map): - if not self.check_frame_interval(): - return - # device_frame = self.device_frame_service.add_frame(self.device.id, frame) - # frame_id = device_frame.id + def save_frame_results(self, frames, result_maps): + for idx,frame in enumerate(frames): + results_map = result_maps[idx] - future = asyncio.run_coroutine_threadsafe( - self.device_frame_service.add_frame(self.device.id, frame), self.main_loop - ) - device_frame = future.result() - frame_id = device_frame.id + if not self.check_frame_interval(): + return + # device_frame = self.device_frame_service.add_frame(self.device.id, frame) + # frame_id = device_frame.id - logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') - frame_results = [] - for model_exec_id, results in results_map.items(): - for r in results: - frame_result = FrameAnalysisResultCreate( - device_id=self.device.id, - frame_id=frame_id, - algo_model_id=model_exec_id, - object_class_id=r.object_class_id, - object_class_name=r.object_class_name, - confidence=round(r.confidence, 4), - location=r.location, - ) - frame_results.append(frame_result) - asyncio.run_coroutine_threadsafe( - self.frame_analysis_result_service.add_frame_analysis_results(frame_results), - self.main_loop - ) - self.thread_pool.submit_task(self.push_frame_results, frame_results) + future = asyncio.run_coroutine_threadsafe( + self.device_frame_service.add_frame(self.device.id, frame), self.main_loop + ) + device_frame = future.result() + frame_id = device_frame.id + + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') + frame_results = [] + for model_exec_id, results in results_map.items(): + for r in results: + frame_result = FrameAnalysisResultCreate( + device_id=self.device.id, + frame_id=frame_id, + algo_model_id=model_exec_id, + object_class_id=r.object_class_id, + object_class_name=r.object_class_name, + confidence=round(r.confidence, 4), + location=r.location, + ) + frame_results.append(frame_result) + asyncio.run_coroutine_threadsafe( + self.frame_analysis_result_service.add_frame_analysis_results(frame_results), + self.main_loop + ) + self.thread_pool.submit_task(self.push_frame_results, frame_results) def push_frame_results(self, frame_results): global_config = GlobalConfig() @@ -124,31 +129,63 @@ ) self.push_ts = current_time # 更新推送时间戳 + def log_fps(self, frame_count): + self.frames_detected += frame_count + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_detected / 10.0 + self.frames_detected = 0 + logger.info(f"FPS (detect) for device {self.device.code}: {fps}") + self.fps_ts = current_time + def run(self): while not self.stream_loader.init: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 self.stream_loader.init_cap() - for frame in self.stream_loader: + for frames in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 # print('frame') - if frame is None: + if frames is None or len(frames)<=0: continue self.device_status_manager.set_status(device_id=self.device.id) - results_map = {} - annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + + result_maps = [] # 保存每个frame的result map + annotators = [] + for frame in frames: + annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + annotators.append(annotator) + for model_exec in self.model_exec_list: handle_task_name = model_exec.algo_model_info.handle_task handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) handler_instance = handler_cls(model_exec) - frame, results = handler_instance.run(frame, annotator) - results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] + frames, results = handler_instance.run(frames, annotators) + + # 遍历检测结果,按帧存储 + for frame_idx, r in enumerate(results): # 遍历每帧结果 + if len(result_maps) <= frame_idx: # 初始化 frame_results_map + result_maps.append({}) + frame_results_map = result_maps[frame_idx] + # 为当前模型存储检测结果 + if model_exec.algo_model_id not in frame_results_map: + frame_results_map[model_exec.algo_model_id] = [] + # 添加检测结果 + frame_results_map[model_exec.algo_model_id].extend( + DetectionResult.from_dict(box) for box in r + ) # 结果处理 - self.thread_pool.submit_task(self.save_frame_results, frame, results_map) - self.display_frame_manager.add_frame(self.device.id, annotator.result()) + self.thread_pool.submit_task(self.save_frame_results, frames, result_maps) + for annotator in annotators: + self.display_frame_manager.add_frame(self.device.id, annotator.result()) + + self.log_fps(len(frames)) + + # future = asyncio.run_coroutine_threadsafe( # self.display_frame_manager.add_frame(self.device.id, annotator.result()), self.main_loop # ) diff --git a/algo/model_manager.py b/algo/model_manager.py index 1c30867..e84676f 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -19,11 +19,12 @@ class ModelManager: - def __init__(self, model_service: ModelService, model_warm_up=5): + def __init__(self, model_service: ModelService, model_warm_up=5, batch_size=4): # self.db = db self.model_service = model_service self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up + self.batch_size=4 async def query_model_inuse(self): algo_model_list = list(await self.model_service.get_models_in_use()) @@ -58,8 +59,11 @@ if not isinstance(imgsz, list): imgsz = [imgsz, imgsz] dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + dummy_inputs = [] + for _ in range(self.batch_size): + dummy_inputs.append(dummy_input) for i in range(self.model_warm_up): - algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + algo_model_exec.algo_model_exec.predict(source=dummy_inputs, imgsz=imgsz, verbose=False,save=False,save_txt=False) logger.info(f'warm up model {model_name} success!') logger.info(f'load model {model_name} success!') diff --git a/algo/scene_runner.py b/algo/scene_runner.py index e5c067e..5120f96 100644 --- a/algo/scene_runner.py +++ b/algo/scene_runner.py @@ -69,7 +69,7 @@ try: handle_task_name = scene.scene_handle_task handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) - handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) + handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop,scene.range_points) future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) self.device_tasks[device.id] = handler_instance logger.info(f'start thread {thread_id}, device info: {device}') diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 4133388..202213c 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -1,13 +1,20 @@ +from datetime import datetime, timedelta + import cv2 import time import numpy as np from threading import Thread, Event + +import queue + from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool class OpenCVStreamLoad: def __init__(self, camera_url, camera_code, device_thread_id = '', + batch_size=4, + queue_size=100, retry_interval=1, vid_stride=1): assert camera_url is not None and camera_url != '' @@ -19,10 +26,15 @@ self.init = False self.frame = None + self.frame_queue = queue.Queue(maxsize=queue_size) + self.fps_ts = None self.cap = None + self.batch_size = batch_size + self.frames_read = 0 self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 self.thread_pool = GlobalThreadPool() + self.init_cap() # self.thread = Thread(target=self.update, daemon=True) # self.thread.start() @@ -44,7 +56,14 @@ 尝试创建视频流捕获对象。 """ try: - cap = cv2.VideoCapture(self.url) + # cap = cv2.VideoCapture(self.url) + gst_pipeline = ( + f"rtspsrc location={self.url} ! " + f"rtph264depay ! h264parse ! " + f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + f"appsink" + ) + cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) # 可以在这里设置cap的一些属性,如果需要的话 return cap except Exception as e: @@ -74,9 +93,13 @@ cap = cv2.VideoCapture(self.url) # cap = cv2.VideoCapture(self.url, cv2.CAP_FFMPEG) - # 使用 GStreamer pipeline 作为 VideoCapture 的输入 - # gst_str = f"rtspsrc location={self.url} ! decodebin ! videoconvert ! appsink" - # cap = cv2.VideoCapture(gst_str, cv2.CAP_GSTREAMER) + # gst_pipeline = ( + # f"rtspsrc location={self.url} ! " + # f"rtph264depay ! h264parse ! " + # f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + # f"appsink" + # ) + # cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) if cap.isOpened(): logger.info(f"{self.url} connected successfully!") @@ -91,6 +114,15 @@ return cap + def log_fps(self): + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_read / 10.0 + self.frames_read = 0 + logger.info(f"FPS (read) for device {self.camera_code}: {fps}") + self.fps_ts = current_time + def update(self): vid_n = 0 log_n = 0 @@ -111,6 +143,9 @@ else: vid_n += 1 self.frame = frame + self.frames_read += 1 + if not self.frame_queue.full(): + self.frame_queue.put(frame) if log_n % 1000 == 0: logger.debug(f'{self.url} cap success') log_n = (log_n + 1) % 250 @@ -120,12 +155,22 @@ self.cap.release() self.frame = None self.cap = self.get_connect() # 尝试重新连接 + self.log_fps() def __iter__(self): return self def __next__(self): - return self.frame + batch_frames = [] + + queue_length = self.frame_queue.qsize() + if queue_length < self.batch_size: + return [] + + while not self.frame_queue.empty() and len(batch_frames) < self.batch_size: + frame = self.frame_queue.get() + batch_frames.append(frame) + return batch_frames def stop(self): logger.info(f'stop stream loader {self.url},device thread {self.device_thread_id} stopped') diff --git a/apis/control.py b/apis/control.py index 9c6d6fb..2d351d8 100644 --- a/apis/control.py +++ b/apis/control.py @@ -9,6 +9,7 @@ from fastapi import APIRouter import docker import socket +import signal from apis.base import standard_error_response, standard_response from common.global_logger import logger @@ -47,6 +48,23 @@ return None +def handle_sigterm(signum, frame): + print(f"Received signal: {signum}, shutting down gracefully...") + os._exit(0) + + +signal.signal(signal.SIGTERM, handle_sigterm) + + +def shutdown(): + # 等待短暂时间以确保响应已发送 + time.sleep(1) + logger.info("Shutting down the application...") + # sys.exit(0) + # os.kill(os.getpid(), signal.SIGTERM) + os._exit(0) + + @router.get("/restart") async def restart(): try: @@ -92,17 +110,20 @@ logger.error(f"Failed to restart container asynchronously: {ex}") # 在新线程中启动重启操作 - threading.Thread(target=restart_container_async).start() + # threading.Thread(target=restart_container_async).start() + threading.Thread(target=shutdown).start() return standard_response() except Exception as e: traceback.print_exc() return standard_error_response(code=500, message=f"Failed to restart container: {e}") + @router.get("/sync_test") def sync_test(): return standard_response() + @router.get("/async_test") async def async_test(): return standard_response() diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py deleted file mode 100644 index 9109d75..0000000 --- a/algo/algo_runner_manager.py +++ /dev/null @@ -1,8 +0,0 @@ -from algo.algo_runner import AlgoRunner - -# algo_runner = AlgoRunner() - - -def get_algo_runner(): - pass - # return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 6a86830..d2c647e 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -3,7 +3,7 @@ from copy import deepcopy from dataclasses import dataclass import importlib -from datetime import datetime +from datetime import datetime, timedelta from threading import Event from typing import List, Dict @@ -48,6 +48,8 @@ self.__stop_event = Event() # 使用 Event 控制线程的运行状态 self.frame_ts = None self.push_ts = None + self.frames_detected = 0 + self.fps_ts = None self.thread_id = thread_id self.device_frame_service = DeviceFrameService(db_session) @@ -77,37 +79,40 @@ return True return False - def save_frame_results(self, frame, results_map): - if not self.check_frame_interval(): - return - # device_frame = self.device_frame_service.add_frame(self.device.id, frame) - # frame_id = device_frame.id + def save_frame_results(self, frames, result_maps): + for idx,frame in enumerate(frames): + results_map = result_maps[idx] - future = asyncio.run_coroutine_threadsafe( - self.device_frame_service.add_frame(self.device.id, frame), self.main_loop - ) - device_frame = future.result() - frame_id = device_frame.id + if not self.check_frame_interval(): + return + # device_frame = self.device_frame_service.add_frame(self.device.id, frame) + # frame_id = device_frame.id - logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') - frame_results = [] - for model_exec_id, results in results_map.items(): - for r in results: - frame_result = FrameAnalysisResultCreate( - device_id=self.device.id, - frame_id=frame_id, - algo_model_id=model_exec_id, - object_class_id=r.object_class_id, - object_class_name=r.object_class_name, - confidence=round(r.confidence, 4), - location=r.location, - ) - frame_results.append(frame_result) - asyncio.run_coroutine_threadsafe( - self.frame_analysis_result_service.add_frame_analysis_results(frame_results), - self.main_loop - ) - self.thread_pool.submit_task(self.push_frame_results, frame_results) + future = asyncio.run_coroutine_threadsafe( + self.device_frame_service.add_frame(self.device.id, frame), self.main_loop + ) + device_frame = future.result() + frame_id = device_frame.id + + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') + frame_results = [] + for model_exec_id, results in results_map.items(): + for r in results: + frame_result = FrameAnalysisResultCreate( + device_id=self.device.id, + frame_id=frame_id, + algo_model_id=model_exec_id, + object_class_id=r.object_class_id, + object_class_name=r.object_class_name, + confidence=round(r.confidence, 4), + location=r.location, + ) + frame_results.append(frame_result) + asyncio.run_coroutine_threadsafe( + self.frame_analysis_result_service.add_frame_analysis_results(frame_results), + self.main_loop + ) + self.thread_pool.submit_task(self.push_frame_results, frame_results) def push_frame_results(self, frame_results): global_config = GlobalConfig() @@ -124,31 +129,63 @@ ) self.push_ts = current_time # 更新推送时间戳 + def log_fps(self, frame_count): + self.frames_detected += frame_count + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_detected / 10.0 + self.frames_detected = 0 + logger.info(f"FPS (detect) for device {self.device.code}: {fps}") + self.fps_ts = current_time + def run(self): while not self.stream_loader.init: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 self.stream_loader.init_cap() - for frame in self.stream_loader: + for frames in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 # print('frame') - if frame is None: + if frames is None or len(frames)<=0: continue self.device_status_manager.set_status(device_id=self.device.id) - results_map = {} - annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + + result_maps = [] # 保存每个frame的result map + annotators = [] + for frame in frames: + annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + annotators.append(annotator) + for model_exec in self.model_exec_list: handle_task_name = model_exec.algo_model_info.handle_task handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) handler_instance = handler_cls(model_exec) - frame, results = handler_instance.run(frame, annotator) - results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] + frames, results = handler_instance.run(frames, annotators) + + # 遍历检测结果,按帧存储 + for frame_idx, r in enumerate(results): # 遍历每帧结果 + if len(result_maps) <= frame_idx: # 初始化 frame_results_map + result_maps.append({}) + frame_results_map = result_maps[frame_idx] + # 为当前模型存储检测结果 + if model_exec.algo_model_id not in frame_results_map: + frame_results_map[model_exec.algo_model_id] = [] + # 添加检测结果 + frame_results_map[model_exec.algo_model_id].extend( + DetectionResult.from_dict(box) for box in r + ) # 结果处理 - self.thread_pool.submit_task(self.save_frame_results, frame, results_map) - self.display_frame_manager.add_frame(self.device.id, annotator.result()) + self.thread_pool.submit_task(self.save_frame_results, frames, result_maps) + for annotator in annotators: + self.display_frame_manager.add_frame(self.device.id, annotator.result()) + + self.log_fps(len(frames)) + + # future = asyncio.run_coroutine_threadsafe( # self.display_frame_manager.add_frame(self.device.id, annotator.result()), self.main_loop # ) diff --git a/algo/model_manager.py b/algo/model_manager.py index 1c30867..e84676f 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -19,11 +19,12 @@ class ModelManager: - def __init__(self, model_service: ModelService, model_warm_up=5): + def __init__(self, model_service: ModelService, model_warm_up=5, batch_size=4): # self.db = db self.model_service = model_service self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up + self.batch_size=4 async def query_model_inuse(self): algo_model_list = list(await self.model_service.get_models_in_use()) @@ -58,8 +59,11 @@ if not isinstance(imgsz, list): imgsz = [imgsz, imgsz] dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + dummy_inputs = [] + for _ in range(self.batch_size): + dummy_inputs.append(dummy_input) for i in range(self.model_warm_up): - algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + algo_model_exec.algo_model_exec.predict(source=dummy_inputs, imgsz=imgsz, verbose=False,save=False,save_txt=False) logger.info(f'warm up model {model_name} success!') logger.info(f'load model {model_name} success!') diff --git a/algo/scene_runner.py b/algo/scene_runner.py index e5c067e..5120f96 100644 --- a/algo/scene_runner.py +++ b/algo/scene_runner.py @@ -69,7 +69,7 @@ try: handle_task_name = scene.scene_handle_task handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) - handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) + handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop,scene.range_points) future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) self.device_tasks[device.id] = handler_instance logger.info(f'start thread {thread_id}, device info: {device}') diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 4133388..202213c 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -1,13 +1,20 @@ +from datetime import datetime, timedelta + import cv2 import time import numpy as np from threading import Thread, Event + +import queue + from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool class OpenCVStreamLoad: def __init__(self, camera_url, camera_code, device_thread_id = '', + batch_size=4, + queue_size=100, retry_interval=1, vid_stride=1): assert camera_url is not None and camera_url != '' @@ -19,10 +26,15 @@ self.init = False self.frame = None + self.frame_queue = queue.Queue(maxsize=queue_size) + self.fps_ts = None self.cap = None + self.batch_size = batch_size + self.frames_read = 0 self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 self.thread_pool = GlobalThreadPool() + self.init_cap() # self.thread = Thread(target=self.update, daemon=True) # self.thread.start() @@ -44,7 +56,14 @@ 尝试创建视频流捕获对象。 """ try: - cap = cv2.VideoCapture(self.url) + # cap = cv2.VideoCapture(self.url) + gst_pipeline = ( + f"rtspsrc location={self.url} ! " + f"rtph264depay ! h264parse ! " + f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + f"appsink" + ) + cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) # 可以在这里设置cap的一些属性,如果需要的话 return cap except Exception as e: @@ -74,9 +93,13 @@ cap = cv2.VideoCapture(self.url) # cap = cv2.VideoCapture(self.url, cv2.CAP_FFMPEG) - # 使用 GStreamer pipeline 作为 VideoCapture 的输入 - # gst_str = f"rtspsrc location={self.url} ! decodebin ! videoconvert ! appsink" - # cap = cv2.VideoCapture(gst_str, cv2.CAP_GSTREAMER) + # gst_pipeline = ( + # f"rtspsrc location={self.url} ! " + # f"rtph264depay ! h264parse ! " + # f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + # f"appsink" + # ) + # cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) if cap.isOpened(): logger.info(f"{self.url} connected successfully!") @@ -91,6 +114,15 @@ return cap + def log_fps(self): + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_read / 10.0 + self.frames_read = 0 + logger.info(f"FPS (read) for device {self.camera_code}: {fps}") + self.fps_ts = current_time + def update(self): vid_n = 0 log_n = 0 @@ -111,6 +143,9 @@ else: vid_n += 1 self.frame = frame + self.frames_read += 1 + if not self.frame_queue.full(): + self.frame_queue.put(frame) if log_n % 1000 == 0: logger.debug(f'{self.url} cap success') log_n = (log_n + 1) % 250 @@ -120,12 +155,22 @@ self.cap.release() self.frame = None self.cap = self.get_connect() # 尝试重新连接 + self.log_fps() def __iter__(self): return self def __next__(self): - return self.frame + batch_frames = [] + + queue_length = self.frame_queue.qsize() + if queue_length < self.batch_size: + return [] + + while not self.frame_queue.empty() and len(batch_frames) < self.batch_size: + frame = self.frame_queue.get() + batch_frames.append(frame) + return batch_frames def stop(self): logger.info(f'stop stream loader {self.url},device thread {self.device_thread_id} stopped') diff --git a/apis/control.py b/apis/control.py index 9c6d6fb..2d351d8 100644 --- a/apis/control.py +++ b/apis/control.py @@ -9,6 +9,7 @@ from fastapi import APIRouter import docker import socket +import signal from apis.base import standard_error_response, standard_response from common.global_logger import logger @@ -47,6 +48,23 @@ return None +def handle_sigterm(signum, frame): + print(f"Received signal: {signum}, shutting down gracefully...") + os._exit(0) + + +signal.signal(signal.SIGTERM, handle_sigterm) + + +def shutdown(): + # 等待短暂时间以确保响应已发送 + time.sleep(1) + logger.info("Shutting down the application...") + # sys.exit(0) + # os.kill(os.getpid(), signal.SIGTERM) + os._exit(0) + + @router.get("/restart") async def restart(): try: @@ -92,17 +110,20 @@ logger.error(f"Failed to restart container asynchronously: {ex}") # 在新线程中启动重启操作 - threading.Thread(target=restart_container_async).start() + # threading.Thread(target=restart_container_async).start() + threading.Thread(target=shutdown).start() return standard_response() except Exception as e: traceback.print_exc() return standard_error_response(code=500, message=f"Failed to restart container: {e}") + @router.get("/sync_test") def sync_test(): return standard_response() + @router.get("/async_test") async def async_test(): return standard_response() diff --git a/apis/device_scene_realtion.py b/apis/device_scene_realtion.py index d9aa3a2..77d3677 100644 --- a/apis/device_scene_realtion.py +++ b/apis/device_scene_realtion.py @@ -22,8 +22,8 @@ @router.post("/update_by_device", response_model=StandardResponse[DeviceSceneRelation]) -async def update_by_device(device_id: int, scene_id: int, +async def update_by_device(device_id: int, scene_id: int, range_points: str = None, db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - relation = await service.update_relation_by_device(device_id, scene_id) + relation = await service.update_relation_by_device(device_id, scene_id,range_points) return standard_response(data=relation) diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py deleted file mode 100644 index 9109d75..0000000 --- a/algo/algo_runner_manager.py +++ /dev/null @@ -1,8 +0,0 @@ -from algo.algo_runner import AlgoRunner - -# algo_runner = AlgoRunner() - - -def get_algo_runner(): - pass - # return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 6a86830..d2c647e 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -3,7 +3,7 @@ from copy import deepcopy from dataclasses import dataclass import importlib -from datetime import datetime +from datetime import datetime, timedelta from threading import Event from typing import List, Dict @@ -48,6 +48,8 @@ self.__stop_event = Event() # 使用 Event 控制线程的运行状态 self.frame_ts = None self.push_ts = None + self.frames_detected = 0 + self.fps_ts = None self.thread_id = thread_id self.device_frame_service = DeviceFrameService(db_session) @@ -77,37 +79,40 @@ return True return False - def save_frame_results(self, frame, results_map): - if not self.check_frame_interval(): - return - # device_frame = self.device_frame_service.add_frame(self.device.id, frame) - # frame_id = device_frame.id + def save_frame_results(self, frames, result_maps): + for idx,frame in enumerate(frames): + results_map = result_maps[idx] - future = asyncio.run_coroutine_threadsafe( - self.device_frame_service.add_frame(self.device.id, frame), self.main_loop - ) - device_frame = future.result() - frame_id = device_frame.id + if not self.check_frame_interval(): + return + # device_frame = self.device_frame_service.add_frame(self.device.id, frame) + # frame_id = device_frame.id - logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') - frame_results = [] - for model_exec_id, results in results_map.items(): - for r in results: - frame_result = FrameAnalysisResultCreate( - device_id=self.device.id, - frame_id=frame_id, - algo_model_id=model_exec_id, - object_class_id=r.object_class_id, - object_class_name=r.object_class_name, - confidence=round(r.confidence, 4), - location=r.location, - ) - frame_results.append(frame_result) - asyncio.run_coroutine_threadsafe( - self.frame_analysis_result_service.add_frame_analysis_results(frame_results), - self.main_loop - ) - self.thread_pool.submit_task(self.push_frame_results, frame_results) + future = asyncio.run_coroutine_threadsafe( + self.device_frame_service.add_frame(self.device.id, frame), self.main_loop + ) + device_frame = future.result() + frame_id = device_frame.id + + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') + frame_results = [] + for model_exec_id, results in results_map.items(): + for r in results: + frame_result = FrameAnalysisResultCreate( + device_id=self.device.id, + frame_id=frame_id, + algo_model_id=model_exec_id, + object_class_id=r.object_class_id, + object_class_name=r.object_class_name, + confidence=round(r.confidence, 4), + location=r.location, + ) + frame_results.append(frame_result) + asyncio.run_coroutine_threadsafe( + self.frame_analysis_result_service.add_frame_analysis_results(frame_results), + self.main_loop + ) + self.thread_pool.submit_task(self.push_frame_results, frame_results) def push_frame_results(self, frame_results): global_config = GlobalConfig() @@ -124,31 +129,63 @@ ) self.push_ts = current_time # 更新推送时间戳 + def log_fps(self, frame_count): + self.frames_detected += frame_count + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_detected / 10.0 + self.frames_detected = 0 + logger.info(f"FPS (detect) for device {self.device.code}: {fps}") + self.fps_ts = current_time + def run(self): while not self.stream_loader.init: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 self.stream_loader.init_cap() - for frame in self.stream_loader: + for frames in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 # print('frame') - if frame is None: + if frames is None or len(frames)<=0: continue self.device_status_manager.set_status(device_id=self.device.id) - results_map = {} - annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + + result_maps = [] # 保存每个frame的result map + annotators = [] + for frame in frames: + annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + annotators.append(annotator) + for model_exec in self.model_exec_list: handle_task_name = model_exec.algo_model_info.handle_task handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) handler_instance = handler_cls(model_exec) - frame, results = handler_instance.run(frame, annotator) - results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] + frames, results = handler_instance.run(frames, annotators) + + # 遍历检测结果,按帧存储 + for frame_idx, r in enumerate(results): # 遍历每帧结果 + if len(result_maps) <= frame_idx: # 初始化 frame_results_map + result_maps.append({}) + frame_results_map = result_maps[frame_idx] + # 为当前模型存储检测结果 + if model_exec.algo_model_id not in frame_results_map: + frame_results_map[model_exec.algo_model_id] = [] + # 添加检测结果 + frame_results_map[model_exec.algo_model_id].extend( + DetectionResult.from_dict(box) for box in r + ) # 结果处理 - self.thread_pool.submit_task(self.save_frame_results, frame, results_map) - self.display_frame_manager.add_frame(self.device.id, annotator.result()) + self.thread_pool.submit_task(self.save_frame_results, frames, result_maps) + for annotator in annotators: + self.display_frame_manager.add_frame(self.device.id, annotator.result()) + + self.log_fps(len(frames)) + + # future = asyncio.run_coroutine_threadsafe( # self.display_frame_manager.add_frame(self.device.id, annotator.result()), self.main_loop # ) diff --git a/algo/model_manager.py b/algo/model_manager.py index 1c30867..e84676f 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -19,11 +19,12 @@ class ModelManager: - def __init__(self, model_service: ModelService, model_warm_up=5): + def __init__(self, model_service: ModelService, model_warm_up=5, batch_size=4): # self.db = db self.model_service = model_service self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up + self.batch_size=4 async def query_model_inuse(self): algo_model_list = list(await self.model_service.get_models_in_use()) @@ -58,8 +59,11 @@ if not isinstance(imgsz, list): imgsz = [imgsz, imgsz] dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + dummy_inputs = [] + for _ in range(self.batch_size): + dummy_inputs.append(dummy_input) for i in range(self.model_warm_up): - algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + algo_model_exec.algo_model_exec.predict(source=dummy_inputs, imgsz=imgsz, verbose=False,save=False,save_txt=False) logger.info(f'warm up model {model_name} success!') logger.info(f'load model {model_name} success!') diff --git a/algo/scene_runner.py b/algo/scene_runner.py index e5c067e..5120f96 100644 --- a/algo/scene_runner.py +++ b/algo/scene_runner.py @@ -69,7 +69,7 @@ try: handle_task_name = scene.scene_handle_task handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) - handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) + handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop,scene.range_points) future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) self.device_tasks[device.id] = handler_instance logger.info(f'start thread {thread_id}, device info: {device}') diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 4133388..202213c 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -1,13 +1,20 @@ +from datetime import datetime, timedelta + import cv2 import time import numpy as np from threading import Thread, Event + +import queue + from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool class OpenCVStreamLoad: def __init__(self, camera_url, camera_code, device_thread_id = '', + batch_size=4, + queue_size=100, retry_interval=1, vid_stride=1): assert camera_url is not None and camera_url != '' @@ -19,10 +26,15 @@ self.init = False self.frame = None + self.frame_queue = queue.Queue(maxsize=queue_size) + self.fps_ts = None self.cap = None + self.batch_size = batch_size + self.frames_read = 0 self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 self.thread_pool = GlobalThreadPool() + self.init_cap() # self.thread = Thread(target=self.update, daemon=True) # self.thread.start() @@ -44,7 +56,14 @@ 尝试创建视频流捕获对象。 """ try: - cap = cv2.VideoCapture(self.url) + # cap = cv2.VideoCapture(self.url) + gst_pipeline = ( + f"rtspsrc location={self.url} ! " + f"rtph264depay ! h264parse ! " + f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + f"appsink" + ) + cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) # 可以在这里设置cap的一些属性,如果需要的话 return cap except Exception as e: @@ -74,9 +93,13 @@ cap = cv2.VideoCapture(self.url) # cap = cv2.VideoCapture(self.url, cv2.CAP_FFMPEG) - # 使用 GStreamer pipeline 作为 VideoCapture 的输入 - # gst_str = f"rtspsrc location={self.url} ! decodebin ! videoconvert ! appsink" - # cap = cv2.VideoCapture(gst_str, cv2.CAP_GSTREAMER) + # gst_pipeline = ( + # f"rtspsrc location={self.url} ! " + # f"rtph264depay ! h264parse ! " + # f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + # f"appsink" + # ) + # cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) if cap.isOpened(): logger.info(f"{self.url} connected successfully!") @@ -91,6 +114,15 @@ return cap + def log_fps(self): + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_read / 10.0 + self.frames_read = 0 + logger.info(f"FPS (read) for device {self.camera_code}: {fps}") + self.fps_ts = current_time + def update(self): vid_n = 0 log_n = 0 @@ -111,6 +143,9 @@ else: vid_n += 1 self.frame = frame + self.frames_read += 1 + if not self.frame_queue.full(): + self.frame_queue.put(frame) if log_n % 1000 == 0: logger.debug(f'{self.url} cap success') log_n = (log_n + 1) % 250 @@ -120,12 +155,22 @@ self.cap.release() self.frame = None self.cap = self.get_connect() # 尝试重新连接 + self.log_fps() def __iter__(self): return self def __next__(self): - return self.frame + batch_frames = [] + + queue_length = self.frame_queue.qsize() + if queue_length < self.batch_size: + return [] + + while not self.frame_queue.empty() and len(batch_frames) < self.batch_size: + frame = self.frame_queue.get() + batch_frames.append(frame) + return batch_frames def stop(self): logger.info(f'stop stream loader {self.url},device thread {self.device_thread_id} stopped') diff --git a/apis/control.py b/apis/control.py index 9c6d6fb..2d351d8 100644 --- a/apis/control.py +++ b/apis/control.py @@ -9,6 +9,7 @@ from fastapi import APIRouter import docker import socket +import signal from apis.base import standard_error_response, standard_response from common.global_logger import logger @@ -47,6 +48,23 @@ return None +def handle_sigterm(signum, frame): + print(f"Received signal: {signum}, shutting down gracefully...") + os._exit(0) + + +signal.signal(signal.SIGTERM, handle_sigterm) + + +def shutdown(): + # 等待短暂时间以确保响应已发送 + time.sleep(1) + logger.info("Shutting down the application...") + # sys.exit(0) + # os.kill(os.getpid(), signal.SIGTERM) + os._exit(0) + + @router.get("/restart") async def restart(): try: @@ -92,17 +110,20 @@ logger.error(f"Failed to restart container asynchronously: {ex}") # 在新线程中启动重启操作 - threading.Thread(target=restart_container_async).start() + # threading.Thread(target=restart_container_async).start() + threading.Thread(target=shutdown).start() return standard_response() except Exception as e: traceback.print_exc() return standard_error_response(code=500, message=f"Failed to restart container: {e}") + @router.get("/sync_test") def sync_test(): return standard_response() + @router.get("/async_test") async def async_test(): return standard_response() diff --git a/apis/device_scene_realtion.py b/apis/device_scene_realtion.py index d9aa3a2..77d3677 100644 --- a/apis/device_scene_realtion.py +++ b/apis/device_scene_realtion.py @@ -22,8 +22,8 @@ @router.post("/update_by_device", response_model=StandardResponse[DeviceSceneRelation]) -async def update_by_device(device_id: int, scene_id: int, +async def update_by_device(device_id: int, scene_id: int, range_points: str = None, db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - relation = await service.update_relation_by_device(device_id, scene_id) + relation = await service.update_relation_by_device(device_id, scene_id,range_points) return standard_response(data=relation) diff --git a/app_instance.py b/app_instance.py index 82a750d..c648f6e 100644 --- a/app_instance.py +++ b/app_instance.py @@ -1,4 +1,5 @@ import asyncio + from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException @@ -17,6 +18,7 @@ from services.scene_service import SceneService from services.schedule_job import start_scheduler from tcp.tcp_manager import TcpManager +from tcp.tcp_server import start_server _app = None # 创建一个私有变量来存储 app 实例 @@ -69,6 +71,8 @@ app.state.scene_runner = scene_runner await scene_runner.start() + main_loop.create_task(start_server()) + main_loop.create_task(start_scheduler()) yield # 允许请求处理 diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py deleted file mode 100644 index 9109d75..0000000 --- a/algo/algo_runner_manager.py +++ /dev/null @@ -1,8 +0,0 @@ -from algo.algo_runner import AlgoRunner - -# algo_runner = AlgoRunner() - - -def get_algo_runner(): - pass - # return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 6a86830..d2c647e 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -3,7 +3,7 @@ from copy import deepcopy from dataclasses import dataclass import importlib -from datetime import datetime +from datetime import datetime, timedelta from threading import Event from typing import List, Dict @@ -48,6 +48,8 @@ self.__stop_event = Event() # 使用 Event 控制线程的运行状态 self.frame_ts = None self.push_ts = None + self.frames_detected = 0 + self.fps_ts = None self.thread_id = thread_id self.device_frame_service = DeviceFrameService(db_session) @@ -77,37 +79,40 @@ return True return False - def save_frame_results(self, frame, results_map): - if not self.check_frame_interval(): - return - # device_frame = self.device_frame_service.add_frame(self.device.id, frame) - # frame_id = device_frame.id + def save_frame_results(self, frames, result_maps): + for idx,frame in enumerate(frames): + results_map = result_maps[idx] - future = asyncio.run_coroutine_threadsafe( - self.device_frame_service.add_frame(self.device.id, frame), self.main_loop - ) - device_frame = future.result() - frame_id = device_frame.id + if not self.check_frame_interval(): + return + # device_frame = self.device_frame_service.add_frame(self.device.id, frame) + # frame_id = device_frame.id - logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') - frame_results = [] - for model_exec_id, results in results_map.items(): - for r in results: - frame_result = FrameAnalysisResultCreate( - device_id=self.device.id, - frame_id=frame_id, - algo_model_id=model_exec_id, - object_class_id=r.object_class_id, - object_class_name=r.object_class_name, - confidence=round(r.confidence, 4), - location=r.location, - ) - frame_results.append(frame_result) - asyncio.run_coroutine_threadsafe( - self.frame_analysis_result_service.add_frame_analysis_results(frame_results), - self.main_loop - ) - self.thread_pool.submit_task(self.push_frame_results, frame_results) + future = asyncio.run_coroutine_threadsafe( + self.device_frame_service.add_frame(self.device.id, frame), self.main_loop + ) + device_frame = future.result() + frame_id = device_frame.id + + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') + frame_results = [] + for model_exec_id, results in results_map.items(): + for r in results: + frame_result = FrameAnalysisResultCreate( + device_id=self.device.id, + frame_id=frame_id, + algo_model_id=model_exec_id, + object_class_id=r.object_class_id, + object_class_name=r.object_class_name, + confidence=round(r.confidence, 4), + location=r.location, + ) + frame_results.append(frame_result) + asyncio.run_coroutine_threadsafe( + self.frame_analysis_result_service.add_frame_analysis_results(frame_results), + self.main_loop + ) + self.thread_pool.submit_task(self.push_frame_results, frame_results) def push_frame_results(self, frame_results): global_config = GlobalConfig() @@ -124,31 +129,63 @@ ) self.push_ts = current_time # 更新推送时间戳 + def log_fps(self, frame_count): + self.frames_detected += frame_count + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_detected / 10.0 + self.frames_detected = 0 + logger.info(f"FPS (detect) for device {self.device.code}: {fps}") + self.fps_ts = current_time + def run(self): while not self.stream_loader.init: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 self.stream_loader.init_cap() - for frame in self.stream_loader: + for frames in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 # print('frame') - if frame is None: + if frames is None or len(frames)<=0: continue self.device_status_manager.set_status(device_id=self.device.id) - results_map = {} - annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + + result_maps = [] # 保存每个frame的result map + annotators = [] + for frame in frames: + annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + annotators.append(annotator) + for model_exec in self.model_exec_list: handle_task_name = model_exec.algo_model_info.handle_task handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) handler_instance = handler_cls(model_exec) - frame, results = handler_instance.run(frame, annotator) - results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] + frames, results = handler_instance.run(frames, annotators) + + # 遍历检测结果,按帧存储 + for frame_idx, r in enumerate(results): # 遍历每帧结果 + if len(result_maps) <= frame_idx: # 初始化 frame_results_map + result_maps.append({}) + frame_results_map = result_maps[frame_idx] + # 为当前模型存储检测结果 + if model_exec.algo_model_id not in frame_results_map: + frame_results_map[model_exec.algo_model_id] = [] + # 添加检测结果 + frame_results_map[model_exec.algo_model_id].extend( + DetectionResult.from_dict(box) for box in r + ) # 结果处理 - self.thread_pool.submit_task(self.save_frame_results, frame, results_map) - self.display_frame_manager.add_frame(self.device.id, annotator.result()) + self.thread_pool.submit_task(self.save_frame_results, frames, result_maps) + for annotator in annotators: + self.display_frame_manager.add_frame(self.device.id, annotator.result()) + + self.log_fps(len(frames)) + + # future = asyncio.run_coroutine_threadsafe( # self.display_frame_manager.add_frame(self.device.id, annotator.result()), self.main_loop # ) diff --git a/algo/model_manager.py b/algo/model_manager.py index 1c30867..e84676f 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -19,11 +19,12 @@ class ModelManager: - def __init__(self, model_service: ModelService, model_warm_up=5): + def __init__(self, model_service: ModelService, model_warm_up=5, batch_size=4): # self.db = db self.model_service = model_service self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up + self.batch_size=4 async def query_model_inuse(self): algo_model_list = list(await self.model_service.get_models_in_use()) @@ -58,8 +59,11 @@ if not isinstance(imgsz, list): imgsz = [imgsz, imgsz] dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + dummy_inputs = [] + for _ in range(self.batch_size): + dummy_inputs.append(dummy_input) for i in range(self.model_warm_up): - algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + algo_model_exec.algo_model_exec.predict(source=dummy_inputs, imgsz=imgsz, verbose=False,save=False,save_txt=False) logger.info(f'warm up model {model_name} success!') logger.info(f'load model {model_name} success!') diff --git a/algo/scene_runner.py b/algo/scene_runner.py index e5c067e..5120f96 100644 --- a/algo/scene_runner.py +++ b/algo/scene_runner.py @@ -69,7 +69,7 @@ try: handle_task_name = scene.scene_handle_task handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) - handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) + handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop,scene.range_points) future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) self.device_tasks[device.id] = handler_instance logger.info(f'start thread {thread_id}, device info: {device}') diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 4133388..202213c 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -1,13 +1,20 @@ +from datetime import datetime, timedelta + import cv2 import time import numpy as np from threading import Thread, Event + +import queue + from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool class OpenCVStreamLoad: def __init__(self, camera_url, camera_code, device_thread_id = '', + batch_size=4, + queue_size=100, retry_interval=1, vid_stride=1): assert camera_url is not None and camera_url != '' @@ -19,10 +26,15 @@ self.init = False self.frame = None + self.frame_queue = queue.Queue(maxsize=queue_size) + self.fps_ts = None self.cap = None + self.batch_size = batch_size + self.frames_read = 0 self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 self.thread_pool = GlobalThreadPool() + self.init_cap() # self.thread = Thread(target=self.update, daemon=True) # self.thread.start() @@ -44,7 +56,14 @@ 尝试创建视频流捕获对象。 """ try: - cap = cv2.VideoCapture(self.url) + # cap = cv2.VideoCapture(self.url) + gst_pipeline = ( + f"rtspsrc location={self.url} ! " + f"rtph264depay ! h264parse ! " + f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + f"appsink" + ) + cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) # 可以在这里设置cap的一些属性,如果需要的话 return cap except Exception as e: @@ -74,9 +93,13 @@ cap = cv2.VideoCapture(self.url) # cap = cv2.VideoCapture(self.url, cv2.CAP_FFMPEG) - # 使用 GStreamer pipeline 作为 VideoCapture 的输入 - # gst_str = f"rtspsrc location={self.url} ! decodebin ! videoconvert ! appsink" - # cap = cv2.VideoCapture(gst_str, cv2.CAP_GSTREAMER) + # gst_pipeline = ( + # f"rtspsrc location={self.url} ! " + # f"rtph264depay ! h264parse ! " + # f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + # f"appsink" + # ) + # cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) if cap.isOpened(): logger.info(f"{self.url} connected successfully!") @@ -91,6 +114,15 @@ return cap + def log_fps(self): + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_read / 10.0 + self.frames_read = 0 + logger.info(f"FPS (read) for device {self.camera_code}: {fps}") + self.fps_ts = current_time + def update(self): vid_n = 0 log_n = 0 @@ -111,6 +143,9 @@ else: vid_n += 1 self.frame = frame + self.frames_read += 1 + if not self.frame_queue.full(): + self.frame_queue.put(frame) if log_n % 1000 == 0: logger.debug(f'{self.url} cap success') log_n = (log_n + 1) % 250 @@ -120,12 +155,22 @@ self.cap.release() self.frame = None self.cap = self.get_connect() # 尝试重新连接 + self.log_fps() def __iter__(self): return self def __next__(self): - return self.frame + batch_frames = [] + + queue_length = self.frame_queue.qsize() + if queue_length < self.batch_size: + return [] + + while not self.frame_queue.empty() and len(batch_frames) < self.batch_size: + frame = self.frame_queue.get() + batch_frames.append(frame) + return batch_frames def stop(self): logger.info(f'stop stream loader {self.url},device thread {self.device_thread_id} stopped') diff --git a/apis/control.py b/apis/control.py index 9c6d6fb..2d351d8 100644 --- a/apis/control.py +++ b/apis/control.py @@ -9,6 +9,7 @@ from fastapi import APIRouter import docker import socket +import signal from apis.base import standard_error_response, standard_response from common.global_logger import logger @@ -47,6 +48,23 @@ return None +def handle_sigterm(signum, frame): + print(f"Received signal: {signum}, shutting down gracefully...") + os._exit(0) + + +signal.signal(signal.SIGTERM, handle_sigterm) + + +def shutdown(): + # 等待短暂时间以确保响应已发送 + time.sleep(1) + logger.info("Shutting down the application...") + # sys.exit(0) + # os.kill(os.getpid(), signal.SIGTERM) + os._exit(0) + + @router.get("/restart") async def restart(): try: @@ -92,17 +110,20 @@ logger.error(f"Failed to restart container asynchronously: {ex}") # 在新线程中启动重启操作 - threading.Thread(target=restart_container_async).start() + # threading.Thread(target=restart_container_async).start() + threading.Thread(target=shutdown).start() return standard_response() except Exception as e: traceback.print_exc() return standard_error_response(code=500, message=f"Failed to restart container: {e}") + @router.get("/sync_test") def sync_test(): return standard_response() + @router.get("/async_test") async def async_test(): return standard_response() diff --git a/apis/device_scene_realtion.py b/apis/device_scene_realtion.py index d9aa3a2..77d3677 100644 --- a/apis/device_scene_realtion.py +++ b/apis/device_scene_realtion.py @@ -22,8 +22,8 @@ @router.post("/update_by_device", response_model=StandardResponse[DeviceSceneRelation]) -async def update_by_device(device_id: int, scene_id: int, +async def update_by_device(device_id: int, scene_id: int, range_points: str = None, db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - relation = await service.update_relation_by_device(device_id, scene_id) + relation = await service.update_relation_by_device(device_id, scene_id,range_points) return standard_response(data=relation) diff --git a/app_instance.py b/app_instance.py index 82a750d..c648f6e 100644 --- a/app_instance.py +++ b/app_instance.py @@ -1,4 +1,5 @@ import asyncio + from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException @@ -17,6 +18,7 @@ from services.scene_service import SceneService from services.schedule_job import start_scheduler from tcp.tcp_manager import TcpManager +from tcp.tcp_server import start_server _app = None # 创建一个私有变量来存储 app 实例 @@ -69,6 +71,8 @@ app.state.scene_runner = scene_runner await scene_runner.start() + main_loop.create_task(start_server()) + main_loop.create_task(start_scheduler()) yield # 允许请求处理 diff --git a/common/byte_utils.py b/common/byte_utils.py index 301bb8a..ea3f4ea 100644 --- a/common/byte_utils.py +++ b/common/byte_utils.py @@ -1,3 +1,11 @@ def format_bytes(data: bytes): - return ''.join(f'\\x{byte:02X}' for byte in data) \ No newline at end of file + return ''.join(f'\\x{byte:02X}' for byte in data) + + +if __name__ == '__main__': + msg = b'\xaa\x01\x00\x93\x12\x00\xA6' + print(format_bytes(msg)) + + msg = b'\xAA\x01\x00\x95\x00\x00\x96' + print(format_bytes(msg)) \ No newline at end of file diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py deleted file mode 100644 index 9109d75..0000000 --- a/algo/algo_runner_manager.py +++ /dev/null @@ -1,8 +0,0 @@ -from algo.algo_runner import AlgoRunner - -# algo_runner = AlgoRunner() - - -def get_algo_runner(): - pass - # return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 6a86830..d2c647e 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -3,7 +3,7 @@ from copy import deepcopy from dataclasses import dataclass import importlib -from datetime import datetime +from datetime import datetime, timedelta from threading import Event from typing import List, Dict @@ -48,6 +48,8 @@ self.__stop_event = Event() # 使用 Event 控制线程的运行状态 self.frame_ts = None self.push_ts = None + self.frames_detected = 0 + self.fps_ts = None self.thread_id = thread_id self.device_frame_service = DeviceFrameService(db_session) @@ -77,37 +79,40 @@ return True return False - def save_frame_results(self, frame, results_map): - if not self.check_frame_interval(): - return - # device_frame = self.device_frame_service.add_frame(self.device.id, frame) - # frame_id = device_frame.id + def save_frame_results(self, frames, result_maps): + for idx,frame in enumerate(frames): + results_map = result_maps[idx] - future = asyncio.run_coroutine_threadsafe( - self.device_frame_service.add_frame(self.device.id, frame), self.main_loop - ) - device_frame = future.result() - frame_id = device_frame.id + if not self.check_frame_interval(): + return + # device_frame = self.device_frame_service.add_frame(self.device.id, frame) + # frame_id = device_frame.id - logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') - frame_results = [] - for model_exec_id, results in results_map.items(): - for r in results: - frame_result = FrameAnalysisResultCreate( - device_id=self.device.id, - frame_id=frame_id, - algo_model_id=model_exec_id, - object_class_id=r.object_class_id, - object_class_name=r.object_class_name, - confidence=round(r.confidence, 4), - location=r.location, - ) - frame_results.append(frame_result) - asyncio.run_coroutine_threadsafe( - self.frame_analysis_result_service.add_frame_analysis_results(frame_results), - self.main_loop - ) - self.thread_pool.submit_task(self.push_frame_results, frame_results) + future = asyncio.run_coroutine_threadsafe( + self.device_frame_service.add_frame(self.device.id, frame), self.main_loop + ) + device_frame = future.result() + frame_id = device_frame.id + + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') + frame_results = [] + for model_exec_id, results in results_map.items(): + for r in results: + frame_result = FrameAnalysisResultCreate( + device_id=self.device.id, + frame_id=frame_id, + algo_model_id=model_exec_id, + object_class_id=r.object_class_id, + object_class_name=r.object_class_name, + confidence=round(r.confidence, 4), + location=r.location, + ) + frame_results.append(frame_result) + asyncio.run_coroutine_threadsafe( + self.frame_analysis_result_service.add_frame_analysis_results(frame_results), + self.main_loop + ) + self.thread_pool.submit_task(self.push_frame_results, frame_results) def push_frame_results(self, frame_results): global_config = GlobalConfig() @@ -124,31 +129,63 @@ ) self.push_ts = current_time # 更新推送时间戳 + def log_fps(self, frame_count): + self.frames_detected += frame_count + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_detected / 10.0 + self.frames_detected = 0 + logger.info(f"FPS (detect) for device {self.device.code}: {fps}") + self.fps_ts = current_time + def run(self): while not self.stream_loader.init: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 self.stream_loader.init_cap() - for frame in self.stream_loader: + for frames in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 # print('frame') - if frame is None: + if frames is None or len(frames)<=0: continue self.device_status_manager.set_status(device_id=self.device.id) - results_map = {} - annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + + result_maps = [] # 保存每个frame的result map + annotators = [] + for frame in frames: + annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + annotators.append(annotator) + for model_exec in self.model_exec_list: handle_task_name = model_exec.algo_model_info.handle_task handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) handler_instance = handler_cls(model_exec) - frame, results = handler_instance.run(frame, annotator) - results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] + frames, results = handler_instance.run(frames, annotators) + + # 遍历检测结果,按帧存储 + for frame_idx, r in enumerate(results): # 遍历每帧结果 + if len(result_maps) <= frame_idx: # 初始化 frame_results_map + result_maps.append({}) + frame_results_map = result_maps[frame_idx] + # 为当前模型存储检测结果 + if model_exec.algo_model_id not in frame_results_map: + frame_results_map[model_exec.algo_model_id] = [] + # 添加检测结果 + frame_results_map[model_exec.algo_model_id].extend( + DetectionResult.from_dict(box) for box in r + ) # 结果处理 - self.thread_pool.submit_task(self.save_frame_results, frame, results_map) - self.display_frame_manager.add_frame(self.device.id, annotator.result()) + self.thread_pool.submit_task(self.save_frame_results, frames, result_maps) + for annotator in annotators: + self.display_frame_manager.add_frame(self.device.id, annotator.result()) + + self.log_fps(len(frames)) + + # future = asyncio.run_coroutine_threadsafe( # self.display_frame_manager.add_frame(self.device.id, annotator.result()), self.main_loop # ) diff --git a/algo/model_manager.py b/algo/model_manager.py index 1c30867..e84676f 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -19,11 +19,12 @@ class ModelManager: - def __init__(self, model_service: ModelService, model_warm_up=5): + def __init__(self, model_service: ModelService, model_warm_up=5, batch_size=4): # self.db = db self.model_service = model_service self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up + self.batch_size=4 async def query_model_inuse(self): algo_model_list = list(await self.model_service.get_models_in_use()) @@ -58,8 +59,11 @@ if not isinstance(imgsz, list): imgsz = [imgsz, imgsz] dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + dummy_inputs = [] + for _ in range(self.batch_size): + dummy_inputs.append(dummy_input) for i in range(self.model_warm_up): - algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + algo_model_exec.algo_model_exec.predict(source=dummy_inputs, imgsz=imgsz, verbose=False,save=False,save_txt=False) logger.info(f'warm up model {model_name} success!') logger.info(f'load model {model_name} success!') diff --git a/algo/scene_runner.py b/algo/scene_runner.py index e5c067e..5120f96 100644 --- a/algo/scene_runner.py +++ b/algo/scene_runner.py @@ -69,7 +69,7 @@ try: handle_task_name = scene.scene_handle_task handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) - handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) + handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop,scene.range_points) future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) self.device_tasks[device.id] = handler_instance logger.info(f'start thread {thread_id}, device info: {device}') diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 4133388..202213c 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -1,13 +1,20 @@ +from datetime import datetime, timedelta + import cv2 import time import numpy as np from threading import Thread, Event + +import queue + from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool class OpenCVStreamLoad: def __init__(self, camera_url, camera_code, device_thread_id = '', + batch_size=4, + queue_size=100, retry_interval=1, vid_stride=1): assert camera_url is not None and camera_url != '' @@ -19,10 +26,15 @@ self.init = False self.frame = None + self.frame_queue = queue.Queue(maxsize=queue_size) + self.fps_ts = None self.cap = None + self.batch_size = batch_size + self.frames_read = 0 self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 self.thread_pool = GlobalThreadPool() + self.init_cap() # self.thread = Thread(target=self.update, daemon=True) # self.thread.start() @@ -44,7 +56,14 @@ 尝试创建视频流捕获对象。 """ try: - cap = cv2.VideoCapture(self.url) + # cap = cv2.VideoCapture(self.url) + gst_pipeline = ( + f"rtspsrc location={self.url} ! " + f"rtph264depay ! h264parse ! " + f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + f"appsink" + ) + cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) # 可以在这里设置cap的一些属性,如果需要的话 return cap except Exception as e: @@ -74,9 +93,13 @@ cap = cv2.VideoCapture(self.url) # cap = cv2.VideoCapture(self.url, cv2.CAP_FFMPEG) - # 使用 GStreamer pipeline 作为 VideoCapture 的输入 - # gst_str = f"rtspsrc location={self.url} ! decodebin ! videoconvert ! appsink" - # cap = cv2.VideoCapture(gst_str, cv2.CAP_GSTREAMER) + # gst_pipeline = ( + # f"rtspsrc location={self.url} ! " + # f"rtph264depay ! h264parse ! " + # f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + # f"appsink" + # ) + # cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) if cap.isOpened(): logger.info(f"{self.url} connected successfully!") @@ -91,6 +114,15 @@ return cap + def log_fps(self): + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_read / 10.0 + self.frames_read = 0 + logger.info(f"FPS (read) for device {self.camera_code}: {fps}") + self.fps_ts = current_time + def update(self): vid_n = 0 log_n = 0 @@ -111,6 +143,9 @@ else: vid_n += 1 self.frame = frame + self.frames_read += 1 + if not self.frame_queue.full(): + self.frame_queue.put(frame) if log_n % 1000 == 0: logger.debug(f'{self.url} cap success') log_n = (log_n + 1) % 250 @@ -120,12 +155,22 @@ self.cap.release() self.frame = None self.cap = self.get_connect() # 尝试重新连接 + self.log_fps() def __iter__(self): return self def __next__(self): - return self.frame + batch_frames = [] + + queue_length = self.frame_queue.qsize() + if queue_length < self.batch_size: + return [] + + while not self.frame_queue.empty() and len(batch_frames) < self.batch_size: + frame = self.frame_queue.get() + batch_frames.append(frame) + return batch_frames def stop(self): logger.info(f'stop stream loader {self.url},device thread {self.device_thread_id} stopped') diff --git a/apis/control.py b/apis/control.py index 9c6d6fb..2d351d8 100644 --- a/apis/control.py +++ b/apis/control.py @@ -9,6 +9,7 @@ from fastapi import APIRouter import docker import socket +import signal from apis.base import standard_error_response, standard_response from common.global_logger import logger @@ -47,6 +48,23 @@ return None +def handle_sigterm(signum, frame): + print(f"Received signal: {signum}, shutting down gracefully...") + os._exit(0) + + +signal.signal(signal.SIGTERM, handle_sigterm) + + +def shutdown(): + # 等待短暂时间以确保响应已发送 + time.sleep(1) + logger.info("Shutting down the application...") + # sys.exit(0) + # os.kill(os.getpid(), signal.SIGTERM) + os._exit(0) + + @router.get("/restart") async def restart(): try: @@ -92,17 +110,20 @@ logger.error(f"Failed to restart container asynchronously: {ex}") # 在新线程中启动重启操作 - threading.Thread(target=restart_container_async).start() + # threading.Thread(target=restart_container_async).start() + threading.Thread(target=shutdown).start() return standard_response() except Exception as e: traceback.print_exc() return standard_error_response(code=500, message=f"Failed to restart container: {e}") + @router.get("/sync_test") def sync_test(): return standard_response() + @router.get("/async_test") async def async_test(): return standard_response() diff --git a/apis/device_scene_realtion.py b/apis/device_scene_realtion.py index d9aa3a2..77d3677 100644 --- a/apis/device_scene_realtion.py +++ b/apis/device_scene_realtion.py @@ -22,8 +22,8 @@ @router.post("/update_by_device", response_model=StandardResponse[DeviceSceneRelation]) -async def update_by_device(device_id: int, scene_id: int, +async def update_by_device(device_id: int, scene_id: int, range_points: str = None, db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - relation = await service.update_relation_by_device(device_id, scene_id) + relation = await service.update_relation_by_device(device_id, scene_id,range_points) return standard_response(data=relation) diff --git a/app_instance.py b/app_instance.py index 82a750d..c648f6e 100644 --- a/app_instance.py +++ b/app_instance.py @@ -1,4 +1,5 @@ import asyncio + from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException @@ -17,6 +18,7 @@ from services.scene_service import SceneService from services.schedule_job import start_scheduler from tcp.tcp_manager import TcpManager +from tcp.tcp_server import start_server _app = None # 创建一个私有变量来存储 app 实例 @@ -69,6 +71,8 @@ app.state.scene_runner = scene_runner await scene_runner.start() + main_loop.create_task(start_server()) + main_loop.create_task(start_scheduler()) yield # 允许请求处理 diff --git a/common/byte_utils.py b/common/byte_utils.py index 301bb8a..ea3f4ea 100644 --- a/common/byte_utils.py +++ b/common/byte_utils.py @@ -1,3 +1,11 @@ def format_bytes(data: bytes): - return ''.join(f'\\x{byte:02X}' for byte in data) \ No newline at end of file + return ''.join(f'\\x{byte:02X}' for byte in data) + + +if __name__ == '__main__': + msg = b'\xaa\x01\x00\x93\x12\x00\xA6' + print(format_bytes(msg)) + + msg = b'\xAA\x01\x00\x95\x00\x00\x96' + print(format_bytes(msg)) \ No newline at end of file diff --git a/common/consts.py b/common/consts.py index 5a3286c..6deaedf 100644 --- a/common/consts.py +++ b/common/consts.py @@ -36,3 +36,4 @@ GAS = 1 ALGO_RESULT = 2 ALARM = 3 + HARMFUL_GAS = 4 diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py deleted file mode 100644 index 9109d75..0000000 --- a/algo/algo_runner_manager.py +++ /dev/null @@ -1,8 +0,0 @@ -from algo.algo_runner import AlgoRunner - -# algo_runner = AlgoRunner() - - -def get_algo_runner(): - pass - # return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 6a86830..d2c647e 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -3,7 +3,7 @@ from copy import deepcopy from dataclasses import dataclass import importlib -from datetime import datetime +from datetime import datetime, timedelta from threading import Event from typing import List, Dict @@ -48,6 +48,8 @@ self.__stop_event = Event() # 使用 Event 控制线程的运行状态 self.frame_ts = None self.push_ts = None + self.frames_detected = 0 + self.fps_ts = None self.thread_id = thread_id self.device_frame_service = DeviceFrameService(db_session) @@ -77,37 +79,40 @@ return True return False - def save_frame_results(self, frame, results_map): - if not self.check_frame_interval(): - return - # device_frame = self.device_frame_service.add_frame(self.device.id, frame) - # frame_id = device_frame.id + def save_frame_results(self, frames, result_maps): + for idx,frame in enumerate(frames): + results_map = result_maps[idx] - future = asyncio.run_coroutine_threadsafe( - self.device_frame_service.add_frame(self.device.id, frame), self.main_loop - ) - device_frame = future.result() - frame_id = device_frame.id + if not self.check_frame_interval(): + return + # device_frame = self.device_frame_service.add_frame(self.device.id, frame) + # frame_id = device_frame.id - logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') - frame_results = [] - for model_exec_id, results in results_map.items(): - for r in results: - frame_result = FrameAnalysisResultCreate( - device_id=self.device.id, - frame_id=frame_id, - algo_model_id=model_exec_id, - object_class_id=r.object_class_id, - object_class_name=r.object_class_name, - confidence=round(r.confidence, 4), - location=r.location, - ) - frame_results.append(frame_result) - asyncio.run_coroutine_threadsafe( - self.frame_analysis_result_service.add_frame_analysis_results(frame_results), - self.main_loop - ) - self.thread_pool.submit_task(self.push_frame_results, frame_results) + future = asyncio.run_coroutine_threadsafe( + self.device_frame_service.add_frame(self.device.id, frame), self.main_loop + ) + device_frame = future.result() + frame_id = device_frame.id + + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') + frame_results = [] + for model_exec_id, results in results_map.items(): + for r in results: + frame_result = FrameAnalysisResultCreate( + device_id=self.device.id, + frame_id=frame_id, + algo_model_id=model_exec_id, + object_class_id=r.object_class_id, + object_class_name=r.object_class_name, + confidence=round(r.confidence, 4), + location=r.location, + ) + frame_results.append(frame_result) + asyncio.run_coroutine_threadsafe( + self.frame_analysis_result_service.add_frame_analysis_results(frame_results), + self.main_loop + ) + self.thread_pool.submit_task(self.push_frame_results, frame_results) def push_frame_results(self, frame_results): global_config = GlobalConfig() @@ -124,31 +129,63 @@ ) self.push_ts = current_time # 更新推送时间戳 + def log_fps(self, frame_count): + self.frames_detected += frame_count + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_detected / 10.0 + self.frames_detected = 0 + logger.info(f"FPS (detect) for device {self.device.code}: {fps}") + self.fps_ts = current_time + def run(self): while not self.stream_loader.init: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 self.stream_loader.init_cap() - for frame in self.stream_loader: + for frames in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 # print('frame') - if frame is None: + if frames is None or len(frames)<=0: continue self.device_status_manager.set_status(device_id=self.device.id) - results_map = {} - annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + + result_maps = [] # 保存每个frame的result map + annotators = [] + for frame in frames: + annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + annotators.append(annotator) + for model_exec in self.model_exec_list: handle_task_name = model_exec.algo_model_info.handle_task handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) handler_instance = handler_cls(model_exec) - frame, results = handler_instance.run(frame, annotator) - results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] + frames, results = handler_instance.run(frames, annotators) + + # 遍历检测结果,按帧存储 + for frame_idx, r in enumerate(results): # 遍历每帧结果 + if len(result_maps) <= frame_idx: # 初始化 frame_results_map + result_maps.append({}) + frame_results_map = result_maps[frame_idx] + # 为当前模型存储检测结果 + if model_exec.algo_model_id not in frame_results_map: + frame_results_map[model_exec.algo_model_id] = [] + # 添加检测结果 + frame_results_map[model_exec.algo_model_id].extend( + DetectionResult.from_dict(box) for box in r + ) # 结果处理 - self.thread_pool.submit_task(self.save_frame_results, frame, results_map) - self.display_frame_manager.add_frame(self.device.id, annotator.result()) + self.thread_pool.submit_task(self.save_frame_results, frames, result_maps) + for annotator in annotators: + self.display_frame_manager.add_frame(self.device.id, annotator.result()) + + self.log_fps(len(frames)) + + # future = asyncio.run_coroutine_threadsafe( # self.display_frame_manager.add_frame(self.device.id, annotator.result()), self.main_loop # ) diff --git a/algo/model_manager.py b/algo/model_manager.py index 1c30867..e84676f 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -19,11 +19,12 @@ class ModelManager: - def __init__(self, model_service: ModelService, model_warm_up=5): + def __init__(self, model_service: ModelService, model_warm_up=5, batch_size=4): # self.db = db self.model_service = model_service self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up + self.batch_size=4 async def query_model_inuse(self): algo_model_list = list(await self.model_service.get_models_in_use()) @@ -58,8 +59,11 @@ if not isinstance(imgsz, list): imgsz = [imgsz, imgsz] dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + dummy_inputs = [] + for _ in range(self.batch_size): + dummy_inputs.append(dummy_input) for i in range(self.model_warm_up): - algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + algo_model_exec.algo_model_exec.predict(source=dummy_inputs, imgsz=imgsz, verbose=False,save=False,save_txt=False) logger.info(f'warm up model {model_name} success!') logger.info(f'load model {model_name} success!') diff --git a/algo/scene_runner.py b/algo/scene_runner.py index e5c067e..5120f96 100644 --- a/algo/scene_runner.py +++ b/algo/scene_runner.py @@ -69,7 +69,7 @@ try: handle_task_name = scene.scene_handle_task handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) - handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) + handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop,scene.range_points) future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) self.device_tasks[device.id] = handler_instance logger.info(f'start thread {thread_id}, device info: {device}') diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 4133388..202213c 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -1,13 +1,20 @@ +from datetime import datetime, timedelta + import cv2 import time import numpy as np from threading import Thread, Event + +import queue + from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool class OpenCVStreamLoad: def __init__(self, camera_url, camera_code, device_thread_id = '', + batch_size=4, + queue_size=100, retry_interval=1, vid_stride=1): assert camera_url is not None and camera_url != '' @@ -19,10 +26,15 @@ self.init = False self.frame = None + self.frame_queue = queue.Queue(maxsize=queue_size) + self.fps_ts = None self.cap = None + self.batch_size = batch_size + self.frames_read = 0 self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 self.thread_pool = GlobalThreadPool() + self.init_cap() # self.thread = Thread(target=self.update, daemon=True) # self.thread.start() @@ -44,7 +56,14 @@ 尝试创建视频流捕获对象。 """ try: - cap = cv2.VideoCapture(self.url) + # cap = cv2.VideoCapture(self.url) + gst_pipeline = ( + f"rtspsrc location={self.url} ! " + f"rtph264depay ! h264parse ! " + f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + f"appsink" + ) + cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) # 可以在这里设置cap的一些属性,如果需要的话 return cap except Exception as e: @@ -74,9 +93,13 @@ cap = cv2.VideoCapture(self.url) # cap = cv2.VideoCapture(self.url, cv2.CAP_FFMPEG) - # 使用 GStreamer pipeline 作为 VideoCapture 的输入 - # gst_str = f"rtspsrc location={self.url} ! decodebin ! videoconvert ! appsink" - # cap = cv2.VideoCapture(gst_str, cv2.CAP_GSTREAMER) + # gst_pipeline = ( + # f"rtspsrc location={self.url} ! " + # f"rtph264depay ! h264parse ! " + # f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + # f"appsink" + # ) + # cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) if cap.isOpened(): logger.info(f"{self.url} connected successfully!") @@ -91,6 +114,15 @@ return cap + def log_fps(self): + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_read / 10.0 + self.frames_read = 0 + logger.info(f"FPS (read) for device {self.camera_code}: {fps}") + self.fps_ts = current_time + def update(self): vid_n = 0 log_n = 0 @@ -111,6 +143,9 @@ else: vid_n += 1 self.frame = frame + self.frames_read += 1 + if not self.frame_queue.full(): + self.frame_queue.put(frame) if log_n % 1000 == 0: logger.debug(f'{self.url} cap success') log_n = (log_n + 1) % 250 @@ -120,12 +155,22 @@ self.cap.release() self.frame = None self.cap = self.get_connect() # 尝试重新连接 + self.log_fps() def __iter__(self): return self def __next__(self): - return self.frame + batch_frames = [] + + queue_length = self.frame_queue.qsize() + if queue_length < self.batch_size: + return [] + + while not self.frame_queue.empty() and len(batch_frames) < self.batch_size: + frame = self.frame_queue.get() + batch_frames.append(frame) + return batch_frames def stop(self): logger.info(f'stop stream loader {self.url},device thread {self.device_thread_id} stopped') diff --git a/apis/control.py b/apis/control.py index 9c6d6fb..2d351d8 100644 --- a/apis/control.py +++ b/apis/control.py @@ -9,6 +9,7 @@ from fastapi import APIRouter import docker import socket +import signal from apis.base import standard_error_response, standard_response from common.global_logger import logger @@ -47,6 +48,23 @@ return None +def handle_sigterm(signum, frame): + print(f"Received signal: {signum}, shutting down gracefully...") + os._exit(0) + + +signal.signal(signal.SIGTERM, handle_sigterm) + + +def shutdown(): + # 等待短暂时间以确保响应已发送 + time.sleep(1) + logger.info("Shutting down the application...") + # sys.exit(0) + # os.kill(os.getpid(), signal.SIGTERM) + os._exit(0) + + @router.get("/restart") async def restart(): try: @@ -92,17 +110,20 @@ logger.error(f"Failed to restart container asynchronously: {ex}") # 在新线程中启动重启操作 - threading.Thread(target=restart_container_async).start() + # threading.Thread(target=restart_container_async).start() + threading.Thread(target=shutdown).start() return standard_response() except Exception as e: traceback.print_exc() return standard_error_response(code=500, message=f"Failed to restart container: {e}") + @router.get("/sync_test") def sync_test(): return standard_response() + @router.get("/async_test") async def async_test(): return standard_response() diff --git a/apis/device_scene_realtion.py b/apis/device_scene_realtion.py index d9aa3a2..77d3677 100644 --- a/apis/device_scene_realtion.py +++ b/apis/device_scene_realtion.py @@ -22,8 +22,8 @@ @router.post("/update_by_device", response_model=StandardResponse[DeviceSceneRelation]) -async def update_by_device(device_id: int, scene_id: int, +async def update_by_device(device_id: int, scene_id: int, range_points: str = None, db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - relation = await service.update_relation_by_device(device_id, scene_id) + relation = await service.update_relation_by_device(device_id, scene_id,range_points) return standard_response(data=relation) diff --git a/app_instance.py b/app_instance.py index 82a750d..c648f6e 100644 --- a/app_instance.py +++ b/app_instance.py @@ -1,4 +1,5 @@ import asyncio + from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException @@ -17,6 +18,7 @@ from services.scene_service import SceneService from services.schedule_job import start_scheduler from tcp.tcp_manager import TcpManager +from tcp.tcp_server import start_server _app = None # 创建一个私有变量来存储 app 实例 @@ -69,6 +71,8 @@ app.state.scene_runner = scene_runner await scene_runner.start() + main_loop.create_task(start_server()) + main_loop.create_task(start_scheduler()) yield # 允许请求处理 diff --git a/common/byte_utils.py b/common/byte_utils.py index 301bb8a..ea3f4ea 100644 --- a/common/byte_utils.py +++ b/common/byte_utils.py @@ -1,3 +1,11 @@ def format_bytes(data: bytes): - return ''.join(f'\\x{byte:02X}' for byte in data) \ No newline at end of file + return ''.join(f'\\x{byte:02X}' for byte in data) + + +if __name__ == '__main__': + msg = b'\xaa\x01\x00\x93\x12\x00\xA6' + print(format_bytes(msg)) + + msg = b'\xAA\x01\x00\x95\x00\x00\x96' + print(format_bytes(msg)) \ No newline at end of file diff --git a/common/consts.py b/common/consts.py index 5a3286c..6deaedf 100644 --- a/common/consts.py +++ b/common/consts.py @@ -36,3 +36,4 @@ GAS = 1 ALGO_RESULT = 2 ALARM = 3 + HARMFUL_GAS = 4 diff --git a/common/detect_utils.py b/common/detect_utils.py new file mode 100644 index 0000000..33c340c --- /dev/null +++ b/common/detect_utils.py @@ -0,0 +1,72 @@ +import matplotlib.path as mpath + + +def is_within_alert_range(bbox, range): + if range is None: + return True + bottom_left = [bbox[0], bbox[3]] + bottom_right = [bbox[2], bbox[3]] + + in_fence = is_point_in_polygon(bottom_left, range) or is_point_in_polygon(bottom_right, range) + return in_fence + + +def is_point_in_polygon(point, polygon): + """ + 判断点是否在多边形内 + + 参数: + - point: 待判断的点,格式为(x, y) + - polygon: 多边形的顶点坐标列表,格式为[(x1, y1), (x2, y2), ...] + + 返回: + - 如果点在多边形内,返回True;否则,返回False。 + """ + # 创建多边形路径对象 + # polygon = np.array(polygon, dtype=np.int32) + # + # hull = ConvexHull(polygon) + # sorted_coordinates = polygon[hull.vertices] + + poly_path = mpath.Path(polygon) + # 使用contains_point方法检查点是否在多边形内 + return poly_path.contains_point(point) + + +def intersection_area(bbox1, bbox2): + # 计算两个坐标框的重叠面积 + x1, y1, x2, y2 = bbox1 + x3, y3, x4, y4 = bbox2 + + xi1 = max(x1, x3) + yi1 = max(y1, y3) + xi2 = min(x2, x4) + yi2 = min(y2, y4) + + width = max(0, xi2 - xi1) + height = max(0, yi2 - yi1) + + return width * height + + +def bbox_area(bbox): + # 计算坐标框的面积 + x1, y1, x2, y2 = bbox + return (x2 - x1) * (y2 - y1) + + +def get_person_head(person_bbox, head_bboxes): + best_head = None + max_overlap = 0 + + for head in head_bboxes: + head_bbox = head.xyxy.cpu().squeeze() + overlap_area = intersection_area(person_bbox, head_bbox) + head_area = bbox_area(head_bbox) + overlap_ratio = overlap_area / head_area + + if overlap_ratio >= 0.8 and (best_head is None or overlap_area > max_overlap or ( + overlap_area == max_overlap and float(head.conf) > float(best_head.conf))): + best_head = head + max_overlap = overlap_area + return best_head diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py deleted file mode 100644 index 9109d75..0000000 --- a/algo/algo_runner_manager.py +++ /dev/null @@ -1,8 +0,0 @@ -from algo.algo_runner import AlgoRunner - -# algo_runner = AlgoRunner() - - -def get_algo_runner(): - pass - # return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 6a86830..d2c647e 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -3,7 +3,7 @@ from copy import deepcopy from dataclasses import dataclass import importlib -from datetime import datetime +from datetime import datetime, timedelta from threading import Event from typing import List, Dict @@ -48,6 +48,8 @@ self.__stop_event = Event() # 使用 Event 控制线程的运行状态 self.frame_ts = None self.push_ts = None + self.frames_detected = 0 + self.fps_ts = None self.thread_id = thread_id self.device_frame_service = DeviceFrameService(db_session) @@ -77,37 +79,40 @@ return True return False - def save_frame_results(self, frame, results_map): - if not self.check_frame_interval(): - return - # device_frame = self.device_frame_service.add_frame(self.device.id, frame) - # frame_id = device_frame.id + def save_frame_results(self, frames, result_maps): + for idx,frame in enumerate(frames): + results_map = result_maps[idx] - future = asyncio.run_coroutine_threadsafe( - self.device_frame_service.add_frame(self.device.id, frame), self.main_loop - ) - device_frame = future.result() - frame_id = device_frame.id + if not self.check_frame_interval(): + return + # device_frame = self.device_frame_service.add_frame(self.device.id, frame) + # frame_id = device_frame.id - logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') - frame_results = [] - for model_exec_id, results in results_map.items(): - for r in results: - frame_result = FrameAnalysisResultCreate( - device_id=self.device.id, - frame_id=frame_id, - algo_model_id=model_exec_id, - object_class_id=r.object_class_id, - object_class_name=r.object_class_name, - confidence=round(r.confidence, 4), - location=r.location, - ) - frame_results.append(frame_result) - asyncio.run_coroutine_threadsafe( - self.frame_analysis_result_service.add_frame_analysis_results(frame_results), - self.main_loop - ) - self.thread_pool.submit_task(self.push_frame_results, frame_results) + future = asyncio.run_coroutine_threadsafe( + self.device_frame_service.add_frame(self.device.id, frame), self.main_loop + ) + device_frame = future.result() + frame_id = device_frame.id + + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') + frame_results = [] + for model_exec_id, results in results_map.items(): + for r in results: + frame_result = FrameAnalysisResultCreate( + device_id=self.device.id, + frame_id=frame_id, + algo_model_id=model_exec_id, + object_class_id=r.object_class_id, + object_class_name=r.object_class_name, + confidence=round(r.confidence, 4), + location=r.location, + ) + frame_results.append(frame_result) + asyncio.run_coroutine_threadsafe( + self.frame_analysis_result_service.add_frame_analysis_results(frame_results), + self.main_loop + ) + self.thread_pool.submit_task(self.push_frame_results, frame_results) def push_frame_results(self, frame_results): global_config = GlobalConfig() @@ -124,31 +129,63 @@ ) self.push_ts = current_time # 更新推送时间戳 + def log_fps(self, frame_count): + self.frames_detected += frame_count + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_detected / 10.0 + self.frames_detected = 0 + logger.info(f"FPS (detect) for device {self.device.code}: {fps}") + self.fps_ts = current_time + def run(self): while not self.stream_loader.init: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 self.stream_loader.init_cap() - for frame in self.stream_loader: + for frames in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 # print('frame') - if frame is None: + if frames is None or len(frames)<=0: continue self.device_status_manager.set_status(device_id=self.device.id) - results_map = {} - annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + + result_maps = [] # 保存每个frame的result map + annotators = [] + for frame in frames: + annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + annotators.append(annotator) + for model_exec in self.model_exec_list: handle_task_name = model_exec.algo_model_info.handle_task handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) handler_instance = handler_cls(model_exec) - frame, results = handler_instance.run(frame, annotator) - results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] + frames, results = handler_instance.run(frames, annotators) + + # 遍历检测结果,按帧存储 + for frame_idx, r in enumerate(results): # 遍历每帧结果 + if len(result_maps) <= frame_idx: # 初始化 frame_results_map + result_maps.append({}) + frame_results_map = result_maps[frame_idx] + # 为当前模型存储检测结果 + if model_exec.algo_model_id not in frame_results_map: + frame_results_map[model_exec.algo_model_id] = [] + # 添加检测结果 + frame_results_map[model_exec.algo_model_id].extend( + DetectionResult.from_dict(box) for box in r + ) # 结果处理 - self.thread_pool.submit_task(self.save_frame_results, frame, results_map) - self.display_frame_manager.add_frame(self.device.id, annotator.result()) + self.thread_pool.submit_task(self.save_frame_results, frames, result_maps) + for annotator in annotators: + self.display_frame_manager.add_frame(self.device.id, annotator.result()) + + self.log_fps(len(frames)) + + # future = asyncio.run_coroutine_threadsafe( # self.display_frame_manager.add_frame(self.device.id, annotator.result()), self.main_loop # ) diff --git a/algo/model_manager.py b/algo/model_manager.py index 1c30867..e84676f 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -19,11 +19,12 @@ class ModelManager: - def __init__(self, model_service: ModelService, model_warm_up=5): + def __init__(self, model_service: ModelService, model_warm_up=5, batch_size=4): # self.db = db self.model_service = model_service self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up + self.batch_size=4 async def query_model_inuse(self): algo_model_list = list(await self.model_service.get_models_in_use()) @@ -58,8 +59,11 @@ if not isinstance(imgsz, list): imgsz = [imgsz, imgsz] dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + dummy_inputs = [] + for _ in range(self.batch_size): + dummy_inputs.append(dummy_input) for i in range(self.model_warm_up): - algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + algo_model_exec.algo_model_exec.predict(source=dummy_inputs, imgsz=imgsz, verbose=False,save=False,save_txt=False) logger.info(f'warm up model {model_name} success!') logger.info(f'load model {model_name} success!') diff --git a/algo/scene_runner.py b/algo/scene_runner.py index e5c067e..5120f96 100644 --- a/algo/scene_runner.py +++ b/algo/scene_runner.py @@ -69,7 +69,7 @@ try: handle_task_name = scene.scene_handle_task handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) - handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) + handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop,scene.range_points) future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) self.device_tasks[device.id] = handler_instance logger.info(f'start thread {thread_id}, device info: {device}') diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 4133388..202213c 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -1,13 +1,20 @@ +from datetime import datetime, timedelta + import cv2 import time import numpy as np from threading import Thread, Event + +import queue + from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool class OpenCVStreamLoad: def __init__(self, camera_url, camera_code, device_thread_id = '', + batch_size=4, + queue_size=100, retry_interval=1, vid_stride=1): assert camera_url is not None and camera_url != '' @@ -19,10 +26,15 @@ self.init = False self.frame = None + self.frame_queue = queue.Queue(maxsize=queue_size) + self.fps_ts = None self.cap = None + self.batch_size = batch_size + self.frames_read = 0 self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 self.thread_pool = GlobalThreadPool() + self.init_cap() # self.thread = Thread(target=self.update, daemon=True) # self.thread.start() @@ -44,7 +56,14 @@ 尝试创建视频流捕获对象。 """ try: - cap = cv2.VideoCapture(self.url) + # cap = cv2.VideoCapture(self.url) + gst_pipeline = ( + f"rtspsrc location={self.url} ! " + f"rtph264depay ! h264parse ! " + f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + f"appsink" + ) + cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) # 可以在这里设置cap的一些属性,如果需要的话 return cap except Exception as e: @@ -74,9 +93,13 @@ cap = cv2.VideoCapture(self.url) # cap = cv2.VideoCapture(self.url, cv2.CAP_FFMPEG) - # 使用 GStreamer pipeline 作为 VideoCapture 的输入 - # gst_str = f"rtspsrc location={self.url} ! decodebin ! videoconvert ! appsink" - # cap = cv2.VideoCapture(gst_str, cv2.CAP_GSTREAMER) + # gst_pipeline = ( + # f"rtspsrc location={self.url} ! " + # f"rtph264depay ! h264parse ! " + # f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + # f"appsink" + # ) + # cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) if cap.isOpened(): logger.info(f"{self.url} connected successfully!") @@ -91,6 +114,15 @@ return cap + def log_fps(self): + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_read / 10.0 + self.frames_read = 0 + logger.info(f"FPS (read) for device {self.camera_code}: {fps}") + self.fps_ts = current_time + def update(self): vid_n = 0 log_n = 0 @@ -111,6 +143,9 @@ else: vid_n += 1 self.frame = frame + self.frames_read += 1 + if not self.frame_queue.full(): + self.frame_queue.put(frame) if log_n % 1000 == 0: logger.debug(f'{self.url} cap success') log_n = (log_n + 1) % 250 @@ -120,12 +155,22 @@ self.cap.release() self.frame = None self.cap = self.get_connect() # 尝试重新连接 + self.log_fps() def __iter__(self): return self def __next__(self): - return self.frame + batch_frames = [] + + queue_length = self.frame_queue.qsize() + if queue_length < self.batch_size: + return [] + + while not self.frame_queue.empty() and len(batch_frames) < self.batch_size: + frame = self.frame_queue.get() + batch_frames.append(frame) + return batch_frames def stop(self): logger.info(f'stop stream loader {self.url},device thread {self.device_thread_id} stopped') diff --git a/apis/control.py b/apis/control.py index 9c6d6fb..2d351d8 100644 --- a/apis/control.py +++ b/apis/control.py @@ -9,6 +9,7 @@ from fastapi import APIRouter import docker import socket +import signal from apis.base import standard_error_response, standard_response from common.global_logger import logger @@ -47,6 +48,23 @@ return None +def handle_sigterm(signum, frame): + print(f"Received signal: {signum}, shutting down gracefully...") + os._exit(0) + + +signal.signal(signal.SIGTERM, handle_sigterm) + + +def shutdown(): + # 等待短暂时间以确保响应已发送 + time.sleep(1) + logger.info("Shutting down the application...") + # sys.exit(0) + # os.kill(os.getpid(), signal.SIGTERM) + os._exit(0) + + @router.get("/restart") async def restart(): try: @@ -92,17 +110,20 @@ logger.error(f"Failed to restart container asynchronously: {ex}") # 在新线程中启动重启操作 - threading.Thread(target=restart_container_async).start() + # threading.Thread(target=restart_container_async).start() + threading.Thread(target=shutdown).start() return standard_response() except Exception as e: traceback.print_exc() return standard_error_response(code=500, message=f"Failed to restart container: {e}") + @router.get("/sync_test") def sync_test(): return standard_response() + @router.get("/async_test") async def async_test(): return standard_response() diff --git a/apis/device_scene_realtion.py b/apis/device_scene_realtion.py index d9aa3a2..77d3677 100644 --- a/apis/device_scene_realtion.py +++ b/apis/device_scene_realtion.py @@ -22,8 +22,8 @@ @router.post("/update_by_device", response_model=StandardResponse[DeviceSceneRelation]) -async def update_by_device(device_id: int, scene_id: int, +async def update_by_device(device_id: int, scene_id: int, range_points: str = None, db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - relation = await service.update_relation_by_device(device_id, scene_id) + relation = await service.update_relation_by_device(device_id, scene_id,range_points) return standard_response(data=relation) diff --git a/app_instance.py b/app_instance.py index 82a750d..c648f6e 100644 --- a/app_instance.py +++ b/app_instance.py @@ -1,4 +1,5 @@ import asyncio + from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException @@ -17,6 +18,7 @@ from services.scene_service import SceneService from services.schedule_job import start_scheduler from tcp.tcp_manager import TcpManager +from tcp.tcp_server import start_server _app = None # 创建一个私有变量来存储 app 实例 @@ -69,6 +71,8 @@ app.state.scene_runner = scene_runner await scene_runner.start() + main_loop.create_task(start_server()) + main_loop.create_task(start_scheduler()) yield # 允许请求处理 diff --git a/common/byte_utils.py b/common/byte_utils.py index 301bb8a..ea3f4ea 100644 --- a/common/byte_utils.py +++ b/common/byte_utils.py @@ -1,3 +1,11 @@ def format_bytes(data: bytes): - return ''.join(f'\\x{byte:02X}' for byte in data) \ No newline at end of file + return ''.join(f'\\x{byte:02X}' for byte in data) + + +if __name__ == '__main__': + msg = b'\xaa\x01\x00\x93\x12\x00\xA6' + print(format_bytes(msg)) + + msg = b'\xAA\x01\x00\x95\x00\x00\x96' + print(format_bytes(msg)) \ No newline at end of file diff --git a/common/consts.py b/common/consts.py index 5a3286c..6deaedf 100644 --- a/common/consts.py +++ b/common/consts.py @@ -36,3 +36,4 @@ GAS = 1 ALGO_RESULT = 2 ALARM = 3 + HARMFUL_GAS = 4 diff --git a/common/detect_utils.py b/common/detect_utils.py new file mode 100644 index 0000000..33c340c --- /dev/null +++ b/common/detect_utils.py @@ -0,0 +1,72 @@ +import matplotlib.path as mpath + + +def is_within_alert_range(bbox, range): + if range is None: + return True + bottom_left = [bbox[0], bbox[3]] + bottom_right = [bbox[2], bbox[3]] + + in_fence = is_point_in_polygon(bottom_left, range) or is_point_in_polygon(bottom_right, range) + return in_fence + + +def is_point_in_polygon(point, polygon): + """ + 判断点是否在多边形内 + + 参数: + - point: 待判断的点,格式为(x, y) + - polygon: 多边形的顶点坐标列表,格式为[(x1, y1), (x2, y2), ...] + + 返回: + - 如果点在多边形内,返回True;否则,返回False。 + """ + # 创建多边形路径对象 + # polygon = np.array(polygon, dtype=np.int32) + # + # hull = ConvexHull(polygon) + # sorted_coordinates = polygon[hull.vertices] + + poly_path = mpath.Path(polygon) + # 使用contains_point方法检查点是否在多边形内 + return poly_path.contains_point(point) + + +def intersection_area(bbox1, bbox2): + # 计算两个坐标框的重叠面积 + x1, y1, x2, y2 = bbox1 + x3, y3, x4, y4 = bbox2 + + xi1 = max(x1, x3) + yi1 = max(y1, y3) + xi2 = min(x2, x4) + yi2 = min(y2, y4) + + width = max(0, xi2 - xi1) + height = max(0, yi2 - yi1) + + return width * height + + +def bbox_area(bbox): + # 计算坐标框的面积 + x1, y1, x2, y2 = bbox + return (x2 - x1) * (y2 - y1) + + +def get_person_head(person_bbox, head_bboxes): + best_head = None + max_overlap = 0 + + for head in head_bboxes: + head_bbox = head.xyxy.cpu().squeeze() + overlap_area = intersection_area(person_bbox, head_bbox) + head_area = bbox_area(head_bbox) + overlap_ratio = overlap_area / head_area + + if overlap_ratio >= 0.8 and (best_head is None or overlap_area > max_overlap or ( + overlap_area == max_overlap and float(head.conf) > float(best_head.conf))): + best_head = head + max_overlap = overlap_area + return best_head diff --git a/common/harmful_gas_manager.py b/common/harmful_gas_manager.py new file mode 100644 index 0000000..297e02c --- /dev/null +++ b/common/harmful_gas_manager.py @@ -0,0 +1,44 @@ +import threading +from datetime import datetime + + +class HarmfulGasManager: + + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + # 确保线程安全的单例模式 + if not cls._instance: + with cls._lock: + if not cls._instance: + cls._instance = super(HarmfulGasManager, cls).__new__(cls) + return cls._instance + + def __init__(self): + # 初始化一次,避免重复初始化 + if not hasattr(self, "device_data"): + self.device_data = {} + self.lock = threading.Lock() + + def get_device_all_data(self, device_code): + """获取指定设备的在线状态""" + with self.lock: + return self.device_data.get(device_code, None) + + def get_device_data(self, device_code, gas_type): + """获取指定设备的在线状态""" + with self.lock: + device_data = self.device_data.get(device_code, None) + if device_data: + return device_data.get(gas_type, None) + return None + + def set_device_data(self, device_code, gas_type, gas_data): + """设置指定设备的在线状态""" + with self.lock: + device_data = self.device_data.get(device_code, None) + if not device_data: + device_data = {} + self.device_data[device_code] = device_data + device_data[gas_type] = gas_data \ No newline at end of file diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py deleted file mode 100644 index 9109d75..0000000 --- a/algo/algo_runner_manager.py +++ /dev/null @@ -1,8 +0,0 @@ -from algo.algo_runner import AlgoRunner - -# algo_runner = AlgoRunner() - - -def get_algo_runner(): - pass - # return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 6a86830..d2c647e 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -3,7 +3,7 @@ from copy import deepcopy from dataclasses import dataclass import importlib -from datetime import datetime +from datetime import datetime, timedelta from threading import Event from typing import List, Dict @@ -48,6 +48,8 @@ self.__stop_event = Event() # 使用 Event 控制线程的运行状态 self.frame_ts = None self.push_ts = None + self.frames_detected = 0 + self.fps_ts = None self.thread_id = thread_id self.device_frame_service = DeviceFrameService(db_session) @@ -77,37 +79,40 @@ return True return False - def save_frame_results(self, frame, results_map): - if not self.check_frame_interval(): - return - # device_frame = self.device_frame_service.add_frame(self.device.id, frame) - # frame_id = device_frame.id + def save_frame_results(self, frames, result_maps): + for idx,frame in enumerate(frames): + results_map = result_maps[idx] - future = asyncio.run_coroutine_threadsafe( - self.device_frame_service.add_frame(self.device.id, frame), self.main_loop - ) - device_frame = future.result() - frame_id = device_frame.id + if not self.check_frame_interval(): + return + # device_frame = self.device_frame_service.add_frame(self.device.id, frame) + # frame_id = device_frame.id - logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') - frame_results = [] - for model_exec_id, results in results_map.items(): - for r in results: - frame_result = FrameAnalysisResultCreate( - device_id=self.device.id, - frame_id=frame_id, - algo_model_id=model_exec_id, - object_class_id=r.object_class_id, - object_class_name=r.object_class_name, - confidence=round(r.confidence, 4), - location=r.location, - ) - frame_results.append(frame_result) - asyncio.run_coroutine_threadsafe( - self.frame_analysis_result_service.add_frame_analysis_results(frame_results), - self.main_loop - ) - self.thread_pool.submit_task(self.push_frame_results, frame_results) + future = asyncio.run_coroutine_threadsafe( + self.device_frame_service.add_frame(self.device.id, frame), self.main_loop + ) + device_frame = future.result() + frame_id = device_frame.id + + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') + frame_results = [] + for model_exec_id, results in results_map.items(): + for r in results: + frame_result = FrameAnalysisResultCreate( + device_id=self.device.id, + frame_id=frame_id, + algo_model_id=model_exec_id, + object_class_id=r.object_class_id, + object_class_name=r.object_class_name, + confidence=round(r.confidence, 4), + location=r.location, + ) + frame_results.append(frame_result) + asyncio.run_coroutine_threadsafe( + self.frame_analysis_result_service.add_frame_analysis_results(frame_results), + self.main_loop + ) + self.thread_pool.submit_task(self.push_frame_results, frame_results) def push_frame_results(self, frame_results): global_config = GlobalConfig() @@ -124,31 +129,63 @@ ) self.push_ts = current_time # 更新推送时间戳 + def log_fps(self, frame_count): + self.frames_detected += frame_count + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_detected / 10.0 + self.frames_detected = 0 + logger.info(f"FPS (detect) for device {self.device.code}: {fps}") + self.fps_ts = current_time + def run(self): while not self.stream_loader.init: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 self.stream_loader.init_cap() - for frame in self.stream_loader: + for frames in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 # print('frame') - if frame is None: + if frames is None or len(frames)<=0: continue self.device_status_manager.set_status(device_id=self.device.id) - results_map = {} - annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + + result_maps = [] # 保存每个frame的result map + annotators = [] + for frame in frames: + annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + annotators.append(annotator) + for model_exec in self.model_exec_list: handle_task_name = model_exec.algo_model_info.handle_task handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) handler_instance = handler_cls(model_exec) - frame, results = handler_instance.run(frame, annotator) - results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] + frames, results = handler_instance.run(frames, annotators) + + # 遍历检测结果,按帧存储 + for frame_idx, r in enumerate(results): # 遍历每帧结果 + if len(result_maps) <= frame_idx: # 初始化 frame_results_map + result_maps.append({}) + frame_results_map = result_maps[frame_idx] + # 为当前模型存储检测结果 + if model_exec.algo_model_id not in frame_results_map: + frame_results_map[model_exec.algo_model_id] = [] + # 添加检测结果 + frame_results_map[model_exec.algo_model_id].extend( + DetectionResult.from_dict(box) for box in r + ) # 结果处理 - self.thread_pool.submit_task(self.save_frame_results, frame, results_map) - self.display_frame_manager.add_frame(self.device.id, annotator.result()) + self.thread_pool.submit_task(self.save_frame_results, frames, result_maps) + for annotator in annotators: + self.display_frame_manager.add_frame(self.device.id, annotator.result()) + + self.log_fps(len(frames)) + + # future = asyncio.run_coroutine_threadsafe( # self.display_frame_manager.add_frame(self.device.id, annotator.result()), self.main_loop # ) diff --git a/algo/model_manager.py b/algo/model_manager.py index 1c30867..e84676f 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -19,11 +19,12 @@ class ModelManager: - def __init__(self, model_service: ModelService, model_warm_up=5): + def __init__(self, model_service: ModelService, model_warm_up=5, batch_size=4): # self.db = db self.model_service = model_service self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up + self.batch_size=4 async def query_model_inuse(self): algo_model_list = list(await self.model_service.get_models_in_use()) @@ -58,8 +59,11 @@ if not isinstance(imgsz, list): imgsz = [imgsz, imgsz] dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + dummy_inputs = [] + for _ in range(self.batch_size): + dummy_inputs.append(dummy_input) for i in range(self.model_warm_up): - algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + algo_model_exec.algo_model_exec.predict(source=dummy_inputs, imgsz=imgsz, verbose=False,save=False,save_txt=False) logger.info(f'warm up model {model_name} success!') logger.info(f'load model {model_name} success!') diff --git a/algo/scene_runner.py b/algo/scene_runner.py index e5c067e..5120f96 100644 --- a/algo/scene_runner.py +++ b/algo/scene_runner.py @@ -69,7 +69,7 @@ try: handle_task_name = scene.scene_handle_task handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) - handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) + handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop,scene.range_points) future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) self.device_tasks[device.id] = handler_instance logger.info(f'start thread {thread_id}, device info: {device}') diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 4133388..202213c 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -1,13 +1,20 @@ +from datetime import datetime, timedelta + import cv2 import time import numpy as np from threading import Thread, Event + +import queue + from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool class OpenCVStreamLoad: def __init__(self, camera_url, camera_code, device_thread_id = '', + batch_size=4, + queue_size=100, retry_interval=1, vid_stride=1): assert camera_url is not None and camera_url != '' @@ -19,10 +26,15 @@ self.init = False self.frame = None + self.frame_queue = queue.Queue(maxsize=queue_size) + self.fps_ts = None self.cap = None + self.batch_size = batch_size + self.frames_read = 0 self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 self.thread_pool = GlobalThreadPool() + self.init_cap() # self.thread = Thread(target=self.update, daemon=True) # self.thread.start() @@ -44,7 +56,14 @@ 尝试创建视频流捕获对象。 """ try: - cap = cv2.VideoCapture(self.url) + # cap = cv2.VideoCapture(self.url) + gst_pipeline = ( + f"rtspsrc location={self.url} ! " + f"rtph264depay ! h264parse ! " + f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + f"appsink" + ) + cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) # 可以在这里设置cap的一些属性,如果需要的话 return cap except Exception as e: @@ -74,9 +93,13 @@ cap = cv2.VideoCapture(self.url) # cap = cv2.VideoCapture(self.url, cv2.CAP_FFMPEG) - # 使用 GStreamer pipeline 作为 VideoCapture 的输入 - # gst_str = f"rtspsrc location={self.url} ! decodebin ! videoconvert ! appsink" - # cap = cv2.VideoCapture(gst_str, cv2.CAP_GSTREAMER) + # gst_pipeline = ( + # f"rtspsrc location={self.url} ! " + # f"rtph264depay ! h264parse ! " + # f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + # f"appsink" + # ) + # cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) if cap.isOpened(): logger.info(f"{self.url} connected successfully!") @@ -91,6 +114,15 @@ return cap + def log_fps(self): + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_read / 10.0 + self.frames_read = 0 + logger.info(f"FPS (read) for device {self.camera_code}: {fps}") + self.fps_ts = current_time + def update(self): vid_n = 0 log_n = 0 @@ -111,6 +143,9 @@ else: vid_n += 1 self.frame = frame + self.frames_read += 1 + if not self.frame_queue.full(): + self.frame_queue.put(frame) if log_n % 1000 == 0: logger.debug(f'{self.url} cap success') log_n = (log_n + 1) % 250 @@ -120,12 +155,22 @@ self.cap.release() self.frame = None self.cap = self.get_connect() # 尝试重新连接 + self.log_fps() def __iter__(self): return self def __next__(self): - return self.frame + batch_frames = [] + + queue_length = self.frame_queue.qsize() + if queue_length < self.batch_size: + return [] + + while not self.frame_queue.empty() and len(batch_frames) < self.batch_size: + frame = self.frame_queue.get() + batch_frames.append(frame) + return batch_frames def stop(self): logger.info(f'stop stream loader {self.url},device thread {self.device_thread_id} stopped') diff --git a/apis/control.py b/apis/control.py index 9c6d6fb..2d351d8 100644 --- a/apis/control.py +++ b/apis/control.py @@ -9,6 +9,7 @@ from fastapi import APIRouter import docker import socket +import signal from apis.base import standard_error_response, standard_response from common.global_logger import logger @@ -47,6 +48,23 @@ return None +def handle_sigterm(signum, frame): + print(f"Received signal: {signum}, shutting down gracefully...") + os._exit(0) + + +signal.signal(signal.SIGTERM, handle_sigterm) + + +def shutdown(): + # 等待短暂时间以确保响应已发送 + time.sleep(1) + logger.info("Shutting down the application...") + # sys.exit(0) + # os.kill(os.getpid(), signal.SIGTERM) + os._exit(0) + + @router.get("/restart") async def restart(): try: @@ -92,17 +110,20 @@ logger.error(f"Failed to restart container asynchronously: {ex}") # 在新线程中启动重启操作 - threading.Thread(target=restart_container_async).start() + # threading.Thread(target=restart_container_async).start() + threading.Thread(target=shutdown).start() return standard_response() except Exception as e: traceback.print_exc() return standard_error_response(code=500, message=f"Failed to restart container: {e}") + @router.get("/sync_test") def sync_test(): return standard_response() + @router.get("/async_test") async def async_test(): return standard_response() diff --git a/apis/device_scene_realtion.py b/apis/device_scene_realtion.py index d9aa3a2..77d3677 100644 --- a/apis/device_scene_realtion.py +++ b/apis/device_scene_realtion.py @@ -22,8 +22,8 @@ @router.post("/update_by_device", response_model=StandardResponse[DeviceSceneRelation]) -async def update_by_device(device_id: int, scene_id: int, +async def update_by_device(device_id: int, scene_id: int, range_points: str = None, db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - relation = await service.update_relation_by_device(device_id, scene_id) + relation = await service.update_relation_by_device(device_id, scene_id,range_points) return standard_response(data=relation) diff --git a/app_instance.py b/app_instance.py index 82a750d..c648f6e 100644 --- a/app_instance.py +++ b/app_instance.py @@ -1,4 +1,5 @@ import asyncio + from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException @@ -17,6 +18,7 @@ from services.scene_service import SceneService from services.schedule_job import start_scheduler from tcp.tcp_manager import TcpManager +from tcp.tcp_server import start_server _app = None # 创建一个私有变量来存储 app 实例 @@ -69,6 +71,8 @@ app.state.scene_runner = scene_runner await scene_runner.start() + main_loop.create_task(start_server()) + main_loop.create_task(start_scheduler()) yield # 允许请求处理 diff --git a/common/byte_utils.py b/common/byte_utils.py index 301bb8a..ea3f4ea 100644 --- a/common/byte_utils.py +++ b/common/byte_utils.py @@ -1,3 +1,11 @@ def format_bytes(data: bytes): - return ''.join(f'\\x{byte:02X}' for byte in data) \ No newline at end of file + return ''.join(f'\\x{byte:02X}' for byte in data) + + +if __name__ == '__main__': + msg = b'\xaa\x01\x00\x93\x12\x00\xA6' + print(format_bytes(msg)) + + msg = b'\xAA\x01\x00\x95\x00\x00\x96' + print(format_bytes(msg)) \ No newline at end of file diff --git a/common/consts.py b/common/consts.py index 5a3286c..6deaedf 100644 --- a/common/consts.py +++ b/common/consts.py @@ -36,3 +36,4 @@ GAS = 1 ALGO_RESULT = 2 ALARM = 3 + HARMFUL_GAS = 4 diff --git a/common/detect_utils.py b/common/detect_utils.py new file mode 100644 index 0000000..33c340c --- /dev/null +++ b/common/detect_utils.py @@ -0,0 +1,72 @@ +import matplotlib.path as mpath + + +def is_within_alert_range(bbox, range): + if range is None: + return True + bottom_left = [bbox[0], bbox[3]] + bottom_right = [bbox[2], bbox[3]] + + in_fence = is_point_in_polygon(bottom_left, range) or is_point_in_polygon(bottom_right, range) + return in_fence + + +def is_point_in_polygon(point, polygon): + """ + 判断点是否在多边形内 + + 参数: + - point: 待判断的点,格式为(x, y) + - polygon: 多边形的顶点坐标列表,格式为[(x1, y1), (x2, y2), ...] + + 返回: + - 如果点在多边形内,返回True;否则,返回False。 + """ + # 创建多边形路径对象 + # polygon = np.array(polygon, dtype=np.int32) + # + # hull = ConvexHull(polygon) + # sorted_coordinates = polygon[hull.vertices] + + poly_path = mpath.Path(polygon) + # 使用contains_point方法检查点是否在多边形内 + return poly_path.contains_point(point) + + +def intersection_area(bbox1, bbox2): + # 计算两个坐标框的重叠面积 + x1, y1, x2, y2 = bbox1 + x3, y3, x4, y4 = bbox2 + + xi1 = max(x1, x3) + yi1 = max(y1, y3) + xi2 = min(x2, x4) + yi2 = min(y2, y4) + + width = max(0, xi2 - xi1) + height = max(0, yi2 - yi1) + + return width * height + + +def bbox_area(bbox): + # 计算坐标框的面积 + x1, y1, x2, y2 = bbox + return (x2 - x1) * (y2 - y1) + + +def get_person_head(person_bbox, head_bboxes): + best_head = None + max_overlap = 0 + + for head in head_bboxes: + head_bbox = head.xyxy.cpu().squeeze() + overlap_area = intersection_area(person_bbox, head_bbox) + head_area = bbox_area(head_bbox) + overlap_ratio = overlap_area / head_area + + if overlap_ratio >= 0.8 and (best_head is None or overlap_area > max_overlap or ( + overlap_area == max_overlap and float(head.conf) > float(best_head.conf))): + best_head = head + max_overlap = overlap_area + return best_head diff --git a/common/harmful_gas_manager.py b/common/harmful_gas_manager.py new file mode 100644 index 0000000..297e02c --- /dev/null +++ b/common/harmful_gas_manager.py @@ -0,0 +1,44 @@ +import threading +from datetime import datetime + + +class HarmfulGasManager: + + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + # 确保线程安全的单例模式 + if not cls._instance: + with cls._lock: + if not cls._instance: + cls._instance = super(HarmfulGasManager, cls).__new__(cls) + return cls._instance + + def __init__(self): + # 初始化一次,避免重复初始化 + if not hasattr(self, "device_data"): + self.device_data = {} + self.lock = threading.Lock() + + def get_device_all_data(self, device_code): + """获取指定设备的在线状态""" + with self.lock: + return self.device_data.get(device_code, None) + + def get_device_data(self, device_code, gas_type): + """获取指定设备的在线状态""" + with self.lock: + device_data = self.device_data.get(device_code, None) + if device_data: + return device_data.get(gas_type, None) + return None + + def set_device_data(self, device_code, gas_type, gas_data): + """设置指定设备的在线状态""" + with self.lock: + device_data = self.device_data.get(device_code, None) + if not device_data: + device_data = {} + self.device_data[device_code] = device_data + device_data[gas_type] = gas_data \ No newline at end of file diff --git a/common/http_utils.py b/common/http_utils.py index 5f15707..a49c6ea 100644 --- a/common/http_utils.py +++ b/common/http_utils.py @@ -22,3 +22,11 @@ logger.info(f"Response: {response.status_code}, {response.text}") except requests.RequestException as e: logger.error(f"Failed to push data: {e}") + + +def get_request(url, headers=None): + try: + response = requests.get(url, headers=headers) + return response.json() + except requests.RequestException as e: + logger.error(f"Failed to get data: {e}") diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py deleted file mode 100644 index 9109d75..0000000 --- a/algo/algo_runner_manager.py +++ /dev/null @@ -1,8 +0,0 @@ -from algo.algo_runner import AlgoRunner - -# algo_runner = AlgoRunner() - - -def get_algo_runner(): - pass - # return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 6a86830..d2c647e 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -3,7 +3,7 @@ from copy import deepcopy from dataclasses import dataclass import importlib -from datetime import datetime +from datetime import datetime, timedelta from threading import Event from typing import List, Dict @@ -48,6 +48,8 @@ self.__stop_event = Event() # 使用 Event 控制线程的运行状态 self.frame_ts = None self.push_ts = None + self.frames_detected = 0 + self.fps_ts = None self.thread_id = thread_id self.device_frame_service = DeviceFrameService(db_session) @@ -77,37 +79,40 @@ return True return False - def save_frame_results(self, frame, results_map): - if not self.check_frame_interval(): - return - # device_frame = self.device_frame_service.add_frame(self.device.id, frame) - # frame_id = device_frame.id + def save_frame_results(self, frames, result_maps): + for idx,frame in enumerate(frames): + results_map = result_maps[idx] - future = asyncio.run_coroutine_threadsafe( - self.device_frame_service.add_frame(self.device.id, frame), self.main_loop - ) - device_frame = future.result() - frame_id = device_frame.id + if not self.check_frame_interval(): + return + # device_frame = self.device_frame_service.add_frame(self.device.id, frame) + # frame_id = device_frame.id - logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') - frame_results = [] - for model_exec_id, results in results_map.items(): - for r in results: - frame_result = FrameAnalysisResultCreate( - device_id=self.device.id, - frame_id=frame_id, - algo_model_id=model_exec_id, - object_class_id=r.object_class_id, - object_class_name=r.object_class_name, - confidence=round(r.confidence, 4), - location=r.location, - ) - frame_results.append(frame_result) - asyncio.run_coroutine_threadsafe( - self.frame_analysis_result_service.add_frame_analysis_results(frame_results), - self.main_loop - ) - self.thread_pool.submit_task(self.push_frame_results, frame_results) + future = asyncio.run_coroutine_threadsafe( + self.device_frame_service.add_frame(self.device.id, frame), self.main_loop + ) + device_frame = future.result() + frame_id = device_frame.id + + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') + frame_results = [] + for model_exec_id, results in results_map.items(): + for r in results: + frame_result = FrameAnalysisResultCreate( + device_id=self.device.id, + frame_id=frame_id, + algo_model_id=model_exec_id, + object_class_id=r.object_class_id, + object_class_name=r.object_class_name, + confidence=round(r.confidence, 4), + location=r.location, + ) + frame_results.append(frame_result) + asyncio.run_coroutine_threadsafe( + self.frame_analysis_result_service.add_frame_analysis_results(frame_results), + self.main_loop + ) + self.thread_pool.submit_task(self.push_frame_results, frame_results) def push_frame_results(self, frame_results): global_config = GlobalConfig() @@ -124,31 +129,63 @@ ) self.push_ts = current_time # 更新推送时间戳 + def log_fps(self, frame_count): + self.frames_detected += frame_count + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_detected / 10.0 + self.frames_detected = 0 + logger.info(f"FPS (detect) for device {self.device.code}: {fps}") + self.fps_ts = current_time + def run(self): while not self.stream_loader.init: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 self.stream_loader.init_cap() - for frame in self.stream_loader: + for frames in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 # print('frame') - if frame is None: + if frames is None or len(frames)<=0: continue self.device_status_manager.set_status(device_id=self.device.id) - results_map = {} - annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + + result_maps = [] # 保存每个frame的result map + annotators = [] + for frame in frames: + annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + annotators.append(annotator) + for model_exec in self.model_exec_list: handle_task_name = model_exec.algo_model_info.handle_task handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) handler_instance = handler_cls(model_exec) - frame, results = handler_instance.run(frame, annotator) - results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] + frames, results = handler_instance.run(frames, annotators) + + # 遍历检测结果,按帧存储 + for frame_idx, r in enumerate(results): # 遍历每帧结果 + if len(result_maps) <= frame_idx: # 初始化 frame_results_map + result_maps.append({}) + frame_results_map = result_maps[frame_idx] + # 为当前模型存储检测结果 + if model_exec.algo_model_id not in frame_results_map: + frame_results_map[model_exec.algo_model_id] = [] + # 添加检测结果 + frame_results_map[model_exec.algo_model_id].extend( + DetectionResult.from_dict(box) for box in r + ) # 结果处理 - self.thread_pool.submit_task(self.save_frame_results, frame, results_map) - self.display_frame_manager.add_frame(self.device.id, annotator.result()) + self.thread_pool.submit_task(self.save_frame_results, frames, result_maps) + for annotator in annotators: + self.display_frame_manager.add_frame(self.device.id, annotator.result()) + + self.log_fps(len(frames)) + + # future = asyncio.run_coroutine_threadsafe( # self.display_frame_manager.add_frame(self.device.id, annotator.result()), self.main_loop # ) diff --git a/algo/model_manager.py b/algo/model_manager.py index 1c30867..e84676f 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -19,11 +19,12 @@ class ModelManager: - def __init__(self, model_service: ModelService, model_warm_up=5): + def __init__(self, model_service: ModelService, model_warm_up=5, batch_size=4): # self.db = db self.model_service = model_service self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up + self.batch_size=4 async def query_model_inuse(self): algo_model_list = list(await self.model_service.get_models_in_use()) @@ -58,8 +59,11 @@ if not isinstance(imgsz, list): imgsz = [imgsz, imgsz] dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + dummy_inputs = [] + for _ in range(self.batch_size): + dummy_inputs.append(dummy_input) for i in range(self.model_warm_up): - algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + algo_model_exec.algo_model_exec.predict(source=dummy_inputs, imgsz=imgsz, verbose=False,save=False,save_txt=False) logger.info(f'warm up model {model_name} success!') logger.info(f'load model {model_name} success!') diff --git a/algo/scene_runner.py b/algo/scene_runner.py index e5c067e..5120f96 100644 --- a/algo/scene_runner.py +++ b/algo/scene_runner.py @@ -69,7 +69,7 @@ try: handle_task_name = scene.scene_handle_task handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) - handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) + handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop,scene.range_points) future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) self.device_tasks[device.id] = handler_instance logger.info(f'start thread {thread_id}, device info: {device}') diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 4133388..202213c 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -1,13 +1,20 @@ +from datetime import datetime, timedelta + import cv2 import time import numpy as np from threading import Thread, Event + +import queue + from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool class OpenCVStreamLoad: def __init__(self, camera_url, camera_code, device_thread_id = '', + batch_size=4, + queue_size=100, retry_interval=1, vid_stride=1): assert camera_url is not None and camera_url != '' @@ -19,10 +26,15 @@ self.init = False self.frame = None + self.frame_queue = queue.Queue(maxsize=queue_size) + self.fps_ts = None self.cap = None + self.batch_size = batch_size + self.frames_read = 0 self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 self.thread_pool = GlobalThreadPool() + self.init_cap() # self.thread = Thread(target=self.update, daemon=True) # self.thread.start() @@ -44,7 +56,14 @@ 尝试创建视频流捕获对象。 """ try: - cap = cv2.VideoCapture(self.url) + # cap = cv2.VideoCapture(self.url) + gst_pipeline = ( + f"rtspsrc location={self.url} ! " + f"rtph264depay ! h264parse ! " + f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + f"appsink" + ) + cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) # 可以在这里设置cap的一些属性,如果需要的话 return cap except Exception as e: @@ -74,9 +93,13 @@ cap = cv2.VideoCapture(self.url) # cap = cv2.VideoCapture(self.url, cv2.CAP_FFMPEG) - # 使用 GStreamer pipeline 作为 VideoCapture 的输入 - # gst_str = f"rtspsrc location={self.url} ! decodebin ! videoconvert ! appsink" - # cap = cv2.VideoCapture(gst_str, cv2.CAP_GSTREAMER) + # gst_pipeline = ( + # f"rtspsrc location={self.url} ! " + # f"rtph264depay ! h264parse ! " + # f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + # f"appsink" + # ) + # cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) if cap.isOpened(): logger.info(f"{self.url} connected successfully!") @@ -91,6 +114,15 @@ return cap + def log_fps(self): + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_read / 10.0 + self.frames_read = 0 + logger.info(f"FPS (read) for device {self.camera_code}: {fps}") + self.fps_ts = current_time + def update(self): vid_n = 0 log_n = 0 @@ -111,6 +143,9 @@ else: vid_n += 1 self.frame = frame + self.frames_read += 1 + if not self.frame_queue.full(): + self.frame_queue.put(frame) if log_n % 1000 == 0: logger.debug(f'{self.url} cap success') log_n = (log_n + 1) % 250 @@ -120,12 +155,22 @@ self.cap.release() self.frame = None self.cap = self.get_connect() # 尝试重新连接 + self.log_fps() def __iter__(self): return self def __next__(self): - return self.frame + batch_frames = [] + + queue_length = self.frame_queue.qsize() + if queue_length < self.batch_size: + return [] + + while not self.frame_queue.empty() and len(batch_frames) < self.batch_size: + frame = self.frame_queue.get() + batch_frames.append(frame) + return batch_frames def stop(self): logger.info(f'stop stream loader {self.url},device thread {self.device_thread_id} stopped') diff --git a/apis/control.py b/apis/control.py index 9c6d6fb..2d351d8 100644 --- a/apis/control.py +++ b/apis/control.py @@ -9,6 +9,7 @@ from fastapi import APIRouter import docker import socket +import signal from apis.base import standard_error_response, standard_response from common.global_logger import logger @@ -47,6 +48,23 @@ return None +def handle_sigterm(signum, frame): + print(f"Received signal: {signum}, shutting down gracefully...") + os._exit(0) + + +signal.signal(signal.SIGTERM, handle_sigterm) + + +def shutdown(): + # 等待短暂时间以确保响应已发送 + time.sleep(1) + logger.info("Shutting down the application...") + # sys.exit(0) + # os.kill(os.getpid(), signal.SIGTERM) + os._exit(0) + + @router.get("/restart") async def restart(): try: @@ -92,17 +110,20 @@ logger.error(f"Failed to restart container asynchronously: {ex}") # 在新线程中启动重启操作 - threading.Thread(target=restart_container_async).start() + # threading.Thread(target=restart_container_async).start() + threading.Thread(target=shutdown).start() return standard_response() except Exception as e: traceback.print_exc() return standard_error_response(code=500, message=f"Failed to restart container: {e}") + @router.get("/sync_test") def sync_test(): return standard_response() + @router.get("/async_test") async def async_test(): return standard_response() diff --git a/apis/device_scene_realtion.py b/apis/device_scene_realtion.py index d9aa3a2..77d3677 100644 --- a/apis/device_scene_realtion.py +++ b/apis/device_scene_realtion.py @@ -22,8 +22,8 @@ @router.post("/update_by_device", response_model=StandardResponse[DeviceSceneRelation]) -async def update_by_device(device_id: int, scene_id: int, +async def update_by_device(device_id: int, scene_id: int, range_points: str = None, db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - relation = await service.update_relation_by_device(device_id, scene_id) + relation = await service.update_relation_by_device(device_id, scene_id,range_points) return standard_response(data=relation) diff --git a/app_instance.py b/app_instance.py index 82a750d..c648f6e 100644 --- a/app_instance.py +++ b/app_instance.py @@ -1,4 +1,5 @@ import asyncio + from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException @@ -17,6 +18,7 @@ from services.scene_service import SceneService from services.schedule_job import start_scheduler from tcp.tcp_manager import TcpManager +from tcp.tcp_server import start_server _app = None # 创建一个私有变量来存储 app 实例 @@ -69,6 +71,8 @@ app.state.scene_runner = scene_runner await scene_runner.start() + main_loop.create_task(start_server()) + main_loop.create_task(start_scheduler()) yield # 允许请求处理 diff --git a/common/byte_utils.py b/common/byte_utils.py index 301bb8a..ea3f4ea 100644 --- a/common/byte_utils.py +++ b/common/byte_utils.py @@ -1,3 +1,11 @@ def format_bytes(data: bytes): - return ''.join(f'\\x{byte:02X}' for byte in data) \ No newline at end of file + return ''.join(f'\\x{byte:02X}' for byte in data) + + +if __name__ == '__main__': + msg = b'\xaa\x01\x00\x93\x12\x00\xA6' + print(format_bytes(msg)) + + msg = b'\xAA\x01\x00\x95\x00\x00\x96' + print(format_bytes(msg)) \ No newline at end of file diff --git a/common/consts.py b/common/consts.py index 5a3286c..6deaedf 100644 --- a/common/consts.py +++ b/common/consts.py @@ -36,3 +36,4 @@ GAS = 1 ALGO_RESULT = 2 ALARM = 3 + HARMFUL_GAS = 4 diff --git a/common/detect_utils.py b/common/detect_utils.py new file mode 100644 index 0000000..33c340c --- /dev/null +++ b/common/detect_utils.py @@ -0,0 +1,72 @@ +import matplotlib.path as mpath + + +def is_within_alert_range(bbox, range): + if range is None: + return True + bottom_left = [bbox[0], bbox[3]] + bottom_right = [bbox[2], bbox[3]] + + in_fence = is_point_in_polygon(bottom_left, range) or is_point_in_polygon(bottom_right, range) + return in_fence + + +def is_point_in_polygon(point, polygon): + """ + 判断点是否在多边形内 + + 参数: + - point: 待判断的点,格式为(x, y) + - polygon: 多边形的顶点坐标列表,格式为[(x1, y1), (x2, y2), ...] + + 返回: + - 如果点在多边形内,返回True;否则,返回False。 + """ + # 创建多边形路径对象 + # polygon = np.array(polygon, dtype=np.int32) + # + # hull = ConvexHull(polygon) + # sorted_coordinates = polygon[hull.vertices] + + poly_path = mpath.Path(polygon) + # 使用contains_point方法检查点是否在多边形内 + return poly_path.contains_point(point) + + +def intersection_area(bbox1, bbox2): + # 计算两个坐标框的重叠面积 + x1, y1, x2, y2 = bbox1 + x3, y3, x4, y4 = bbox2 + + xi1 = max(x1, x3) + yi1 = max(y1, y3) + xi2 = min(x2, x4) + yi2 = min(y2, y4) + + width = max(0, xi2 - xi1) + height = max(0, yi2 - yi1) + + return width * height + + +def bbox_area(bbox): + # 计算坐标框的面积 + x1, y1, x2, y2 = bbox + return (x2 - x1) * (y2 - y1) + + +def get_person_head(person_bbox, head_bboxes): + best_head = None + max_overlap = 0 + + for head in head_bboxes: + head_bbox = head.xyxy.cpu().squeeze() + overlap_area = intersection_area(person_bbox, head_bbox) + head_area = bbox_area(head_bbox) + overlap_ratio = overlap_area / head_area + + if overlap_ratio >= 0.8 and (best_head is None or overlap_area > max_overlap or ( + overlap_area == max_overlap and float(head.conf) > float(best_head.conf))): + best_head = head + max_overlap = overlap_area + return best_head diff --git a/common/harmful_gas_manager.py b/common/harmful_gas_manager.py new file mode 100644 index 0000000..297e02c --- /dev/null +++ b/common/harmful_gas_manager.py @@ -0,0 +1,44 @@ +import threading +from datetime import datetime + + +class HarmfulGasManager: + + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + # 确保线程安全的单例模式 + if not cls._instance: + with cls._lock: + if not cls._instance: + cls._instance = super(HarmfulGasManager, cls).__new__(cls) + return cls._instance + + def __init__(self): + # 初始化一次,避免重复初始化 + if not hasattr(self, "device_data"): + self.device_data = {} + self.lock = threading.Lock() + + def get_device_all_data(self, device_code): + """获取指定设备的在线状态""" + with self.lock: + return self.device_data.get(device_code, None) + + def get_device_data(self, device_code, gas_type): + """获取指定设备的在线状态""" + with self.lock: + device_data = self.device_data.get(device_code, None) + if device_data: + return device_data.get(gas_type, None) + return None + + def set_device_data(self, device_code, gas_type, gas_data): + """设置指定设备的在线状态""" + with self.lock: + device_data = self.device_data.get(device_code, None) + if not device_data: + device_data = {} + self.device_data[device_code] = device_data + device_data[gas_type] = gas_data \ No newline at end of file diff --git a/common/http_utils.py b/common/http_utils.py index 5f15707..a49c6ea 100644 --- a/common/http_utils.py +++ b/common/http_utils.py @@ -22,3 +22,11 @@ logger.info(f"Response: {response.status_code}, {response.text}") except requests.RequestException as e: logger.error(f"Failed to push data: {e}") + + +def get_request(url, headers=None): + try: + response = requests.get(url, headers=headers) + return response.json() + except requests.RequestException as e: + logger.error(f"Failed to get data: {e}") diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db index d1acdd4..694b825 100644 --- a/db/safe-algo-pro.db +++ b/db/safe-algo-pro.db Binary files differ diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py deleted file mode 100644 index 9109d75..0000000 --- a/algo/algo_runner_manager.py +++ /dev/null @@ -1,8 +0,0 @@ -from algo.algo_runner import AlgoRunner - -# algo_runner = AlgoRunner() - - -def get_algo_runner(): - pass - # return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 6a86830..d2c647e 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -3,7 +3,7 @@ from copy import deepcopy from dataclasses import dataclass import importlib -from datetime import datetime +from datetime import datetime, timedelta from threading import Event from typing import List, Dict @@ -48,6 +48,8 @@ self.__stop_event = Event() # 使用 Event 控制线程的运行状态 self.frame_ts = None self.push_ts = None + self.frames_detected = 0 + self.fps_ts = None self.thread_id = thread_id self.device_frame_service = DeviceFrameService(db_session) @@ -77,37 +79,40 @@ return True return False - def save_frame_results(self, frame, results_map): - if not self.check_frame_interval(): - return - # device_frame = self.device_frame_service.add_frame(self.device.id, frame) - # frame_id = device_frame.id + def save_frame_results(self, frames, result_maps): + for idx,frame in enumerate(frames): + results_map = result_maps[idx] - future = asyncio.run_coroutine_threadsafe( - self.device_frame_service.add_frame(self.device.id, frame), self.main_loop - ) - device_frame = future.result() - frame_id = device_frame.id + if not self.check_frame_interval(): + return + # device_frame = self.device_frame_service.add_frame(self.device.id, frame) + # frame_id = device_frame.id - logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') - frame_results = [] - for model_exec_id, results in results_map.items(): - for r in results: - frame_result = FrameAnalysisResultCreate( - device_id=self.device.id, - frame_id=frame_id, - algo_model_id=model_exec_id, - object_class_id=r.object_class_id, - object_class_name=r.object_class_name, - confidence=round(r.confidence, 4), - location=r.location, - ) - frame_results.append(frame_result) - asyncio.run_coroutine_threadsafe( - self.frame_analysis_result_service.add_frame_analysis_results(frame_results), - self.main_loop - ) - self.thread_pool.submit_task(self.push_frame_results, frame_results) + future = asyncio.run_coroutine_threadsafe( + self.device_frame_service.add_frame(self.device.id, frame), self.main_loop + ) + device_frame = future.result() + frame_id = device_frame.id + + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') + frame_results = [] + for model_exec_id, results in results_map.items(): + for r in results: + frame_result = FrameAnalysisResultCreate( + device_id=self.device.id, + frame_id=frame_id, + algo_model_id=model_exec_id, + object_class_id=r.object_class_id, + object_class_name=r.object_class_name, + confidence=round(r.confidence, 4), + location=r.location, + ) + frame_results.append(frame_result) + asyncio.run_coroutine_threadsafe( + self.frame_analysis_result_service.add_frame_analysis_results(frame_results), + self.main_loop + ) + self.thread_pool.submit_task(self.push_frame_results, frame_results) def push_frame_results(self, frame_results): global_config = GlobalConfig() @@ -124,31 +129,63 @@ ) self.push_ts = current_time # 更新推送时间戳 + def log_fps(self, frame_count): + self.frames_detected += frame_count + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_detected / 10.0 + self.frames_detected = 0 + logger.info(f"FPS (detect) for device {self.device.code}: {fps}") + self.fps_ts = current_time + def run(self): while not self.stream_loader.init: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 self.stream_loader.init_cap() - for frame in self.stream_loader: + for frames in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 # print('frame') - if frame is None: + if frames is None or len(frames)<=0: continue self.device_status_manager.set_status(device_id=self.device.id) - results_map = {} - annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + + result_maps = [] # 保存每个frame的result map + annotators = [] + for frame in frames: + annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + annotators.append(annotator) + for model_exec in self.model_exec_list: handle_task_name = model_exec.algo_model_info.handle_task handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) handler_instance = handler_cls(model_exec) - frame, results = handler_instance.run(frame, annotator) - results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] + frames, results = handler_instance.run(frames, annotators) + + # 遍历检测结果,按帧存储 + for frame_idx, r in enumerate(results): # 遍历每帧结果 + if len(result_maps) <= frame_idx: # 初始化 frame_results_map + result_maps.append({}) + frame_results_map = result_maps[frame_idx] + # 为当前模型存储检测结果 + if model_exec.algo_model_id not in frame_results_map: + frame_results_map[model_exec.algo_model_id] = [] + # 添加检测结果 + frame_results_map[model_exec.algo_model_id].extend( + DetectionResult.from_dict(box) for box in r + ) # 结果处理 - self.thread_pool.submit_task(self.save_frame_results, frame, results_map) - self.display_frame_manager.add_frame(self.device.id, annotator.result()) + self.thread_pool.submit_task(self.save_frame_results, frames, result_maps) + for annotator in annotators: + self.display_frame_manager.add_frame(self.device.id, annotator.result()) + + self.log_fps(len(frames)) + + # future = asyncio.run_coroutine_threadsafe( # self.display_frame_manager.add_frame(self.device.id, annotator.result()), self.main_loop # ) diff --git a/algo/model_manager.py b/algo/model_manager.py index 1c30867..e84676f 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -19,11 +19,12 @@ class ModelManager: - def __init__(self, model_service: ModelService, model_warm_up=5): + def __init__(self, model_service: ModelService, model_warm_up=5, batch_size=4): # self.db = db self.model_service = model_service self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up + self.batch_size=4 async def query_model_inuse(self): algo_model_list = list(await self.model_service.get_models_in_use()) @@ -58,8 +59,11 @@ if not isinstance(imgsz, list): imgsz = [imgsz, imgsz] dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + dummy_inputs = [] + for _ in range(self.batch_size): + dummy_inputs.append(dummy_input) for i in range(self.model_warm_up): - algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + algo_model_exec.algo_model_exec.predict(source=dummy_inputs, imgsz=imgsz, verbose=False,save=False,save_txt=False) logger.info(f'warm up model {model_name} success!') logger.info(f'load model {model_name} success!') diff --git a/algo/scene_runner.py b/algo/scene_runner.py index e5c067e..5120f96 100644 --- a/algo/scene_runner.py +++ b/algo/scene_runner.py @@ -69,7 +69,7 @@ try: handle_task_name = scene.scene_handle_task handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) - handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) + handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop,scene.range_points) future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) self.device_tasks[device.id] = handler_instance logger.info(f'start thread {thread_id}, device info: {device}') diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 4133388..202213c 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -1,13 +1,20 @@ +from datetime import datetime, timedelta + import cv2 import time import numpy as np from threading import Thread, Event + +import queue + from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool class OpenCVStreamLoad: def __init__(self, camera_url, camera_code, device_thread_id = '', + batch_size=4, + queue_size=100, retry_interval=1, vid_stride=1): assert camera_url is not None and camera_url != '' @@ -19,10 +26,15 @@ self.init = False self.frame = None + self.frame_queue = queue.Queue(maxsize=queue_size) + self.fps_ts = None self.cap = None + self.batch_size = batch_size + self.frames_read = 0 self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 self.thread_pool = GlobalThreadPool() + self.init_cap() # self.thread = Thread(target=self.update, daemon=True) # self.thread.start() @@ -44,7 +56,14 @@ 尝试创建视频流捕获对象。 """ try: - cap = cv2.VideoCapture(self.url) + # cap = cv2.VideoCapture(self.url) + gst_pipeline = ( + f"rtspsrc location={self.url} ! " + f"rtph264depay ! h264parse ! " + f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + f"appsink" + ) + cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) # 可以在这里设置cap的一些属性,如果需要的话 return cap except Exception as e: @@ -74,9 +93,13 @@ cap = cv2.VideoCapture(self.url) # cap = cv2.VideoCapture(self.url, cv2.CAP_FFMPEG) - # 使用 GStreamer pipeline 作为 VideoCapture 的输入 - # gst_str = f"rtspsrc location={self.url} ! decodebin ! videoconvert ! appsink" - # cap = cv2.VideoCapture(gst_str, cv2.CAP_GSTREAMER) + # gst_pipeline = ( + # f"rtspsrc location={self.url} ! " + # f"rtph264depay ! h264parse ! " + # f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + # f"appsink" + # ) + # cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) if cap.isOpened(): logger.info(f"{self.url} connected successfully!") @@ -91,6 +114,15 @@ return cap + def log_fps(self): + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_read / 10.0 + self.frames_read = 0 + logger.info(f"FPS (read) for device {self.camera_code}: {fps}") + self.fps_ts = current_time + def update(self): vid_n = 0 log_n = 0 @@ -111,6 +143,9 @@ else: vid_n += 1 self.frame = frame + self.frames_read += 1 + if not self.frame_queue.full(): + self.frame_queue.put(frame) if log_n % 1000 == 0: logger.debug(f'{self.url} cap success') log_n = (log_n + 1) % 250 @@ -120,12 +155,22 @@ self.cap.release() self.frame = None self.cap = self.get_connect() # 尝试重新连接 + self.log_fps() def __iter__(self): return self def __next__(self): - return self.frame + batch_frames = [] + + queue_length = self.frame_queue.qsize() + if queue_length < self.batch_size: + return [] + + while not self.frame_queue.empty() and len(batch_frames) < self.batch_size: + frame = self.frame_queue.get() + batch_frames.append(frame) + return batch_frames def stop(self): logger.info(f'stop stream loader {self.url},device thread {self.device_thread_id} stopped') diff --git a/apis/control.py b/apis/control.py index 9c6d6fb..2d351d8 100644 --- a/apis/control.py +++ b/apis/control.py @@ -9,6 +9,7 @@ from fastapi import APIRouter import docker import socket +import signal from apis.base import standard_error_response, standard_response from common.global_logger import logger @@ -47,6 +48,23 @@ return None +def handle_sigterm(signum, frame): + print(f"Received signal: {signum}, shutting down gracefully...") + os._exit(0) + + +signal.signal(signal.SIGTERM, handle_sigterm) + + +def shutdown(): + # 等待短暂时间以确保响应已发送 + time.sleep(1) + logger.info("Shutting down the application...") + # sys.exit(0) + # os.kill(os.getpid(), signal.SIGTERM) + os._exit(0) + + @router.get("/restart") async def restart(): try: @@ -92,17 +110,20 @@ logger.error(f"Failed to restart container asynchronously: {ex}") # 在新线程中启动重启操作 - threading.Thread(target=restart_container_async).start() + # threading.Thread(target=restart_container_async).start() + threading.Thread(target=shutdown).start() return standard_response() except Exception as e: traceback.print_exc() return standard_error_response(code=500, message=f"Failed to restart container: {e}") + @router.get("/sync_test") def sync_test(): return standard_response() + @router.get("/async_test") async def async_test(): return standard_response() diff --git a/apis/device_scene_realtion.py b/apis/device_scene_realtion.py index d9aa3a2..77d3677 100644 --- a/apis/device_scene_realtion.py +++ b/apis/device_scene_realtion.py @@ -22,8 +22,8 @@ @router.post("/update_by_device", response_model=StandardResponse[DeviceSceneRelation]) -async def update_by_device(device_id: int, scene_id: int, +async def update_by_device(device_id: int, scene_id: int, range_points: str = None, db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - relation = await service.update_relation_by_device(device_id, scene_id) + relation = await service.update_relation_by_device(device_id, scene_id,range_points) return standard_response(data=relation) diff --git a/app_instance.py b/app_instance.py index 82a750d..c648f6e 100644 --- a/app_instance.py +++ b/app_instance.py @@ -1,4 +1,5 @@ import asyncio + from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException @@ -17,6 +18,7 @@ from services.scene_service import SceneService from services.schedule_job import start_scheduler from tcp.tcp_manager import TcpManager +from tcp.tcp_server import start_server _app = None # 创建一个私有变量来存储 app 实例 @@ -69,6 +71,8 @@ app.state.scene_runner = scene_runner await scene_runner.start() + main_loop.create_task(start_server()) + main_loop.create_task(start_scheduler()) yield # 允许请求处理 diff --git a/common/byte_utils.py b/common/byte_utils.py index 301bb8a..ea3f4ea 100644 --- a/common/byte_utils.py +++ b/common/byte_utils.py @@ -1,3 +1,11 @@ def format_bytes(data: bytes): - return ''.join(f'\\x{byte:02X}' for byte in data) \ No newline at end of file + return ''.join(f'\\x{byte:02X}' for byte in data) + + +if __name__ == '__main__': + msg = b'\xaa\x01\x00\x93\x12\x00\xA6' + print(format_bytes(msg)) + + msg = b'\xAA\x01\x00\x95\x00\x00\x96' + print(format_bytes(msg)) \ No newline at end of file diff --git a/common/consts.py b/common/consts.py index 5a3286c..6deaedf 100644 --- a/common/consts.py +++ b/common/consts.py @@ -36,3 +36,4 @@ GAS = 1 ALGO_RESULT = 2 ALARM = 3 + HARMFUL_GAS = 4 diff --git a/common/detect_utils.py b/common/detect_utils.py new file mode 100644 index 0000000..33c340c --- /dev/null +++ b/common/detect_utils.py @@ -0,0 +1,72 @@ +import matplotlib.path as mpath + + +def is_within_alert_range(bbox, range): + if range is None: + return True + bottom_left = [bbox[0], bbox[3]] + bottom_right = [bbox[2], bbox[3]] + + in_fence = is_point_in_polygon(bottom_left, range) or is_point_in_polygon(bottom_right, range) + return in_fence + + +def is_point_in_polygon(point, polygon): + """ + 判断点是否在多边形内 + + 参数: + - point: 待判断的点,格式为(x, y) + - polygon: 多边形的顶点坐标列表,格式为[(x1, y1), (x2, y2), ...] + + 返回: + - 如果点在多边形内,返回True;否则,返回False。 + """ + # 创建多边形路径对象 + # polygon = np.array(polygon, dtype=np.int32) + # + # hull = ConvexHull(polygon) + # sorted_coordinates = polygon[hull.vertices] + + poly_path = mpath.Path(polygon) + # 使用contains_point方法检查点是否在多边形内 + return poly_path.contains_point(point) + + +def intersection_area(bbox1, bbox2): + # 计算两个坐标框的重叠面积 + x1, y1, x2, y2 = bbox1 + x3, y3, x4, y4 = bbox2 + + xi1 = max(x1, x3) + yi1 = max(y1, y3) + xi2 = min(x2, x4) + yi2 = min(y2, y4) + + width = max(0, xi2 - xi1) + height = max(0, yi2 - yi1) + + return width * height + + +def bbox_area(bbox): + # 计算坐标框的面积 + x1, y1, x2, y2 = bbox + return (x2 - x1) * (y2 - y1) + + +def get_person_head(person_bbox, head_bboxes): + best_head = None + max_overlap = 0 + + for head in head_bboxes: + head_bbox = head.xyxy.cpu().squeeze() + overlap_area = intersection_area(person_bbox, head_bbox) + head_area = bbox_area(head_bbox) + overlap_ratio = overlap_area / head_area + + if overlap_ratio >= 0.8 and (best_head is None or overlap_area > max_overlap or ( + overlap_area == max_overlap and float(head.conf) > float(best_head.conf))): + best_head = head + max_overlap = overlap_area + return best_head diff --git a/common/harmful_gas_manager.py b/common/harmful_gas_manager.py new file mode 100644 index 0000000..297e02c --- /dev/null +++ b/common/harmful_gas_manager.py @@ -0,0 +1,44 @@ +import threading +from datetime import datetime + + +class HarmfulGasManager: + + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + # 确保线程安全的单例模式 + if not cls._instance: + with cls._lock: + if not cls._instance: + cls._instance = super(HarmfulGasManager, cls).__new__(cls) + return cls._instance + + def __init__(self): + # 初始化一次,避免重复初始化 + if not hasattr(self, "device_data"): + self.device_data = {} + self.lock = threading.Lock() + + def get_device_all_data(self, device_code): + """获取指定设备的在线状态""" + with self.lock: + return self.device_data.get(device_code, None) + + def get_device_data(self, device_code, gas_type): + """获取指定设备的在线状态""" + with self.lock: + device_data = self.device_data.get(device_code, None) + if device_data: + return device_data.get(gas_type, None) + return None + + def set_device_data(self, device_code, gas_type, gas_data): + """设置指定设备的在线状态""" + with self.lock: + device_data = self.device_data.get(device_code, None) + if not device_data: + device_data = {} + self.device_data[device_code] = device_data + device_data[gas_type] = gas_data \ No newline at end of file diff --git a/common/http_utils.py b/common/http_utils.py index 5f15707..a49c6ea 100644 --- a/common/http_utils.py +++ b/common/http_utils.py @@ -22,3 +22,11 @@ logger.info(f"Response: {response.status_code}, {response.text}") except requests.RequestException as e: logger.error(f"Failed to push data: {e}") + + +def get_request(url, headers=None): + try: + response = requests.get(url, headers=headers) + return response.json() + except requests.RequestException as e: + logger.error(f"Failed to get data: {e}") diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db index d1acdd4..694b825 100644 --- a/db/safe-algo-pro.db +++ b/db/safe-algo-pro.db Binary files differ diff --git a/entity/alarm_record.py b/entity/alarm_record.py new file mode 100644 index 0000000..1ebe474 --- /dev/null +++ b/entity/alarm_record.py @@ -0,0 +1,32 @@ +from datetime import datetime +from typing import Optional + +from sqlmodel import Field, SQLModel + + +class AlarmRecordBase(SQLModel): + # todo tree code?? alarm_category?? + device_code: str + alarm_type: int + alarm_content: str + alarm_value: Optional[float] = None + alarm_image: Optional[str] = None + alarm_time: datetime = Field(default_factory=datetime.now) + + class Config: + json_encoders = { + datetime: lambda v: v.strftime('%Y-%m-%d %H:%M:%S') + } + + +class AlarmRecord(AlarmRecordBase, table=True): + __tablename__ = 'alarm_record' + id: Optional[int] = Field(default=None, primary_key=True) + + +class AlarmRecordCreate(AlarmRecordBase): + pass + + +class AlarmRecordInfo(AlarmRecordBase): + id: int diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py deleted file mode 100644 index 9109d75..0000000 --- a/algo/algo_runner_manager.py +++ /dev/null @@ -1,8 +0,0 @@ -from algo.algo_runner import AlgoRunner - -# algo_runner = AlgoRunner() - - -def get_algo_runner(): - pass - # return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 6a86830..d2c647e 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -3,7 +3,7 @@ from copy import deepcopy from dataclasses import dataclass import importlib -from datetime import datetime +from datetime import datetime, timedelta from threading import Event from typing import List, Dict @@ -48,6 +48,8 @@ self.__stop_event = Event() # 使用 Event 控制线程的运行状态 self.frame_ts = None self.push_ts = None + self.frames_detected = 0 + self.fps_ts = None self.thread_id = thread_id self.device_frame_service = DeviceFrameService(db_session) @@ -77,37 +79,40 @@ return True return False - def save_frame_results(self, frame, results_map): - if not self.check_frame_interval(): - return - # device_frame = self.device_frame_service.add_frame(self.device.id, frame) - # frame_id = device_frame.id + def save_frame_results(self, frames, result_maps): + for idx,frame in enumerate(frames): + results_map = result_maps[idx] - future = asyncio.run_coroutine_threadsafe( - self.device_frame_service.add_frame(self.device.id, frame), self.main_loop - ) - device_frame = future.result() - frame_id = device_frame.id + if not self.check_frame_interval(): + return + # device_frame = self.device_frame_service.add_frame(self.device.id, frame) + # frame_id = device_frame.id - logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') - frame_results = [] - for model_exec_id, results in results_map.items(): - for r in results: - frame_result = FrameAnalysisResultCreate( - device_id=self.device.id, - frame_id=frame_id, - algo_model_id=model_exec_id, - object_class_id=r.object_class_id, - object_class_name=r.object_class_name, - confidence=round(r.confidence, 4), - location=r.location, - ) - frame_results.append(frame_result) - asyncio.run_coroutine_threadsafe( - self.frame_analysis_result_service.add_frame_analysis_results(frame_results), - self.main_loop - ) - self.thread_pool.submit_task(self.push_frame_results, frame_results) + future = asyncio.run_coroutine_threadsafe( + self.device_frame_service.add_frame(self.device.id, frame), self.main_loop + ) + device_frame = future.result() + frame_id = device_frame.id + + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') + frame_results = [] + for model_exec_id, results in results_map.items(): + for r in results: + frame_result = FrameAnalysisResultCreate( + device_id=self.device.id, + frame_id=frame_id, + algo_model_id=model_exec_id, + object_class_id=r.object_class_id, + object_class_name=r.object_class_name, + confidence=round(r.confidence, 4), + location=r.location, + ) + frame_results.append(frame_result) + asyncio.run_coroutine_threadsafe( + self.frame_analysis_result_service.add_frame_analysis_results(frame_results), + self.main_loop + ) + self.thread_pool.submit_task(self.push_frame_results, frame_results) def push_frame_results(self, frame_results): global_config = GlobalConfig() @@ -124,31 +129,63 @@ ) self.push_ts = current_time # 更新推送时间戳 + def log_fps(self, frame_count): + self.frames_detected += frame_count + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_detected / 10.0 + self.frames_detected = 0 + logger.info(f"FPS (detect) for device {self.device.code}: {fps}") + self.fps_ts = current_time + def run(self): while not self.stream_loader.init: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 self.stream_loader.init_cap() - for frame in self.stream_loader: + for frames in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 # print('frame') - if frame is None: + if frames is None or len(frames)<=0: continue self.device_status_manager.set_status(device_id=self.device.id) - results_map = {} - annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + + result_maps = [] # 保存每个frame的result map + annotators = [] + for frame in frames: + annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + annotators.append(annotator) + for model_exec in self.model_exec_list: handle_task_name = model_exec.algo_model_info.handle_task handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) handler_instance = handler_cls(model_exec) - frame, results = handler_instance.run(frame, annotator) - results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] + frames, results = handler_instance.run(frames, annotators) + + # 遍历检测结果,按帧存储 + for frame_idx, r in enumerate(results): # 遍历每帧结果 + if len(result_maps) <= frame_idx: # 初始化 frame_results_map + result_maps.append({}) + frame_results_map = result_maps[frame_idx] + # 为当前模型存储检测结果 + if model_exec.algo_model_id not in frame_results_map: + frame_results_map[model_exec.algo_model_id] = [] + # 添加检测结果 + frame_results_map[model_exec.algo_model_id].extend( + DetectionResult.from_dict(box) for box in r + ) # 结果处理 - self.thread_pool.submit_task(self.save_frame_results, frame, results_map) - self.display_frame_manager.add_frame(self.device.id, annotator.result()) + self.thread_pool.submit_task(self.save_frame_results, frames, result_maps) + for annotator in annotators: + self.display_frame_manager.add_frame(self.device.id, annotator.result()) + + self.log_fps(len(frames)) + + # future = asyncio.run_coroutine_threadsafe( # self.display_frame_manager.add_frame(self.device.id, annotator.result()), self.main_loop # ) diff --git a/algo/model_manager.py b/algo/model_manager.py index 1c30867..e84676f 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -19,11 +19,12 @@ class ModelManager: - def __init__(self, model_service: ModelService, model_warm_up=5): + def __init__(self, model_service: ModelService, model_warm_up=5, batch_size=4): # self.db = db self.model_service = model_service self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up + self.batch_size=4 async def query_model_inuse(self): algo_model_list = list(await self.model_service.get_models_in_use()) @@ -58,8 +59,11 @@ if not isinstance(imgsz, list): imgsz = [imgsz, imgsz] dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + dummy_inputs = [] + for _ in range(self.batch_size): + dummy_inputs.append(dummy_input) for i in range(self.model_warm_up): - algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + algo_model_exec.algo_model_exec.predict(source=dummy_inputs, imgsz=imgsz, verbose=False,save=False,save_txt=False) logger.info(f'warm up model {model_name} success!') logger.info(f'load model {model_name} success!') diff --git a/algo/scene_runner.py b/algo/scene_runner.py index e5c067e..5120f96 100644 --- a/algo/scene_runner.py +++ b/algo/scene_runner.py @@ -69,7 +69,7 @@ try: handle_task_name = scene.scene_handle_task handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) - handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) + handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop,scene.range_points) future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) self.device_tasks[device.id] = handler_instance logger.info(f'start thread {thread_id}, device info: {device}') diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 4133388..202213c 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -1,13 +1,20 @@ +from datetime import datetime, timedelta + import cv2 import time import numpy as np from threading import Thread, Event + +import queue + from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool class OpenCVStreamLoad: def __init__(self, camera_url, camera_code, device_thread_id = '', + batch_size=4, + queue_size=100, retry_interval=1, vid_stride=1): assert camera_url is not None and camera_url != '' @@ -19,10 +26,15 @@ self.init = False self.frame = None + self.frame_queue = queue.Queue(maxsize=queue_size) + self.fps_ts = None self.cap = None + self.batch_size = batch_size + self.frames_read = 0 self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 self.thread_pool = GlobalThreadPool() + self.init_cap() # self.thread = Thread(target=self.update, daemon=True) # self.thread.start() @@ -44,7 +56,14 @@ 尝试创建视频流捕获对象。 """ try: - cap = cv2.VideoCapture(self.url) + # cap = cv2.VideoCapture(self.url) + gst_pipeline = ( + f"rtspsrc location={self.url} ! " + f"rtph264depay ! h264parse ! " + f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + f"appsink" + ) + cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) # 可以在这里设置cap的一些属性,如果需要的话 return cap except Exception as e: @@ -74,9 +93,13 @@ cap = cv2.VideoCapture(self.url) # cap = cv2.VideoCapture(self.url, cv2.CAP_FFMPEG) - # 使用 GStreamer pipeline 作为 VideoCapture 的输入 - # gst_str = f"rtspsrc location={self.url} ! decodebin ! videoconvert ! appsink" - # cap = cv2.VideoCapture(gst_str, cv2.CAP_GSTREAMER) + # gst_pipeline = ( + # f"rtspsrc location={self.url} ! " + # f"rtph264depay ! h264parse ! " + # f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + # f"appsink" + # ) + # cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) if cap.isOpened(): logger.info(f"{self.url} connected successfully!") @@ -91,6 +114,15 @@ return cap + def log_fps(self): + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_read / 10.0 + self.frames_read = 0 + logger.info(f"FPS (read) for device {self.camera_code}: {fps}") + self.fps_ts = current_time + def update(self): vid_n = 0 log_n = 0 @@ -111,6 +143,9 @@ else: vid_n += 1 self.frame = frame + self.frames_read += 1 + if not self.frame_queue.full(): + self.frame_queue.put(frame) if log_n % 1000 == 0: logger.debug(f'{self.url} cap success') log_n = (log_n + 1) % 250 @@ -120,12 +155,22 @@ self.cap.release() self.frame = None self.cap = self.get_connect() # 尝试重新连接 + self.log_fps() def __iter__(self): return self def __next__(self): - return self.frame + batch_frames = [] + + queue_length = self.frame_queue.qsize() + if queue_length < self.batch_size: + return [] + + while not self.frame_queue.empty() and len(batch_frames) < self.batch_size: + frame = self.frame_queue.get() + batch_frames.append(frame) + return batch_frames def stop(self): logger.info(f'stop stream loader {self.url},device thread {self.device_thread_id} stopped') diff --git a/apis/control.py b/apis/control.py index 9c6d6fb..2d351d8 100644 --- a/apis/control.py +++ b/apis/control.py @@ -9,6 +9,7 @@ from fastapi import APIRouter import docker import socket +import signal from apis.base import standard_error_response, standard_response from common.global_logger import logger @@ -47,6 +48,23 @@ return None +def handle_sigterm(signum, frame): + print(f"Received signal: {signum}, shutting down gracefully...") + os._exit(0) + + +signal.signal(signal.SIGTERM, handle_sigterm) + + +def shutdown(): + # 等待短暂时间以确保响应已发送 + time.sleep(1) + logger.info("Shutting down the application...") + # sys.exit(0) + # os.kill(os.getpid(), signal.SIGTERM) + os._exit(0) + + @router.get("/restart") async def restart(): try: @@ -92,17 +110,20 @@ logger.error(f"Failed to restart container asynchronously: {ex}") # 在新线程中启动重启操作 - threading.Thread(target=restart_container_async).start() + # threading.Thread(target=restart_container_async).start() + threading.Thread(target=shutdown).start() return standard_response() except Exception as e: traceback.print_exc() return standard_error_response(code=500, message=f"Failed to restart container: {e}") + @router.get("/sync_test") def sync_test(): return standard_response() + @router.get("/async_test") async def async_test(): return standard_response() diff --git a/apis/device_scene_realtion.py b/apis/device_scene_realtion.py index d9aa3a2..77d3677 100644 --- a/apis/device_scene_realtion.py +++ b/apis/device_scene_realtion.py @@ -22,8 +22,8 @@ @router.post("/update_by_device", response_model=StandardResponse[DeviceSceneRelation]) -async def update_by_device(device_id: int, scene_id: int, +async def update_by_device(device_id: int, scene_id: int, range_points: str = None, db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - relation = await service.update_relation_by_device(device_id, scene_id) + relation = await service.update_relation_by_device(device_id, scene_id,range_points) return standard_response(data=relation) diff --git a/app_instance.py b/app_instance.py index 82a750d..c648f6e 100644 --- a/app_instance.py +++ b/app_instance.py @@ -1,4 +1,5 @@ import asyncio + from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException @@ -17,6 +18,7 @@ from services.scene_service import SceneService from services.schedule_job import start_scheduler from tcp.tcp_manager import TcpManager +from tcp.tcp_server import start_server _app = None # 创建一个私有变量来存储 app 实例 @@ -69,6 +71,8 @@ app.state.scene_runner = scene_runner await scene_runner.start() + main_loop.create_task(start_server()) + main_loop.create_task(start_scheduler()) yield # 允许请求处理 diff --git a/common/byte_utils.py b/common/byte_utils.py index 301bb8a..ea3f4ea 100644 --- a/common/byte_utils.py +++ b/common/byte_utils.py @@ -1,3 +1,11 @@ def format_bytes(data: bytes): - return ''.join(f'\\x{byte:02X}' for byte in data) \ No newline at end of file + return ''.join(f'\\x{byte:02X}' for byte in data) + + +if __name__ == '__main__': + msg = b'\xaa\x01\x00\x93\x12\x00\xA6' + print(format_bytes(msg)) + + msg = b'\xAA\x01\x00\x95\x00\x00\x96' + print(format_bytes(msg)) \ No newline at end of file diff --git a/common/consts.py b/common/consts.py index 5a3286c..6deaedf 100644 --- a/common/consts.py +++ b/common/consts.py @@ -36,3 +36,4 @@ GAS = 1 ALGO_RESULT = 2 ALARM = 3 + HARMFUL_GAS = 4 diff --git a/common/detect_utils.py b/common/detect_utils.py new file mode 100644 index 0000000..33c340c --- /dev/null +++ b/common/detect_utils.py @@ -0,0 +1,72 @@ +import matplotlib.path as mpath + + +def is_within_alert_range(bbox, range): + if range is None: + return True + bottom_left = [bbox[0], bbox[3]] + bottom_right = [bbox[2], bbox[3]] + + in_fence = is_point_in_polygon(bottom_left, range) or is_point_in_polygon(bottom_right, range) + return in_fence + + +def is_point_in_polygon(point, polygon): + """ + 判断点是否在多边形内 + + 参数: + - point: 待判断的点,格式为(x, y) + - polygon: 多边形的顶点坐标列表,格式为[(x1, y1), (x2, y2), ...] + + 返回: + - 如果点在多边形内,返回True;否则,返回False。 + """ + # 创建多边形路径对象 + # polygon = np.array(polygon, dtype=np.int32) + # + # hull = ConvexHull(polygon) + # sorted_coordinates = polygon[hull.vertices] + + poly_path = mpath.Path(polygon) + # 使用contains_point方法检查点是否在多边形内 + return poly_path.contains_point(point) + + +def intersection_area(bbox1, bbox2): + # 计算两个坐标框的重叠面积 + x1, y1, x2, y2 = bbox1 + x3, y3, x4, y4 = bbox2 + + xi1 = max(x1, x3) + yi1 = max(y1, y3) + xi2 = min(x2, x4) + yi2 = min(y2, y4) + + width = max(0, xi2 - xi1) + height = max(0, yi2 - yi1) + + return width * height + + +def bbox_area(bbox): + # 计算坐标框的面积 + x1, y1, x2, y2 = bbox + return (x2 - x1) * (y2 - y1) + + +def get_person_head(person_bbox, head_bboxes): + best_head = None + max_overlap = 0 + + for head in head_bboxes: + head_bbox = head.xyxy.cpu().squeeze() + overlap_area = intersection_area(person_bbox, head_bbox) + head_area = bbox_area(head_bbox) + overlap_ratio = overlap_area / head_area + + if overlap_ratio >= 0.8 and (best_head is None or overlap_area > max_overlap or ( + overlap_area == max_overlap and float(head.conf) > float(best_head.conf))): + best_head = head + max_overlap = overlap_area + return best_head diff --git a/common/harmful_gas_manager.py b/common/harmful_gas_manager.py new file mode 100644 index 0000000..297e02c --- /dev/null +++ b/common/harmful_gas_manager.py @@ -0,0 +1,44 @@ +import threading +from datetime import datetime + + +class HarmfulGasManager: + + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + # 确保线程安全的单例模式 + if not cls._instance: + with cls._lock: + if not cls._instance: + cls._instance = super(HarmfulGasManager, cls).__new__(cls) + return cls._instance + + def __init__(self): + # 初始化一次,避免重复初始化 + if not hasattr(self, "device_data"): + self.device_data = {} + self.lock = threading.Lock() + + def get_device_all_data(self, device_code): + """获取指定设备的在线状态""" + with self.lock: + return self.device_data.get(device_code, None) + + def get_device_data(self, device_code, gas_type): + """获取指定设备的在线状态""" + with self.lock: + device_data = self.device_data.get(device_code, None) + if device_data: + return device_data.get(gas_type, None) + return None + + def set_device_data(self, device_code, gas_type, gas_data): + """设置指定设备的在线状态""" + with self.lock: + device_data = self.device_data.get(device_code, None) + if not device_data: + device_data = {} + self.device_data[device_code] = device_data + device_data[gas_type] = gas_data \ No newline at end of file diff --git a/common/http_utils.py b/common/http_utils.py index 5f15707..a49c6ea 100644 --- a/common/http_utils.py +++ b/common/http_utils.py @@ -22,3 +22,11 @@ logger.info(f"Response: {response.status_code}, {response.text}") except requests.RequestException as e: logger.error(f"Failed to push data: {e}") + + +def get_request(url, headers=None): + try: + response = requests.get(url, headers=headers) + return response.json() + except requests.RequestException as e: + logger.error(f"Failed to get data: {e}") diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db index d1acdd4..694b825 100644 --- a/db/safe-algo-pro.db +++ b/db/safe-algo-pro.db Binary files differ diff --git a/entity/alarm_record.py b/entity/alarm_record.py new file mode 100644 index 0000000..1ebe474 --- /dev/null +++ b/entity/alarm_record.py @@ -0,0 +1,32 @@ +from datetime import datetime +from typing import Optional + +from sqlmodel import Field, SQLModel + + +class AlarmRecordBase(SQLModel): + # todo tree code?? alarm_category?? + device_code: str + alarm_type: int + alarm_content: str + alarm_value: Optional[float] = None + alarm_image: Optional[str] = None + alarm_time: datetime = Field(default_factory=datetime.now) + + class Config: + json_encoders = { + datetime: lambda v: v.strftime('%Y-%m-%d %H:%M:%S') + } + + +class AlarmRecord(AlarmRecordBase, table=True): + __tablename__ = 'alarm_record' + id: Optional[int] = Field(default=None, primary_key=True) + + +class AlarmRecordCreate(AlarmRecordBase): + pass + + +class AlarmRecordInfo(AlarmRecordBase): + id: int diff --git a/entity/device_scene_relation.py b/entity/device_scene_relation.py index c4543d7..38481ef 100644 --- a/entity/device_scene_relation.py +++ b/entity/device_scene_relation.py @@ -8,6 +8,7 @@ class DeviceSceneRelationBase(SQLModel): scene_id: int device_id: int + range_points: Optional[str] = None class DeviceSceneRelation(DeviceSceneRelationBase, TimestampMixin, table=True): diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py deleted file mode 100644 index 9109d75..0000000 --- a/algo/algo_runner_manager.py +++ /dev/null @@ -1,8 +0,0 @@ -from algo.algo_runner import AlgoRunner - -# algo_runner = AlgoRunner() - - -def get_algo_runner(): - pass - # return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 6a86830..d2c647e 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -3,7 +3,7 @@ from copy import deepcopy from dataclasses import dataclass import importlib -from datetime import datetime +from datetime import datetime, timedelta from threading import Event from typing import List, Dict @@ -48,6 +48,8 @@ self.__stop_event = Event() # 使用 Event 控制线程的运行状态 self.frame_ts = None self.push_ts = None + self.frames_detected = 0 + self.fps_ts = None self.thread_id = thread_id self.device_frame_service = DeviceFrameService(db_session) @@ -77,37 +79,40 @@ return True return False - def save_frame_results(self, frame, results_map): - if not self.check_frame_interval(): - return - # device_frame = self.device_frame_service.add_frame(self.device.id, frame) - # frame_id = device_frame.id + def save_frame_results(self, frames, result_maps): + for idx,frame in enumerate(frames): + results_map = result_maps[idx] - future = asyncio.run_coroutine_threadsafe( - self.device_frame_service.add_frame(self.device.id, frame), self.main_loop - ) - device_frame = future.result() - frame_id = device_frame.id + if not self.check_frame_interval(): + return + # device_frame = self.device_frame_service.add_frame(self.device.id, frame) + # frame_id = device_frame.id - logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') - frame_results = [] - for model_exec_id, results in results_map.items(): - for r in results: - frame_result = FrameAnalysisResultCreate( - device_id=self.device.id, - frame_id=frame_id, - algo_model_id=model_exec_id, - object_class_id=r.object_class_id, - object_class_name=r.object_class_name, - confidence=round(r.confidence, 4), - location=r.location, - ) - frame_results.append(frame_result) - asyncio.run_coroutine_threadsafe( - self.frame_analysis_result_service.add_frame_analysis_results(frame_results), - self.main_loop - ) - self.thread_pool.submit_task(self.push_frame_results, frame_results) + future = asyncio.run_coroutine_threadsafe( + self.device_frame_service.add_frame(self.device.id, frame), self.main_loop + ) + device_frame = future.result() + frame_id = device_frame.id + + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') + frame_results = [] + for model_exec_id, results in results_map.items(): + for r in results: + frame_result = FrameAnalysisResultCreate( + device_id=self.device.id, + frame_id=frame_id, + algo_model_id=model_exec_id, + object_class_id=r.object_class_id, + object_class_name=r.object_class_name, + confidence=round(r.confidence, 4), + location=r.location, + ) + frame_results.append(frame_result) + asyncio.run_coroutine_threadsafe( + self.frame_analysis_result_service.add_frame_analysis_results(frame_results), + self.main_loop + ) + self.thread_pool.submit_task(self.push_frame_results, frame_results) def push_frame_results(self, frame_results): global_config = GlobalConfig() @@ -124,31 +129,63 @@ ) self.push_ts = current_time # 更新推送时间戳 + def log_fps(self, frame_count): + self.frames_detected += frame_count + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_detected / 10.0 + self.frames_detected = 0 + logger.info(f"FPS (detect) for device {self.device.code}: {fps}") + self.fps_ts = current_time + def run(self): while not self.stream_loader.init: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 self.stream_loader.init_cap() - for frame in self.stream_loader: + for frames in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 # print('frame') - if frame is None: + if frames is None or len(frames)<=0: continue self.device_status_manager.set_status(device_id=self.device.id) - results_map = {} - annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + + result_maps = [] # 保存每个frame的result map + annotators = [] + for frame in frames: + annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + annotators.append(annotator) + for model_exec in self.model_exec_list: handle_task_name = model_exec.algo_model_info.handle_task handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) handler_instance = handler_cls(model_exec) - frame, results = handler_instance.run(frame, annotator) - results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] + frames, results = handler_instance.run(frames, annotators) + + # 遍历检测结果,按帧存储 + for frame_idx, r in enumerate(results): # 遍历每帧结果 + if len(result_maps) <= frame_idx: # 初始化 frame_results_map + result_maps.append({}) + frame_results_map = result_maps[frame_idx] + # 为当前模型存储检测结果 + if model_exec.algo_model_id not in frame_results_map: + frame_results_map[model_exec.algo_model_id] = [] + # 添加检测结果 + frame_results_map[model_exec.algo_model_id].extend( + DetectionResult.from_dict(box) for box in r + ) # 结果处理 - self.thread_pool.submit_task(self.save_frame_results, frame, results_map) - self.display_frame_manager.add_frame(self.device.id, annotator.result()) + self.thread_pool.submit_task(self.save_frame_results, frames, result_maps) + for annotator in annotators: + self.display_frame_manager.add_frame(self.device.id, annotator.result()) + + self.log_fps(len(frames)) + + # future = asyncio.run_coroutine_threadsafe( # self.display_frame_manager.add_frame(self.device.id, annotator.result()), self.main_loop # ) diff --git a/algo/model_manager.py b/algo/model_manager.py index 1c30867..e84676f 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -19,11 +19,12 @@ class ModelManager: - def __init__(self, model_service: ModelService, model_warm_up=5): + def __init__(self, model_service: ModelService, model_warm_up=5, batch_size=4): # self.db = db self.model_service = model_service self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up + self.batch_size=4 async def query_model_inuse(self): algo_model_list = list(await self.model_service.get_models_in_use()) @@ -58,8 +59,11 @@ if not isinstance(imgsz, list): imgsz = [imgsz, imgsz] dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + dummy_inputs = [] + for _ in range(self.batch_size): + dummy_inputs.append(dummy_input) for i in range(self.model_warm_up): - algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + algo_model_exec.algo_model_exec.predict(source=dummy_inputs, imgsz=imgsz, verbose=False,save=False,save_txt=False) logger.info(f'warm up model {model_name} success!') logger.info(f'load model {model_name} success!') diff --git a/algo/scene_runner.py b/algo/scene_runner.py index e5c067e..5120f96 100644 --- a/algo/scene_runner.py +++ b/algo/scene_runner.py @@ -69,7 +69,7 @@ try: handle_task_name = scene.scene_handle_task handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) - handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) + handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop,scene.range_points) future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) self.device_tasks[device.id] = handler_instance logger.info(f'start thread {thread_id}, device info: {device}') diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 4133388..202213c 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -1,13 +1,20 @@ +from datetime import datetime, timedelta + import cv2 import time import numpy as np from threading import Thread, Event + +import queue + from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool class OpenCVStreamLoad: def __init__(self, camera_url, camera_code, device_thread_id = '', + batch_size=4, + queue_size=100, retry_interval=1, vid_stride=1): assert camera_url is not None and camera_url != '' @@ -19,10 +26,15 @@ self.init = False self.frame = None + self.frame_queue = queue.Queue(maxsize=queue_size) + self.fps_ts = None self.cap = None + self.batch_size = batch_size + self.frames_read = 0 self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 self.thread_pool = GlobalThreadPool() + self.init_cap() # self.thread = Thread(target=self.update, daemon=True) # self.thread.start() @@ -44,7 +56,14 @@ 尝试创建视频流捕获对象。 """ try: - cap = cv2.VideoCapture(self.url) + # cap = cv2.VideoCapture(self.url) + gst_pipeline = ( + f"rtspsrc location={self.url} ! " + f"rtph264depay ! h264parse ! " + f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + f"appsink" + ) + cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) # 可以在这里设置cap的一些属性,如果需要的话 return cap except Exception as e: @@ -74,9 +93,13 @@ cap = cv2.VideoCapture(self.url) # cap = cv2.VideoCapture(self.url, cv2.CAP_FFMPEG) - # 使用 GStreamer pipeline 作为 VideoCapture 的输入 - # gst_str = f"rtspsrc location={self.url} ! decodebin ! videoconvert ! appsink" - # cap = cv2.VideoCapture(gst_str, cv2.CAP_GSTREAMER) + # gst_pipeline = ( + # f"rtspsrc location={self.url} ! " + # f"rtph264depay ! h264parse ! " + # f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + # f"appsink" + # ) + # cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) if cap.isOpened(): logger.info(f"{self.url} connected successfully!") @@ -91,6 +114,15 @@ return cap + def log_fps(self): + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_read / 10.0 + self.frames_read = 0 + logger.info(f"FPS (read) for device {self.camera_code}: {fps}") + self.fps_ts = current_time + def update(self): vid_n = 0 log_n = 0 @@ -111,6 +143,9 @@ else: vid_n += 1 self.frame = frame + self.frames_read += 1 + if not self.frame_queue.full(): + self.frame_queue.put(frame) if log_n % 1000 == 0: logger.debug(f'{self.url} cap success') log_n = (log_n + 1) % 250 @@ -120,12 +155,22 @@ self.cap.release() self.frame = None self.cap = self.get_connect() # 尝试重新连接 + self.log_fps() def __iter__(self): return self def __next__(self): - return self.frame + batch_frames = [] + + queue_length = self.frame_queue.qsize() + if queue_length < self.batch_size: + return [] + + while not self.frame_queue.empty() and len(batch_frames) < self.batch_size: + frame = self.frame_queue.get() + batch_frames.append(frame) + return batch_frames def stop(self): logger.info(f'stop stream loader {self.url},device thread {self.device_thread_id} stopped') diff --git a/apis/control.py b/apis/control.py index 9c6d6fb..2d351d8 100644 --- a/apis/control.py +++ b/apis/control.py @@ -9,6 +9,7 @@ from fastapi import APIRouter import docker import socket +import signal from apis.base import standard_error_response, standard_response from common.global_logger import logger @@ -47,6 +48,23 @@ return None +def handle_sigterm(signum, frame): + print(f"Received signal: {signum}, shutting down gracefully...") + os._exit(0) + + +signal.signal(signal.SIGTERM, handle_sigterm) + + +def shutdown(): + # 等待短暂时间以确保响应已发送 + time.sleep(1) + logger.info("Shutting down the application...") + # sys.exit(0) + # os.kill(os.getpid(), signal.SIGTERM) + os._exit(0) + + @router.get("/restart") async def restart(): try: @@ -92,17 +110,20 @@ logger.error(f"Failed to restart container asynchronously: {ex}") # 在新线程中启动重启操作 - threading.Thread(target=restart_container_async).start() + # threading.Thread(target=restart_container_async).start() + threading.Thread(target=shutdown).start() return standard_response() except Exception as e: traceback.print_exc() return standard_error_response(code=500, message=f"Failed to restart container: {e}") + @router.get("/sync_test") def sync_test(): return standard_response() + @router.get("/async_test") async def async_test(): return standard_response() diff --git a/apis/device_scene_realtion.py b/apis/device_scene_realtion.py index d9aa3a2..77d3677 100644 --- a/apis/device_scene_realtion.py +++ b/apis/device_scene_realtion.py @@ -22,8 +22,8 @@ @router.post("/update_by_device", response_model=StandardResponse[DeviceSceneRelation]) -async def update_by_device(device_id: int, scene_id: int, +async def update_by_device(device_id: int, scene_id: int, range_points: str = None, db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - relation = await service.update_relation_by_device(device_id, scene_id) + relation = await service.update_relation_by_device(device_id, scene_id,range_points) return standard_response(data=relation) diff --git a/app_instance.py b/app_instance.py index 82a750d..c648f6e 100644 --- a/app_instance.py +++ b/app_instance.py @@ -1,4 +1,5 @@ import asyncio + from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException @@ -17,6 +18,7 @@ from services.scene_service import SceneService from services.schedule_job import start_scheduler from tcp.tcp_manager import TcpManager +from tcp.tcp_server import start_server _app = None # 创建一个私有变量来存储 app 实例 @@ -69,6 +71,8 @@ app.state.scene_runner = scene_runner await scene_runner.start() + main_loop.create_task(start_server()) + main_loop.create_task(start_scheduler()) yield # 允许请求处理 diff --git a/common/byte_utils.py b/common/byte_utils.py index 301bb8a..ea3f4ea 100644 --- a/common/byte_utils.py +++ b/common/byte_utils.py @@ -1,3 +1,11 @@ def format_bytes(data: bytes): - return ''.join(f'\\x{byte:02X}' for byte in data) \ No newline at end of file + return ''.join(f'\\x{byte:02X}' for byte in data) + + +if __name__ == '__main__': + msg = b'\xaa\x01\x00\x93\x12\x00\xA6' + print(format_bytes(msg)) + + msg = b'\xAA\x01\x00\x95\x00\x00\x96' + print(format_bytes(msg)) \ No newline at end of file diff --git a/common/consts.py b/common/consts.py index 5a3286c..6deaedf 100644 --- a/common/consts.py +++ b/common/consts.py @@ -36,3 +36,4 @@ GAS = 1 ALGO_RESULT = 2 ALARM = 3 + HARMFUL_GAS = 4 diff --git a/common/detect_utils.py b/common/detect_utils.py new file mode 100644 index 0000000..33c340c --- /dev/null +++ b/common/detect_utils.py @@ -0,0 +1,72 @@ +import matplotlib.path as mpath + + +def is_within_alert_range(bbox, range): + if range is None: + return True + bottom_left = [bbox[0], bbox[3]] + bottom_right = [bbox[2], bbox[3]] + + in_fence = is_point_in_polygon(bottom_left, range) or is_point_in_polygon(bottom_right, range) + return in_fence + + +def is_point_in_polygon(point, polygon): + """ + 判断点是否在多边形内 + + 参数: + - point: 待判断的点,格式为(x, y) + - polygon: 多边形的顶点坐标列表,格式为[(x1, y1), (x2, y2), ...] + + 返回: + - 如果点在多边形内,返回True;否则,返回False。 + """ + # 创建多边形路径对象 + # polygon = np.array(polygon, dtype=np.int32) + # + # hull = ConvexHull(polygon) + # sorted_coordinates = polygon[hull.vertices] + + poly_path = mpath.Path(polygon) + # 使用contains_point方法检查点是否在多边形内 + return poly_path.contains_point(point) + + +def intersection_area(bbox1, bbox2): + # 计算两个坐标框的重叠面积 + x1, y1, x2, y2 = bbox1 + x3, y3, x4, y4 = bbox2 + + xi1 = max(x1, x3) + yi1 = max(y1, y3) + xi2 = min(x2, x4) + yi2 = min(y2, y4) + + width = max(0, xi2 - xi1) + height = max(0, yi2 - yi1) + + return width * height + + +def bbox_area(bbox): + # 计算坐标框的面积 + x1, y1, x2, y2 = bbox + return (x2 - x1) * (y2 - y1) + + +def get_person_head(person_bbox, head_bboxes): + best_head = None + max_overlap = 0 + + for head in head_bboxes: + head_bbox = head.xyxy.cpu().squeeze() + overlap_area = intersection_area(person_bbox, head_bbox) + head_area = bbox_area(head_bbox) + overlap_ratio = overlap_area / head_area + + if overlap_ratio >= 0.8 and (best_head is None or overlap_area > max_overlap or ( + overlap_area == max_overlap and float(head.conf) > float(best_head.conf))): + best_head = head + max_overlap = overlap_area + return best_head diff --git a/common/harmful_gas_manager.py b/common/harmful_gas_manager.py new file mode 100644 index 0000000..297e02c --- /dev/null +++ b/common/harmful_gas_manager.py @@ -0,0 +1,44 @@ +import threading +from datetime import datetime + + +class HarmfulGasManager: + + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + # 确保线程安全的单例模式 + if not cls._instance: + with cls._lock: + if not cls._instance: + cls._instance = super(HarmfulGasManager, cls).__new__(cls) + return cls._instance + + def __init__(self): + # 初始化一次,避免重复初始化 + if not hasattr(self, "device_data"): + self.device_data = {} + self.lock = threading.Lock() + + def get_device_all_data(self, device_code): + """获取指定设备的在线状态""" + with self.lock: + return self.device_data.get(device_code, None) + + def get_device_data(self, device_code, gas_type): + """获取指定设备的在线状态""" + with self.lock: + device_data = self.device_data.get(device_code, None) + if device_data: + return device_data.get(gas_type, None) + return None + + def set_device_data(self, device_code, gas_type, gas_data): + """设置指定设备的在线状态""" + with self.lock: + device_data = self.device_data.get(device_code, None) + if not device_data: + device_data = {} + self.device_data[device_code] = device_data + device_data[gas_type] = gas_data \ No newline at end of file diff --git a/common/http_utils.py b/common/http_utils.py index 5f15707..a49c6ea 100644 --- a/common/http_utils.py +++ b/common/http_utils.py @@ -22,3 +22,11 @@ logger.info(f"Response: {response.status_code}, {response.text}") except requests.RequestException as e: logger.error(f"Failed to push data: {e}") + + +def get_request(url, headers=None): + try: + response = requests.get(url, headers=headers) + return response.json() + except requests.RequestException as e: + logger.error(f"Failed to get data: {e}") diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db index d1acdd4..694b825 100644 --- a/db/safe-algo-pro.db +++ b/db/safe-algo-pro.db Binary files differ diff --git a/entity/alarm_record.py b/entity/alarm_record.py new file mode 100644 index 0000000..1ebe474 --- /dev/null +++ b/entity/alarm_record.py @@ -0,0 +1,32 @@ +from datetime import datetime +from typing import Optional + +from sqlmodel import Field, SQLModel + + +class AlarmRecordBase(SQLModel): + # todo tree code?? alarm_category?? + device_code: str + alarm_type: int + alarm_content: str + alarm_value: Optional[float] = None + alarm_image: Optional[str] = None + alarm_time: datetime = Field(default_factory=datetime.now) + + class Config: + json_encoders = { + datetime: lambda v: v.strftime('%Y-%m-%d %H:%M:%S') + } + + +class AlarmRecord(AlarmRecordBase, table=True): + __tablename__ = 'alarm_record' + id: Optional[int] = Field(default=None, primary_key=True) + + +class AlarmRecordCreate(AlarmRecordBase): + pass + + +class AlarmRecordInfo(AlarmRecordBase): + id: int diff --git a/entity/device_scene_relation.py b/entity/device_scene_relation.py index c4543d7..38481ef 100644 --- a/entity/device_scene_relation.py +++ b/entity/device_scene_relation.py @@ -8,6 +8,7 @@ class DeviceSceneRelationBase(SQLModel): scene_id: int device_id: int + range_points: Optional[str] = None class DeviceSceneRelation(DeviceSceneRelationBase, TimestampMixin, table=True): diff --git a/model_handler/base_model_handler.py b/model_handler/base_model_handler.py index 46e0d53..c7c3490 100644 --- a/model_handler/base_model_handler.py +++ b/model_handler/base_model_handler.py @@ -8,41 +8,48 @@ def __init__(self, model: AlgoModelExec): self.model = model self.model_names = model.algo_model_exec.names + self.model_ids = list(self.model_names.keys()) def pre_process(self, frame): return frame - def model_inference(self, frame): - results_generator = self.model.algo_model_exec.predict(source=frame, imgsz=self.model.input_size, + def model_inference(self, frames): + results_generator = self.model.algo_model_exec.predict(source=frames, imgsz=self.model.input_size, save_txt=False, save=False, - verbose=False, stream=True) - results = list(results_generator) # 确保生成器转换为列表 - result = results[0] - # logger.debug(f"model {self.model.algo_model_info.name} result: {len(result)}") - return result + verbose=True, stream=True) + result_boxes = [] + for r in results_generator: + result_boxes.append(r.boxes) - def post_process(self, frame, model_result, annotator): + return result_boxes + + def post_process(self, frames, model_results, annotators): results = [] - for box in model_result.boxes: - results.append( - { - 'object_class_id': int(box.cls), - 'object_class_name': self.model_names[int(box.cls)], - 'confidence': float(box.conf), - 'location': ", ".join([f"{x:.6f}" for x in box.xyxyn.cpu().squeeze().tolist()]) - } - ) - if annotator is not None: - for s_box in model_result.boxes: - annotator.box_label(s_box.xyxy.cpu().squeeze(), - f"{self.model_names[int(s_box.cls)]} {float(s_box.conf):.2f}", - color=colors(int(s_box.cls)), - rotated=False) + for idx,frame in enumerate(frames): + frame_boxes = model_results[idx] + annotator = annotators[idx] + frame_result = [] + for box in frame_boxes: + frame_result.append( + { + 'object_class_id': int(box.cls), + 'object_class_name': self.model_names[int(box.cls)], + 'confidence': float(box.conf), + 'location': ", ".join([f"{x:.6f}" for x in box.xyxyn.cpu().squeeze().tolist()]) + } + ) + results.append(frame_result) + if annotator is not None: + for s_box in frame_boxes: + annotator.box_label(s_box.xyxy.cpu().squeeze(), + f"{self.model_names[int(s_box.cls)]} {float(s_box.conf):.2f}", + color=colors(int(s_box.cls)), + rotated=False) return results - def run(self, frame, annotator): - processed_frame = self.pre_process(frame=frame) - model_result = self.model_inference(frame=processed_frame) - result = self.post_process(frame=frame, model_result=model_result, annotator=annotator) - return frame, result + def run(self, frames, annotators): + processed_frames = [self.pre_process(frame=frame) for frame in frames] + result_boxes = self.model_inference(frames=processed_frames) + results = self.post_process(frames=frames, model_results=result_boxes, annotators=annotators) + return frames, results diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py deleted file mode 100644 index 9109d75..0000000 --- a/algo/algo_runner_manager.py +++ /dev/null @@ -1,8 +0,0 @@ -from algo.algo_runner import AlgoRunner - -# algo_runner = AlgoRunner() - - -def get_algo_runner(): - pass - # return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 6a86830..d2c647e 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -3,7 +3,7 @@ from copy import deepcopy from dataclasses import dataclass import importlib -from datetime import datetime +from datetime import datetime, timedelta from threading import Event from typing import List, Dict @@ -48,6 +48,8 @@ self.__stop_event = Event() # 使用 Event 控制线程的运行状态 self.frame_ts = None self.push_ts = None + self.frames_detected = 0 + self.fps_ts = None self.thread_id = thread_id self.device_frame_service = DeviceFrameService(db_session) @@ -77,37 +79,40 @@ return True return False - def save_frame_results(self, frame, results_map): - if not self.check_frame_interval(): - return - # device_frame = self.device_frame_service.add_frame(self.device.id, frame) - # frame_id = device_frame.id + def save_frame_results(self, frames, result_maps): + for idx,frame in enumerate(frames): + results_map = result_maps[idx] - future = asyncio.run_coroutine_threadsafe( - self.device_frame_service.add_frame(self.device.id, frame), self.main_loop - ) - device_frame = future.result() - frame_id = device_frame.id + if not self.check_frame_interval(): + return + # device_frame = self.device_frame_service.add_frame(self.device.id, frame) + # frame_id = device_frame.id - logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') - frame_results = [] - for model_exec_id, results in results_map.items(): - for r in results: - frame_result = FrameAnalysisResultCreate( - device_id=self.device.id, - frame_id=frame_id, - algo_model_id=model_exec_id, - object_class_id=r.object_class_id, - object_class_name=r.object_class_name, - confidence=round(r.confidence, 4), - location=r.location, - ) - frame_results.append(frame_result) - asyncio.run_coroutine_threadsafe( - self.frame_analysis_result_service.add_frame_analysis_results(frame_results), - self.main_loop - ) - self.thread_pool.submit_task(self.push_frame_results, frame_results) + future = asyncio.run_coroutine_threadsafe( + self.device_frame_service.add_frame(self.device.id, frame), self.main_loop + ) + device_frame = future.result() + frame_id = device_frame.id + + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') + frame_results = [] + for model_exec_id, results in results_map.items(): + for r in results: + frame_result = FrameAnalysisResultCreate( + device_id=self.device.id, + frame_id=frame_id, + algo_model_id=model_exec_id, + object_class_id=r.object_class_id, + object_class_name=r.object_class_name, + confidence=round(r.confidence, 4), + location=r.location, + ) + frame_results.append(frame_result) + asyncio.run_coroutine_threadsafe( + self.frame_analysis_result_service.add_frame_analysis_results(frame_results), + self.main_loop + ) + self.thread_pool.submit_task(self.push_frame_results, frame_results) def push_frame_results(self, frame_results): global_config = GlobalConfig() @@ -124,31 +129,63 @@ ) self.push_ts = current_time # 更新推送时间戳 + def log_fps(self, frame_count): + self.frames_detected += frame_count + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_detected / 10.0 + self.frames_detected = 0 + logger.info(f"FPS (detect) for device {self.device.code}: {fps}") + self.fps_ts = current_time + def run(self): while not self.stream_loader.init: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 self.stream_loader.init_cap() - for frame in self.stream_loader: + for frames in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 # print('frame') - if frame is None: + if frames is None or len(frames)<=0: continue self.device_status_manager.set_status(device_id=self.device.id) - results_map = {} - annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + + result_maps = [] # 保存每个frame的result map + annotators = [] + for frame in frames: + annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + annotators.append(annotator) + for model_exec in self.model_exec_list: handle_task_name = model_exec.algo_model_info.handle_task handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) handler_instance = handler_cls(model_exec) - frame, results = handler_instance.run(frame, annotator) - results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] + frames, results = handler_instance.run(frames, annotators) + + # 遍历检测结果,按帧存储 + for frame_idx, r in enumerate(results): # 遍历每帧结果 + if len(result_maps) <= frame_idx: # 初始化 frame_results_map + result_maps.append({}) + frame_results_map = result_maps[frame_idx] + # 为当前模型存储检测结果 + if model_exec.algo_model_id not in frame_results_map: + frame_results_map[model_exec.algo_model_id] = [] + # 添加检测结果 + frame_results_map[model_exec.algo_model_id].extend( + DetectionResult.from_dict(box) for box in r + ) # 结果处理 - self.thread_pool.submit_task(self.save_frame_results, frame, results_map) - self.display_frame_manager.add_frame(self.device.id, annotator.result()) + self.thread_pool.submit_task(self.save_frame_results, frames, result_maps) + for annotator in annotators: + self.display_frame_manager.add_frame(self.device.id, annotator.result()) + + self.log_fps(len(frames)) + + # future = asyncio.run_coroutine_threadsafe( # self.display_frame_manager.add_frame(self.device.id, annotator.result()), self.main_loop # ) diff --git a/algo/model_manager.py b/algo/model_manager.py index 1c30867..e84676f 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -19,11 +19,12 @@ class ModelManager: - def __init__(self, model_service: ModelService, model_warm_up=5): + def __init__(self, model_service: ModelService, model_warm_up=5, batch_size=4): # self.db = db self.model_service = model_service self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up + self.batch_size=4 async def query_model_inuse(self): algo_model_list = list(await self.model_service.get_models_in_use()) @@ -58,8 +59,11 @@ if not isinstance(imgsz, list): imgsz = [imgsz, imgsz] dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + dummy_inputs = [] + for _ in range(self.batch_size): + dummy_inputs.append(dummy_input) for i in range(self.model_warm_up): - algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + algo_model_exec.algo_model_exec.predict(source=dummy_inputs, imgsz=imgsz, verbose=False,save=False,save_txt=False) logger.info(f'warm up model {model_name} success!') logger.info(f'load model {model_name} success!') diff --git a/algo/scene_runner.py b/algo/scene_runner.py index e5c067e..5120f96 100644 --- a/algo/scene_runner.py +++ b/algo/scene_runner.py @@ -69,7 +69,7 @@ try: handle_task_name = scene.scene_handle_task handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) - handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) + handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop,scene.range_points) future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) self.device_tasks[device.id] = handler_instance logger.info(f'start thread {thread_id}, device info: {device}') diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 4133388..202213c 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -1,13 +1,20 @@ +from datetime import datetime, timedelta + import cv2 import time import numpy as np from threading import Thread, Event + +import queue + from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool class OpenCVStreamLoad: def __init__(self, camera_url, camera_code, device_thread_id = '', + batch_size=4, + queue_size=100, retry_interval=1, vid_stride=1): assert camera_url is not None and camera_url != '' @@ -19,10 +26,15 @@ self.init = False self.frame = None + self.frame_queue = queue.Queue(maxsize=queue_size) + self.fps_ts = None self.cap = None + self.batch_size = batch_size + self.frames_read = 0 self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 self.thread_pool = GlobalThreadPool() + self.init_cap() # self.thread = Thread(target=self.update, daemon=True) # self.thread.start() @@ -44,7 +56,14 @@ 尝试创建视频流捕获对象。 """ try: - cap = cv2.VideoCapture(self.url) + # cap = cv2.VideoCapture(self.url) + gst_pipeline = ( + f"rtspsrc location={self.url} ! " + f"rtph264depay ! h264parse ! " + f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + f"appsink" + ) + cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) # 可以在这里设置cap的一些属性,如果需要的话 return cap except Exception as e: @@ -74,9 +93,13 @@ cap = cv2.VideoCapture(self.url) # cap = cv2.VideoCapture(self.url, cv2.CAP_FFMPEG) - # 使用 GStreamer pipeline 作为 VideoCapture 的输入 - # gst_str = f"rtspsrc location={self.url} ! decodebin ! videoconvert ! appsink" - # cap = cv2.VideoCapture(gst_str, cv2.CAP_GSTREAMER) + # gst_pipeline = ( + # f"rtspsrc location={self.url} ! " + # f"rtph264depay ! h264parse ! " + # f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + # f"appsink" + # ) + # cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) if cap.isOpened(): logger.info(f"{self.url} connected successfully!") @@ -91,6 +114,15 @@ return cap + def log_fps(self): + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_read / 10.0 + self.frames_read = 0 + logger.info(f"FPS (read) for device {self.camera_code}: {fps}") + self.fps_ts = current_time + def update(self): vid_n = 0 log_n = 0 @@ -111,6 +143,9 @@ else: vid_n += 1 self.frame = frame + self.frames_read += 1 + if not self.frame_queue.full(): + self.frame_queue.put(frame) if log_n % 1000 == 0: logger.debug(f'{self.url} cap success') log_n = (log_n + 1) % 250 @@ -120,12 +155,22 @@ self.cap.release() self.frame = None self.cap = self.get_connect() # 尝试重新连接 + self.log_fps() def __iter__(self): return self def __next__(self): - return self.frame + batch_frames = [] + + queue_length = self.frame_queue.qsize() + if queue_length < self.batch_size: + return [] + + while not self.frame_queue.empty() and len(batch_frames) < self.batch_size: + frame = self.frame_queue.get() + batch_frames.append(frame) + return batch_frames def stop(self): logger.info(f'stop stream loader {self.url},device thread {self.device_thread_id} stopped') diff --git a/apis/control.py b/apis/control.py index 9c6d6fb..2d351d8 100644 --- a/apis/control.py +++ b/apis/control.py @@ -9,6 +9,7 @@ from fastapi import APIRouter import docker import socket +import signal from apis.base import standard_error_response, standard_response from common.global_logger import logger @@ -47,6 +48,23 @@ return None +def handle_sigterm(signum, frame): + print(f"Received signal: {signum}, shutting down gracefully...") + os._exit(0) + + +signal.signal(signal.SIGTERM, handle_sigterm) + + +def shutdown(): + # 等待短暂时间以确保响应已发送 + time.sleep(1) + logger.info("Shutting down the application...") + # sys.exit(0) + # os.kill(os.getpid(), signal.SIGTERM) + os._exit(0) + + @router.get("/restart") async def restart(): try: @@ -92,17 +110,20 @@ logger.error(f"Failed to restart container asynchronously: {ex}") # 在新线程中启动重启操作 - threading.Thread(target=restart_container_async).start() + # threading.Thread(target=restart_container_async).start() + threading.Thread(target=shutdown).start() return standard_response() except Exception as e: traceback.print_exc() return standard_error_response(code=500, message=f"Failed to restart container: {e}") + @router.get("/sync_test") def sync_test(): return standard_response() + @router.get("/async_test") async def async_test(): return standard_response() diff --git a/apis/device_scene_realtion.py b/apis/device_scene_realtion.py index d9aa3a2..77d3677 100644 --- a/apis/device_scene_realtion.py +++ b/apis/device_scene_realtion.py @@ -22,8 +22,8 @@ @router.post("/update_by_device", response_model=StandardResponse[DeviceSceneRelation]) -async def update_by_device(device_id: int, scene_id: int, +async def update_by_device(device_id: int, scene_id: int, range_points: str = None, db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - relation = await service.update_relation_by_device(device_id, scene_id) + relation = await service.update_relation_by_device(device_id, scene_id,range_points) return standard_response(data=relation) diff --git a/app_instance.py b/app_instance.py index 82a750d..c648f6e 100644 --- a/app_instance.py +++ b/app_instance.py @@ -1,4 +1,5 @@ import asyncio + from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException @@ -17,6 +18,7 @@ from services.scene_service import SceneService from services.schedule_job import start_scheduler from tcp.tcp_manager import TcpManager +from tcp.tcp_server import start_server _app = None # 创建一个私有变量来存储 app 实例 @@ -69,6 +71,8 @@ app.state.scene_runner = scene_runner await scene_runner.start() + main_loop.create_task(start_server()) + main_loop.create_task(start_scheduler()) yield # 允许请求处理 diff --git a/common/byte_utils.py b/common/byte_utils.py index 301bb8a..ea3f4ea 100644 --- a/common/byte_utils.py +++ b/common/byte_utils.py @@ -1,3 +1,11 @@ def format_bytes(data: bytes): - return ''.join(f'\\x{byte:02X}' for byte in data) \ No newline at end of file + return ''.join(f'\\x{byte:02X}' for byte in data) + + +if __name__ == '__main__': + msg = b'\xaa\x01\x00\x93\x12\x00\xA6' + print(format_bytes(msg)) + + msg = b'\xAA\x01\x00\x95\x00\x00\x96' + print(format_bytes(msg)) \ No newline at end of file diff --git a/common/consts.py b/common/consts.py index 5a3286c..6deaedf 100644 --- a/common/consts.py +++ b/common/consts.py @@ -36,3 +36,4 @@ GAS = 1 ALGO_RESULT = 2 ALARM = 3 + HARMFUL_GAS = 4 diff --git a/common/detect_utils.py b/common/detect_utils.py new file mode 100644 index 0000000..33c340c --- /dev/null +++ b/common/detect_utils.py @@ -0,0 +1,72 @@ +import matplotlib.path as mpath + + +def is_within_alert_range(bbox, range): + if range is None: + return True + bottom_left = [bbox[0], bbox[3]] + bottom_right = [bbox[2], bbox[3]] + + in_fence = is_point_in_polygon(bottom_left, range) or is_point_in_polygon(bottom_right, range) + return in_fence + + +def is_point_in_polygon(point, polygon): + """ + 判断点是否在多边形内 + + 参数: + - point: 待判断的点,格式为(x, y) + - polygon: 多边形的顶点坐标列表,格式为[(x1, y1), (x2, y2), ...] + + 返回: + - 如果点在多边形内,返回True;否则,返回False。 + """ + # 创建多边形路径对象 + # polygon = np.array(polygon, dtype=np.int32) + # + # hull = ConvexHull(polygon) + # sorted_coordinates = polygon[hull.vertices] + + poly_path = mpath.Path(polygon) + # 使用contains_point方法检查点是否在多边形内 + return poly_path.contains_point(point) + + +def intersection_area(bbox1, bbox2): + # 计算两个坐标框的重叠面积 + x1, y1, x2, y2 = bbox1 + x3, y3, x4, y4 = bbox2 + + xi1 = max(x1, x3) + yi1 = max(y1, y3) + xi2 = min(x2, x4) + yi2 = min(y2, y4) + + width = max(0, xi2 - xi1) + height = max(0, yi2 - yi1) + + return width * height + + +def bbox_area(bbox): + # 计算坐标框的面积 + x1, y1, x2, y2 = bbox + return (x2 - x1) * (y2 - y1) + + +def get_person_head(person_bbox, head_bboxes): + best_head = None + max_overlap = 0 + + for head in head_bboxes: + head_bbox = head.xyxy.cpu().squeeze() + overlap_area = intersection_area(person_bbox, head_bbox) + head_area = bbox_area(head_bbox) + overlap_ratio = overlap_area / head_area + + if overlap_ratio >= 0.8 and (best_head is None or overlap_area > max_overlap or ( + overlap_area == max_overlap and float(head.conf) > float(best_head.conf))): + best_head = head + max_overlap = overlap_area + return best_head diff --git a/common/harmful_gas_manager.py b/common/harmful_gas_manager.py new file mode 100644 index 0000000..297e02c --- /dev/null +++ b/common/harmful_gas_manager.py @@ -0,0 +1,44 @@ +import threading +from datetime import datetime + + +class HarmfulGasManager: + + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + # 确保线程安全的单例模式 + if not cls._instance: + with cls._lock: + if not cls._instance: + cls._instance = super(HarmfulGasManager, cls).__new__(cls) + return cls._instance + + def __init__(self): + # 初始化一次,避免重复初始化 + if not hasattr(self, "device_data"): + self.device_data = {} + self.lock = threading.Lock() + + def get_device_all_data(self, device_code): + """获取指定设备的在线状态""" + with self.lock: + return self.device_data.get(device_code, None) + + def get_device_data(self, device_code, gas_type): + """获取指定设备的在线状态""" + with self.lock: + device_data = self.device_data.get(device_code, None) + if device_data: + return device_data.get(gas_type, None) + return None + + def set_device_data(self, device_code, gas_type, gas_data): + """设置指定设备的在线状态""" + with self.lock: + device_data = self.device_data.get(device_code, None) + if not device_data: + device_data = {} + self.device_data[device_code] = device_data + device_data[gas_type] = gas_data \ No newline at end of file diff --git a/common/http_utils.py b/common/http_utils.py index 5f15707..a49c6ea 100644 --- a/common/http_utils.py +++ b/common/http_utils.py @@ -22,3 +22,11 @@ logger.info(f"Response: {response.status_code}, {response.text}") except requests.RequestException as e: logger.error(f"Failed to push data: {e}") + + +def get_request(url, headers=None): + try: + response = requests.get(url, headers=headers) + return response.json() + except requests.RequestException as e: + logger.error(f"Failed to get data: {e}") diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db index d1acdd4..694b825 100644 --- a/db/safe-algo-pro.db +++ b/db/safe-algo-pro.db Binary files differ diff --git a/entity/alarm_record.py b/entity/alarm_record.py new file mode 100644 index 0000000..1ebe474 --- /dev/null +++ b/entity/alarm_record.py @@ -0,0 +1,32 @@ +from datetime import datetime +from typing import Optional + +from sqlmodel import Field, SQLModel + + +class AlarmRecordBase(SQLModel): + # todo tree code?? alarm_category?? + device_code: str + alarm_type: int + alarm_content: str + alarm_value: Optional[float] = None + alarm_image: Optional[str] = None + alarm_time: datetime = Field(default_factory=datetime.now) + + class Config: + json_encoders = { + datetime: lambda v: v.strftime('%Y-%m-%d %H:%M:%S') + } + + +class AlarmRecord(AlarmRecordBase, table=True): + __tablename__ = 'alarm_record' + id: Optional[int] = Field(default=None, primary_key=True) + + +class AlarmRecordCreate(AlarmRecordBase): + pass + + +class AlarmRecordInfo(AlarmRecordBase): + id: int diff --git a/entity/device_scene_relation.py b/entity/device_scene_relation.py index c4543d7..38481ef 100644 --- a/entity/device_scene_relation.py +++ b/entity/device_scene_relation.py @@ -8,6 +8,7 @@ class DeviceSceneRelationBase(SQLModel): scene_id: int device_id: int + range_points: Optional[str] = None class DeviceSceneRelation(DeviceSceneRelationBase, TimestampMixin, table=True): diff --git a/model_handler/base_model_handler.py b/model_handler/base_model_handler.py index 46e0d53..c7c3490 100644 --- a/model_handler/base_model_handler.py +++ b/model_handler/base_model_handler.py @@ -8,41 +8,48 @@ def __init__(self, model: AlgoModelExec): self.model = model self.model_names = model.algo_model_exec.names + self.model_ids = list(self.model_names.keys()) def pre_process(self, frame): return frame - def model_inference(self, frame): - results_generator = self.model.algo_model_exec.predict(source=frame, imgsz=self.model.input_size, + def model_inference(self, frames): + results_generator = self.model.algo_model_exec.predict(source=frames, imgsz=self.model.input_size, save_txt=False, save=False, - verbose=False, stream=True) - results = list(results_generator) # 确保生成器转换为列表 - result = results[0] - # logger.debug(f"model {self.model.algo_model_info.name} result: {len(result)}") - return result + verbose=True, stream=True) + result_boxes = [] + for r in results_generator: + result_boxes.append(r.boxes) - def post_process(self, frame, model_result, annotator): + return result_boxes + + def post_process(self, frames, model_results, annotators): results = [] - for box in model_result.boxes: - results.append( - { - 'object_class_id': int(box.cls), - 'object_class_name': self.model_names[int(box.cls)], - 'confidence': float(box.conf), - 'location': ", ".join([f"{x:.6f}" for x in box.xyxyn.cpu().squeeze().tolist()]) - } - ) - if annotator is not None: - for s_box in model_result.boxes: - annotator.box_label(s_box.xyxy.cpu().squeeze(), - f"{self.model_names[int(s_box.cls)]} {float(s_box.conf):.2f}", - color=colors(int(s_box.cls)), - rotated=False) + for idx,frame in enumerate(frames): + frame_boxes = model_results[idx] + annotator = annotators[idx] + frame_result = [] + for box in frame_boxes: + frame_result.append( + { + 'object_class_id': int(box.cls), + 'object_class_name': self.model_names[int(box.cls)], + 'confidence': float(box.conf), + 'location': ", ".join([f"{x:.6f}" for x in box.xyxyn.cpu().squeeze().tolist()]) + } + ) + results.append(frame_result) + if annotator is not None: + for s_box in frame_boxes: + annotator.box_label(s_box.xyxy.cpu().squeeze(), + f"{self.model_names[int(s_box.cls)]} {float(s_box.conf):.2f}", + color=colors(int(s_box.cls)), + rotated=False) return results - def run(self, frame, annotator): - processed_frame = self.pre_process(frame=frame) - model_result = self.model_inference(frame=processed_frame) - result = self.post_process(frame=frame, model_result=model_result, annotator=annotator) - return frame, result + def run(self, frames, annotators): + processed_frames = [self.pre_process(frame=frame) for frame in frames] + result_boxes = self.model_inference(frames=processed_frames) + results = self.post_process(frames=frames, model_results=result_boxes, annotators=annotators) + return frames, results diff --git a/model_handler/coco_engine_model_handler.py b/model_handler/coco_engine_model_handler.py index e6023d0..ec8006f 100644 --- a/model_handler/coco_engine_model_handler.py +++ b/model_handler/coco_engine_model_handler.py @@ -88,3 +88,4 @@ 78: 'hair_drier', 79: 'toothbrush', } + self.model_ids = list(self.model_names.keys()) diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py deleted file mode 100644 index 9109d75..0000000 --- a/algo/algo_runner_manager.py +++ /dev/null @@ -1,8 +0,0 @@ -from algo.algo_runner import AlgoRunner - -# algo_runner = AlgoRunner() - - -def get_algo_runner(): - pass - # return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 6a86830..d2c647e 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -3,7 +3,7 @@ from copy import deepcopy from dataclasses import dataclass import importlib -from datetime import datetime +from datetime import datetime, timedelta from threading import Event from typing import List, Dict @@ -48,6 +48,8 @@ self.__stop_event = Event() # 使用 Event 控制线程的运行状态 self.frame_ts = None self.push_ts = None + self.frames_detected = 0 + self.fps_ts = None self.thread_id = thread_id self.device_frame_service = DeviceFrameService(db_session) @@ -77,37 +79,40 @@ return True return False - def save_frame_results(self, frame, results_map): - if not self.check_frame_interval(): - return - # device_frame = self.device_frame_service.add_frame(self.device.id, frame) - # frame_id = device_frame.id + def save_frame_results(self, frames, result_maps): + for idx,frame in enumerate(frames): + results_map = result_maps[idx] - future = asyncio.run_coroutine_threadsafe( - self.device_frame_service.add_frame(self.device.id, frame), self.main_loop - ) - device_frame = future.result() - frame_id = device_frame.id + if not self.check_frame_interval(): + return + # device_frame = self.device_frame_service.add_frame(self.device.id, frame) + # frame_id = device_frame.id - logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') - frame_results = [] - for model_exec_id, results in results_map.items(): - for r in results: - frame_result = FrameAnalysisResultCreate( - device_id=self.device.id, - frame_id=frame_id, - algo_model_id=model_exec_id, - object_class_id=r.object_class_id, - object_class_name=r.object_class_name, - confidence=round(r.confidence, 4), - location=r.location, - ) - frame_results.append(frame_result) - asyncio.run_coroutine_threadsafe( - self.frame_analysis_result_service.add_frame_analysis_results(frame_results), - self.main_loop - ) - self.thread_pool.submit_task(self.push_frame_results, frame_results) + future = asyncio.run_coroutine_threadsafe( + self.device_frame_service.add_frame(self.device.id, frame), self.main_loop + ) + device_frame = future.result() + frame_id = device_frame.id + + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') + frame_results = [] + for model_exec_id, results in results_map.items(): + for r in results: + frame_result = FrameAnalysisResultCreate( + device_id=self.device.id, + frame_id=frame_id, + algo_model_id=model_exec_id, + object_class_id=r.object_class_id, + object_class_name=r.object_class_name, + confidence=round(r.confidence, 4), + location=r.location, + ) + frame_results.append(frame_result) + asyncio.run_coroutine_threadsafe( + self.frame_analysis_result_service.add_frame_analysis_results(frame_results), + self.main_loop + ) + self.thread_pool.submit_task(self.push_frame_results, frame_results) def push_frame_results(self, frame_results): global_config = GlobalConfig() @@ -124,31 +129,63 @@ ) self.push_ts = current_time # 更新推送时间戳 + def log_fps(self, frame_count): + self.frames_detected += frame_count + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_detected / 10.0 + self.frames_detected = 0 + logger.info(f"FPS (detect) for device {self.device.code}: {fps}") + self.fps_ts = current_time + def run(self): while not self.stream_loader.init: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 self.stream_loader.init_cap() - for frame in self.stream_loader: + for frames in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 # print('frame') - if frame is None: + if frames is None or len(frames)<=0: continue self.device_status_manager.set_status(device_id=self.device.id) - results_map = {} - annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + + result_maps = [] # 保存每个frame的result map + annotators = [] + for frame in frames: + annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + annotators.append(annotator) + for model_exec in self.model_exec_list: handle_task_name = model_exec.algo_model_info.handle_task handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) handler_instance = handler_cls(model_exec) - frame, results = handler_instance.run(frame, annotator) - results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] + frames, results = handler_instance.run(frames, annotators) + + # 遍历检测结果,按帧存储 + for frame_idx, r in enumerate(results): # 遍历每帧结果 + if len(result_maps) <= frame_idx: # 初始化 frame_results_map + result_maps.append({}) + frame_results_map = result_maps[frame_idx] + # 为当前模型存储检测结果 + if model_exec.algo_model_id not in frame_results_map: + frame_results_map[model_exec.algo_model_id] = [] + # 添加检测结果 + frame_results_map[model_exec.algo_model_id].extend( + DetectionResult.from_dict(box) for box in r + ) # 结果处理 - self.thread_pool.submit_task(self.save_frame_results, frame, results_map) - self.display_frame_manager.add_frame(self.device.id, annotator.result()) + self.thread_pool.submit_task(self.save_frame_results, frames, result_maps) + for annotator in annotators: + self.display_frame_manager.add_frame(self.device.id, annotator.result()) + + self.log_fps(len(frames)) + + # future = asyncio.run_coroutine_threadsafe( # self.display_frame_manager.add_frame(self.device.id, annotator.result()), self.main_loop # ) diff --git a/algo/model_manager.py b/algo/model_manager.py index 1c30867..e84676f 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -19,11 +19,12 @@ class ModelManager: - def __init__(self, model_service: ModelService, model_warm_up=5): + def __init__(self, model_service: ModelService, model_warm_up=5, batch_size=4): # self.db = db self.model_service = model_service self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up + self.batch_size=4 async def query_model_inuse(self): algo_model_list = list(await self.model_service.get_models_in_use()) @@ -58,8 +59,11 @@ if not isinstance(imgsz, list): imgsz = [imgsz, imgsz] dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + dummy_inputs = [] + for _ in range(self.batch_size): + dummy_inputs.append(dummy_input) for i in range(self.model_warm_up): - algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + algo_model_exec.algo_model_exec.predict(source=dummy_inputs, imgsz=imgsz, verbose=False,save=False,save_txt=False) logger.info(f'warm up model {model_name} success!') logger.info(f'load model {model_name} success!') diff --git a/algo/scene_runner.py b/algo/scene_runner.py index e5c067e..5120f96 100644 --- a/algo/scene_runner.py +++ b/algo/scene_runner.py @@ -69,7 +69,7 @@ try: handle_task_name = scene.scene_handle_task handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) - handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) + handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop,scene.range_points) future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) self.device_tasks[device.id] = handler_instance logger.info(f'start thread {thread_id}, device info: {device}') diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 4133388..202213c 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -1,13 +1,20 @@ +from datetime import datetime, timedelta + import cv2 import time import numpy as np from threading import Thread, Event + +import queue + from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool class OpenCVStreamLoad: def __init__(self, camera_url, camera_code, device_thread_id = '', + batch_size=4, + queue_size=100, retry_interval=1, vid_stride=1): assert camera_url is not None and camera_url != '' @@ -19,10 +26,15 @@ self.init = False self.frame = None + self.frame_queue = queue.Queue(maxsize=queue_size) + self.fps_ts = None self.cap = None + self.batch_size = batch_size + self.frames_read = 0 self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 self.thread_pool = GlobalThreadPool() + self.init_cap() # self.thread = Thread(target=self.update, daemon=True) # self.thread.start() @@ -44,7 +56,14 @@ 尝试创建视频流捕获对象。 """ try: - cap = cv2.VideoCapture(self.url) + # cap = cv2.VideoCapture(self.url) + gst_pipeline = ( + f"rtspsrc location={self.url} ! " + f"rtph264depay ! h264parse ! " + f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + f"appsink" + ) + cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) # 可以在这里设置cap的一些属性,如果需要的话 return cap except Exception as e: @@ -74,9 +93,13 @@ cap = cv2.VideoCapture(self.url) # cap = cv2.VideoCapture(self.url, cv2.CAP_FFMPEG) - # 使用 GStreamer pipeline 作为 VideoCapture 的输入 - # gst_str = f"rtspsrc location={self.url} ! decodebin ! videoconvert ! appsink" - # cap = cv2.VideoCapture(gst_str, cv2.CAP_GSTREAMER) + # gst_pipeline = ( + # f"rtspsrc location={self.url} ! " + # f"rtph264depay ! h264parse ! " + # f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + # f"appsink" + # ) + # cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) if cap.isOpened(): logger.info(f"{self.url} connected successfully!") @@ -91,6 +114,15 @@ return cap + def log_fps(self): + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_read / 10.0 + self.frames_read = 0 + logger.info(f"FPS (read) for device {self.camera_code}: {fps}") + self.fps_ts = current_time + def update(self): vid_n = 0 log_n = 0 @@ -111,6 +143,9 @@ else: vid_n += 1 self.frame = frame + self.frames_read += 1 + if not self.frame_queue.full(): + self.frame_queue.put(frame) if log_n % 1000 == 0: logger.debug(f'{self.url} cap success') log_n = (log_n + 1) % 250 @@ -120,12 +155,22 @@ self.cap.release() self.frame = None self.cap = self.get_connect() # 尝试重新连接 + self.log_fps() def __iter__(self): return self def __next__(self): - return self.frame + batch_frames = [] + + queue_length = self.frame_queue.qsize() + if queue_length < self.batch_size: + return [] + + while not self.frame_queue.empty() and len(batch_frames) < self.batch_size: + frame = self.frame_queue.get() + batch_frames.append(frame) + return batch_frames def stop(self): logger.info(f'stop stream loader {self.url},device thread {self.device_thread_id} stopped') diff --git a/apis/control.py b/apis/control.py index 9c6d6fb..2d351d8 100644 --- a/apis/control.py +++ b/apis/control.py @@ -9,6 +9,7 @@ from fastapi import APIRouter import docker import socket +import signal from apis.base import standard_error_response, standard_response from common.global_logger import logger @@ -47,6 +48,23 @@ return None +def handle_sigterm(signum, frame): + print(f"Received signal: {signum}, shutting down gracefully...") + os._exit(0) + + +signal.signal(signal.SIGTERM, handle_sigterm) + + +def shutdown(): + # 等待短暂时间以确保响应已发送 + time.sleep(1) + logger.info("Shutting down the application...") + # sys.exit(0) + # os.kill(os.getpid(), signal.SIGTERM) + os._exit(0) + + @router.get("/restart") async def restart(): try: @@ -92,17 +110,20 @@ logger.error(f"Failed to restart container asynchronously: {ex}") # 在新线程中启动重启操作 - threading.Thread(target=restart_container_async).start() + # threading.Thread(target=restart_container_async).start() + threading.Thread(target=shutdown).start() return standard_response() except Exception as e: traceback.print_exc() return standard_error_response(code=500, message=f"Failed to restart container: {e}") + @router.get("/sync_test") def sync_test(): return standard_response() + @router.get("/async_test") async def async_test(): return standard_response() diff --git a/apis/device_scene_realtion.py b/apis/device_scene_realtion.py index d9aa3a2..77d3677 100644 --- a/apis/device_scene_realtion.py +++ b/apis/device_scene_realtion.py @@ -22,8 +22,8 @@ @router.post("/update_by_device", response_model=StandardResponse[DeviceSceneRelation]) -async def update_by_device(device_id: int, scene_id: int, +async def update_by_device(device_id: int, scene_id: int, range_points: str = None, db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - relation = await service.update_relation_by_device(device_id, scene_id) + relation = await service.update_relation_by_device(device_id, scene_id,range_points) return standard_response(data=relation) diff --git a/app_instance.py b/app_instance.py index 82a750d..c648f6e 100644 --- a/app_instance.py +++ b/app_instance.py @@ -1,4 +1,5 @@ import asyncio + from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException @@ -17,6 +18,7 @@ from services.scene_service import SceneService from services.schedule_job import start_scheduler from tcp.tcp_manager import TcpManager +from tcp.tcp_server import start_server _app = None # 创建一个私有变量来存储 app 实例 @@ -69,6 +71,8 @@ app.state.scene_runner = scene_runner await scene_runner.start() + main_loop.create_task(start_server()) + main_loop.create_task(start_scheduler()) yield # 允许请求处理 diff --git a/common/byte_utils.py b/common/byte_utils.py index 301bb8a..ea3f4ea 100644 --- a/common/byte_utils.py +++ b/common/byte_utils.py @@ -1,3 +1,11 @@ def format_bytes(data: bytes): - return ''.join(f'\\x{byte:02X}' for byte in data) \ No newline at end of file + return ''.join(f'\\x{byte:02X}' for byte in data) + + +if __name__ == '__main__': + msg = b'\xaa\x01\x00\x93\x12\x00\xA6' + print(format_bytes(msg)) + + msg = b'\xAA\x01\x00\x95\x00\x00\x96' + print(format_bytes(msg)) \ No newline at end of file diff --git a/common/consts.py b/common/consts.py index 5a3286c..6deaedf 100644 --- a/common/consts.py +++ b/common/consts.py @@ -36,3 +36,4 @@ GAS = 1 ALGO_RESULT = 2 ALARM = 3 + HARMFUL_GAS = 4 diff --git a/common/detect_utils.py b/common/detect_utils.py new file mode 100644 index 0000000..33c340c --- /dev/null +++ b/common/detect_utils.py @@ -0,0 +1,72 @@ +import matplotlib.path as mpath + + +def is_within_alert_range(bbox, range): + if range is None: + return True + bottom_left = [bbox[0], bbox[3]] + bottom_right = [bbox[2], bbox[3]] + + in_fence = is_point_in_polygon(bottom_left, range) or is_point_in_polygon(bottom_right, range) + return in_fence + + +def is_point_in_polygon(point, polygon): + """ + 判断点是否在多边形内 + + 参数: + - point: 待判断的点,格式为(x, y) + - polygon: 多边形的顶点坐标列表,格式为[(x1, y1), (x2, y2), ...] + + 返回: + - 如果点在多边形内,返回True;否则,返回False。 + """ + # 创建多边形路径对象 + # polygon = np.array(polygon, dtype=np.int32) + # + # hull = ConvexHull(polygon) + # sorted_coordinates = polygon[hull.vertices] + + poly_path = mpath.Path(polygon) + # 使用contains_point方法检查点是否在多边形内 + return poly_path.contains_point(point) + + +def intersection_area(bbox1, bbox2): + # 计算两个坐标框的重叠面积 + x1, y1, x2, y2 = bbox1 + x3, y3, x4, y4 = bbox2 + + xi1 = max(x1, x3) + yi1 = max(y1, y3) + xi2 = min(x2, x4) + yi2 = min(y2, y4) + + width = max(0, xi2 - xi1) + height = max(0, yi2 - yi1) + + return width * height + + +def bbox_area(bbox): + # 计算坐标框的面积 + x1, y1, x2, y2 = bbox + return (x2 - x1) * (y2 - y1) + + +def get_person_head(person_bbox, head_bboxes): + best_head = None + max_overlap = 0 + + for head in head_bboxes: + head_bbox = head.xyxy.cpu().squeeze() + overlap_area = intersection_area(person_bbox, head_bbox) + head_area = bbox_area(head_bbox) + overlap_ratio = overlap_area / head_area + + if overlap_ratio >= 0.8 and (best_head is None or overlap_area > max_overlap or ( + overlap_area == max_overlap and float(head.conf) > float(best_head.conf))): + best_head = head + max_overlap = overlap_area + return best_head diff --git a/common/harmful_gas_manager.py b/common/harmful_gas_manager.py new file mode 100644 index 0000000..297e02c --- /dev/null +++ b/common/harmful_gas_manager.py @@ -0,0 +1,44 @@ +import threading +from datetime import datetime + + +class HarmfulGasManager: + + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + # 确保线程安全的单例模式 + if not cls._instance: + with cls._lock: + if not cls._instance: + cls._instance = super(HarmfulGasManager, cls).__new__(cls) + return cls._instance + + def __init__(self): + # 初始化一次,避免重复初始化 + if not hasattr(self, "device_data"): + self.device_data = {} + self.lock = threading.Lock() + + def get_device_all_data(self, device_code): + """获取指定设备的在线状态""" + with self.lock: + return self.device_data.get(device_code, None) + + def get_device_data(self, device_code, gas_type): + """获取指定设备的在线状态""" + with self.lock: + device_data = self.device_data.get(device_code, None) + if device_data: + return device_data.get(gas_type, None) + return None + + def set_device_data(self, device_code, gas_type, gas_data): + """设置指定设备的在线状态""" + with self.lock: + device_data = self.device_data.get(device_code, None) + if not device_data: + device_data = {} + self.device_data[device_code] = device_data + device_data[gas_type] = gas_data \ No newline at end of file diff --git a/common/http_utils.py b/common/http_utils.py index 5f15707..a49c6ea 100644 --- a/common/http_utils.py +++ b/common/http_utils.py @@ -22,3 +22,11 @@ logger.info(f"Response: {response.status_code}, {response.text}") except requests.RequestException as e: logger.error(f"Failed to push data: {e}") + + +def get_request(url, headers=None): + try: + response = requests.get(url, headers=headers) + return response.json() + except requests.RequestException as e: + logger.error(f"Failed to get data: {e}") diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db index d1acdd4..694b825 100644 --- a/db/safe-algo-pro.db +++ b/db/safe-algo-pro.db Binary files differ diff --git a/entity/alarm_record.py b/entity/alarm_record.py new file mode 100644 index 0000000..1ebe474 --- /dev/null +++ b/entity/alarm_record.py @@ -0,0 +1,32 @@ +from datetime import datetime +from typing import Optional + +from sqlmodel import Field, SQLModel + + +class AlarmRecordBase(SQLModel): + # todo tree code?? alarm_category?? + device_code: str + alarm_type: int + alarm_content: str + alarm_value: Optional[float] = None + alarm_image: Optional[str] = None + alarm_time: datetime = Field(default_factory=datetime.now) + + class Config: + json_encoders = { + datetime: lambda v: v.strftime('%Y-%m-%d %H:%M:%S') + } + + +class AlarmRecord(AlarmRecordBase, table=True): + __tablename__ = 'alarm_record' + id: Optional[int] = Field(default=None, primary_key=True) + + +class AlarmRecordCreate(AlarmRecordBase): + pass + + +class AlarmRecordInfo(AlarmRecordBase): + id: int diff --git a/entity/device_scene_relation.py b/entity/device_scene_relation.py index c4543d7..38481ef 100644 --- a/entity/device_scene_relation.py +++ b/entity/device_scene_relation.py @@ -8,6 +8,7 @@ class DeviceSceneRelationBase(SQLModel): scene_id: int device_id: int + range_points: Optional[str] = None class DeviceSceneRelation(DeviceSceneRelationBase, TimestampMixin, table=True): diff --git a/model_handler/base_model_handler.py b/model_handler/base_model_handler.py index 46e0d53..c7c3490 100644 --- a/model_handler/base_model_handler.py +++ b/model_handler/base_model_handler.py @@ -8,41 +8,48 @@ def __init__(self, model: AlgoModelExec): self.model = model self.model_names = model.algo_model_exec.names + self.model_ids = list(self.model_names.keys()) def pre_process(self, frame): return frame - def model_inference(self, frame): - results_generator = self.model.algo_model_exec.predict(source=frame, imgsz=self.model.input_size, + def model_inference(self, frames): + results_generator = self.model.algo_model_exec.predict(source=frames, imgsz=self.model.input_size, save_txt=False, save=False, - verbose=False, stream=True) - results = list(results_generator) # 确保生成器转换为列表 - result = results[0] - # logger.debug(f"model {self.model.algo_model_info.name} result: {len(result)}") - return result + verbose=True, stream=True) + result_boxes = [] + for r in results_generator: + result_boxes.append(r.boxes) - def post_process(self, frame, model_result, annotator): + return result_boxes + + def post_process(self, frames, model_results, annotators): results = [] - for box in model_result.boxes: - results.append( - { - 'object_class_id': int(box.cls), - 'object_class_name': self.model_names[int(box.cls)], - 'confidence': float(box.conf), - 'location': ", ".join([f"{x:.6f}" for x in box.xyxyn.cpu().squeeze().tolist()]) - } - ) - if annotator is not None: - for s_box in model_result.boxes: - annotator.box_label(s_box.xyxy.cpu().squeeze(), - f"{self.model_names[int(s_box.cls)]} {float(s_box.conf):.2f}", - color=colors(int(s_box.cls)), - rotated=False) + for idx,frame in enumerate(frames): + frame_boxes = model_results[idx] + annotator = annotators[idx] + frame_result = [] + for box in frame_boxes: + frame_result.append( + { + 'object_class_id': int(box.cls), + 'object_class_name': self.model_names[int(box.cls)], + 'confidence': float(box.conf), + 'location': ", ".join([f"{x:.6f}" for x in box.xyxyn.cpu().squeeze().tolist()]) + } + ) + results.append(frame_result) + if annotator is not None: + for s_box in frame_boxes: + annotator.box_label(s_box.xyxy.cpu().squeeze(), + f"{self.model_names[int(s_box.cls)]} {float(s_box.conf):.2f}", + color=colors(int(s_box.cls)), + rotated=False) return results - def run(self, frame, annotator): - processed_frame = self.pre_process(frame=frame) - model_result = self.model_inference(frame=processed_frame) - result = self.post_process(frame=frame, model_result=model_result, annotator=annotator) - return frame, result + def run(self, frames, annotators): + processed_frames = [self.pre_process(frame=frame) for frame in frames] + result_boxes = self.model_inference(frames=processed_frames) + results = self.post_process(frames=frames, model_results=result_boxes, annotators=annotators) + return frames, results diff --git a/model_handler/coco_engine_model_handler.py b/model_handler/coco_engine_model_handler.py index e6023d0..ec8006f 100644 --- a/model_handler/coco_engine_model_handler.py +++ b/model_handler/coco_engine_model_handler.py @@ -88,3 +88,4 @@ 78: 'hair_drier', 79: 'toothbrush', } + self.model_ids = list(self.model_names.keys()) diff --git a/model_handler/labor_engine_model_handler.py b/model_handler/labor_engine_model_handler.py index d0fefba..8892fcb 100644 --- a/model_handler/labor_engine_model_handler.py +++ b/model_handler/labor_engine_model_handler.py @@ -68,3 +68,4 @@ 58: '鼓风机', } + self.model_ids = list(self.model_names.keys()) diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py deleted file mode 100644 index 9109d75..0000000 --- a/algo/algo_runner_manager.py +++ /dev/null @@ -1,8 +0,0 @@ -from algo.algo_runner import AlgoRunner - -# algo_runner = AlgoRunner() - - -def get_algo_runner(): - pass - # return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 6a86830..d2c647e 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -3,7 +3,7 @@ from copy import deepcopy from dataclasses import dataclass import importlib -from datetime import datetime +from datetime import datetime, timedelta from threading import Event from typing import List, Dict @@ -48,6 +48,8 @@ self.__stop_event = Event() # 使用 Event 控制线程的运行状态 self.frame_ts = None self.push_ts = None + self.frames_detected = 0 + self.fps_ts = None self.thread_id = thread_id self.device_frame_service = DeviceFrameService(db_session) @@ -77,37 +79,40 @@ return True return False - def save_frame_results(self, frame, results_map): - if not self.check_frame_interval(): - return - # device_frame = self.device_frame_service.add_frame(self.device.id, frame) - # frame_id = device_frame.id + def save_frame_results(self, frames, result_maps): + for idx,frame in enumerate(frames): + results_map = result_maps[idx] - future = asyncio.run_coroutine_threadsafe( - self.device_frame_service.add_frame(self.device.id, frame), self.main_loop - ) - device_frame = future.result() - frame_id = device_frame.id + if not self.check_frame_interval(): + return + # device_frame = self.device_frame_service.add_frame(self.device.id, frame) + # frame_id = device_frame.id - logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') - frame_results = [] - for model_exec_id, results in results_map.items(): - for r in results: - frame_result = FrameAnalysisResultCreate( - device_id=self.device.id, - frame_id=frame_id, - algo_model_id=model_exec_id, - object_class_id=r.object_class_id, - object_class_name=r.object_class_name, - confidence=round(r.confidence, 4), - location=r.location, - ) - frame_results.append(frame_result) - asyncio.run_coroutine_threadsafe( - self.frame_analysis_result_service.add_frame_analysis_results(frame_results), - self.main_loop - ) - self.thread_pool.submit_task(self.push_frame_results, frame_results) + future = asyncio.run_coroutine_threadsafe( + self.device_frame_service.add_frame(self.device.id, frame), self.main_loop + ) + device_frame = future.result() + frame_id = device_frame.id + + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') + frame_results = [] + for model_exec_id, results in results_map.items(): + for r in results: + frame_result = FrameAnalysisResultCreate( + device_id=self.device.id, + frame_id=frame_id, + algo_model_id=model_exec_id, + object_class_id=r.object_class_id, + object_class_name=r.object_class_name, + confidence=round(r.confidence, 4), + location=r.location, + ) + frame_results.append(frame_result) + asyncio.run_coroutine_threadsafe( + self.frame_analysis_result_service.add_frame_analysis_results(frame_results), + self.main_loop + ) + self.thread_pool.submit_task(self.push_frame_results, frame_results) def push_frame_results(self, frame_results): global_config = GlobalConfig() @@ -124,31 +129,63 @@ ) self.push_ts = current_time # 更新推送时间戳 + def log_fps(self, frame_count): + self.frames_detected += frame_count + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_detected / 10.0 + self.frames_detected = 0 + logger.info(f"FPS (detect) for device {self.device.code}: {fps}") + self.fps_ts = current_time + def run(self): while not self.stream_loader.init: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 self.stream_loader.init_cap() - for frame in self.stream_loader: + for frames in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 # print('frame') - if frame is None: + if frames is None or len(frames)<=0: continue self.device_status_manager.set_status(device_id=self.device.id) - results_map = {} - annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + + result_maps = [] # 保存每个frame的result map + annotators = [] + for frame in frames: + annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + annotators.append(annotator) + for model_exec in self.model_exec_list: handle_task_name = model_exec.algo_model_info.handle_task handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) handler_instance = handler_cls(model_exec) - frame, results = handler_instance.run(frame, annotator) - results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] + frames, results = handler_instance.run(frames, annotators) + + # 遍历检测结果,按帧存储 + for frame_idx, r in enumerate(results): # 遍历每帧结果 + if len(result_maps) <= frame_idx: # 初始化 frame_results_map + result_maps.append({}) + frame_results_map = result_maps[frame_idx] + # 为当前模型存储检测结果 + if model_exec.algo_model_id not in frame_results_map: + frame_results_map[model_exec.algo_model_id] = [] + # 添加检测结果 + frame_results_map[model_exec.algo_model_id].extend( + DetectionResult.from_dict(box) for box in r + ) # 结果处理 - self.thread_pool.submit_task(self.save_frame_results, frame, results_map) - self.display_frame_manager.add_frame(self.device.id, annotator.result()) + self.thread_pool.submit_task(self.save_frame_results, frames, result_maps) + for annotator in annotators: + self.display_frame_manager.add_frame(self.device.id, annotator.result()) + + self.log_fps(len(frames)) + + # future = asyncio.run_coroutine_threadsafe( # self.display_frame_manager.add_frame(self.device.id, annotator.result()), self.main_loop # ) diff --git a/algo/model_manager.py b/algo/model_manager.py index 1c30867..e84676f 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -19,11 +19,12 @@ class ModelManager: - def __init__(self, model_service: ModelService, model_warm_up=5): + def __init__(self, model_service: ModelService, model_warm_up=5, batch_size=4): # self.db = db self.model_service = model_service self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up + self.batch_size=4 async def query_model_inuse(self): algo_model_list = list(await self.model_service.get_models_in_use()) @@ -58,8 +59,11 @@ if not isinstance(imgsz, list): imgsz = [imgsz, imgsz] dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + dummy_inputs = [] + for _ in range(self.batch_size): + dummy_inputs.append(dummy_input) for i in range(self.model_warm_up): - algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + algo_model_exec.algo_model_exec.predict(source=dummy_inputs, imgsz=imgsz, verbose=False,save=False,save_txt=False) logger.info(f'warm up model {model_name} success!') logger.info(f'load model {model_name} success!') diff --git a/algo/scene_runner.py b/algo/scene_runner.py index e5c067e..5120f96 100644 --- a/algo/scene_runner.py +++ b/algo/scene_runner.py @@ -69,7 +69,7 @@ try: handle_task_name = scene.scene_handle_task handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) - handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) + handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop,scene.range_points) future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) self.device_tasks[device.id] = handler_instance logger.info(f'start thread {thread_id}, device info: {device}') diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 4133388..202213c 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -1,13 +1,20 @@ +from datetime import datetime, timedelta + import cv2 import time import numpy as np from threading import Thread, Event + +import queue + from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool class OpenCVStreamLoad: def __init__(self, camera_url, camera_code, device_thread_id = '', + batch_size=4, + queue_size=100, retry_interval=1, vid_stride=1): assert camera_url is not None and camera_url != '' @@ -19,10 +26,15 @@ self.init = False self.frame = None + self.frame_queue = queue.Queue(maxsize=queue_size) + self.fps_ts = None self.cap = None + self.batch_size = batch_size + self.frames_read = 0 self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 self.thread_pool = GlobalThreadPool() + self.init_cap() # self.thread = Thread(target=self.update, daemon=True) # self.thread.start() @@ -44,7 +56,14 @@ 尝试创建视频流捕获对象。 """ try: - cap = cv2.VideoCapture(self.url) + # cap = cv2.VideoCapture(self.url) + gst_pipeline = ( + f"rtspsrc location={self.url} ! " + f"rtph264depay ! h264parse ! " + f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + f"appsink" + ) + cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) # 可以在这里设置cap的一些属性,如果需要的话 return cap except Exception as e: @@ -74,9 +93,13 @@ cap = cv2.VideoCapture(self.url) # cap = cv2.VideoCapture(self.url, cv2.CAP_FFMPEG) - # 使用 GStreamer pipeline 作为 VideoCapture 的输入 - # gst_str = f"rtspsrc location={self.url} ! decodebin ! videoconvert ! appsink" - # cap = cv2.VideoCapture(gst_str, cv2.CAP_GSTREAMER) + # gst_pipeline = ( + # f"rtspsrc location={self.url} ! " + # f"rtph264depay ! h264parse ! " + # f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + # f"appsink" + # ) + # cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) if cap.isOpened(): logger.info(f"{self.url} connected successfully!") @@ -91,6 +114,15 @@ return cap + def log_fps(self): + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_read / 10.0 + self.frames_read = 0 + logger.info(f"FPS (read) for device {self.camera_code}: {fps}") + self.fps_ts = current_time + def update(self): vid_n = 0 log_n = 0 @@ -111,6 +143,9 @@ else: vid_n += 1 self.frame = frame + self.frames_read += 1 + if not self.frame_queue.full(): + self.frame_queue.put(frame) if log_n % 1000 == 0: logger.debug(f'{self.url} cap success') log_n = (log_n + 1) % 250 @@ -120,12 +155,22 @@ self.cap.release() self.frame = None self.cap = self.get_connect() # 尝试重新连接 + self.log_fps() def __iter__(self): return self def __next__(self): - return self.frame + batch_frames = [] + + queue_length = self.frame_queue.qsize() + if queue_length < self.batch_size: + return [] + + while not self.frame_queue.empty() and len(batch_frames) < self.batch_size: + frame = self.frame_queue.get() + batch_frames.append(frame) + return batch_frames def stop(self): logger.info(f'stop stream loader {self.url},device thread {self.device_thread_id} stopped') diff --git a/apis/control.py b/apis/control.py index 9c6d6fb..2d351d8 100644 --- a/apis/control.py +++ b/apis/control.py @@ -9,6 +9,7 @@ from fastapi import APIRouter import docker import socket +import signal from apis.base import standard_error_response, standard_response from common.global_logger import logger @@ -47,6 +48,23 @@ return None +def handle_sigterm(signum, frame): + print(f"Received signal: {signum}, shutting down gracefully...") + os._exit(0) + + +signal.signal(signal.SIGTERM, handle_sigterm) + + +def shutdown(): + # 等待短暂时间以确保响应已发送 + time.sleep(1) + logger.info("Shutting down the application...") + # sys.exit(0) + # os.kill(os.getpid(), signal.SIGTERM) + os._exit(0) + + @router.get("/restart") async def restart(): try: @@ -92,17 +110,20 @@ logger.error(f"Failed to restart container asynchronously: {ex}") # 在新线程中启动重启操作 - threading.Thread(target=restart_container_async).start() + # threading.Thread(target=restart_container_async).start() + threading.Thread(target=shutdown).start() return standard_response() except Exception as e: traceback.print_exc() return standard_error_response(code=500, message=f"Failed to restart container: {e}") + @router.get("/sync_test") def sync_test(): return standard_response() + @router.get("/async_test") async def async_test(): return standard_response() diff --git a/apis/device_scene_realtion.py b/apis/device_scene_realtion.py index d9aa3a2..77d3677 100644 --- a/apis/device_scene_realtion.py +++ b/apis/device_scene_realtion.py @@ -22,8 +22,8 @@ @router.post("/update_by_device", response_model=StandardResponse[DeviceSceneRelation]) -async def update_by_device(device_id: int, scene_id: int, +async def update_by_device(device_id: int, scene_id: int, range_points: str = None, db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - relation = await service.update_relation_by_device(device_id, scene_id) + relation = await service.update_relation_by_device(device_id, scene_id,range_points) return standard_response(data=relation) diff --git a/app_instance.py b/app_instance.py index 82a750d..c648f6e 100644 --- a/app_instance.py +++ b/app_instance.py @@ -1,4 +1,5 @@ import asyncio + from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException @@ -17,6 +18,7 @@ from services.scene_service import SceneService from services.schedule_job import start_scheduler from tcp.tcp_manager import TcpManager +from tcp.tcp_server import start_server _app = None # 创建一个私有变量来存储 app 实例 @@ -69,6 +71,8 @@ app.state.scene_runner = scene_runner await scene_runner.start() + main_loop.create_task(start_server()) + main_loop.create_task(start_scheduler()) yield # 允许请求处理 diff --git a/common/byte_utils.py b/common/byte_utils.py index 301bb8a..ea3f4ea 100644 --- a/common/byte_utils.py +++ b/common/byte_utils.py @@ -1,3 +1,11 @@ def format_bytes(data: bytes): - return ''.join(f'\\x{byte:02X}' for byte in data) \ No newline at end of file + return ''.join(f'\\x{byte:02X}' for byte in data) + + +if __name__ == '__main__': + msg = b'\xaa\x01\x00\x93\x12\x00\xA6' + print(format_bytes(msg)) + + msg = b'\xAA\x01\x00\x95\x00\x00\x96' + print(format_bytes(msg)) \ No newline at end of file diff --git a/common/consts.py b/common/consts.py index 5a3286c..6deaedf 100644 --- a/common/consts.py +++ b/common/consts.py @@ -36,3 +36,4 @@ GAS = 1 ALGO_RESULT = 2 ALARM = 3 + HARMFUL_GAS = 4 diff --git a/common/detect_utils.py b/common/detect_utils.py new file mode 100644 index 0000000..33c340c --- /dev/null +++ b/common/detect_utils.py @@ -0,0 +1,72 @@ +import matplotlib.path as mpath + + +def is_within_alert_range(bbox, range): + if range is None: + return True + bottom_left = [bbox[0], bbox[3]] + bottom_right = [bbox[2], bbox[3]] + + in_fence = is_point_in_polygon(bottom_left, range) or is_point_in_polygon(bottom_right, range) + return in_fence + + +def is_point_in_polygon(point, polygon): + """ + 判断点是否在多边形内 + + 参数: + - point: 待判断的点,格式为(x, y) + - polygon: 多边形的顶点坐标列表,格式为[(x1, y1), (x2, y2), ...] + + 返回: + - 如果点在多边形内,返回True;否则,返回False。 + """ + # 创建多边形路径对象 + # polygon = np.array(polygon, dtype=np.int32) + # + # hull = ConvexHull(polygon) + # sorted_coordinates = polygon[hull.vertices] + + poly_path = mpath.Path(polygon) + # 使用contains_point方法检查点是否在多边形内 + return poly_path.contains_point(point) + + +def intersection_area(bbox1, bbox2): + # 计算两个坐标框的重叠面积 + x1, y1, x2, y2 = bbox1 + x3, y3, x4, y4 = bbox2 + + xi1 = max(x1, x3) + yi1 = max(y1, y3) + xi2 = min(x2, x4) + yi2 = min(y2, y4) + + width = max(0, xi2 - xi1) + height = max(0, yi2 - yi1) + + return width * height + + +def bbox_area(bbox): + # 计算坐标框的面积 + x1, y1, x2, y2 = bbox + return (x2 - x1) * (y2 - y1) + + +def get_person_head(person_bbox, head_bboxes): + best_head = None + max_overlap = 0 + + for head in head_bboxes: + head_bbox = head.xyxy.cpu().squeeze() + overlap_area = intersection_area(person_bbox, head_bbox) + head_area = bbox_area(head_bbox) + overlap_ratio = overlap_area / head_area + + if overlap_ratio >= 0.8 and (best_head is None or overlap_area > max_overlap or ( + overlap_area == max_overlap and float(head.conf) > float(best_head.conf))): + best_head = head + max_overlap = overlap_area + return best_head diff --git a/common/harmful_gas_manager.py b/common/harmful_gas_manager.py new file mode 100644 index 0000000..297e02c --- /dev/null +++ b/common/harmful_gas_manager.py @@ -0,0 +1,44 @@ +import threading +from datetime import datetime + + +class HarmfulGasManager: + + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + # 确保线程安全的单例模式 + if not cls._instance: + with cls._lock: + if not cls._instance: + cls._instance = super(HarmfulGasManager, cls).__new__(cls) + return cls._instance + + def __init__(self): + # 初始化一次,避免重复初始化 + if not hasattr(self, "device_data"): + self.device_data = {} + self.lock = threading.Lock() + + def get_device_all_data(self, device_code): + """获取指定设备的在线状态""" + with self.lock: + return self.device_data.get(device_code, None) + + def get_device_data(self, device_code, gas_type): + """获取指定设备的在线状态""" + with self.lock: + device_data = self.device_data.get(device_code, None) + if device_data: + return device_data.get(gas_type, None) + return None + + def set_device_data(self, device_code, gas_type, gas_data): + """设置指定设备的在线状态""" + with self.lock: + device_data = self.device_data.get(device_code, None) + if not device_data: + device_data = {} + self.device_data[device_code] = device_data + device_data[gas_type] = gas_data \ No newline at end of file diff --git a/common/http_utils.py b/common/http_utils.py index 5f15707..a49c6ea 100644 --- a/common/http_utils.py +++ b/common/http_utils.py @@ -22,3 +22,11 @@ logger.info(f"Response: {response.status_code}, {response.text}") except requests.RequestException as e: logger.error(f"Failed to push data: {e}") + + +def get_request(url, headers=None): + try: + response = requests.get(url, headers=headers) + return response.json() + except requests.RequestException as e: + logger.error(f"Failed to get data: {e}") diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db index d1acdd4..694b825 100644 --- a/db/safe-algo-pro.db +++ b/db/safe-algo-pro.db Binary files differ diff --git a/entity/alarm_record.py b/entity/alarm_record.py new file mode 100644 index 0000000..1ebe474 --- /dev/null +++ b/entity/alarm_record.py @@ -0,0 +1,32 @@ +from datetime import datetime +from typing import Optional + +from sqlmodel import Field, SQLModel + + +class AlarmRecordBase(SQLModel): + # todo tree code?? alarm_category?? + device_code: str + alarm_type: int + alarm_content: str + alarm_value: Optional[float] = None + alarm_image: Optional[str] = None + alarm_time: datetime = Field(default_factory=datetime.now) + + class Config: + json_encoders = { + datetime: lambda v: v.strftime('%Y-%m-%d %H:%M:%S') + } + + +class AlarmRecord(AlarmRecordBase, table=True): + __tablename__ = 'alarm_record' + id: Optional[int] = Field(default=None, primary_key=True) + + +class AlarmRecordCreate(AlarmRecordBase): + pass + + +class AlarmRecordInfo(AlarmRecordBase): + id: int diff --git a/entity/device_scene_relation.py b/entity/device_scene_relation.py index c4543d7..38481ef 100644 --- a/entity/device_scene_relation.py +++ b/entity/device_scene_relation.py @@ -8,6 +8,7 @@ class DeviceSceneRelationBase(SQLModel): scene_id: int device_id: int + range_points: Optional[str] = None class DeviceSceneRelation(DeviceSceneRelationBase, TimestampMixin, table=True): diff --git a/model_handler/base_model_handler.py b/model_handler/base_model_handler.py index 46e0d53..c7c3490 100644 --- a/model_handler/base_model_handler.py +++ b/model_handler/base_model_handler.py @@ -8,41 +8,48 @@ def __init__(self, model: AlgoModelExec): self.model = model self.model_names = model.algo_model_exec.names + self.model_ids = list(self.model_names.keys()) def pre_process(self, frame): return frame - def model_inference(self, frame): - results_generator = self.model.algo_model_exec.predict(source=frame, imgsz=self.model.input_size, + def model_inference(self, frames): + results_generator = self.model.algo_model_exec.predict(source=frames, imgsz=self.model.input_size, save_txt=False, save=False, - verbose=False, stream=True) - results = list(results_generator) # 确保生成器转换为列表 - result = results[0] - # logger.debug(f"model {self.model.algo_model_info.name} result: {len(result)}") - return result + verbose=True, stream=True) + result_boxes = [] + for r in results_generator: + result_boxes.append(r.boxes) - def post_process(self, frame, model_result, annotator): + return result_boxes + + def post_process(self, frames, model_results, annotators): results = [] - for box in model_result.boxes: - results.append( - { - 'object_class_id': int(box.cls), - 'object_class_name': self.model_names[int(box.cls)], - 'confidence': float(box.conf), - 'location': ", ".join([f"{x:.6f}" for x in box.xyxyn.cpu().squeeze().tolist()]) - } - ) - if annotator is not None: - for s_box in model_result.boxes: - annotator.box_label(s_box.xyxy.cpu().squeeze(), - f"{self.model_names[int(s_box.cls)]} {float(s_box.conf):.2f}", - color=colors(int(s_box.cls)), - rotated=False) + for idx,frame in enumerate(frames): + frame_boxes = model_results[idx] + annotator = annotators[idx] + frame_result = [] + for box in frame_boxes: + frame_result.append( + { + 'object_class_id': int(box.cls), + 'object_class_name': self.model_names[int(box.cls)], + 'confidence': float(box.conf), + 'location': ", ".join([f"{x:.6f}" for x in box.xyxyn.cpu().squeeze().tolist()]) + } + ) + results.append(frame_result) + if annotator is not None: + for s_box in frame_boxes: + annotator.box_label(s_box.xyxy.cpu().squeeze(), + f"{self.model_names[int(s_box.cls)]} {float(s_box.conf):.2f}", + color=colors(int(s_box.cls)), + rotated=False) return results - def run(self, frame, annotator): - processed_frame = self.pre_process(frame=frame) - model_result = self.model_inference(frame=processed_frame) - result = self.post_process(frame=frame, model_result=model_result, annotator=annotator) - return frame, result + def run(self, frames, annotators): + processed_frames = [self.pre_process(frame=frame) for frame in frames] + result_boxes = self.model_inference(frames=processed_frames) + results = self.post_process(frames=frames, model_results=result_boxes, annotators=annotators) + return frames, results diff --git a/model_handler/coco_engine_model_handler.py b/model_handler/coco_engine_model_handler.py index e6023d0..ec8006f 100644 --- a/model_handler/coco_engine_model_handler.py +++ b/model_handler/coco_engine_model_handler.py @@ -88,3 +88,4 @@ 78: 'hair_drier', 79: 'toothbrush', } + self.model_ids = list(self.model_names.keys()) diff --git a/model_handler/labor_engine_model_handler.py b/model_handler/labor_engine_model_handler.py index d0fefba..8892fcb 100644 --- a/model_handler/labor_engine_model_handler.py +++ b/model_handler/labor_engine_model_handler.py @@ -68,3 +68,4 @@ 58: '鼓风机', } + self.model_ids = list(self.model_names.keys()) diff --git a/run.sh b/run.sh index bb36274..815883a 100644 --- a/run.sh +++ b/run.sh @@ -8,4 +8,5 @@ export LC_ALL=C.UTF-8 cd /code/safe-algo-pro -python3 main.py +exec python3 main.py +#tail -f /dev/null \ No newline at end of file diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py deleted file mode 100644 index 9109d75..0000000 --- a/algo/algo_runner_manager.py +++ /dev/null @@ -1,8 +0,0 @@ -from algo.algo_runner import AlgoRunner - -# algo_runner = AlgoRunner() - - -def get_algo_runner(): - pass - # return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 6a86830..d2c647e 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -3,7 +3,7 @@ from copy import deepcopy from dataclasses import dataclass import importlib -from datetime import datetime +from datetime import datetime, timedelta from threading import Event from typing import List, Dict @@ -48,6 +48,8 @@ self.__stop_event = Event() # 使用 Event 控制线程的运行状态 self.frame_ts = None self.push_ts = None + self.frames_detected = 0 + self.fps_ts = None self.thread_id = thread_id self.device_frame_service = DeviceFrameService(db_session) @@ -77,37 +79,40 @@ return True return False - def save_frame_results(self, frame, results_map): - if not self.check_frame_interval(): - return - # device_frame = self.device_frame_service.add_frame(self.device.id, frame) - # frame_id = device_frame.id + def save_frame_results(self, frames, result_maps): + for idx,frame in enumerate(frames): + results_map = result_maps[idx] - future = asyncio.run_coroutine_threadsafe( - self.device_frame_service.add_frame(self.device.id, frame), self.main_loop - ) - device_frame = future.result() - frame_id = device_frame.id + if not self.check_frame_interval(): + return + # device_frame = self.device_frame_service.add_frame(self.device.id, frame) + # frame_id = device_frame.id - logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') - frame_results = [] - for model_exec_id, results in results_map.items(): - for r in results: - frame_result = FrameAnalysisResultCreate( - device_id=self.device.id, - frame_id=frame_id, - algo_model_id=model_exec_id, - object_class_id=r.object_class_id, - object_class_name=r.object_class_name, - confidence=round(r.confidence, 4), - location=r.location, - ) - frame_results.append(frame_result) - asyncio.run_coroutine_threadsafe( - self.frame_analysis_result_service.add_frame_analysis_results(frame_results), - self.main_loop - ) - self.thread_pool.submit_task(self.push_frame_results, frame_results) + future = asyncio.run_coroutine_threadsafe( + self.device_frame_service.add_frame(self.device.id, frame), self.main_loop + ) + device_frame = future.result() + frame_id = device_frame.id + + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') + frame_results = [] + for model_exec_id, results in results_map.items(): + for r in results: + frame_result = FrameAnalysisResultCreate( + device_id=self.device.id, + frame_id=frame_id, + algo_model_id=model_exec_id, + object_class_id=r.object_class_id, + object_class_name=r.object_class_name, + confidence=round(r.confidence, 4), + location=r.location, + ) + frame_results.append(frame_result) + asyncio.run_coroutine_threadsafe( + self.frame_analysis_result_service.add_frame_analysis_results(frame_results), + self.main_loop + ) + self.thread_pool.submit_task(self.push_frame_results, frame_results) def push_frame_results(self, frame_results): global_config = GlobalConfig() @@ -124,31 +129,63 @@ ) self.push_ts = current_time # 更新推送时间戳 + def log_fps(self, frame_count): + self.frames_detected += frame_count + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_detected / 10.0 + self.frames_detected = 0 + logger.info(f"FPS (detect) for device {self.device.code}: {fps}") + self.fps_ts = current_time + def run(self): while not self.stream_loader.init: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 self.stream_loader.init_cap() - for frame in self.stream_loader: + for frames in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 # print('frame') - if frame is None: + if frames is None or len(frames)<=0: continue self.device_status_manager.set_status(device_id=self.device.id) - results_map = {} - annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + + result_maps = [] # 保存每个frame的result map + annotators = [] + for frame in frames: + annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + annotators.append(annotator) + for model_exec in self.model_exec_list: handle_task_name = model_exec.algo_model_info.handle_task handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) handler_instance = handler_cls(model_exec) - frame, results = handler_instance.run(frame, annotator) - results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] + frames, results = handler_instance.run(frames, annotators) + + # 遍历检测结果,按帧存储 + for frame_idx, r in enumerate(results): # 遍历每帧结果 + if len(result_maps) <= frame_idx: # 初始化 frame_results_map + result_maps.append({}) + frame_results_map = result_maps[frame_idx] + # 为当前模型存储检测结果 + if model_exec.algo_model_id not in frame_results_map: + frame_results_map[model_exec.algo_model_id] = [] + # 添加检测结果 + frame_results_map[model_exec.algo_model_id].extend( + DetectionResult.from_dict(box) for box in r + ) # 结果处理 - self.thread_pool.submit_task(self.save_frame_results, frame, results_map) - self.display_frame_manager.add_frame(self.device.id, annotator.result()) + self.thread_pool.submit_task(self.save_frame_results, frames, result_maps) + for annotator in annotators: + self.display_frame_manager.add_frame(self.device.id, annotator.result()) + + self.log_fps(len(frames)) + + # future = asyncio.run_coroutine_threadsafe( # self.display_frame_manager.add_frame(self.device.id, annotator.result()), self.main_loop # ) diff --git a/algo/model_manager.py b/algo/model_manager.py index 1c30867..e84676f 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -19,11 +19,12 @@ class ModelManager: - def __init__(self, model_service: ModelService, model_warm_up=5): + def __init__(self, model_service: ModelService, model_warm_up=5, batch_size=4): # self.db = db self.model_service = model_service self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up + self.batch_size=4 async def query_model_inuse(self): algo_model_list = list(await self.model_service.get_models_in_use()) @@ -58,8 +59,11 @@ if not isinstance(imgsz, list): imgsz = [imgsz, imgsz] dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + dummy_inputs = [] + for _ in range(self.batch_size): + dummy_inputs.append(dummy_input) for i in range(self.model_warm_up): - algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + algo_model_exec.algo_model_exec.predict(source=dummy_inputs, imgsz=imgsz, verbose=False,save=False,save_txt=False) logger.info(f'warm up model {model_name} success!') logger.info(f'load model {model_name} success!') diff --git a/algo/scene_runner.py b/algo/scene_runner.py index e5c067e..5120f96 100644 --- a/algo/scene_runner.py +++ b/algo/scene_runner.py @@ -69,7 +69,7 @@ try: handle_task_name = scene.scene_handle_task handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) - handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) + handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop,scene.range_points) future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) self.device_tasks[device.id] = handler_instance logger.info(f'start thread {thread_id}, device info: {device}') diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 4133388..202213c 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -1,13 +1,20 @@ +from datetime import datetime, timedelta + import cv2 import time import numpy as np from threading import Thread, Event + +import queue + from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool class OpenCVStreamLoad: def __init__(self, camera_url, camera_code, device_thread_id = '', + batch_size=4, + queue_size=100, retry_interval=1, vid_stride=1): assert camera_url is not None and camera_url != '' @@ -19,10 +26,15 @@ self.init = False self.frame = None + self.frame_queue = queue.Queue(maxsize=queue_size) + self.fps_ts = None self.cap = None + self.batch_size = batch_size + self.frames_read = 0 self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 self.thread_pool = GlobalThreadPool() + self.init_cap() # self.thread = Thread(target=self.update, daemon=True) # self.thread.start() @@ -44,7 +56,14 @@ 尝试创建视频流捕获对象。 """ try: - cap = cv2.VideoCapture(self.url) + # cap = cv2.VideoCapture(self.url) + gst_pipeline = ( + f"rtspsrc location={self.url} ! " + f"rtph264depay ! h264parse ! " + f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + f"appsink" + ) + cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) # 可以在这里设置cap的一些属性,如果需要的话 return cap except Exception as e: @@ -74,9 +93,13 @@ cap = cv2.VideoCapture(self.url) # cap = cv2.VideoCapture(self.url, cv2.CAP_FFMPEG) - # 使用 GStreamer pipeline 作为 VideoCapture 的输入 - # gst_str = f"rtspsrc location={self.url} ! decodebin ! videoconvert ! appsink" - # cap = cv2.VideoCapture(gst_str, cv2.CAP_GSTREAMER) + # gst_pipeline = ( + # f"rtspsrc location={self.url} ! " + # f"rtph264depay ! h264parse ! " + # f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + # f"appsink" + # ) + # cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) if cap.isOpened(): logger.info(f"{self.url} connected successfully!") @@ -91,6 +114,15 @@ return cap + def log_fps(self): + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_read / 10.0 + self.frames_read = 0 + logger.info(f"FPS (read) for device {self.camera_code}: {fps}") + self.fps_ts = current_time + def update(self): vid_n = 0 log_n = 0 @@ -111,6 +143,9 @@ else: vid_n += 1 self.frame = frame + self.frames_read += 1 + if not self.frame_queue.full(): + self.frame_queue.put(frame) if log_n % 1000 == 0: logger.debug(f'{self.url} cap success') log_n = (log_n + 1) % 250 @@ -120,12 +155,22 @@ self.cap.release() self.frame = None self.cap = self.get_connect() # 尝试重新连接 + self.log_fps() def __iter__(self): return self def __next__(self): - return self.frame + batch_frames = [] + + queue_length = self.frame_queue.qsize() + if queue_length < self.batch_size: + return [] + + while not self.frame_queue.empty() and len(batch_frames) < self.batch_size: + frame = self.frame_queue.get() + batch_frames.append(frame) + return batch_frames def stop(self): logger.info(f'stop stream loader {self.url},device thread {self.device_thread_id} stopped') diff --git a/apis/control.py b/apis/control.py index 9c6d6fb..2d351d8 100644 --- a/apis/control.py +++ b/apis/control.py @@ -9,6 +9,7 @@ from fastapi import APIRouter import docker import socket +import signal from apis.base import standard_error_response, standard_response from common.global_logger import logger @@ -47,6 +48,23 @@ return None +def handle_sigterm(signum, frame): + print(f"Received signal: {signum}, shutting down gracefully...") + os._exit(0) + + +signal.signal(signal.SIGTERM, handle_sigterm) + + +def shutdown(): + # 等待短暂时间以确保响应已发送 + time.sleep(1) + logger.info("Shutting down the application...") + # sys.exit(0) + # os.kill(os.getpid(), signal.SIGTERM) + os._exit(0) + + @router.get("/restart") async def restart(): try: @@ -92,17 +110,20 @@ logger.error(f"Failed to restart container asynchronously: {ex}") # 在新线程中启动重启操作 - threading.Thread(target=restart_container_async).start() + # threading.Thread(target=restart_container_async).start() + threading.Thread(target=shutdown).start() return standard_response() except Exception as e: traceback.print_exc() return standard_error_response(code=500, message=f"Failed to restart container: {e}") + @router.get("/sync_test") def sync_test(): return standard_response() + @router.get("/async_test") async def async_test(): return standard_response() diff --git a/apis/device_scene_realtion.py b/apis/device_scene_realtion.py index d9aa3a2..77d3677 100644 --- a/apis/device_scene_realtion.py +++ b/apis/device_scene_realtion.py @@ -22,8 +22,8 @@ @router.post("/update_by_device", response_model=StandardResponse[DeviceSceneRelation]) -async def update_by_device(device_id: int, scene_id: int, +async def update_by_device(device_id: int, scene_id: int, range_points: str = None, db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - relation = await service.update_relation_by_device(device_id, scene_id) + relation = await service.update_relation_by_device(device_id, scene_id,range_points) return standard_response(data=relation) diff --git a/app_instance.py b/app_instance.py index 82a750d..c648f6e 100644 --- a/app_instance.py +++ b/app_instance.py @@ -1,4 +1,5 @@ import asyncio + from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException @@ -17,6 +18,7 @@ from services.scene_service import SceneService from services.schedule_job import start_scheduler from tcp.tcp_manager import TcpManager +from tcp.tcp_server import start_server _app = None # 创建一个私有变量来存储 app 实例 @@ -69,6 +71,8 @@ app.state.scene_runner = scene_runner await scene_runner.start() + main_loop.create_task(start_server()) + main_loop.create_task(start_scheduler()) yield # 允许请求处理 diff --git a/common/byte_utils.py b/common/byte_utils.py index 301bb8a..ea3f4ea 100644 --- a/common/byte_utils.py +++ b/common/byte_utils.py @@ -1,3 +1,11 @@ def format_bytes(data: bytes): - return ''.join(f'\\x{byte:02X}' for byte in data) \ No newline at end of file + return ''.join(f'\\x{byte:02X}' for byte in data) + + +if __name__ == '__main__': + msg = b'\xaa\x01\x00\x93\x12\x00\xA6' + print(format_bytes(msg)) + + msg = b'\xAA\x01\x00\x95\x00\x00\x96' + print(format_bytes(msg)) \ No newline at end of file diff --git a/common/consts.py b/common/consts.py index 5a3286c..6deaedf 100644 --- a/common/consts.py +++ b/common/consts.py @@ -36,3 +36,4 @@ GAS = 1 ALGO_RESULT = 2 ALARM = 3 + HARMFUL_GAS = 4 diff --git a/common/detect_utils.py b/common/detect_utils.py new file mode 100644 index 0000000..33c340c --- /dev/null +++ b/common/detect_utils.py @@ -0,0 +1,72 @@ +import matplotlib.path as mpath + + +def is_within_alert_range(bbox, range): + if range is None: + return True + bottom_left = [bbox[0], bbox[3]] + bottom_right = [bbox[2], bbox[3]] + + in_fence = is_point_in_polygon(bottom_left, range) or is_point_in_polygon(bottom_right, range) + return in_fence + + +def is_point_in_polygon(point, polygon): + """ + 判断点是否在多边形内 + + 参数: + - point: 待判断的点,格式为(x, y) + - polygon: 多边形的顶点坐标列表,格式为[(x1, y1), (x2, y2), ...] + + 返回: + - 如果点在多边形内,返回True;否则,返回False。 + """ + # 创建多边形路径对象 + # polygon = np.array(polygon, dtype=np.int32) + # + # hull = ConvexHull(polygon) + # sorted_coordinates = polygon[hull.vertices] + + poly_path = mpath.Path(polygon) + # 使用contains_point方法检查点是否在多边形内 + return poly_path.contains_point(point) + + +def intersection_area(bbox1, bbox2): + # 计算两个坐标框的重叠面积 + x1, y1, x2, y2 = bbox1 + x3, y3, x4, y4 = bbox2 + + xi1 = max(x1, x3) + yi1 = max(y1, y3) + xi2 = min(x2, x4) + yi2 = min(y2, y4) + + width = max(0, xi2 - xi1) + height = max(0, yi2 - yi1) + + return width * height + + +def bbox_area(bbox): + # 计算坐标框的面积 + x1, y1, x2, y2 = bbox + return (x2 - x1) * (y2 - y1) + + +def get_person_head(person_bbox, head_bboxes): + best_head = None + max_overlap = 0 + + for head in head_bboxes: + head_bbox = head.xyxy.cpu().squeeze() + overlap_area = intersection_area(person_bbox, head_bbox) + head_area = bbox_area(head_bbox) + overlap_ratio = overlap_area / head_area + + if overlap_ratio >= 0.8 and (best_head is None or overlap_area > max_overlap or ( + overlap_area == max_overlap and float(head.conf) > float(best_head.conf))): + best_head = head + max_overlap = overlap_area + return best_head diff --git a/common/harmful_gas_manager.py b/common/harmful_gas_manager.py new file mode 100644 index 0000000..297e02c --- /dev/null +++ b/common/harmful_gas_manager.py @@ -0,0 +1,44 @@ +import threading +from datetime import datetime + + +class HarmfulGasManager: + + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + # 确保线程安全的单例模式 + if not cls._instance: + with cls._lock: + if not cls._instance: + cls._instance = super(HarmfulGasManager, cls).__new__(cls) + return cls._instance + + def __init__(self): + # 初始化一次,避免重复初始化 + if not hasattr(self, "device_data"): + self.device_data = {} + self.lock = threading.Lock() + + def get_device_all_data(self, device_code): + """获取指定设备的在线状态""" + with self.lock: + return self.device_data.get(device_code, None) + + def get_device_data(self, device_code, gas_type): + """获取指定设备的在线状态""" + with self.lock: + device_data = self.device_data.get(device_code, None) + if device_data: + return device_data.get(gas_type, None) + return None + + def set_device_data(self, device_code, gas_type, gas_data): + """设置指定设备的在线状态""" + with self.lock: + device_data = self.device_data.get(device_code, None) + if not device_data: + device_data = {} + self.device_data[device_code] = device_data + device_data[gas_type] = gas_data \ No newline at end of file diff --git a/common/http_utils.py b/common/http_utils.py index 5f15707..a49c6ea 100644 --- a/common/http_utils.py +++ b/common/http_utils.py @@ -22,3 +22,11 @@ logger.info(f"Response: {response.status_code}, {response.text}") except requests.RequestException as e: logger.error(f"Failed to push data: {e}") + + +def get_request(url, headers=None): + try: + response = requests.get(url, headers=headers) + return response.json() + except requests.RequestException as e: + logger.error(f"Failed to get data: {e}") diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db index d1acdd4..694b825 100644 --- a/db/safe-algo-pro.db +++ b/db/safe-algo-pro.db Binary files differ diff --git a/entity/alarm_record.py b/entity/alarm_record.py new file mode 100644 index 0000000..1ebe474 --- /dev/null +++ b/entity/alarm_record.py @@ -0,0 +1,32 @@ +from datetime import datetime +from typing import Optional + +from sqlmodel import Field, SQLModel + + +class AlarmRecordBase(SQLModel): + # todo tree code?? alarm_category?? + device_code: str + alarm_type: int + alarm_content: str + alarm_value: Optional[float] = None + alarm_image: Optional[str] = None + alarm_time: datetime = Field(default_factory=datetime.now) + + class Config: + json_encoders = { + datetime: lambda v: v.strftime('%Y-%m-%d %H:%M:%S') + } + + +class AlarmRecord(AlarmRecordBase, table=True): + __tablename__ = 'alarm_record' + id: Optional[int] = Field(default=None, primary_key=True) + + +class AlarmRecordCreate(AlarmRecordBase): + pass + + +class AlarmRecordInfo(AlarmRecordBase): + id: int diff --git a/entity/device_scene_relation.py b/entity/device_scene_relation.py index c4543d7..38481ef 100644 --- a/entity/device_scene_relation.py +++ b/entity/device_scene_relation.py @@ -8,6 +8,7 @@ class DeviceSceneRelationBase(SQLModel): scene_id: int device_id: int + range_points: Optional[str] = None class DeviceSceneRelation(DeviceSceneRelationBase, TimestampMixin, table=True): diff --git a/model_handler/base_model_handler.py b/model_handler/base_model_handler.py index 46e0d53..c7c3490 100644 --- a/model_handler/base_model_handler.py +++ b/model_handler/base_model_handler.py @@ -8,41 +8,48 @@ def __init__(self, model: AlgoModelExec): self.model = model self.model_names = model.algo_model_exec.names + self.model_ids = list(self.model_names.keys()) def pre_process(self, frame): return frame - def model_inference(self, frame): - results_generator = self.model.algo_model_exec.predict(source=frame, imgsz=self.model.input_size, + def model_inference(self, frames): + results_generator = self.model.algo_model_exec.predict(source=frames, imgsz=self.model.input_size, save_txt=False, save=False, - verbose=False, stream=True) - results = list(results_generator) # 确保生成器转换为列表 - result = results[0] - # logger.debug(f"model {self.model.algo_model_info.name} result: {len(result)}") - return result + verbose=True, stream=True) + result_boxes = [] + for r in results_generator: + result_boxes.append(r.boxes) - def post_process(self, frame, model_result, annotator): + return result_boxes + + def post_process(self, frames, model_results, annotators): results = [] - for box in model_result.boxes: - results.append( - { - 'object_class_id': int(box.cls), - 'object_class_name': self.model_names[int(box.cls)], - 'confidence': float(box.conf), - 'location': ", ".join([f"{x:.6f}" for x in box.xyxyn.cpu().squeeze().tolist()]) - } - ) - if annotator is not None: - for s_box in model_result.boxes: - annotator.box_label(s_box.xyxy.cpu().squeeze(), - f"{self.model_names[int(s_box.cls)]} {float(s_box.conf):.2f}", - color=colors(int(s_box.cls)), - rotated=False) + for idx,frame in enumerate(frames): + frame_boxes = model_results[idx] + annotator = annotators[idx] + frame_result = [] + for box in frame_boxes: + frame_result.append( + { + 'object_class_id': int(box.cls), + 'object_class_name': self.model_names[int(box.cls)], + 'confidence': float(box.conf), + 'location': ", ".join([f"{x:.6f}" for x in box.xyxyn.cpu().squeeze().tolist()]) + } + ) + results.append(frame_result) + if annotator is not None: + for s_box in frame_boxes: + annotator.box_label(s_box.xyxy.cpu().squeeze(), + f"{self.model_names[int(s_box.cls)]} {float(s_box.conf):.2f}", + color=colors(int(s_box.cls)), + rotated=False) return results - def run(self, frame, annotator): - processed_frame = self.pre_process(frame=frame) - model_result = self.model_inference(frame=processed_frame) - result = self.post_process(frame=frame, model_result=model_result, annotator=annotator) - return frame, result + def run(self, frames, annotators): + processed_frames = [self.pre_process(frame=frame) for frame in frames] + result_boxes = self.model_inference(frames=processed_frames) + results = self.post_process(frames=frames, model_results=result_boxes, annotators=annotators) + return frames, results diff --git a/model_handler/coco_engine_model_handler.py b/model_handler/coco_engine_model_handler.py index e6023d0..ec8006f 100644 --- a/model_handler/coco_engine_model_handler.py +++ b/model_handler/coco_engine_model_handler.py @@ -88,3 +88,4 @@ 78: 'hair_drier', 79: 'toothbrush', } + self.model_ids = list(self.model_names.keys()) diff --git a/model_handler/labor_engine_model_handler.py b/model_handler/labor_engine_model_handler.py index d0fefba..8892fcb 100644 --- a/model_handler/labor_engine_model_handler.py +++ b/model_handler/labor_engine_model_handler.py @@ -68,3 +68,4 @@ 58: '鼓风机', } + self.model_ids = list(self.model_names.keys()) diff --git a/run.sh b/run.sh index bb36274..815883a 100644 --- a/run.sh +++ b/run.sh @@ -8,4 +8,5 @@ export LC_ALL=C.UTF-8 cd /code/safe-algo-pro -python3 main.py +exec python3 main.py +#tail -f /dev/null \ No newline at end of file diff --git a/scene_handler/alarm_message_center.py b/scene_handler/alarm_message_center.py new file mode 100644 index 0000000..fba1397 --- /dev/null +++ b/scene_handler/alarm_message_center.py @@ -0,0 +1,121 @@ +import asyncio +import copy +import time +from collections import deque +from threading import Thread, Lock + +''' +队列消息取出规则: +- 按 alarmCategory(优先级:2 > 1 > 3 > 0)和 category_order 从小到大排序。 +- 确保每个 alarmCategory 在 10 秒内只能发送一次。 + +消息定时清理: +- 每次取消息时会检查队列中长期堆积、超过一定时间(比如 30 秒)的消息,并清理这些消息。 +''' + + +class AlarmMessageCenter: + def __init__(self, device_id, main_loop, tcp_manager=None, message_send_interval=5, category_interval=10, retention_time=30, + category_priority=None): + self.device_id = device_id + self.tcp_manager = tcp_manager + self.main_loop = main_loop + + self.queue = deque() + self.last_sent_time = {} # 记录每个 alarmCategory 的最后发送时间 + self.lock = Lock() + self.message_send_interval = message_send_interval # 消息发送间隔(秒) + self.category_interval = category_interval # 类别发送间隔(秒) + self.retention_time = retention_time # 消息最长保留时间(秒) + if category_priority: + self.category_priority = category_priority + self.auto_update_alarm_priority = False + else: + self.category_priority = category_priority + self.auto_update_alarm_priority = True + + def add_message(self, message_ori): + message = copy.deepcopy(message_ori) + message['timestamp'] = int(time.time()) # 添加消息放入队列的时间 + with self.lock: + self.queue.append(message) + + # 动态更新优先级映射 + if self.auto_update_alarm_priority: + alarm_category = message['alarmCategory'] + if alarm_category not in self.category_priority: + unique_categories = sorted({msg['alarmCategory'] for msg in self.queue}) + self.category_priority = {cat: idx for idx, cat in enumerate(unique_categories)} + print(f"更新优先级映射: {self.category_priority}") + + def _clean_old_messages(self): + """清理长期堆积的消息""" + print(f'清理长期堆积的消息 (队列长度: {len(self.queue)})') + now = time.time() + with self.lock: + self.queue = deque([msg for msg in self.queue if now - msg['timestamp'] <= self.retention_time]) + print(f'清理后的队列长度: {len(self.queue)}') + + # def _get_next_message(self): + # """按优先级和时间规则取出下一条消息""" + # now = time.time() + # with self.lock: + # # 按照规则排序队列 + # sorted_queue = sorted( + # self.queue, + # key=lambda x: ( + # self.alarm_priority.get(x['alarmCategory'], 4), # 按 alarmCategory 优先级排序 + # x['category_order'], # category_order 小的排前面 + # x['timestamp'] # 时间靠前的排前面 + # ) + # ) + # # 遍历排序后的队列,找到符合规则的消息 + # for msg in sorted_queue: + # alarm_category = msg['alarmCategory'] + # if alarm_category not in self.last_sent_time or now - self.last_sent_time[alarm_category] > 10: + # self.queue.remove(msg) + # self.last_sent_time[alarm_category] = now + # return msg + # return None + + def _get_next_message(self): + """按优先级和时间规则取出下一条消息""" + now = time.time() + with self.lock: + # 按优先级依次检查 + for priority in sorted(self.category_priority.values()): + found_valid_message = False # 用于标记当前优先级是否有消息被处理 + for msg in sorted( + (m for m in self.queue if self.category_priority.get(m['alarmCategory'], 99) == priority), + key=lambda x: (-x['timestamp'], x['category_order']) + ): + alarm_category = msg['alarmCategory'] + # 检查是否符合发送条件 + if alarm_category not in self.last_sent_time or now - self.last_sent_time[ + alarm_category] > self.category_interval: + self.queue.remove(msg) + self.last_sent_time[alarm_category] = now + return msg # 找到符合条件的消息立即返回 + found_valid_message = True # 当前优先级存在消息但不符合条件 + # 如果当前优先级的所有消息都被检查过且不符合条件,跳到下一个优先级 + if not found_valid_message: + continue + return None # 如果所有优先级都没有消息符合条件,则返回 None + + def process_messages(self): + while True: + time.sleep(self.message_send_interval) + self._clean_old_messages() # 清理长期堆积的消息 + next_message = self._get_next_message() + if next_message: + self.send_message(next_message) + + def send_message(self, message): + """发送报警消息""" + print(f"发送报警消息: {message['alarmContent']} (类别: {message['alarmCategory']}, 时间: {message['timestamp']})") + if self.tcp_manager: + asyncio.run_coroutine_threadsafe( + self.tcp_manager.send_message_to_device(device_id=self.device_id, + message=message['alarmSoundMessage'], + have_response=False), + self.main_loop) diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py deleted file mode 100644 index 9109d75..0000000 --- a/algo/algo_runner_manager.py +++ /dev/null @@ -1,8 +0,0 @@ -from algo.algo_runner import AlgoRunner - -# algo_runner = AlgoRunner() - - -def get_algo_runner(): - pass - # return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 6a86830..d2c647e 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -3,7 +3,7 @@ from copy import deepcopy from dataclasses import dataclass import importlib -from datetime import datetime +from datetime import datetime, timedelta from threading import Event from typing import List, Dict @@ -48,6 +48,8 @@ self.__stop_event = Event() # 使用 Event 控制线程的运行状态 self.frame_ts = None self.push_ts = None + self.frames_detected = 0 + self.fps_ts = None self.thread_id = thread_id self.device_frame_service = DeviceFrameService(db_session) @@ -77,37 +79,40 @@ return True return False - def save_frame_results(self, frame, results_map): - if not self.check_frame_interval(): - return - # device_frame = self.device_frame_service.add_frame(self.device.id, frame) - # frame_id = device_frame.id + def save_frame_results(self, frames, result_maps): + for idx,frame in enumerate(frames): + results_map = result_maps[idx] - future = asyncio.run_coroutine_threadsafe( - self.device_frame_service.add_frame(self.device.id, frame), self.main_loop - ) - device_frame = future.result() - frame_id = device_frame.id + if not self.check_frame_interval(): + return + # device_frame = self.device_frame_service.add_frame(self.device.id, frame) + # frame_id = device_frame.id - logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') - frame_results = [] - for model_exec_id, results in results_map.items(): - for r in results: - frame_result = FrameAnalysisResultCreate( - device_id=self.device.id, - frame_id=frame_id, - algo_model_id=model_exec_id, - object_class_id=r.object_class_id, - object_class_name=r.object_class_name, - confidence=round(r.confidence, 4), - location=r.location, - ) - frame_results.append(frame_result) - asyncio.run_coroutine_threadsafe( - self.frame_analysis_result_service.add_frame_analysis_results(frame_results), - self.main_loop - ) - self.thread_pool.submit_task(self.push_frame_results, frame_results) + future = asyncio.run_coroutine_threadsafe( + self.device_frame_service.add_frame(self.device.id, frame), self.main_loop + ) + device_frame = future.result() + frame_id = device_frame.id + + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') + frame_results = [] + for model_exec_id, results in results_map.items(): + for r in results: + frame_result = FrameAnalysisResultCreate( + device_id=self.device.id, + frame_id=frame_id, + algo_model_id=model_exec_id, + object_class_id=r.object_class_id, + object_class_name=r.object_class_name, + confidence=round(r.confidence, 4), + location=r.location, + ) + frame_results.append(frame_result) + asyncio.run_coroutine_threadsafe( + self.frame_analysis_result_service.add_frame_analysis_results(frame_results), + self.main_loop + ) + self.thread_pool.submit_task(self.push_frame_results, frame_results) def push_frame_results(self, frame_results): global_config = GlobalConfig() @@ -124,31 +129,63 @@ ) self.push_ts = current_time # 更新推送时间戳 + def log_fps(self, frame_count): + self.frames_detected += frame_count + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_detected / 10.0 + self.frames_detected = 0 + logger.info(f"FPS (detect) for device {self.device.code}: {fps}") + self.fps_ts = current_time + def run(self): while not self.stream_loader.init: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 self.stream_loader.init_cap() - for frame in self.stream_loader: + for frames in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 # print('frame') - if frame is None: + if frames is None or len(frames)<=0: continue self.device_status_manager.set_status(device_id=self.device.id) - results_map = {} - annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + + result_maps = [] # 保存每个frame的result map + annotators = [] + for frame in frames: + annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + annotators.append(annotator) + for model_exec in self.model_exec_list: handle_task_name = model_exec.algo_model_info.handle_task handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) handler_instance = handler_cls(model_exec) - frame, results = handler_instance.run(frame, annotator) - results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] + frames, results = handler_instance.run(frames, annotators) + + # 遍历检测结果,按帧存储 + for frame_idx, r in enumerate(results): # 遍历每帧结果 + if len(result_maps) <= frame_idx: # 初始化 frame_results_map + result_maps.append({}) + frame_results_map = result_maps[frame_idx] + # 为当前模型存储检测结果 + if model_exec.algo_model_id not in frame_results_map: + frame_results_map[model_exec.algo_model_id] = [] + # 添加检测结果 + frame_results_map[model_exec.algo_model_id].extend( + DetectionResult.from_dict(box) for box in r + ) # 结果处理 - self.thread_pool.submit_task(self.save_frame_results, frame, results_map) - self.display_frame_manager.add_frame(self.device.id, annotator.result()) + self.thread_pool.submit_task(self.save_frame_results, frames, result_maps) + for annotator in annotators: + self.display_frame_manager.add_frame(self.device.id, annotator.result()) + + self.log_fps(len(frames)) + + # future = asyncio.run_coroutine_threadsafe( # self.display_frame_manager.add_frame(self.device.id, annotator.result()), self.main_loop # ) diff --git a/algo/model_manager.py b/algo/model_manager.py index 1c30867..e84676f 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -19,11 +19,12 @@ class ModelManager: - def __init__(self, model_service: ModelService, model_warm_up=5): + def __init__(self, model_service: ModelService, model_warm_up=5, batch_size=4): # self.db = db self.model_service = model_service self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up + self.batch_size=4 async def query_model_inuse(self): algo_model_list = list(await self.model_service.get_models_in_use()) @@ -58,8 +59,11 @@ if not isinstance(imgsz, list): imgsz = [imgsz, imgsz] dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + dummy_inputs = [] + for _ in range(self.batch_size): + dummy_inputs.append(dummy_input) for i in range(self.model_warm_up): - algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + algo_model_exec.algo_model_exec.predict(source=dummy_inputs, imgsz=imgsz, verbose=False,save=False,save_txt=False) logger.info(f'warm up model {model_name} success!') logger.info(f'load model {model_name} success!') diff --git a/algo/scene_runner.py b/algo/scene_runner.py index e5c067e..5120f96 100644 --- a/algo/scene_runner.py +++ b/algo/scene_runner.py @@ -69,7 +69,7 @@ try: handle_task_name = scene.scene_handle_task handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) - handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) + handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop,scene.range_points) future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) self.device_tasks[device.id] = handler_instance logger.info(f'start thread {thread_id}, device info: {device}') diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 4133388..202213c 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -1,13 +1,20 @@ +from datetime import datetime, timedelta + import cv2 import time import numpy as np from threading import Thread, Event + +import queue + from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool class OpenCVStreamLoad: def __init__(self, camera_url, camera_code, device_thread_id = '', + batch_size=4, + queue_size=100, retry_interval=1, vid_stride=1): assert camera_url is not None and camera_url != '' @@ -19,10 +26,15 @@ self.init = False self.frame = None + self.frame_queue = queue.Queue(maxsize=queue_size) + self.fps_ts = None self.cap = None + self.batch_size = batch_size + self.frames_read = 0 self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 self.thread_pool = GlobalThreadPool() + self.init_cap() # self.thread = Thread(target=self.update, daemon=True) # self.thread.start() @@ -44,7 +56,14 @@ 尝试创建视频流捕获对象。 """ try: - cap = cv2.VideoCapture(self.url) + # cap = cv2.VideoCapture(self.url) + gst_pipeline = ( + f"rtspsrc location={self.url} ! " + f"rtph264depay ! h264parse ! " + f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + f"appsink" + ) + cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) # 可以在这里设置cap的一些属性,如果需要的话 return cap except Exception as e: @@ -74,9 +93,13 @@ cap = cv2.VideoCapture(self.url) # cap = cv2.VideoCapture(self.url, cv2.CAP_FFMPEG) - # 使用 GStreamer pipeline 作为 VideoCapture 的输入 - # gst_str = f"rtspsrc location={self.url} ! decodebin ! videoconvert ! appsink" - # cap = cv2.VideoCapture(gst_str, cv2.CAP_GSTREAMER) + # gst_pipeline = ( + # f"rtspsrc location={self.url} ! " + # f"rtph264depay ! h264parse ! " + # f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + # f"appsink" + # ) + # cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) if cap.isOpened(): logger.info(f"{self.url} connected successfully!") @@ -91,6 +114,15 @@ return cap + def log_fps(self): + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_read / 10.0 + self.frames_read = 0 + logger.info(f"FPS (read) for device {self.camera_code}: {fps}") + self.fps_ts = current_time + def update(self): vid_n = 0 log_n = 0 @@ -111,6 +143,9 @@ else: vid_n += 1 self.frame = frame + self.frames_read += 1 + if not self.frame_queue.full(): + self.frame_queue.put(frame) if log_n % 1000 == 0: logger.debug(f'{self.url} cap success') log_n = (log_n + 1) % 250 @@ -120,12 +155,22 @@ self.cap.release() self.frame = None self.cap = self.get_connect() # 尝试重新连接 + self.log_fps() def __iter__(self): return self def __next__(self): - return self.frame + batch_frames = [] + + queue_length = self.frame_queue.qsize() + if queue_length < self.batch_size: + return [] + + while not self.frame_queue.empty() and len(batch_frames) < self.batch_size: + frame = self.frame_queue.get() + batch_frames.append(frame) + return batch_frames def stop(self): logger.info(f'stop stream loader {self.url},device thread {self.device_thread_id} stopped') diff --git a/apis/control.py b/apis/control.py index 9c6d6fb..2d351d8 100644 --- a/apis/control.py +++ b/apis/control.py @@ -9,6 +9,7 @@ from fastapi import APIRouter import docker import socket +import signal from apis.base import standard_error_response, standard_response from common.global_logger import logger @@ -47,6 +48,23 @@ return None +def handle_sigterm(signum, frame): + print(f"Received signal: {signum}, shutting down gracefully...") + os._exit(0) + + +signal.signal(signal.SIGTERM, handle_sigterm) + + +def shutdown(): + # 等待短暂时间以确保响应已发送 + time.sleep(1) + logger.info("Shutting down the application...") + # sys.exit(0) + # os.kill(os.getpid(), signal.SIGTERM) + os._exit(0) + + @router.get("/restart") async def restart(): try: @@ -92,17 +110,20 @@ logger.error(f"Failed to restart container asynchronously: {ex}") # 在新线程中启动重启操作 - threading.Thread(target=restart_container_async).start() + # threading.Thread(target=restart_container_async).start() + threading.Thread(target=shutdown).start() return standard_response() except Exception as e: traceback.print_exc() return standard_error_response(code=500, message=f"Failed to restart container: {e}") + @router.get("/sync_test") def sync_test(): return standard_response() + @router.get("/async_test") async def async_test(): return standard_response() diff --git a/apis/device_scene_realtion.py b/apis/device_scene_realtion.py index d9aa3a2..77d3677 100644 --- a/apis/device_scene_realtion.py +++ b/apis/device_scene_realtion.py @@ -22,8 +22,8 @@ @router.post("/update_by_device", response_model=StandardResponse[DeviceSceneRelation]) -async def update_by_device(device_id: int, scene_id: int, +async def update_by_device(device_id: int, scene_id: int, range_points: str = None, db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - relation = await service.update_relation_by_device(device_id, scene_id) + relation = await service.update_relation_by_device(device_id, scene_id,range_points) return standard_response(data=relation) diff --git a/app_instance.py b/app_instance.py index 82a750d..c648f6e 100644 --- a/app_instance.py +++ b/app_instance.py @@ -1,4 +1,5 @@ import asyncio + from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException @@ -17,6 +18,7 @@ from services.scene_service import SceneService from services.schedule_job import start_scheduler from tcp.tcp_manager import TcpManager +from tcp.tcp_server import start_server _app = None # 创建一个私有变量来存储 app 实例 @@ -69,6 +71,8 @@ app.state.scene_runner = scene_runner await scene_runner.start() + main_loop.create_task(start_server()) + main_loop.create_task(start_scheduler()) yield # 允许请求处理 diff --git a/common/byte_utils.py b/common/byte_utils.py index 301bb8a..ea3f4ea 100644 --- a/common/byte_utils.py +++ b/common/byte_utils.py @@ -1,3 +1,11 @@ def format_bytes(data: bytes): - return ''.join(f'\\x{byte:02X}' for byte in data) \ No newline at end of file + return ''.join(f'\\x{byte:02X}' for byte in data) + + +if __name__ == '__main__': + msg = b'\xaa\x01\x00\x93\x12\x00\xA6' + print(format_bytes(msg)) + + msg = b'\xAA\x01\x00\x95\x00\x00\x96' + print(format_bytes(msg)) \ No newline at end of file diff --git a/common/consts.py b/common/consts.py index 5a3286c..6deaedf 100644 --- a/common/consts.py +++ b/common/consts.py @@ -36,3 +36,4 @@ GAS = 1 ALGO_RESULT = 2 ALARM = 3 + HARMFUL_GAS = 4 diff --git a/common/detect_utils.py b/common/detect_utils.py new file mode 100644 index 0000000..33c340c --- /dev/null +++ b/common/detect_utils.py @@ -0,0 +1,72 @@ +import matplotlib.path as mpath + + +def is_within_alert_range(bbox, range): + if range is None: + return True + bottom_left = [bbox[0], bbox[3]] + bottom_right = [bbox[2], bbox[3]] + + in_fence = is_point_in_polygon(bottom_left, range) or is_point_in_polygon(bottom_right, range) + return in_fence + + +def is_point_in_polygon(point, polygon): + """ + 判断点是否在多边形内 + + 参数: + - point: 待判断的点,格式为(x, y) + - polygon: 多边形的顶点坐标列表,格式为[(x1, y1), (x2, y2), ...] + + 返回: + - 如果点在多边形内,返回True;否则,返回False。 + """ + # 创建多边形路径对象 + # polygon = np.array(polygon, dtype=np.int32) + # + # hull = ConvexHull(polygon) + # sorted_coordinates = polygon[hull.vertices] + + poly_path = mpath.Path(polygon) + # 使用contains_point方法检查点是否在多边形内 + return poly_path.contains_point(point) + + +def intersection_area(bbox1, bbox2): + # 计算两个坐标框的重叠面积 + x1, y1, x2, y2 = bbox1 + x3, y3, x4, y4 = bbox2 + + xi1 = max(x1, x3) + yi1 = max(y1, y3) + xi2 = min(x2, x4) + yi2 = min(y2, y4) + + width = max(0, xi2 - xi1) + height = max(0, yi2 - yi1) + + return width * height + + +def bbox_area(bbox): + # 计算坐标框的面积 + x1, y1, x2, y2 = bbox + return (x2 - x1) * (y2 - y1) + + +def get_person_head(person_bbox, head_bboxes): + best_head = None + max_overlap = 0 + + for head in head_bboxes: + head_bbox = head.xyxy.cpu().squeeze() + overlap_area = intersection_area(person_bbox, head_bbox) + head_area = bbox_area(head_bbox) + overlap_ratio = overlap_area / head_area + + if overlap_ratio >= 0.8 and (best_head is None or overlap_area > max_overlap or ( + overlap_area == max_overlap and float(head.conf) > float(best_head.conf))): + best_head = head + max_overlap = overlap_area + return best_head diff --git a/common/harmful_gas_manager.py b/common/harmful_gas_manager.py new file mode 100644 index 0000000..297e02c --- /dev/null +++ b/common/harmful_gas_manager.py @@ -0,0 +1,44 @@ +import threading +from datetime import datetime + + +class HarmfulGasManager: + + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + # 确保线程安全的单例模式 + if not cls._instance: + with cls._lock: + if not cls._instance: + cls._instance = super(HarmfulGasManager, cls).__new__(cls) + return cls._instance + + def __init__(self): + # 初始化一次,避免重复初始化 + if not hasattr(self, "device_data"): + self.device_data = {} + self.lock = threading.Lock() + + def get_device_all_data(self, device_code): + """获取指定设备的在线状态""" + with self.lock: + return self.device_data.get(device_code, None) + + def get_device_data(self, device_code, gas_type): + """获取指定设备的在线状态""" + with self.lock: + device_data = self.device_data.get(device_code, None) + if device_data: + return device_data.get(gas_type, None) + return None + + def set_device_data(self, device_code, gas_type, gas_data): + """设置指定设备的在线状态""" + with self.lock: + device_data = self.device_data.get(device_code, None) + if not device_data: + device_data = {} + self.device_data[device_code] = device_data + device_data[gas_type] = gas_data \ No newline at end of file diff --git a/common/http_utils.py b/common/http_utils.py index 5f15707..a49c6ea 100644 --- a/common/http_utils.py +++ b/common/http_utils.py @@ -22,3 +22,11 @@ logger.info(f"Response: {response.status_code}, {response.text}") except requests.RequestException as e: logger.error(f"Failed to push data: {e}") + + +def get_request(url, headers=None): + try: + response = requests.get(url, headers=headers) + return response.json() + except requests.RequestException as e: + logger.error(f"Failed to get data: {e}") diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db index d1acdd4..694b825 100644 --- a/db/safe-algo-pro.db +++ b/db/safe-algo-pro.db Binary files differ diff --git a/entity/alarm_record.py b/entity/alarm_record.py new file mode 100644 index 0000000..1ebe474 --- /dev/null +++ b/entity/alarm_record.py @@ -0,0 +1,32 @@ +from datetime import datetime +from typing import Optional + +from sqlmodel import Field, SQLModel + + +class AlarmRecordBase(SQLModel): + # todo tree code?? alarm_category?? + device_code: str + alarm_type: int + alarm_content: str + alarm_value: Optional[float] = None + alarm_image: Optional[str] = None + alarm_time: datetime = Field(default_factory=datetime.now) + + class Config: + json_encoders = { + datetime: lambda v: v.strftime('%Y-%m-%d %H:%M:%S') + } + + +class AlarmRecord(AlarmRecordBase, table=True): + __tablename__ = 'alarm_record' + id: Optional[int] = Field(default=None, primary_key=True) + + +class AlarmRecordCreate(AlarmRecordBase): + pass + + +class AlarmRecordInfo(AlarmRecordBase): + id: int diff --git a/entity/device_scene_relation.py b/entity/device_scene_relation.py index c4543d7..38481ef 100644 --- a/entity/device_scene_relation.py +++ b/entity/device_scene_relation.py @@ -8,6 +8,7 @@ class DeviceSceneRelationBase(SQLModel): scene_id: int device_id: int + range_points: Optional[str] = None class DeviceSceneRelation(DeviceSceneRelationBase, TimestampMixin, table=True): diff --git a/model_handler/base_model_handler.py b/model_handler/base_model_handler.py index 46e0d53..c7c3490 100644 --- a/model_handler/base_model_handler.py +++ b/model_handler/base_model_handler.py @@ -8,41 +8,48 @@ def __init__(self, model: AlgoModelExec): self.model = model self.model_names = model.algo_model_exec.names + self.model_ids = list(self.model_names.keys()) def pre_process(self, frame): return frame - def model_inference(self, frame): - results_generator = self.model.algo_model_exec.predict(source=frame, imgsz=self.model.input_size, + def model_inference(self, frames): + results_generator = self.model.algo_model_exec.predict(source=frames, imgsz=self.model.input_size, save_txt=False, save=False, - verbose=False, stream=True) - results = list(results_generator) # 确保生成器转换为列表 - result = results[0] - # logger.debug(f"model {self.model.algo_model_info.name} result: {len(result)}") - return result + verbose=True, stream=True) + result_boxes = [] + for r in results_generator: + result_boxes.append(r.boxes) - def post_process(self, frame, model_result, annotator): + return result_boxes + + def post_process(self, frames, model_results, annotators): results = [] - for box in model_result.boxes: - results.append( - { - 'object_class_id': int(box.cls), - 'object_class_name': self.model_names[int(box.cls)], - 'confidence': float(box.conf), - 'location': ", ".join([f"{x:.6f}" for x in box.xyxyn.cpu().squeeze().tolist()]) - } - ) - if annotator is not None: - for s_box in model_result.boxes: - annotator.box_label(s_box.xyxy.cpu().squeeze(), - f"{self.model_names[int(s_box.cls)]} {float(s_box.conf):.2f}", - color=colors(int(s_box.cls)), - rotated=False) + for idx,frame in enumerate(frames): + frame_boxes = model_results[idx] + annotator = annotators[idx] + frame_result = [] + for box in frame_boxes: + frame_result.append( + { + 'object_class_id': int(box.cls), + 'object_class_name': self.model_names[int(box.cls)], + 'confidence': float(box.conf), + 'location': ", ".join([f"{x:.6f}" for x in box.xyxyn.cpu().squeeze().tolist()]) + } + ) + results.append(frame_result) + if annotator is not None: + for s_box in frame_boxes: + annotator.box_label(s_box.xyxy.cpu().squeeze(), + f"{self.model_names[int(s_box.cls)]} {float(s_box.conf):.2f}", + color=colors(int(s_box.cls)), + rotated=False) return results - def run(self, frame, annotator): - processed_frame = self.pre_process(frame=frame) - model_result = self.model_inference(frame=processed_frame) - result = self.post_process(frame=frame, model_result=model_result, annotator=annotator) - return frame, result + def run(self, frames, annotators): + processed_frames = [self.pre_process(frame=frame) for frame in frames] + result_boxes = self.model_inference(frames=processed_frames) + results = self.post_process(frames=frames, model_results=result_boxes, annotators=annotators) + return frames, results diff --git a/model_handler/coco_engine_model_handler.py b/model_handler/coco_engine_model_handler.py index e6023d0..ec8006f 100644 --- a/model_handler/coco_engine_model_handler.py +++ b/model_handler/coco_engine_model_handler.py @@ -88,3 +88,4 @@ 78: 'hair_drier', 79: 'toothbrush', } + self.model_ids = list(self.model_names.keys()) diff --git a/model_handler/labor_engine_model_handler.py b/model_handler/labor_engine_model_handler.py index d0fefba..8892fcb 100644 --- a/model_handler/labor_engine_model_handler.py +++ b/model_handler/labor_engine_model_handler.py @@ -68,3 +68,4 @@ 58: '鼓风机', } + self.model_ids = list(self.model_names.keys()) diff --git a/run.sh b/run.sh index bb36274..815883a 100644 --- a/run.sh +++ b/run.sh @@ -8,4 +8,5 @@ export LC_ALL=C.UTF-8 cd /code/safe-algo-pro -python3 main.py +exec python3 main.py +#tail -f /dev/null \ No newline at end of file diff --git a/scene_handler/alarm_message_center.py b/scene_handler/alarm_message_center.py new file mode 100644 index 0000000..fba1397 --- /dev/null +++ b/scene_handler/alarm_message_center.py @@ -0,0 +1,121 @@ +import asyncio +import copy +import time +from collections import deque +from threading import Thread, Lock + +''' +队列消息取出规则: +- 按 alarmCategory(优先级:2 > 1 > 3 > 0)和 category_order 从小到大排序。 +- 确保每个 alarmCategory 在 10 秒内只能发送一次。 + +消息定时清理: +- 每次取消息时会检查队列中长期堆积、超过一定时间(比如 30 秒)的消息,并清理这些消息。 +''' + + +class AlarmMessageCenter: + def __init__(self, device_id, main_loop, tcp_manager=None, message_send_interval=5, category_interval=10, retention_time=30, + category_priority=None): + self.device_id = device_id + self.tcp_manager = tcp_manager + self.main_loop = main_loop + + self.queue = deque() + self.last_sent_time = {} # 记录每个 alarmCategory 的最后发送时间 + self.lock = Lock() + self.message_send_interval = message_send_interval # 消息发送间隔(秒) + self.category_interval = category_interval # 类别发送间隔(秒) + self.retention_time = retention_time # 消息最长保留时间(秒) + if category_priority: + self.category_priority = category_priority + self.auto_update_alarm_priority = False + else: + self.category_priority = category_priority + self.auto_update_alarm_priority = True + + def add_message(self, message_ori): + message = copy.deepcopy(message_ori) + message['timestamp'] = int(time.time()) # 添加消息放入队列的时间 + with self.lock: + self.queue.append(message) + + # 动态更新优先级映射 + if self.auto_update_alarm_priority: + alarm_category = message['alarmCategory'] + if alarm_category not in self.category_priority: + unique_categories = sorted({msg['alarmCategory'] for msg in self.queue}) + self.category_priority = {cat: idx for idx, cat in enumerate(unique_categories)} + print(f"更新优先级映射: {self.category_priority}") + + def _clean_old_messages(self): + """清理长期堆积的消息""" + print(f'清理长期堆积的消息 (队列长度: {len(self.queue)})') + now = time.time() + with self.lock: + self.queue = deque([msg for msg in self.queue if now - msg['timestamp'] <= self.retention_time]) + print(f'清理后的队列长度: {len(self.queue)}') + + # def _get_next_message(self): + # """按优先级和时间规则取出下一条消息""" + # now = time.time() + # with self.lock: + # # 按照规则排序队列 + # sorted_queue = sorted( + # self.queue, + # key=lambda x: ( + # self.alarm_priority.get(x['alarmCategory'], 4), # 按 alarmCategory 优先级排序 + # x['category_order'], # category_order 小的排前面 + # x['timestamp'] # 时间靠前的排前面 + # ) + # ) + # # 遍历排序后的队列,找到符合规则的消息 + # for msg in sorted_queue: + # alarm_category = msg['alarmCategory'] + # if alarm_category not in self.last_sent_time or now - self.last_sent_time[alarm_category] > 10: + # self.queue.remove(msg) + # self.last_sent_time[alarm_category] = now + # return msg + # return None + + def _get_next_message(self): + """按优先级和时间规则取出下一条消息""" + now = time.time() + with self.lock: + # 按优先级依次检查 + for priority in sorted(self.category_priority.values()): + found_valid_message = False # 用于标记当前优先级是否有消息被处理 + for msg in sorted( + (m for m in self.queue if self.category_priority.get(m['alarmCategory'], 99) == priority), + key=lambda x: (-x['timestamp'], x['category_order']) + ): + alarm_category = msg['alarmCategory'] + # 检查是否符合发送条件 + if alarm_category not in self.last_sent_time or now - self.last_sent_time[ + alarm_category] > self.category_interval: + self.queue.remove(msg) + self.last_sent_time[alarm_category] = now + return msg # 找到符合条件的消息立即返回 + found_valid_message = True # 当前优先级存在消息但不符合条件 + # 如果当前优先级的所有消息都被检查过且不符合条件,跳到下一个优先级 + if not found_valid_message: + continue + return None # 如果所有优先级都没有消息符合条件,则返回 None + + def process_messages(self): + while True: + time.sleep(self.message_send_interval) + self._clean_old_messages() # 清理长期堆积的消息 + next_message = self._get_next_message() + if next_message: + self.send_message(next_message) + + def send_message(self, message): + """发送报警消息""" + print(f"发送报警消息: {message['alarmContent']} (类别: {message['alarmCategory']}, 时间: {message['timestamp']})") + if self.tcp_manager: + asyncio.run_coroutine_threadsafe( + self.tcp_manager.send_message_to_device(device_id=self.device_id, + message=message['alarmSoundMessage'], + have_response=False), + self.main_loop) diff --git a/scene_handler/alarm_record_center.py b/scene_handler/alarm_record_center.py new file mode 100644 index 0000000..5373cd5 --- /dev/null +++ b/scene_handler/alarm_record_center.py @@ -0,0 +1,128 @@ +import asyncio +import base64 +import time +from datetime import datetime + +import cv2 + +from common.global_thread_pool import GlobalThreadPool +from common.http_utils import send_request +from db.database import get_db +from entity.alarm_record import AlarmRecordCreate +from services.alarm_record_service import AlarmRecordService +from services.global_config import GlobalConfig + + +def image_to_base64(numpy_image, format='jpg'): + # 将NumPy数组转换为图片格式的字节流 + # format指定图片格式,'png'或'jpeg' + success, encoded_image = cv2.imencode(f'.{format}', numpy_image) + if not success: + raise ValueError("Could not encode image") + + # 将字节流转换为Base64编码 + base64_encoded_image = base64.b64encode(encoded_image) + + # 将bytes类型转换为UTF-8字符串 + base64_message = base64_encoded_image.decode('utf-8') + + return base64_message + # return f'data:image/{format};base64,{base64_message}' + + +class AlarmRecordCenter: + def __init__(self, save_interval=-1, main_loop=None): + """ + 初始化报警记录中心 + :param upload_interval: 报警上传间隔(秒),如果 <= 0,则不报警 + """ + self.main_loop = main_loop + self.save_interval = save_interval + self.thread_pool = GlobalThreadPool() + self.global_config = GlobalConfig() + # self.upload_interval = upload_interval + self.device_alarm_upload_time = {} # key: device_code, value: {alarm_type: last upload time} + self.device_alarm_save_time = {} + + def need_alarm(self, device_code, alarm_dict): + """ + 是否需要报警 + :param device_code: 设备编号 + :param alarm_dict: 报警类型字典 + :return: 是否需要报警 + """ + current_time = time.time() + alarm_type = alarm_dict['alarmType'] + + return self.need_save_alarm_record(device_code, current_time, alarm_type) \ + or self.need_upload_alarm_record(device_code, current_time, alarm_type) + + def need_save_alarm_record(self, device_code, current_time, alarm_type): + if self.save_interval <= 0: + return False + + if device_code not in self.device_alarm_save_time: + self.device_alarm_save_time[device_code] = {} + last_save_time = self.device_alarm_save_time[device_code].get(alarm_type) + if last_save_time is None or (current_time - last_save_time) > self.save_interval: + return True + return False + + def need_upload_alarm_record(self, device_code, current_time, alarm_type): + push_config = self.global_config.get_alarm_push_config() + if not push_config or not push_config.push_url: + return False + + if device_code not in self.device_alarm_upload_time: + self.device_alarm_upload_time[device_code] = {} + last_upload_time = self.device_alarm_upload_time[device_code].get(alarm_type) + if last_upload_time is None or (current_time - last_upload_time) > push_config.upload_interval: + return True + return False + + async def save_record(self, alarm_data, alarm_np_img): + async for db in get_db(): + alarm_record_service = AlarmRecordService(db) + await alarm_record_service.add_alarm(alarm_data=alarm_data, alarm_np_img=alarm_np_img) + + def upload_alarm_record(self, device_code, alarm_dict, alarm_np_img=None, alarm_value=None): + """ + 上传报警记录 + :param alarm_value: + :param device_code: 设备编号 + :param alarm_dict: 报警类型字典 + :param alarm_np_img: 报警图片(np array) + """ + # 获取当前时间 + current_time = time.time() + alarm_type = alarm_dict['alarmType'] + + if self.need_save_alarm_record(device_code, current_time, alarm_type): + alarm_record_data = AlarmRecordCreate( + device_code=device_code, + device_id=100, + alarm_type=alarm_type, + alarm_content=alarm_dict['alarmContent'], + alarm_time=datetime.now(), + alarm_value=alarm_value if alarm_value else None, + ) + asyncio.run_coroutine_threadsafe( + self.save_record(alarm_record_data, alarm_np_img), + self.main_loop) + self.device_alarm_save_time[device_code][alarm_type] = current_time + + if self.need_upload_alarm_record(device_code, current_time, alarm_type): + alarm_record = { + 'devcode': device_code, + 'alarmTime': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), + 'alarmType': alarm_dict['alarmType'], + 'alarmContent': alarm_dict['alarmContent'], + } + if alarm_value: + alarm_record['alarmValue'] = alarm_value + if alarm_np_img: + alarm_record['alarmImage'] = image_to_base64(alarm_np_img) + + push_config = self.global_config.get_alarm_push_config() + self.thread_pool.submit_task(send_request, push_config.upload_url, alarm_record) + self.device_alarm_upload_time[device_code][alarm_type] = current_time diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py deleted file mode 100644 index 9109d75..0000000 --- a/algo/algo_runner_manager.py +++ /dev/null @@ -1,8 +0,0 @@ -from algo.algo_runner import AlgoRunner - -# algo_runner = AlgoRunner() - - -def get_algo_runner(): - pass - # return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 6a86830..d2c647e 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -3,7 +3,7 @@ from copy import deepcopy from dataclasses import dataclass import importlib -from datetime import datetime +from datetime import datetime, timedelta from threading import Event from typing import List, Dict @@ -48,6 +48,8 @@ self.__stop_event = Event() # 使用 Event 控制线程的运行状态 self.frame_ts = None self.push_ts = None + self.frames_detected = 0 + self.fps_ts = None self.thread_id = thread_id self.device_frame_service = DeviceFrameService(db_session) @@ -77,37 +79,40 @@ return True return False - def save_frame_results(self, frame, results_map): - if not self.check_frame_interval(): - return - # device_frame = self.device_frame_service.add_frame(self.device.id, frame) - # frame_id = device_frame.id + def save_frame_results(self, frames, result_maps): + for idx,frame in enumerate(frames): + results_map = result_maps[idx] - future = asyncio.run_coroutine_threadsafe( - self.device_frame_service.add_frame(self.device.id, frame), self.main_loop - ) - device_frame = future.result() - frame_id = device_frame.id + if not self.check_frame_interval(): + return + # device_frame = self.device_frame_service.add_frame(self.device.id, frame) + # frame_id = device_frame.id - logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') - frame_results = [] - for model_exec_id, results in results_map.items(): - for r in results: - frame_result = FrameAnalysisResultCreate( - device_id=self.device.id, - frame_id=frame_id, - algo_model_id=model_exec_id, - object_class_id=r.object_class_id, - object_class_name=r.object_class_name, - confidence=round(r.confidence, 4), - location=r.location, - ) - frame_results.append(frame_result) - asyncio.run_coroutine_threadsafe( - self.frame_analysis_result_service.add_frame_analysis_results(frame_results), - self.main_loop - ) - self.thread_pool.submit_task(self.push_frame_results, frame_results) + future = asyncio.run_coroutine_threadsafe( + self.device_frame_service.add_frame(self.device.id, frame), self.main_loop + ) + device_frame = future.result() + frame_id = device_frame.id + + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') + frame_results = [] + for model_exec_id, results in results_map.items(): + for r in results: + frame_result = FrameAnalysisResultCreate( + device_id=self.device.id, + frame_id=frame_id, + algo_model_id=model_exec_id, + object_class_id=r.object_class_id, + object_class_name=r.object_class_name, + confidence=round(r.confidence, 4), + location=r.location, + ) + frame_results.append(frame_result) + asyncio.run_coroutine_threadsafe( + self.frame_analysis_result_service.add_frame_analysis_results(frame_results), + self.main_loop + ) + self.thread_pool.submit_task(self.push_frame_results, frame_results) def push_frame_results(self, frame_results): global_config = GlobalConfig() @@ -124,31 +129,63 @@ ) self.push_ts = current_time # 更新推送时间戳 + def log_fps(self, frame_count): + self.frames_detected += frame_count + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_detected / 10.0 + self.frames_detected = 0 + logger.info(f"FPS (detect) for device {self.device.code}: {fps}") + self.fps_ts = current_time + def run(self): while not self.stream_loader.init: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 self.stream_loader.init_cap() - for frame in self.stream_loader: + for frames in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 # print('frame') - if frame is None: + if frames is None or len(frames)<=0: continue self.device_status_manager.set_status(device_id=self.device.id) - results_map = {} - annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + + result_maps = [] # 保存每个frame的result map + annotators = [] + for frame in frames: + annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + annotators.append(annotator) + for model_exec in self.model_exec_list: handle_task_name = model_exec.algo_model_info.handle_task handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) handler_instance = handler_cls(model_exec) - frame, results = handler_instance.run(frame, annotator) - results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] + frames, results = handler_instance.run(frames, annotators) + + # 遍历检测结果,按帧存储 + for frame_idx, r in enumerate(results): # 遍历每帧结果 + if len(result_maps) <= frame_idx: # 初始化 frame_results_map + result_maps.append({}) + frame_results_map = result_maps[frame_idx] + # 为当前模型存储检测结果 + if model_exec.algo_model_id not in frame_results_map: + frame_results_map[model_exec.algo_model_id] = [] + # 添加检测结果 + frame_results_map[model_exec.algo_model_id].extend( + DetectionResult.from_dict(box) for box in r + ) # 结果处理 - self.thread_pool.submit_task(self.save_frame_results, frame, results_map) - self.display_frame_manager.add_frame(self.device.id, annotator.result()) + self.thread_pool.submit_task(self.save_frame_results, frames, result_maps) + for annotator in annotators: + self.display_frame_manager.add_frame(self.device.id, annotator.result()) + + self.log_fps(len(frames)) + + # future = asyncio.run_coroutine_threadsafe( # self.display_frame_manager.add_frame(self.device.id, annotator.result()), self.main_loop # ) diff --git a/algo/model_manager.py b/algo/model_manager.py index 1c30867..e84676f 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -19,11 +19,12 @@ class ModelManager: - def __init__(self, model_service: ModelService, model_warm_up=5): + def __init__(self, model_service: ModelService, model_warm_up=5, batch_size=4): # self.db = db self.model_service = model_service self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up + self.batch_size=4 async def query_model_inuse(self): algo_model_list = list(await self.model_service.get_models_in_use()) @@ -58,8 +59,11 @@ if not isinstance(imgsz, list): imgsz = [imgsz, imgsz] dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + dummy_inputs = [] + for _ in range(self.batch_size): + dummy_inputs.append(dummy_input) for i in range(self.model_warm_up): - algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + algo_model_exec.algo_model_exec.predict(source=dummy_inputs, imgsz=imgsz, verbose=False,save=False,save_txt=False) logger.info(f'warm up model {model_name} success!') logger.info(f'load model {model_name} success!') diff --git a/algo/scene_runner.py b/algo/scene_runner.py index e5c067e..5120f96 100644 --- a/algo/scene_runner.py +++ b/algo/scene_runner.py @@ -69,7 +69,7 @@ try: handle_task_name = scene.scene_handle_task handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) - handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) + handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop,scene.range_points) future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) self.device_tasks[device.id] = handler_instance logger.info(f'start thread {thread_id}, device info: {device}') diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 4133388..202213c 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -1,13 +1,20 @@ +from datetime import datetime, timedelta + import cv2 import time import numpy as np from threading import Thread, Event + +import queue + from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool class OpenCVStreamLoad: def __init__(self, camera_url, camera_code, device_thread_id = '', + batch_size=4, + queue_size=100, retry_interval=1, vid_stride=1): assert camera_url is not None and camera_url != '' @@ -19,10 +26,15 @@ self.init = False self.frame = None + self.frame_queue = queue.Queue(maxsize=queue_size) + self.fps_ts = None self.cap = None + self.batch_size = batch_size + self.frames_read = 0 self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 self.thread_pool = GlobalThreadPool() + self.init_cap() # self.thread = Thread(target=self.update, daemon=True) # self.thread.start() @@ -44,7 +56,14 @@ 尝试创建视频流捕获对象。 """ try: - cap = cv2.VideoCapture(self.url) + # cap = cv2.VideoCapture(self.url) + gst_pipeline = ( + f"rtspsrc location={self.url} ! " + f"rtph264depay ! h264parse ! " + f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + f"appsink" + ) + cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) # 可以在这里设置cap的一些属性,如果需要的话 return cap except Exception as e: @@ -74,9 +93,13 @@ cap = cv2.VideoCapture(self.url) # cap = cv2.VideoCapture(self.url, cv2.CAP_FFMPEG) - # 使用 GStreamer pipeline 作为 VideoCapture 的输入 - # gst_str = f"rtspsrc location={self.url} ! decodebin ! videoconvert ! appsink" - # cap = cv2.VideoCapture(gst_str, cv2.CAP_GSTREAMER) + # gst_pipeline = ( + # f"rtspsrc location={self.url} ! " + # f"rtph264depay ! h264parse ! " + # f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + # f"appsink" + # ) + # cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) if cap.isOpened(): logger.info(f"{self.url} connected successfully!") @@ -91,6 +114,15 @@ return cap + def log_fps(self): + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_read / 10.0 + self.frames_read = 0 + logger.info(f"FPS (read) for device {self.camera_code}: {fps}") + self.fps_ts = current_time + def update(self): vid_n = 0 log_n = 0 @@ -111,6 +143,9 @@ else: vid_n += 1 self.frame = frame + self.frames_read += 1 + if not self.frame_queue.full(): + self.frame_queue.put(frame) if log_n % 1000 == 0: logger.debug(f'{self.url} cap success') log_n = (log_n + 1) % 250 @@ -120,12 +155,22 @@ self.cap.release() self.frame = None self.cap = self.get_connect() # 尝试重新连接 + self.log_fps() def __iter__(self): return self def __next__(self): - return self.frame + batch_frames = [] + + queue_length = self.frame_queue.qsize() + if queue_length < self.batch_size: + return [] + + while not self.frame_queue.empty() and len(batch_frames) < self.batch_size: + frame = self.frame_queue.get() + batch_frames.append(frame) + return batch_frames def stop(self): logger.info(f'stop stream loader {self.url},device thread {self.device_thread_id} stopped') diff --git a/apis/control.py b/apis/control.py index 9c6d6fb..2d351d8 100644 --- a/apis/control.py +++ b/apis/control.py @@ -9,6 +9,7 @@ from fastapi import APIRouter import docker import socket +import signal from apis.base import standard_error_response, standard_response from common.global_logger import logger @@ -47,6 +48,23 @@ return None +def handle_sigterm(signum, frame): + print(f"Received signal: {signum}, shutting down gracefully...") + os._exit(0) + + +signal.signal(signal.SIGTERM, handle_sigterm) + + +def shutdown(): + # 等待短暂时间以确保响应已发送 + time.sleep(1) + logger.info("Shutting down the application...") + # sys.exit(0) + # os.kill(os.getpid(), signal.SIGTERM) + os._exit(0) + + @router.get("/restart") async def restart(): try: @@ -92,17 +110,20 @@ logger.error(f"Failed to restart container asynchronously: {ex}") # 在新线程中启动重启操作 - threading.Thread(target=restart_container_async).start() + # threading.Thread(target=restart_container_async).start() + threading.Thread(target=shutdown).start() return standard_response() except Exception as e: traceback.print_exc() return standard_error_response(code=500, message=f"Failed to restart container: {e}") + @router.get("/sync_test") def sync_test(): return standard_response() + @router.get("/async_test") async def async_test(): return standard_response() diff --git a/apis/device_scene_realtion.py b/apis/device_scene_realtion.py index d9aa3a2..77d3677 100644 --- a/apis/device_scene_realtion.py +++ b/apis/device_scene_realtion.py @@ -22,8 +22,8 @@ @router.post("/update_by_device", response_model=StandardResponse[DeviceSceneRelation]) -async def update_by_device(device_id: int, scene_id: int, +async def update_by_device(device_id: int, scene_id: int, range_points: str = None, db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - relation = await service.update_relation_by_device(device_id, scene_id) + relation = await service.update_relation_by_device(device_id, scene_id,range_points) return standard_response(data=relation) diff --git a/app_instance.py b/app_instance.py index 82a750d..c648f6e 100644 --- a/app_instance.py +++ b/app_instance.py @@ -1,4 +1,5 @@ import asyncio + from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException @@ -17,6 +18,7 @@ from services.scene_service import SceneService from services.schedule_job import start_scheduler from tcp.tcp_manager import TcpManager +from tcp.tcp_server import start_server _app = None # 创建一个私有变量来存储 app 实例 @@ -69,6 +71,8 @@ app.state.scene_runner = scene_runner await scene_runner.start() + main_loop.create_task(start_server()) + main_loop.create_task(start_scheduler()) yield # 允许请求处理 diff --git a/common/byte_utils.py b/common/byte_utils.py index 301bb8a..ea3f4ea 100644 --- a/common/byte_utils.py +++ b/common/byte_utils.py @@ -1,3 +1,11 @@ def format_bytes(data: bytes): - return ''.join(f'\\x{byte:02X}' for byte in data) \ No newline at end of file + return ''.join(f'\\x{byte:02X}' for byte in data) + + +if __name__ == '__main__': + msg = b'\xaa\x01\x00\x93\x12\x00\xA6' + print(format_bytes(msg)) + + msg = b'\xAA\x01\x00\x95\x00\x00\x96' + print(format_bytes(msg)) \ No newline at end of file diff --git a/common/consts.py b/common/consts.py index 5a3286c..6deaedf 100644 --- a/common/consts.py +++ b/common/consts.py @@ -36,3 +36,4 @@ GAS = 1 ALGO_RESULT = 2 ALARM = 3 + HARMFUL_GAS = 4 diff --git a/common/detect_utils.py b/common/detect_utils.py new file mode 100644 index 0000000..33c340c --- /dev/null +++ b/common/detect_utils.py @@ -0,0 +1,72 @@ +import matplotlib.path as mpath + + +def is_within_alert_range(bbox, range): + if range is None: + return True + bottom_left = [bbox[0], bbox[3]] + bottom_right = [bbox[2], bbox[3]] + + in_fence = is_point_in_polygon(bottom_left, range) or is_point_in_polygon(bottom_right, range) + return in_fence + + +def is_point_in_polygon(point, polygon): + """ + 判断点是否在多边形内 + + 参数: + - point: 待判断的点,格式为(x, y) + - polygon: 多边形的顶点坐标列表,格式为[(x1, y1), (x2, y2), ...] + + 返回: + - 如果点在多边形内,返回True;否则,返回False。 + """ + # 创建多边形路径对象 + # polygon = np.array(polygon, dtype=np.int32) + # + # hull = ConvexHull(polygon) + # sorted_coordinates = polygon[hull.vertices] + + poly_path = mpath.Path(polygon) + # 使用contains_point方法检查点是否在多边形内 + return poly_path.contains_point(point) + + +def intersection_area(bbox1, bbox2): + # 计算两个坐标框的重叠面积 + x1, y1, x2, y2 = bbox1 + x3, y3, x4, y4 = bbox2 + + xi1 = max(x1, x3) + yi1 = max(y1, y3) + xi2 = min(x2, x4) + yi2 = min(y2, y4) + + width = max(0, xi2 - xi1) + height = max(0, yi2 - yi1) + + return width * height + + +def bbox_area(bbox): + # 计算坐标框的面积 + x1, y1, x2, y2 = bbox + return (x2 - x1) * (y2 - y1) + + +def get_person_head(person_bbox, head_bboxes): + best_head = None + max_overlap = 0 + + for head in head_bboxes: + head_bbox = head.xyxy.cpu().squeeze() + overlap_area = intersection_area(person_bbox, head_bbox) + head_area = bbox_area(head_bbox) + overlap_ratio = overlap_area / head_area + + if overlap_ratio >= 0.8 and (best_head is None or overlap_area > max_overlap or ( + overlap_area == max_overlap and float(head.conf) > float(best_head.conf))): + best_head = head + max_overlap = overlap_area + return best_head diff --git a/common/harmful_gas_manager.py b/common/harmful_gas_manager.py new file mode 100644 index 0000000..297e02c --- /dev/null +++ b/common/harmful_gas_manager.py @@ -0,0 +1,44 @@ +import threading +from datetime import datetime + + +class HarmfulGasManager: + + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + # 确保线程安全的单例模式 + if not cls._instance: + with cls._lock: + if not cls._instance: + cls._instance = super(HarmfulGasManager, cls).__new__(cls) + return cls._instance + + def __init__(self): + # 初始化一次,避免重复初始化 + if not hasattr(self, "device_data"): + self.device_data = {} + self.lock = threading.Lock() + + def get_device_all_data(self, device_code): + """获取指定设备的在线状态""" + with self.lock: + return self.device_data.get(device_code, None) + + def get_device_data(self, device_code, gas_type): + """获取指定设备的在线状态""" + with self.lock: + device_data = self.device_data.get(device_code, None) + if device_data: + return device_data.get(gas_type, None) + return None + + def set_device_data(self, device_code, gas_type, gas_data): + """设置指定设备的在线状态""" + with self.lock: + device_data = self.device_data.get(device_code, None) + if not device_data: + device_data = {} + self.device_data[device_code] = device_data + device_data[gas_type] = gas_data \ No newline at end of file diff --git a/common/http_utils.py b/common/http_utils.py index 5f15707..a49c6ea 100644 --- a/common/http_utils.py +++ b/common/http_utils.py @@ -22,3 +22,11 @@ logger.info(f"Response: {response.status_code}, {response.text}") except requests.RequestException as e: logger.error(f"Failed to push data: {e}") + + +def get_request(url, headers=None): + try: + response = requests.get(url, headers=headers) + return response.json() + except requests.RequestException as e: + logger.error(f"Failed to get data: {e}") diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db index d1acdd4..694b825 100644 --- a/db/safe-algo-pro.db +++ b/db/safe-algo-pro.db Binary files differ diff --git a/entity/alarm_record.py b/entity/alarm_record.py new file mode 100644 index 0000000..1ebe474 --- /dev/null +++ b/entity/alarm_record.py @@ -0,0 +1,32 @@ +from datetime import datetime +from typing import Optional + +from sqlmodel import Field, SQLModel + + +class AlarmRecordBase(SQLModel): + # todo tree code?? alarm_category?? + device_code: str + alarm_type: int + alarm_content: str + alarm_value: Optional[float] = None + alarm_image: Optional[str] = None + alarm_time: datetime = Field(default_factory=datetime.now) + + class Config: + json_encoders = { + datetime: lambda v: v.strftime('%Y-%m-%d %H:%M:%S') + } + + +class AlarmRecord(AlarmRecordBase, table=True): + __tablename__ = 'alarm_record' + id: Optional[int] = Field(default=None, primary_key=True) + + +class AlarmRecordCreate(AlarmRecordBase): + pass + + +class AlarmRecordInfo(AlarmRecordBase): + id: int diff --git a/entity/device_scene_relation.py b/entity/device_scene_relation.py index c4543d7..38481ef 100644 --- a/entity/device_scene_relation.py +++ b/entity/device_scene_relation.py @@ -8,6 +8,7 @@ class DeviceSceneRelationBase(SQLModel): scene_id: int device_id: int + range_points: Optional[str] = None class DeviceSceneRelation(DeviceSceneRelationBase, TimestampMixin, table=True): diff --git a/model_handler/base_model_handler.py b/model_handler/base_model_handler.py index 46e0d53..c7c3490 100644 --- a/model_handler/base_model_handler.py +++ b/model_handler/base_model_handler.py @@ -8,41 +8,48 @@ def __init__(self, model: AlgoModelExec): self.model = model self.model_names = model.algo_model_exec.names + self.model_ids = list(self.model_names.keys()) def pre_process(self, frame): return frame - def model_inference(self, frame): - results_generator = self.model.algo_model_exec.predict(source=frame, imgsz=self.model.input_size, + def model_inference(self, frames): + results_generator = self.model.algo_model_exec.predict(source=frames, imgsz=self.model.input_size, save_txt=False, save=False, - verbose=False, stream=True) - results = list(results_generator) # 确保生成器转换为列表 - result = results[0] - # logger.debug(f"model {self.model.algo_model_info.name} result: {len(result)}") - return result + verbose=True, stream=True) + result_boxes = [] + for r in results_generator: + result_boxes.append(r.boxes) - def post_process(self, frame, model_result, annotator): + return result_boxes + + def post_process(self, frames, model_results, annotators): results = [] - for box in model_result.boxes: - results.append( - { - 'object_class_id': int(box.cls), - 'object_class_name': self.model_names[int(box.cls)], - 'confidence': float(box.conf), - 'location': ", ".join([f"{x:.6f}" for x in box.xyxyn.cpu().squeeze().tolist()]) - } - ) - if annotator is not None: - for s_box in model_result.boxes: - annotator.box_label(s_box.xyxy.cpu().squeeze(), - f"{self.model_names[int(s_box.cls)]} {float(s_box.conf):.2f}", - color=colors(int(s_box.cls)), - rotated=False) + for idx,frame in enumerate(frames): + frame_boxes = model_results[idx] + annotator = annotators[idx] + frame_result = [] + for box in frame_boxes: + frame_result.append( + { + 'object_class_id': int(box.cls), + 'object_class_name': self.model_names[int(box.cls)], + 'confidence': float(box.conf), + 'location': ", ".join([f"{x:.6f}" for x in box.xyxyn.cpu().squeeze().tolist()]) + } + ) + results.append(frame_result) + if annotator is not None: + for s_box in frame_boxes: + annotator.box_label(s_box.xyxy.cpu().squeeze(), + f"{self.model_names[int(s_box.cls)]} {float(s_box.conf):.2f}", + color=colors(int(s_box.cls)), + rotated=False) return results - def run(self, frame, annotator): - processed_frame = self.pre_process(frame=frame) - model_result = self.model_inference(frame=processed_frame) - result = self.post_process(frame=frame, model_result=model_result, annotator=annotator) - return frame, result + def run(self, frames, annotators): + processed_frames = [self.pre_process(frame=frame) for frame in frames] + result_boxes = self.model_inference(frames=processed_frames) + results = self.post_process(frames=frames, model_results=result_boxes, annotators=annotators) + return frames, results diff --git a/model_handler/coco_engine_model_handler.py b/model_handler/coco_engine_model_handler.py index e6023d0..ec8006f 100644 --- a/model_handler/coco_engine_model_handler.py +++ b/model_handler/coco_engine_model_handler.py @@ -88,3 +88,4 @@ 78: 'hair_drier', 79: 'toothbrush', } + self.model_ids = list(self.model_names.keys()) diff --git a/model_handler/labor_engine_model_handler.py b/model_handler/labor_engine_model_handler.py index d0fefba..8892fcb 100644 --- a/model_handler/labor_engine_model_handler.py +++ b/model_handler/labor_engine_model_handler.py @@ -68,3 +68,4 @@ 58: '鼓风机', } + self.model_ids = list(self.model_names.keys()) diff --git a/run.sh b/run.sh index bb36274..815883a 100644 --- a/run.sh +++ b/run.sh @@ -8,4 +8,5 @@ export LC_ALL=C.UTF-8 cd /code/safe-algo-pro -python3 main.py +exec python3 main.py +#tail -f /dev/null \ No newline at end of file diff --git a/scene_handler/alarm_message_center.py b/scene_handler/alarm_message_center.py new file mode 100644 index 0000000..fba1397 --- /dev/null +++ b/scene_handler/alarm_message_center.py @@ -0,0 +1,121 @@ +import asyncio +import copy +import time +from collections import deque +from threading import Thread, Lock + +''' +队列消息取出规则: +- 按 alarmCategory(优先级:2 > 1 > 3 > 0)和 category_order 从小到大排序。 +- 确保每个 alarmCategory 在 10 秒内只能发送一次。 + +消息定时清理: +- 每次取消息时会检查队列中长期堆积、超过一定时间(比如 30 秒)的消息,并清理这些消息。 +''' + + +class AlarmMessageCenter: + def __init__(self, device_id, main_loop, tcp_manager=None, message_send_interval=5, category_interval=10, retention_time=30, + category_priority=None): + self.device_id = device_id + self.tcp_manager = tcp_manager + self.main_loop = main_loop + + self.queue = deque() + self.last_sent_time = {} # 记录每个 alarmCategory 的最后发送时间 + self.lock = Lock() + self.message_send_interval = message_send_interval # 消息发送间隔(秒) + self.category_interval = category_interval # 类别发送间隔(秒) + self.retention_time = retention_time # 消息最长保留时间(秒) + if category_priority: + self.category_priority = category_priority + self.auto_update_alarm_priority = False + else: + self.category_priority = category_priority + self.auto_update_alarm_priority = True + + def add_message(self, message_ori): + message = copy.deepcopy(message_ori) + message['timestamp'] = int(time.time()) # 添加消息放入队列的时间 + with self.lock: + self.queue.append(message) + + # 动态更新优先级映射 + if self.auto_update_alarm_priority: + alarm_category = message['alarmCategory'] + if alarm_category not in self.category_priority: + unique_categories = sorted({msg['alarmCategory'] for msg in self.queue}) + self.category_priority = {cat: idx for idx, cat in enumerate(unique_categories)} + print(f"更新优先级映射: {self.category_priority}") + + def _clean_old_messages(self): + """清理长期堆积的消息""" + print(f'清理长期堆积的消息 (队列长度: {len(self.queue)})') + now = time.time() + with self.lock: + self.queue = deque([msg for msg in self.queue if now - msg['timestamp'] <= self.retention_time]) + print(f'清理后的队列长度: {len(self.queue)}') + + # def _get_next_message(self): + # """按优先级和时间规则取出下一条消息""" + # now = time.time() + # with self.lock: + # # 按照规则排序队列 + # sorted_queue = sorted( + # self.queue, + # key=lambda x: ( + # self.alarm_priority.get(x['alarmCategory'], 4), # 按 alarmCategory 优先级排序 + # x['category_order'], # category_order 小的排前面 + # x['timestamp'] # 时间靠前的排前面 + # ) + # ) + # # 遍历排序后的队列,找到符合规则的消息 + # for msg in sorted_queue: + # alarm_category = msg['alarmCategory'] + # if alarm_category not in self.last_sent_time or now - self.last_sent_time[alarm_category] > 10: + # self.queue.remove(msg) + # self.last_sent_time[alarm_category] = now + # return msg + # return None + + def _get_next_message(self): + """按优先级和时间规则取出下一条消息""" + now = time.time() + with self.lock: + # 按优先级依次检查 + for priority in sorted(self.category_priority.values()): + found_valid_message = False # 用于标记当前优先级是否有消息被处理 + for msg in sorted( + (m for m in self.queue if self.category_priority.get(m['alarmCategory'], 99) == priority), + key=lambda x: (-x['timestamp'], x['category_order']) + ): + alarm_category = msg['alarmCategory'] + # 检查是否符合发送条件 + if alarm_category not in self.last_sent_time or now - self.last_sent_time[ + alarm_category] > self.category_interval: + self.queue.remove(msg) + self.last_sent_time[alarm_category] = now + return msg # 找到符合条件的消息立即返回 + found_valid_message = True # 当前优先级存在消息但不符合条件 + # 如果当前优先级的所有消息都被检查过且不符合条件,跳到下一个优先级 + if not found_valid_message: + continue + return None # 如果所有优先级都没有消息符合条件,则返回 None + + def process_messages(self): + while True: + time.sleep(self.message_send_interval) + self._clean_old_messages() # 清理长期堆积的消息 + next_message = self._get_next_message() + if next_message: + self.send_message(next_message) + + def send_message(self, message): + """发送报警消息""" + print(f"发送报警消息: {message['alarmContent']} (类别: {message['alarmCategory']}, 时间: {message['timestamp']})") + if self.tcp_manager: + asyncio.run_coroutine_threadsafe( + self.tcp_manager.send_message_to_device(device_id=self.device_id, + message=message['alarmSoundMessage'], + have_response=False), + self.main_loop) diff --git a/scene_handler/alarm_record_center.py b/scene_handler/alarm_record_center.py new file mode 100644 index 0000000..5373cd5 --- /dev/null +++ b/scene_handler/alarm_record_center.py @@ -0,0 +1,128 @@ +import asyncio +import base64 +import time +from datetime import datetime + +import cv2 + +from common.global_thread_pool import GlobalThreadPool +from common.http_utils import send_request +from db.database import get_db +from entity.alarm_record import AlarmRecordCreate +from services.alarm_record_service import AlarmRecordService +from services.global_config import GlobalConfig + + +def image_to_base64(numpy_image, format='jpg'): + # 将NumPy数组转换为图片格式的字节流 + # format指定图片格式,'png'或'jpeg' + success, encoded_image = cv2.imencode(f'.{format}', numpy_image) + if not success: + raise ValueError("Could not encode image") + + # 将字节流转换为Base64编码 + base64_encoded_image = base64.b64encode(encoded_image) + + # 将bytes类型转换为UTF-8字符串 + base64_message = base64_encoded_image.decode('utf-8') + + return base64_message + # return f'data:image/{format};base64,{base64_message}' + + +class AlarmRecordCenter: + def __init__(self, save_interval=-1, main_loop=None): + """ + 初始化报警记录中心 + :param upload_interval: 报警上传间隔(秒),如果 <= 0,则不报警 + """ + self.main_loop = main_loop + self.save_interval = save_interval + self.thread_pool = GlobalThreadPool() + self.global_config = GlobalConfig() + # self.upload_interval = upload_interval + self.device_alarm_upload_time = {} # key: device_code, value: {alarm_type: last upload time} + self.device_alarm_save_time = {} + + def need_alarm(self, device_code, alarm_dict): + """ + 是否需要报警 + :param device_code: 设备编号 + :param alarm_dict: 报警类型字典 + :return: 是否需要报警 + """ + current_time = time.time() + alarm_type = alarm_dict['alarmType'] + + return self.need_save_alarm_record(device_code, current_time, alarm_type) \ + or self.need_upload_alarm_record(device_code, current_time, alarm_type) + + def need_save_alarm_record(self, device_code, current_time, alarm_type): + if self.save_interval <= 0: + return False + + if device_code not in self.device_alarm_save_time: + self.device_alarm_save_time[device_code] = {} + last_save_time = self.device_alarm_save_time[device_code].get(alarm_type) + if last_save_time is None or (current_time - last_save_time) > self.save_interval: + return True + return False + + def need_upload_alarm_record(self, device_code, current_time, alarm_type): + push_config = self.global_config.get_alarm_push_config() + if not push_config or not push_config.push_url: + return False + + if device_code not in self.device_alarm_upload_time: + self.device_alarm_upload_time[device_code] = {} + last_upload_time = self.device_alarm_upload_time[device_code].get(alarm_type) + if last_upload_time is None or (current_time - last_upload_time) > push_config.upload_interval: + return True + return False + + async def save_record(self, alarm_data, alarm_np_img): + async for db in get_db(): + alarm_record_service = AlarmRecordService(db) + await alarm_record_service.add_alarm(alarm_data=alarm_data, alarm_np_img=alarm_np_img) + + def upload_alarm_record(self, device_code, alarm_dict, alarm_np_img=None, alarm_value=None): + """ + 上传报警记录 + :param alarm_value: + :param device_code: 设备编号 + :param alarm_dict: 报警类型字典 + :param alarm_np_img: 报警图片(np array) + """ + # 获取当前时间 + current_time = time.time() + alarm_type = alarm_dict['alarmType'] + + if self.need_save_alarm_record(device_code, current_time, alarm_type): + alarm_record_data = AlarmRecordCreate( + device_code=device_code, + device_id=100, + alarm_type=alarm_type, + alarm_content=alarm_dict['alarmContent'], + alarm_time=datetime.now(), + alarm_value=alarm_value if alarm_value else None, + ) + asyncio.run_coroutine_threadsafe( + self.save_record(alarm_record_data, alarm_np_img), + self.main_loop) + self.device_alarm_save_time[device_code][alarm_type] = current_time + + if self.need_upload_alarm_record(device_code, current_time, alarm_type): + alarm_record = { + 'devcode': device_code, + 'alarmTime': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), + 'alarmType': alarm_dict['alarmType'], + 'alarmContent': alarm_dict['alarmContent'], + } + if alarm_value: + alarm_record['alarmValue'] = alarm_value + if alarm_np_img: + alarm_record['alarmImage'] = image_to_base64(alarm_np_img) + + push_config = self.global_config.get_alarm_push_config() + self.thread_pool.submit_task(send_request, push_config.upload_url, alarm_record) + self.device_alarm_upload_time[device_code][alarm_type] = current_time diff --git a/scene_handler/block_scene_handler.py b/scene_handler/block_scene_handler.py new file mode 100644 index 0000000..3dd0b31 --- /dev/null +++ b/scene_handler/block_scene_handler.py @@ -0,0 +1,426 @@ +import time +import traceback +from asyncio import Event +from copy import deepcopy +from datetime import datetime + + +from flatbuffers.builder import np +from scipy.spatial import ConvexHull + +from algo.stream_loader import OpenCVStreamLoad +from common.detect_utils import is_within_alert_range, get_person_head, intersection_area, bbox_area +from common.device_status_manager import DeviceStatusManager +from common.global_logger import logger +from common.global_thread_pool import GlobalThreadPool +from common.harmful_gas_manager import HarmfulGasManager +from common.image_plotting import Annotator +from entity.device import Device +from scene_handler.alarm_message_center import AlarmMessageCenter +from scene_handler.alarm_record_center import AlarmRecordCenter +from scene_handler.base_scene_handler import BaseSceneHandler +from scene_handler.limit_space_scene_handler import is_overlapping +from services.global_config import GlobalConfig +from tcp.tcp_manager import TcpManager + +from entity.device import Device +from common.http_utils import get_request +from ultralytics import YOLO + +''' +alarmCategory: +0 行为监管 +1 环境监管 +2 人员监管 +3 围栏监管 + +handelType: +0 检测到报警 +1 未检测到报警 +2 人未穿戴报警 +3 其他 +''' +ALARM_DICT = [ + { + 'alarmCategory': 0, + 'alarmType': '1', + 'handelType': 1, + 'category_order': 1, + 'class_idx': [34], + 'alarm_name': 'no_fire_extinguisher', + 'alarmContent': '未检测到灭火器', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '', + }, + { + 'alarmCategory': 0, + 'alarmType': '2', + 'handelType': 1, + 'category_order': 2, + 'class_idx': [43], + 'alarm_name': 'no_barrier_tape', + 'alarmContent': '未检测到警戒线', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '', + }, + { + 'alarmCategory': 0, + 'alarmType': '3', + 'handelType': 1, + 'category_order': 3, + 'class_idx': [48], + 'alarm_name': 'no_cone', + 'alarmContent': '未检测到锥桶', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '', + }, + { + 'alarmCategory': 0, + 'alarmType': '4', + 'handelType': 1, + 'category_order': 4, + 'class_idx': [4, 5, 16], + 'alarm_name': 'no_board', + 'alarmContent': '未检测到指示牌', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '', + }, + { + 'alarmCategory': 0, + 'alarmType': '5', + 'handelType': 2, + 'category_order': -1, + 'class_idx': [18], + 'alarm_name': 'no_helmet', + 'alarmContent': '未佩戴安全帽', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '未佩戴安全帽', + }, + # todo 明火 + { + 'alarmCategory': 1, + 'alarmType': '7', + 'handelType': 3, + 'category_order': 1, + 'class_idx': [], + 'alarm_name': 'gas_alarm', + 'alarmContent': '甲烷浓度超限', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '', + }, + { + 'alarmCategory': 1, + 'alarmType': '8', + 'handelType': 3, + 'category_order': 2, + 'class_idx': [], + 'alarm_name': 'harmful_alarm', + 'alarmContent': '有害气体浓度超标', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '', + }, + { + 'alarmCategory': 2, + 'alarmType': '9', + 'handelType': 3, + 'category_order': -1, + 'class_idx': [], + 'alarm_name': 'health_alarm', + 'alarmContent': '心率血氧异常', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '', + }, + { + 'alarmCategory': 3, + 'alarmType': '10', + 'handelType': 2, + 'category_order': 4, + 'class_idx': [24], + 'alarm_name': 'break_in_alarm', + 'alarmContent': '非法闯入', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '非法闯入', + }, + +] + +COLOR_RED = (0, 0, 255) +COLOR_BLUE = (255, 0, 0) + + +class BlockSceneHandler(BaseSceneHandler): + def __init__(self, device: Device, thread_id: str, tcp_manager: TcpManager, main_loop, range_points): + super().__init__(device=device, thread_id=thread_id, tcp_manager=tcp_manager, main_loop=main_loop) + self.__stop_event = Event(loop=main_loop) + self.health_ts_dict = {} + self.harmful_ts_dict = {} + self.object_ts_dict = {} + self.thread_pool = GlobalThreadPool() + + self.alarm_message_center = AlarmMessageCenter(device.id,main_loop=main_loop, tcp_manager=tcp_manager, + category_priority={2: 0, 1: 1, 3: 2, 0: 3}) + self.alarm_record_center = AlarmRecordCenter(save_interval=device.alarm_interval,main_loop=main_loop) + self.harmful_data_manager = HarmfulGasManager() + self.device_status_manager = DeviceStatusManager() + + + self.health_device_codes = ['HWIH061000056395'] # todo + self.harmful_device_codes = [] # todo + + for helmet_code in self.health_device_codes: + self.thread_pool.submit_task(self.health_data_task, helmet_code) + for harmful_device_code in self.harmful_device_codes: + self.thread_pool.submit_task(self.harmful_data_task, harmful_device_code) + + self.thread_pool.submit_task(self.alarm_message_center.process_messages) + + # todo 明火 + self.model = YOLO('weights/labor-v8-20241114.pt') + self.model_classes = { + 0: '三脚架', + 3: '人', + 4: '作业信息公示牌', + 6: '危险告知牌', + 9: '反光衣', + 11: '呼吸面罩', + 13: '四合一', + 15: '头', + 16: '安全告知牌', + 18: '安全帽', + 20: '安全标识牌', + 24: '工服', + 34: '灭火器', + 43: '警戒线', + 48: '路锥', + 58: '鼓风机', + } + self.PERSON_CLASS_IDX = 3 + self.HEAD_CLASS_IDX = 15 + + self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, + device_thread_id=thread_id) + self.range_points = range_points + self.abs_range_points = self.get_absolute_range() + + self.tracking_status = {} # 跟踪每个行人的状态 + self.max_missing_frames = 25 # 报警的阈值 + self.disappear_threshold = 25 * 3 # 移除行人的阈值 + + def get_absolute_range(self): + fence_info = eval(self.range_points) + if fence_info and len(fence_info) > 1: + abs_points = [] + for p in fence_info: + abs_points.append( + [int(p[0] * int(self.stream_loader.frame_width)), int(p[1] * int(self.stream_loader.frame_height))]) + + abs_points = np.array(abs_points, dtype=np.int32) + hull = ConvexHull(abs_points) + sorted_coordinates = abs_points[hull.vertices] + # abs_points = abs_points.reshape((-1, 1, 2)) + return sorted_coordinates + else: + return None + + def harmful_data_task(self, harmful_device_code): + while not self.__stop_event.is_set(): + harmful_gas_data = self.harmful_data_manager.get_device_all_data(harmful_device_code) + for gas_type, gas_data in harmful_gas_data.items(): + ts_key = f'{harmful_device_code}_{gas_type}' + last_ts = self.harmful_ts_dict.get(ts_key) + gas_ts = gas_data.get('gas_ts') + if last_ts is None or (gas_ts - last_ts).total_seconds() > 0: + self.harmful_ts_dict[ts_key] = gas_ts + self.handle_harmful_gas_alarm(harmful_device_code, gas_type, gas_data) + + def health_data_task(self, helmet_code): + while not self.__stop_event.is_set(): + header = { + 'ak': 'fe80b2f021644b1b8c77fda743a83670', + 'sk': '8771ea6e931d4db646a26f67bcb89909', + } + url = f'https://jls.huaweisoft.com//api/ih-log/v1.0/ih-api/helmetInfo/{helmet_code}' + response = get_request(url, headers=header) + if response and response.get('data'): + last_ts = self.health_ts_dict.get(helmet_code) + vitalsigns_data = response.get('data').get('vitalSignsData') + if vitalsigns_data: + upload_timestamp = datetime.strptime(vitalsigns_data.get('uploadTimestamp'), "%Y-%m-%d %H:%M:%S") + if last_ts is None or (upload_timestamp.timestamp() - last_ts) > 0: + self.health_ts_dict[helmet_code] = upload_timestamp.timestamp() + if time.time() - upload_timestamp.timestamp() < 10 * 60: # 10分钟以前的数据不做处理 + self.handle_health_alarm(helmet_code, vitalsigns_data.get('bloodOxygen'), + vitalsigns_data.get('heartRate'),upload_timestamp) + time.sleep(10) + + def handle_health_alarm(self, helmet_code, blood_oxygen, heartrate, upload_timestamp): + logger.debug(f'health_data: {helmet_code}, blood_oxygen = {blood_oxygen}, heartrate = {heartrate}, ' + f'upload_timestamp = {upload_timestamp}') + if heartrate < 60 or heartrate > 120 or blood_oxygen < 85: + alarm_dict = [d for d in ALARM_DICT if d['alarmCategory'] == 2] + if alarm_dict: + self.alarm_message_center.add_message(alarm_dict[0]) + # todo 需要生成报警记录吗 需要往后台发原始数据吗 + + def handle_harmful_gas_alarm(self, device_code, gas_type, gas_data): + alarm = False + gas_value = gas_data['gas_value'] + if gas_type == 3: # h2s + alarm = gas_value > 120.0 + elif gas_type == 4: # co + alarm = gas_value > 10.0 + elif gas_type == 5: # o2 + alarm = gas_value < 15 + elif gas_type == 50: # ex + alarm = gas_value > 10 + + if alarm: + alarm_dict = [d for d in ALARM_DICT if d['alarmCategory'] == 1] + if alarm_dict: + self.alarm_message_center.add_message(alarm_dict[0]) + # todo 需要生成报警记录吗 + + def model_predict(self, frames): + results_generator = self.model(frames, save_txt=False, save=False, verbose=False, conf=0.5, + classes=list(self.model_classes.keys()), + imgsz=640, + stream=True) + result_boxes = [] + for r in results_generator: + result_boxes.append(r.boxes) + + pred_ids = [[int(box.cls) for box in sublist] for sublist in result_boxes] + pred_names = [[self.model_classes[int(box.cls)] for box in sublist] for sublist in result_boxes] + return result_boxes, pred_ids, pred_names + + def handle_behave_alarm(self, frames, result_boxes, pred_ids, pred_names): + behave_alarm_dicts = [d for d in ALARM_DICT if d['alarmCategory'] == 0] + for alarm_dict in behave_alarm_dicts: + for idx, frame_boxes in enumerate(result_boxes): + frame = frames[idx] + frame_ids, frame_names = pred_ids[idx], pred_names[idx] + object_boxes = [box for box in frame_boxes if int(box.cls) in alarm_dict['class_idx']] + if alarm_dict['handelType'] == 0: # 检测到就报警 + if object_boxes: + self.alarm_message_center.add_message(alarm_dict) + if self.alarm_record_center.need_alarm(self.device.code, alarm_dict): + annotator = Annotator(deepcopy(frame)) + for box in frame_boxes: + if int(box.cls) != self.HEAD_CLASS_IDX: + box_color = COLOR_RED if int(box.cls) in alarm_dict['class_idx'] else COLOR_BLUE + annotator.box_label(box.xyxy.cpu().squeeze(), + f"{self.model_classes[int(box.cls)]}", + color=box_color, + rotated=False) + self.alarm_record_center.upload_alarm_record(self.device.code, alarm_dict, + annotator.result()) + + elif alarm_dict['handelType'] == 1: # 检测不到报警 + if object_boxes: + self.object_ts_dict[alarm_dict['alarm_name']] = time.time() + else: + last_ts = self.object_ts_dict.get(alarm_dict['alarm_name'], 0) + if time.time() - last_ts > 5: + self.object_ts_dict[alarm_dict['alarm_name']] = time.time() + self.alarm_message_center.add_message(alarm_dict) + if self.alarm_record_center.need_alarm(self.device.code, alarm_dict): + annotator = Annotator(deepcopy(frame)) + for box in frame_boxes: + if int(box.cls) != self.HEAD_CLASS_IDX: + annotator.box_label(box.xyxy.cpu().squeeze(), + f"{self.model_classes[int(box.cls)]}", + color=COLOR_BLUE, + rotated=False) + self.alarm_record_center.upload_alarm_record(self.device.code, alarm_dict, + annotator.result()) + elif alarm_dict['handelType'] == 2: # 人未穿戴报警 + person_boxes = [box for box in frame_boxes if int(box.cls) == self.PERSON_CLASS_IDX] + head_boxes = [box for box in frame_boxes if int(box.cls) == self.HEAD_CLASS_IDX] + has_alarm = False + annotator = None + for person_box in person_boxes: + person_bbox = person_box.xyxy.cpu().squeeze() + # 检查这个人是否佩戴了安全帽 + has_helmet = True + person_head = get_person_head(person_bbox, head_boxes) + if person_head is not None: + has_helmet = any( + is_overlapping(person_head.xyxy.cpu().squeeze(), helmet.xyxy.cpu().squeeze()) + for helmet in object_boxes) + if not has_helmet: + has_alarm = True + if self.alarm_record_center.need_alarm(self.device.code, alarm_dict): + annotator = Annotator(deepcopy(frame)) if annotator is None else annotator + annotator.box_label(person_bbox, alarm_dict['label'], color=COLOR_RED, rotated=False) + + if has_alarm: + self.alarm_message_center.add_message(alarm_dict) + if self.alarm_record_center.need_alarm(self.device.code, alarm_dict): + for box in frame_boxes: + box_cls = box.cls + if box_cls != self.PERSON_CLASS_IDX and box_cls != self.HEAD_CLASS_IDX: + annotator.box_label(box.xyxy.cpu().squeeze(), + f"{self.model_classes[int(box.cls)]}", + color=COLOR_BLUE, + rotated=False) + self.alarm_record_center.upload_alarm_record(self.device.code, alarm_dict, + annotator.result()) + + def handle_break_in_alarm(self, frames, result_boxes, pred_ids, pred_names): + break_in_alarm_dicts = [d for d in ALARM_DICT if d['alarmCategory'] == 3] + for alarm_dict in break_in_alarm_dicts: + for idx, frame_boxes in enumerate(result_boxes): + frame = frames[idx] + frame_ids, frame_names = pred_ids[idx], pred_names[idx] + person_boxes = [box for box in frame_boxes if int(box.cls) == self.PERSON_CLASS_IDX] + head_boxes = [box for box in frame_boxes if int(box.cls) == self.HEAD_CLASS_IDX] + object_boxes = [box for box in frame_boxes if int(box.cls) in alarm_dict['class_idx']] + has_alarm = False + annotator = None + for person_box in person_boxes: + person_bbox = person_box.xyxy.cpu().squeeze() + person_id = person_box.id + if is_within_alert_range(person_bbox, self.abs_range_points): + has_object = True + person_head = get_person_head(person_bbox, head_boxes) + if person_head is not None: + overlap_ratio = intersection_area(person_bbox, person_head.xyxy.cpu().squeeze()) / bbox_area(person_bbox) + if overlap_ratio < 0.5: # 头占人<0.5,判断是否穿工服。不太准确 + has_object = any( + is_overlapping(person_head.xyxy.cpu().squeeze(), object_boxe.xyxy.cpu().squeeze()) + for object_boxe in object_boxes) + if not has_object: + self.tracking_status[person_box.id] = self.tracking_status.get(person_box.id, 0) + 1 + + has_alarm = True + if self.alarm_record_center.need_alarm(self.device.code, alarm_dict): + annotator = Annotator(deepcopy(frame)) if annotator is None else annotator + annotator.box_label(person_bbox, alarm_dict['label'], color=COLOR_RED, rotated=False) + + if has_alarm: + self.alarm_message_center.add_message(alarm_dict) + if self.alarm_record_center.need_alarm(self.device.code, alarm_dict): + self.alarm_record_center.upload_alarm_record(self.device.code, alarm_dict, + annotator.result()) + + def run(self): + while not self.stream_loader.init: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + self.stream_loader.init_cap() + for frames in self.stream_loader: + try: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + if not frames: + continue + + self.device_status_manager.set_status(device_id=self.device.id) + result_boxes, pred_ids, pred_names = self.model_predict(frames) # 结果都是二维数组,对应batch中的每个frame + # print(pred_names) + self.handle_behave_alarm(frames, result_boxes, pred_ids, pred_names) + self.handle_break_in_alarm(frames, result_boxes, pred_ids, pred_names) + + except Exception as ex: + traceback.print_exc() + logger.error(ex) diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py deleted file mode 100644 index 9109d75..0000000 --- a/algo/algo_runner_manager.py +++ /dev/null @@ -1,8 +0,0 @@ -from algo.algo_runner import AlgoRunner - -# algo_runner = AlgoRunner() - - -def get_algo_runner(): - pass - # return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 6a86830..d2c647e 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -3,7 +3,7 @@ from copy import deepcopy from dataclasses import dataclass import importlib -from datetime import datetime +from datetime import datetime, timedelta from threading import Event from typing import List, Dict @@ -48,6 +48,8 @@ self.__stop_event = Event() # 使用 Event 控制线程的运行状态 self.frame_ts = None self.push_ts = None + self.frames_detected = 0 + self.fps_ts = None self.thread_id = thread_id self.device_frame_service = DeviceFrameService(db_session) @@ -77,37 +79,40 @@ return True return False - def save_frame_results(self, frame, results_map): - if not self.check_frame_interval(): - return - # device_frame = self.device_frame_service.add_frame(self.device.id, frame) - # frame_id = device_frame.id + def save_frame_results(self, frames, result_maps): + for idx,frame in enumerate(frames): + results_map = result_maps[idx] - future = asyncio.run_coroutine_threadsafe( - self.device_frame_service.add_frame(self.device.id, frame), self.main_loop - ) - device_frame = future.result() - frame_id = device_frame.id + if not self.check_frame_interval(): + return + # device_frame = self.device_frame_service.add_frame(self.device.id, frame) + # frame_id = device_frame.id - logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') - frame_results = [] - for model_exec_id, results in results_map.items(): - for r in results: - frame_result = FrameAnalysisResultCreate( - device_id=self.device.id, - frame_id=frame_id, - algo_model_id=model_exec_id, - object_class_id=r.object_class_id, - object_class_name=r.object_class_name, - confidence=round(r.confidence, 4), - location=r.location, - ) - frame_results.append(frame_result) - asyncio.run_coroutine_threadsafe( - self.frame_analysis_result_service.add_frame_analysis_results(frame_results), - self.main_loop - ) - self.thread_pool.submit_task(self.push_frame_results, frame_results) + future = asyncio.run_coroutine_threadsafe( + self.device_frame_service.add_frame(self.device.id, frame), self.main_loop + ) + device_frame = future.result() + frame_id = device_frame.id + + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') + frame_results = [] + for model_exec_id, results in results_map.items(): + for r in results: + frame_result = FrameAnalysisResultCreate( + device_id=self.device.id, + frame_id=frame_id, + algo_model_id=model_exec_id, + object_class_id=r.object_class_id, + object_class_name=r.object_class_name, + confidence=round(r.confidence, 4), + location=r.location, + ) + frame_results.append(frame_result) + asyncio.run_coroutine_threadsafe( + self.frame_analysis_result_service.add_frame_analysis_results(frame_results), + self.main_loop + ) + self.thread_pool.submit_task(self.push_frame_results, frame_results) def push_frame_results(self, frame_results): global_config = GlobalConfig() @@ -124,31 +129,63 @@ ) self.push_ts = current_time # 更新推送时间戳 + def log_fps(self, frame_count): + self.frames_detected += frame_count + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_detected / 10.0 + self.frames_detected = 0 + logger.info(f"FPS (detect) for device {self.device.code}: {fps}") + self.fps_ts = current_time + def run(self): while not self.stream_loader.init: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 self.stream_loader.init_cap() - for frame in self.stream_loader: + for frames in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 # print('frame') - if frame is None: + if frames is None or len(frames)<=0: continue self.device_status_manager.set_status(device_id=self.device.id) - results_map = {} - annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + + result_maps = [] # 保存每个frame的result map + annotators = [] + for frame in frames: + annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + annotators.append(annotator) + for model_exec in self.model_exec_list: handle_task_name = model_exec.algo_model_info.handle_task handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) handler_instance = handler_cls(model_exec) - frame, results = handler_instance.run(frame, annotator) - results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] + frames, results = handler_instance.run(frames, annotators) + + # 遍历检测结果,按帧存储 + for frame_idx, r in enumerate(results): # 遍历每帧结果 + if len(result_maps) <= frame_idx: # 初始化 frame_results_map + result_maps.append({}) + frame_results_map = result_maps[frame_idx] + # 为当前模型存储检测结果 + if model_exec.algo_model_id not in frame_results_map: + frame_results_map[model_exec.algo_model_id] = [] + # 添加检测结果 + frame_results_map[model_exec.algo_model_id].extend( + DetectionResult.from_dict(box) for box in r + ) # 结果处理 - self.thread_pool.submit_task(self.save_frame_results, frame, results_map) - self.display_frame_manager.add_frame(self.device.id, annotator.result()) + self.thread_pool.submit_task(self.save_frame_results, frames, result_maps) + for annotator in annotators: + self.display_frame_manager.add_frame(self.device.id, annotator.result()) + + self.log_fps(len(frames)) + + # future = asyncio.run_coroutine_threadsafe( # self.display_frame_manager.add_frame(self.device.id, annotator.result()), self.main_loop # ) diff --git a/algo/model_manager.py b/algo/model_manager.py index 1c30867..e84676f 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -19,11 +19,12 @@ class ModelManager: - def __init__(self, model_service: ModelService, model_warm_up=5): + def __init__(self, model_service: ModelService, model_warm_up=5, batch_size=4): # self.db = db self.model_service = model_service self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up + self.batch_size=4 async def query_model_inuse(self): algo_model_list = list(await self.model_service.get_models_in_use()) @@ -58,8 +59,11 @@ if not isinstance(imgsz, list): imgsz = [imgsz, imgsz] dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + dummy_inputs = [] + for _ in range(self.batch_size): + dummy_inputs.append(dummy_input) for i in range(self.model_warm_up): - algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + algo_model_exec.algo_model_exec.predict(source=dummy_inputs, imgsz=imgsz, verbose=False,save=False,save_txt=False) logger.info(f'warm up model {model_name} success!') logger.info(f'load model {model_name} success!') diff --git a/algo/scene_runner.py b/algo/scene_runner.py index e5c067e..5120f96 100644 --- a/algo/scene_runner.py +++ b/algo/scene_runner.py @@ -69,7 +69,7 @@ try: handle_task_name = scene.scene_handle_task handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) - handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) + handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop,scene.range_points) future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) self.device_tasks[device.id] = handler_instance logger.info(f'start thread {thread_id}, device info: {device}') diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 4133388..202213c 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -1,13 +1,20 @@ +from datetime import datetime, timedelta + import cv2 import time import numpy as np from threading import Thread, Event + +import queue + from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool class OpenCVStreamLoad: def __init__(self, camera_url, camera_code, device_thread_id = '', + batch_size=4, + queue_size=100, retry_interval=1, vid_stride=1): assert camera_url is not None and camera_url != '' @@ -19,10 +26,15 @@ self.init = False self.frame = None + self.frame_queue = queue.Queue(maxsize=queue_size) + self.fps_ts = None self.cap = None + self.batch_size = batch_size + self.frames_read = 0 self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 self.thread_pool = GlobalThreadPool() + self.init_cap() # self.thread = Thread(target=self.update, daemon=True) # self.thread.start() @@ -44,7 +56,14 @@ 尝试创建视频流捕获对象。 """ try: - cap = cv2.VideoCapture(self.url) + # cap = cv2.VideoCapture(self.url) + gst_pipeline = ( + f"rtspsrc location={self.url} ! " + f"rtph264depay ! h264parse ! " + f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + f"appsink" + ) + cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) # 可以在这里设置cap的一些属性,如果需要的话 return cap except Exception as e: @@ -74,9 +93,13 @@ cap = cv2.VideoCapture(self.url) # cap = cv2.VideoCapture(self.url, cv2.CAP_FFMPEG) - # 使用 GStreamer pipeline 作为 VideoCapture 的输入 - # gst_str = f"rtspsrc location={self.url} ! decodebin ! videoconvert ! appsink" - # cap = cv2.VideoCapture(gst_str, cv2.CAP_GSTREAMER) + # gst_pipeline = ( + # f"rtspsrc location={self.url} ! " + # f"rtph264depay ! h264parse ! " + # f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + # f"appsink" + # ) + # cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) if cap.isOpened(): logger.info(f"{self.url} connected successfully!") @@ -91,6 +114,15 @@ return cap + def log_fps(self): + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_read / 10.0 + self.frames_read = 0 + logger.info(f"FPS (read) for device {self.camera_code}: {fps}") + self.fps_ts = current_time + def update(self): vid_n = 0 log_n = 0 @@ -111,6 +143,9 @@ else: vid_n += 1 self.frame = frame + self.frames_read += 1 + if not self.frame_queue.full(): + self.frame_queue.put(frame) if log_n % 1000 == 0: logger.debug(f'{self.url} cap success') log_n = (log_n + 1) % 250 @@ -120,12 +155,22 @@ self.cap.release() self.frame = None self.cap = self.get_connect() # 尝试重新连接 + self.log_fps() def __iter__(self): return self def __next__(self): - return self.frame + batch_frames = [] + + queue_length = self.frame_queue.qsize() + if queue_length < self.batch_size: + return [] + + while not self.frame_queue.empty() and len(batch_frames) < self.batch_size: + frame = self.frame_queue.get() + batch_frames.append(frame) + return batch_frames def stop(self): logger.info(f'stop stream loader {self.url},device thread {self.device_thread_id} stopped') diff --git a/apis/control.py b/apis/control.py index 9c6d6fb..2d351d8 100644 --- a/apis/control.py +++ b/apis/control.py @@ -9,6 +9,7 @@ from fastapi import APIRouter import docker import socket +import signal from apis.base import standard_error_response, standard_response from common.global_logger import logger @@ -47,6 +48,23 @@ return None +def handle_sigterm(signum, frame): + print(f"Received signal: {signum}, shutting down gracefully...") + os._exit(0) + + +signal.signal(signal.SIGTERM, handle_sigterm) + + +def shutdown(): + # 等待短暂时间以确保响应已发送 + time.sleep(1) + logger.info("Shutting down the application...") + # sys.exit(0) + # os.kill(os.getpid(), signal.SIGTERM) + os._exit(0) + + @router.get("/restart") async def restart(): try: @@ -92,17 +110,20 @@ logger.error(f"Failed to restart container asynchronously: {ex}") # 在新线程中启动重启操作 - threading.Thread(target=restart_container_async).start() + # threading.Thread(target=restart_container_async).start() + threading.Thread(target=shutdown).start() return standard_response() except Exception as e: traceback.print_exc() return standard_error_response(code=500, message=f"Failed to restart container: {e}") + @router.get("/sync_test") def sync_test(): return standard_response() + @router.get("/async_test") async def async_test(): return standard_response() diff --git a/apis/device_scene_realtion.py b/apis/device_scene_realtion.py index d9aa3a2..77d3677 100644 --- a/apis/device_scene_realtion.py +++ b/apis/device_scene_realtion.py @@ -22,8 +22,8 @@ @router.post("/update_by_device", response_model=StandardResponse[DeviceSceneRelation]) -async def update_by_device(device_id: int, scene_id: int, +async def update_by_device(device_id: int, scene_id: int, range_points: str = None, db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - relation = await service.update_relation_by_device(device_id, scene_id) + relation = await service.update_relation_by_device(device_id, scene_id,range_points) return standard_response(data=relation) diff --git a/app_instance.py b/app_instance.py index 82a750d..c648f6e 100644 --- a/app_instance.py +++ b/app_instance.py @@ -1,4 +1,5 @@ import asyncio + from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException @@ -17,6 +18,7 @@ from services.scene_service import SceneService from services.schedule_job import start_scheduler from tcp.tcp_manager import TcpManager +from tcp.tcp_server import start_server _app = None # 创建一个私有变量来存储 app 实例 @@ -69,6 +71,8 @@ app.state.scene_runner = scene_runner await scene_runner.start() + main_loop.create_task(start_server()) + main_loop.create_task(start_scheduler()) yield # 允许请求处理 diff --git a/common/byte_utils.py b/common/byte_utils.py index 301bb8a..ea3f4ea 100644 --- a/common/byte_utils.py +++ b/common/byte_utils.py @@ -1,3 +1,11 @@ def format_bytes(data: bytes): - return ''.join(f'\\x{byte:02X}' for byte in data) \ No newline at end of file + return ''.join(f'\\x{byte:02X}' for byte in data) + + +if __name__ == '__main__': + msg = b'\xaa\x01\x00\x93\x12\x00\xA6' + print(format_bytes(msg)) + + msg = b'\xAA\x01\x00\x95\x00\x00\x96' + print(format_bytes(msg)) \ No newline at end of file diff --git a/common/consts.py b/common/consts.py index 5a3286c..6deaedf 100644 --- a/common/consts.py +++ b/common/consts.py @@ -36,3 +36,4 @@ GAS = 1 ALGO_RESULT = 2 ALARM = 3 + HARMFUL_GAS = 4 diff --git a/common/detect_utils.py b/common/detect_utils.py new file mode 100644 index 0000000..33c340c --- /dev/null +++ b/common/detect_utils.py @@ -0,0 +1,72 @@ +import matplotlib.path as mpath + + +def is_within_alert_range(bbox, range): + if range is None: + return True + bottom_left = [bbox[0], bbox[3]] + bottom_right = [bbox[2], bbox[3]] + + in_fence = is_point_in_polygon(bottom_left, range) or is_point_in_polygon(bottom_right, range) + return in_fence + + +def is_point_in_polygon(point, polygon): + """ + 判断点是否在多边形内 + + 参数: + - point: 待判断的点,格式为(x, y) + - polygon: 多边形的顶点坐标列表,格式为[(x1, y1), (x2, y2), ...] + + 返回: + - 如果点在多边形内,返回True;否则,返回False。 + """ + # 创建多边形路径对象 + # polygon = np.array(polygon, dtype=np.int32) + # + # hull = ConvexHull(polygon) + # sorted_coordinates = polygon[hull.vertices] + + poly_path = mpath.Path(polygon) + # 使用contains_point方法检查点是否在多边形内 + return poly_path.contains_point(point) + + +def intersection_area(bbox1, bbox2): + # 计算两个坐标框的重叠面积 + x1, y1, x2, y2 = bbox1 + x3, y3, x4, y4 = bbox2 + + xi1 = max(x1, x3) + yi1 = max(y1, y3) + xi2 = min(x2, x4) + yi2 = min(y2, y4) + + width = max(0, xi2 - xi1) + height = max(0, yi2 - yi1) + + return width * height + + +def bbox_area(bbox): + # 计算坐标框的面积 + x1, y1, x2, y2 = bbox + return (x2 - x1) * (y2 - y1) + + +def get_person_head(person_bbox, head_bboxes): + best_head = None + max_overlap = 0 + + for head in head_bboxes: + head_bbox = head.xyxy.cpu().squeeze() + overlap_area = intersection_area(person_bbox, head_bbox) + head_area = bbox_area(head_bbox) + overlap_ratio = overlap_area / head_area + + if overlap_ratio >= 0.8 and (best_head is None or overlap_area > max_overlap or ( + overlap_area == max_overlap and float(head.conf) > float(best_head.conf))): + best_head = head + max_overlap = overlap_area + return best_head diff --git a/common/harmful_gas_manager.py b/common/harmful_gas_manager.py new file mode 100644 index 0000000..297e02c --- /dev/null +++ b/common/harmful_gas_manager.py @@ -0,0 +1,44 @@ +import threading +from datetime import datetime + + +class HarmfulGasManager: + + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + # 确保线程安全的单例模式 + if not cls._instance: + with cls._lock: + if not cls._instance: + cls._instance = super(HarmfulGasManager, cls).__new__(cls) + return cls._instance + + def __init__(self): + # 初始化一次,避免重复初始化 + if not hasattr(self, "device_data"): + self.device_data = {} + self.lock = threading.Lock() + + def get_device_all_data(self, device_code): + """获取指定设备的在线状态""" + with self.lock: + return self.device_data.get(device_code, None) + + def get_device_data(self, device_code, gas_type): + """获取指定设备的在线状态""" + with self.lock: + device_data = self.device_data.get(device_code, None) + if device_data: + return device_data.get(gas_type, None) + return None + + def set_device_data(self, device_code, gas_type, gas_data): + """设置指定设备的在线状态""" + with self.lock: + device_data = self.device_data.get(device_code, None) + if not device_data: + device_data = {} + self.device_data[device_code] = device_data + device_data[gas_type] = gas_data \ No newline at end of file diff --git a/common/http_utils.py b/common/http_utils.py index 5f15707..a49c6ea 100644 --- a/common/http_utils.py +++ b/common/http_utils.py @@ -22,3 +22,11 @@ logger.info(f"Response: {response.status_code}, {response.text}") except requests.RequestException as e: logger.error(f"Failed to push data: {e}") + + +def get_request(url, headers=None): + try: + response = requests.get(url, headers=headers) + return response.json() + except requests.RequestException as e: + logger.error(f"Failed to get data: {e}") diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db index d1acdd4..694b825 100644 --- a/db/safe-algo-pro.db +++ b/db/safe-algo-pro.db Binary files differ diff --git a/entity/alarm_record.py b/entity/alarm_record.py new file mode 100644 index 0000000..1ebe474 --- /dev/null +++ b/entity/alarm_record.py @@ -0,0 +1,32 @@ +from datetime import datetime +from typing import Optional + +from sqlmodel import Field, SQLModel + + +class AlarmRecordBase(SQLModel): + # todo tree code?? alarm_category?? + device_code: str + alarm_type: int + alarm_content: str + alarm_value: Optional[float] = None + alarm_image: Optional[str] = None + alarm_time: datetime = Field(default_factory=datetime.now) + + class Config: + json_encoders = { + datetime: lambda v: v.strftime('%Y-%m-%d %H:%M:%S') + } + + +class AlarmRecord(AlarmRecordBase, table=True): + __tablename__ = 'alarm_record' + id: Optional[int] = Field(default=None, primary_key=True) + + +class AlarmRecordCreate(AlarmRecordBase): + pass + + +class AlarmRecordInfo(AlarmRecordBase): + id: int diff --git a/entity/device_scene_relation.py b/entity/device_scene_relation.py index c4543d7..38481ef 100644 --- a/entity/device_scene_relation.py +++ b/entity/device_scene_relation.py @@ -8,6 +8,7 @@ class DeviceSceneRelationBase(SQLModel): scene_id: int device_id: int + range_points: Optional[str] = None class DeviceSceneRelation(DeviceSceneRelationBase, TimestampMixin, table=True): diff --git a/model_handler/base_model_handler.py b/model_handler/base_model_handler.py index 46e0d53..c7c3490 100644 --- a/model_handler/base_model_handler.py +++ b/model_handler/base_model_handler.py @@ -8,41 +8,48 @@ def __init__(self, model: AlgoModelExec): self.model = model self.model_names = model.algo_model_exec.names + self.model_ids = list(self.model_names.keys()) def pre_process(self, frame): return frame - def model_inference(self, frame): - results_generator = self.model.algo_model_exec.predict(source=frame, imgsz=self.model.input_size, + def model_inference(self, frames): + results_generator = self.model.algo_model_exec.predict(source=frames, imgsz=self.model.input_size, save_txt=False, save=False, - verbose=False, stream=True) - results = list(results_generator) # 确保生成器转换为列表 - result = results[0] - # logger.debug(f"model {self.model.algo_model_info.name} result: {len(result)}") - return result + verbose=True, stream=True) + result_boxes = [] + for r in results_generator: + result_boxes.append(r.boxes) - def post_process(self, frame, model_result, annotator): + return result_boxes + + def post_process(self, frames, model_results, annotators): results = [] - for box in model_result.boxes: - results.append( - { - 'object_class_id': int(box.cls), - 'object_class_name': self.model_names[int(box.cls)], - 'confidence': float(box.conf), - 'location': ", ".join([f"{x:.6f}" for x in box.xyxyn.cpu().squeeze().tolist()]) - } - ) - if annotator is not None: - for s_box in model_result.boxes: - annotator.box_label(s_box.xyxy.cpu().squeeze(), - f"{self.model_names[int(s_box.cls)]} {float(s_box.conf):.2f}", - color=colors(int(s_box.cls)), - rotated=False) + for idx,frame in enumerate(frames): + frame_boxes = model_results[idx] + annotator = annotators[idx] + frame_result = [] + for box in frame_boxes: + frame_result.append( + { + 'object_class_id': int(box.cls), + 'object_class_name': self.model_names[int(box.cls)], + 'confidence': float(box.conf), + 'location': ", ".join([f"{x:.6f}" for x in box.xyxyn.cpu().squeeze().tolist()]) + } + ) + results.append(frame_result) + if annotator is not None: + for s_box in frame_boxes: + annotator.box_label(s_box.xyxy.cpu().squeeze(), + f"{self.model_names[int(s_box.cls)]} {float(s_box.conf):.2f}", + color=colors(int(s_box.cls)), + rotated=False) return results - def run(self, frame, annotator): - processed_frame = self.pre_process(frame=frame) - model_result = self.model_inference(frame=processed_frame) - result = self.post_process(frame=frame, model_result=model_result, annotator=annotator) - return frame, result + def run(self, frames, annotators): + processed_frames = [self.pre_process(frame=frame) for frame in frames] + result_boxes = self.model_inference(frames=processed_frames) + results = self.post_process(frames=frames, model_results=result_boxes, annotators=annotators) + return frames, results diff --git a/model_handler/coco_engine_model_handler.py b/model_handler/coco_engine_model_handler.py index e6023d0..ec8006f 100644 --- a/model_handler/coco_engine_model_handler.py +++ b/model_handler/coco_engine_model_handler.py @@ -88,3 +88,4 @@ 78: 'hair_drier', 79: 'toothbrush', } + self.model_ids = list(self.model_names.keys()) diff --git a/model_handler/labor_engine_model_handler.py b/model_handler/labor_engine_model_handler.py index d0fefba..8892fcb 100644 --- a/model_handler/labor_engine_model_handler.py +++ b/model_handler/labor_engine_model_handler.py @@ -68,3 +68,4 @@ 58: '鼓风机', } + self.model_ids = list(self.model_names.keys()) diff --git a/run.sh b/run.sh index bb36274..815883a 100644 --- a/run.sh +++ b/run.sh @@ -8,4 +8,5 @@ export LC_ALL=C.UTF-8 cd /code/safe-algo-pro -python3 main.py +exec python3 main.py +#tail -f /dev/null \ No newline at end of file diff --git a/scene_handler/alarm_message_center.py b/scene_handler/alarm_message_center.py new file mode 100644 index 0000000..fba1397 --- /dev/null +++ b/scene_handler/alarm_message_center.py @@ -0,0 +1,121 @@ +import asyncio +import copy +import time +from collections import deque +from threading import Thread, Lock + +''' +队列消息取出规则: +- 按 alarmCategory(优先级:2 > 1 > 3 > 0)和 category_order 从小到大排序。 +- 确保每个 alarmCategory 在 10 秒内只能发送一次。 + +消息定时清理: +- 每次取消息时会检查队列中长期堆积、超过一定时间(比如 30 秒)的消息,并清理这些消息。 +''' + + +class AlarmMessageCenter: + def __init__(self, device_id, main_loop, tcp_manager=None, message_send_interval=5, category_interval=10, retention_time=30, + category_priority=None): + self.device_id = device_id + self.tcp_manager = tcp_manager + self.main_loop = main_loop + + self.queue = deque() + self.last_sent_time = {} # 记录每个 alarmCategory 的最后发送时间 + self.lock = Lock() + self.message_send_interval = message_send_interval # 消息发送间隔(秒) + self.category_interval = category_interval # 类别发送间隔(秒) + self.retention_time = retention_time # 消息最长保留时间(秒) + if category_priority: + self.category_priority = category_priority + self.auto_update_alarm_priority = False + else: + self.category_priority = category_priority + self.auto_update_alarm_priority = True + + def add_message(self, message_ori): + message = copy.deepcopy(message_ori) + message['timestamp'] = int(time.time()) # 添加消息放入队列的时间 + with self.lock: + self.queue.append(message) + + # 动态更新优先级映射 + if self.auto_update_alarm_priority: + alarm_category = message['alarmCategory'] + if alarm_category not in self.category_priority: + unique_categories = sorted({msg['alarmCategory'] for msg in self.queue}) + self.category_priority = {cat: idx for idx, cat in enumerate(unique_categories)} + print(f"更新优先级映射: {self.category_priority}") + + def _clean_old_messages(self): + """清理长期堆积的消息""" + print(f'清理长期堆积的消息 (队列长度: {len(self.queue)})') + now = time.time() + with self.lock: + self.queue = deque([msg for msg in self.queue if now - msg['timestamp'] <= self.retention_time]) + print(f'清理后的队列长度: {len(self.queue)}') + + # def _get_next_message(self): + # """按优先级和时间规则取出下一条消息""" + # now = time.time() + # with self.lock: + # # 按照规则排序队列 + # sorted_queue = sorted( + # self.queue, + # key=lambda x: ( + # self.alarm_priority.get(x['alarmCategory'], 4), # 按 alarmCategory 优先级排序 + # x['category_order'], # category_order 小的排前面 + # x['timestamp'] # 时间靠前的排前面 + # ) + # ) + # # 遍历排序后的队列,找到符合规则的消息 + # for msg in sorted_queue: + # alarm_category = msg['alarmCategory'] + # if alarm_category not in self.last_sent_time or now - self.last_sent_time[alarm_category] > 10: + # self.queue.remove(msg) + # self.last_sent_time[alarm_category] = now + # return msg + # return None + + def _get_next_message(self): + """按优先级和时间规则取出下一条消息""" + now = time.time() + with self.lock: + # 按优先级依次检查 + for priority in sorted(self.category_priority.values()): + found_valid_message = False # 用于标记当前优先级是否有消息被处理 + for msg in sorted( + (m for m in self.queue if self.category_priority.get(m['alarmCategory'], 99) == priority), + key=lambda x: (-x['timestamp'], x['category_order']) + ): + alarm_category = msg['alarmCategory'] + # 检查是否符合发送条件 + if alarm_category not in self.last_sent_time or now - self.last_sent_time[ + alarm_category] > self.category_interval: + self.queue.remove(msg) + self.last_sent_time[alarm_category] = now + return msg # 找到符合条件的消息立即返回 + found_valid_message = True # 当前优先级存在消息但不符合条件 + # 如果当前优先级的所有消息都被检查过且不符合条件,跳到下一个优先级 + if not found_valid_message: + continue + return None # 如果所有优先级都没有消息符合条件,则返回 None + + def process_messages(self): + while True: + time.sleep(self.message_send_interval) + self._clean_old_messages() # 清理长期堆积的消息 + next_message = self._get_next_message() + if next_message: + self.send_message(next_message) + + def send_message(self, message): + """发送报警消息""" + print(f"发送报警消息: {message['alarmContent']} (类别: {message['alarmCategory']}, 时间: {message['timestamp']})") + if self.tcp_manager: + asyncio.run_coroutine_threadsafe( + self.tcp_manager.send_message_to_device(device_id=self.device_id, + message=message['alarmSoundMessage'], + have_response=False), + self.main_loop) diff --git a/scene_handler/alarm_record_center.py b/scene_handler/alarm_record_center.py new file mode 100644 index 0000000..5373cd5 --- /dev/null +++ b/scene_handler/alarm_record_center.py @@ -0,0 +1,128 @@ +import asyncio +import base64 +import time +from datetime import datetime + +import cv2 + +from common.global_thread_pool import GlobalThreadPool +from common.http_utils import send_request +from db.database import get_db +from entity.alarm_record import AlarmRecordCreate +from services.alarm_record_service import AlarmRecordService +from services.global_config import GlobalConfig + + +def image_to_base64(numpy_image, format='jpg'): + # 将NumPy数组转换为图片格式的字节流 + # format指定图片格式,'png'或'jpeg' + success, encoded_image = cv2.imencode(f'.{format}', numpy_image) + if not success: + raise ValueError("Could not encode image") + + # 将字节流转换为Base64编码 + base64_encoded_image = base64.b64encode(encoded_image) + + # 将bytes类型转换为UTF-8字符串 + base64_message = base64_encoded_image.decode('utf-8') + + return base64_message + # return f'data:image/{format};base64,{base64_message}' + + +class AlarmRecordCenter: + def __init__(self, save_interval=-1, main_loop=None): + """ + 初始化报警记录中心 + :param upload_interval: 报警上传间隔(秒),如果 <= 0,则不报警 + """ + self.main_loop = main_loop + self.save_interval = save_interval + self.thread_pool = GlobalThreadPool() + self.global_config = GlobalConfig() + # self.upload_interval = upload_interval + self.device_alarm_upload_time = {} # key: device_code, value: {alarm_type: last upload time} + self.device_alarm_save_time = {} + + def need_alarm(self, device_code, alarm_dict): + """ + 是否需要报警 + :param device_code: 设备编号 + :param alarm_dict: 报警类型字典 + :return: 是否需要报警 + """ + current_time = time.time() + alarm_type = alarm_dict['alarmType'] + + return self.need_save_alarm_record(device_code, current_time, alarm_type) \ + or self.need_upload_alarm_record(device_code, current_time, alarm_type) + + def need_save_alarm_record(self, device_code, current_time, alarm_type): + if self.save_interval <= 0: + return False + + if device_code not in self.device_alarm_save_time: + self.device_alarm_save_time[device_code] = {} + last_save_time = self.device_alarm_save_time[device_code].get(alarm_type) + if last_save_time is None or (current_time - last_save_time) > self.save_interval: + return True + return False + + def need_upload_alarm_record(self, device_code, current_time, alarm_type): + push_config = self.global_config.get_alarm_push_config() + if not push_config or not push_config.push_url: + return False + + if device_code not in self.device_alarm_upload_time: + self.device_alarm_upload_time[device_code] = {} + last_upload_time = self.device_alarm_upload_time[device_code].get(alarm_type) + if last_upload_time is None or (current_time - last_upload_time) > push_config.upload_interval: + return True + return False + + async def save_record(self, alarm_data, alarm_np_img): + async for db in get_db(): + alarm_record_service = AlarmRecordService(db) + await alarm_record_service.add_alarm(alarm_data=alarm_data, alarm_np_img=alarm_np_img) + + def upload_alarm_record(self, device_code, alarm_dict, alarm_np_img=None, alarm_value=None): + """ + 上传报警记录 + :param alarm_value: + :param device_code: 设备编号 + :param alarm_dict: 报警类型字典 + :param alarm_np_img: 报警图片(np array) + """ + # 获取当前时间 + current_time = time.time() + alarm_type = alarm_dict['alarmType'] + + if self.need_save_alarm_record(device_code, current_time, alarm_type): + alarm_record_data = AlarmRecordCreate( + device_code=device_code, + device_id=100, + alarm_type=alarm_type, + alarm_content=alarm_dict['alarmContent'], + alarm_time=datetime.now(), + alarm_value=alarm_value if alarm_value else None, + ) + asyncio.run_coroutine_threadsafe( + self.save_record(alarm_record_data, alarm_np_img), + self.main_loop) + self.device_alarm_save_time[device_code][alarm_type] = current_time + + if self.need_upload_alarm_record(device_code, current_time, alarm_type): + alarm_record = { + 'devcode': device_code, + 'alarmTime': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), + 'alarmType': alarm_dict['alarmType'], + 'alarmContent': alarm_dict['alarmContent'], + } + if alarm_value: + alarm_record['alarmValue'] = alarm_value + if alarm_np_img: + alarm_record['alarmImage'] = image_to_base64(alarm_np_img) + + push_config = self.global_config.get_alarm_push_config() + self.thread_pool.submit_task(send_request, push_config.upload_url, alarm_record) + self.device_alarm_upload_time[device_code][alarm_type] = current_time diff --git a/scene_handler/block_scene_handler.py b/scene_handler/block_scene_handler.py new file mode 100644 index 0000000..3dd0b31 --- /dev/null +++ b/scene_handler/block_scene_handler.py @@ -0,0 +1,426 @@ +import time +import traceback +from asyncio import Event +from copy import deepcopy +from datetime import datetime + + +from flatbuffers.builder import np +from scipy.spatial import ConvexHull + +from algo.stream_loader import OpenCVStreamLoad +from common.detect_utils import is_within_alert_range, get_person_head, intersection_area, bbox_area +from common.device_status_manager import DeviceStatusManager +from common.global_logger import logger +from common.global_thread_pool import GlobalThreadPool +from common.harmful_gas_manager import HarmfulGasManager +from common.image_plotting import Annotator +from entity.device import Device +from scene_handler.alarm_message_center import AlarmMessageCenter +from scene_handler.alarm_record_center import AlarmRecordCenter +from scene_handler.base_scene_handler import BaseSceneHandler +from scene_handler.limit_space_scene_handler import is_overlapping +from services.global_config import GlobalConfig +from tcp.tcp_manager import TcpManager + +from entity.device import Device +from common.http_utils import get_request +from ultralytics import YOLO + +''' +alarmCategory: +0 行为监管 +1 环境监管 +2 人员监管 +3 围栏监管 + +handelType: +0 检测到报警 +1 未检测到报警 +2 人未穿戴报警 +3 其他 +''' +ALARM_DICT = [ + { + 'alarmCategory': 0, + 'alarmType': '1', + 'handelType': 1, + 'category_order': 1, + 'class_idx': [34], + 'alarm_name': 'no_fire_extinguisher', + 'alarmContent': '未检测到灭火器', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '', + }, + { + 'alarmCategory': 0, + 'alarmType': '2', + 'handelType': 1, + 'category_order': 2, + 'class_idx': [43], + 'alarm_name': 'no_barrier_tape', + 'alarmContent': '未检测到警戒线', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '', + }, + { + 'alarmCategory': 0, + 'alarmType': '3', + 'handelType': 1, + 'category_order': 3, + 'class_idx': [48], + 'alarm_name': 'no_cone', + 'alarmContent': '未检测到锥桶', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '', + }, + { + 'alarmCategory': 0, + 'alarmType': '4', + 'handelType': 1, + 'category_order': 4, + 'class_idx': [4, 5, 16], + 'alarm_name': 'no_board', + 'alarmContent': '未检测到指示牌', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '', + }, + { + 'alarmCategory': 0, + 'alarmType': '5', + 'handelType': 2, + 'category_order': -1, + 'class_idx': [18], + 'alarm_name': 'no_helmet', + 'alarmContent': '未佩戴安全帽', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '未佩戴安全帽', + }, + # todo 明火 + { + 'alarmCategory': 1, + 'alarmType': '7', + 'handelType': 3, + 'category_order': 1, + 'class_idx': [], + 'alarm_name': 'gas_alarm', + 'alarmContent': '甲烷浓度超限', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '', + }, + { + 'alarmCategory': 1, + 'alarmType': '8', + 'handelType': 3, + 'category_order': 2, + 'class_idx': [], + 'alarm_name': 'harmful_alarm', + 'alarmContent': '有害气体浓度超标', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '', + }, + { + 'alarmCategory': 2, + 'alarmType': '9', + 'handelType': 3, + 'category_order': -1, + 'class_idx': [], + 'alarm_name': 'health_alarm', + 'alarmContent': '心率血氧异常', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '', + }, + { + 'alarmCategory': 3, + 'alarmType': '10', + 'handelType': 2, + 'category_order': 4, + 'class_idx': [24], + 'alarm_name': 'break_in_alarm', + 'alarmContent': '非法闯入', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '非法闯入', + }, + +] + +COLOR_RED = (0, 0, 255) +COLOR_BLUE = (255, 0, 0) + + +class BlockSceneHandler(BaseSceneHandler): + def __init__(self, device: Device, thread_id: str, tcp_manager: TcpManager, main_loop, range_points): + super().__init__(device=device, thread_id=thread_id, tcp_manager=tcp_manager, main_loop=main_loop) + self.__stop_event = Event(loop=main_loop) + self.health_ts_dict = {} + self.harmful_ts_dict = {} + self.object_ts_dict = {} + self.thread_pool = GlobalThreadPool() + + self.alarm_message_center = AlarmMessageCenter(device.id,main_loop=main_loop, tcp_manager=tcp_manager, + category_priority={2: 0, 1: 1, 3: 2, 0: 3}) + self.alarm_record_center = AlarmRecordCenter(save_interval=device.alarm_interval,main_loop=main_loop) + self.harmful_data_manager = HarmfulGasManager() + self.device_status_manager = DeviceStatusManager() + + + self.health_device_codes = ['HWIH061000056395'] # todo + self.harmful_device_codes = [] # todo + + for helmet_code in self.health_device_codes: + self.thread_pool.submit_task(self.health_data_task, helmet_code) + for harmful_device_code in self.harmful_device_codes: + self.thread_pool.submit_task(self.harmful_data_task, harmful_device_code) + + self.thread_pool.submit_task(self.alarm_message_center.process_messages) + + # todo 明火 + self.model = YOLO('weights/labor-v8-20241114.pt') + self.model_classes = { + 0: '三脚架', + 3: '人', + 4: '作业信息公示牌', + 6: '危险告知牌', + 9: '反光衣', + 11: '呼吸面罩', + 13: '四合一', + 15: '头', + 16: '安全告知牌', + 18: '安全帽', + 20: '安全标识牌', + 24: '工服', + 34: '灭火器', + 43: '警戒线', + 48: '路锥', + 58: '鼓风机', + } + self.PERSON_CLASS_IDX = 3 + self.HEAD_CLASS_IDX = 15 + + self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, + device_thread_id=thread_id) + self.range_points = range_points + self.abs_range_points = self.get_absolute_range() + + self.tracking_status = {} # 跟踪每个行人的状态 + self.max_missing_frames = 25 # 报警的阈值 + self.disappear_threshold = 25 * 3 # 移除行人的阈值 + + def get_absolute_range(self): + fence_info = eval(self.range_points) + if fence_info and len(fence_info) > 1: + abs_points = [] + for p in fence_info: + abs_points.append( + [int(p[0] * int(self.stream_loader.frame_width)), int(p[1] * int(self.stream_loader.frame_height))]) + + abs_points = np.array(abs_points, dtype=np.int32) + hull = ConvexHull(abs_points) + sorted_coordinates = abs_points[hull.vertices] + # abs_points = abs_points.reshape((-1, 1, 2)) + return sorted_coordinates + else: + return None + + def harmful_data_task(self, harmful_device_code): + while not self.__stop_event.is_set(): + harmful_gas_data = self.harmful_data_manager.get_device_all_data(harmful_device_code) + for gas_type, gas_data in harmful_gas_data.items(): + ts_key = f'{harmful_device_code}_{gas_type}' + last_ts = self.harmful_ts_dict.get(ts_key) + gas_ts = gas_data.get('gas_ts') + if last_ts is None or (gas_ts - last_ts).total_seconds() > 0: + self.harmful_ts_dict[ts_key] = gas_ts + self.handle_harmful_gas_alarm(harmful_device_code, gas_type, gas_data) + + def health_data_task(self, helmet_code): + while not self.__stop_event.is_set(): + header = { + 'ak': 'fe80b2f021644b1b8c77fda743a83670', + 'sk': '8771ea6e931d4db646a26f67bcb89909', + } + url = f'https://jls.huaweisoft.com//api/ih-log/v1.0/ih-api/helmetInfo/{helmet_code}' + response = get_request(url, headers=header) + if response and response.get('data'): + last_ts = self.health_ts_dict.get(helmet_code) + vitalsigns_data = response.get('data').get('vitalSignsData') + if vitalsigns_data: + upload_timestamp = datetime.strptime(vitalsigns_data.get('uploadTimestamp'), "%Y-%m-%d %H:%M:%S") + if last_ts is None or (upload_timestamp.timestamp() - last_ts) > 0: + self.health_ts_dict[helmet_code] = upload_timestamp.timestamp() + if time.time() - upload_timestamp.timestamp() < 10 * 60: # 10分钟以前的数据不做处理 + self.handle_health_alarm(helmet_code, vitalsigns_data.get('bloodOxygen'), + vitalsigns_data.get('heartRate'),upload_timestamp) + time.sleep(10) + + def handle_health_alarm(self, helmet_code, blood_oxygen, heartrate, upload_timestamp): + logger.debug(f'health_data: {helmet_code}, blood_oxygen = {blood_oxygen}, heartrate = {heartrate}, ' + f'upload_timestamp = {upload_timestamp}') + if heartrate < 60 or heartrate > 120 or blood_oxygen < 85: + alarm_dict = [d for d in ALARM_DICT if d['alarmCategory'] == 2] + if alarm_dict: + self.alarm_message_center.add_message(alarm_dict[0]) + # todo 需要生成报警记录吗 需要往后台发原始数据吗 + + def handle_harmful_gas_alarm(self, device_code, gas_type, gas_data): + alarm = False + gas_value = gas_data['gas_value'] + if gas_type == 3: # h2s + alarm = gas_value > 120.0 + elif gas_type == 4: # co + alarm = gas_value > 10.0 + elif gas_type == 5: # o2 + alarm = gas_value < 15 + elif gas_type == 50: # ex + alarm = gas_value > 10 + + if alarm: + alarm_dict = [d for d in ALARM_DICT if d['alarmCategory'] == 1] + if alarm_dict: + self.alarm_message_center.add_message(alarm_dict[0]) + # todo 需要生成报警记录吗 + + def model_predict(self, frames): + results_generator = self.model(frames, save_txt=False, save=False, verbose=False, conf=0.5, + classes=list(self.model_classes.keys()), + imgsz=640, + stream=True) + result_boxes = [] + for r in results_generator: + result_boxes.append(r.boxes) + + pred_ids = [[int(box.cls) for box in sublist] for sublist in result_boxes] + pred_names = [[self.model_classes[int(box.cls)] for box in sublist] for sublist in result_boxes] + return result_boxes, pred_ids, pred_names + + def handle_behave_alarm(self, frames, result_boxes, pred_ids, pred_names): + behave_alarm_dicts = [d for d in ALARM_DICT if d['alarmCategory'] == 0] + for alarm_dict in behave_alarm_dicts: + for idx, frame_boxes in enumerate(result_boxes): + frame = frames[idx] + frame_ids, frame_names = pred_ids[idx], pred_names[idx] + object_boxes = [box for box in frame_boxes if int(box.cls) in alarm_dict['class_idx']] + if alarm_dict['handelType'] == 0: # 检测到就报警 + if object_boxes: + self.alarm_message_center.add_message(alarm_dict) + if self.alarm_record_center.need_alarm(self.device.code, alarm_dict): + annotator = Annotator(deepcopy(frame)) + for box in frame_boxes: + if int(box.cls) != self.HEAD_CLASS_IDX: + box_color = COLOR_RED if int(box.cls) in alarm_dict['class_idx'] else COLOR_BLUE + annotator.box_label(box.xyxy.cpu().squeeze(), + f"{self.model_classes[int(box.cls)]}", + color=box_color, + rotated=False) + self.alarm_record_center.upload_alarm_record(self.device.code, alarm_dict, + annotator.result()) + + elif alarm_dict['handelType'] == 1: # 检测不到报警 + if object_boxes: + self.object_ts_dict[alarm_dict['alarm_name']] = time.time() + else: + last_ts = self.object_ts_dict.get(alarm_dict['alarm_name'], 0) + if time.time() - last_ts > 5: + self.object_ts_dict[alarm_dict['alarm_name']] = time.time() + self.alarm_message_center.add_message(alarm_dict) + if self.alarm_record_center.need_alarm(self.device.code, alarm_dict): + annotator = Annotator(deepcopy(frame)) + for box in frame_boxes: + if int(box.cls) != self.HEAD_CLASS_IDX: + annotator.box_label(box.xyxy.cpu().squeeze(), + f"{self.model_classes[int(box.cls)]}", + color=COLOR_BLUE, + rotated=False) + self.alarm_record_center.upload_alarm_record(self.device.code, alarm_dict, + annotator.result()) + elif alarm_dict['handelType'] == 2: # 人未穿戴报警 + person_boxes = [box for box in frame_boxes if int(box.cls) == self.PERSON_CLASS_IDX] + head_boxes = [box for box in frame_boxes if int(box.cls) == self.HEAD_CLASS_IDX] + has_alarm = False + annotator = None + for person_box in person_boxes: + person_bbox = person_box.xyxy.cpu().squeeze() + # 检查这个人是否佩戴了安全帽 + has_helmet = True + person_head = get_person_head(person_bbox, head_boxes) + if person_head is not None: + has_helmet = any( + is_overlapping(person_head.xyxy.cpu().squeeze(), helmet.xyxy.cpu().squeeze()) + for helmet in object_boxes) + if not has_helmet: + has_alarm = True + if self.alarm_record_center.need_alarm(self.device.code, alarm_dict): + annotator = Annotator(deepcopy(frame)) if annotator is None else annotator + annotator.box_label(person_bbox, alarm_dict['label'], color=COLOR_RED, rotated=False) + + if has_alarm: + self.alarm_message_center.add_message(alarm_dict) + if self.alarm_record_center.need_alarm(self.device.code, alarm_dict): + for box in frame_boxes: + box_cls = box.cls + if box_cls != self.PERSON_CLASS_IDX and box_cls != self.HEAD_CLASS_IDX: + annotator.box_label(box.xyxy.cpu().squeeze(), + f"{self.model_classes[int(box.cls)]}", + color=COLOR_BLUE, + rotated=False) + self.alarm_record_center.upload_alarm_record(self.device.code, alarm_dict, + annotator.result()) + + def handle_break_in_alarm(self, frames, result_boxes, pred_ids, pred_names): + break_in_alarm_dicts = [d for d in ALARM_DICT if d['alarmCategory'] == 3] + for alarm_dict in break_in_alarm_dicts: + for idx, frame_boxes in enumerate(result_boxes): + frame = frames[idx] + frame_ids, frame_names = pred_ids[idx], pred_names[idx] + person_boxes = [box for box in frame_boxes if int(box.cls) == self.PERSON_CLASS_IDX] + head_boxes = [box for box in frame_boxes if int(box.cls) == self.HEAD_CLASS_IDX] + object_boxes = [box for box in frame_boxes if int(box.cls) in alarm_dict['class_idx']] + has_alarm = False + annotator = None + for person_box in person_boxes: + person_bbox = person_box.xyxy.cpu().squeeze() + person_id = person_box.id + if is_within_alert_range(person_bbox, self.abs_range_points): + has_object = True + person_head = get_person_head(person_bbox, head_boxes) + if person_head is not None: + overlap_ratio = intersection_area(person_bbox, person_head.xyxy.cpu().squeeze()) / bbox_area(person_bbox) + if overlap_ratio < 0.5: # 头占人<0.5,判断是否穿工服。不太准确 + has_object = any( + is_overlapping(person_head.xyxy.cpu().squeeze(), object_boxe.xyxy.cpu().squeeze()) + for object_boxe in object_boxes) + if not has_object: + self.tracking_status[person_box.id] = self.tracking_status.get(person_box.id, 0) + 1 + + has_alarm = True + if self.alarm_record_center.need_alarm(self.device.code, alarm_dict): + annotator = Annotator(deepcopy(frame)) if annotator is None else annotator + annotator.box_label(person_bbox, alarm_dict['label'], color=COLOR_RED, rotated=False) + + if has_alarm: + self.alarm_message_center.add_message(alarm_dict) + if self.alarm_record_center.need_alarm(self.device.code, alarm_dict): + self.alarm_record_center.upload_alarm_record(self.device.code, alarm_dict, + annotator.result()) + + def run(self): + while not self.stream_loader.init: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + self.stream_loader.init_cap() + for frames in self.stream_loader: + try: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + if not frames: + continue + + self.device_status_manager.set_status(device_id=self.device.id) + result_boxes, pred_ids, pred_names = self.model_predict(frames) # 结果都是二维数组,对应batch中的每个frame + # print(pred_names) + self.handle_behave_alarm(frames, result_boxes, pred_ids, pred_names) + self.handle_break_in_alarm(frames, result_boxes, pred_ids, pred_names) + + except Exception as ex: + traceback.print_exc() + logger.error(ex) diff --git a/scene_handler/limit_space_scene_handler.py b/scene_handler/limit_space_scene_handler.py index d1c5b9a..a7f1674 100644 --- a/scene_handler/limit_space_scene_handler.py +++ b/scene_handler/limit_space_scene_handler.py @@ -145,7 +145,7 @@ class LimitSpaceSceneHandler(BaseSceneHandler): - def __init__(self, device: Device, thread_id: str, tcp_manager: TcpManager, main_loop): + def __init__(self, device: Device, thread_id: str, tcp_manager: TcpManager, main_loop, range_points): super().__init__(device=device, thread_id=thread_id, tcp_manager=tcp_manager, main_loop=main_loop) # self.device = device # self.thread_id = thread_id @@ -287,20 +287,20 @@ if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 self.stream_loader.init_cap() - for frame in self.stream_loader: + for frames in self.stream_loader: try: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 # print('frame') - if frame is None: + if frames is None: continue self.device_status_manager.set_status(device_id=self.device.id) - result_boxes, pred_ids, pred_names = self.model_predict(frame) + # result_boxes, pred_ids, pred_names = self.model_predict(frames) frame_alarm = {} - self.process_alarm(frame, result_boxes, pred_ids, pred_names, frame_alarm) - self.process_labor(frame, result_boxes, pred_ids, pred_names) + # self.process_alarm(frame, result_boxes, pred_ids, pred_names, frame_alarm) + # self.process_labor(frame, result_boxes, pred_ids, pred_names) if len(frame_alarm.keys()) > 0: for key in frame_alarm.keys(): diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py deleted file mode 100644 index 9109d75..0000000 --- a/algo/algo_runner_manager.py +++ /dev/null @@ -1,8 +0,0 @@ -from algo.algo_runner import AlgoRunner - -# algo_runner = AlgoRunner() - - -def get_algo_runner(): - pass - # return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 6a86830..d2c647e 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -3,7 +3,7 @@ from copy import deepcopy from dataclasses import dataclass import importlib -from datetime import datetime +from datetime import datetime, timedelta from threading import Event from typing import List, Dict @@ -48,6 +48,8 @@ self.__stop_event = Event() # 使用 Event 控制线程的运行状态 self.frame_ts = None self.push_ts = None + self.frames_detected = 0 + self.fps_ts = None self.thread_id = thread_id self.device_frame_service = DeviceFrameService(db_session) @@ -77,37 +79,40 @@ return True return False - def save_frame_results(self, frame, results_map): - if not self.check_frame_interval(): - return - # device_frame = self.device_frame_service.add_frame(self.device.id, frame) - # frame_id = device_frame.id + def save_frame_results(self, frames, result_maps): + for idx,frame in enumerate(frames): + results_map = result_maps[idx] - future = asyncio.run_coroutine_threadsafe( - self.device_frame_service.add_frame(self.device.id, frame), self.main_loop - ) - device_frame = future.result() - frame_id = device_frame.id + if not self.check_frame_interval(): + return + # device_frame = self.device_frame_service.add_frame(self.device.id, frame) + # frame_id = device_frame.id - logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') - frame_results = [] - for model_exec_id, results in results_map.items(): - for r in results: - frame_result = FrameAnalysisResultCreate( - device_id=self.device.id, - frame_id=frame_id, - algo_model_id=model_exec_id, - object_class_id=r.object_class_id, - object_class_name=r.object_class_name, - confidence=round(r.confidence, 4), - location=r.location, - ) - frame_results.append(frame_result) - asyncio.run_coroutine_threadsafe( - self.frame_analysis_result_service.add_frame_analysis_results(frame_results), - self.main_loop - ) - self.thread_pool.submit_task(self.push_frame_results, frame_results) + future = asyncio.run_coroutine_threadsafe( + self.device_frame_service.add_frame(self.device.id, frame), self.main_loop + ) + device_frame = future.result() + frame_id = device_frame.id + + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') + frame_results = [] + for model_exec_id, results in results_map.items(): + for r in results: + frame_result = FrameAnalysisResultCreate( + device_id=self.device.id, + frame_id=frame_id, + algo_model_id=model_exec_id, + object_class_id=r.object_class_id, + object_class_name=r.object_class_name, + confidence=round(r.confidence, 4), + location=r.location, + ) + frame_results.append(frame_result) + asyncio.run_coroutine_threadsafe( + self.frame_analysis_result_service.add_frame_analysis_results(frame_results), + self.main_loop + ) + self.thread_pool.submit_task(self.push_frame_results, frame_results) def push_frame_results(self, frame_results): global_config = GlobalConfig() @@ -124,31 +129,63 @@ ) self.push_ts = current_time # 更新推送时间戳 + def log_fps(self, frame_count): + self.frames_detected += frame_count + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_detected / 10.0 + self.frames_detected = 0 + logger.info(f"FPS (detect) for device {self.device.code}: {fps}") + self.fps_ts = current_time + def run(self): while not self.stream_loader.init: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 self.stream_loader.init_cap() - for frame in self.stream_loader: + for frames in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 # print('frame') - if frame is None: + if frames is None or len(frames)<=0: continue self.device_status_manager.set_status(device_id=self.device.id) - results_map = {} - annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + + result_maps = [] # 保存每个frame的result map + annotators = [] + for frame in frames: + annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + annotators.append(annotator) + for model_exec in self.model_exec_list: handle_task_name = model_exec.algo_model_info.handle_task handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) handler_instance = handler_cls(model_exec) - frame, results = handler_instance.run(frame, annotator) - results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] + frames, results = handler_instance.run(frames, annotators) + + # 遍历检测结果,按帧存储 + for frame_idx, r in enumerate(results): # 遍历每帧结果 + if len(result_maps) <= frame_idx: # 初始化 frame_results_map + result_maps.append({}) + frame_results_map = result_maps[frame_idx] + # 为当前模型存储检测结果 + if model_exec.algo_model_id not in frame_results_map: + frame_results_map[model_exec.algo_model_id] = [] + # 添加检测结果 + frame_results_map[model_exec.algo_model_id].extend( + DetectionResult.from_dict(box) for box in r + ) # 结果处理 - self.thread_pool.submit_task(self.save_frame_results, frame, results_map) - self.display_frame_manager.add_frame(self.device.id, annotator.result()) + self.thread_pool.submit_task(self.save_frame_results, frames, result_maps) + for annotator in annotators: + self.display_frame_manager.add_frame(self.device.id, annotator.result()) + + self.log_fps(len(frames)) + + # future = asyncio.run_coroutine_threadsafe( # self.display_frame_manager.add_frame(self.device.id, annotator.result()), self.main_loop # ) diff --git a/algo/model_manager.py b/algo/model_manager.py index 1c30867..e84676f 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -19,11 +19,12 @@ class ModelManager: - def __init__(self, model_service: ModelService, model_warm_up=5): + def __init__(self, model_service: ModelService, model_warm_up=5, batch_size=4): # self.db = db self.model_service = model_service self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up + self.batch_size=4 async def query_model_inuse(self): algo_model_list = list(await self.model_service.get_models_in_use()) @@ -58,8 +59,11 @@ if not isinstance(imgsz, list): imgsz = [imgsz, imgsz] dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + dummy_inputs = [] + for _ in range(self.batch_size): + dummy_inputs.append(dummy_input) for i in range(self.model_warm_up): - algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + algo_model_exec.algo_model_exec.predict(source=dummy_inputs, imgsz=imgsz, verbose=False,save=False,save_txt=False) logger.info(f'warm up model {model_name} success!') logger.info(f'load model {model_name} success!') diff --git a/algo/scene_runner.py b/algo/scene_runner.py index e5c067e..5120f96 100644 --- a/algo/scene_runner.py +++ b/algo/scene_runner.py @@ -69,7 +69,7 @@ try: handle_task_name = scene.scene_handle_task handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) - handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) + handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop,scene.range_points) future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) self.device_tasks[device.id] = handler_instance logger.info(f'start thread {thread_id}, device info: {device}') diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 4133388..202213c 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -1,13 +1,20 @@ +from datetime import datetime, timedelta + import cv2 import time import numpy as np from threading import Thread, Event + +import queue + from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool class OpenCVStreamLoad: def __init__(self, camera_url, camera_code, device_thread_id = '', + batch_size=4, + queue_size=100, retry_interval=1, vid_stride=1): assert camera_url is not None and camera_url != '' @@ -19,10 +26,15 @@ self.init = False self.frame = None + self.frame_queue = queue.Queue(maxsize=queue_size) + self.fps_ts = None self.cap = None + self.batch_size = batch_size + self.frames_read = 0 self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 self.thread_pool = GlobalThreadPool() + self.init_cap() # self.thread = Thread(target=self.update, daemon=True) # self.thread.start() @@ -44,7 +56,14 @@ 尝试创建视频流捕获对象。 """ try: - cap = cv2.VideoCapture(self.url) + # cap = cv2.VideoCapture(self.url) + gst_pipeline = ( + f"rtspsrc location={self.url} ! " + f"rtph264depay ! h264parse ! " + f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + f"appsink" + ) + cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) # 可以在这里设置cap的一些属性,如果需要的话 return cap except Exception as e: @@ -74,9 +93,13 @@ cap = cv2.VideoCapture(self.url) # cap = cv2.VideoCapture(self.url, cv2.CAP_FFMPEG) - # 使用 GStreamer pipeline 作为 VideoCapture 的输入 - # gst_str = f"rtspsrc location={self.url} ! decodebin ! videoconvert ! appsink" - # cap = cv2.VideoCapture(gst_str, cv2.CAP_GSTREAMER) + # gst_pipeline = ( + # f"rtspsrc location={self.url} ! " + # f"rtph264depay ! h264parse ! " + # f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + # f"appsink" + # ) + # cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) if cap.isOpened(): logger.info(f"{self.url} connected successfully!") @@ -91,6 +114,15 @@ return cap + def log_fps(self): + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_read / 10.0 + self.frames_read = 0 + logger.info(f"FPS (read) for device {self.camera_code}: {fps}") + self.fps_ts = current_time + def update(self): vid_n = 0 log_n = 0 @@ -111,6 +143,9 @@ else: vid_n += 1 self.frame = frame + self.frames_read += 1 + if not self.frame_queue.full(): + self.frame_queue.put(frame) if log_n % 1000 == 0: logger.debug(f'{self.url} cap success') log_n = (log_n + 1) % 250 @@ -120,12 +155,22 @@ self.cap.release() self.frame = None self.cap = self.get_connect() # 尝试重新连接 + self.log_fps() def __iter__(self): return self def __next__(self): - return self.frame + batch_frames = [] + + queue_length = self.frame_queue.qsize() + if queue_length < self.batch_size: + return [] + + while not self.frame_queue.empty() and len(batch_frames) < self.batch_size: + frame = self.frame_queue.get() + batch_frames.append(frame) + return batch_frames def stop(self): logger.info(f'stop stream loader {self.url},device thread {self.device_thread_id} stopped') diff --git a/apis/control.py b/apis/control.py index 9c6d6fb..2d351d8 100644 --- a/apis/control.py +++ b/apis/control.py @@ -9,6 +9,7 @@ from fastapi import APIRouter import docker import socket +import signal from apis.base import standard_error_response, standard_response from common.global_logger import logger @@ -47,6 +48,23 @@ return None +def handle_sigterm(signum, frame): + print(f"Received signal: {signum}, shutting down gracefully...") + os._exit(0) + + +signal.signal(signal.SIGTERM, handle_sigterm) + + +def shutdown(): + # 等待短暂时间以确保响应已发送 + time.sleep(1) + logger.info("Shutting down the application...") + # sys.exit(0) + # os.kill(os.getpid(), signal.SIGTERM) + os._exit(0) + + @router.get("/restart") async def restart(): try: @@ -92,17 +110,20 @@ logger.error(f"Failed to restart container asynchronously: {ex}") # 在新线程中启动重启操作 - threading.Thread(target=restart_container_async).start() + # threading.Thread(target=restart_container_async).start() + threading.Thread(target=shutdown).start() return standard_response() except Exception as e: traceback.print_exc() return standard_error_response(code=500, message=f"Failed to restart container: {e}") + @router.get("/sync_test") def sync_test(): return standard_response() + @router.get("/async_test") async def async_test(): return standard_response() diff --git a/apis/device_scene_realtion.py b/apis/device_scene_realtion.py index d9aa3a2..77d3677 100644 --- a/apis/device_scene_realtion.py +++ b/apis/device_scene_realtion.py @@ -22,8 +22,8 @@ @router.post("/update_by_device", response_model=StandardResponse[DeviceSceneRelation]) -async def update_by_device(device_id: int, scene_id: int, +async def update_by_device(device_id: int, scene_id: int, range_points: str = None, db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - relation = await service.update_relation_by_device(device_id, scene_id) + relation = await service.update_relation_by_device(device_id, scene_id,range_points) return standard_response(data=relation) diff --git a/app_instance.py b/app_instance.py index 82a750d..c648f6e 100644 --- a/app_instance.py +++ b/app_instance.py @@ -1,4 +1,5 @@ import asyncio + from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException @@ -17,6 +18,7 @@ from services.scene_service import SceneService from services.schedule_job import start_scheduler from tcp.tcp_manager import TcpManager +from tcp.tcp_server import start_server _app = None # 创建一个私有变量来存储 app 实例 @@ -69,6 +71,8 @@ app.state.scene_runner = scene_runner await scene_runner.start() + main_loop.create_task(start_server()) + main_loop.create_task(start_scheduler()) yield # 允许请求处理 diff --git a/common/byte_utils.py b/common/byte_utils.py index 301bb8a..ea3f4ea 100644 --- a/common/byte_utils.py +++ b/common/byte_utils.py @@ -1,3 +1,11 @@ def format_bytes(data: bytes): - return ''.join(f'\\x{byte:02X}' for byte in data) \ No newline at end of file + return ''.join(f'\\x{byte:02X}' for byte in data) + + +if __name__ == '__main__': + msg = b'\xaa\x01\x00\x93\x12\x00\xA6' + print(format_bytes(msg)) + + msg = b'\xAA\x01\x00\x95\x00\x00\x96' + print(format_bytes(msg)) \ No newline at end of file diff --git a/common/consts.py b/common/consts.py index 5a3286c..6deaedf 100644 --- a/common/consts.py +++ b/common/consts.py @@ -36,3 +36,4 @@ GAS = 1 ALGO_RESULT = 2 ALARM = 3 + HARMFUL_GAS = 4 diff --git a/common/detect_utils.py b/common/detect_utils.py new file mode 100644 index 0000000..33c340c --- /dev/null +++ b/common/detect_utils.py @@ -0,0 +1,72 @@ +import matplotlib.path as mpath + + +def is_within_alert_range(bbox, range): + if range is None: + return True + bottom_left = [bbox[0], bbox[3]] + bottom_right = [bbox[2], bbox[3]] + + in_fence = is_point_in_polygon(bottom_left, range) or is_point_in_polygon(bottom_right, range) + return in_fence + + +def is_point_in_polygon(point, polygon): + """ + 判断点是否在多边形内 + + 参数: + - point: 待判断的点,格式为(x, y) + - polygon: 多边形的顶点坐标列表,格式为[(x1, y1), (x2, y2), ...] + + 返回: + - 如果点在多边形内,返回True;否则,返回False。 + """ + # 创建多边形路径对象 + # polygon = np.array(polygon, dtype=np.int32) + # + # hull = ConvexHull(polygon) + # sorted_coordinates = polygon[hull.vertices] + + poly_path = mpath.Path(polygon) + # 使用contains_point方法检查点是否在多边形内 + return poly_path.contains_point(point) + + +def intersection_area(bbox1, bbox2): + # 计算两个坐标框的重叠面积 + x1, y1, x2, y2 = bbox1 + x3, y3, x4, y4 = bbox2 + + xi1 = max(x1, x3) + yi1 = max(y1, y3) + xi2 = min(x2, x4) + yi2 = min(y2, y4) + + width = max(0, xi2 - xi1) + height = max(0, yi2 - yi1) + + return width * height + + +def bbox_area(bbox): + # 计算坐标框的面积 + x1, y1, x2, y2 = bbox + return (x2 - x1) * (y2 - y1) + + +def get_person_head(person_bbox, head_bboxes): + best_head = None + max_overlap = 0 + + for head in head_bboxes: + head_bbox = head.xyxy.cpu().squeeze() + overlap_area = intersection_area(person_bbox, head_bbox) + head_area = bbox_area(head_bbox) + overlap_ratio = overlap_area / head_area + + if overlap_ratio >= 0.8 and (best_head is None or overlap_area > max_overlap or ( + overlap_area == max_overlap and float(head.conf) > float(best_head.conf))): + best_head = head + max_overlap = overlap_area + return best_head diff --git a/common/harmful_gas_manager.py b/common/harmful_gas_manager.py new file mode 100644 index 0000000..297e02c --- /dev/null +++ b/common/harmful_gas_manager.py @@ -0,0 +1,44 @@ +import threading +from datetime import datetime + + +class HarmfulGasManager: + + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + # 确保线程安全的单例模式 + if not cls._instance: + with cls._lock: + if not cls._instance: + cls._instance = super(HarmfulGasManager, cls).__new__(cls) + return cls._instance + + def __init__(self): + # 初始化一次,避免重复初始化 + if not hasattr(self, "device_data"): + self.device_data = {} + self.lock = threading.Lock() + + def get_device_all_data(self, device_code): + """获取指定设备的在线状态""" + with self.lock: + return self.device_data.get(device_code, None) + + def get_device_data(self, device_code, gas_type): + """获取指定设备的在线状态""" + with self.lock: + device_data = self.device_data.get(device_code, None) + if device_data: + return device_data.get(gas_type, None) + return None + + def set_device_data(self, device_code, gas_type, gas_data): + """设置指定设备的在线状态""" + with self.lock: + device_data = self.device_data.get(device_code, None) + if not device_data: + device_data = {} + self.device_data[device_code] = device_data + device_data[gas_type] = gas_data \ No newline at end of file diff --git a/common/http_utils.py b/common/http_utils.py index 5f15707..a49c6ea 100644 --- a/common/http_utils.py +++ b/common/http_utils.py @@ -22,3 +22,11 @@ logger.info(f"Response: {response.status_code}, {response.text}") except requests.RequestException as e: logger.error(f"Failed to push data: {e}") + + +def get_request(url, headers=None): + try: + response = requests.get(url, headers=headers) + return response.json() + except requests.RequestException as e: + logger.error(f"Failed to get data: {e}") diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db index d1acdd4..694b825 100644 --- a/db/safe-algo-pro.db +++ b/db/safe-algo-pro.db Binary files differ diff --git a/entity/alarm_record.py b/entity/alarm_record.py new file mode 100644 index 0000000..1ebe474 --- /dev/null +++ b/entity/alarm_record.py @@ -0,0 +1,32 @@ +from datetime import datetime +from typing import Optional + +from sqlmodel import Field, SQLModel + + +class AlarmRecordBase(SQLModel): + # todo tree code?? alarm_category?? + device_code: str + alarm_type: int + alarm_content: str + alarm_value: Optional[float] = None + alarm_image: Optional[str] = None + alarm_time: datetime = Field(default_factory=datetime.now) + + class Config: + json_encoders = { + datetime: lambda v: v.strftime('%Y-%m-%d %H:%M:%S') + } + + +class AlarmRecord(AlarmRecordBase, table=True): + __tablename__ = 'alarm_record' + id: Optional[int] = Field(default=None, primary_key=True) + + +class AlarmRecordCreate(AlarmRecordBase): + pass + + +class AlarmRecordInfo(AlarmRecordBase): + id: int diff --git a/entity/device_scene_relation.py b/entity/device_scene_relation.py index c4543d7..38481ef 100644 --- a/entity/device_scene_relation.py +++ b/entity/device_scene_relation.py @@ -8,6 +8,7 @@ class DeviceSceneRelationBase(SQLModel): scene_id: int device_id: int + range_points: Optional[str] = None class DeviceSceneRelation(DeviceSceneRelationBase, TimestampMixin, table=True): diff --git a/model_handler/base_model_handler.py b/model_handler/base_model_handler.py index 46e0d53..c7c3490 100644 --- a/model_handler/base_model_handler.py +++ b/model_handler/base_model_handler.py @@ -8,41 +8,48 @@ def __init__(self, model: AlgoModelExec): self.model = model self.model_names = model.algo_model_exec.names + self.model_ids = list(self.model_names.keys()) def pre_process(self, frame): return frame - def model_inference(self, frame): - results_generator = self.model.algo_model_exec.predict(source=frame, imgsz=self.model.input_size, + def model_inference(self, frames): + results_generator = self.model.algo_model_exec.predict(source=frames, imgsz=self.model.input_size, save_txt=False, save=False, - verbose=False, stream=True) - results = list(results_generator) # 确保生成器转换为列表 - result = results[0] - # logger.debug(f"model {self.model.algo_model_info.name} result: {len(result)}") - return result + verbose=True, stream=True) + result_boxes = [] + for r in results_generator: + result_boxes.append(r.boxes) - def post_process(self, frame, model_result, annotator): + return result_boxes + + def post_process(self, frames, model_results, annotators): results = [] - for box in model_result.boxes: - results.append( - { - 'object_class_id': int(box.cls), - 'object_class_name': self.model_names[int(box.cls)], - 'confidence': float(box.conf), - 'location': ", ".join([f"{x:.6f}" for x in box.xyxyn.cpu().squeeze().tolist()]) - } - ) - if annotator is not None: - for s_box in model_result.boxes: - annotator.box_label(s_box.xyxy.cpu().squeeze(), - f"{self.model_names[int(s_box.cls)]} {float(s_box.conf):.2f}", - color=colors(int(s_box.cls)), - rotated=False) + for idx,frame in enumerate(frames): + frame_boxes = model_results[idx] + annotator = annotators[idx] + frame_result = [] + for box in frame_boxes: + frame_result.append( + { + 'object_class_id': int(box.cls), + 'object_class_name': self.model_names[int(box.cls)], + 'confidence': float(box.conf), + 'location': ", ".join([f"{x:.6f}" for x in box.xyxyn.cpu().squeeze().tolist()]) + } + ) + results.append(frame_result) + if annotator is not None: + for s_box in frame_boxes: + annotator.box_label(s_box.xyxy.cpu().squeeze(), + f"{self.model_names[int(s_box.cls)]} {float(s_box.conf):.2f}", + color=colors(int(s_box.cls)), + rotated=False) return results - def run(self, frame, annotator): - processed_frame = self.pre_process(frame=frame) - model_result = self.model_inference(frame=processed_frame) - result = self.post_process(frame=frame, model_result=model_result, annotator=annotator) - return frame, result + def run(self, frames, annotators): + processed_frames = [self.pre_process(frame=frame) for frame in frames] + result_boxes = self.model_inference(frames=processed_frames) + results = self.post_process(frames=frames, model_results=result_boxes, annotators=annotators) + return frames, results diff --git a/model_handler/coco_engine_model_handler.py b/model_handler/coco_engine_model_handler.py index e6023d0..ec8006f 100644 --- a/model_handler/coco_engine_model_handler.py +++ b/model_handler/coco_engine_model_handler.py @@ -88,3 +88,4 @@ 78: 'hair_drier', 79: 'toothbrush', } + self.model_ids = list(self.model_names.keys()) diff --git a/model_handler/labor_engine_model_handler.py b/model_handler/labor_engine_model_handler.py index d0fefba..8892fcb 100644 --- a/model_handler/labor_engine_model_handler.py +++ b/model_handler/labor_engine_model_handler.py @@ -68,3 +68,4 @@ 58: '鼓风机', } + self.model_ids = list(self.model_names.keys()) diff --git a/run.sh b/run.sh index bb36274..815883a 100644 --- a/run.sh +++ b/run.sh @@ -8,4 +8,5 @@ export LC_ALL=C.UTF-8 cd /code/safe-algo-pro -python3 main.py +exec python3 main.py +#tail -f /dev/null \ No newline at end of file diff --git a/scene_handler/alarm_message_center.py b/scene_handler/alarm_message_center.py new file mode 100644 index 0000000..fba1397 --- /dev/null +++ b/scene_handler/alarm_message_center.py @@ -0,0 +1,121 @@ +import asyncio +import copy +import time +from collections import deque +from threading import Thread, Lock + +''' +队列消息取出规则: +- 按 alarmCategory(优先级:2 > 1 > 3 > 0)和 category_order 从小到大排序。 +- 确保每个 alarmCategory 在 10 秒内只能发送一次。 + +消息定时清理: +- 每次取消息时会检查队列中长期堆积、超过一定时间(比如 30 秒)的消息,并清理这些消息。 +''' + + +class AlarmMessageCenter: + def __init__(self, device_id, main_loop, tcp_manager=None, message_send_interval=5, category_interval=10, retention_time=30, + category_priority=None): + self.device_id = device_id + self.tcp_manager = tcp_manager + self.main_loop = main_loop + + self.queue = deque() + self.last_sent_time = {} # 记录每个 alarmCategory 的最后发送时间 + self.lock = Lock() + self.message_send_interval = message_send_interval # 消息发送间隔(秒) + self.category_interval = category_interval # 类别发送间隔(秒) + self.retention_time = retention_time # 消息最长保留时间(秒) + if category_priority: + self.category_priority = category_priority + self.auto_update_alarm_priority = False + else: + self.category_priority = category_priority + self.auto_update_alarm_priority = True + + def add_message(self, message_ori): + message = copy.deepcopy(message_ori) + message['timestamp'] = int(time.time()) # 添加消息放入队列的时间 + with self.lock: + self.queue.append(message) + + # 动态更新优先级映射 + if self.auto_update_alarm_priority: + alarm_category = message['alarmCategory'] + if alarm_category not in self.category_priority: + unique_categories = sorted({msg['alarmCategory'] for msg in self.queue}) + self.category_priority = {cat: idx for idx, cat in enumerate(unique_categories)} + print(f"更新优先级映射: {self.category_priority}") + + def _clean_old_messages(self): + """清理长期堆积的消息""" + print(f'清理长期堆积的消息 (队列长度: {len(self.queue)})') + now = time.time() + with self.lock: + self.queue = deque([msg for msg in self.queue if now - msg['timestamp'] <= self.retention_time]) + print(f'清理后的队列长度: {len(self.queue)}') + + # def _get_next_message(self): + # """按优先级和时间规则取出下一条消息""" + # now = time.time() + # with self.lock: + # # 按照规则排序队列 + # sorted_queue = sorted( + # self.queue, + # key=lambda x: ( + # self.alarm_priority.get(x['alarmCategory'], 4), # 按 alarmCategory 优先级排序 + # x['category_order'], # category_order 小的排前面 + # x['timestamp'] # 时间靠前的排前面 + # ) + # ) + # # 遍历排序后的队列,找到符合规则的消息 + # for msg in sorted_queue: + # alarm_category = msg['alarmCategory'] + # if alarm_category not in self.last_sent_time or now - self.last_sent_time[alarm_category] > 10: + # self.queue.remove(msg) + # self.last_sent_time[alarm_category] = now + # return msg + # return None + + def _get_next_message(self): + """按优先级和时间规则取出下一条消息""" + now = time.time() + with self.lock: + # 按优先级依次检查 + for priority in sorted(self.category_priority.values()): + found_valid_message = False # 用于标记当前优先级是否有消息被处理 + for msg in sorted( + (m for m in self.queue if self.category_priority.get(m['alarmCategory'], 99) == priority), + key=lambda x: (-x['timestamp'], x['category_order']) + ): + alarm_category = msg['alarmCategory'] + # 检查是否符合发送条件 + if alarm_category not in self.last_sent_time or now - self.last_sent_time[ + alarm_category] > self.category_interval: + self.queue.remove(msg) + self.last_sent_time[alarm_category] = now + return msg # 找到符合条件的消息立即返回 + found_valid_message = True # 当前优先级存在消息但不符合条件 + # 如果当前优先级的所有消息都被检查过且不符合条件,跳到下一个优先级 + if not found_valid_message: + continue + return None # 如果所有优先级都没有消息符合条件,则返回 None + + def process_messages(self): + while True: + time.sleep(self.message_send_interval) + self._clean_old_messages() # 清理长期堆积的消息 + next_message = self._get_next_message() + if next_message: + self.send_message(next_message) + + def send_message(self, message): + """发送报警消息""" + print(f"发送报警消息: {message['alarmContent']} (类别: {message['alarmCategory']}, 时间: {message['timestamp']})") + if self.tcp_manager: + asyncio.run_coroutine_threadsafe( + self.tcp_manager.send_message_to_device(device_id=self.device_id, + message=message['alarmSoundMessage'], + have_response=False), + self.main_loop) diff --git a/scene_handler/alarm_record_center.py b/scene_handler/alarm_record_center.py new file mode 100644 index 0000000..5373cd5 --- /dev/null +++ b/scene_handler/alarm_record_center.py @@ -0,0 +1,128 @@ +import asyncio +import base64 +import time +from datetime import datetime + +import cv2 + +from common.global_thread_pool import GlobalThreadPool +from common.http_utils import send_request +from db.database import get_db +from entity.alarm_record import AlarmRecordCreate +from services.alarm_record_service import AlarmRecordService +from services.global_config import GlobalConfig + + +def image_to_base64(numpy_image, format='jpg'): + # 将NumPy数组转换为图片格式的字节流 + # format指定图片格式,'png'或'jpeg' + success, encoded_image = cv2.imencode(f'.{format}', numpy_image) + if not success: + raise ValueError("Could not encode image") + + # 将字节流转换为Base64编码 + base64_encoded_image = base64.b64encode(encoded_image) + + # 将bytes类型转换为UTF-8字符串 + base64_message = base64_encoded_image.decode('utf-8') + + return base64_message + # return f'data:image/{format};base64,{base64_message}' + + +class AlarmRecordCenter: + def __init__(self, save_interval=-1, main_loop=None): + """ + 初始化报警记录中心 + :param upload_interval: 报警上传间隔(秒),如果 <= 0,则不报警 + """ + self.main_loop = main_loop + self.save_interval = save_interval + self.thread_pool = GlobalThreadPool() + self.global_config = GlobalConfig() + # self.upload_interval = upload_interval + self.device_alarm_upload_time = {} # key: device_code, value: {alarm_type: last upload time} + self.device_alarm_save_time = {} + + def need_alarm(self, device_code, alarm_dict): + """ + 是否需要报警 + :param device_code: 设备编号 + :param alarm_dict: 报警类型字典 + :return: 是否需要报警 + """ + current_time = time.time() + alarm_type = alarm_dict['alarmType'] + + return self.need_save_alarm_record(device_code, current_time, alarm_type) \ + or self.need_upload_alarm_record(device_code, current_time, alarm_type) + + def need_save_alarm_record(self, device_code, current_time, alarm_type): + if self.save_interval <= 0: + return False + + if device_code not in self.device_alarm_save_time: + self.device_alarm_save_time[device_code] = {} + last_save_time = self.device_alarm_save_time[device_code].get(alarm_type) + if last_save_time is None or (current_time - last_save_time) > self.save_interval: + return True + return False + + def need_upload_alarm_record(self, device_code, current_time, alarm_type): + push_config = self.global_config.get_alarm_push_config() + if not push_config or not push_config.push_url: + return False + + if device_code not in self.device_alarm_upload_time: + self.device_alarm_upload_time[device_code] = {} + last_upload_time = self.device_alarm_upload_time[device_code].get(alarm_type) + if last_upload_time is None or (current_time - last_upload_time) > push_config.upload_interval: + return True + return False + + async def save_record(self, alarm_data, alarm_np_img): + async for db in get_db(): + alarm_record_service = AlarmRecordService(db) + await alarm_record_service.add_alarm(alarm_data=alarm_data, alarm_np_img=alarm_np_img) + + def upload_alarm_record(self, device_code, alarm_dict, alarm_np_img=None, alarm_value=None): + """ + 上传报警记录 + :param alarm_value: + :param device_code: 设备编号 + :param alarm_dict: 报警类型字典 + :param alarm_np_img: 报警图片(np array) + """ + # 获取当前时间 + current_time = time.time() + alarm_type = alarm_dict['alarmType'] + + if self.need_save_alarm_record(device_code, current_time, alarm_type): + alarm_record_data = AlarmRecordCreate( + device_code=device_code, + device_id=100, + alarm_type=alarm_type, + alarm_content=alarm_dict['alarmContent'], + alarm_time=datetime.now(), + alarm_value=alarm_value if alarm_value else None, + ) + asyncio.run_coroutine_threadsafe( + self.save_record(alarm_record_data, alarm_np_img), + self.main_loop) + self.device_alarm_save_time[device_code][alarm_type] = current_time + + if self.need_upload_alarm_record(device_code, current_time, alarm_type): + alarm_record = { + 'devcode': device_code, + 'alarmTime': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), + 'alarmType': alarm_dict['alarmType'], + 'alarmContent': alarm_dict['alarmContent'], + } + if alarm_value: + alarm_record['alarmValue'] = alarm_value + if alarm_np_img: + alarm_record['alarmImage'] = image_to_base64(alarm_np_img) + + push_config = self.global_config.get_alarm_push_config() + self.thread_pool.submit_task(send_request, push_config.upload_url, alarm_record) + self.device_alarm_upload_time[device_code][alarm_type] = current_time diff --git a/scene_handler/block_scene_handler.py b/scene_handler/block_scene_handler.py new file mode 100644 index 0000000..3dd0b31 --- /dev/null +++ b/scene_handler/block_scene_handler.py @@ -0,0 +1,426 @@ +import time +import traceback +from asyncio import Event +from copy import deepcopy +from datetime import datetime + + +from flatbuffers.builder import np +from scipy.spatial import ConvexHull + +from algo.stream_loader import OpenCVStreamLoad +from common.detect_utils import is_within_alert_range, get_person_head, intersection_area, bbox_area +from common.device_status_manager import DeviceStatusManager +from common.global_logger import logger +from common.global_thread_pool import GlobalThreadPool +from common.harmful_gas_manager import HarmfulGasManager +from common.image_plotting import Annotator +from entity.device import Device +from scene_handler.alarm_message_center import AlarmMessageCenter +from scene_handler.alarm_record_center import AlarmRecordCenter +from scene_handler.base_scene_handler import BaseSceneHandler +from scene_handler.limit_space_scene_handler import is_overlapping +from services.global_config import GlobalConfig +from tcp.tcp_manager import TcpManager + +from entity.device import Device +from common.http_utils import get_request +from ultralytics import YOLO + +''' +alarmCategory: +0 行为监管 +1 环境监管 +2 人员监管 +3 围栏监管 + +handelType: +0 检测到报警 +1 未检测到报警 +2 人未穿戴报警 +3 其他 +''' +ALARM_DICT = [ + { + 'alarmCategory': 0, + 'alarmType': '1', + 'handelType': 1, + 'category_order': 1, + 'class_idx': [34], + 'alarm_name': 'no_fire_extinguisher', + 'alarmContent': '未检测到灭火器', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '', + }, + { + 'alarmCategory': 0, + 'alarmType': '2', + 'handelType': 1, + 'category_order': 2, + 'class_idx': [43], + 'alarm_name': 'no_barrier_tape', + 'alarmContent': '未检测到警戒线', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '', + }, + { + 'alarmCategory': 0, + 'alarmType': '3', + 'handelType': 1, + 'category_order': 3, + 'class_idx': [48], + 'alarm_name': 'no_cone', + 'alarmContent': '未检测到锥桶', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '', + }, + { + 'alarmCategory': 0, + 'alarmType': '4', + 'handelType': 1, + 'category_order': 4, + 'class_idx': [4, 5, 16], + 'alarm_name': 'no_board', + 'alarmContent': '未检测到指示牌', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '', + }, + { + 'alarmCategory': 0, + 'alarmType': '5', + 'handelType': 2, + 'category_order': -1, + 'class_idx': [18], + 'alarm_name': 'no_helmet', + 'alarmContent': '未佩戴安全帽', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '未佩戴安全帽', + }, + # todo 明火 + { + 'alarmCategory': 1, + 'alarmType': '7', + 'handelType': 3, + 'category_order': 1, + 'class_idx': [], + 'alarm_name': 'gas_alarm', + 'alarmContent': '甲烷浓度超限', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '', + }, + { + 'alarmCategory': 1, + 'alarmType': '8', + 'handelType': 3, + 'category_order': 2, + 'class_idx': [], + 'alarm_name': 'harmful_alarm', + 'alarmContent': '有害气体浓度超标', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '', + }, + { + 'alarmCategory': 2, + 'alarmType': '9', + 'handelType': 3, + 'category_order': -1, + 'class_idx': [], + 'alarm_name': 'health_alarm', + 'alarmContent': '心率血氧异常', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '', + }, + { + 'alarmCategory': 3, + 'alarmType': '10', + 'handelType': 2, + 'category_order': 4, + 'class_idx': [24], + 'alarm_name': 'break_in_alarm', + 'alarmContent': '非法闯入', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '非法闯入', + }, + +] + +COLOR_RED = (0, 0, 255) +COLOR_BLUE = (255, 0, 0) + + +class BlockSceneHandler(BaseSceneHandler): + def __init__(self, device: Device, thread_id: str, tcp_manager: TcpManager, main_loop, range_points): + super().__init__(device=device, thread_id=thread_id, tcp_manager=tcp_manager, main_loop=main_loop) + self.__stop_event = Event(loop=main_loop) + self.health_ts_dict = {} + self.harmful_ts_dict = {} + self.object_ts_dict = {} + self.thread_pool = GlobalThreadPool() + + self.alarm_message_center = AlarmMessageCenter(device.id,main_loop=main_loop, tcp_manager=tcp_manager, + category_priority={2: 0, 1: 1, 3: 2, 0: 3}) + self.alarm_record_center = AlarmRecordCenter(save_interval=device.alarm_interval,main_loop=main_loop) + self.harmful_data_manager = HarmfulGasManager() + self.device_status_manager = DeviceStatusManager() + + + self.health_device_codes = ['HWIH061000056395'] # todo + self.harmful_device_codes = [] # todo + + for helmet_code in self.health_device_codes: + self.thread_pool.submit_task(self.health_data_task, helmet_code) + for harmful_device_code in self.harmful_device_codes: + self.thread_pool.submit_task(self.harmful_data_task, harmful_device_code) + + self.thread_pool.submit_task(self.alarm_message_center.process_messages) + + # todo 明火 + self.model = YOLO('weights/labor-v8-20241114.pt') + self.model_classes = { + 0: '三脚架', + 3: '人', + 4: '作业信息公示牌', + 6: '危险告知牌', + 9: '反光衣', + 11: '呼吸面罩', + 13: '四合一', + 15: '头', + 16: '安全告知牌', + 18: '安全帽', + 20: '安全标识牌', + 24: '工服', + 34: '灭火器', + 43: '警戒线', + 48: '路锥', + 58: '鼓风机', + } + self.PERSON_CLASS_IDX = 3 + self.HEAD_CLASS_IDX = 15 + + self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, + device_thread_id=thread_id) + self.range_points = range_points + self.abs_range_points = self.get_absolute_range() + + self.tracking_status = {} # 跟踪每个行人的状态 + self.max_missing_frames = 25 # 报警的阈值 + self.disappear_threshold = 25 * 3 # 移除行人的阈值 + + def get_absolute_range(self): + fence_info = eval(self.range_points) + if fence_info and len(fence_info) > 1: + abs_points = [] + for p in fence_info: + abs_points.append( + [int(p[0] * int(self.stream_loader.frame_width)), int(p[1] * int(self.stream_loader.frame_height))]) + + abs_points = np.array(abs_points, dtype=np.int32) + hull = ConvexHull(abs_points) + sorted_coordinates = abs_points[hull.vertices] + # abs_points = abs_points.reshape((-1, 1, 2)) + return sorted_coordinates + else: + return None + + def harmful_data_task(self, harmful_device_code): + while not self.__stop_event.is_set(): + harmful_gas_data = self.harmful_data_manager.get_device_all_data(harmful_device_code) + for gas_type, gas_data in harmful_gas_data.items(): + ts_key = f'{harmful_device_code}_{gas_type}' + last_ts = self.harmful_ts_dict.get(ts_key) + gas_ts = gas_data.get('gas_ts') + if last_ts is None or (gas_ts - last_ts).total_seconds() > 0: + self.harmful_ts_dict[ts_key] = gas_ts + self.handle_harmful_gas_alarm(harmful_device_code, gas_type, gas_data) + + def health_data_task(self, helmet_code): + while not self.__stop_event.is_set(): + header = { + 'ak': 'fe80b2f021644b1b8c77fda743a83670', + 'sk': '8771ea6e931d4db646a26f67bcb89909', + } + url = f'https://jls.huaweisoft.com//api/ih-log/v1.0/ih-api/helmetInfo/{helmet_code}' + response = get_request(url, headers=header) + if response and response.get('data'): + last_ts = self.health_ts_dict.get(helmet_code) + vitalsigns_data = response.get('data').get('vitalSignsData') + if vitalsigns_data: + upload_timestamp = datetime.strptime(vitalsigns_data.get('uploadTimestamp'), "%Y-%m-%d %H:%M:%S") + if last_ts is None or (upload_timestamp.timestamp() - last_ts) > 0: + self.health_ts_dict[helmet_code] = upload_timestamp.timestamp() + if time.time() - upload_timestamp.timestamp() < 10 * 60: # 10分钟以前的数据不做处理 + self.handle_health_alarm(helmet_code, vitalsigns_data.get('bloodOxygen'), + vitalsigns_data.get('heartRate'),upload_timestamp) + time.sleep(10) + + def handle_health_alarm(self, helmet_code, blood_oxygen, heartrate, upload_timestamp): + logger.debug(f'health_data: {helmet_code}, blood_oxygen = {blood_oxygen}, heartrate = {heartrate}, ' + f'upload_timestamp = {upload_timestamp}') + if heartrate < 60 or heartrate > 120 or blood_oxygen < 85: + alarm_dict = [d for d in ALARM_DICT if d['alarmCategory'] == 2] + if alarm_dict: + self.alarm_message_center.add_message(alarm_dict[0]) + # todo 需要生成报警记录吗 需要往后台发原始数据吗 + + def handle_harmful_gas_alarm(self, device_code, gas_type, gas_data): + alarm = False + gas_value = gas_data['gas_value'] + if gas_type == 3: # h2s + alarm = gas_value > 120.0 + elif gas_type == 4: # co + alarm = gas_value > 10.0 + elif gas_type == 5: # o2 + alarm = gas_value < 15 + elif gas_type == 50: # ex + alarm = gas_value > 10 + + if alarm: + alarm_dict = [d for d in ALARM_DICT if d['alarmCategory'] == 1] + if alarm_dict: + self.alarm_message_center.add_message(alarm_dict[0]) + # todo 需要生成报警记录吗 + + def model_predict(self, frames): + results_generator = self.model(frames, save_txt=False, save=False, verbose=False, conf=0.5, + classes=list(self.model_classes.keys()), + imgsz=640, + stream=True) + result_boxes = [] + for r in results_generator: + result_boxes.append(r.boxes) + + pred_ids = [[int(box.cls) for box in sublist] for sublist in result_boxes] + pred_names = [[self.model_classes[int(box.cls)] for box in sublist] for sublist in result_boxes] + return result_boxes, pred_ids, pred_names + + def handle_behave_alarm(self, frames, result_boxes, pred_ids, pred_names): + behave_alarm_dicts = [d for d in ALARM_DICT if d['alarmCategory'] == 0] + for alarm_dict in behave_alarm_dicts: + for idx, frame_boxes in enumerate(result_boxes): + frame = frames[idx] + frame_ids, frame_names = pred_ids[idx], pred_names[idx] + object_boxes = [box for box in frame_boxes if int(box.cls) in alarm_dict['class_idx']] + if alarm_dict['handelType'] == 0: # 检测到就报警 + if object_boxes: + self.alarm_message_center.add_message(alarm_dict) + if self.alarm_record_center.need_alarm(self.device.code, alarm_dict): + annotator = Annotator(deepcopy(frame)) + for box in frame_boxes: + if int(box.cls) != self.HEAD_CLASS_IDX: + box_color = COLOR_RED if int(box.cls) in alarm_dict['class_idx'] else COLOR_BLUE + annotator.box_label(box.xyxy.cpu().squeeze(), + f"{self.model_classes[int(box.cls)]}", + color=box_color, + rotated=False) + self.alarm_record_center.upload_alarm_record(self.device.code, alarm_dict, + annotator.result()) + + elif alarm_dict['handelType'] == 1: # 检测不到报警 + if object_boxes: + self.object_ts_dict[alarm_dict['alarm_name']] = time.time() + else: + last_ts = self.object_ts_dict.get(alarm_dict['alarm_name'], 0) + if time.time() - last_ts > 5: + self.object_ts_dict[alarm_dict['alarm_name']] = time.time() + self.alarm_message_center.add_message(alarm_dict) + if self.alarm_record_center.need_alarm(self.device.code, alarm_dict): + annotator = Annotator(deepcopy(frame)) + for box in frame_boxes: + if int(box.cls) != self.HEAD_CLASS_IDX: + annotator.box_label(box.xyxy.cpu().squeeze(), + f"{self.model_classes[int(box.cls)]}", + color=COLOR_BLUE, + rotated=False) + self.alarm_record_center.upload_alarm_record(self.device.code, alarm_dict, + annotator.result()) + elif alarm_dict['handelType'] == 2: # 人未穿戴报警 + person_boxes = [box for box in frame_boxes if int(box.cls) == self.PERSON_CLASS_IDX] + head_boxes = [box for box in frame_boxes if int(box.cls) == self.HEAD_CLASS_IDX] + has_alarm = False + annotator = None + for person_box in person_boxes: + person_bbox = person_box.xyxy.cpu().squeeze() + # 检查这个人是否佩戴了安全帽 + has_helmet = True + person_head = get_person_head(person_bbox, head_boxes) + if person_head is not None: + has_helmet = any( + is_overlapping(person_head.xyxy.cpu().squeeze(), helmet.xyxy.cpu().squeeze()) + for helmet in object_boxes) + if not has_helmet: + has_alarm = True + if self.alarm_record_center.need_alarm(self.device.code, alarm_dict): + annotator = Annotator(deepcopy(frame)) if annotator is None else annotator + annotator.box_label(person_bbox, alarm_dict['label'], color=COLOR_RED, rotated=False) + + if has_alarm: + self.alarm_message_center.add_message(alarm_dict) + if self.alarm_record_center.need_alarm(self.device.code, alarm_dict): + for box in frame_boxes: + box_cls = box.cls + if box_cls != self.PERSON_CLASS_IDX and box_cls != self.HEAD_CLASS_IDX: + annotator.box_label(box.xyxy.cpu().squeeze(), + f"{self.model_classes[int(box.cls)]}", + color=COLOR_BLUE, + rotated=False) + self.alarm_record_center.upload_alarm_record(self.device.code, alarm_dict, + annotator.result()) + + def handle_break_in_alarm(self, frames, result_boxes, pred_ids, pred_names): + break_in_alarm_dicts = [d for d in ALARM_DICT if d['alarmCategory'] == 3] + for alarm_dict in break_in_alarm_dicts: + for idx, frame_boxes in enumerate(result_boxes): + frame = frames[idx] + frame_ids, frame_names = pred_ids[idx], pred_names[idx] + person_boxes = [box for box in frame_boxes if int(box.cls) == self.PERSON_CLASS_IDX] + head_boxes = [box for box in frame_boxes if int(box.cls) == self.HEAD_CLASS_IDX] + object_boxes = [box for box in frame_boxes if int(box.cls) in alarm_dict['class_idx']] + has_alarm = False + annotator = None + for person_box in person_boxes: + person_bbox = person_box.xyxy.cpu().squeeze() + person_id = person_box.id + if is_within_alert_range(person_bbox, self.abs_range_points): + has_object = True + person_head = get_person_head(person_bbox, head_boxes) + if person_head is not None: + overlap_ratio = intersection_area(person_bbox, person_head.xyxy.cpu().squeeze()) / bbox_area(person_bbox) + if overlap_ratio < 0.5: # 头占人<0.5,判断是否穿工服。不太准确 + has_object = any( + is_overlapping(person_head.xyxy.cpu().squeeze(), object_boxe.xyxy.cpu().squeeze()) + for object_boxe in object_boxes) + if not has_object: + self.tracking_status[person_box.id] = self.tracking_status.get(person_box.id, 0) + 1 + + has_alarm = True + if self.alarm_record_center.need_alarm(self.device.code, alarm_dict): + annotator = Annotator(deepcopy(frame)) if annotator is None else annotator + annotator.box_label(person_bbox, alarm_dict['label'], color=COLOR_RED, rotated=False) + + if has_alarm: + self.alarm_message_center.add_message(alarm_dict) + if self.alarm_record_center.need_alarm(self.device.code, alarm_dict): + self.alarm_record_center.upload_alarm_record(self.device.code, alarm_dict, + annotator.result()) + + def run(self): + while not self.stream_loader.init: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + self.stream_loader.init_cap() + for frames in self.stream_loader: + try: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + if not frames: + continue + + self.device_status_manager.set_status(device_id=self.device.id) + result_boxes, pred_ids, pred_names = self.model_predict(frames) # 结果都是二维数组,对应batch中的每个frame + # print(pred_names) + self.handle_behave_alarm(frames, result_boxes, pred_ids, pred_names) + self.handle_break_in_alarm(frames, result_boxes, pred_ids, pred_names) + + except Exception as ex: + traceback.print_exc() + logger.error(ex) diff --git a/scene_handler/limit_space_scene_handler.py b/scene_handler/limit_space_scene_handler.py index d1c5b9a..a7f1674 100644 --- a/scene_handler/limit_space_scene_handler.py +++ b/scene_handler/limit_space_scene_handler.py @@ -145,7 +145,7 @@ class LimitSpaceSceneHandler(BaseSceneHandler): - def __init__(self, device: Device, thread_id: str, tcp_manager: TcpManager, main_loop): + def __init__(self, device: Device, thread_id: str, tcp_manager: TcpManager, main_loop, range_points): super().__init__(device=device, thread_id=thread_id, tcp_manager=tcp_manager, main_loop=main_loop) # self.device = device # self.thread_id = thread_id @@ -287,20 +287,20 @@ if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 self.stream_loader.init_cap() - for frame in self.stream_loader: + for frames in self.stream_loader: try: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 # print('frame') - if frame is None: + if frames is None: continue self.device_status_manager.set_status(device_id=self.device.id) - result_boxes, pred_ids, pred_names = self.model_predict(frame) + # result_boxes, pred_ids, pred_names = self.model_predict(frames) frame_alarm = {} - self.process_alarm(frame, result_boxes, pred_ids, pred_names, frame_alarm) - self.process_labor(frame, result_boxes, pred_ids, pred_names) + # self.process_alarm(frame, result_boxes, pred_ids, pred_names, frame_alarm) + # self.process_labor(frame, result_boxes, pred_ids, pred_names) if len(frame_alarm.keys()) > 0: for key in frame_alarm.keys(): diff --git a/services/alarm_record_service.py b/services/alarm_record_service.py new file mode 100644 index 0000000..a0a86f9 --- /dev/null +++ b/services/alarm_record_service.py @@ -0,0 +1,46 @@ +import os +import uuid +from datetime import datetime + +import aiofiles +import cv2 +from sqlalchemy.ext.asyncio import AsyncSession + +from entity.alarm_record import AlarmRecordCreate, AlarmRecord + + +class AlarmRecordService: + def __init__(self, db: AsyncSession): + self.db = db + + async def add_alarm(self, alarm_data:AlarmRecordCreate, alarm_np_img): + async def save_alarm_file(): + # 生成当前年月日作为目录路径 + current_date = datetime.now().strftime('%Y-%m-%d') + # 生成随机 UUID 作为文件名 + file_name = f"{uuid.uuid4()}.jpeg" + # 创建保存图片的完整路径 + save_path = os.path.join('./storage/alarms', current_date, file_name) + # 创建目录(如果不存在) + os.makedirs(os.path.dirname(save_path), exist_ok=True) + + # 将 frame_data 转换为二进制数据(假设 frame_data 是一个 numpy 数组) + _, encoded_image = cv2.imencode('.jpeg', alarm_np_img) + image_data = encoded_image.tobytes() + + # 使用 aiofiles 进行异步写入 + async with aiofiles.open(save_path, 'wb') as f: + await f.write(image_data) + + return save_path + + # 异步保存图片文件 + alarm_img_path = await save_alarm_file() + + # 创建并保存到数据库中 + alarm_record = AlarmRecord.model_validate(alarm_data) + alarm_record.alarm_image = alarm_img_path + self.db.add(alarm_record) + await self.db.commit() + await self.db.refresh(alarm_record) + return alarm_record diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py deleted file mode 100644 index 9109d75..0000000 --- a/algo/algo_runner_manager.py +++ /dev/null @@ -1,8 +0,0 @@ -from algo.algo_runner import AlgoRunner - -# algo_runner = AlgoRunner() - - -def get_algo_runner(): - pass - # return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 6a86830..d2c647e 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -3,7 +3,7 @@ from copy import deepcopy from dataclasses import dataclass import importlib -from datetime import datetime +from datetime import datetime, timedelta from threading import Event from typing import List, Dict @@ -48,6 +48,8 @@ self.__stop_event = Event() # 使用 Event 控制线程的运行状态 self.frame_ts = None self.push_ts = None + self.frames_detected = 0 + self.fps_ts = None self.thread_id = thread_id self.device_frame_service = DeviceFrameService(db_session) @@ -77,37 +79,40 @@ return True return False - def save_frame_results(self, frame, results_map): - if not self.check_frame_interval(): - return - # device_frame = self.device_frame_service.add_frame(self.device.id, frame) - # frame_id = device_frame.id + def save_frame_results(self, frames, result_maps): + for idx,frame in enumerate(frames): + results_map = result_maps[idx] - future = asyncio.run_coroutine_threadsafe( - self.device_frame_service.add_frame(self.device.id, frame), self.main_loop - ) - device_frame = future.result() - frame_id = device_frame.id + if not self.check_frame_interval(): + return + # device_frame = self.device_frame_service.add_frame(self.device.id, frame) + # frame_id = device_frame.id - logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') - frame_results = [] - for model_exec_id, results in results_map.items(): - for r in results: - frame_result = FrameAnalysisResultCreate( - device_id=self.device.id, - frame_id=frame_id, - algo_model_id=model_exec_id, - object_class_id=r.object_class_id, - object_class_name=r.object_class_name, - confidence=round(r.confidence, 4), - location=r.location, - ) - frame_results.append(frame_result) - asyncio.run_coroutine_threadsafe( - self.frame_analysis_result_service.add_frame_analysis_results(frame_results), - self.main_loop - ) - self.thread_pool.submit_task(self.push_frame_results, frame_results) + future = asyncio.run_coroutine_threadsafe( + self.device_frame_service.add_frame(self.device.id, frame), self.main_loop + ) + device_frame = future.result() + frame_id = device_frame.id + + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') + frame_results = [] + for model_exec_id, results in results_map.items(): + for r in results: + frame_result = FrameAnalysisResultCreate( + device_id=self.device.id, + frame_id=frame_id, + algo_model_id=model_exec_id, + object_class_id=r.object_class_id, + object_class_name=r.object_class_name, + confidence=round(r.confidence, 4), + location=r.location, + ) + frame_results.append(frame_result) + asyncio.run_coroutine_threadsafe( + self.frame_analysis_result_service.add_frame_analysis_results(frame_results), + self.main_loop + ) + self.thread_pool.submit_task(self.push_frame_results, frame_results) def push_frame_results(self, frame_results): global_config = GlobalConfig() @@ -124,31 +129,63 @@ ) self.push_ts = current_time # 更新推送时间戳 + def log_fps(self, frame_count): + self.frames_detected += frame_count + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_detected / 10.0 + self.frames_detected = 0 + logger.info(f"FPS (detect) for device {self.device.code}: {fps}") + self.fps_ts = current_time + def run(self): while not self.stream_loader.init: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 self.stream_loader.init_cap() - for frame in self.stream_loader: + for frames in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 # print('frame') - if frame is None: + if frames is None or len(frames)<=0: continue self.device_status_manager.set_status(device_id=self.device.id) - results_map = {} - annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + + result_maps = [] # 保存每个frame的result map + annotators = [] + for frame in frames: + annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + annotators.append(annotator) + for model_exec in self.model_exec_list: handle_task_name = model_exec.algo_model_info.handle_task handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) handler_instance = handler_cls(model_exec) - frame, results = handler_instance.run(frame, annotator) - results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] + frames, results = handler_instance.run(frames, annotators) + + # 遍历检测结果,按帧存储 + for frame_idx, r in enumerate(results): # 遍历每帧结果 + if len(result_maps) <= frame_idx: # 初始化 frame_results_map + result_maps.append({}) + frame_results_map = result_maps[frame_idx] + # 为当前模型存储检测结果 + if model_exec.algo_model_id not in frame_results_map: + frame_results_map[model_exec.algo_model_id] = [] + # 添加检测结果 + frame_results_map[model_exec.algo_model_id].extend( + DetectionResult.from_dict(box) for box in r + ) # 结果处理 - self.thread_pool.submit_task(self.save_frame_results, frame, results_map) - self.display_frame_manager.add_frame(self.device.id, annotator.result()) + self.thread_pool.submit_task(self.save_frame_results, frames, result_maps) + for annotator in annotators: + self.display_frame_manager.add_frame(self.device.id, annotator.result()) + + self.log_fps(len(frames)) + + # future = asyncio.run_coroutine_threadsafe( # self.display_frame_manager.add_frame(self.device.id, annotator.result()), self.main_loop # ) diff --git a/algo/model_manager.py b/algo/model_manager.py index 1c30867..e84676f 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -19,11 +19,12 @@ class ModelManager: - def __init__(self, model_service: ModelService, model_warm_up=5): + def __init__(self, model_service: ModelService, model_warm_up=5, batch_size=4): # self.db = db self.model_service = model_service self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up + self.batch_size=4 async def query_model_inuse(self): algo_model_list = list(await self.model_service.get_models_in_use()) @@ -58,8 +59,11 @@ if not isinstance(imgsz, list): imgsz = [imgsz, imgsz] dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + dummy_inputs = [] + for _ in range(self.batch_size): + dummy_inputs.append(dummy_input) for i in range(self.model_warm_up): - algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + algo_model_exec.algo_model_exec.predict(source=dummy_inputs, imgsz=imgsz, verbose=False,save=False,save_txt=False) logger.info(f'warm up model {model_name} success!') logger.info(f'load model {model_name} success!') diff --git a/algo/scene_runner.py b/algo/scene_runner.py index e5c067e..5120f96 100644 --- a/algo/scene_runner.py +++ b/algo/scene_runner.py @@ -69,7 +69,7 @@ try: handle_task_name = scene.scene_handle_task handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) - handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) + handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop,scene.range_points) future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) self.device_tasks[device.id] = handler_instance logger.info(f'start thread {thread_id}, device info: {device}') diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 4133388..202213c 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -1,13 +1,20 @@ +from datetime import datetime, timedelta + import cv2 import time import numpy as np from threading import Thread, Event + +import queue + from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool class OpenCVStreamLoad: def __init__(self, camera_url, camera_code, device_thread_id = '', + batch_size=4, + queue_size=100, retry_interval=1, vid_stride=1): assert camera_url is not None and camera_url != '' @@ -19,10 +26,15 @@ self.init = False self.frame = None + self.frame_queue = queue.Queue(maxsize=queue_size) + self.fps_ts = None self.cap = None + self.batch_size = batch_size + self.frames_read = 0 self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 self.thread_pool = GlobalThreadPool() + self.init_cap() # self.thread = Thread(target=self.update, daemon=True) # self.thread.start() @@ -44,7 +56,14 @@ 尝试创建视频流捕获对象。 """ try: - cap = cv2.VideoCapture(self.url) + # cap = cv2.VideoCapture(self.url) + gst_pipeline = ( + f"rtspsrc location={self.url} ! " + f"rtph264depay ! h264parse ! " + f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + f"appsink" + ) + cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) # 可以在这里设置cap的一些属性,如果需要的话 return cap except Exception as e: @@ -74,9 +93,13 @@ cap = cv2.VideoCapture(self.url) # cap = cv2.VideoCapture(self.url, cv2.CAP_FFMPEG) - # 使用 GStreamer pipeline 作为 VideoCapture 的输入 - # gst_str = f"rtspsrc location={self.url} ! decodebin ! videoconvert ! appsink" - # cap = cv2.VideoCapture(gst_str, cv2.CAP_GSTREAMER) + # gst_pipeline = ( + # f"rtspsrc location={self.url} ! " + # f"rtph264depay ! h264parse ! " + # f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + # f"appsink" + # ) + # cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) if cap.isOpened(): logger.info(f"{self.url} connected successfully!") @@ -91,6 +114,15 @@ return cap + def log_fps(self): + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_read / 10.0 + self.frames_read = 0 + logger.info(f"FPS (read) for device {self.camera_code}: {fps}") + self.fps_ts = current_time + def update(self): vid_n = 0 log_n = 0 @@ -111,6 +143,9 @@ else: vid_n += 1 self.frame = frame + self.frames_read += 1 + if not self.frame_queue.full(): + self.frame_queue.put(frame) if log_n % 1000 == 0: logger.debug(f'{self.url} cap success') log_n = (log_n + 1) % 250 @@ -120,12 +155,22 @@ self.cap.release() self.frame = None self.cap = self.get_connect() # 尝试重新连接 + self.log_fps() def __iter__(self): return self def __next__(self): - return self.frame + batch_frames = [] + + queue_length = self.frame_queue.qsize() + if queue_length < self.batch_size: + return [] + + while not self.frame_queue.empty() and len(batch_frames) < self.batch_size: + frame = self.frame_queue.get() + batch_frames.append(frame) + return batch_frames def stop(self): logger.info(f'stop stream loader {self.url},device thread {self.device_thread_id} stopped') diff --git a/apis/control.py b/apis/control.py index 9c6d6fb..2d351d8 100644 --- a/apis/control.py +++ b/apis/control.py @@ -9,6 +9,7 @@ from fastapi import APIRouter import docker import socket +import signal from apis.base import standard_error_response, standard_response from common.global_logger import logger @@ -47,6 +48,23 @@ return None +def handle_sigterm(signum, frame): + print(f"Received signal: {signum}, shutting down gracefully...") + os._exit(0) + + +signal.signal(signal.SIGTERM, handle_sigterm) + + +def shutdown(): + # 等待短暂时间以确保响应已发送 + time.sleep(1) + logger.info("Shutting down the application...") + # sys.exit(0) + # os.kill(os.getpid(), signal.SIGTERM) + os._exit(0) + + @router.get("/restart") async def restart(): try: @@ -92,17 +110,20 @@ logger.error(f"Failed to restart container asynchronously: {ex}") # 在新线程中启动重启操作 - threading.Thread(target=restart_container_async).start() + # threading.Thread(target=restart_container_async).start() + threading.Thread(target=shutdown).start() return standard_response() except Exception as e: traceback.print_exc() return standard_error_response(code=500, message=f"Failed to restart container: {e}") + @router.get("/sync_test") def sync_test(): return standard_response() + @router.get("/async_test") async def async_test(): return standard_response() diff --git a/apis/device_scene_realtion.py b/apis/device_scene_realtion.py index d9aa3a2..77d3677 100644 --- a/apis/device_scene_realtion.py +++ b/apis/device_scene_realtion.py @@ -22,8 +22,8 @@ @router.post("/update_by_device", response_model=StandardResponse[DeviceSceneRelation]) -async def update_by_device(device_id: int, scene_id: int, +async def update_by_device(device_id: int, scene_id: int, range_points: str = None, db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - relation = await service.update_relation_by_device(device_id, scene_id) + relation = await service.update_relation_by_device(device_id, scene_id,range_points) return standard_response(data=relation) diff --git a/app_instance.py b/app_instance.py index 82a750d..c648f6e 100644 --- a/app_instance.py +++ b/app_instance.py @@ -1,4 +1,5 @@ import asyncio + from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException @@ -17,6 +18,7 @@ from services.scene_service import SceneService from services.schedule_job import start_scheduler from tcp.tcp_manager import TcpManager +from tcp.tcp_server import start_server _app = None # 创建一个私有变量来存储 app 实例 @@ -69,6 +71,8 @@ app.state.scene_runner = scene_runner await scene_runner.start() + main_loop.create_task(start_server()) + main_loop.create_task(start_scheduler()) yield # 允许请求处理 diff --git a/common/byte_utils.py b/common/byte_utils.py index 301bb8a..ea3f4ea 100644 --- a/common/byte_utils.py +++ b/common/byte_utils.py @@ -1,3 +1,11 @@ def format_bytes(data: bytes): - return ''.join(f'\\x{byte:02X}' for byte in data) \ No newline at end of file + return ''.join(f'\\x{byte:02X}' for byte in data) + + +if __name__ == '__main__': + msg = b'\xaa\x01\x00\x93\x12\x00\xA6' + print(format_bytes(msg)) + + msg = b'\xAA\x01\x00\x95\x00\x00\x96' + print(format_bytes(msg)) \ No newline at end of file diff --git a/common/consts.py b/common/consts.py index 5a3286c..6deaedf 100644 --- a/common/consts.py +++ b/common/consts.py @@ -36,3 +36,4 @@ GAS = 1 ALGO_RESULT = 2 ALARM = 3 + HARMFUL_GAS = 4 diff --git a/common/detect_utils.py b/common/detect_utils.py new file mode 100644 index 0000000..33c340c --- /dev/null +++ b/common/detect_utils.py @@ -0,0 +1,72 @@ +import matplotlib.path as mpath + + +def is_within_alert_range(bbox, range): + if range is None: + return True + bottom_left = [bbox[0], bbox[3]] + bottom_right = [bbox[2], bbox[3]] + + in_fence = is_point_in_polygon(bottom_left, range) or is_point_in_polygon(bottom_right, range) + return in_fence + + +def is_point_in_polygon(point, polygon): + """ + 判断点是否在多边形内 + + 参数: + - point: 待判断的点,格式为(x, y) + - polygon: 多边形的顶点坐标列表,格式为[(x1, y1), (x2, y2), ...] + + 返回: + - 如果点在多边形内,返回True;否则,返回False。 + """ + # 创建多边形路径对象 + # polygon = np.array(polygon, dtype=np.int32) + # + # hull = ConvexHull(polygon) + # sorted_coordinates = polygon[hull.vertices] + + poly_path = mpath.Path(polygon) + # 使用contains_point方法检查点是否在多边形内 + return poly_path.contains_point(point) + + +def intersection_area(bbox1, bbox2): + # 计算两个坐标框的重叠面积 + x1, y1, x2, y2 = bbox1 + x3, y3, x4, y4 = bbox2 + + xi1 = max(x1, x3) + yi1 = max(y1, y3) + xi2 = min(x2, x4) + yi2 = min(y2, y4) + + width = max(0, xi2 - xi1) + height = max(0, yi2 - yi1) + + return width * height + + +def bbox_area(bbox): + # 计算坐标框的面积 + x1, y1, x2, y2 = bbox + return (x2 - x1) * (y2 - y1) + + +def get_person_head(person_bbox, head_bboxes): + best_head = None + max_overlap = 0 + + for head in head_bboxes: + head_bbox = head.xyxy.cpu().squeeze() + overlap_area = intersection_area(person_bbox, head_bbox) + head_area = bbox_area(head_bbox) + overlap_ratio = overlap_area / head_area + + if overlap_ratio >= 0.8 and (best_head is None or overlap_area > max_overlap or ( + overlap_area == max_overlap and float(head.conf) > float(best_head.conf))): + best_head = head + max_overlap = overlap_area + return best_head diff --git a/common/harmful_gas_manager.py b/common/harmful_gas_manager.py new file mode 100644 index 0000000..297e02c --- /dev/null +++ b/common/harmful_gas_manager.py @@ -0,0 +1,44 @@ +import threading +from datetime import datetime + + +class HarmfulGasManager: + + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + # 确保线程安全的单例模式 + if not cls._instance: + with cls._lock: + if not cls._instance: + cls._instance = super(HarmfulGasManager, cls).__new__(cls) + return cls._instance + + def __init__(self): + # 初始化一次,避免重复初始化 + if not hasattr(self, "device_data"): + self.device_data = {} + self.lock = threading.Lock() + + def get_device_all_data(self, device_code): + """获取指定设备的在线状态""" + with self.lock: + return self.device_data.get(device_code, None) + + def get_device_data(self, device_code, gas_type): + """获取指定设备的在线状态""" + with self.lock: + device_data = self.device_data.get(device_code, None) + if device_data: + return device_data.get(gas_type, None) + return None + + def set_device_data(self, device_code, gas_type, gas_data): + """设置指定设备的在线状态""" + with self.lock: + device_data = self.device_data.get(device_code, None) + if not device_data: + device_data = {} + self.device_data[device_code] = device_data + device_data[gas_type] = gas_data \ No newline at end of file diff --git a/common/http_utils.py b/common/http_utils.py index 5f15707..a49c6ea 100644 --- a/common/http_utils.py +++ b/common/http_utils.py @@ -22,3 +22,11 @@ logger.info(f"Response: {response.status_code}, {response.text}") except requests.RequestException as e: logger.error(f"Failed to push data: {e}") + + +def get_request(url, headers=None): + try: + response = requests.get(url, headers=headers) + return response.json() + except requests.RequestException as e: + logger.error(f"Failed to get data: {e}") diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db index d1acdd4..694b825 100644 --- a/db/safe-algo-pro.db +++ b/db/safe-algo-pro.db Binary files differ diff --git a/entity/alarm_record.py b/entity/alarm_record.py new file mode 100644 index 0000000..1ebe474 --- /dev/null +++ b/entity/alarm_record.py @@ -0,0 +1,32 @@ +from datetime import datetime +from typing import Optional + +from sqlmodel import Field, SQLModel + + +class AlarmRecordBase(SQLModel): + # todo tree code?? alarm_category?? + device_code: str + alarm_type: int + alarm_content: str + alarm_value: Optional[float] = None + alarm_image: Optional[str] = None + alarm_time: datetime = Field(default_factory=datetime.now) + + class Config: + json_encoders = { + datetime: lambda v: v.strftime('%Y-%m-%d %H:%M:%S') + } + + +class AlarmRecord(AlarmRecordBase, table=True): + __tablename__ = 'alarm_record' + id: Optional[int] = Field(default=None, primary_key=True) + + +class AlarmRecordCreate(AlarmRecordBase): + pass + + +class AlarmRecordInfo(AlarmRecordBase): + id: int diff --git a/entity/device_scene_relation.py b/entity/device_scene_relation.py index c4543d7..38481ef 100644 --- a/entity/device_scene_relation.py +++ b/entity/device_scene_relation.py @@ -8,6 +8,7 @@ class DeviceSceneRelationBase(SQLModel): scene_id: int device_id: int + range_points: Optional[str] = None class DeviceSceneRelation(DeviceSceneRelationBase, TimestampMixin, table=True): diff --git a/model_handler/base_model_handler.py b/model_handler/base_model_handler.py index 46e0d53..c7c3490 100644 --- a/model_handler/base_model_handler.py +++ b/model_handler/base_model_handler.py @@ -8,41 +8,48 @@ def __init__(self, model: AlgoModelExec): self.model = model self.model_names = model.algo_model_exec.names + self.model_ids = list(self.model_names.keys()) def pre_process(self, frame): return frame - def model_inference(self, frame): - results_generator = self.model.algo_model_exec.predict(source=frame, imgsz=self.model.input_size, + def model_inference(self, frames): + results_generator = self.model.algo_model_exec.predict(source=frames, imgsz=self.model.input_size, save_txt=False, save=False, - verbose=False, stream=True) - results = list(results_generator) # 确保生成器转换为列表 - result = results[0] - # logger.debug(f"model {self.model.algo_model_info.name} result: {len(result)}") - return result + verbose=True, stream=True) + result_boxes = [] + for r in results_generator: + result_boxes.append(r.boxes) - def post_process(self, frame, model_result, annotator): + return result_boxes + + def post_process(self, frames, model_results, annotators): results = [] - for box in model_result.boxes: - results.append( - { - 'object_class_id': int(box.cls), - 'object_class_name': self.model_names[int(box.cls)], - 'confidence': float(box.conf), - 'location': ", ".join([f"{x:.6f}" for x in box.xyxyn.cpu().squeeze().tolist()]) - } - ) - if annotator is not None: - for s_box in model_result.boxes: - annotator.box_label(s_box.xyxy.cpu().squeeze(), - f"{self.model_names[int(s_box.cls)]} {float(s_box.conf):.2f}", - color=colors(int(s_box.cls)), - rotated=False) + for idx,frame in enumerate(frames): + frame_boxes = model_results[idx] + annotator = annotators[idx] + frame_result = [] + for box in frame_boxes: + frame_result.append( + { + 'object_class_id': int(box.cls), + 'object_class_name': self.model_names[int(box.cls)], + 'confidence': float(box.conf), + 'location': ", ".join([f"{x:.6f}" for x in box.xyxyn.cpu().squeeze().tolist()]) + } + ) + results.append(frame_result) + if annotator is not None: + for s_box in frame_boxes: + annotator.box_label(s_box.xyxy.cpu().squeeze(), + f"{self.model_names[int(s_box.cls)]} {float(s_box.conf):.2f}", + color=colors(int(s_box.cls)), + rotated=False) return results - def run(self, frame, annotator): - processed_frame = self.pre_process(frame=frame) - model_result = self.model_inference(frame=processed_frame) - result = self.post_process(frame=frame, model_result=model_result, annotator=annotator) - return frame, result + def run(self, frames, annotators): + processed_frames = [self.pre_process(frame=frame) for frame in frames] + result_boxes = self.model_inference(frames=processed_frames) + results = self.post_process(frames=frames, model_results=result_boxes, annotators=annotators) + return frames, results diff --git a/model_handler/coco_engine_model_handler.py b/model_handler/coco_engine_model_handler.py index e6023d0..ec8006f 100644 --- a/model_handler/coco_engine_model_handler.py +++ b/model_handler/coco_engine_model_handler.py @@ -88,3 +88,4 @@ 78: 'hair_drier', 79: 'toothbrush', } + self.model_ids = list(self.model_names.keys()) diff --git a/model_handler/labor_engine_model_handler.py b/model_handler/labor_engine_model_handler.py index d0fefba..8892fcb 100644 --- a/model_handler/labor_engine_model_handler.py +++ b/model_handler/labor_engine_model_handler.py @@ -68,3 +68,4 @@ 58: '鼓风机', } + self.model_ids = list(self.model_names.keys()) diff --git a/run.sh b/run.sh index bb36274..815883a 100644 --- a/run.sh +++ b/run.sh @@ -8,4 +8,5 @@ export LC_ALL=C.UTF-8 cd /code/safe-algo-pro -python3 main.py +exec python3 main.py +#tail -f /dev/null \ No newline at end of file diff --git a/scene_handler/alarm_message_center.py b/scene_handler/alarm_message_center.py new file mode 100644 index 0000000..fba1397 --- /dev/null +++ b/scene_handler/alarm_message_center.py @@ -0,0 +1,121 @@ +import asyncio +import copy +import time +from collections import deque +from threading import Thread, Lock + +''' +队列消息取出规则: +- 按 alarmCategory(优先级:2 > 1 > 3 > 0)和 category_order 从小到大排序。 +- 确保每个 alarmCategory 在 10 秒内只能发送一次。 + +消息定时清理: +- 每次取消息时会检查队列中长期堆积、超过一定时间(比如 30 秒)的消息,并清理这些消息。 +''' + + +class AlarmMessageCenter: + def __init__(self, device_id, main_loop, tcp_manager=None, message_send_interval=5, category_interval=10, retention_time=30, + category_priority=None): + self.device_id = device_id + self.tcp_manager = tcp_manager + self.main_loop = main_loop + + self.queue = deque() + self.last_sent_time = {} # 记录每个 alarmCategory 的最后发送时间 + self.lock = Lock() + self.message_send_interval = message_send_interval # 消息发送间隔(秒) + self.category_interval = category_interval # 类别发送间隔(秒) + self.retention_time = retention_time # 消息最长保留时间(秒) + if category_priority: + self.category_priority = category_priority + self.auto_update_alarm_priority = False + else: + self.category_priority = category_priority + self.auto_update_alarm_priority = True + + def add_message(self, message_ori): + message = copy.deepcopy(message_ori) + message['timestamp'] = int(time.time()) # 添加消息放入队列的时间 + with self.lock: + self.queue.append(message) + + # 动态更新优先级映射 + if self.auto_update_alarm_priority: + alarm_category = message['alarmCategory'] + if alarm_category not in self.category_priority: + unique_categories = sorted({msg['alarmCategory'] for msg in self.queue}) + self.category_priority = {cat: idx for idx, cat in enumerate(unique_categories)} + print(f"更新优先级映射: {self.category_priority}") + + def _clean_old_messages(self): + """清理长期堆积的消息""" + print(f'清理长期堆积的消息 (队列长度: {len(self.queue)})') + now = time.time() + with self.lock: + self.queue = deque([msg for msg in self.queue if now - msg['timestamp'] <= self.retention_time]) + print(f'清理后的队列长度: {len(self.queue)}') + + # def _get_next_message(self): + # """按优先级和时间规则取出下一条消息""" + # now = time.time() + # with self.lock: + # # 按照规则排序队列 + # sorted_queue = sorted( + # self.queue, + # key=lambda x: ( + # self.alarm_priority.get(x['alarmCategory'], 4), # 按 alarmCategory 优先级排序 + # x['category_order'], # category_order 小的排前面 + # x['timestamp'] # 时间靠前的排前面 + # ) + # ) + # # 遍历排序后的队列,找到符合规则的消息 + # for msg in sorted_queue: + # alarm_category = msg['alarmCategory'] + # if alarm_category not in self.last_sent_time or now - self.last_sent_time[alarm_category] > 10: + # self.queue.remove(msg) + # self.last_sent_time[alarm_category] = now + # return msg + # return None + + def _get_next_message(self): + """按优先级和时间规则取出下一条消息""" + now = time.time() + with self.lock: + # 按优先级依次检查 + for priority in sorted(self.category_priority.values()): + found_valid_message = False # 用于标记当前优先级是否有消息被处理 + for msg in sorted( + (m for m in self.queue if self.category_priority.get(m['alarmCategory'], 99) == priority), + key=lambda x: (-x['timestamp'], x['category_order']) + ): + alarm_category = msg['alarmCategory'] + # 检查是否符合发送条件 + if alarm_category not in self.last_sent_time or now - self.last_sent_time[ + alarm_category] > self.category_interval: + self.queue.remove(msg) + self.last_sent_time[alarm_category] = now + return msg # 找到符合条件的消息立即返回 + found_valid_message = True # 当前优先级存在消息但不符合条件 + # 如果当前优先级的所有消息都被检查过且不符合条件,跳到下一个优先级 + if not found_valid_message: + continue + return None # 如果所有优先级都没有消息符合条件,则返回 None + + def process_messages(self): + while True: + time.sleep(self.message_send_interval) + self._clean_old_messages() # 清理长期堆积的消息 + next_message = self._get_next_message() + if next_message: + self.send_message(next_message) + + def send_message(self, message): + """发送报警消息""" + print(f"发送报警消息: {message['alarmContent']} (类别: {message['alarmCategory']}, 时间: {message['timestamp']})") + if self.tcp_manager: + asyncio.run_coroutine_threadsafe( + self.tcp_manager.send_message_to_device(device_id=self.device_id, + message=message['alarmSoundMessage'], + have_response=False), + self.main_loop) diff --git a/scene_handler/alarm_record_center.py b/scene_handler/alarm_record_center.py new file mode 100644 index 0000000..5373cd5 --- /dev/null +++ b/scene_handler/alarm_record_center.py @@ -0,0 +1,128 @@ +import asyncio +import base64 +import time +from datetime import datetime + +import cv2 + +from common.global_thread_pool import GlobalThreadPool +from common.http_utils import send_request +from db.database import get_db +from entity.alarm_record import AlarmRecordCreate +from services.alarm_record_service import AlarmRecordService +from services.global_config import GlobalConfig + + +def image_to_base64(numpy_image, format='jpg'): + # 将NumPy数组转换为图片格式的字节流 + # format指定图片格式,'png'或'jpeg' + success, encoded_image = cv2.imencode(f'.{format}', numpy_image) + if not success: + raise ValueError("Could not encode image") + + # 将字节流转换为Base64编码 + base64_encoded_image = base64.b64encode(encoded_image) + + # 将bytes类型转换为UTF-8字符串 + base64_message = base64_encoded_image.decode('utf-8') + + return base64_message + # return f'data:image/{format};base64,{base64_message}' + + +class AlarmRecordCenter: + def __init__(self, save_interval=-1, main_loop=None): + """ + 初始化报警记录中心 + :param upload_interval: 报警上传间隔(秒),如果 <= 0,则不报警 + """ + self.main_loop = main_loop + self.save_interval = save_interval + self.thread_pool = GlobalThreadPool() + self.global_config = GlobalConfig() + # self.upload_interval = upload_interval + self.device_alarm_upload_time = {} # key: device_code, value: {alarm_type: last upload time} + self.device_alarm_save_time = {} + + def need_alarm(self, device_code, alarm_dict): + """ + 是否需要报警 + :param device_code: 设备编号 + :param alarm_dict: 报警类型字典 + :return: 是否需要报警 + """ + current_time = time.time() + alarm_type = alarm_dict['alarmType'] + + return self.need_save_alarm_record(device_code, current_time, alarm_type) \ + or self.need_upload_alarm_record(device_code, current_time, alarm_type) + + def need_save_alarm_record(self, device_code, current_time, alarm_type): + if self.save_interval <= 0: + return False + + if device_code not in self.device_alarm_save_time: + self.device_alarm_save_time[device_code] = {} + last_save_time = self.device_alarm_save_time[device_code].get(alarm_type) + if last_save_time is None or (current_time - last_save_time) > self.save_interval: + return True + return False + + def need_upload_alarm_record(self, device_code, current_time, alarm_type): + push_config = self.global_config.get_alarm_push_config() + if not push_config or not push_config.push_url: + return False + + if device_code not in self.device_alarm_upload_time: + self.device_alarm_upload_time[device_code] = {} + last_upload_time = self.device_alarm_upload_time[device_code].get(alarm_type) + if last_upload_time is None or (current_time - last_upload_time) > push_config.upload_interval: + return True + return False + + async def save_record(self, alarm_data, alarm_np_img): + async for db in get_db(): + alarm_record_service = AlarmRecordService(db) + await alarm_record_service.add_alarm(alarm_data=alarm_data, alarm_np_img=alarm_np_img) + + def upload_alarm_record(self, device_code, alarm_dict, alarm_np_img=None, alarm_value=None): + """ + 上传报警记录 + :param alarm_value: + :param device_code: 设备编号 + :param alarm_dict: 报警类型字典 + :param alarm_np_img: 报警图片(np array) + """ + # 获取当前时间 + current_time = time.time() + alarm_type = alarm_dict['alarmType'] + + if self.need_save_alarm_record(device_code, current_time, alarm_type): + alarm_record_data = AlarmRecordCreate( + device_code=device_code, + device_id=100, + alarm_type=alarm_type, + alarm_content=alarm_dict['alarmContent'], + alarm_time=datetime.now(), + alarm_value=alarm_value if alarm_value else None, + ) + asyncio.run_coroutine_threadsafe( + self.save_record(alarm_record_data, alarm_np_img), + self.main_loop) + self.device_alarm_save_time[device_code][alarm_type] = current_time + + if self.need_upload_alarm_record(device_code, current_time, alarm_type): + alarm_record = { + 'devcode': device_code, + 'alarmTime': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), + 'alarmType': alarm_dict['alarmType'], + 'alarmContent': alarm_dict['alarmContent'], + } + if alarm_value: + alarm_record['alarmValue'] = alarm_value + if alarm_np_img: + alarm_record['alarmImage'] = image_to_base64(alarm_np_img) + + push_config = self.global_config.get_alarm_push_config() + self.thread_pool.submit_task(send_request, push_config.upload_url, alarm_record) + self.device_alarm_upload_time[device_code][alarm_type] = current_time diff --git a/scene_handler/block_scene_handler.py b/scene_handler/block_scene_handler.py new file mode 100644 index 0000000..3dd0b31 --- /dev/null +++ b/scene_handler/block_scene_handler.py @@ -0,0 +1,426 @@ +import time +import traceback +from asyncio import Event +from copy import deepcopy +from datetime import datetime + + +from flatbuffers.builder import np +from scipy.spatial import ConvexHull + +from algo.stream_loader import OpenCVStreamLoad +from common.detect_utils import is_within_alert_range, get_person_head, intersection_area, bbox_area +from common.device_status_manager import DeviceStatusManager +from common.global_logger import logger +from common.global_thread_pool import GlobalThreadPool +from common.harmful_gas_manager import HarmfulGasManager +from common.image_plotting import Annotator +from entity.device import Device +from scene_handler.alarm_message_center import AlarmMessageCenter +from scene_handler.alarm_record_center import AlarmRecordCenter +from scene_handler.base_scene_handler import BaseSceneHandler +from scene_handler.limit_space_scene_handler import is_overlapping +from services.global_config import GlobalConfig +from tcp.tcp_manager import TcpManager + +from entity.device import Device +from common.http_utils import get_request +from ultralytics import YOLO + +''' +alarmCategory: +0 行为监管 +1 环境监管 +2 人员监管 +3 围栏监管 + +handelType: +0 检测到报警 +1 未检测到报警 +2 人未穿戴报警 +3 其他 +''' +ALARM_DICT = [ + { + 'alarmCategory': 0, + 'alarmType': '1', + 'handelType': 1, + 'category_order': 1, + 'class_idx': [34], + 'alarm_name': 'no_fire_extinguisher', + 'alarmContent': '未检测到灭火器', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '', + }, + { + 'alarmCategory': 0, + 'alarmType': '2', + 'handelType': 1, + 'category_order': 2, + 'class_idx': [43], + 'alarm_name': 'no_barrier_tape', + 'alarmContent': '未检测到警戒线', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '', + }, + { + 'alarmCategory': 0, + 'alarmType': '3', + 'handelType': 1, + 'category_order': 3, + 'class_idx': [48], + 'alarm_name': 'no_cone', + 'alarmContent': '未检测到锥桶', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '', + }, + { + 'alarmCategory': 0, + 'alarmType': '4', + 'handelType': 1, + 'category_order': 4, + 'class_idx': [4, 5, 16], + 'alarm_name': 'no_board', + 'alarmContent': '未检测到指示牌', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '', + }, + { + 'alarmCategory': 0, + 'alarmType': '5', + 'handelType': 2, + 'category_order': -1, + 'class_idx': [18], + 'alarm_name': 'no_helmet', + 'alarmContent': '未佩戴安全帽', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '未佩戴安全帽', + }, + # todo 明火 + { + 'alarmCategory': 1, + 'alarmType': '7', + 'handelType': 3, + 'category_order': 1, + 'class_idx': [], + 'alarm_name': 'gas_alarm', + 'alarmContent': '甲烷浓度超限', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '', + }, + { + 'alarmCategory': 1, + 'alarmType': '8', + 'handelType': 3, + 'category_order': 2, + 'class_idx': [], + 'alarm_name': 'harmful_alarm', + 'alarmContent': '有害气体浓度超标', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '', + }, + { + 'alarmCategory': 2, + 'alarmType': '9', + 'handelType': 3, + 'category_order': -1, + 'class_idx': [], + 'alarm_name': 'health_alarm', + 'alarmContent': '心率血氧异常', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '', + }, + { + 'alarmCategory': 3, + 'alarmType': '10', + 'handelType': 2, + 'category_order': 4, + 'class_idx': [24], + 'alarm_name': 'break_in_alarm', + 'alarmContent': '非法闯入', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '非法闯入', + }, + +] + +COLOR_RED = (0, 0, 255) +COLOR_BLUE = (255, 0, 0) + + +class BlockSceneHandler(BaseSceneHandler): + def __init__(self, device: Device, thread_id: str, tcp_manager: TcpManager, main_loop, range_points): + super().__init__(device=device, thread_id=thread_id, tcp_manager=tcp_manager, main_loop=main_loop) + self.__stop_event = Event(loop=main_loop) + self.health_ts_dict = {} + self.harmful_ts_dict = {} + self.object_ts_dict = {} + self.thread_pool = GlobalThreadPool() + + self.alarm_message_center = AlarmMessageCenter(device.id,main_loop=main_loop, tcp_manager=tcp_manager, + category_priority={2: 0, 1: 1, 3: 2, 0: 3}) + self.alarm_record_center = AlarmRecordCenter(save_interval=device.alarm_interval,main_loop=main_loop) + self.harmful_data_manager = HarmfulGasManager() + self.device_status_manager = DeviceStatusManager() + + + self.health_device_codes = ['HWIH061000056395'] # todo + self.harmful_device_codes = [] # todo + + for helmet_code in self.health_device_codes: + self.thread_pool.submit_task(self.health_data_task, helmet_code) + for harmful_device_code in self.harmful_device_codes: + self.thread_pool.submit_task(self.harmful_data_task, harmful_device_code) + + self.thread_pool.submit_task(self.alarm_message_center.process_messages) + + # todo 明火 + self.model = YOLO('weights/labor-v8-20241114.pt') + self.model_classes = { + 0: '三脚架', + 3: '人', + 4: '作业信息公示牌', + 6: '危险告知牌', + 9: '反光衣', + 11: '呼吸面罩', + 13: '四合一', + 15: '头', + 16: '安全告知牌', + 18: '安全帽', + 20: '安全标识牌', + 24: '工服', + 34: '灭火器', + 43: '警戒线', + 48: '路锥', + 58: '鼓风机', + } + self.PERSON_CLASS_IDX = 3 + self.HEAD_CLASS_IDX = 15 + + self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, + device_thread_id=thread_id) + self.range_points = range_points + self.abs_range_points = self.get_absolute_range() + + self.tracking_status = {} # 跟踪每个行人的状态 + self.max_missing_frames = 25 # 报警的阈值 + self.disappear_threshold = 25 * 3 # 移除行人的阈值 + + def get_absolute_range(self): + fence_info = eval(self.range_points) + if fence_info and len(fence_info) > 1: + abs_points = [] + for p in fence_info: + abs_points.append( + [int(p[0] * int(self.stream_loader.frame_width)), int(p[1] * int(self.stream_loader.frame_height))]) + + abs_points = np.array(abs_points, dtype=np.int32) + hull = ConvexHull(abs_points) + sorted_coordinates = abs_points[hull.vertices] + # abs_points = abs_points.reshape((-1, 1, 2)) + return sorted_coordinates + else: + return None + + def harmful_data_task(self, harmful_device_code): + while not self.__stop_event.is_set(): + harmful_gas_data = self.harmful_data_manager.get_device_all_data(harmful_device_code) + for gas_type, gas_data in harmful_gas_data.items(): + ts_key = f'{harmful_device_code}_{gas_type}' + last_ts = self.harmful_ts_dict.get(ts_key) + gas_ts = gas_data.get('gas_ts') + if last_ts is None or (gas_ts - last_ts).total_seconds() > 0: + self.harmful_ts_dict[ts_key] = gas_ts + self.handle_harmful_gas_alarm(harmful_device_code, gas_type, gas_data) + + def health_data_task(self, helmet_code): + while not self.__stop_event.is_set(): + header = { + 'ak': 'fe80b2f021644b1b8c77fda743a83670', + 'sk': '8771ea6e931d4db646a26f67bcb89909', + } + url = f'https://jls.huaweisoft.com//api/ih-log/v1.0/ih-api/helmetInfo/{helmet_code}' + response = get_request(url, headers=header) + if response and response.get('data'): + last_ts = self.health_ts_dict.get(helmet_code) + vitalsigns_data = response.get('data').get('vitalSignsData') + if vitalsigns_data: + upload_timestamp = datetime.strptime(vitalsigns_data.get('uploadTimestamp'), "%Y-%m-%d %H:%M:%S") + if last_ts is None or (upload_timestamp.timestamp() - last_ts) > 0: + self.health_ts_dict[helmet_code] = upload_timestamp.timestamp() + if time.time() - upload_timestamp.timestamp() < 10 * 60: # 10分钟以前的数据不做处理 + self.handle_health_alarm(helmet_code, vitalsigns_data.get('bloodOxygen'), + vitalsigns_data.get('heartRate'),upload_timestamp) + time.sleep(10) + + def handle_health_alarm(self, helmet_code, blood_oxygen, heartrate, upload_timestamp): + logger.debug(f'health_data: {helmet_code}, blood_oxygen = {blood_oxygen}, heartrate = {heartrate}, ' + f'upload_timestamp = {upload_timestamp}') + if heartrate < 60 or heartrate > 120 or blood_oxygen < 85: + alarm_dict = [d for d in ALARM_DICT if d['alarmCategory'] == 2] + if alarm_dict: + self.alarm_message_center.add_message(alarm_dict[0]) + # todo 需要生成报警记录吗 需要往后台发原始数据吗 + + def handle_harmful_gas_alarm(self, device_code, gas_type, gas_data): + alarm = False + gas_value = gas_data['gas_value'] + if gas_type == 3: # h2s + alarm = gas_value > 120.0 + elif gas_type == 4: # co + alarm = gas_value > 10.0 + elif gas_type == 5: # o2 + alarm = gas_value < 15 + elif gas_type == 50: # ex + alarm = gas_value > 10 + + if alarm: + alarm_dict = [d for d in ALARM_DICT if d['alarmCategory'] == 1] + if alarm_dict: + self.alarm_message_center.add_message(alarm_dict[0]) + # todo 需要生成报警记录吗 + + def model_predict(self, frames): + results_generator = self.model(frames, save_txt=False, save=False, verbose=False, conf=0.5, + classes=list(self.model_classes.keys()), + imgsz=640, + stream=True) + result_boxes = [] + for r in results_generator: + result_boxes.append(r.boxes) + + pred_ids = [[int(box.cls) for box in sublist] for sublist in result_boxes] + pred_names = [[self.model_classes[int(box.cls)] for box in sublist] for sublist in result_boxes] + return result_boxes, pred_ids, pred_names + + def handle_behave_alarm(self, frames, result_boxes, pred_ids, pred_names): + behave_alarm_dicts = [d for d in ALARM_DICT if d['alarmCategory'] == 0] + for alarm_dict in behave_alarm_dicts: + for idx, frame_boxes in enumerate(result_boxes): + frame = frames[idx] + frame_ids, frame_names = pred_ids[idx], pred_names[idx] + object_boxes = [box for box in frame_boxes if int(box.cls) in alarm_dict['class_idx']] + if alarm_dict['handelType'] == 0: # 检测到就报警 + if object_boxes: + self.alarm_message_center.add_message(alarm_dict) + if self.alarm_record_center.need_alarm(self.device.code, alarm_dict): + annotator = Annotator(deepcopy(frame)) + for box in frame_boxes: + if int(box.cls) != self.HEAD_CLASS_IDX: + box_color = COLOR_RED if int(box.cls) in alarm_dict['class_idx'] else COLOR_BLUE + annotator.box_label(box.xyxy.cpu().squeeze(), + f"{self.model_classes[int(box.cls)]}", + color=box_color, + rotated=False) + self.alarm_record_center.upload_alarm_record(self.device.code, alarm_dict, + annotator.result()) + + elif alarm_dict['handelType'] == 1: # 检测不到报警 + if object_boxes: + self.object_ts_dict[alarm_dict['alarm_name']] = time.time() + else: + last_ts = self.object_ts_dict.get(alarm_dict['alarm_name'], 0) + if time.time() - last_ts > 5: + self.object_ts_dict[alarm_dict['alarm_name']] = time.time() + self.alarm_message_center.add_message(alarm_dict) + if self.alarm_record_center.need_alarm(self.device.code, alarm_dict): + annotator = Annotator(deepcopy(frame)) + for box in frame_boxes: + if int(box.cls) != self.HEAD_CLASS_IDX: + annotator.box_label(box.xyxy.cpu().squeeze(), + f"{self.model_classes[int(box.cls)]}", + color=COLOR_BLUE, + rotated=False) + self.alarm_record_center.upload_alarm_record(self.device.code, alarm_dict, + annotator.result()) + elif alarm_dict['handelType'] == 2: # 人未穿戴报警 + person_boxes = [box for box in frame_boxes if int(box.cls) == self.PERSON_CLASS_IDX] + head_boxes = [box for box in frame_boxes if int(box.cls) == self.HEAD_CLASS_IDX] + has_alarm = False + annotator = None + for person_box in person_boxes: + person_bbox = person_box.xyxy.cpu().squeeze() + # 检查这个人是否佩戴了安全帽 + has_helmet = True + person_head = get_person_head(person_bbox, head_boxes) + if person_head is not None: + has_helmet = any( + is_overlapping(person_head.xyxy.cpu().squeeze(), helmet.xyxy.cpu().squeeze()) + for helmet in object_boxes) + if not has_helmet: + has_alarm = True + if self.alarm_record_center.need_alarm(self.device.code, alarm_dict): + annotator = Annotator(deepcopy(frame)) if annotator is None else annotator + annotator.box_label(person_bbox, alarm_dict['label'], color=COLOR_RED, rotated=False) + + if has_alarm: + self.alarm_message_center.add_message(alarm_dict) + if self.alarm_record_center.need_alarm(self.device.code, alarm_dict): + for box in frame_boxes: + box_cls = box.cls + if box_cls != self.PERSON_CLASS_IDX and box_cls != self.HEAD_CLASS_IDX: + annotator.box_label(box.xyxy.cpu().squeeze(), + f"{self.model_classes[int(box.cls)]}", + color=COLOR_BLUE, + rotated=False) + self.alarm_record_center.upload_alarm_record(self.device.code, alarm_dict, + annotator.result()) + + def handle_break_in_alarm(self, frames, result_boxes, pred_ids, pred_names): + break_in_alarm_dicts = [d for d in ALARM_DICT if d['alarmCategory'] == 3] + for alarm_dict in break_in_alarm_dicts: + for idx, frame_boxes in enumerate(result_boxes): + frame = frames[idx] + frame_ids, frame_names = pred_ids[idx], pred_names[idx] + person_boxes = [box for box in frame_boxes if int(box.cls) == self.PERSON_CLASS_IDX] + head_boxes = [box for box in frame_boxes if int(box.cls) == self.HEAD_CLASS_IDX] + object_boxes = [box for box in frame_boxes if int(box.cls) in alarm_dict['class_idx']] + has_alarm = False + annotator = None + for person_box in person_boxes: + person_bbox = person_box.xyxy.cpu().squeeze() + person_id = person_box.id + if is_within_alert_range(person_bbox, self.abs_range_points): + has_object = True + person_head = get_person_head(person_bbox, head_boxes) + if person_head is not None: + overlap_ratio = intersection_area(person_bbox, person_head.xyxy.cpu().squeeze()) / bbox_area(person_bbox) + if overlap_ratio < 0.5: # 头占人<0.5,判断是否穿工服。不太准确 + has_object = any( + is_overlapping(person_head.xyxy.cpu().squeeze(), object_boxe.xyxy.cpu().squeeze()) + for object_boxe in object_boxes) + if not has_object: + self.tracking_status[person_box.id] = self.tracking_status.get(person_box.id, 0) + 1 + + has_alarm = True + if self.alarm_record_center.need_alarm(self.device.code, alarm_dict): + annotator = Annotator(deepcopy(frame)) if annotator is None else annotator + annotator.box_label(person_bbox, alarm_dict['label'], color=COLOR_RED, rotated=False) + + if has_alarm: + self.alarm_message_center.add_message(alarm_dict) + if self.alarm_record_center.need_alarm(self.device.code, alarm_dict): + self.alarm_record_center.upload_alarm_record(self.device.code, alarm_dict, + annotator.result()) + + def run(self): + while not self.stream_loader.init: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + self.stream_loader.init_cap() + for frames in self.stream_loader: + try: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + if not frames: + continue + + self.device_status_manager.set_status(device_id=self.device.id) + result_boxes, pred_ids, pred_names = self.model_predict(frames) # 结果都是二维数组,对应batch中的每个frame + # print(pred_names) + self.handle_behave_alarm(frames, result_boxes, pred_ids, pred_names) + self.handle_break_in_alarm(frames, result_boxes, pred_ids, pred_names) + + except Exception as ex: + traceback.print_exc() + logger.error(ex) diff --git a/scene_handler/limit_space_scene_handler.py b/scene_handler/limit_space_scene_handler.py index d1c5b9a..a7f1674 100644 --- a/scene_handler/limit_space_scene_handler.py +++ b/scene_handler/limit_space_scene_handler.py @@ -145,7 +145,7 @@ class LimitSpaceSceneHandler(BaseSceneHandler): - def __init__(self, device: Device, thread_id: str, tcp_manager: TcpManager, main_loop): + def __init__(self, device: Device, thread_id: str, tcp_manager: TcpManager, main_loop, range_points): super().__init__(device=device, thread_id=thread_id, tcp_manager=tcp_manager, main_loop=main_loop) # self.device = device # self.thread_id = thread_id @@ -287,20 +287,20 @@ if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 self.stream_loader.init_cap() - for frame in self.stream_loader: + for frames in self.stream_loader: try: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 # print('frame') - if frame is None: + if frames is None: continue self.device_status_manager.set_status(device_id=self.device.id) - result_boxes, pred_ids, pred_names = self.model_predict(frame) + # result_boxes, pred_ids, pred_names = self.model_predict(frames) frame_alarm = {} - self.process_alarm(frame, result_boxes, pred_ids, pred_names, frame_alarm) - self.process_labor(frame, result_boxes, pred_ids, pred_names) + # self.process_alarm(frame, result_boxes, pred_ids, pred_names, frame_alarm) + # self.process_labor(frame, result_boxes, pred_ids, pred_names) if len(frame_alarm.keys()) > 0: for key in frame_alarm.keys(): diff --git a/services/alarm_record_service.py b/services/alarm_record_service.py new file mode 100644 index 0000000..a0a86f9 --- /dev/null +++ b/services/alarm_record_service.py @@ -0,0 +1,46 @@ +import os +import uuid +from datetime import datetime + +import aiofiles +import cv2 +from sqlalchemy.ext.asyncio import AsyncSession + +from entity.alarm_record import AlarmRecordCreate, AlarmRecord + + +class AlarmRecordService: + def __init__(self, db: AsyncSession): + self.db = db + + async def add_alarm(self, alarm_data:AlarmRecordCreate, alarm_np_img): + async def save_alarm_file(): + # 生成当前年月日作为目录路径 + current_date = datetime.now().strftime('%Y-%m-%d') + # 生成随机 UUID 作为文件名 + file_name = f"{uuid.uuid4()}.jpeg" + # 创建保存图片的完整路径 + save_path = os.path.join('./storage/alarms', current_date, file_name) + # 创建目录(如果不存在) + os.makedirs(os.path.dirname(save_path), exist_ok=True) + + # 将 frame_data 转换为二进制数据(假设 frame_data 是一个 numpy 数组) + _, encoded_image = cv2.imencode('.jpeg', alarm_np_img) + image_data = encoded_image.tobytes() + + # 使用 aiofiles 进行异步写入 + async with aiofiles.open(save_path, 'wb') as f: + await f.write(image_data) + + return save_path + + # 异步保存图片文件 + alarm_img_path = await save_alarm_file() + + # 创建并保存到数据库中 + alarm_record = AlarmRecord.model_validate(alarm_data) + alarm_record.alarm_image = alarm_img_path + self.db.add(alarm_record) + await self.db.commit() + await self.db.refresh(alarm_record) + return alarm_record diff --git a/services/device_scene_relation_service.py b/services/device_scene_relation_service.py index 3416831..c94c540 100644 --- a/services/device_scene_relation_service.py +++ b/services/device_scene_relation_service.py @@ -39,18 +39,19 @@ if result_row: relation, scene = result_row scene_info = DeviceSceneRelationInfo( - id=relation.id, - device_id=relation.device_id, - scene_id=relation.scene_id, - scene_name=scene.name, - scene_version=scene.version, - scene_handle_task=scene.handle_task, - scene_remark=scene.remark, - ) + id=relation.id, + device_id=relation.device_id, + scene_id=relation.scene_id, + scene_name=scene.name, + scene_version=scene.version, + scene_handle_task=scene.handle_task, + scene_remark=scene.remark, + range_points=relation.range_points + ) return scene_info - async def add_relation_by_device(self, device_id: int, scene_id: int): - new_relation = DeviceSceneRelation(device_id=device_id, scene_id=scene_id) + async def add_relation_by_device(self, device_id: int, scene_id: int, range_points: str = None): + new_relation = DeviceSceneRelation(device_id=device_id, scene_id=scene_id, range_points=range_points) new_relation.create_time = datetime.now() new_relation.update_time = datetime.now() self.db.add(new_relation) @@ -64,8 +65,8 @@ await self.db.commit() return result.rowcount - async def update_relation_by_device(self, device_id: int, scene_id: int): + async def update_relation_by_device(self, device_id: int, scene_id: int, range_points: str = None): await self.delete_relation_by_device(device_id) - new_relation = await self.add_relation_by_device(device_id, scene_id) + new_relation = await self.add_relation_by_device(device_id, scene_id, range_points) self.notify_change(device_id, NotifyChangeType.DEVICE_SCENE_RELATION_UPDATE) return new_relation diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py deleted file mode 100644 index 9109d75..0000000 --- a/algo/algo_runner_manager.py +++ /dev/null @@ -1,8 +0,0 @@ -from algo.algo_runner import AlgoRunner - -# algo_runner = AlgoRunner() - - -def get_algo_runner(): - pass - # return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 6a86830..d2c647e 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -3,7 +3,7 @@ from copy import deepcopy from dataclasses import dataclass import importlib -from datetime import datetime +from datetime import datetime, timedelta from threading import Event from typing import List, Dict @@ -48,6 +48,8 @@ self.__stop_event = Event() # 使用 Event 控制线程的运行状态 self.frame_ts = None self.push_ts = None + self.frames_detected = 0 + self.fps_ts = None self.thread_id = thread_id self.device_frame_service = DeviceFrameService(db_session) @@ -77,37 +79,40 @@ return True return False - def save_frame_results(self, frame, results_map): - if not self.check_frame_interval(): - return - # device_frame = self.device_frame_service.add_frame(self.device.id, frame) - # frame_id = device_frame.id + def save_frame_results(self, frames, result_maps): + for idx,frame in enumerate(frames): + results_map = result_maps[idx] - future = asyncio.run_coroutine_threadsafe( - self.device_frame_service.add_frame(self.device.id, frame), self.main_loop - ) - device_frame = future.result() - frame_id = device_frame.id + if not self.check_frame_interval(): + return + # device_frame = self.device_frame_service.add_frame(self.device.id, frame) + # frame_id = device_frame.id - logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') - frame_results = [] - for model_exec_id, results in results_map.items(): - for r in results: - frame_result = FrameAnalysisResultCreate( - device_id=self.device.id, - frame_id=frame_id, - algo_model_id=model_exec_id, - object_class_id=r.object_class_id, - object_class_name=r.object_class_name, - confidence=round(r.confidence, 4), - location=r.location, - ) - frame_results.append(frame_result) - asyncio.run_coroutine_threadsafe( - self.frame_analysis_result_service.add_frame_analysis_results(frame_results), - self.main_loop - ) - self.thread_pool.submit_task(self.push_frame_results, frame_results) + future = asyncio.run_coroutine_threadsafe( + self.device_frame_service.add_frame(self.device.id, frame), self.main_loop + ) + device_frame = future.result() + frame_id = device_frame.id + + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') + frame_results = [] + for model_exec_id, results in results_map.items(): + for r in results: + frame_result = FrameAnalysisResultCreate( + device_id=self.device.id, + frame_id=frame_id, + algo_model_id=model_exec_id, + object_class_id=r.object_class_id, + object_class_name=r.object_class_name, + confidence=round(r.confidence, 4), + location=r.location, + ) + frame_results.append(frame_result) + asyncio.run_coroutine_threadsafe( + self.frame_analysis_result_service.add_frame_analysis_results(frame_results), + self.main_loop + ) + self.thread_pool.submit_task(self.push_frame_results, frame_results) def push_frame_results(self, frame_results): global_config = GlobalConfig() @@ -124,31 +129,63 @@ ) self.push_ts = current_time # 更新推送时间戳 + def log_fps(self, frame_count): + self.frames_detected += frame_count + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_detected / 10.0 + self.frames_detected = 0 + logger.info(f"FPS (detect) for device {self.device.code}: {fps}") + self.fps_ts = current_time + def run(self): while not self.stream_loader.init: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 self.stream_loader.init_cap() - for frame in self.stream_loader: + for frames in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 # print('frame') - if frame is None: + if frames is None or len(frames)<=0: continue self.device_status_manager.set_status(device_id=self.device.id) - results_map = {} - annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + + result_maps = [] # 保存每个frame的result map + annotators = [] + for frame in frames: + annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + annotators.append(annotator) + for model_exec in self.model_exec_list: handle_task_name = model_exec.algo_model_info.handle_task handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) handler_instance = handler_cls(model_exec) - frame, results = handler_instance.run(frame, annotator) - results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] + frames, results = handler_instance.run(frames, annotators) + + # 遍历检测结果,按帧存储 + for frame_idx, r in enumerate(results): # 遍历每帧结果 + if len(result_maps) <= frame_idx: # 初始化 frame_results_map + result_maps.append({}) + frame_results_map = result_maps[frame_idx] + # 为当前模型存储检测结果 + if model_exec.algo_model_id not in frame_results_map: + frame_results_map[model_exec.algo_model_id] = [] + # 添加检测结果 + frame_results_map[model_exec.algo_model_id].extend( + DetectionResult.from_dict(box) for box in r + ) # 结果处理 - self.thread_pool.submit_task(self.save_frame_results, frame, results_map) - self.display_frame_manager.add_frame(self.device.id, annotator.result()) + self.thread_pool.submit_task(self.save_frame_results, frames, result_maps) + for annotator in annotators: + self.display_frame_manager.add_frame(self.device.id, annotator.result()) + + self.log_fps(len(frames)) + + # future = asyncio.run_coroutine_threadsafe( # self.display_frame_manager.add_frame(self.device.id, annotator.result()), self.main_loop # ) diff --git a/algo/model_manager.py b/algo/model_manager.py index 1c30867..e84676f 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -19,11 +19,12 @@ class ModelManager: - def __init__(self, model_service: ModelService, model_warm_up=5): + def __init__(self, model_service: ModelService, model_warm_up=5, batch_size=4): # self.db = db self.model_service = model_service self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up + self.batch_size=4 async def query_model_inuse(self): algo_model_list = list(await self.model_service.get_models_in_use()) @@ -58,8 +59,11 @@ if not isinstance(imgsz, list): imgsz = [imgsz, imgsz] dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + dummy_inputs = [] + for _ in range(self.batch_size): + dummy_inputs.append(dummy_input) for i in range(self.model_warm_up): - algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + algo_model_exec.algo_model_exec.predict(source=dummy_inputs, imgsz=imgsz, verbose=False,save=False,save_txt=False) logger.info(f'warm up model {model_name} success!') logger.info(f'load model {model_name} success!') diff --git a/algo/scene_runner.py b/algo/scene_runner.py index e5c067e..5120f96 100644 --- a/algo/scene_runner.py +++ b/algo/scene_runner.py @@ -69,7 +69,7 @@ try: handle_task_name = scene.scene_handle_task handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) - handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) + handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop,scene.range_points) future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) self.device_tasks[device.id] = handler_instance logger.info(f'start thread {thread_id}, device info: {device}') diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 4133388..202213c 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -1,13 +1,20 @@ +from datetime import datetime, timedelta + import cv2 import time import numpy as np from threading import Thread, Event + +import queue + from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool class OpenCVStreamLoad: def __init__(self, camera_url, camera_code, device_thread_id = '', + batch_size=4, + queue_size=100, retry_interval=1, vid_stride=1): assert camera_url is not None and camera_url != '' @@ -19,10 +26,15 @@ self.init = False self.frame = None + self.frame_queue = queue.Queue(maxsize=queue_size) + self.fps_ts = None self.cap = None + self.batch_size = batch_size + self.frames_read = 0 self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 self.thread_pool = GlobalThreadPool() + self.init_cap() # self.thread = Thread(target=self.update, daemon=True) # self.thread.start() @@ -44,7 +56,14 @@ 尝试创建视频流捕获对象。 """ try: - cap = cv2.VideoCapture(self.url) + # cap = cv2.VideoCapture(self.url) + gst_pipeline = ( + f"rtspsrc location={self.url} ! " + f"rtph264depay ! h264parse ! " + f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + f"appsink" + ) + cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) # 可以在这里设置cap的一些属性,如果需要的话 return cap except Exception as e: @@ -74,9 +93,13 @@ cap = cv2.VideoCapture(self.url) # cap = cv2.VideoCapture(self.url, cv2.CAP_FFMPEG) - # 使用 GStreamer pipeline 作为 VideoCapture 的输入 - # gst_str = f"rtspsrc location={self.url} ! decodebin ! videoconvert ! appsink" - # cap = cv2.VideoCapture(gst_str, cv2.CAP_GSTREAMER) + # gst_pipeline = ( + # f"rtspsrc location={self.url} ! " + # f"rtph264depay ! h264parse ! " + # f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + # f"appsink" + # ) + # cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) if cap.isOpened(): logger.info(f"{self.url} connected successfully!") @@ -91,6 +114,15 @@ return cap + def log_fps(self): + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_read / 10.0 + self.frames_read = 0 + logger.info(f"FPS (read) for device {self.camera_code}: {fps}") + self.fps_ts = current_time + def update(self): vid_n = 0 log_n = 0 @@ -111,6 +143,9 @@ else: vid_n += 1 self.frame = frame + self.frames_read += 1 + if not self.frame_queue.full(): + self.frame_queue.put(frame) if log_n % 1000 == 0: logger.debug(f'{self.url} cap success') log_n = (log_n + 1) % 250 @@ -120,12 +155,22 @@ self.cap.release() self.frame = None self.cap = self.get_connect() # 尝试重新连接 + self.log_fps() def __iter__(self): return self def __next__(self): - return self.frame + batch_frames = [] + + queue_length = self.frame_queue.qsize() + if queue_length < self.batch_size: + return [] + + while not self.frame_queue.empty() and len(batch_frames) < self.batch_size: + frame = self.frame_queue.get() + batch_frames.append(frame) + return batch_frames def stop(self): logger.info(f'stop stream loader {self.url},device thread {self.device_thread_id} stopped') diff --git a/apis/control.py b/apis/control.py index 9c6d6fb..2d351d8 100644 --- a/apis/control.py +++ b/apis/control.py @@ -9,6 +9,7 @@ from fastapi import APIRouter import docker import socket +import signal from apis.base import standard_error_response, standard_response from common.global_logger import logger @@ -47,6 +48,23 @@ return None +def handle_sigterm(signum, frame): + print(f"Received signal: {signum}, shutting down gracefully...") + os._exit(0) + + +signal.signal(signal.SIGTERM, handle_sigterm) + + +def shutdown(): + # 等待短暂时间以确保响应已发送 + time.sleep(1) + logger.info("Shutting down the application...") + # sys.exit(0) + # os.kill(os.getpid(), signal.SIGTERM) + os._exit(0) + + @router.get("/restart") async def restart(): try: @@ -92,17 +110,20 @@ logger.error(f"Failed to restart container asynchronously: {ex}") # 在新线程中启动重启操作 - threading.Thread(target=restart_container_async).start() + # threading.Thread(target=restart_container_async).start() + threading.Thread(target=shutdown).start() return standard_response() except Exception as e: traceback.print_exc() return standard_error_response(code=500, message=f"Failed to restart container: {e}") + @router.get("/sync_test") def sync_test(): return standard_response() + @router.get("/async_test") async def async_test(): return standard_response() diff --git a/apis/device_scene_realtion.py b/apis/device_scene_realtion.py index d9aa3a2..77d3677 100644 --- a/apis/device_scene_realtion.py +++ b/apis/device_scene_realtion.py @@ -22,8 +22,8 @@ @router.post("/update_by_device", response_model=StandardResponse[DeviceSceneRelation]) -async def update_by_device(device_id: int, scene_id: int, +async def update_by_device(device_id: int, scene_id: int, range_points: str = None, db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - relation = await service.update_relation_by_device(device_id, scene_id) + relation = await service.update_relation_by_device(device_id, scene_id,range_points) return standard_response(data=relation) diff --git a/app_instance.py b/app_instance.py index 82a750d..c648f6e 100644 --- a/app_instance.py +++ b/app_instance.py @@ -1,4 +1,5 @@ import asyncio + from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException @@ -17,6 +18,7 @@ from services.scene_service import SceneService from services.schedule_job import start_scheduler from tcp.tcp_manager import TcpManager +from tcp.tcp_server import start_server _app = None # 创建一个私有变量来存储 app 实例 @@ -69,6 +71,8 @@ app.state.scene_runner = scene_runner await scene_runner.start() + main_loop.create_task(start_server()) + main_loop.create_task(start_scheduler()) yield # 允许请求处理 diff --git a/common/byte_utils.py b/common/byte_utils.py index 301bb8a..ea3f4ea 100644 --- a/common/byte_utils.py +++ b/common/byte_utils.py @@ -1,3 +1,11 @@ def format_bytes(data: bytes): - return ''.join(f'\\x{byte:02X}' for byte in data) \ No newline at end of file + return ''.join(f'\\x{byte:02X}' for byte in data) + + +if __name__ == '__main__': + msg = b'\xaa\x01\x00\x93\x12\x00\xA6' + print(format_bytes(msg)) + + msg = b'\xAA\x01\x00\x95\x00\x00\x96' + print(format_bytes(msg)) \ No newline at end of file diff --git a/common/consts.py b/common/consts.py index 5a3286c..6deaedf 100644 --- a/common/consts.py +++ b/common/consts.py @@ -36,3 +36,4 @@ GAS = 1 ALGO_RESULT = 2 ALARM = 3 + HARMFUL_GAS = 4 diff --git a/common/detect_utils.py b/common/detect_utils.py new file mode 100644 index 0000000..33c340c --- /dev/null +++ b/common/detect_utils.py @@ -0,0 +1,72 @@ +import matplotlib.path as mpath + + +def is_within_alert_range(bbox, range): + if range is None: + return True + bottom_left = [bbox[0], bbox[3]] + bottom_right = [bbox[2], bbox[3]] + + in_fence = is_point_in_polygon(bottom_left, range) or is_point_in_polygon(bottom_right, range) + return in_fence + + +def is_point_in_polygon(point, polygon): + """ + 判断点是否在多边形内 + + 参数: + - point: 待判断的点,格式为(x, y) + - polygon: 多边形的顶点坐标列表,格式为[(x1, y1), (x2, y2), ...] + + 返回: + - 如果点在多边形内,返回True;否则,返回False。 + """ + # 创建多边形路径对象 + # polygon = np.array(polygon, dtype=np.int32) + # + # hull = ConvexHull(polygon) + # sorted_coordinates = polygon[hull.vertices] + + poly_path = mpath.Path(polygon) + # 使用contains_point方法检查点是否在多边形内 + return poly_path.contains_point(point) + + +def intersection_area(bbox1, bbox2): + # 计算两个坐标框的重叠面积 + x1, y1, x2, y2 = bbox1 + x3, y3, x4, y4 = bbox2 + + xi1 = max(x1, x3) + yi1 = max(y1, y3) + xi2 = min(x2, x4) + yi2 = min(y2, y4) + + width = max(0, xi2 - xi1) + height = max(0, yi2 - yi1) + + return width * height + + +def bbox_area(bbox): + # 计算坐标框的面积 + x1, y1, x2, y2 = bbox + return (x2 - x1) * (y2 - y1) + + +def get_person_head(person_bbox, head_bboxes): + best_head = None + max_overlap = 0 + + for head in head_bboxes: + head_bbox = head.xyxy.cpu().squeeze() + overlap_area = intersection_area(person_bbox, head_bbox) + head_area = bbox_area(head_bbox) + overlap_ratio = overlap_area / head_area + + if overlap_ratio >= 0.8 and (best_head is None or overlap_area > max_overlap or ( + overlap_area == max_overlap and float(head.conf) > float(best_head.conf))): + best_head = head + max_overlap = overlap_area + return best_head diff --git a/common/harmful_gas_manager.py b/common/harmful_gas_manager.py new file mode 100644 index 0000000..297e02c --- /dev/null +++ b/common/harmful_gas_manager.py @@ -0,0 +1,44 @@ +import threading +from datetime import datetime + + +class HarmfulGasManager: + + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + # 确保线程安全的单例模式 + if not cls._instance: + with cls._lock: + if not cls._instance: + cls._instance = super(HarmfulGasManager, cls).__new__(cls) + return cls._instance + + def __init__(self): + # 初始化一次,避免重复初始化 + if not hasattr(self, "device_data"): + self.device_data = {} + self.lock = threading.Lock() + + def get_device_all_data(self, device_code): + """获取指定设备的在线状态""" + with self.lock: + return self.device_data.get(device_code, None) + + def get_device_data(self, device_code, gas_type): + """获取指定设备的在线状态""" + with self.lock: + device_data = self.device_data.get(device_code, None) + if device_data: + return device_data.get(gas_type, None) + return None + + def set_device_data(self, device_code, gas_type, gas_data): + """设置指定设备的在线状态""" + with self.lock: + device_data = self.device_data.get(device_code, None) + if not device_data: + device_data = {} + self.device_data[device_code] = device_data + device_data[gas_type] = gas_data \ No newline at end of file diff --git a/common/http_utils.py b/common/http_utils.py index 5f15707..a49c6ea 100644 --- a/common/http_utils.py +++ b/common/http_utils.py @@ -22,3 +22,11 @@ logger.info(f"Response: {response.status_code}, {response.text}") except requests.RequestException as e: logger.error(f"Failed to push data: {e}") + + +def get_request(url, headers=None): + try: + response = requests.get(url, headers=headers) + return response.json() + except requests.RequestException as e: + logger.error(f"Failed to get data: {e}") diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db index d1acdd4..694b825 100644 --- a/db/safe-algo-pro.db +++ b/db/safe-algo-pro.db Binary files differ diff --git a/entity/alarm_record.py b/entity/alarm_record.py new file mode 100644 index 0000000..1ebe474 --- /dev/null +++ b/entity/alarm_record.py @@ -0,0 +1,32 @@ +from datetime import datetime +from typing import Optional + +from sqlmodel import Field, SQLModel + + +class AlarmRecordBase(SQLModel): + # todo tree code?? alarm_category?? + device_code: str + alarm_type: int + alarm_content: str + alarm_value: Optional[float] = None + alarm_image: Optional[str] = None + alarm_time: datetime = Field(default_factory=datetime.now) + + class Config: + json_encoders = { + datetime: lambda v: v.strftime('%Y-%m-%d %H:%M:%S') + } + + +class AlarmRecord(AlarmRecordBase, table=True): + __tablename__ = 'alarm_record' + id: Optional[int] = Field(default=None, primary_key=True) + + +class AlarmRecordCreate(AlarmRecordBase): + pass + + +class AlarmRecordInfo(AlarmRecordBase): + id: int diff --git a/entity/device_scene_relation.py b/entity/device_scene_relation.py index c4543d7..38481ef 100644 --- a/entity/device_scene_relation.py +++ b/entity/device_scene_relation.py @@ -8,6 +8,7 @@ class DeviceSceneRelationBase(SQLModel): scene_id: int device_id: int + range_points: Optional[str] = None class DeviceSceneRelation(DeviceSceneRelationBase, TimestampMixin, table=True): diff --git a/model_handler/base_model_handler.py b/model_handler/base_model_handler.py index 46e0d53..c7c3490 100644 --- a/model_handler/base_model_handler.py +++ b/model_handler/base_model_handler.py @@ -8,41 +8,48 @@ def __init__(self, model: AlgoModelExec): self.model = model self.model_names = model.algo_model_exec.names + self.model_ids = list(self.model_names.keys()) def pre_process(self, frame): return frame - def model_inference(self, frame): - results_generator = self.model.algo_model_exec.predict(source=frame, imgsz=self.model.input_size, + def model_inference(self, frames): + results_generator = self.model.algo_model_exec.predict(source=frames, imgsz=self.model.input_size, save_txt=False, save=False, - verbose=False, stream=True) - results = list(results_generator) # 确保生成器转换为列表 - result = results[0] - # logger.debug(f"model {self.model.algo_model_info.name} result: {len(result)}") - return result + verbose=True, stream=True) + result_boxes = [] + for r in results_generator: + result_boxes.append(r.boxes) - def post_process(self, frame, model_result, annotator): + return result_boxes + + def post_process(self, frames, model_results, annotators): results = [] - for box in model_result.boxes: - results.append( - { - 'object_class_id': int(box.cls), - 'object_class_name': self.model_names[int(box.cls)], - 'confidence': float(box.conf), - 'location': ", ".join([f"{x:.6f}" for x in box.xyxyn.cpu().squeeze().tolist()]) - } - ) - if annotator is not None: - for s_box in model_result.boxes: - annotator.box_label(s_box.xyxy.cpu().squeeze(), - f"{self.model_names[int(s_box.cls)]} {float(s_box.conf):.2f}", - color=colors(int(s_box.cls)), - rotated=False) + for idx,frame in enumerate(frames): + frame_boxes = model_results[idx] + annotator = annotators[idx] + frame_result = [] + for box in frame_boxes: + frame_result.append( + { + 'object_class_id': int(box.cls), + 'object_class_name': self.model_names[int(box.cls)], + 'confidence': float(box.conf), + 'location': ", ".join([f"{x:.6f}" for x in box.xyxyn.cpu().squeeze().tolist()]) + } + ) + results.append(frame_result) + if annotator is not None: + for s_box in frame_boxes: + annotator.box_label(s_box.xyxy.cpu().squeeze(), + f"{self.model_names[int(s_box.cls)]} {float(s_box.conf):.2f}", + color=colors(int(s_box.cls)), + rotated=False) return results - def run(self, frame, annotator): - processed_frame = self.pre_process(frame=frame) - model_result = self.model_inference(frame=processed_frame) - result = self.post_process(frame=frame, model_result=model_result, annotator=annotator) - return frame, result + def run(self, frames, annotators): + processed_frames = [self.pre_process(frame=frame) for frame in frames] + result_boxes = self.model_inference(frames=processed_frames) + results = self.post_process(frames=frames, model_results=result_boxes, annotators=annotators) + return frames, results diff --git a/model_handler/coco_engine_model_handler.py b/model_handler/coco_engine_model_handler.py index e6023d0..ec8006f 100644 --- a/model_handler/coco_engine_model_handler.py +++ b/model_handler/coco_engine_model_handler.py @@ -88,3 +88,4 @@ 78: 'hair_drier', 79: 'toothbrush', } + self.model_ids = list(self.model_names.keys()) diff --git a/model_handler/labor_engine_model_handler.py b/model_handler/labor_engine_model_handler.py index d0fefba..8892fcb 100644 --- a/model_handler/labor_engine_model_handler.py +++ b/model_handler/labor_engine_model_handler.py @@ -68,3 +68,4 @@ 58: '鼓风机', } + self.model_ids = list(self.model_names.keys()) diff --git a/run.sh b/run.sh index bb36274..815883a 100644 --- a/run.sh +++ b/run.sh @@ -8,4 +8,5 @@ export LC_ALL=C.UTF-8 cd /code/safe-algo-pro -python3 main.py +exec python3 main.py +#tail -f /dev/null \ No newline at end of file diff --git a/scene_handler/alarm_message_center.py b/scene_handler/alarm_message_center.py new file mode 100644 index 0000000..fba1397 --- /dev/null +++ b/scene_handler/alarm_message_center.py @@ -0,0 +1,121 @@ +import asyncio +import copy +import time +from collections import deque +from threading import Thread, Lock + +''' +队列消息取出规则: +- 按 alarmCategory(优先级:2 > 1 > 3 > 0)和 category_order 从小到大排序。 +- 确保每个 alarmCategory 在 10 秒内只能发送一次。 + +消息定时清理: +- 每次取消息时会检查队列中长期堆积、超过一定时间(比如 30 秒)的消息,并清理这些消息。 +''' + + +class AlarmMessageCenter: + def __init__(self, device_id, main_loop, tcp_manager=None, message_send_interval=5, category_interval=10, retention_time=30, + category_priority=None): + self.device_id = device_id + self.tcp_manager = tcp_manager + self.main_loop = main_loop + + self.queue = deque() + self.last_sent_time = {} # 记录每个 alarmCategory 的最后发送时间 + self.lock = Lock() + self.message_send_interval = message_send_interval # 消息发送间隔(秒) + self.category_interval = category_interval # 类别发送间隔(秒) + self.retention_time = retention_time # 消息最长保留时间(秒) + if category_priority: + self.category_priority = category_priority + self.auto_update_alarm_priority = False + else: + self.category_priority = category_priority + self.auto_update_alarm_priority = True + + def add_message(self, message_ori): + message = copy.deepcopy(message_ori) + message['timestamp'] = int(time.time()) # 添加消息放入队列的时间 + with self.lock: + self.queue.append(message) + + # 动态更新优先级映射 + if self.auto_update_alarm_priority: + alarm_category = message['alarmCategory'] + if alarm_category not in self.category_priority: + unique_categories = sorted({msg['alarmCategory'] for msg in self.queue}) + self.category_priority = {cat: idx for idx, cat in enumerate(unique_categories)} + print(f"更新优先级映射: {self.category_priority}") + + def _clean_old_messages(self): + """清理长期堆积的消息""" + print(f'清理长期堆积的消息 (队列长度: {len(self.queue)})') + now = time.time() + with self.lock: + self.queue = deque([msg for msg in self.queue if now - msg['timestamp'] <= self.retention_time]) + print(f'清理后的队列长度: {len(self.queue)}') + + # def _get_next_message(self): + # """按优先级和时间规则取出下一条消息""" + # now = time.time() + # with self.lock: + # # 按照规则排序队列 + # sorted_queue = sorted( + # self.queue, + # key=lambda x: ( + # self.alarm_priority.get(x['alarmCategory'], 4), # 按 alarmCategory 优先级排序 + # x['category_order'], # category_order 小的排前面 + # x['timestamp'] # 时间靠前的排前面 + # ) + # ) + # # 遍历排序后的队列,找到符合规则的消息 + # for msg in sorted_queue: + # alarm_category = msg['alarmCategory'] + # if alarm_category not in self.last_sent_time or now - self.last_sent_time[alarm_category] > 10: + # self.queue.remove(msg) + # self.last_sent_time[alarm_category] = now + # return msg + # return None + + def _get_next_message(self): + """按优先级和时间规则取出下一条消息""" + now = time.time() + with self.lock: + # 按优先级依次检查 + for priority in sorted(self.category_priority.values()): + found_valid_message = False # 用于标记当前优先级是否有消息被处理 + for msg in sorted( + (m for m in self.queue if self.category_priority.get(m['alarmCategory'], 99) == priority), + key=lambda x: (-x['timestamp'], x['category_order']) + ): + alarm_category = msg['alarmCategory'] + # 检查是否符合发送条件 + if alarm_category not in self.last_sent_time or now - self.last_sent_time[ + alarm_category] > self.category_interval: + self.queue.remove(msg) + self.last_sent_time[alarm_category] = now + return msg # 找到符合条件的消息立即返回 + found_valid_message = True # 当前优先级存在消息但不符合条件 + # 如果当前优先级的所有消息都被检查过且不符合条件,跳到下一个优先级 + if not found_valid_message: + continue + return None # 如果所有优先级都没有消息符合条件,则返回 None + + def process_messages(self): + while True: + time.sleep(self.message_send_interval) + self._clean_old_messages() # 清理长期堆积的消息 + next_message = self._get_next_message() + if next_message: + self.send_message(next_message) + + def send_message(self, message): + """发送报警消息""" + print(f"发送报警消息: {message['alarmContent']} (类别: {message['alarmCategory']}, 时间: {message['timestamp']})") + if self.tcp_manager: + asyncio.run_coroutine_threadsafe( + self.tcp_manager.send_message_to_device(device_id=self.device_id, + message=message['alarmSoundMessage'], + have_response=False), + self.main_loop) diff --git a/scene_handler/alarm_record_center.py b/scene_handler/alarm_record_center.py new file mode 100644 index 0000000..5373cd5 --- /dev/null +++ b/scene_handler/alarm_record_center.py @@ -0,0 +1,128 @@ +import asyncio +import base64 +import time +from datetime import datetime + +import cv2 + +from common.global_thread_pool import GlobalThreadPool +from common.http_utils import send_request +from db.database import get_db +from entity.alarm_record import AlarmRecordCreate +from services.alarm_record_service import AlarmRecordService +from services.global_config import GlobalConfig + + +def image_to_base64(numpy_image, format='jpg'): + # 将NumPy数组转换为图片格式的字节流 + # format指定图片格式,'png'或'jpeg' + success, encoded_image = cv2.imencode(f'.{format}', numpy_image) + if not success: + raise ValueError("Could not encode image") + + # 将字节流转换为Base64编码 + base64_encoded_image = base64.b64encode(encoded_image) + + # 将bytes类型转换为UTF-8字符串 + base64_message = base64_encoded_image.decode('utf-8') + + return base64_message + # return f'data:image/{format};base64,{base64_message}' + + +class AlarmRecordCenter: + def __init__(self, save_interval=-1, main_loop=None): + """ + 初始化报警记录中心 + :param upload_interval: 报警上传间隔(秒),如果 <= 0,则不报警 + """ + self.main_loop = main_loop + self.save_interval = save_interval + self.thread_pool = GlobalThreadPool() + self.global_config = GlobalConfig() + # self.upload_interval = upload_interval + self.device_alarm_upload_time = {} # key: device_code, value: {alarm_type: last upload time} + self.device_alarm_save_time = {} + + def need_alarm(self, device_code, alarm_dict): + """ + 是否需要报警 + :param device_code: 设备编号 + :param alarm_dict: 报警类型字典 + :return: 是否需要报警 + """ + current_time = time.time() + alarm_type = alarm_dict['alarmType'] + + return self.need_save_alarm_record(device_code, current_time, alarm_type) \ + or self.need_upload_alarm_record(device_code, current_time, alarm_type) + + def need_save_alarm_record(self, device_code, current_time, alarm_type): + if self.save_interval <= 0: + return False + + if device_code not in self.device_alarm_save_time: + self.device_alarm_save_time[device_code] = {} + last_save_time = self.device_alarm_save_time[device_code].get(alarm_type) + if last_save_time is None or (current_time - last_save_time) > self.save_interval: + return True + return False + + def need_upload_alarm_record(self, device_code, current_time, alarm_type): + push_config = self.global_config.get_alarm_push_config() + if not push_config or not push_config.push_url: + return False + + if device_code not in self.device_alarm_upload_time: + self.device_alarm_upload_time[device_code] = {} + last_upload_time = self.device_alarm_upload_time[device_code].get(alarm_type) + if last_upload_time is None or (current_time - last_upload_time) > push_config.upload_interval: + return True + return False + + async def save_record(self, alarm_data, alarm_np_img): + async for db in get_db(): + alarm_record_service = AlarmRecordService(db) + await alarm_record_service.add_alarm(alarm_data=alarm_data, alarm_np_img=alarm_np_img) + + def upload_alarm_record(self, device_code, alarm_dict, alarm_np_img=None, alarm_value=None): + """ + 上传报警记录 + :param alarm_value: + :param device_code: 设备编号 + :param alarm_dict: 报警类型字典 + :param alarm_np_img: 报警图片(np array) + """ + # 获取当前时间 + current_time = time.time() + alarm_type = alarm_dict['alarmType'] + + if self.need_save_alarm_record(device_code, current_time, alarm_type): + alarm_record_data = AlarmRecordCreate( + device_code=device_code, + device_id=100, + alarm_type=alarm_type, + alarm_content=alarm_dict['alarmContent'], + alarm_time=datetime.now(), + alarm_value=alarm_value if alarm_value else None, + ) + asyncio.run_coroutine_threadsafe( + self.save_record(alarm_record_data, alarm_np_img), + self.main_loop) + self.device_alarm_save_time[device_code][alarm_type] = current_time + + if self.need_upload_alarm_record(device_code, current_time, alarm_type): + alarm_record = { + 'devcode': device_code, + 'alarmTime': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), + 'alarmType': alarm_dict['alarmType'], + 'alarmContent': alarm_dict['alarmContent'], + } + if alarm_value: + alarm_record['alarmValue'] = alarm_value + if alarm_np_img: + alarm_record['alarmImage'] = image_to_base64(alarm_np_img) + + push_config = self.global_config.get_alarm_push_config() + self.thread_pool.submit_task(send_request, push_config.upload_url, alarm_record) + self.device_alarm_upload_time[device_code][alarm_type] = current_time diff --git a/scene_handler/block_scene_handler.py b/scene_handler/block_scene_handler.py new file mode 100644 index 0000000..3dd0b31 --- /dev/null +++ b/scene_handler/block_scene_handler.py @@ -0,0 +1,426 @@ +import time +import traceback +from asyncio import Event +from copy import deepcopy +from datetime import datetime + + +from flatbuffers.builder import np +from scipy.spatial import ConvexHull + +from algo.stream_loader import OpenCVStreamLoad +from common.detect_utils import is_within_alert_range, get_person_head, intersection_area, bbox_area +from common.device_status_manager import DeviceStatusManager +from common.global_logger import logger +from common.global_thread_pool import GlobalThreadPool +from common.harmful_gas_manager import HarmfulGasManager +from common.image_plotting import Annotator +from entity.device import Device +from scene_handler.alarm_message_center import AlarmMessageCenter +from scene_handler.alarm_record_center import AlarmRecordCenter +from scene_handler.base_scene_handler import BaseSceneHandler +from scene_handler.limit_space_scene_handler import is_overlapping +from services.global_config import GlobalConfig +from tcp.tcp_manager import TcpManager + +from entity.device import Device +from common.http_utils import get_request +from ultralytics import YOLO + +''' +alarmCategory: +0 行为监管 +1 环境监管 +2 人员监管 +3 围栏监管 + +handelType: +0 检测到报警 +1 未检测到报警 +2 人未穿戴报警 +3 其他 +''' +ALARM_DICT = [ + { + 'alarmCategory': 0, + 'alarmType': '1', + 'handelType': 1, + 'category_order': 1, + 'class_idx': [34], + 'alarm_name': 'no_fire_extinguisher', + 'alarmContent': '未检测到灭火器', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '', + }, + { + 'alarmCategory': 0, + 'alarmType': '2', + 'handelType': 1, + 'category_order': 2, + 'class_idx': [43], + 'alarm_name': 'no_barrier_tape', + 'alarmContent': '未检测到警戒线', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '', + }, + { + 'alarmCategory': 0, + 'alarmType': '3', + 'handelType': 1, + 'category_order': 3, + 'class_idx': [48], + 'alarm_name': 'no_cone', + 'alarmContent': '未检测到锥桶', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '', + }, + { + 'alarmCategory': 0, + 'alarmType': '4', + 'handelType': 1, + 'category_order': 4, + 'class_idx': [4, 5, 16], + 'alarm_name': 'no_board', + 'alarmContent': '未检测到指示牌', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '', + }, + { + 'alarmCategory': 0, + 'alarmType': '5', + 'handelType': 2, + 'category_order': -1, + 'class_idx': [18], + 'alarm_name': 'no_helmet', + 'alarmContent': '未佩戴安全帽', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '未佩戴安全帽', + }, + # todo 明火 + { + 'alarmCategory': 1, + 'alarmType': '7', + 'handelType': 3, + 'category_order': 1, + 'class_idx': [], + 'alarm_name': 'gas_alarm', + 'alarmContent': '甲烷浓度超限', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '', + }, + { + 'alarmCategory': 1, + 'alarmType': '8', + 'handelType': 3, + 'category_order': 2, + 'class_idx': [], + 'alarm_name': 'harmful_alarm', + 'alarmContent': '有害气体浓度超标', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '', + }, + { + 'alarmCategory': 2, + 'alarmType': '9', + 'handelType': 3, + 'category_order': -1, + 'class_idx': [], + 'alarm_name': 'health_alarm', + 'alarmContent': '心率血氧异常', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '', + }, + { + 'alarmCategory': 3, + 'alarmType': '10', + 'handelType': 2, + 'category_order': 4, + 'class_idx': [24], + 'alarm_name': 'break_in_alarm', + 'alarmContent': '非法闯入', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '非法闯入', + }, + +] + +COLOR_RED = (0, 0, 255) +COLOR_BLUE = (255, 0, 0) + + +class BlockSceneHandler(BaseSceneHandler): + def __init__(self, device: Device, thread_id: str, tcp_manager: TcpManager, main_loop, range_points): + super().__init__(device=device, thread_id=thread_id, tcp_manager=tcp_manager, main_loop=main_loop) + self.__stop_event = Event(loop=main_loop) + self.health_ts_dict = {} + self.harmful_ts_dict = {} + self.object_ts_dict = {} + self.thread_pool = GlobalThreadPool() + + self.alarm_message_center = AlarmMessageCenter(device.id,main_loop=main_loop, tcp_manager=tcp_manager, + category_priority={2: 0, 1: 1, 3: 2, 0: 3}) + self.alarm_record_center = AlarmRecordCenter(save_interval=device.alarm_interval,main_loop=main_loop) + self.harmful_data_manager = HarmfulGasManager() + self.device_status_manager = DeviceStatusManager() + + + self.health_device_codes = ['HWIH061000056395'] # todo + self.harmful_device_codes = [] # todo + + for helmet_code in self.health_device_codes: + self.thread_pool.submit_task(self.health_data_task, helmet_code) + for harmful_device_code in self.harmful_device_codes: + self.thread_pool.submit_task(self.harmful_data_task, harmful_device_code) + + self.thread_pool.submit_task(self.alarm_message_center.process_messages) + + # todo 明火 + self.model = YOLO('weights/labor-v8-20241114.pt') + self.model_classes = { + 0: '三脚架', + 3: '人', + 4: '作业信息公示牌', + 6: '危险告知牌', + 9: '反光衣', + 11: '呼吸面罩', + 13: '四合一', + 15: '头', + 16: '安全告知牌', + 18: '安全帽', + 20: '安全标识牌', + 24: '工服', + 34: '灭火器', + 43: '警戒线', + 48: '路锥', + 58: '鼓风机', + } + self.PERSON_CLASS_IDX = 3 + self.HEAD_CLASS_IDX = 15 + + self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, + device_thread_id=thread_id) + self.range_points = range_points + self.abs_range_points = self.get_absolute_range() + + self.tracking_status = {} # 跟踪每个行人的状态 + self.max_missing_frames = 25 # 报警的阈值 + self.disappear_threshold = 25 * 3 # 移除行人的阈值 + + def get_absolute_range(self): + fence_info = eval(self.range_points) + if fence_info and len(fence_info) > 1: + abs_points = [] + for p in fence_info: + abs_points.append( + [int(p[0] * int(self.stream_loader.frame_width)), int(p[1] * int(self.stream_loader.frame_height))]) + + abs_points = np.array(abs_points, dtype=np.int32) + hull = ConvexHull(abs_points) + sorted_coordinates = abs_points[hull.vertices] + # abs_points = abs_points.reshape((-1, 1, 2)) + return sorted_coordinates + else: + return None + + def harmful_data_task(self, harmful_device_code): + while not self.__stop_event.is_set(): + harmful_gas_data = self.harmful_data_manager.get_device_all_data(harmful_device_code) + for gas_type, gas_data in harmful_gas_data.items(): + ts_key = f'{harmful_device_code}_{gas_type}' + last_ts = self.harmful_ts_dict.get(ts_key) + gas_ts = gas_data.get('gas_ts') + if last_ts is None or (gas_ts - last_ts).total_seconds() > 0: + self.harmful_ts_dict[ts_key] = gas_ts + self.handle_harmful_gas_alarm(harmful_device_code, gas_type, gas_data) + + def health_data_task(self, helmet_code): + while not self.__stop_event.is_set(): + header = { + 'ak': 'fe80b2f021644b1b8c77fda743a83670', + 'sk': '8771ea6e931d4db646a26f67bcb89909', + } + url = f'https://jls.huaweisoft.com//api/ih-log/v1.0/ih-api/helmetInfo/{helmet_code}' + response = get_request(url, headers=header) + if response and response.get('data'): + last_ts = self.health_ts_dict.get(helmet_code) + vitalsigns_data = response.get('data').get('vitalSignsData') + if vitalsigns_data: + upload_timestamp = datetime.strptime(vitalsigns_data.get('uploadTimestamp'), "%Y-%m-%d %H:%M:%S") + if last_ts is None or (upload_timestamp.timestamp() - last_ts) > 0: + self.health_ts_dict[helmet_code] = upload_timestamp.timestamp() + if time.time() - upload_timestamp.timestamp() < 10 * 60: # 10分钟以前的数据不做处理 + self.handle_health_alarm(helmet_code, vitalsigns_data.get('bloodOxygen'), + vitalsigns_data.get('heartRate'),upload_timestamp) + time.sleep(10) + + def handle_health_alarm(self, helmet_code, blood_oxygen, heartrate, upload_timestamp): + logger.debug(f'health_data: {helmet_code}, blood_oxygen = {blood_oxygen}, heartrate = {heartrate}, ' + f'upload_timestamp = {upload_timestamp}') + if heartrate < 60 or heartrate > 120 or blood_oxygen < 85: + alarm_dict = [d for d in ALARM_DICT if d['alarmCategory'] == 2] + if alarm_dict: + self.alarm_message_center.add_message(alarm_dict[0]) + # todo 需要生成报警记录吗 需要往后台发原始数据吗 + + def handle_harmful_gas_alarm(self, device_code, gas_type, gas_data): + alarm = False + gas_value = gas_data['gas_value'] + if gas_type == 3: # h2s + alarm = gas_value > 120.0 + elif gas_type == 4: # co + alarm = gas_value > 10.0 + elif gas_type == 5: # o2 + alarm = gas_value < 15 + elif gas_type == 50: # ex + alarm = gas_value > 10 + + if alarm: + alarm_dict = [d for d in ALARM_DICT if d['alarmCategory'] == 1] + if alarm_dict: + self.alarm_message_center.add_message(alarm_dict[0]) + # todo 需要生成报警记录吗 + + def model_predict(self, frames): + results_generator = self.model(frames, save_txt=False, save=False, verbose=False, conf=0.5, + classes=list(self.model_classes.keys()), + imgsz=640, + stream=True) + result_boxes = [] + for r in results_generator: + result_boxes.append(r.boxes) + + pred_ids = [[int(box.cls) for box in sublist] for sublist in result_boxes] + pred_names = [[self.model_classes[int(box.cls)] for box in sublist] for sublist in result_boxes] + return result_boxes, pred_ids, pred_names + + def handle_behave_alarm(self, frames, result_boxes, pred_ids, pred_names): + behave_alarm_dicts = [d for d in ALARM_DICT if d['alarmCategory'] == 0] + for alarm_dict in behave_alarm_dicts: + for idx, frame_boxes in enumerate(result_boxes): + frame = frames[idx] + frame_ids, frame_names = pred_ids[idx], pred_names[idx] + object_boxes = [box for box in frame_boxes if int(box.cls) in alarm_dict['class_idx']] + if alarm_dict['handelType'] == 0: # 检测到就报警 + if object_boxes: + self.alarm_message_center.add_message(alarm_dict) + if self.alarm_record_center.need_alarm(self.device.code, alarm_dict): + annotator = Annotator(deepcopy(frame)) + for box in frame_boxes: + if int(box.cls) != self.HEAD_CLASS_IDX: + box_color = COLOR_RED if int(box.cls) in alarm_dict['class_idx'] else COLOR_BLUE + annotator.box_label(box.xyxy.cpu().squeeze(), + f"{self.model_classes[int(box.cls)]}", + color=box_color, + rotated=False) + self.alarm_record_center.upload_alarm_record(self.device.code, alarm_dict, + annotator.result()) + + elif alarm_dict['handelType'] == 1: # 检测不到报警 + if object_boxes: + self.object_ts_dict[alarm_dict['alarm_name']] = time.time() + else: + last_ts = self.object_ts_dict.get(alarm_dict['alarm_name'], 0) + if time.time() - last_ts > 5: + self.object_ts_dict[alarm_dict['alarm_name']] = time.time() + self.alarm_message_center.add_message(alarm_dict) + if self.alarm_record_center.need_alarm(self.device.code, alarm_dict): + annotator = Annotator(deepcopy(frame)) + for box in frame_boxes: + if int(box.cls) != self.HEAD_CLASS_IDX: + annotator.box_label(box.xyxy.cpu().squeeze(), + f"{self.model_classes[int(box.cls)]}", + color=COLOR_BLUE, + rotated=False) + self.alarm_record_center.upload_alarm_record(self.device.code, alarm_dict, + annotator.result()) + elif alarm_dict['handelType'] == 2: # 人未穿戴报警 + person_boxes = [box for box in frame_boxes if int(box.cls) == self.PERSON_CLASS_IDX] + head_boxes = [box for box in frame_boxes if int(box.cls) == self.HEAD_CLASS_IDX] + has_alarm = False + annotator = None + for person_box in person_boxes: + person_bbox = person_box.xyxy.cpu().squeeze() + # 检查这个人是否佩戴了安全帽 + has_helmet = True + person_head = get_person_head(person_bbox, head_boxes) + if person_head is not None: + has_helmet = any( + is_overlapping(person_head.xyxy.cpu().squeeze(), helmet.xyxy.cpu().squeeze()) + for helmet in object_boxes) + if not has_helmet: + has_alarm = True + if self.alarm_record_center.need_alarm(self.device.code, alarm_dict): + annotator = Annotator(deepcopy(frame)) if annotator is None else annotator + annotator.box_label(person_bbox, alarm_dict['label'], color=COLOR_RED, rotated=False) + + if has_alarm: + self.alarm_message_center.add_message(alarm_dict) + if self.alarm_record_center.need_alarm(self.device.code, alarm_dict): + for box in frame_boxes: + box_cls = box.cls + if box_cls != self.PERSON_CLASS_IDX and box_cls != self.HEAD_CLASS_IDX: + annotator.box_label(box.xyxy.cpu().squeeze(), + f"{self.model_classes[int(box.cls)]}", + color=COLOR_BLUE, + rotated=False) + self.alarm_record_center.upload_alarm_record(self.device.code, alarm_dict, + annotator.result()) + + def handle_break_in_alarm(self, frames, result_boxes, pred_ids, pred_names): + break_in_alarm_dicts = [d for d in ALARM_DICT if d['alarmCategory'] == 3] + for alarm_dict in break_in_alarm_dicts: + for idx, frame_boxes in enumerate(result_boxes): + frame = frames[idx] + frame_ids, frame_names = pred_ids[idx], pred_names[idx] + person_boxes = [box for box in frame_boxes if int(box.cls) == self.PERSON_CLASS_IDX] + head_boxes = [box for box in frame_boxes if int(box.cls) == self.HEAD_CLASS_IDX] + object_boxes = [box for box in frame_boxes if int(box.cls) in alarm_dict['class_idx']] + has_alarm = False + annotator = None + for person_box in person_boxes: + person_bbox = person_box.xyxy.cpu().squeeze() + person_id = person_box.id + if is_within_alert_range(person_bbox, self.abs_range_points): + has_object = True + person_head = get_person_head(person_bbox, head_boxes) + if person_head is not None: + overlap_ratio = intersection_area(person_bbox, person_head.xyxy.cpu().squeeze()) / bbox_area(person_bbox) + if overlap_ratio < 0.5: # 头占人<0.5,判断是否穿工服。不太准确 + has_object = any( + is_overlapping(person_head.xyxy.cpu().squeeze(), object_boxe.xyxy.cpu().squeeze()) + for object_boxe in object_boxes) + if not has_object: + self.tracking_status[person_box.id] = self.tracking_status.get(person_box.id, 0) + 1 + + has_alarm = True + if self.alarm_record_center.need_alarm(self.device.code, alarm_dict): + annotator = Annotator(deepcopy(frame)) if annotator is None else annotator + annotator.box_label(person_bbox, alarm_dict['label'], color=COLOR_RED, rotated=False) + + if has_alarm: + self.alarm_message_center.add_message(alarm_dict) + if self.alarm_record_center.need_alarm(self.device.code, alarm_dict): + self.alarm_record_center.upload_alarm_record(self.device.code, alarm_dict, + annotator.result()) + + def run(self): + while not self.stream_loader.init: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + self.stream_loader.init_cap() + for frames in self.stream_loader: + try: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + if not frames: + continue + + self.device_status_manager.set_status(device_id=self.device.id) + result_boxes, pred_ids, pred_names = self.model_predict(frames) # 结果都是二维数组,对应batch中的每个frame + # print(pred_names) + self.handle_behave_alarm(frames, result_boxes, pred_ids, pred_names) + self.handle_break_in_alarm(frames, result_boxes, pred_ids, pred_names) + + except Exception as ex: + traceback.print_exc() + logger.error(ex) diff --git a/scene_handler/limit_space_scene_handler.py b/scene_handler/limit_space_scene_handler.py index d1c5b9a..a7f1674 100644 --- a/scene_handler/limit_space_scene_handler.py +++ b/scene_handler/limit_space_scene_handler.py @@ -145,7 +145,7 @@ class LimitSpaceSceneHandler(BaseSceneHandler): - def __init__(self, device: Device, thread_id: str, tcp_manager: TcpManager, main_loop): + def __init__(self, device: Device, thread_id: str, tcp_manager: TcpManager, main_loop, range_points): super().__init__(device=device, thread_id=thread_id, tcp_manager=tcp_manager, main_loop=main_loop) # self.device = device # self.thread_id = thread_id @@ -287,20 +287,20 @@ if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 self.stream_loader.init_cap() - for frame in self.stream_loader: + for frames in self.stream_loader: try: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 # print('frame') - if frame is None: + if frames is None: continue self.device_status_manager.set_status(device_id=self.device.id) - result_boxes, pred_ids, pred_names = self.model_predict(frame) + # result_boxes, pred_ids, pred_names = self.model_predict(frames) frame_alarm = {} - self.process_alarm(frame, result_boxes, pred_ids, pred_names, frame_alarm) - self.process_labor(frame, result_boxes, pred_ids, pred_names) + # self.process_alarm(frame, result_boxes, pred_ids, pred_names, frame_alarm) + # self.process_labor(frame, result_boxes, pred_ids, pred_names) if len(frame_alarm.keys()) > 0: for key in frame_alarm.keys(): diff --git a/services/alarm_record_service.py b/services/alarm_record_service.py new file mode 100644 index 0000000..a0a86f9 --- /dev/null +++ b/services/alarm_record_service.py @@ -0,0 +1,46 @@ +import os +import uuid +from datetime import datetime + +import aiofiles +import cv2 +from sqlalchemy.ext.asyncio import AsyncSession + +from entity.alarm_record import AlarmRecordCreate, AlarmRecord + + +class AlarmRecordService: + def __init__(self, db: AsyncSession): + self.db = db + + async def add_alarm(self, alarm_data:AlarmRecordCreate, alarm_np_img): + async def save_alarm_file(): + # 生成当前年月日作为目录路径 + current_date = datetime.now().strftime('%Y-%m-%d') + # 生成随机 UUID 作为文件名 + file_name = f"{uuid.uuid4()}.jpeg" + # 创建保存图片的完整路径 + save_path = os.path.join('./storage/alarms', current_date, file_name) + # 创建目录(如果不存在) + os.makedirs(os.path.dirname(save_path), exist_ok=True) + + # 将 frame_data 转换为二进制数据(假设 frame_data 是一个 numpy 数组) + _, encoded_image = cv2.imencode('.jpeg', alarm_np_img) + image_data = encoded_image.tobytes() + + # 使用 aiofiles 进行异步写入 + async with aiofiles.open(save_path, 'wb') as f: + await f.write(image_data) + + return save_path + + # 异步保存图片文件 + alarm_img_path = await save_alarm_file() + + # 创建并保存到数据库中 + alarm_record = AlarmRecord.model_validate(alarm_data) + alarm_record.alarm_image = alarm_img_path + self.db.add(alarm_record) + await self.db.commit() + await self.db.refresh(alarm_record) + return alarm_record diff --git a/services/device_scene_relation_service.py b/services/device_scene_relation_service.py index 3416831..c94c540 100644 --- a/services/device_scene_relation_service.py +++ b/services/device_scene_relation_service.py @@ -39,18 +39,19 @@ if result_row: relation, scene = result_row scene_info = DeviceSceneRelationInfo( - id=relation.id, - device_id=relation.device_id, - scene_id=relation.scene_id, - scene_name=scene.name, - scene_version=scene.version, - scene_handle_task=scene.handle_task, - scene_remark=scene.remark, - ) + id=relation.id, + device_id=relation.device_id, + scene_id=relation.scene_id, + scene_name=scene.name, + scene_version=scene.version, + scene_handle_task=scene.handle_task, + scene_remark=scene.remark, + range_points=relation.range_points + ) return scene_info - async def add_relation_by_device(self, device_id: int, scene_id: int): - new_relation = DeviceSceneRelation(device_id=device_id, scene_id=scene_id) + async def add_relation_by_device(self, device_id: int, scene_id: int, range_points: str = None): + new_relation = DeviceSceneRelation(device_id=device_id, scene_id=scene_id, range_points=range_points) new_relation.create_time = datetime.now() new_relation.update_time = datetime.now() self.db.add(new_relation) @@ -64,8 +65,8 @@ await self.db.commit() return result.rowcount - async def update_relation_by_device(self, device_id: int, scene_id: int): + async def update_relation_by_device(self, device_id: int, scene_id: int, range_points: str = None): await self.delete_relation_by_device(device_id) - new_relation = await self.add_relation_by_device(device_id, scene_id) + new_relation = await self.add_relation_by_device(device_id, scene_id, range_points) self.notify_change(device_id, NotifyChangeType.DEVICE_SCENE_RELATION_UPDATE) return new_relation diff --git a/services/global_config.py b/services/global_config.py index fa53b2b..8baad1e 100644 --- a/services/global_config.py +++ b/services/global_config.py @@ -23,6 +23,7 @@ self.gas_push_config = None self.algo_result_push_config = None self.alarm_push_config = None + self.harmful_gas_push_config = None self._init_done = False async def _initialize(self): @@ -39,6 +40,7 @@ self.gas_push_config = await self.config_service.get_push_config(PUSH_TYPE.GAS) self.algo_result_push_config = await self.config_service.get_push_config(PUSH_TYPE.ALGO_RESULT) self.alarm_push_config = await self.config_service.get_push_config(PUSH_TYPE.ALARM) + self.harmful_gas_push_config = await self.config_service.get_push_config(PUSH_TYPE.HARMFUL_GAS) self._init_done = True async def on_config_change(self, config: PushConfig): @@ -48,6 +50,8 @@ await self.set_algo_result_push_config(config) elif config.push_type == PUSH_TYPE.ALARM: await self.set_alarm_push_config(config) + elif config.push_type == PUSH_TYPE.HARMFUL_GAS: + await self.set_harmful_gas_push_config(config) def get_gas_push_config(self): """获取 gas_push_config 配置""" @@ -78,3 +82,14 @@ if config: async with self._lock: self.alarm_push_config = config + + def get_harmful_gas_push_config(self): + """获取 algo_result_push_config 配置""" + return self.harmful_gas_push_config + + async def set_harmful_gas_push_config(self, config): + """设置 algo_result_push_config 配置""" + if config: + async with self._lock: + self.harmful_gas_push_config = config + diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py deleted file mode 100644 index 9109d75..0000000 --- a/algo/algo_runner_manager.py +++ /dev/null @@ -1,8 +0,0 @@ -from algo.algo_runner import AlgoRunner - -# algo_runner = AlgoRunner() - - -def get_algo_runner(): - pass - # return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 6a86830..d2c647e 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -3,7 +3,7 @@ from copy import deepcopy from dataclasses import dataclass import importlib -from datetime import datetime +from datetime import datetime, timedelta from threading import Event from typing import List, Dict @@ -48,6 +48,8 @@ self.__stop_event = Event() # 使用 Event 控制线程的运行状态 self.frame_ts = None self.push_ts = None + self.frames_detected = 0 + self.fps_ts = None self.thread_id = thread_id self.device_frame_service = DeviceFrameService(db_session) @@ -77,37 +79,40 @@ return True return False - def save_frame_results(self, frame, results_map): - if not self.check_frame_interval(): - return - # device_frame = self.device_frame_service.add_frame(self.device.id, frame) - # frame_id = device_frame.id + def save_frame_results(self, frames, result_maps): + for idx,frame in enumerate(frames): + results_map = result_maps[idx] - future = asyncio.run_coroutine_threadsafe( - self.device_frame_service.add_frame(self.device.id, frame), self.main_loop - ) - device_frame = future.result() - frame_id = device_frame.id + if not self.check_frame_interval(): + return + # device_frame = self.device_frame_service.add_frame(self.device.id, frame) + # frame_id = device_frame.id - logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') - frame_results = [] - for model_exec_id, results in results_map.items(): - for r in results: - frame_result = FrameAnalysisResultCreate( - device_id=self.device.id, - frame_id=frame_id, - algo_model_id=model_exec_id, - object_class_id=r.object_class_id, - object_class_name=r.object_class_name, - confidence=round(r.confidence, 4), - location=r.location, - ) - frame_results.append(frame_result) - asyncio.run_coroutine_threadsafe( - self.frame_analysis_result_service.add_frame_analysis_results(frame_results), - self.main_loop - ) - self.thread_pool.submit_task(self.push_frame_results, frame_results) + future = asyncio.run_coroutine_threadsafe( + self.device_frame_service.add_frame(self.device.id, frame), self.main_loop + ) + device_frame = future.result() + frame_id = device_frame.id + + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') + frame_results = [] + for model_exec_id, results in results_map.items(): + for r in results: + frame_result = FrameAnalysisResultCreate( + device_id=self.device.id, + frame_id=frame_id, + algo_model_id=model_exec_id, + object_class_id=r.object_class_id, + object_class_name=r.object_class_name, + confidence=round(r.confidence, 4), + location=r.location, + ) + frame_results.append(frame_result) + asyncio.run_coroutine_threadsafe( + self.frame_analysis_result_service.add_frame_analysis_results(frame_results), + self.main_loop + ) + self.thread_pool.submit_task(self.push_frame_results, frame_results) def push_frame_results(self, frame_results): global_config = GlobalConfig() @@ -124,31 +129,63 @@ ) self.push_ts = current_time # 更新推送时间戳 + def log_fps(self, frame_count): + self.frames_detected += frame_count + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_detected / 10.0 + self.frames_detected = 0 + logger.info(f"FPS (detect) for device {self.device.code}: {fps}") + self.fps_ts = current_time + def run(self): while not self.stream_loader.init: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 self.stream_loader.init_cap() - for frame in self.stream_loader: + for frames in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 # print('frame') - if frame is None: + if frames is None or len(frames)<=0: continue self.device_status_manager.set_status(device_id=self.device.id) - results_map = {} - annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + + result_maps = [] # 保存每个frame的result map + annotators = [] + for frame in frames: + annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + annotators.append(annotator) + for model_exec in self.model_exec_list: handle_task_name = model_exec.algo_model_info.handle_task handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) handler_instance = handler_cls(model_exec) - frame, results = handler_instance.run(frame, annotator) - results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] + frames, results = handler_instance.run(frames, annotators) + + # 遍历检测结果,按帧存储 + for frame_idx, r in enumerate(results): # 遍历每帧结果 + if len(result_maps) <= frame_idx: # 初始化 frame_results_map + result_maps.append({}) + frame_results_map = result_maps[frame_idx] + # 为当前模型存储检测结果 + if model_exec.algo_model_id not in frame_results_map: + frame_results_map[model_exec.algo_model_id] = [] + # 添加检测结果 + frame_results_map[model_exec.algo_model_id].extend( + DetectionResult.from_dict(box) for box in r + ) # 结果处理 - self.thread_pool.submit_task(self.save_frame_results, frame, results_map) - self.display_frame_manager.add_frame(self.device.id, annotator.result()) + self.thread_pool.submit_task(self.save_frame_results, frames, result_maps) + for annotator in annotators: + self.display_frame_manager.add_frame(self.device.id, annotator.result()) + + self.log_fps(len(frames)) + + # future = asyncio.run_coroutine_threadsafe( # self.display_frame_manager.add_frame(self.device.id, annotator.result()), self.main_loop # ) diff --git a/algo/model_manager.py b/algo/model_manager.py index 1c30867..e84676f 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -19,11 +19,12 @@ class ModelManager: - def __init__(self, model_service: ModelService, model_warm_up=5): + def __init__(self, model_service: ModelService, model_warm_up=5, batch_size=4): # self.db = db self.model_service = model_service self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up + self.batch_size=4 async def query_model_inuse(self): algo_model_list = list(await self.model_service.get_models_in_use()) @@ -58,8 +59,11 @@ if not isinstance(imgsz, list): imgsz = [imgsz, imgsz] dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + dummy_inputs = [] + for _ in range(self.batch_size): + dummy_inputs.append(dummy_input) for i in range(self.model_warm_up): - algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + algo_model_exec.algo_model_exec.predict(source=dummy_inputs, imgsz=imgsz, verbose=False,save=False,save_txt=False) logger.info(f'warm up model {model_name} success!') logger.info(f'load model {model_name} success!') diff --git a/algo/scene_runner.py b/algo/scene_runner.py index e5c067e..5120f96 100644 --- a/algo/scene_runner.py +++ b/algo/scene_runner.py @@ -69,7 +69,7 @@ try: handle_task_name = scene.scene_handle_task handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) - handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) + handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop,scene.range_points) future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) self.device_tasks[device.id] = handler_instance logger.info(f'start thread {thread_id}, device info: {device}') diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 4133388..202213c 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -1,13 +1,20 @@ +from datetime import datetime, timedelta + import cv2 import time import numpy as np from threading import Thread, Event + +import queue + from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool class OpenCVStreamLoad: def __init__(self, camera_url, camera_code, device_thread_id = '', + batch_size=4, + queue_size=100, retry_interval=1, vid_stride=1): assert camera_url is not None and camera_url != '' @@ -19,10 +26,15 @@ self.init = False self.frame = None + self.frame_queue = queue.Queue(maxsize=queue_size) + self.fps_ts = None self.cap = None + self.batch_size = batch_size + self.frames_read = 0 self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 self.thread_pool = GlobalThreadPool() + self.init_cap() # self.thread = Thread(target=self.update, daemon=True) # self.thread.start() @@ -44,7 +56,14 @@ 尝试创建视频流捕获对象。 """ try: - cap = cv2.VideoCapture(self.url) + # cap = cv2.VideoCapture(self.url) + gst_pipeline = ( + f"rtspsrc location={self.url} ! " + f"rtph264depay ! h264parse ! " + f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + f"appsink" + ) + cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) # 可以在这里设置cap的一些属性,如果需要的话 return cap except Exception as e: @@ -74,9 +93,13 @@ cap = cv2.VideoCapture(self.url) # cap = cv2.VideoCapture(self.url, cv2.CAP_FFMPEG) - # 使用 GStreamer pipeline 作为 VideoCapture 的输入 - # gst_str = f"rtspsrc location={self.url} ! decodebin ! videoconvert ! appsink" - # cap = cv2.VideoCapture(gst_str, cv2.CAP_GSTREAMER) + # gst_pipeline = ( + # f"rtspsrc location={self.url} ! " + # f"rtph264depay ! h264parse ! " + # f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + # f"appsink" + # ) + # cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) if cap.isOpened(): logger.info(f"{self.url} connected successfully!") @@ -91,6 +114,15 @@ return cap + def log_fps(self): + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_read / 10.0 + self.frames_read = 0 + logger.info(f"FPS (read) for device {self.camera_code}: {fps}") + self.fps_ts = current_time + def update(self): vid_n = 0 log_n = 0 @@ -111,6 +143,9 @@ else: vid_n += 1 self.frame = frame + self.frames_read += 1 + if not self.frame_queue.full(): + self.frame_queue.put(frame) if log_n % 1000 == 0: logger.debug(f'{self.url} cap success') log_n = (log_n + 1) % 250 @@ -120,12 +155,22 @@ self.cap.release() self.frame = None self.cap = self.get_connect() # 尝试重新连接 + self.log_fps() def __iter__(self): return self def __next__(self): - return self.frame + batch_frames = [] + + queue_length = self.frame_queue.qsize() + if queue_length < self.batch_size: + return [] + + while not self.frame_queue.empty() and len(batch_frames) < self.batch_size: + frame = self.frame_queue.get() + batch_frames.append(frame) + return batch_frames def stop(self): logger.info(f'stop stream loader {self.url},device thread {self.device_thread_id} stopped') diff --git a/apis/control.py b/apis/control.py index 9c6d6fb..2d351d8 100644 --- a/apis/control.py +++ b/apis/control.py @@ -9,6 +9,7 @@ from fastapi import APIRouter import docker import socket +import signal from apis.base import standard_error_response, standard_response from common.global_logger import logger @@ -47,6 +48,23 @@ return None +def handle_sigterm(signum, frame): + print(f"Received signal: {signum}, shutting down gracefully...") + os._exit(0) + + +signal.signal(signal.SIGTERM, handle_sigterm) + + +def shutdown(): + # 等待短暂时间以确保响应已发送 + time.sleep(1) + logger.info("Shutting down the application...") + # sys.exit(0) + # os.kill(os.getpid(), signal.SIGTERM) + os._exit(0) + + @router.get("/restart") async def restart(): try: @@ -92,17 +110,20 @@ logger.error(f"Failed to restart container asynchronously: {ex}") # 在新线程中启动重启操作 - threading.Thread(target=restart_container_async).start() + # threading.Thread(target=restart_container_async).start() + threading.Thread(target=shutdown).start() return standard_response() except Exception as e: traceback.print_exc() return standard_error_response(code=500, message=f"Failed to restart container: {e}") + @router.get("/sync_test") def sync_test(): return standard_response() + @router.get("/async_test") async def async_test(): return standard_response() diff --git a/apis/device_scene_realtion.py b/apis/device_scene_realtion.py index d9aa3a2..77d3677 100644 --- a/apis/device_scene_realtion.py +++ b/apis/device_scene_realtion.py @@ -22,8 +22,8 @@ @router.post("/update_by_device", response_model=StandardResponse[DeviceSceneRelation]) -async def update_by_device(device_id: int, scene_id: int, +async def update_by_device(device_id: int, scene_id: int, range_points: str = None, db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - relation = await service.update_relation_by_device(device_id, scene_id) + relation = await service.update_relation_by_device(device_id, scene_id,range_points) return standard_response(data=relation) diff --git a/app_instance.py b/app_instance.py index 82a750d..c648f6e 100644 --- a/app_instance.py +++ b/app_instance.py @@ -1,4 +1,5 @@ import asyncio + from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException @@ -17,6 +18,7 @@ from services.scene_service import SceneService from services.schedule_job import start_scheduler from tcp.tcp_manager import TcpManager +from tcp.tcp_server import start_server _app = None # 创建一个私有变量来存储 app 实例 @@ -69,6 +71,8 @@ app.state.scene_runner = scene_runner await scene_runner.start() + main_loop.create_task(start_server()) + main_loop.create_task(start_scheduler()) yield # 允许请求处理 diff --git a/common/byte_utils.py b/common/byte_utils.py index 301bb8a..ea3f4ea 100644 --- a/common/byte_utils.py +++ b/common/byte_utils.py @@ -1,3 +1,11 @@ def format_bytes(data: bytes): - return ''.join(f'\\x{byte:02X}' for byte in data) \ No newline at end of file + return ''.join(f'\\x{byte:02X}' for byte in data) + + +if __name__ == '__main__': + msg = b'\xaa\x01\x00\x93\x12\x00\xA6' + print(format_bytes(msg)) + + msg = b'\xAA\x01\x00\x95\x00\x00\x96' + print(format_bytes(msg)) \ No newline at end of file diff --git a/common/consts.py b/common/consts.py index 5a3286c..6deaedf 100644 --- a/common/consts.py +++ b/common/consts.py @@ -36,3 +36,4 @@ GAS = 1 ALGO_RESULT = 2 ALARM = 3 + HARMFUL_GAS = 4 diff --git a/common/detect_utils.py b/common/detect_utils.py new file mode 100644 index 0000000..33c340c --- /dev/null +++ b/common/detect_utils.py @@ -0,0 +1,72 @@ +import matplotlib.path as mpath + + +def is_within_alert_range(bbox, range): + if range is None: + return True + bottom_left = [bbox[0], bbox[3]] + bottom_right = [bbox[2], bbox[3]] + + in_fence = is_point_in_polygon(bottom_left, range) or is_point_in_polygon(bottom_right, range) + return in_fence + + +def is_point_in_polygon(point, polygon): + """ + 判断点是否在多边形内 + + 参数: + - point: 待判断的点,格式为(x, y) + - polygon: 多边形的顶点坐标列表,格式为[(x1, y1), (x2, y2), ...] + + 返回: + - 如果点在多边形内,返回True;否则,返回False。 + """ + # 创建多边形路径对象 + # polygon = np.array(polygon, dtype=np.int32) + # + # hull = ConvexHull(polygon) + # sorted_coordinates = polygon[hull.vertices] + + poly_path = mpath.Path(polygon) + # 使用contains_point方法检查点是否在多边形内 + return poly_path.contains_point(point) + + +def intersection_area(bbox1, bbox2): + # 计算两个坐标框的重叠面积 + x1, y1, x2, y2 = bbox1 + x3, y3, x4, y4 = bbox2 + + xi1 = max(x1, x3) + yi1 = max(y1, y3) + xi2 = min(x2, x4) + yi2 = min(y2, y4) + + width = max(0, xi2 - xi1) + height = max(0, yi2 - yi1) + + return width * height + + +def bbox_area(bbox): + # 计算坐标框的面积 + x1, y1, x2, y2 = bbox + return (x2 - x1) * (y2 - y1) + + +def get_person_head(person_bbox, head_bboxes): + best_head = None + max_overlap = 0 + + for head in head_bboxes: + head_bbox = head.xyxy.cpu().squeeze() + overlap_area = intersection_area(person_bbox, head_bbox) + head_area = bbox_area(head_bbox) + overlap_ratio = overlap_area / head_area + + if overlap_ratio >= 0.8 and (best_head is None or overlap_area > max_overlap or ( + overlap_area == max_overlap and float(head.conf) > float(best_head.conf))): + best_head = head + max_overlap = overlap_area + return best_head diff --git a/common/harmful_gas_manager.py b/common/harmful_gas_manager.py new file mode 100644 index 0000000..297e02c --- /dev/null +++ b/common/harmful_gas_manager.py @@ -0,0 +1,44 @@ +import threading +from datetime import datetime + + +class HarmfulGasManager: + + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + # 确保线程安全的单例模式 + if not cls._instance: + with cls._lock: + if not cls._instance: + cls._instance = super(HarmfulGasManager, cls).__new__(cls) + return cls._instance + + def __init__(self): + # 初始化一次,避免重复初始化 + if not hasattr(self, "device_data"): + self.device_data = {} + self.lock = threading.Lock() + + def get_device_all_data(self, device_code): + """获取指定设备的在线状态""" + with self.lock: + return self.device_data.get(device_code, None) + + def get_device_data(self, device_code, gas_type): + """获取指定设备的在线状态""" + with self.lock: + device_data = self.device_data.get(device_code, None) + if device_data: + return device_data.get(gas_type, None) + return None + + def set_device_data(self, device_code, gas_type, gas_data): + """设置指定设备的在线状态""" + with self.lock: + device_data = self.device_data.get(device_code, None) + if not device_data: + device_data = {} + self.device_data[device_code] = device_data + device_data[gas_type] = gas_data \ No newline at end of file diff --git a/common/http_utils.py b/common/http_utils.py index 5f15707..a49c6ea 100644 --- a/common/http_utils.py +++ b/common/http_utils.py @@ -22,3 +22,11 @@ logger.info(f"Response: {response.status_code}, {response.text}") except requests.RequestException as e: logger.error(f"Failed to push data: {e}") + + +def get_request(url, headers=None): + try: + response = requests.get(url, headers=headers) + return response.json() + except requests.RequestException as e: + logger.error(f"Failed to get data: {e}") diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db index d1acdd4..694b825 100644 --- a/db/safe-algo-pro.db +++ b/db/safe-algo-pro.db Binary files differ diff --git a/entity/alarm_record.py b/entity/alarm_record.py new file mode 100644 index 0000000..1ebe474 --- /dev/null +++ b/entity/alarm_record.py @@ -0,0 +1,32 @@ +from datetime import datetime +from typing import Optional + +from sqlmodel import Field, SQLModel + + +class AlarmRecordBase(SQLModel): + # todo tree code?? alarm_category?? + device_code: str + alarm_type: int + alarm_content: str + alarm_value: Optional[float] = None + alarm_image: Optional[str] = None + alarm_time: datetime = Field(default_factory=datetime.now) + + class Config: + json_encoders = { + datetime: lambda v: v.strftime('%Y-%m-%d %H:%M:%S') + } + + +class AlarmRecord(AlarmRecordBase, table=True): + __tablename__ = 'alarm_record' + id: Optional[int] = Field(default=None, primary_key=True) + + +class AlarmRecordCreate(AlarmRecordBase): + pass + + +class AlarmRecordInfo(AlarmRecordBase): + id: int diff --git a/entity/device_scene_relation.py b/entity/device_scene_relation.py index c4543d7..38481ef 100644 --- a/entity/device_scene_relation.py +++ b/entity/device_scene_relation.py @@ -8,6 +8,7 @@ class DeviceSceneRelationBase(SQLModel): scene_id: int device_id: int + range_points: Optional[str] = None class DeviceSceneRelation(DeviceSceneRelationBase, TimestampMixin, table=True): diff --git a/model_handler/base_model_handler.py b/model_handler/base_model_handler.py index 46e0d53..c7c3490 100644 --- a/model_handler/base_model_handler.py +++ b/model_handler/base_model_handler.py @@ -8,41 +8,48 @@ def __init__(self, model: AlgoModelExec): self.model = model self.model_names = model.algo_model_exec.names + self.model_ids = list(self.model_names.keys()) def pre_process(self, frame): return frame - def model_inference(self, frame): - results_generator = self.model.algo_model_exec.predict(source=frame, imgsz=self.model.input_size, + def model_inference(self, frames): + results_generator = self.model.algo_model_exec.predict(source=frames, imgsz=self.model.input_size, save_txt=False, save=False, - verbose=False, stream=True) - results = list(results_generator) # 确保生成器转换为列表 - result = results[0] - # logger.debug(f"model {self.model.algo_model_info.name} result: {len(result)}") - return result + verbose=True, stream=True) + result_boxes = [] + for r in results_generator: + result_boxes.append(r.boxes) - def post_process(self, frame, model_result, annotator): + return result_boxes + + def post_process(self, frames, model_results, annotators): results = [] - for box in model_result.boxes: - results.append( - { - 'object_class_id': int(box.cls), - 'object_class_name': self.model_names[int(box.cls)], - 'confidence': float(box.conf), - 'location': ", ".join([f"{x:.6f}" for x in box.xyxyn.cpu().squeeze().tolist()]) - } - ) - if annotator is not None: - for s_box in model_result.boxes: - annotator.box_label(s_box.xyxy.cpu().squeeze(), - f"{self.model_names[int(s_box.cls)]} {float(s_box.conf):.2f}", - color=colors(int(s_box.cls)), - rotated=False) + for idx,frame in enumerate(frames): + frame_boxes = model_results[idx] + annotator = annotators[idx] + frame_result = [] + for box in frame_boxes: + frame_result.append( + { + 'object_class_id': int(box.cls), + 'object_class_name': self.model_names[int(box.cls)], + 'confidence': float(box.conf), + 'location': ", ".join([f"{x:.6f}" for x in box.xyxyn.cpu().squeeze().tolist()]) + } + ) + results.append(frame_result) + if annotator is not None: + for s_box in frame_boxes: + annotator.box_label(s_box.xyxy.cpu().squeeze(), + f"{self.model_names[int(s_box.cls)]} {float(s_box.conf):.2f}", + color=colors(int(s_box.cls)), + rotated=False) return results - def run(self, frame, annotator): - processed_frame = self.pre_process(frame=frame) - model_result = self.model_inference(frame=processed_frame) - result = self.post_process(frame=frame, model_result=model_result, annotator=annotator) - return frame, result + def run(self, frames, annotators): + processed_frames = [self.pre_process(frame=frame) for frame in frames] + result_boxes = self.model_inference(frames=processed_frames) + results = self.post_process(frames=frames, model_results=result_boxes, annotators=annotators) + return frames, results diff --git a/model_handler/coco_engine_model_handler.py b/model_handler/coco_engine_model_handler.py index e6023d0..ec8006f 100644 --- a/model_handler/coco_engine_model_handler.py +++ b/model_handler/coco_engine_model_handler.py @@ -88,3 +88,4 @@ 78: 'hair_drier', 79: 'toothbrush', } + self.model_ids = list(self.model_names.keys()) diff --git a/model_handler/labor_engine_model_handler.py b/model_handler/labor_engine_model_handler.py index d0fefba..8892fcb 100644 --- a/model_handler/labor_engine_model_handler.py +++ b/model_handler/labor_engine_model_handler.py @@ -68,3 +68,4 @@ 58: '鼓风机', } + self.model_ids = list(self.model_names.keys()) diff --git a/run.sh b/run.sh index bb36274..815883a 100644 --- a/run.sh +++ b/run.sh @@ -8,4 +8,5 @@ export LC_ALL=C.UTF-8 cd /code/safe-algo-pro -python3 main.py +exec python3 main.py +#tail -f /dev/null \ No newline at end of file diff --git a/scene_handler/alarm_message_center.py b/scene_handler/alarm_message_center.py new file mode 100644 index 0000000..fba1397 --- /dev/null +++ b/scene_handler/alarm_message_center.py @@ -0,0 +1,121 @@ +import asyncio +import copy +import time +from collections import deque +from threading import Thread, Lock + +''' +队列消息取出规则: +- 按 alarmCategory(优先级:2 > 1 > 3 > 0)和 category_order 从小到大排序。 +- 确保每个 alarmCategory 在 10 秒内只能发送一次。 + +消息定时清理: +- 每次取消息时会检查队列中长期堆积、超过一定时间(比如 30 秒)的消息,并清理这些消息。 +''' + + +class AlarmMessageCenter: + def __init__(self, device_id, main_loop, tcp_manager=None, message_send_interval=5, category_interval=10, retention_time=30, + category_priority=None): + self.device_id = device_id + self.tcp_manager = tcp_manager + self.main_loop = main_loop + + self.queue = deque() + self.last_sent_time = {} # 记录每个 alarmCategory 的最后发送时间 + self.lock = Lock() + self.message_send_interval = message_send_interval # 消息发送间隔(秒) + self.category_interval = category_interval # 类别发送间隔(秒) + self.retention_time = retention_time # 消息最长保留时间(秒) + if category_priority: + self.category_priority = category_priority + self.auto_update_alarm_priority = False + else: + self.category_priority = category_priority + self.auto_update_alarm_priority = True + + def add_message(self, message_ori): + message = copy.deepcopy(message_ori) + message['timestamp'] = int(time.time()) # 添加消息放入队列的时间 + with self.lock: + self.queue.append(message) + + # 动态更新优先级映射 + if self.auto_update_alarm_priority: + alarm_category = message['alarmCategory'] + if alarm_category not in self.category_priority: + unique_categories = sorted({msg['alarmCategory'] for msg in self.queue}) + self.category_priority = {cat: idx for idx, cat in enumerate(unique_categories)} + print(f"更新优先级映射: {self.category_priority}") + + def _clean_old_messages(self): + """清理长期堆积的消息""" + print(f'清理长期堆积的消息 (队列长度: {len(self.queue)})') + now = time.time() + with self.lock: + self.queue = deque([msg for msg in self.queue if now - msg['timestamp'] <= self.retention_time]) + print(f'清理后的队列长度: {len(self.queue)}') + + # def _get_next_message(self): + # """按优先级和时间规则取出下一条消息""" + # now = time.time() + # with self.lock: + # # 按照规则排序队列 + # sorted_queue = sorted( + # self.queue, + # key=lambda x: ( + # self.alarm_priority.get(x['alarmCategory'], 4), # 按 alarmCategory 优先级排序 + # x['category_order'], # category_order 小的排前面 + # x['timestamp'] # 时间靠前的排前面 + # ) + # ) + # # 遍历排序后的队列,找到符合规则的消息 + # for msg in sorted_queue: + # alarm_category = msg['alarmCategory'] + # if alarm_category not in self.last_sent_time or now - self.last_sent_time[alarm_category] > 10: + # self.queue.remove(msg) + # self.last_sent_time[alarm_category] = now + # return msg + # return None + + def _get_next_message(self): + """按优先级和时间规则取出下一条消息""" + now = time.time() + with self.lock: + # 按优先级依次检查 + for priority in sorted(self.category_priority.values()): + found_valid_message = False # 用于标记当前优先级是否有消息被处理 + for msg in sorted( + (m for m in self.queue if self.category_priority.get(m['alarmCategory'], 99) == priority), + key=lambda x: (-x['timestamp'], x['category_order']) + ): + alarm_category = msg['alarmCategory'] + # 检查是否符合发送条件 + if alarm_category not in self.last_sent_time or now - self.last_sent_time[ + alarm_category] > self.category_interval: + self.queue.remove(msg) + self.last_sent_time[alarm_category] = now + return msg # 找到符合条件的消息立即返回 + found_valid_message = True # 当前优先级存在消息但不符合条件 + # 如果当前优先级的所有消息都被检查过且不符合条件,跳到下一个优先级 + if not found_valid_message: + continue + return None # 如果所有优先级都没有消息符合条件,则返回 None + + def process_messages(self): + while True: + time.sleep(self.message_send_interval) + self._clean_old_messages() # 清理长期堆积的消息 + next_message = self._get_next_message() + if next_message: + self.send_message(next_message) + + def send_message(self, message): + """发送报警消息""" + print(f"发送报警消息: {message['alarmContent']} (类别: {message['alarmCategory']}, 时间: {message['timestamp']})") + if self.tcp_manager: + asyncio.run_coroutine_threadsafe( + self.tcp_manager.send_message_to_device(device_id=self.device_id, + message=message['alarmSoundMessage'], + have_response=False), + self.main_loop) diff --git a/scene_handler/alarm_record_center.py b/scene_handler/alarm_record_center.py new file mode 100644 index 0000000..5373cd5 --- /dev/null +++ b/scene_handler/alarm_record_center.py @@ -0,0 +1,128 @@ +import asyncio +import base64 +import time +from datetime import datetime + +import cv2 + +from common.global_thread_pool import GlobalThreadPool +from common.http_utils import send_request +from db.database import get_db +from entity.alarm_record import AlarmRecordCreate +from services.alarm_record_service import AlarmRecordService +from services.global_config import GlobalConfig + + +def image_to_base64(numpy_image, format='jpg'): + # 将NumPy数组转换为图片格式的字节流 + # format指定图片格式,'png'或'jpeg' + success, encoded_image = cv2.imencode(f'.{format}', numpy_image) + if not success: + raise ValueError("Could not encode image") + + # 将字节流转换为Base64编码 + base64_encoded_image = base64.b64encode(encoded_image) + + # 将bytes类型转换为UTF-8字符串 + base64_message = base64_encoded_image.decode('utf-8') + + return base64_message + # return f'data:image/{format};base64,{base64_message}' + + +class AlarmRecordCenter: + def __init__(self, save_interval=-1, main_loop=None): + """ + 初始化报警记录中心 + :param upload_interval: 报警上传间隔(秒),如果 <= 0,则不报警 + """ + self.main_loop = main_loop + self.save_interval = save_interval + self.thread_pool = GlobalThreadPool() + self.global_config = GlobalConfig() + # self.upload_interval = upload_interval + self.device_alarm_upload_time = {} # key: device_code, value: {alarm_type: last upload time} + self.device_alarm_save_time = {} + + def need_alarm(self, device_code, alarm_dict): + """ + 是否需要报警 + :param device_code: 设备编号 + :param alarm_dict: 报警类型字典 + :return: 是否需要报警 + """ + current_time = time.time() + alarm_type = alarm_dict['alarmType'] + + return self.need_save_alarm_record(device_code, current_time, alarm_type) \ + or self.need_upload_alarm_record(device_code, current_time, alarm_type) + + def need_save_alarm_record(self, device_code, current_time, alarm_type): + if self.save_interval <= 0: + return False + + if device_code not in self.device_alarm_save_time: + self.device_alarm_save_time[device_code] = {} + last_save_time = self.device_alarm_save_time[device_code].get(alarm_type) + if last_save_time is None or (current_time - last_save_time) > self.save_interval: + return True + return False + + def need_upload_alarm_record(self, device_code, current_time, alarm_type): + push_config = self.global_config.get_alarm_push_config() + if not push_config or not push_config.push_url: + return False + + if device_code not in self.device_alarm_upload_time: + self.device_alarm_upload_time[device_code] = {} + last_upload_time = self.device_alarm_upload_time[device_code].get(alarm_type) + if last_upload_time is None or (current_time - last_upload_time) > push_config.upload_interval: + return True + return False + + async def save_record(self, alarm_data, alarm_np_img): + async for db in get_db(): + alarm_record_service = AlarmRecordService(db) + await alarm_record_service.add_alarm(alarm_data=alarm_data, alarm_np_img=alarm_np_img) + + def upload_alarm_record(self, device_code, alarm_dict, alarm_np_img=None, alarm_value=None): + """ + 上传报警记录 + :param alarm_value: + :param device_code: 设备编号 + :param alarm_dict: 报警类型字典 + :param alarm_np_img: 报警图片(np array) + """ + # 获取当前时间 + current_time = time.time() + alarm_type = alarm_dict['alarmType'] + + if self.need_save_alarm_record(device_code, current_time, alarm_type): + alarm_record_data = AlarmRecordCreate( + device_code=device_code, + device_id=100, + alarm_type=alarm_type, + alarm_content=alarm_dict['alarmContent'], + alarm_time=datetime.now(), + alarm_value=alarm_value if alarm_value else None, + ) + asyncio.run_coroutine_threadsafe( + self.save_record(alarm_record_data, alarm_np_img), + self.main_loop) + self.device_alarm_save_time[device_code][alarm_type] = current_time + + if self.need_upload_alarm_record(device_code, current_time, alarm_type): + alarm_record = { + 'devcode': device_code, + 'alarmTime': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), + 'alarmType': alarm_dict['alarmType'], + 'alarmContent': alarm_dict['alarmContent'], + } + if alarm_value: + alarm_record['alarmValue'] = alarm_value + if alarm_np_img: + alarm_record['alarmImage'] = image_to_base64(alarm_np_img) + + push_config = self.global_config.get_alarm_push_config() + self.thread_pool.submit_task(send_request, push_config.upload_url, alarm_record) + self.device_alarm_upload_time[device_code][alarm_type] = current_time diff --git a/scene_handler/block_scene_handler.py b/scene_handler/block_scene_handler.py new file mode 100644 index 0000000..3dd0b31 --- /dev/null +++ b/scene_handler/block_scene_handler.py @@ -0,0 +1,426 @@ +import time +import traceback +from asyncio import Event +from copy import deepcopy +from datetime import datetime + + +from flatbuffers.builder import np +from scipy.spatial import ConvexHull + +from algo.stream_loader import OpenCVStreamLoad +from common.detect_utils import is_within_alert_range, get_person_head, intersection_area, bbox_area +from common.device_status_manager import DeviceStatusManager +from common.global_logger import logger +from common.global_thread_pool import GlobalThreadPool +from common.harmful_gas_manager import HarmfulGasManager +from common.image_plotting import Annotator +from entity.device import Device +from scene_handler.alarm_message_center import AlarmMessageCenter +from scene_handler.alarm_record_center import AlarmRecordCenter +from scene_handler.base_scene_handler import BaseSceneHandler +from scene_handler.limit_space_scene_handler import is_overlapping +from services.global_config import GlobalConfig +from tcp.tcp_manager import TcpManager + +from entity.device import Device +from common.http_utils import get_request +from ultralytics import YOLO + +''' +alarmCategory: +0 行为监管 +1 环境监管 +2 人员监管 +3 围栏监管 + +handelType: +0 检测到报警 +1 未检测到报警 +2 人未穿戴报警 +3 其他 +''' +ALARM_DICT = [ + { + 'alarmCategory': 0, + 'alarmType': '1', + 'handelType': 1, + 'category_order': 1, + 'class_idx': [34], + 'alarm_name': 'no_fire_extinguisher', + 'alarmContent': '未检测到灭火器', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '', + }, + { + 'alarmCategory': 0, + 'alarmType': '2', + 'handelType': 1, + 'category_order': 2, + 'class_idx': [43], + 'alarm_name': 'no_barrier_tape', + 'alarmContent': '未检测到警戒线', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '', + }, + { + 'alarmCategory': 0, + 'alarmType': '3', + 'handelType': 1, + 'category_order': 3, + 'class_idx': [48], + 'alarm_name': 'no_cone', + 'alarmContent': '未检测到锥桶', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '', + }, + { + 'alarmCategory': 0, + 'alarmType': '4', + 'handelType': 1, + 'category_order': 4, + 'class_idx': [4, 5, 16], + 'alarm_name': 'no_board', + 'alarmContent': '未检测到指示牌', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '', + }, + { + 'alarmCategory': 0, + 'alarmType': '5', + 'handelType': 2, + 'category_order': -1, + 'class_idx': [18], + 'alarm_name': 'no_helmet', + 'alarmContent': '未佩戴安全帽', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '未佩戴安全帽', + }, + # todo 明火 + { + 'alarmCategory': 1, + 'alarmType': '7', + 'handelType': 3, + 'category_order': 1, + 'class_idx': [], + 'alarm_name': 'gas_alarm', + 'alarmContent': '甲烷浓度超限', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '', + }, + { + 'alarmCategory': 1, + 'alarmType': '8', + 'handelType': 3, + 'category_order': 2, + 'class_idx': [], + 'alarm_name': 'harmful_alarm', + 'alarmContent': '有害气体浓度超标', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '', + }, + { + 'alarmCategory': 2, + 'alarmType': '9', + 'handelType': 3, + 'category_order': -1, + 'class_idx': [], + 'alarm_name': 'health_alarm', + 'alarmContent': '心率血氧异常', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '', + }, + { + 'alarmCategory': 3, + 'alarmType': '10', + 'handelType': 2, + 'category_order': 4, + 'class_idx': [24], + 'alarm_name': 'break_in_alarm', + 'alarmContent': '非法闯入', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '非法闯入', + }, + +] + +COLOR_RED = (0, 0, 255) +COLOR_BLUE = (255, 0, 0) + + +class BlockSceneHandler(BaseSceneHandler): + def __init__(self, device: Device, thread_id: str, tcp_manager: TcpManager, main_loop, range_points): + super().__init__(device=device, thread_id=thread_id, tcp_manager=tcp_manager, main_loop=main_loop) + self.__stop_event = Event(loop=main_loop) + self.health_ts_dict = {} + self.harmful_ts_dict = {} + self.object_ts_dict = {} + self.thread_pool = GlobalThreadPool() + + self.alarm_message_center = AlarmMessageCenter(device.id,main_loop=main_loop, tcp_manager=tcp_manager, + category_priority={2: 0, 1: 1, 3: 2, 0: 3}) + self.alarm_record_center = AlarmRecordCenter(save_interval=device.alarm_interval,main_loop=main_loop) + self.harmful_data_manager = HarmfulGasManager() + self.device_status_manager = DeviceStatusManager() + + + self.health_device_codes = ['HWIH061000056395'] # todo + self.harmful_device_codes = [] # todo + + for helmet_code in self.health_device_codes: + self.thread_pool.submit_task(self.health_data_task, helmet_code) + for harmful_device_code in self.harmful_device_codes: + self.thread_pool.submit_task(self.harmful_data_task, harmful_device_code) + + self.thread_pool.submit_task(self.alarm_message_center.process_messages) + + # todo 明火 + self.model = YOLO('weights/labor-v8-20241114.pt') + self.model_classes = { + 0: '三脚架', + 3: '人', + 4: '作业信息公示牌', + 6: '危险告知牌', + 9: '反光衣', + 11: '呼吸面罩', + 13: '四合一', + 15: '头', + 16: '安全告知牌', + 18: '安全帽', + 20: '安全标识牌', + 24: '工服', + 34: '灭火器', + 43: '警戒线', + 48: '路锥', + 58: '鼓风机', + } + self.PERSON_CLASS_IDX = 3 + self.HEAD_CLASS_IDX = 15 + + self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, + device_thread_id=thread_id) + self.range_points = range_points + self.abs_range_points = self.get_absolute_range() + + self.tracking_status = {} # 跟踪每个行人的状态 + self.max_missing_frames = 25 # 报警的阈值 + self.disappear_threshold = 25 * 3 # 移除行人的阈值 + + def get_absolute_range(self): + fence_info = eval(self.range_points) + if fence_info and len(fence_info) > 1: + abs_points = [] + for p in fence_info: + abs_points.append( + [int(p[0] * int(self.stream_loader.frame_width)), int(p[1] * int(self.stream_loader.frame_height))]) + + abs_points = np.array(abs_points, dtype=np.int32) + hull = ConvexHull(abs_points) + sorted_coordinates = abs_points[hull.vertices] + # abs_points = abs_points.reshape((-1, 1, 2)) + return sorted_coordinates + else: + return None + + def harmful_data_task(self, harmful_device_code): + while not self.__stop_event.is_set(): + harmful_gas_data = self.harmful_data_manager.get_device_all_data(harmful_device_code) + for gas_type, gas_data in harmful_gas_data.items(): + ts_key = f'{harmful_device_code}_{gas_type}' + last_ts = self.harmful_ts_dict.get(ts_key) + gas_ts = gas_data.get('gas_ts') + if last_ts is None or (gas_ts - last_ts).total_seconds() > 0: + self.harmful_ts_dict[ts_key] = gas_ts + self.handle_harmful_gas_alarm(harmful_device_code, gas_type, gas_data) + + def health_data_task(self, helmet_code): + while not self.__stop_event.is_set(): + header = { + 'ak': 'fe80b2f021644b1b8c77fda743a83670', + 'sk': '8771ea6e931d4db646a26f67bcb89909', + } + url = f'https://jls.huaweisoft.com//api/ih-log/v1.0/ih-api/helmetInfo/{helmet_code}' + response = get_request(url, headers=header) + if response and response.get('data'): + last_ts = self.health_ts_dict.get(helmet_code) + vitalsigns_data = response.get('data').get('vitalSignsData') + if vitalsigns_data: + upload_timestamp = datetime.strptime(vitalsigns_data.get('uploadTimestamp'), "%Y-%m-%d %H:%M:%S") + if last_ts is None or (upload_timestamp.timestamp() - last_ts) > 0: + self.health_ts_dict[helmet_code] = upload_timestamp.timestamp() + if time.time() - upload_timestamp.timestamp() < 10 * 60: # 10分钟以前的数据不做处理 + self.handle_health_alarm(helmet_code, vitalsigns_data.get('bloodOxygen'), + vitalsigns_data.get('heartRate'),upload_timestamp) + time.sleep(10) + + def handle_health_alarm(self, helmet_code, blood_oxygen, heartrate, upload_timestamp): + logger.debug(f'health_data: {helmet_code}, blood_oxygen = {blood_oxygen}, heartrate = {heartrate}, ' + f'upload_timestamp = {upload_timestamp}') + if heartrate < 60 or heartrate > 120 or blood_oxygen < 85: + alarm_dict = [d for d in ALARM_DICT if d['alarmCategory'] == 2] + if alarm_dict: + self.alarm_message_center.add_message(alarm_dict[0]) + # todo 需要生成报警记录吗 需要往后台发原始数据吗 + + def handle_harmful_gas_alarm(self, device_code, gas_type, gas_data): + alarm = False + gas_value = gas_data['gas_value'] + if gas_type == 3: # h2s + alarm = gas_value > 120.0 + elif gas_type == 4: # co + alarm = gas_value > 10.0 + elif gas_type == 5: # o2 + alarm = gas_value < 15 + elif gas_type == 50: # ex + alarm = gas_value > 10 + + if alarm: + alarm_dict = [d for d in ALARM_DICT if d['alarmCategory'] == 1] + if alarm_dict: + self.alarm_message_center.add_message(alarm_dict[0]) + # todo 需要生成报警记录吗 + + def model_predict(self, frames): + results_generator = self.model(frames, save_txt=False, save=False, verbose=False, conf=0.5, + classes=list(self.model_classes.keys()), + imgsz=640, + stream=True) + result_boxes = [] + for r in results_generator: + result_boxes.append(r.boxes) + + pred_ids = [[int(box.cls) for box in sublist] for sublist in result_boxes] + pred_names = [[self.model_classes[int(box.cls)] for box in sublist] for sublist in result_boxes] + return result_boxes, pred_ids, pred_names + + def handle_behave_alarm(self, frames, result_boxes, pred_ids, pred_names): + behave_alarm_dicts = [d for d in ALARM_DICT if d['alarmCategory'] == 0] + for alarm_dict in behave_alarm_dicts: + for idx, frame_boxes in enumerate(result_boxes): + frame = frames[idx] + frame_ids, frame_names = pred_ids[idx], pred_names[idx] + object_boxes = [box for box in frame_boxes if int(box.cls) in alarm_dict['class_idx']] + if alarm_dict['handelType'] == 0: # 检测到就报警 + if object_boxes: + self.alarm_message_center.add_message(alarm_dict) + if self.alarm_record_center.need_alarm(self.device.code, alarm_dict): + annotator = Annotator(deepcopy(frame)) + for box in frame_boxes: + if int(box.cls) != self.HEAD_CLASS_IDX: + box_color = COLOR_RED if int(box.cls) in alarm_dict['class_idx'] else COLOR_BLUE + annotator.box_label(box.xyxy.cpu().squeeze(), + f"{self.model_classes[int(box.cls)]}", + color=box_color, + rotated=False) + self.alarm_record_center.upload_alarm_record(self.device.code, alarm_dict, + annotator.result()) + + elif alarm_dict['handelType'] == 1: # 检测不到报警 + if object_boxes: + self.object_ts_dict[alarm_dict['alarm_name']] = time.time() + else: + last_ts = self.object_ts_dict.get(alarm_dict['alarm_name'], 0) + if time.time() - last_ts > 5: + self.object_ts_dict[alarm_dict['alarm_name']] = time.time() + self.alarm_message_center.add_message(alarm_dict) + if self.alarm_record_center.need_alarm(self.device.code, alarm_dict): + annotator = Annotator(deepcopy(frame)) + for box in frame_boxes: + if int(box.cls) != self.HEAD_CLASS_IDX: + annotator.box_label(box.xyxy.cpu().squeeze(), + f"{self.model_classes[int(box.cls)]}", + color=COLOR_BLUE, + rotated=False) + self.alarm_record_center.upload_alarm_record(self.device.code, alarm_dict, + annotator.result()) + elif alarm_dict['handelType'] == 2: # 人未穿戴报警 + person_boxes = [box for box in frame_boxes if int(box.cls) == self.PERSON_CLASS_IDX] + head_boxes = [box for box in frame_boxes if int(box.cls) == self.HEAD_CLASS_IDX] + has_alarm = False + annotator = None + for person_box in person_boxes: + person_bbox = person_box.xyxy.cpu().squeeze() + # 检查这个人是否佩戴了安全帽 + has_helmet = True + person_head = get_person_head(person_bbox, head_boxes) + if person_head is not None: + has_helmet = any( + is_overlapping(person_head.xyxy.cpu().squeeze(), helmet.xyxy.cpu().squeeze()) + for helmet in object_boxes) + if not has_helmet: + has_alarm = True + if self.alarm_record_center.need_alarm(self.device.code, alarm_dict): + annotator = Annotator(deepcopy(frame)) if annotator is None else annotator + annotator.box_label(person_bbox, alarm_dict['label'], color=COLOR_RED, rotated=False) + + if has_alarm: + self.alarm_message_center.add_message(alarm_dict) + if self.alarm_record_center.need_alarm(self.device.code, alarm_dict): + for box in frame_boxes: + box_cls = box.cls + if box_cls != self.PERSON_CLASS_IDX and box_cls != self.HEAD_CLASS_IDX: + annotator.box_label(box.xyxy.cpu().squeeze(), + f"{self.model_classes[int(box.cls)]}", + color=COLOR_BLUE, + rotated=False) + self.alarm_record_center.upload_alarm_record(self.device.code, alarm_dict, + annotator.result()) + + def handle_break_in_alarm(self, frames, result_boxes, pred_ids, pred_names): + break_in_alarm_dicts = [d for d in ALARM_DICT if d['alarmCategory'] == 3] + for alarm_dict in break_in_alarm_dicts: + for idx, frame_boxes in enumerate(result_boxes): + frame = frames[idx] + frame_ids, frame_names = pred_ids[idx], pred_names[idx] + person_boxes = [box for box in frame_boxes if int(box.cls) == self.PERSON_CLASS_IDX] + head_boxes = [box for box in frame_boxes if int(box.cls) == self.HEAD_CLASS_IDX] + object_boxes = [box for box in frame_boxes if int(box.cls) in alarm_dict['class_idx']] + has_alarm = False + annotator = None + for person_box in person_boxes: + person_bbox = person_box.xyxy.cpu().squeeze() + person_id = person_box.id + if is_within_alert_range(person_bbox, self.abs_range_points): + has_object = True + person_head = get_person_head(person_bbox, head_boxes) + if person_head is not None: + overlap_ratio = intersection_area(person_bbox, person_head.xyxy.cpu().squeeze()) / bbox_area(person_bbox) + if overlap_ratio < 0.5: # 头占人<0.5,判断是否穿工服。不太准确 + has_object = any( + is_overlapping(person_head.xyxy.cpu().squeeze(), object_boxe.xyxy.cpu().squeeze()) + for object_boxe in object_boxes) + if not has_object: + self.tracking_status[person_box.id] = self.tracking_status.get(person_box.id, 0) + 1 + + has_alarm = True + if self.alarm_record_center.need_alarm(self.device.code, alarm_dict): + annotator = Annotator(deepcopy(frame)) if annotator is None else annotator + annotator.box_label(person_bbox, alarm_dict['label'], color=COLOR_RED, rotated=False) + + if has_alarm: + self.alarm_message_center.add_message(alarm_dict) + if self.alarm_record_center.need_alarm(self.device.code, alarm_dict): + self.alarm_record_center.upload_alarm_record(self.device.code, alarm_dict, + annotator.result()) + + def run(self): + while not self.stream_loader.init: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + self.stream_loader.init_cap() + for frames in self.stream_loader: + try: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + if not frames: + continue + + self.device_status_manager.set_status(device_id=self.device.id) + result_boxes, pred_ids, pred_names = self.model_predict(frames) # 结果都是二维数组,对应batch中的每个frame + # print(pred_names) + self.handle_behave_alarm(frames, result_boxes, pred_ids, pred_names) + self.handle_break_in_alarm(frames, result_boxes, pred_ids, pred_names) + + except Exception as ex: + traceback.print_exc() + logger.error(ex) diff --git a/scene_handler/limit_space_scene_handler.py b/scene_handler/limit_space_scene_handler.py index d1c5b9a..a7f1674 100644 --- a/scene_handler/limit_space_scene_handler.py +++ b/scene_handler/limit_space_scene_handler.py @@ -145,7 +145,7 @@ class LimitSpaceSceneHandler(BaseSceneHandler): - def __init__(self, device: Device, thread_id: str, tcp_manager: TcpManager, main_loop): + def __init__(self, device: Device, thread_id: str, tcp_manager: TcpManager, main_loop, range_points): super().__init__(device=device, thread_id=thread_id, tcp_manager=tcp_manager, main_loop=main_loop) # self.device = device # self.thread_id = thread_id @@ -287,20 +287,20 @@ if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 self.stream_loader.init_cap() - for frame in self.stream_loader: + for frames in self.stream_loader: try: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 # print('frame') - if frame is None: + if frames is None: continue self.device_status_manager.set_status(device_id=self.device.id) - result_boxes, pred_ids, pred_names = self.model_predict(frame) + # result_boxes, pred_ids, pred_names = self.model_predict(frames) frame_alarm = {} - self.process_alarm(frame, result_boxes, pred_ids, pred_names, frame_alarm) - self.process_labor(frame, result_boxes, pred_ids, pred_names) + # self.process_alarm(frame, result_boxes, pred_ids, pred_names, frame_alarm) + # self.process_labor(frame, result_boxes, pred_ids, pred_names) if len(frame_alarm.keys()) > 0: for key in frame_alarm.keys(): diff --git a/services/alarm_record_service.py b/services/alarm_record_service.py new file mode 100644 index 0000000..a0a86f9 --- /dev/null +++ b/services/alarm_record_service.py @@ -0,0 +1,46 @@ +import os +import uuid +from datetime import datetime + +import aiofiles +import cv2 +from sqlalchemy.ext.asyncio import AsyncSession + +from entity.alarm_record import AlarmRecordCreate, AlarmRecord + + +class AlarmRecordService: + def __init__(self, db: AsyncSession): + self.db = db + + async def add_alarm(self, alarm_data:AlarmRecordCreate, alarm_np_img): + async def save_alarm_file(): + # 生成当前年月日作为目录路径 + current_date = datetime.now().strftime('%Y-%m-%d') + # 生成随机 UUID 作为文件名 + file_name = f"{uuid.uuid4()}.jpeg" + # 创建保存图片的完整路径 + save_path = os.path.join('./storage/alarms', current_date, file_name) + # 创建目录(如果不存在) + os.makedirs(os.path.dirname(save_path), exist_ok=True) + + # 将 frame_data 转换为二进制数据(假设 frame_data 是一个 numpy 数组) + _, encoded_image = cv2.imencode('.jpeg', alarm_np_img) + image_data = encoded_image.tobytes() + + # 使用 aiofiles 进行异步写入 + async with aiofiles.open(save_path, 'wb') as f: + await f.write(image_data) + + return save_path + + # 异步保存图片文件 + alarm_img_path = await save_alarm_file() + + # 创建并保存到数据库中 + alarm_record = AlarmRecord.model_validate(alarm_data) + alarm_record.alarm_image = alarm_img_path + self.db.add(alarm_record) + await self.db.commit() + await self.db.refresh(alarm_record) + return alarm_record diff --git a/services/device_scene_relation_service.py b/services/device_scene_relation_service.py index 3416831..c94c540 100644 --- a/services/device_scene_relation_service.py +++ b/services/device_scene_relation_service.py @@ -39,18 +39,19 @@ if result_row: relation, scene = result_row scene_info = DeviceSceneRelationInfo( - id=relation.id, - device_id=relation.device_id, - scene_id=relation.scene_id, - scene_name=scene.name, - scene_version=scene.version, - scene_handle_task=scene.handle_task, - scene_remark=scene.remark, - ) + id=relation.id, + device_id=relation.device_id, + scene_id=relation.scene_id, + scene_name=scene.name, + scene_version=scene.version, + scene_handle_task=scene.handle_task, + scene_remark=scene.remark, + range_points=relation.range_points + ) return scene_info - async def add_relation_by_device(self, device_id: int, scene_id: int): - new_relation = DeviceSceneRelation(device_id=device_id, scene_id=scene_id) + async def add_relation_by_device(self, device_id: int, scene_id: int, range_points: str = None): + new_relation = DeviceSceneRelation(device_id=device_id, scene_id=scene_id, range_points=range_points) new_relation.create_time = datetime.now() new_relation.update_time = datetime.now() self.db.add(new_relation) @@ -64,8 +65,8 @@ await self.db.commit() return result.rowcount - async def update_relation_by_device(self, device_id: int, scene_id: int): + async def update_relation_by_device(self, device_id: int, scene_id: int, range_points: str = None): await self.delete_relation_by_device(device_id) - new_relation = await self.add_relation_by_device(device_id, scene_id) + new_relation = await self.add_relation_by_device(device_id, scene_id, range_points) self.notify_change(device_id, NotifyChangeType.DEVICE_SCENE_RELATION_UPDATE) return new_relation diff --git a/services/global_config.py b/services/global_config.py index fa53b2b..8baad1e 100644 --- a/services/global_config.py +++ b/services/global_config.py @@ -23,6 +23,7 @@ self.gas_push_config = None self.algo_result_push_config = None self.alarm_push_config = None + self.harmful_gas_push_config = None self._init_done = False async def _initialize(self): @@ -39,6 +40,7 @@ self.gas_push_config = await self.config_service.get_push_config(PUSH_TYPE.GAS) self.algo_result_push_config = await self.config_service.get_push_config(PUSH_TYPE.ALGO_RESULT) self.alarm_push_config = await self.config_service.get_push_config(PUSH_TYPE.ALARM) + self.harmful_gas_push_config = await self.config_service.get_push_config(PUSH_TYPE.HARMFUL_GAS) self._init_done = True async def on_config_change(self, config: PushConfig): @@ -48,6 +50,8 @@ await self.set_algo_result_push_config(config) elif config.push_type == PUSH_TYPE.ALARM: await self.set_alarm_push_config(config) + elif config.push_type == PUSH_TYPE.HARMFUL_GAS: + await self.set_harmful_gas_push_config(config) def get_gas_push_config(self): """获取 gas_push_config 配置""" @@ -78,3 +82,14 @@ if config: async with self._lock: self.alarm_push_config = config + + def get_harmful_gas_push_config(self): + """获取 algo_result_push_config 配置""" + return self.harmful_gas_push_config + + async def set_harmful_gas_push_config(self, config): + """设置 algo_result_push_config 配置""" + if config: + async with self._lock: + self.harmful_gas_push_config = config + diff --git a/tcp/tcp_client_connector.py b/tcp/tcp_client_connector.py index 29d2966..b3b8169 100644 --- a/tcp/tcp_client_connector.py +++ b/tcp/tcp_client_connector.py @@ -68,14 +68,16 @@ self.timeout = timeout # 连接/发送超时时间 self.is_connected = False # 连接状态标志 self.is_reconnecting = False + self.tasks_started = False self.message_queue = asyncio.Queue() #deque() self.gas_task = None self.read_lock = asyncio.Lock() # 添加锁 + self.lock = asyncio.Lock() self.push_ts_dict = {} async def connect(self): """连接到设备""" - while not self.is_connected: + while True: try: logger.info(f"正在连接到 {self.ip}:{self.port}...") # 设置连接超时 @@ -85,11 +87,12 @@ self.is_connected = True logger.info(f"已连接到 {self.ip}:{self.port}") - if self.gas_task is None: - self.gas_task = asyncio.create_task(self.process_message_queue()) # Start processing message queue + if not self.tasks_started: + asyncio.create_task(self.process_message_queue()) # Start processing message queue + asyncio.create_task(self.start_gas_query()) + self.tasks_started = True + break - # 一旦连接成功,开始发送查询指令 - await self.start_gas_query() except (asyncio.TimeoutError, ConnectionRefusedError, OSError) as e: logger.error(f"连接到 {self.ip}:{self.port} 失败,错误: {e}") logger.info(f"{self.reconnect_interval} 秒后将重连到 {self.ip}:{self.port}") @@ -102,31 +105,40 @@ return self.is_reconnecting = True await self.disconnect() # 先断开现有连接 - logger.info(f"Reconnecting to {self.ip}:{self.port} after {self.reconnect_interval} seconds") + # logger.info(f"Reconnecting to {self.ip}:{self.port} after {self.reconnect_interval} seconds") # await asyncio.sleep(self.reconnect_interval) # 等待n秒后重连 await self.connect() self.is_reconnecting = False async def disconnect(self): """断开设备连接,清理资源""" - if self.writer: - logger.info(f"Disconnecting from {self.ip}:{self.port}...") - try: - self.writer.close() - await self.writer.wait_closed() - except Exception as e: - logger.error(f"Error while disconnecting: {e}") - finally: - self.reader = None - self.writer = None - self.is_connected = False # 设置连接状态为 False - logger.info(f"Disconnected from {self.ip}:{self.port}") + # if self.writer: + # logger.info(f"Disconnecting from {self.ip}:{self.port}...") + # try: + # self.writer.close() + # await self.writer.wait_closed() + # except Exception as e: + # logger.error(f"Error while disconnecting: {e}") + # finally: + # self.reader = None + # self.writer = None + # self.is_connected = False # 设置连接状态为 False + # logger.info(f"Disconnected from {self.ip}:{self.port}") + async with self.lock: + if self.writer: + try: + self.writer.close() + await self.writer.wait_closed() + except (ConnectionResetError, BrokenPipeError) as e: + logger.exception(f"Error during disconnection") + self.reader = self.writer = None + logger.info(f"Disconnected from {self.ip}:{self.port}") async def start_gas_query(self): """启动甲烷查询指令,每n秒发送一次""" try: logger.info(f"Start querying gas from {self.ip}...") - while self.is_connected: + while True: await self.send_message(TREE_COMMAND.GAS_QUERY, have_response=True) await asyncio.sleep(self.query_interval) except (ConnectionResetError, asyncio.IncompleteReadError) as e: @@ -167,32 +179,33 @@ async def send_message(self, message: bytes, have_response=True): """Add a message to the queue for sending""" - self.message_queue.append((message, have_response)) + await self.message_queue.put((message, have_response)) logger.info(f"Message enqueued for {self.ip}:{self.port} {format_bytes(message)}") async def process_message_queue(self): """Process messages in the queue, retrying on failures""" - while self.is_connected: + while True: if self.message_queue: - message, have_response = self.message_queue.popleft() + message, have_response = await self.message_queue.get() await self._send_message_with_retry(message, have_response) else: await asyncio.sleep(1) # Small delay to prevent busy-waiting async def _send_message_with_retry(self, message: bytes, have_response): """Send a message with retries on failure""" - retry_attempts = 3 # Maximum retry attempts - for _ in range(retry_attempts): - if not self.is_connected: - await self.reconnect() - if not self.is_connected: - logger.error("Reconnection failed") - continue # Skip this attempt if reconnection fails + # retry_attempts = 3 # Maximum retry attempts + # for _ in range(retry_attempts): + # if not self.is_connected: + # await self.reconnect() + # if not self.is_connected: + # logger.error("Reconnection failed") + # continue # Skip this attempt if reconnection fails - try: - if self.writer is None or self.writer.is_closing(): - raise ConnectionResetError("No active connection or writer is closing") + try: + if self.writer is None or self.writer.is_closing(): + raise ConnectionResetError("No active connection or writer is closing") + async with self.lock: self.writer.write(message) await self.writer.drain() logger.info(f"Sent message to {self.ip}:{self.port}: {message}") @@ -203,13 +216,22 @@ await self.parse_response(data) return # Exit loop on success - except (asyncio.TimeoutError, ConnectionResetError, asyncio.IncompleteReadError, RuntimeError, - BrokenPipeError, OSError, EOFError, ConnectionAbortedError, ConnectionRefusedError) as e: - logger.exception("Failed to send message") - self.is_connected = False # Mark connection as disconnected - await self.reconnect() + except (asyncio.TimeoutError, ConnectionResetError, asyncio.IncompleteReadError, RuntimeError, + BrokenPipeError, OSError, EOFError, ConnectionAbortedError, ConnectionRefusedError) as e: + logger.exception("Failed to send message") + # self.is_connected = False # Mark connection as disconnected + await self.requeue_data(message, have_response) + await self.reconnect() + except Exception as e: + logger.exception(f"Unexpected error. Reconnecting and requeueing data...") + await self.requeue_data(message, have_response) + await self.reconnect() - logger.error("Max retry attempts reached, message sending failed") + # logger.error("Max retry attempts reached, message sending failed") + + async def requeue_data(self, data, have_response): + """Requeue the data that couldn't be sent to avoid data loss.""" + await self.send_message(data, have_response) # async def send_message(self, message: bytes, have_response=True): # """发送自定义消息的接口,供其他类调用""" diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py deleted file mode 100644 index 9109d75..0000000 --- a/algo/algo_runner_manager.py +++ /dev/null @@ -1,8 +0,0 @@ -from algo.algo_runner import AlgoRunner - -# algo_runner = AlgoRunner() - - -def get_algo_runner(): - pass - # return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 6a86830..d2c647e 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -3,7 +3,7 @@ from copy import deepcopy from dataclasses import dataclass import importlib -from datetime import datetime +from datetime import datetime, timedelta from threading import Event from typing import List, Dict @@ -48,6 +48,8 @@ self.__stop_event = Event() # 使用 Event 控制线程的运行状态 self.frame_ts = None self.push_ts = None + self.frames_detected = 0 + self.fps_ts = None self.thread_id = thread_id self.device_frame_service = DeviceFrameService(db_session) @@ -77,37 +79,40 @@ return True return False - def save_frame_results(self, frame, results_map): - if not self.check_frame_interval(): - return - # device_frame = self.device_frame_service.add_frame(self.device.id, frame) - # frame_id = device_frame.id + def save_frame_results(self, frames, result_maps): + for idx,frame in enumerate(frames): + results_map = result_maps[idx] - future = asyncio.run_coroutine_threadsafe( - self.device_frame_service.add_frame(self.device.id, frame), self.main_loop - ) - device_frame = future.result() - frame_id = device_frame.id + if not self.check_frame_interval(): + return + # device_frame = self.device_frame_service.add_frame(self.device.id, frame) + # frame_id = device_frame.id - logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') - frame_results = [] - for model_exec_id, results in results_map.items(): - for r in results: - frame_result = FrameAnalysisResultCreate( - device_id=self.device.id, - frame_id=frame_id, - algo_model_id=model_exec_id, - object_class_id=r.object_class_id, - object_class_name=r.object_class_name, - confidence=round(r.confidence, 4), - location=r.location, - ) - frame_results.append(frame_result) - asyncio.run_coroutine_threadsafe( - self.frame_analysis_result_service.add_frame_analysis_results(frame_results), - self.main_loop - ) - self.thread_pool.submit_task(self.push_frame_results, frame_results) + future = asyncio.run_coroutine_threadsafe( + self.device_frame_service.add_frame(self.device.id, frame), self.main_loop + ) + device_frame = future.result() + frame_id = device_frame.id + + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') + frame_results = [] + for model_exec_id, results in results_map.items(): + for r in results: + frame_result = FrameAnalysisResultCreate( + device_id=self.device.id, + frame_id=frame_id, + algo_model_id=model_exec_id, + object_class_id=r.object_class_id, + object_class_name=r.object_class_name, + confidence=round(r.confidence, 4), + location=r.location, + ) + frame_results.append(frame_result) + asyncio.run_coroutine_threadsafe( + self.frame_analysis_result_service.add_frame_analysis_results(frame_results), + self.main_loop + ) + self.thread_pool.submit_task(self.push_frame_results, frame_results) def push_frame_results(self, frame_results): global_config = GlobalConfig() @@ -124,31 +129,63 @@ ) self.push_ts = current_time # 更新推送时间戳 + def log_fps(self, frame_count): + self.frames_detected += frame_count + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_detected / 10.0 + self.frames_detected = 0 + logger.info(f"FPS (detect) for device {self.device.code}: {fps}") + self.fps_ts = current_time + def run(self): while not self.stream_loader.init: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 self.stream_loader.init_cap() - for frame in self.stream_loader: + for frames in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 # print('frame') - if frame is None: + if frames is None or len(frames)<=0: continue self.device_status_manager.set_status(device_id=self.device.id) - results_map = {} - annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + + result_maps = [] # 保存每个frame的result map + annotators = [] + for frame in frames: + annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") + annotators.append(annotator) + for model_exec in self.model_exec_list: handle_task_name = model_exec.algo_model_info.handle_task handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) handler_instance = handler_cls(model_exec) - frame, results = handler_instance.run(frame, annotator) - results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] + frames, results = handler_instance.run(frames, annotators) + + # 遍历检测结果,按帧存储 + for frame_idx, r in enumerate(results): # 遍历每帧结果 + if len(result_maps) <= frame_idx: # 初始化 frame_results_map + result_maps.append({}) + frame_results_map = result_maps[frame_idx] + # 为当前模型存储检测结果 + if model_exec.algo_model_id not in frame_results_map: + frame_results_map[model_exec.algo_model_id] = [] + # 添加检测结果 + frame_results_map[model_exec.algo_model_id].extend( + DetectionResult.from_dict(box) for box in r + ) # 结果处理 - self.thread_pool.submit_task(self.save_frame_results, frame, results_map) - self.display_frame_manager.add_frame(self.device.id, annotator.result()) + self.thread_pool.submit_task(self.save_frame_results, frames, result_maps) + for annotator in annotators: + self.display_frame_manager.add_frame(self.device.id, annotator.result()) + + self.log_fps(len(frames)) + + # future = asyncio.run_coroutine_threadsafe( # self.display_frame_manager.add_frame(self.device.id, annotator.result()), self.main_loop # ) diff --git a/algo/model_manager.py b/algo/model_manager.py index 1c30867..e84676f 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -19,11 +19,12 @@ class ModelManager: - def __init__(self, model_service: ModelService, model_warm_up=5): + def __init__(self, model_service: ModelService, model_warm_up=5, batch_size=4): # self.db = db self.model_service = model_service self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up + self.batch_size=4 async def query_model_inuse(self): algo_model_list = list(await self.model_service.get_models_in_use()) @@ -58,8 +59,11 @@ if not isinstance(imgsz, list): imgsz = [imgsz, imgsz] dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + dummy_inputs = [] + for _ in range(self.batch_size): + dummy_inputs.append(dummy_input) for i in range(self.model_warm_up): - algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + algo_model_exec.algo_model_exec.predict(source=dummy_inputs, imgsz=imgsz, verbose=False,save=False,save_txt=False) logger.info(f'warm up model {model_name} success!') logger.info(f'load model {model_name} success!') diff --git a/algo/scene_runner.py b/algo/scene_runner.py index e5c067e..5120f96 100644 --- a/algo/scene_runner.py +++ b/algo/scene_runner.py @@ -69,7 +69,7 @@ try: handle_task_name = scene.scene_handle_task handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) - handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) + handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop,scene.range_points) future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) self.device_tasks[device.id] = handler_instance logger.info(f'start thread {thread_id}, device info: {device}') diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 4133388..202213c 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -1,13 +1,20 @@ +from datetime import datetime, timedelta + import cv2 import time import numpy as np from threading import Thread, Event + +import queue + from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool class OpenCVStreamLoad: def __init__(self, camera_url, camera_code, device_thread_id = '', + batch_size=4, + queue_size=100, retry_interval=1, vid_stride=1): assert camera_url is not None and camera_url != '' @@ -19,10 +26,15 @@ self.init = False self.frame = None + self.frame_queue = queue.Queue(maxsize=queue_size) + self.fps_ts = None self.cap = None + self.batch_size = batch_size + self.frames_read = 0 self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 self.thread_pool = GlobalThreadPool() + self.init_cap() # self.thread = Thread(target=self.update, daemon=True) # self.thread.start() @@ -44,7 +56,14 @@ 尝试创建视频流捕获对象。 """ try: - cap = cv2.VideoCapture(self.url) + # cap = cv2.VideoCapture(self.url) + gst_pipeline = ( + f"rtspsrc location={self.url} ! " + f"rtph264depay ! h264parse ! " + f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + f"appsink" + ) + cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) # 可以在这里设置cap的一些属性,如果需要的话 return cap except Exception as e: @@ -74,9 +93,13 @@ cap = cv2.VideoCapture(self.url) # cap = cv2.VideoCapture(self.url, cv2.CAP_FFMPEG) - # 使用 GStreamer pipeline 作为 VideoCapture 的输入 - # gst_str = f"rtspsrc location={self.url} ! decodebin ! videoconvert ! appsink" - # cap = cv2.VideoCapture(gst_str, cv2.CAP_GSTREAMER) + # gst_pipeline = ( + # f"rtspsrc location={self.url} ! " + # f"rtph264depay ! h264parse ! " + # f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! " + # f"appsink" + # ) + # cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER) if cap.isOpened(): logger.info(f"{self.url} connected successfully!") @@ -91,6 +114,15 @@ return cap + def log_fps(self): + current_time = datetime.now() + # 每秒输出 FPS + if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): + fps = self.frames_read / 10.0 + self.frames_read = 0 + logger.info(f"FPS (read) for device {self.camera_code}: {fps}") + self.fps_ts = current_time + def update(self): vid_n = 0 log_n = 0 @@ -111,6 +143,9 @@ else: vid_n += 1 self.frame = frame + self.frames_read += 1 + if not self.frame_queue.full(): + self.frame_queue.put(frame) if log_n % 1000 == 0: logger.debug(f'{self.url} cap success') log_n = (log_n + 1) % 250 @@ -120,12 +155,22 @@ self.cap.release() self.frame = None self.cap = self.get_connect() # 尝试重新连接 + self.log_fps() def __iter__(self): return self def __next__(self): - return self.frame + batch_frames = [] + + queue_length = self.frame_queue.qsize() + if queue_length < self.batch_size: + return [] + + while not self.frame_queue.empty() and len(batch_frames) < self.batch_size: + frame = self.frame_queue.get() + batch_frames.append(frame) + return batch_frames def stop(self): logger.info(f'stop stream loader {self.url},device thread {self.device_thread_id} stopped') diff --git a/apis/control.py b/apis/control.py index 9c6d6fb..2d351d8 100644 --- a/apis/control.py +++ b/apis/control.py @@ -9,6 +9,7 @@ from fastapi import APIRouter import docker import socket +import signal from apis.base import standard_error_response, standard_response from common.global_logger import logger @@ -47,6 +48,23 @@ return None +def handle_sigterm(signum, frame): + print(f"Received signal: {signum}, shutting down gracefully...") + os._exit(0) + + +signal.signal(signal.SIGTERM, handle_sigterm) + + +def shutdown(): + # 等待短暂时间以确保响应已发送 + time.sleep(1) + logger.info("Shutting down the application...") + # sys.exit(0) + # os.kill(os.getpid(), signal.SIGTERM) + os._exit(0) + + @router.get("/restart") async def restart(): try: @@ -92,17 +110,20 @@ logger.error(f"Failed to restart container asynchronously: {ex}") # 在新线程中启动重启操作 - threading.Thread(target=restart_container_async).start() + # threading.Thread(target=restart_container_async).start() + threading.Thread(target=shutdown).start() return standard_response() except Exception as e: traceback.print_exc() return standard_error_response(code=500, message=f"Failed to restart container: {e}") + @router.get("/sync_test") def sync_test(): return standard_response() + @router.get("/async_test") async def async_test(): return standard_response() diff --git a/apis/device_scene_realtion.py b/apis/device_scene_realtion.py index d9aa3a2..77d3677 100644 --- a/apis/device_scene_realtion.py +++ b/apis/device_scene_realtion.py @@ -22,8 +22,8 @@ @router.post("/update_by_device", response_model=StandardResponse[DeviceSceneRelation]) -async def update_by_device(device_id: int, scene_id: int, +async def update_by_device(device_id: int, scene_id: int, range_points: str = None, db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - relation = await service.update_relation_by_device(device_id, scene_id) + relation = await service.update_relation_by_device(device_id, scene_id,range_points) return standard_response(data=relation) diff --git a/app_instance.py b/app_instance.py index 82a750d..c648f6e 100644 --- a/app_instance.py +++ b/app_instance.py @@ -1,4 +1,5 @@ import asyncio + from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException @@ -17,6 +18,7 @@ from services.scene_service import SceneService from services.schedule_job import start_scheduler from tcp.tcp_manager import TcpManager +from tcp.tcp_server import start_server _app = None # 创建一个私有变量来存储 app 实例 @@ -69,6 +71,8 @@ app.state.scene_runner = scene_runner await scene_runner.start() + main_loop.create_task(start_server()) + main_loop.create_task(start_scheduler()) yield # 允许请求处理 diff --git a/common/byte_utils.py b/common/byte_utils.py index 301bb8a..ea3f4ea 100644 --- a/common/byte_utils.py +++ b/common/byte_utils.py @@ -1,3 +1,11 @@ def format_bytes(data: bytes): - return ''.join(f'\\x{byte:02X}' for byte in data) \ No newline at end of file + return ''.join(f'\\x{byte:02X}' for byte in data) + + +if __name__ == '__main__': + msg = b'\xaa\x01\x00\x93\x12\x00\xA6' + print(format_bytes(msg)) + + msg = b'\xAA\x01\x00\x95\x00\x00\x96' + print(format_bytes(msg)) \ No newline at end of file diff --git a/common/consts.py b/common/consts.py index 5a3286c..6deaedf 100644 --- a/common/consts.py +++ b/common/consts.py @@ -36,3 +36,4 @@ GAS = 1 ALGO_RESULT = 2 ALARM = 3 + HARMFUL_GAS = 4 diff --git a/common/detect_utils.py b/common/detect_utils.py new file mode 100644 index 0000000..33c340c --- /dev/null +++ b/common/detect_utils.py @@ -0,0 +1,72 @@ +import matplotlib.path as mpath + + +def is_within_alert_range(bbox, range): + if range is None: + return True + bottom_left = [bbox[0], bbox[3]] + bottom_right = [bbox[2], bbox[3]] + + in_fence = is_point_in_polygon(bottom_left, range) or is_point_in_polygon(bottom_right, range) + return in_fence + + +def is_point_in_polygon(point, polygon): + """ + 判断点是否在多边形内 + + 参数: + - point: 待判断的点,格式为(x, y) + - polygon: 多边形的顶点坐标列表,格式为[(x1, y1), (x2, y2), ...] + + 返回: + - 如果点在多边形内,返回True;否则,返回False。 + """ + # 创建多边形路径对象 + # polygon = np.array(polygon, dtype=np.int32) + # + # hull = ConvexHull(polygon) + # sorted_coordinates = polygon[hull.vertices] + + poly_path = mpath.Path(polygon) + # 使用contains_point方法检查点是否在多边形内 + return poly_path.contains_point(point) + + +def intersection_area(bbox1, bbox2): + # 计算两个坐标框的重叠面积 + x1, y1, x2, y2 = bbox1 + x3, y3, x4, y4 = bbox2 + + xi1 = max(x1, x3) + yi1 = max(y1, y3) + xi2 = min(x2, x4) + yi2 = min(y2, y4) + + width = max(0, xi2 - xi1) + height = max(0, yi2 - yi1) + + return width * height + + +def bbox_area(bbox): + # 计算坐标框的面积 + x1, y1, x2, y2 = bbox + return (x2 - x1) * (y2 - y1) + + +def get_person_head(person_bbox, head_bboxes): + best_head = None + max_overlap = 0 + + for head in head_bboxes: + head_bbox = head.xyxy.cpu().squeeze() + overlap_area = intersection_area(person_bbox, head_bbox) + head_area = bbox_area(head_bbox) + overlap_ratio = overlap_area / head_area + + if overlap_ratio >= 0.8 and (best_head is None or overlap_area > max_overlap or ( + overlap_area == max_overlap and float(head.conf) > float(best_head.conf))): + best_head = head + max_overlap = overlap_area + return best_head diff --git a/common/harmful_gas_manager.py b/common/harmful_gas_manager.py new file mode 100644 index 0000000..297e02c --- /dev/null +++ b/common/harmful_gas_manager.py @@ -0,0 +1,44 @@ +import threading +from datetime import datetime + + +class HarmfulGasManager: + + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + # 确保线程安全的单例模式 + if not cls._instance: + with cls._lock: + if not cls._instance: + cls._instance = super(HarmfulGasManager, cls).__new__(cls) + return cls._instance + + def __init__(self): + # 初始化一次,避免重复初始化 + if not hasattr(self, "device_data"): + self.device_data = {} + self.lock = threading.Lock() + + def get_device_all_data(self, device_code): + """获取指定设备的在线状态""" + with self.lock: + return self.device_data.get(device_code, None) + + def get_device_data(self, device_code, gas_type): + """获取指定设备的在线状态""" + with self.lock: + device_data = self.device_data.get(device_code, None) + if device_data: + return device_data.get(gas_type, None) + return None + + def set_device_data(self, device_code, gas_type, gas_data): + """设置指定设备的在线状态""" + with self.lock: + device_data = self.device_data.get(device_code, None) + if not device_data: + device_data = {} + self.device_data[device_code] = device_data + device_data[gas_type] = gas_data \ No newline at end of file diff --git a/common/http_utils.py b/common/http_utils.py index 5f15707..a49c6ea 100644 --- a/common/http_utils.py +++ b/common/http_utils.py @@ -22,3 +22,11 @@ logger.info(f"Response: {response.status_code}, {response.text}") except requests.RequestException as e: logger.error(f"Failed to push data: {e}") + + +def get_request(url, headers=None): + try: + response = requests.get(url, headers=headers) + return response.json() + except requests.RequestException as e: + logger.error(f"Failed to get data: {e}") diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db index d1acdd4..694b825 100644 --- a/db/safe-algo-pro.db +++ b/db/safe-algo-pro.db Binary files differ diff --git a/entity/alarm_record.py b/entity/alarm_record.py new file mode 100644 index 0000000..1ebe474 --- /dev/null +++ b/entity/alarm_record.py @@ -0,0 +1,32 @@ +from datetime import datetime +from typing import Optional + +from sqlmodel import Field, SQLModel + + +class AlarmRecordBase(SQLModel): + # todo tree code?? alarm_category?? + device_code: str + alarm_type: int + alarm_content: str + alarm_value: Optional[float] = None + alarm_image: Optional[str] = None + alarm_time: datetime = Field(default_factory=datetime.now) + + class Config: + json_encoders = { + datetime: lambda v: v.strftime('%Y-%m-%d %H:%M:%S') + } + + +class AlarmRecord(AlarmRecordBase, table=True): + __tablename__ = 'alarm_record' + id: Optional[int] = Field(default=None, primary_key=True) + + +class AlarmRecordCreate(AlarmRecordBase): + pass + + +class AlarmRecordInfo(AlarmRecordBase): + id: int diff --git a/entity/device_scene_relation.py b/entity/device_scene_relation.py index c4543d7..38481ef 100644 --- a/entity/device_scene_relation.py +++ b/entity/device_scene_relation.py @@ -8,6 +8,7 @@ class DeviceSceneRelationBase(SQLModel): scene_id: int device_id: int + range_points: Optional[str] = None class DeviceSceneRelation(DeviceSceneRelationBase, TimestampMixin, table=True): diff --git a/model_handler/base_model_handler.py b/model_handler/base_model_handler.py index 46e0d53..c7c3490 100644 --- a/model_handler/base_model_handler.py +++ b/model_handler/base_model_handler.py @@ -8,41 +8,48 @@ def __init__(self, model: AlgoModelExec): self.model = model self.model_names = model.algo_model_exec.names + self.model_ids = list(self.model_names.keys()) def pre_process(self, frame): return frame - def model_inference(self, frame): - results_generator = self.model.algo_model_exec.predict(source=frame, imgsz=self.model.input_size, + def model_inference(self, frames): + results_generator = self.model.algo_model_exec.predict(source=frames, imgsz=self.model.input_size, save_txt=False, save=False, - verbose=False, stream=True) - results = list(results_generator) # 确保生成器转换为列表 - result = results[0] - # logger.debug(f"model {self.model.algo_model_info.name} result: {len(result)}") - return result + verbose=True, stream=True) + result_boxes = [] + for r in results_generator: + result_boxes.append(r.boxes) - def post_process(self, frame, model_result, annotator): + return result_boxes + + def post_process(self, frames, model_results, annotators): results = [] - for box in model_result.boxes: - results.append( - { - 'object_class_id': int(box.cls), - 'object_class_name': self.model_names[int(box.cls)], - 'confidence': float(box.conf), - 'location': ", ".join([f"{x:.6f}" for x in box.xyxyn.cpu().squeeze().tolist()]) - } - ) - if annotator is not None: - for s_box in model_result.boxes: - annotator.box_label(s_box.xyxy.cpu().squeeze(), - f"{self.model_names[int(s_box.cls)]} {float(s_box.conf):.2f}", - color=colors(int(s_box.cls)), - rotated=False) + for idx,frame in enumerate(frames): + frame_boxes = model_results[idx] + annotator = annotators[idx] + frame_result = [] + for box in frame_boxes: + frame_result.append( + { + 'object_class_id': int(box.cls), + 'object_class_name': self.model_names[int(box.cls)], + 'confidence': float(box.conf), + 'location': ", ".join([f"{x:.6f}" for x in box.xyxyn.cpu().squeeze().tolist()]) + } + ) + results.append(frame_result) + if annotator is not None: + for s_box in frame_boxes: + annotator.box_label(s_box.xyxy.cpu().squeeze(), + f"{self.model_names[int(s_box.cls)]} {float(s_box.conf):.2f}", + color=colors(int(s_box.cls)), + rotated=False) return results - def run(self, frame, annotator): - processed_frame = self.pre_process(frame=frame) - model_result = self.model_inference(frame=processed_frame) - result = self.post_process(frame=frame, model_result=model_result, annotator=annotator) - return frame, result + def run(self, frames, annotators): + processed_frames = [self.pre_process(frame=frame) for frame in frames] + result_boxes = self.model_inference(frames=processed_frames) + results = self.post_process(frames=frames, model_results=result_boxes, annotators=annotators) + return frames, results diff --git a/model_handler/coco_engine_model_handler.py b/model_handler/coco_engine_model_handler.py index e6023d0..ec8006f 100644 --- a/model_handler/coco_engine_model_handler.py +++ b/model_handler/coco_engine_model_handler.py @@ -88,3 +88,4 @@ 78: 'hair_drier', 79: 'toothbrush', } + self.model_ids = list(self.model_names.keys()) diff --git a/model_handler/labor_engine_model_handler.py b/model_handler/labor_engine_model_handler.py index d0fefba..8892fcb 100644 --- a/model_handler/labor_engine_model_handler.py +++ b/model_handler/labor_engine_model_handler.py @@ -68,3 +68,4 @@ 58: '鼓风机', } + self.model_ids = list(self.model_names.keys()) diff --git a/run.sh b/run.sh index bb36274..815883a 100644 --- a/run.sh +++ b/run.sh @@ -8,4 +8,5 @@ export LC_ALL=C.UTF-8 cd /code/safe-algo-pro -python3 main.py +exec python3 main.py +#tail -f /dev/null \ No newline at end of file diff --git a/scene_handler/alarm_message_center.py b/scene_handler/alarm_message_center.py new file mode 100644 index 0000000..fba1397 --- /dev/null +++ b/scene_handler/alarm_message_center.py @@ -0,0 +1,121 @@ +import asyncio +import copy +import time +from collections import deque +from threading import Thread, Lock + +''' +队列消息取出规则: +- 按 alarmCategory(优先级:2 > 1 > 3 > 0)和 category_order 从小到大排序。 +- 确保每个 alarmCategory 在 10 秒内只能发送一次。 + +消息定时清理: +- 每次取消息时会检查队列中长期堆积、超过一定时间(比如 30 秒)的消息,并清理这些消息。 +''' + + +class AlarmMessageCenter: + def __init__(self, device_id, main_loop, tcp_manager=None, message_send_interval=5, category_interval=10, retention_time=30, + category_priority=None): + self.device_id = device_id + self.tcp_manager = tcp_manager + self.main_loop = main_loop + + self.queue = deque() + self.last_sent_time = {} # 记录每个 alarmCategory 的最后发送时间 + self.lock = Lock() + self.message_send_interval = message_send_interval # 消息发送间隔(秒) + self.category_interval = category_interval # 类别发送间隔(秒) + self.retention_time = retention_time # 消息最长保留时间(秒) + if category_priority: + self.category_priority = category_priority + self.auto_update_alarm_priority = False + else: + self.category_priority = category_priority + self.auto_update_alarm_priority = True + + def add_message(self, message_ori): + message = copy.deepcopy(message_ori) + message['timestamp'] = int(time.time()) # 添加消息放入队列的时间 + with self.lock: + self.queue.append(message) + + # 动态更新优先级映射 + if self.auto_update_alarm_priority: + alarm_category = message['alarmCategory'] + if alarm_category not in self.category_priority: + unique_categories = sorted({msg['alarmCategory'] for msg in self.queue}) + self.category_priority = {cat: idx for idx, cat in enumerate(unique_categories)} + print(f"更新优先级映射: {self.category_priority}") + + def _clean_old_messages(self): + """清理长期堆积的消息""" + print(f'清理长期堆积的消息 (队列长度: {len(self.queue)})') + now = time.time() + with self.lock: + self.queue = deque([msg for msg in self.queue if now - msg['timestamp'] <= self.retention_time]) + print(f'清理后的队列长度: {len(self.queue)}') + + # def _get_next_message(self): + # """按优先级和时间规则取出下一条消息""" + # now = time.time() + # with self.lock: + # # 按照规则排序队列 + # sorted_queue = sorted( + # self.queue, + # key=lambda x: ( + # self.alarm_priority.get(x['alarmCategory'], 4), # 按 alarmCategory 优先级排序 + # x['category_order'], # category_order 小的排前面 + # x['timestamp'] # 时间靠前的排前面 + # ) + # ) + # # 遍历排序后的队列,找到符合规则的消息 + # for msg in sorted_queue: + # alarm_category = msg['alarmCategory'] + # if alarm_category not in self.last_sent_time or now - self.last_sent_time[alarm_category] > 10: + # self.queue.remove(msg) + # self.last_sent_time[alarm_category] = now + # return msg + # return None + + def _get_next_message(self): + """按优先级和时间规则取出下一条消息""" + now = time.time() + with self.lock: + # 按优先级依次检查 + for priority in sorted(self.category_priority.values()): + found_valid_message = False # 用于标记当前优先级是否有消息被处理 + for msg in sorted( + (m for m in self.queue if self.category_priority.get(m['alarmCategory'], 99) == priority), + key=lambda x: (-x['timestamp'], x['category_order']) + ): + alarm_category = msg['alarmCategory'] + # 检查是否符合发送条件 + if alarm_category not in self.last_sent_time or now - self.last_sent_time[ + alarm_category] > self.category_interval: + self.queue.remove(msg) + self.last_sent_time[alarm_category] = now + return msg # 找到符合条件的消息立即返回 + found_valid_message = True # 当前优先级存在消息但不符合条件 + # 如果当前优先级的所有消息都被检查过且不符合条件,跳到下一个优先级 + if not found_valid_message: + continue + return None # 如果所有优先级都没有消息符合条件,则返回 None + + def process_messages(self): + while True: + time.sleep(self.message_send_interval) + self._clean_old_messages() # 清理长期堆积的消息 + next_message = self._get_next_message() + if next_message: + self.send_message(next_message) + + def send_message(self, message): + """发送报警消息""" + print(f"发送报警消息: {message['alarmContent']} (类别: {message['alarmCategory']}, 时间: {message['timestamp']})") + if self.tcp_manager: + asyncio.run_coroutine_threadsafe( + self.tcp_manager.send_message_to_device(device_id=self.device_id, + message=message['alarmSoundMessage'], + have_response=False), + self.main_loop) diff --git a/scene_handler/alarm_record_center.py b/scene_handler/alarm_record_center.py new file mode 100644 index 0000000..5373cd5 --- /dev/null +++ b/scene_handler/alarm_record_center.py @@ -0,0 +1,128 @@ +import asyncio +import base64 +import time +from datetime import datetime + +import cv2 + +from common.global_thread_pool import GlobalThreadPool +from common.http_utils import send_request +from db.database import get_db +from entity.alarm_record import AlarmRecordCreate +from services.alarm_record_service import AlarmRecordService +from services.global_config import GlobalConfig + + +def image_to_base64(numpy_image, format='jpg'): + # 将NumPy数组转换为图片格式的字节流 + # format指定图片格式,'png'或'jpeg' + success, encoded_image = cv2.imencode(f'.{format}', numpy_image) + if not success: + raise ValueError("Could not encode image") + + # 将字节流转换为Base64编码 + base64_encoded_image = base64.b64encode(encoded_image) + + # 将bytes类型转换为UTF-8字符串 + base64_message = base64_encoded_image.decode('utf-8') + + return base64_message + # return f'data:image/{format};base64,{base64_message}' + + +class AlarmRecordCenter: + def __init__(self, save_interval=-1, main_loop=None): + """ + 初始化报警记录中心 + :param upload_interval: 报警上传间隔(秒),如果 <= 0,则不报警 + """ + self.main_loop = main_loop + self.save_interval = save_interval + self.thread_pool = GlobalThreadPool() + self.global_config = GlobalConfig() + # self.upload_interval = upload_interval + self.device_alarm_upload_time = {} # key: device_code, value: {alarm_type: last upload time} + self.device_alarm_save_time = {} + + def need_alarm(self, device_code, alarm_dict): + """ + 是否需要报警 + :param device_code: 设备编号 + :param alarm_dict: 报警类型字典 + :return: 是否需要报警 + """ + current_time = time.time() + alarm_type = alarm_dict['alarmType'] + + return self.need_save_alarm_record(device_code, current_time, alarm_type) \ + or self.need_upload_alarm_record(device_code, current_time, alarm_type) + + def need_save_alarm_record(self, device_code, current_time, alarm_type): + if self.save_interval <= 0: + return False + + if device_code not in self.device_alarm_save_time: + self.device_alarm_save_time[device_code] = {} + last_save_time = self.device_alarm_save_time[device_code].get(alarm_type) + if last_save_time is None or (current_time - last_save_time) > self.save_interval: + return True + return False + + def need_upload_alarm_record(self, device_code, current_time, alarm_type): + push_config = self.global_config.get_alarm_push_config() + if not push_config or not push_config.push_url: + return False + + if device_code not in self.device_alarm_upload_time: + self.device_alarm_upload_time[device_code] = {} + last_upload_time = self.device_alarm_upload_time[device_code].get(alarm_type) + if last_upload_time is None or (current_time - last_upload_time) > push_config.upload_interval: + return True + return False + + async def save_record(self, alarm_data, alarm_np_img): + async for db in get_db(): + alarm_record_service = AlarmRecordService(db) + await alarm_record_service.add_alarm(alarm_data=alarm_data, alarm_np_img=alarm_np_img) + + def upload_alarm_record(self, device_code, alarm_dict, alarm_np_img=None, alarm_value=None): + """ + 上传报警记录 + :param alarm_value: + :param device_code: 设备编号 + :param alarm_dict: 报警类型字典 + :param alarm_np_img: 报警图片(np array) + """ + # 获取当前时间 + current_time = time.time() + alarm_type = alarm_dict['alarmType'] + + if self.need_save_alarm_record(device_code, current_time, alarm_type): + alarm_record_data = AlarmRecordCreate( + device_code=device_code, + device_id=100, + alarm_type=alarm_type, + alarm_content=alarm_dict['alarmContent'], + alarm_time=datetime.now(), + alarm_value=alarm_value if alarm_value else None, + ) + asyncio.run_coroutine_threadsafe( + self.save_record(alarm_record_data, alarm_np_img), + self.main_loop) + self.device_alarm_save_time[device_code][alarm_type] = current_time + + if self.need_upload_alarm_record(device_code, current_time, alarm_type): + alarm_record = { + 'devcode': device_code, + 'alarmTime': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), + 'alarmType': alarm_dict['alarmType'], + 'alarmContent': alarm_dict['alarmContent'], + } + if alarm_value: + alarm_record['alarmValue'] = alarm_value + if alarm_np_img: + alarm_record['alarmImage'] = image_to_base64(alarm_np_img) + + push_config = self.global_config.get_alarm_push_config() + self.thread_pool.submit_task(send_request, push_config.upload_url, alarm_record) + self.device_alarm_upload_time[device_code][alarm_type] = current_time diff --git a/scene_handler/block_scene_handler.py b/scene_handler/block_scene_handler.py new file mode 100644 index 0000000..3dd0b31 --- /dev/null +++ b/scene_handler/block_scene_handler.py @@ -0,0 +1,426 @@ +import time +import traceback +from asyncio import Event +from copy import deepcopy +from datetime import datetime + + +from flatbuffers.builder import np +from scipy.spatial import ConvexHull + +from algo.stream_loader import OpenCVStreamLoad +from common.detect_utils import is_within_alert_range, get_person_head, intersection_area, bbox_area +from common.device_status_manager import DeviceStatusManager +from common.global_logger import logger +from common.global_thread_pool import GlobalThreadPool +from common.harmful_gas_manager import HarmfulGasManager +from common.image_plotting import Annotator +from entity.device import Device +from scene_handler.alarm_message_center import AlarmMessageCenter +from scene_handler.alarm_record_center import AlarmRecordCenter +from scene_handler.base_scene_handler import BaseSceneHandler +from scene_handler.limit_space_scene_handler import is_overlapping +from services.global_config import GlobalConfig +from tcp.tcp_manager import TcpManager + +from entity.device import Device +from common.http_utils import get_request +from ultralytics import YOLO + +''' +alarmCategory: +0 行为监管 +1 环境监管 +2 人员监管 +3 围栏监管 + +handelType: +0 检测到报警 +1 未检测到报警 +2 人未穿戴报警 +3 其他 +''' +ALARM_DICT = [ + { + 'alarmCategory': 0, + 'alarmType': '1', + 'handelType': 1, + 'category_order': 1, + 'class_idx': [34], + 'alarm_name': 'no_fire_extinguisher', + 'alarmContent': '未检测到灭火器', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '', + }, + { + 'alarmCategory': 0, + 'alarmType': '2', + 'handelType': 1, + 'category_order': 2, + 'class_idx': [43], + 'alarm_name': 'no_barrier_tape', + 'alarmContent': '未检测到警戒线', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '', + }, + { + 'alarmCategory': 0, + 'alarmType': '3', + 'handelType': 1, + 'category_order': 3, + 'class_idx': [48], + 'alarm_name': 'no_cone', + 'alarmContent': '未检测到锥桶', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '', + }, + { + 'alarmCategory': 0, + 'alarmType': '4', + 'handelType': 1, + 'category_order': 4, + 'class_idx': [4, 5, 16], + 'alarm_name': 'no_board', + 'alarmContent': '未检测到指示牌', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '', + }, + { + 'alarmCategory': 0, + 'alarmType': '5', + 'handelType': 2, + 'category_order': -1, + 'class_idx': [18], + 'alarm_name': 'no_helmet', + 'alarmContent': '未佩戴安全帽', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '未佩戴安全帽', + }, + # todo 明火 + { + 'alarmCategory': 1, + 'alarmType': '7', + 'handelType': 3, + 'category_order': 1, + 'class_idx': [], + 'alarm_name': 'gas_alarm', + 'alarmContent': '甲烷浓度超限', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '', + }, + { + 'alarmCategory': 1, + 'alarmType': '8', + 'handelType': 3, + 'category_order': 2, + 'class_idx': [], + 'alarm_name': 'harmful_alarm', + 'alarmContent': '有害气体浓度超标', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '', + }, + { + 'alarmCategory': 2, + 'alarmType': '9', + 'handelType': 3, + 'category_order': -1, + 'class_idx': [], + 'alarm_name': 'health_alarm', + 'alarmContent': '心率血氧异常', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '', + }, + { + 'alarmCategory': 3, + 'alarmType': '10', + 'handelType': 2, + 'category_order': 4, + 'class_idx': [24], + 'alarm_name': 'break_in_alarm', + 'alarmContent': '非法闯入', + 'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6', + 'label': '非法闯入', + }, + +] + +COLOR_RED = (0, 0, 255) +COLOR_BLUE = (255, 0, 0) + + +class BlockSceneHandler(BaseSceneHandler): + def __init__(self, device: Device, thread_id: str, tcp_manager: TcpManager, main_loop, range_points): + super().__init__(device=device, thread_id=thread_id, tcp_manager=tcp_manager, main_loop=main_loop) + self.__stop_event = Event(loop=main_loop) + self.health_ts_dict = {} + self.harmful_ts_dict = {} + self.object_ts_dict = {} + self.thread_pool = GlobalThreadPool() + + self.alarm_message_center = AlarmMessageCenter(device.id,main_loop=main_loop, tcp_manager=tcp_manager, + category_priority={2: 0, 1: 1, 3: 2, 0: 3}) + self.alarm_record_center = AlarmRecordCenter(save_interval=device.alarm_interval,main_loop=main_loop) + self.harmful_data_manager = HarmfulGasManager() + self.device_status_manager = DeviceStatusManager() + + + self.health_device_codes = ['HWIH061000056395'] # todo + self.harmful_device_codes = [] # todo + + for helmet_code in self.health_device_codes: + self.thread_pool.submit_task(self.health_data_task, helmet_code) + for harmful_device_code in self.harmful_device_codes: + self.thread_pool.submit_task(self.harmful_data_task, harmful_device_code) + + self.thread_pool.submit_task(self.alarm_message_center.process_messages) + + # todo 明火 + self.model = YOLO('weights/labor-v8-20241114.pt') + self.model_classes = { + 0: '三脚架', + 3: '人', + 4: '作业信息公示牌', + 6: '危险告知牌', + 9: '反光衣', + 11: '呼吸面罩', + 13: '四合一', + 15: '头', + 16: '安全告知牌', + 18: '安全帽', + 20: '安全标识牌', + 24: '工服', + 34: '灭火器', + 43: '警戒线', + 48: '路锥', + 58: '鼓风机', + } + self.PERSON_CLASS_IDX = 3 + self.HEAD_CLASS_IDX = 15 + + self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, + device_thread_id=thread_id) + self.range_points = range_points + self.abs_range_points = self.get_absolute_range() + + self.tracking_status = {} # 跟踪每个行人的状态 + self.max_missing_frames = 25 # 报警的阈值 + self.disappear_threshold = 25 * 3 # 移除行人的阈值 + + def get_absolute_range(self): + fence_info = eval(self.range_points) + if fence_info and len(fence_info) > 1: + abs_points = [] + for p in fence_info: + abs_points.append( + [int(p[0] * int(self.stream_loader.frame_width)), int(p[1] * int(self.stream_loader.frame_height))]) + + abs_points = np.array(abs_points, dtype=np.int32) + hull = ConvexHull(abs_points) + sorted_coordinates = abs_points[hull.vertices] + # abs_points = abs_points.reshape((-1, 1, 2)) + return sorted_coordinates + else: + return None + + def harmful_data_task(self, harmful_device_code): + while not self.__stop_event.is_set(): + harmful_gas_data = self.harmful_data_manager.get_device_all_data(harmful_device_code) + for gas_type, gas_data in harmful_gas_data.items(): + ts_key = f'{harmful_device_code}_{gas_type}' + last_ts = self.harmful_ts_dict.get(ts_key) + gas_ts = gas_data.get('gas_ts') + if last_ts is None or (gas_ts - last_ts).total_seconds() > 0: + self.harmful_ts_dict[ts_key] = gas_ts + self.handle_harmful_gas_alarm(harmful_device_code, gas_type, gas_data) + + def health_data_task(self, helmet_code): + while not self.__stop_event.is_set(): + header = { + 'ak': 'fe80b2f021644b1b8c77fda743a83670', + 'sk': '8771ea6e931d4db646a26f67bcb89909', + } + url = f'https://jls.huaweisoft.com//api/ih-log/v1.0/ih-api/helmetInfo/{helmet_code}' + response = get_request(url, headers=header) + if response and response.get('data'): + last_ts = self.health_ts_dict.get(helmet_code) + vitalsigns_data = response.get('data').get('vitalSignsData') + if vitalsigns_data: + upload_timestamp = datetime.strptime(vitalsigns_data.get('uploadTimestamp'), "%Y-%m-%d %H:%M:%S") + if last_ts is None or (upload_timestamp.timestamp() - last_ts) > 0: + self.health_ts_dict[helmet_code] = upload_timestamp.timestamp() + if time.time() - upload_timestamp.timestamp() < 10 * 60: # 10分钟以前的数据不做处理 + self.handle_health_alarm(helmet_code, vitalsigns_data.get('bloodOxygen'), + vitalsigns_data.get('heartRate'),upload_timestamp) + time.sleep(10) + + def handle_health_alarm(self, helmet_code, blood_oxygen, heartrate, upload_timestamp): + logger.debug(f'health_data: {helmet_code}, blood_oxygen = {blood_oxygen}, heartrate = {heartrate}, ' + f'upload_timestamp = {upload_timestamp}') + if heartrate < 60 or heartrate > 120 or blood_oxygen < 85: + alarm_dict = [d for d in ALARM_DICT if d['alarmCategory'] == 2] + if alarm_dict: + self.alarm_message_center.add_message(alarm_dict[0]) + # todo 需要生成报警记录吗 需要往后台发原始数据吗 + + def handle_harmful_gas_alarm(self, device_code, gas_type, gas_data): + alarm = False + gas_value = gas_data['gas_value'] + if gas_type == 3: # h2s + alarm = gas_value > 120.0 + elif gas_type == 4: # co + alarm = gas_value > 10.0 + elif gas_type == 5: # o2 + alarm = gas_value < 15 + elif gas_type == 50: # ex + alarm = gas_value > 10 + + if alarm: + alarm_dict = [d for d in ALARM_DICT if d['alarmCategory'] == 1] + if alarm_dict: + self.alarm_message_center.add_message(alarm_dict[0]) + # todo 需要生成报警记录吗 + + def model_predict(self, frames): + results_generator = self.model(frames, save_txt=False, save=False, verbose=False, conf=0.5, + classes=list(self.model_classes.keys()), + imgsz=640, + stream=True) + result_boxes = [] + for r in results_generator: + result_boxes.append(r.boxes) + + pred_ids = [[int(box.cls) for box in sublist] for sublist in result_boxes] + pred_names = [[self.model_classes[int(box.cls)] for box in sublist] for sublist in result_boxes] + return result_boxes, pred_ids, pred_names + + def handle_behave_alarm(self, frames, result_boxes, pred_ids, pred_names): + behave_alarm_dicts = [d for d in ALARM_DICT if d['alarmCategory'] == 0] + for alarm_dict in behave_alarm_dicts: + for idx, frame_boxes in enumerate(result_boxes): + frame = frames[idx] + frame_ids, frame_names = pred_ids[idx], pred_names[idx] + object_boxes = [box for box in frame_boxes if int(box.cls) in alarm_dict['class_idx']] + if alarm_dict['handelType'] == 0: # 检测到就报警 + if object_boxes: + self.alarm_message_center.add_message(alarm_dict) + if self.alarm_record_center.need_alarm(self.device.code, alarm_dict): + annotator = Annotator(deepcopy(frame)) + for box in frame_boxes: + if int(box.cls) != self.HEAD_CLASS_IDX: + box_color = COLOR_RED if int(box.cls) in alarm_dict['class_idx'] else COLOR_BLUE + annotator.box_label(box.xyxy.cpu().squeeze(), + f"{self.model_classes[int(box.cls)]}", + color=box_color, + rotated=False) + self.alarm_record_center.upload_alarm_record(self.device.code, alarm_dict, + annotator.result()) + + elif alarm_dict['handelType'] == 1: # 检测不到报警 + if object_boxes: + self.object_ts_dict[alarm_dict['alarm_name']] = time.time() + else: + last_ts = self.object_ts_dict.get(alarm_dict['alarm_name'], 0) + if time.time() - last_ts > 5: + self.object_ts_dict[alarm_dict['alarm_name']] = time.time() + self.alarm_message_center.add_message(alarm_dict) + if self.alarm_record_center.need_alarm(self.device.code, alarm_dict): + annotator = Annotator(deepcopy(frame)) + for box in frame_boxes: + if int(box.cls) != self.HEAD_CLASS_IDX: + annotator.box_label(box.xyxy.cpu().squeeze(), + f"{self.model_classes[int(box.cls)]}", + color=COLOR_BLUE, + rotated=False) + self.alarm_record_center.upload_alarm_record(self.device.code, alarm_dict, + annotator.result()) + elif alarm_dict['handelType'] == 2: # 人未穿戴报警 + person_boxes = [box for box in frame_boxes if int(box.cls) == self.PERSON_CLASS_IDX] + head_boxes = [box for box in frame_boxes if int(box.cls) == self.HEAD_CLASS_IDX] + has_alarm = False + annotator = None + for person_box in person_boxes: + person_bbox = person_box.xyxy.cpu().squeeze() + # 检查这个人是否佩戴了安全帽 + has_helmet = True + person_head = get_person_head(person_bbox, head_boxes) + if person_head is not None: + has_helmet = any( + is_overlapping(person_head.xyxy.cpu().squeeze(), helmet.xyxy.cpu().squeeze()) + for helmet in object_boxes) + if not has_helmet: + has_alarm = True + if self.alarm_record_center.need_alarm(self.device.code, alarm_dict): + annotator = Annotator(deepcopy(frame)) if annotator is None else annotator + annotator.box_label(person_bbox, alarm_dict['label'], color=COLOR_RED, rotated=False) + + if has_alarm: + self.alarm_message_center.add_message(alarm_dict) + if self.alarm_record_center.need_alarm(self.device.code, alarm_dict): + for box in frame_boxes: + box_cls = box.cls + if box_cls != self.PERSON_CLASS_IDX and box_cls != self.HEAD_CLASS_IDX: + annotator.box_label(box.xyxy.cpu().squeeze(), + f"{self.model_classes[int(box.cls)]}", + color=COLOR_BLUE, + rotated=False) + self.alarm_record_center.upload_alarm_record(self.device.code, alarm_dict, + annotator.result()) + + def handle_break_in_alarm(self, frames, result_boxes, pred_ids, pred_names): + break_in_alarm_dicts = [d for d in ALARM_DICT if d['alarmCategory'] == 3] + for alarm_dict in break_in_alarm_dicts: + for idx, frame_boxes in enumerate(result_boxes): + frame = frames[idx] + frame_ids, frame_names = pred_ids[idx], pred_names[idx] + person_boxes = [box for box in frame_boxes if int(box.cls) == self.PERSON_CLASS_IDX] + head_boxes = [box for box in frame_boxes if int(box.cls) == self.HEAD_CLASS_IDX] + object_boxes = [box for box in frame_boxes if int(box.cls) in alarm_dict['class_idx']] + has_alarm = False + annotator = None + for person_box in person_boxes: + person_bbox = person_box.xyxy.cpu().squeeze() + person_id = person_box.id + if is_within_alert_range(person_bbox, self.abs_range_points): + has_object = True + person_head = get_person_head(person_bbox, head_boxes) + if person_head is not None: + overlap_ratio = intersection_area(person_bbox, person_head.xyxy.cpu().squeeze()) / bbox_area(person_bbox) + if overlap_ratio < 0.5: # 头占人<0.5,判断是否穿工服。不太准确 + has_object = any( + is_overlapping(person_head.xyxy.cpu().squeeze(), object_boxe.xyxy.cpu().squeeze()) + for object_boxe in object_boxes) + if not has_object: + self.tracking_status[person_box.id] = self.tracking_status.get(person_box.id, 0) + 1 + + has_alarm = True + if self.alarm_record_center.need_alarm(self.device.code, alarm_dict): + annotator = Annotator(deepcopy(frame)) if annotator is None else annotator + annotator.box_label(person_bbox, alarm_dict['label'], color=COLOR_RED, rotated=False) + + if has_alarm: + self.alarm_message_center.add_message(alarm_dict) + if self.alarm_record_center.need_alarm(self.device.code, alarm_dict): + self.alarm_record_center.upload_alarm_record(self.device.code, alarm_dict, + annotator.result()) + + def run(self): + while not self.stream_loader.init: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + self.stream_loader.init_cap() + for frames in self.stream_loader: + try: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + if not frames: + continue + + self.device_status_manager.set_status(device_id=self.device.id) + result_boxes, pred_ids, pred_names = self.model_predict(frames) # 结果都是二维数组,对应batch中的每个frame + # print(pred_names) + self.handle_behave_alarm(frames, result_boxes, pred_ids, pred_names) + self.handle_break_in_alarm(frames, result_boxes, pred_ids, pred_names) + + except Exception as ex: + traceback.print_exc() + logger.error(ex) diff --git a/scene_handler/limit_space_scene_handler.py b/scene_handler/limit_space_scene_handler.py index d1c5b9a..a7f1674 100644 --- a/scene_handler/limit_space_scene_handler.py +++ b/scene_handler/limit_space_scene_handler.py @@ -145,7 +145,7 @@ class LimitSpaceSceneHandler(BaseSceneHandler): - def __init__(self, device: Device, thread_id: str, tcp_manager: TcpManager, main_loop): + def __init__(self, device: Device, thread_id: str, tcp_manager: TcpManager, main_loop, range_points): super().__init__(device=device, thread_id=thread_id, tcp_manager=tcp_manager, main_loop=main_loop) # self.device = device # self.thread_id = thread_id @@ -287,20 +287,20 @@ if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 self.stream_loader.init_cap() - for frame in self.stream_loader: + for frames in self.stream_loader: try: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 # print('frame') - if frame is None: + if frames is None: continue self.device_status_manager.set_status(device_id=self.device.id) - result_boxes, pred_ids, pred_names = self.model_predict(frame) + # result_boxes, pred_ids, pred_names = self.model_predict(frames) frame_alarm = {} - self.process_alarm(frame, result_boxes, pred_ids, pred_names, frame_alarm) - self.process_labor(frame, result_boxes, pred_ids, pred_names) + # self.process_alarm(frame, result_boxes, pred_ids, pred_names, frame_alarm) + # self.process_labor(frame, result_boxes, pred_ids, pred_names) if len(frame_alarm.keys()) > 0: for key in frame_alarm.keys(): diff --git a/services/alarm_record_service.py b/services/alarm_record_service.py new file mode 100644 index 0000000..a0a86f9 --- /dev/null +++ b/services/alarm_record_service.py @@ -0,0 +1,46 @@ +import os +import uuid +from datetime import datetime + +import aiofiles +import cv2 +from sqlalchemy.ext.asyncio import AsyncSession + +from entity.alarm_record import AlarmRecordCreate, AlarmRecord + + +class AlarmRecordService: + def __init__(self, db: AsyncSession): + self.db = db + + async def add_alarm(self, alarm_data:AlarmRecordCreate, alarm_np_img): + async def save_alarm_file(): + # 生成当前年月日作为目录路径 + current_date = datetime.now().strftime('%Y-%m-%d') + # 生成随机 UUID 作为文件名 + file_name = f"{uuid.uuid4()}.jpeg" + # 创建保存图片的完整路径 + save_path = os.path.join('./storage/alarms', current_date, file_name) + # 创建目录(如果不存在) + os.makedirs(os.path.dirname(save_path), exist_ok=True) + + # 将 frame_data 转换为二进制数据(假设 frame_data 是一个 numpy 数组) + _, encoded_image = cv2.imencode('.jpeg', alarm_np_img) + image_data = encoded_image.tobytes() + + # 使用 aiofiles 进行异步写入 + async with aiofiles.open(save_path, 'wb') as f: + await f.write(image_data) + + return save_path + + # 异步保存图片文件 + alarm_img_path = await save_alarm_file() + + # 创建并保存到数据库中 + alarm_record = AlarmRecord.model_validate(alarm_data) + alarm_record.alarm_image = alarm_img_path + self.db.add(alarm_record) + await self.db.commit() + await self.db.refresh(alarm_record) + return alarm_record diff --git a/services/device_scene_relation_service.py b/services/device_scene_relation_service.py index 3416831..c94c540 100644 --- a/services/device_scene_relation_service.py +++ b/services/device_scene_relation_service.py @@ -39,18 +39,19 @@ if result_row: relation, scene = result_row scene_info = DeviceSceneRelationInfo( - id=relation.id, - device_id=relation.device_id, - scene_id=relation.scene_id, - scene_name=scene.name, - scene_version=scene.version, - scene_handle_task=scene.handle_task, - scene_remark=scene.remark, - ) + id=relation.id, + device_id=relation.device_id, + scene_id=relation.scene_id, + scene_name=scene.name, + scene_version=scene.version, + scene_handle_task=scene.handle_task, + scene_remark=scene.remark, + range_points=relation.range_points + ) return scene_info - async def add_relation_by_device(self, device_id: int, scene_id: int): - new_relation = DeviceSceneRelation(device_id=device_id, scene_id=scene_id) + async def add_relation_by_device(self, device_id: int, scene_id: int, range_points: str = None): + new_relation = DeviceSceneRelation(device_id=device_id, scene_id=scene_id, range_points=range_points) new_relation.create_time = datetime.now() new_relation.update_time = datetime.now() self.db.add(new_relation) @@ -64,8 +65,8 @@ await self.db.commit() return result.rowcount - async def update_relation_by_device(self, device_id: int, scene_id: int): + async def update_relation_by_device(self, device_id: int, scene_id: int, range_points: str = None): await self.delete_relation_by_device(device_id) - new_relation = await self.add_relation_by_device(device_id, scene_id) + new_relation = await self.add_relation_by_device(device_id, scene_id, range_points) self.notify_change(device_id, NotifyChangeType.DEVICE_SCENE_RELATION_UPDATE) return new_relation diff --git a/services/global_config.py b/services/global_config.py index fa53b2b..8baad1e 100644 --- a/services/global_config.py +++ b/services/global_config.py @@ -23,6 +23,7 @@ self.gas_push_config = None self.algo_result_push_config = None self.alarm_push_config = None + self.harmful_gas_push_config = None self._init_done = False async def _initialize(self): @@ -39,6 +40,7 @@ self.gas_push_config = await self.config_service.get_push_config(PUSH_TYPE.GAS) self.algo_result_push_config = await self.config_service.get_push_config(PUSH_TYPE.ALGO_RESULT) self.alarm_push_config = await self.config_service.get_push_config(PUSH_TYPE.ALARM) + self.harmful_gas_push_config = await self.config_service.get_push_config(PUSH_TYPE.HARMFUL_GAS) self._init_done = True async def on_config_change(self, config: PushConfig): @@ -48,6 +50,8 @@ await self.set_algo_result_push_config(config) elif config.push_type == PUSH_TYPE.ALARM: await self.set_alarm_push_config(config) + elif config.push_type == PUSH_TYPE.HARMFUL_GAS: + await self.set_harmful_gas_push_config(config) def get_gas_push_config(self): """获取 gas_push_config 配置""" @@ -78,3 +82,14 @@ if config: async with self._lock: self.alarm_push_config = config + + def get_harmful_gas_push_config(self): + """获取 algo_result_push_config 配置""" + return self.harmful_gas_push_config + + async def set_harmful_gas_push_config(self, config): + """设置 algo_result_push_config 配置""" + if config: + async with self._lock: + self.harmful_gas_push_config = config + diff --git a/tcp/tcp_client_connector.py b/tcp/tcp_client_connector.py index 29d2966..b3b8169 100644 --- a/tcp/tcp_client_connector.py +++ b/tcp/tcp_client_connector.py @@ -68,14 +68,16 @@ self.timeout = timeout # 连接/发送超时时间 self.is_connected = False # 连接状态标志 self.is_reconnecting = False + self.tasks_started = False self.message_queue = asyncio.Queue() #deque() self.gas_task = None self.read_lock = asyncio.Lock() # 添加锁 + self.lock = asyncio.Lock() self.push_ts_dict = {} async def connect(self): """连接到设备""" - while not self.is_connected: + while True: try: logger.info(f"正在连接到 {self.ip}:{self.port}...") # 设置连接超时 @@ -85,11 +87,12 @@ self.is_connected = True logger.info(f"已连接到 {self.ip}:{self.port}") - if self.gas_task is None: - self.gas_task = asyncio.create_task(self.process_message_queue()) # Start processing message queue + if not self.tasks_started: + asyncio.create_task(self.process_message_queue()) # Start processing message queue + asyncio.create_task(self.start_gas_query()) + self.tasks_started = True + break - # 一旦连接成功,开始发送查询指令 - await self.start_gas_query() except (asyncio.TimeoutError, ConnectionRefusedError, OSError) as e: logger.error(f"连接到 {self.ip}:{self.port} 失败,错误: {e}") logger.info(f"{self.reconnect_interval} 秒后将重连到 {self.ip}:{self.port}") @@ -102,31 +105,40 @@ return self.is_reconnecting = True await self.disconnect() # 先断开现有连接 - logger.info(f"Reconnecting to {self.ip}:{self.port} after {self.reconnect_interval} seconds") + # logger.info(f"Reconnecting to {self.ip}:{self.port} after {self.reconnect_interval} seconds") # await asyncio.sleep(self.reconnect_interval) # 等待n秒后重连 await self.connect() self.is_reconnecting = False async def disconnect(self): """断开设备连接,清理资源""" - if self.writer: - logger.info(f"Disconnecting from {self.ip}:{self.port}...") - try: - self.writer.close() - await self.writer.wait_closed() - except Exception as e: - logger.error(f"Error while disconnecting: {e}") - finally: - self.reader = None - self.writer = None - self.is_connected = False # 设置连接状态为 False - logger.info(f"Disconnected from {self.ip}:{self.port}") + # if self.writer: + # logger.info(f"Disconnecting from {self.ip}:{self.port}...") + # try: + # self.writer.close() + # await self.writer.wait_closed() + # except Exception as e: + # logger.error(f"Error while disconnecting: {e}") + # finally: + # self.reader = None + # self.writer = None + # self.is_connected = False # 设置连接状态为 False + # logger.info(f"Disconnected from {self.ip}:{self.port}") + async with self.lock: + if self.writer: + try: + self.writer.close() + await self.writer.wait_closed() + except (ConnectionResetError, BrokenPipeError) as e: + logger.exception(f"Error during disconnection") + self.reader = self.writer = None + logger.info(f"Disconnected from {self.ip}:{self.port}") async def start_gas_query(self): """启动甲烷查询指令,每n秒发送一次""" try: logger.info(f"Start querying gas from {self.ip}...") - while self.is_connected: + while True: await self.send_message(TREE_COMMAND.GAS_QUERY, have_response=True) await asyncio.sleep(self.query_interval) except (ConnectionResetError, asyncio.IncompleteReadError) as e: @@ -167,32 +179,33 @@ async def send_message(self, message: bytes, have_response=True): """Add a message to the queue for sending""" - self.message_queue.append((message, have_response)) + await self.message_queue.put((message, have_response)) logger.info(f"Message enqueued for {self.ip}:{self.port} {format_bytes(message)}") async def process_message_queue(self): """Process messages in the queue, retrying on failures""" - while self.is_connected: + while True: if self.message_queue: - message, have_response = self.message_queue.popleft() + message, have_response = await self.message_queue.get() await self._send_message_with_retry(message, have_response) else: await asyncio.sleep(1) # Small delay to prevent busy-waiting async def _send_message_with_retry(self, message: bytes, have_response): """Send a message with retries on failure""" - retry_attempts = 3 # Maximum retry attempts - for _ in range(retry_attempts): - if not self.is_connected: - await self.reconnect() - if not self.is_connected: - logger.error("Reconnection failed") - continue # Skip this attempt if reconnection fails + # retry_attempts = 3 # Maximum retry attempts + # for _ in range(retry_attempts): + # if not self.is_connected: + # await self.reconnect() + # if not self.is_connected: + # logger.error("Reconnection failed") + # continue # Skip this attempt if reconnection fails - try: - if self.writer is None or self.writer.is_closing(): - raise ConnectionResetError("No active connection or writer is closing") + try: + if self.writer is None or self.writer.is_closing(): + raise ConnectionResetError("No active connection or writer is closing") + async with self.lock: self.writer.write(message) await self.writer.drain() logger.info(f"Sent message to {self.ip}:{self.port}: {message}") @@ -203,13 +216,22 @@ await self.parse_response(data) return # Exit loop on success - except (asyncio.TimeoutError, ConnectionResetError, asyncio.IncompleteReadError, RuntimeError, - BrokenPipeError, OSError, EOFError, ConnectionAbortedError, ConnectionRefusedError) as e: - logger.exception("Failed to send message") - self.is_connected = False # Mark connection as disconnected - await self.reconnect() + except (asyncio.TimeoutError, ConnectionResetError, asyncio.IncompleteReadError, RuntimeError, + BrokenPipeError, OSError, EOFError, ConnectionAbortedError, ConnectionRefusedError) as e: + logger.exception("Failed to send message") + # self.is_connected = False # Mark connection as disconnected + await self.requeue_data(message, have_response) + await self.reconnect() + except Exception as e: + logger.exception(f"Unexpected error. Reconnecting and requeueing data...") + await self.requeue_data(message, have_response) + await self.reconnect() - logger.error("Max retry attempts reached, message sending failed") + # logger.error("Max retry attempts reached, message sending failed") + + async def requeue_data(self, data, have_response): + """Requeue the data that couldn't be sent to avoid data loss.""" + await self.send_message(data, have_response) # async def send_message(self, message: bytes, have_response=True): # """发送自定义消息的接口,供其他类调用""" diff --git a/tcp/tcp_server.py b/tcp/tcp_server.py new file mode 100644 index 0000000..366949d --- /dev/null +++ b/tcp/tcp_server.py @@ -0,0 +1,167 @@ +import asyncio +from datetime import datetime +import re +import json + +from common.harmful_gas_manager import HarmfulGasManager +from common.http_utils import send_request_async +from services.global_config import GlobalConfig + +HOST = '0.0.0.0' +PORT = 12345 + +harmful_gas_manager = HarmfulGasManager() +push_ts_dict = {} + +gas_units = { + 0: "%LEL", + 1: "%VOL", + 2: "PPM", + 3: "umol/mol", + 4: "mg/m3", + 5: "ug/m3", + 6: "℃", + 7: "%" +} + +decimals = { + 0: "没有小数点", + 1: "有一位小数", + 2: "有两位小数", + 3: "有三位小数" +} + +gas_statuses = { + 0: "预热", + 1: "正常", + 3: "传感器故障", + 5: "低限报警", + 6: "高限报警" +} + +gas_types = { + 3: "硫化氢 (H2S)", + 4: "一氧化碳 (CO)", + 5: "氧气 (O2)", + 50: "可燃气体 (Ex)", +} + + +def handle_precision(gas_value, gas_dec): + if gas_dec == 0: + return gas_value + elif gas_dec == 1: + return gas_value / 10 + elif gas_dec == 2: + return gas_value / 100 + elif gas_dec == 3: + return gas_value / 1000 + else: + return gas_value + + +# 示例数据解析函数 +def parse_sensor_data(device_code, sensor_data): + for data in sensor_data: + flag = data.get("flag") + gas_value = data.get("gas_value") + gas_dec = data.get("gas_dec") + gas_status = data.get("gas_status") + gas_type = data.get("gas_type") + gas_unit = data.get("gas_unit") + + # 获取单位、精度、状态和气体类型的描述 + unit = gas_units.get(gas_unit, "未知单位") + precision = decimals.get(gas_dec, "未知精度") + status = gas_statuses.get(gas_status, "未知状态") + gas_type_name = gas_types.get(gas_type, "未知气体") + + # 格式化气体浓度(根据精度进行转换) + gas_value = handle_precision(gas_value, gas_dec) + + gas_data = { + "flag": flag, + "gas_value": gas_value, + "gas_unit": unit, + "gas_status": status, + "gas_type": gas_type_name, + "precision": precision, + 'gas_ts': datetime.now() + } + + harmful_gas_manager.set_device_data(device_code, gas_type, gas_data) + print(harmful_gas_manager.get_device_all_data(device_code)) + + +async def data_push(device_code, message): + global_config = GlobalConfig() + harmful_push_config = global_config.get_gas_push_config() + if harmful_push_config and harmful_push_config.push_url: + last_ts = push_ts_dict.get(device_code) + current_time = datetime.now() + + # 检查是否需要推送数据 + if last_ts is None or (current_time - last_ts).total_seconds() > harmful_push_config.push_interval: + asyncio.create_task(send_request_async(harmful_push_config.push_url, message)) + push_ts_dict[device_code] = current_time # 更新推送时间戳 + else: + print('no harmful push config') + + +async def handle_message(message): + message = message.replace('\r\n', '\n').replace('\r', '\n') + + # 处理消息 + harmful_gas_pattern = r"^([A-Za-z0-9]+)\{(\"sensorDatas\":\[(.*?)\])\}$" + match = re.match(harmful_gas_pattern, message, re.DOTALL) + if match: + device_code = match.group(1) # 设备号 + print(f"设备号: {device_code}") + sensor_data_str = "{" + match.group(2) + "}" # JSON数组部分 + sensor_data_json = json.loads(sensor_data_str) + + if sensor_data_json.get('sensorDatas'): + sensor_data = sensor_data_json.get('sensorDatas') + parse_sensor_data(device_code, sensor_data) + + await data_push(device_code, message) + else: + print("无法解析消息") + + +# 处理客户端连接 +async def handle_client(reader, writer): + client_address = writer.get_extra_info('peername') + print(f"新连接: {client_address}") + + try: + while True: + # 接收数据 + data = await reader.read(1024) + if not data: + print(f"连接关闭: {client_address}") + break + + message = data.decode('utf-8') + print(f"收到数据({client_address}): {message}") + + await handle_message(message) + + except ConnectionResetError: + print(f"客户端断开: {client_address}") + finally: + writer.close() + await writer.wait_closed() + + +# 主服务器函数 +async def start_server(): + server = await asyncio.start_server(handle_client, HOST, PORT) + print(f"服务器启动,监听地址: {HOST}:{PORT}") + + async with server: + await server.serve_forever() + + +if __name__ == "__main__": + asyncio.run(start_server())