diff --git a/.gitignore b/.gitignore index e530a86..6c5fd4c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ test.py -/logs/* \ No newline at end of file +/logs/* +.idea \ No newline at end of file diff --git a/.gitignore b/.gitignore index e530a86..6c5fd4c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ test.py -/logs/* \ No newline at end of file +/logs/* +.idea \ No newline at end of file diff --git a/algo/algo_runner.py b/algo/algo_runner.py index ee07eaf..3b07e0a 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,4 @@ +from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager from common.consts import NotifyChangeType from common.global_thread_pool import GlobalThreadPool @@ -6,25 +7,27 @@ from services.device_model_relation_service import DeviceModelRelationService from services.device_service import DeviceService from services.model_service import ModelService - +from common.global_logger import logger class AlgoRunner: def __init__(self): - self.db = get_db() - self.device_service = DeviceService(self.db) - self.model_service = ModelService(self.db) - self.relation_service = DeviceModelRelationService(self.db) - self.model_manager = ModelManager(self.db) + with next(get_db()) as db: + self.device_service = DeviceService(db) + self.model_service = ModelService(db) + self.relation_service = DeviceModelRelationService(db) + self.model_manager = ModelManager(db) + self.thread_pool = GlobalThreadPool() - self.threads = {} # 用于存储设备对应的线程 + self.device_tasks = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 self.device_service.register_change_callback(self.on_device_change) self.model_service.register_change_callback(self.on_model_change) - self.relation_service.relation_change_callbacks(self.on_relation_change) + self.relation_service.register_change_callback(self.on_relation_change) def start(self): + logger.info("Starting AlgoRunner...") """在程序启动时调用,读取设备和模型,启动检测线程""" self.model_manager.load_models() self.load_and_start_devices() @@ -32,56 +35,73 @@ def load_and_start_devices(self): """从数据库读取设备列表并启动线程""" devices = self.device_service.get_device_list() - devices = [device for device in devices if device.input_stream_url] for device in devices: self.start_device_thread(device) - def run_detection(self, device: Device): - """todo 设备目标检测主逻辑""" - print(f"设备 {device.device_id} 的检测线程启动") - video_stream = self.device_service.get_video_stream(device.device_id) - - while True: - try: - frame = video_stream.read_frame() - for model in self.model_manager.models: - # 调用目标检测模型 - result = self.model_service.run_inference(model, frame) - # 在此处进行后处理,例如标记视频、生成告警等 - self.process_result(result, device.device_id) - except Exception as e: - print(f"设备 {device.device_id} 处理时出错: {e}") - break - video_stream.close() - print(f"设备 {device.device_id} 的检测线程结束") + @staticmethod + def get_device_thread_id(device_id): + return f'device_{device_id}' def start_device_thread(self, device: Device): """为单个设备启动检测线程""" if not device.input_stream_url: - print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程') + logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return - if device.device_id in self.threads: - print(f"设备 {device.device_id} 已经在运行中") + thread_id = self.get_device_thread_id(device.id) + + if self.thread_pool.check_task_is_running(thread_id=thread_id): + logger.info(f"设备 {device.code} 已经在运行中") return # 获取设备绑定的模型列表 - relations = self.relation_service.get_device_models(device.device_id) - self.device_model_relations[device.device_id] = relations + relations = self.relation_service.get_device_models(device.id) + relations = [r for r in relations if r.is_use] + self.device_model_relations[device.id] = relations if not relations: - print(f"设备 {device.code} 未绑定模型,无法启动检测") + logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") return - future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}') + model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations + if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] + device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list) + future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) + self.device_tasks[device.id] = device_detection_task def stop_device_thread(self, device_id): - """todo 控制标志位 停止指定设备的检测线程""" - if device_id in self.threads: - self.threads[device_id].cancel() # 尝试取消线程 - del self.threads[device_id] - print(f"设备 {device_id} 的检测线程已停止") + """停止指定设备的检测线程""" + if device_id in self.device_tasks: + self.device_tasks[device_id].stop_detection_task() + del self.device_tasks[device_id] + + # def stop_device_thread(self, device_id): + # """停止指定设备的检测线程,并确保其成功停止""" + # if device_id in self.device_tasks: + # # 获取线程 ID 和 future 对象 + # thread_id = self.get_device_thread_id(device_id) + # future = self.thread_pool.check_task_stopped(thread_id=thread_id) + # + # if future: + # # 调用 stop_detection_task 停止任务 + # self.device_tasks[device_id].stop_detection_task() + # + # try: + # # 设置超时时间等待任务停止(例如10秒) + # result = future.result(timeout=10) + # logger.info(f"Task for device {device_id} stopped successfully.") + # except TimeoutError: + # logger.error(f"Task for device {device_id} did not stop within the timeout.") + # except Exception as e: + # logger.error(f"Task for device {device_id} encountered an error while stopping: {e}") + # finally: + # # 确保无论任务是否停止,都将其从任务列表中移除 + # del self.device_tasks[device_id] + # else: + # logger.warning(f"No task found for device {device_id}.") + # else: + # logger.warning(f"No task exists for device {device_id} in device_tasks.") def restart_device_thread(self, device_id): self.stop_device_thread(device_id) @@ -97,6 +117,7 @@ elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: + #todo 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 self.restart_device_thread(device_id) def on_model_change(self, model_id, change_type): @@ -113,5 +134,6 @@ def on_relation_change(self, device_id, change_type): """设备模型关系变化回调""" + # todo 判断模型绑定关系是否真的发生了变化,避免不必要的重启 if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: self.restart_device_thread(device_id) diff --git a/.gitignore b/.gitignore index e530a86..6c5fd4c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ test.py -/logs/* \ No newline at end of file +/logs/* +.idea \ No newline at end of file diff --git a/algo/algo_runner.py b/algo/algo_runner.py index ee07eaf..3b07e0a 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,4 @@ +from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager from common.consts import NotifyChangeType from common.global_thread_pool import GlobalThreadPool @@ -6,25 +7,27 @@ from services.device_model_relation_service import DeviceModelRelationService from services.device_service import DeviceService from services.model_service import ModelService - +from common.global_logger import logger class AlgoRunner: def __init__(self): - self.db = get_db() - self.device_service = DeviceService(self.db) - self.model_service = ModelService(self.db) - self.relation_service = DeviceModelRelationService(self.db) - self.model_manager = ModelManager(self.db) + with next(get_db()) as db: + self.device_service = DeviceService(db) + self.model_service = ModelService(db) + self.relation_service = DeviceModelRelationService(db) + self.model_manager = ModelManager(db) + self.thread_pool = GlobalThreadPool() - self.threads = {} # 用于存储设备对应的线程 + self.device_tasks = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 self.device_service.register_change_callback(self.on_device_change) self.model_service.register_change_callback(self.on_model_change) - self.relation_service.relation_change_callbacks(self.on_relation_change) + self.relation_service.register_change_callback(self.on_relation_change) def start(self): + logger.info("Starting AlgoRunner...") """在程序启动时调用,读取设备和模型,启动检测线程""" self.model_manager.load_models() self.load_and_start_devices() @@ -32,56 +35,73 @@ def load_and_start_devices(self): """从数据库读取设备列表并启动线程""" devices = self.device_service.get_device_list() - devices = [device for device in devices if device.input_stream_url] for device in devices: self.start_device_thread(device) - def run_detection(self, device: Device): - """todo 设备目标检测主逻辑""" - print(f"设备 {device.device_id} 的检测线程启动") - video_stream = self.device_service.get_video_stream(device.device_id) - - while True: - try: - frame = video_stream.read_frame() - for model in self.model_manager.models: - # 调用目标检测模型 - result = self.model_service.run_inference(model, frame) - # 在此处进行后处理,例如标记视频、生成告警等 - self.process_result(result, device.device_id) - except Exception as e: - print(f"设备 {device.device_id} 处理时出错: {e}") - break - video_stream.close() - print(f"设备 {device.device_id} 的检测线程结束") + @staticmethod + def get_device_thread_id(device_id): + return f'device_{device_id}' def start_device_thread(self, device: Device): """为单个设备启动检测线程""" if not device.input_stream_url: - print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程') + logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return - if device.device_id in self.threads: - print(f"设备 {device.device_id} 已经在运行中") + thread_id = self.get_device_thread_id(device.id) + + if self.thread_pool.check_task_is_running(thread_id=thread_id): + logger.info(f"设备 {device.code} 已经在运行中") return # 获取设备绑定的模型列表 - relations = self.relation_service.get_device_models(device.device_id) - self.device_model_relations[device.device_id] = relations + relations = self.relation_service.get_device_models(device.id) + relations = [r for r in relations if r.is_use] + self.device_model_relations[device.id] = relations if not relations: - print(f"设备 {device.code} 未绑定模型,无法启动检测") + logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") return - future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}') + model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations + if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] + device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list) + future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) + self.device_tasks[device.id] = device_detection_task def stop_device_thread(self, device_id): - """todo 控制标志位 停止指定设备的检测线程""" - if device_id in self.threads: - self.threads[device_id].cancel() # 尝试取消线程 - del self.threads[device_id] - print(f"设备 {device_id} 的检测线程已停止") + """停止指定设备的检测线程""" + if device_id in self.device_tasks: + self.device_tasks[device_id].stop_detection_task() + del self.device_tasks[device_id] + + # def stop_device_thread(self, device_id): + # """停止指定设备的检测线程,并确保其成功停止""" + # if device_id in self.device_tasks: + # # 获取线程 ID 和 future 对象 + # thread_id = self.get_device_thread_id(device_id) + # future = self.thread_pool.check_task_stopped(thread_id=thread_id) + # + # if future: + # # 调用 stop_detection_task 停止任务 + # self.device_tasks[device_id].stop_detection_task() + # + # try: + # # 设置超时时间等待任务停止(例如10秒) + # result = future.result(timeout=10) + # logger.info(f"Task for device {device_id} stopped successfully.") + # except TimeoutError: + # logger.error(f"Task for device {device_id} did not stop within the timeout.") + # except Exception as e: + # logger.error(f"Task for device {device_id} encountered an error while stopping: {e}") + # finally: + # # 确保无论任务是否停止,都将其从任务列表中移除 + # del self.device_tasks[device_id] + # else: + # logger.warning(f"No task found for device {device_id}.") + # else: + # logger.warning(f"No task exists for device {device_id} in device_tasks.") def restart_device_thread(self, device_id): self.stop_device_thread(device_id) @@ -97,6 +117,7 @@ elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: + #todo 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 self.restart_device_thread(device_id) def on_model_change(self, model_id, change_type): @@ -113,5 +134,6 @@ def on_relation_change(self, device_id, change_type): """设备模型关系变化回调""" + # todo 判断模型绑定关系是否真的发生了变化,避免不必要的重启 if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: self.restart_device_thread(device_id) diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py new file mode 100644 index 0000000..ff2d513 --- /dev/null +++ b/algo/device_detection_task.py @@ -0,0 +1,106 @@ +from dataclasses import dataclass +import importlib +from datetime import datetime +from threading import Event +from typing import List, Dict + +from algo.model_manager import AlgoModelExec +from algo.stream_loader import OpenCVStreamLoad +from common.global_thread_pool import GlobalThreadPool +from common.string_utils import camel_to_snake +from db.database import get_db +from entity.device import Device +from entity.frame_analysis_result import FrameAnalysisResultCreate +from services.device_frame_service import DeviceFrameService +from services.frame_analysis_result_service import FrameAnalysisResultService + + +def get_class(module_name, class_name): + # 动态导入模块 + module = importlib.import_module(module_name) + # 使用 getattr 从模块中获取类 + return getattr(module, class_name) + + +@dataclass +class DetectionResult: + object_class_id: int + object_class_name: str + confidence: float + location: str + + @classmethod + def from_dict(cls, data: Dict) -> 'DetectionResult': + return DetectionResult( + object_class_id=data['object_class_id'], + object_class_name=data['object_class_name'], + confidence=data['confidence'], + location=data['location'] + ) + + +class DeviceDetectionTask: + def __init__(self, device: Device, model_exec_list: List[AlgoModelExec]): + self.device = device + self.model_exec_list = model_exec_list + self.__stop_event = Event() # 使用 Event 控制线程的运行状态 + self.frame_ts = None + + with next(get_db()) as db: + self.device_frame_service = DeviceFrameService(db) + self.frame_analysis_result_service = FrameAnalysisResultService(db) + + self.thread_pool = GlobalThreadPool() + + self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code) + + def stop_detection_task(self): + self.__stop_event.set() + self.stream_loader.stop() # 停止视频流加载的线程 + + def check_frame_interval(self): + if self.device.image_save_interval < 0: + return False + if self.frame_ts is None: + self.frame_ts = datetime.now() + return True + if datetime.now() - self.frame_ts > self.device.image_save_interval: + self.frame_ts = datetime.now() + return True + return False + + def save_frame_results(self, frame, results_map): + if not self.check_frame_interval(): + return + device_frame = self.device_frame_service.add_frame(self.device.id, frame) + frame_id = device_frame.id + frame_results = [] + for model_exec_id, results in results_map.items(): + for r in results: + frame_result = FrameAnalysisResultCreate( + device_id=self.device.id, + frame_id=frame_id, + algo_model_id=model_exec_id, + object_class_id=r.object_class_id, + object_class_name=r.object_class_name, + confidence=r.conf, + location=r.location, + ) + frame_results.append(frame_result) + self.frame_analysis_result_service.add_frame_analysis_results(frame_results) + + def run(self): + for frame in self.stream_loader: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + + results_map = {} + for model_exec in self.model_exec_list: + handle_task_name = model_exec.algo_model_info.handle_task + handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) + handler_instance = handler_cls(model_exec) + frame, results = handler_instance.run(frame, None) + results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] + + # 结果处理 + self.thread_pool.submit_task(self.save_frame_results, frame, results_map) diff --git a/.gitignore b/.gitignore index e530a86..6c5fd4c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ test.py -/logs/* \ No newline at end of file +/logs/* +.idea \ No newline at end of file diff --git a/algo/algo_runner.py b/algo/algo_runner.py index ee07eaf..3b07e0a 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,4 @@ +from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager from common.consts import NotifyChangeType from common.global_thread_pool import GlobalThreadPool @@ -6,25 +7,27 @@ from services.device_model_relation_service import DeviceModelRelationService from services.device_service import DeviceService from services.model_service import ModelService - +from common.global_logger import logger class AlgoRunner: def __init__(self): - self.db = get_db() - self.device_service = DeviceService(self.db) - self.model_service = ModelService(self.db) - self.relation_service = DeviceModelRelationService(self.db) - self.model_manager = ModelManager(self.db) + with next(get_db()) as db: + self.device_service = DeviceService(db) + self.model_service = ModelService(db) + self.relation_service = DeviceModelRelationService(db) + self.model_manager = ModelManager(db) + self.thread_pool = GlobalThreadPool() - self.threads = {} # 用于存储设备对应的线程 + self.device_tasks = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 self.device_service.register_change_callback(self.on_device_change) self.model_service.register_change_callback(self.on_model_change) - self.relation_service.relation_change_callbacks(self.on_relation_change) + self.relation_service.register_change_callback(self.on_relation_change) def start(self): + logger.info("Starting AlgoRunner...") """在程序启动时调用,读取设备和模型,启动检测线程""" self.model_manager.load_models() self.load_and_start_devices() @@ -32,56 +35,73 @@ def load_and_start_devices(self): """从数据库读取设备列表并启动线程""" devices = self.device_service.get_device_list() - devices = [device for device in devices if device.input_stream_url] for device in devices: self.start_device_thread(device) - def run_detection(self, device: Device): - """todo 设备目标检测主逻辑""" - print(f"设备 {device.device_id} 的检测线程启动") - video_stream = self.device_service.get_video_stream(device.device_id) - - while True: - try: - frame = video_stream.read_frame() - for model in self.model_manager.models: - # 调用目标检测模型 - result = self.model_service.run_inference(model, frame) - # 在此处进行后处理,例如标记视频、生成告警等 - self.process_result(result, device.device_id) - except Exception as e: - print(f"设备 {device.device_id} 处理时出错: {e}") - break - video_stream.close() - print(f"设备 {device.device_id} 的检测线程结束") + @staticmethod + def get_device_thread_id(device_id): + return f'device_{device_id}' def start_device_thread(self, device: Device): """为单个设备启动检测线程""" if not device.input_stream_url: - print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程') + logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return - if device.device_id in self.threads: - print(f"设备 {device.device_id} 已经在运行中") + thread_id = self.get_device_thread_id(device.id) + + if self.thread_pool.check_task_is_running(thread_id=thread_id): + logger.info(f"设备 {device.code} 已经在运行中") return # 获取设备绑定的模型列表 - relations = self.relation_service.get_device_models(device.device_id) - self.device_model_relations[device.device_id] = relations + relations = self.relation_service.get_device_models(device.id) + relations = [r for r in relations if r.is_use] + self.device_model_relations[device.id] = relations if not relations: - print(f"设备 {device.code} 未绑定模型,无法启动检测") + logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") return - future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}') + model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations + if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] + device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list) + future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) + self.device_tasks[device.id] = device_detection_task def stop_device_thread(self, device_id): - """todo 控制标志位 停止指定设备的检测线程""" - if device_id in self.threads: - self.threads[device_id].cancel() # 尝试取消线程 - del self.threads[device_id] - print(f"设备 {device_id} 的检测线程已停止") + """停止指定设备的检测线程""" + if device_id in self.device_tasks: + self.device_tasks[device_id].stop_detection_task() + del self.device_tasks[device_id] + + # def stop_device_thread(self, device_id): + # """停止指定设备的检测线程,并确保其成功停止""" + # if device_id in self.device_tasks: + # # 获取线程 ID 和 future 对象 + # thread_id = self.get_device_thread_id(device_id) + # future = self.thread_pool.check_task_stopped(thread_id=thread_id) + # + # if future: + # # 调用 stop_detection_task 停止任务 + # self.device_tasks[device_id].stop_detection_task() + # + # try: + # # 设置超时时间等待任务停止(例如10秒) + # result = future.result(timeout=10) + # logger.info(f"Task for device {device_id} stopped successfully.") + # except TimeoutError: + # logger.error(f"Task for device {device_id} did not stop within the timeout.") + # except Exception as e: + # logger.error(f"Task for device {device_id} encountered an error while stopping: {e}") + # finally: + # # 确保无论任务是否停止,都将其从任务列表中移除 + # del self.device_tasks[device_id] + # else: + # logger.warning(f"No task found for device {device_id}.") + # else: + # logger.warning(f"No task exists for device {device_id} in device_tasks.") def restart_device_thread(self, device_id): self.stop_device_thread(device_id) @@ -97,6 +117,7 @@ elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: + #todo 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 self.restart_device_thread(device_id) def on_model_change(self, model_id, change_type): @@ -113,5 +134,6 @@ def on_relation_change(self, device_id, change_type): """设备模型关系变化回调""" + # todo 判断模型绑定关系是否真的发生了变化,避免不必要的重启 if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: self.restart_device_thread(device_id) diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py new file mode 100644 index 0000000..ff2d513 --- /dev/null +++ b/algo/device_detection_task.py @@ -0,0 +1,106 @@ +from dataclasses import dataclass +import importlib +from datetime import datetime +from threading import Event +from typing import List, Dict + +from algo.model_manager import AlgoModelExec +from algo.stream_loader import OpenCVStreamLoad +from common.global_thread_pool import GlobalThreadPool +from common.string_utils import camel_to_snake +from db.database import get_db +from entity.device import Device +from entity.frame_analysis_result import FrameAnalysisResultCreate +from services.device_frame_service import DeviceFrameService +from services.frame_analysis_result_service import FrameAnalysisResultService + + +def get_class(module_name, class_name): + # 动态导入模块 + module = importlib.import_module(module_name) + # 使用 getattr 从模块中获取类 + return getattr(module, class_name) + + +@dataclass +class DetectionResult: + object_class_id: int + object_class_name: str + confidence: float + location: str + + @classmethod + def from_dict(cls, data: Dict) -> 'DetectionResult': + return DetectionResult( + object_class_id=data['object_class_id'], + object_class_name=data['object_class_name'], + confidence=data['confidence'], + location=data['location'] + ) + + +class DeviceDetectionTask: + def __init__(self, device: Device, model_exec_list: List[AlgoModelExec]): + self.device = device + self.model_exec_list = model_exec_list + self.__stop_event = Event() # 使用 Event 控制线程的运行状态 + self.frame_ts = None + + with next(get_db()) as db: + self.device_frame_service = DeviceFrameService(db) + self.frame_analysis_result_service = FrameAnalysisResultService(db) + + self.thread_pool = GlobalThreadPool() + + self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code) + + def stop_detection_task(self): + self.__stop_event.set() + self.stream_loader.stop() # 停止视频流加载的线程 + + def check_frame_interval(self): + if self.device.image_save_interval < 0: + return False + if self.frame_ts is None: + self.frame_ts = datetime.now() + return True + if datetime.now() - self.frame_ts > self.device.image_save_interval: + self.frame_ts = datetime.now() + return True + return False + + def save_frame_results(self, frame, results_map): + if not self.check_frame_interval(): + return + device_frame = self.device_frame_service.add_frame(self.device.id, frame) + frame_id = device_frame.id + frame_results = [] + for model_exec_id, results in results_map.items(): + for r in results: + frame_result = FrameAnalysisResultCreate( + device_id=self.device.id, + frame_id=frame_id, + algo_model_id=model_exec_id, + object_class_id=r.object_class_id, + object_class_name=r.object_class_name, + confidence=r.conf, + location=r.location, + ) + frame_results.append(frame_result) + self.frame_analysis_result_service.add_frame_analysis_results(frame_results) + + def run(self): + for frame in self.stream_loader: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + + results_map = {} + for model_exec in self.model_exec_list: + handle_task_name = model_exec.algo_model_info.handle_task + handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) + handler_instance = handler_cls(model_exec) + frame, results = handler_instance.run(frame, None) + results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] + + # 结果处理 + self.thread_pool.submit_task(self.save_frame_results, frame, results_map) diff --git a/algo/model_manager.py b/algo/model_manager.py index 3ab0709..535c666 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -1,5 +1,6 @@ +import os.path from dataclasses import dataclass -from typing import Optional +from typing import Optional, Dict import numpy as np from ultralytics import YOLO @@ -13,7 +14,7 @@ class AlgoModelExec: algo_model_id: int algo_model_info: AlgoModel - algo_model_exec: Optional[object] = None + algo_model_exec: Optional[YOLO] = None input_size: int = 640 @@ -21,11 +22,9 @@ def __init__(self, db, model_warm_up=5): self.db = db self.model_service = ModelService(self.db) - self.models = {} + self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up - self.load_models() - def query_model_inuse(self): algo_model_list = list(self.model_service.get_models_in_use()) for algo_model in algo_model_list: @@ -35,6 +34,7 @@ ) def load_models(self): + logger.info('loading models') self.models = {} self.query_model_inuse() for algo_model_id, algo_model_exec in self.models.items(): @@ -45,6 +45,10 @@ model_path = algo_model_exec.algo_model_info.path logger.info(f'loading model {model_name}: {model_path}') + if not os.path.exists(model_path): + logger.info(f'model path:{model_path} not exists') + return + algo_model_exec.algo_model_exec = YOLO(model_path, task='detect') if self.model_warm_up > 0: logger.info(f'warming up model {model_name}') diff --git a/.gitignore b/.gitignore index e530a86..6c5fd4c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ test.py -/logs/* \ No newline at end of file +/logs/* +.idea \ No newline at end of file diff --git a/algo/algo_runner.py b/algo/algo_runner.py index ee07eaf..3b07e0a 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,4 @@ +from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager from common.consts import NotifyChangeType from common.global_thread_pool import GlobalThreadPool @@ -6,25 +7,27 @@ from services.device_model_relation_service import DeviceModelRelationService from services.device_service import DeviceService from services.model_service import ModelService - +from common.global_logger import logger class AlgoRunner: def __init__(self): - self.db = get_db() - self.device_service = DeviceService(self.db) - self.model_service = ModelService(self.db) - self.relation_service = DeviceModelRelationService(self.db) - self.model_manager = ModelManager(self.db) + with next(get_db()) as db: + self.device_service = DeviceService(db) + self.model_service = ModelService(db) + self.relation_service = DeviceModelRelationService(db) + self.model_manager = ModelManager(db) + self.thread_pool = GlobalThreadPool() - self.threads = {} # 用于存储设备对应的线程 + self.device_tasks = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 self.device_service.register_change_callback(self.on_device_change) self.model_service.register_change_callback(self.on_model_change) - self.relation_service.relation_change_callbacks(self.on_relation_change) + self.relation_service.register_change_callback(self.on_relation_change) def start(self): + logger.info("Starting AlgoRunner...") """在程序启动时调用,读取设备和模型,启动检测线程""" self.model_manager.load_models() self.load_and_start_devices() @@ -32,56 +35,73 @@ def load_and_start_devices(self): """从数据库读取设备列表并启动线程""" devices = self.device_service.get_device_list() - devices = [device for device in devices if device.input_stream_url] for device in devices: self.start_device_thread(device) - def run_detection(self, device: Device): - """todo 设备目标检测主逻辑""" - print(f"设备 {device.device_id} 的检测线程启动") - video_stream = self.device_service.get_video_stream(device.device_id) - - while True: - try: - frame = video_stream.read_frame() - for model in self.model_manager.models: - # 调用目标检测模型 - result = self.model_service.run_inference(model, frame) - # 在此处进行后处理,例如标记视频、生成告警等 - self.process_result(result, device.device_id) - except Exception as e: - print(f"设备 {device.device_id} 处理时出错: {e}") - break - video_stream.close() - print(f"设备 {device.device_id} 的检测线程结束") + @staticmethod + def get_device_thread_id(device_id): + return f'device_{device_id}' def start_device_thread(self, device: Device): """为单个设备启动检测线程""" if not device.input_stream_url: - print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程') + logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return - if device.device_id in self.threads: - print(f"设备 {device.device_id} 已经在运行中") + thread_id = self.get_device_thread_id(device.id) + + if self.thread_pool.check_task_is_running(thread_id=thread_id): + logger.info(f"设备 {device.code} 已经在运行中") return # 获取设备绑定的模型列表 - relations = self.relation_service.get_device_models(device.device_id) - self.device_model_relations[device.device_id] = relations + relations = self.relation_service.get_device_models(device.id) + relations = [r for r in relations if r.is_use] + self.device_model_relations[device.id] = relations if not relations: - print(f"设备 {device.code} 未绑定模型,无法启动检测") + logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") return - future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}') + model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations + if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] + device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list) + future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) + self.device_tasks[device.id] = device_detection_task def stop_device_thread(self, device_id): - """todo 控制标志位 停止指定设备的检测线程""" - if device_id in self.threads: - self.threads[device_id].cancel() # 尝试取消线程 - del self.threads[device_id] - print(f"设备 {device_id} 的检测线程已停止") + """停止指定设备的检测线程""" + if device_id in self.device_tasks: + self.device_tasks[device_id].stop_detection_task() + del self.device_tasks[device_id] + + # def stop_device_thread(self, device_id): + # """停止指定设备的检测线程,并确保其成功停止""" + # if device_id in self.device_tasks: + # # 获取线程 ID 和 future 对象 + # thread_id = self.get_device_thread_id(device_id) + # future = self.thread_pool.check_task_stopped(thread_id=thread_id) + # + # if future: + # # 调用 stop_detection_task 停止任务 + # self.device_tasks[device_id].stop_detection_task() + # + # try: + # # 设置超时时间等待任务停止(例如10秒) + # result = future.result(timeout=10) + # logger.info(f"Task for device {device_id} stopped successfully.") + # except TimeoutError: + # logger.error(f"Task for device {device_id} did not stop within the timeout.") + # except Exception as e: + # logger.error(f"Task for device {device_id} encountered an error while stopping: {e}") + # finally: + # # 确保无论任务是否停止,都将其从任务列表中移除 + # del self.device_tasks[device_id] + # else: + # logger.warning(f"No task found for device {device_id}.") + # else: + # logger.warning(f"No task exists for device {device_id} in device_tasks.") def restart_device_thread(self, device_id): self.stop_device_thread(device_id) @@ -97,6 +117,7 @@ elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: + #todo 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 self.restart_device_thread(device_id) def on_model_change(self, model_id, change_type): @@ -113,5 +134,6 @@ def on_relation_change(self, device_id, change_type): """设备模型关系变化回调""" + # todo 判断模型绑定关系是否真的发生了变化,避免不必要的重启 if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: self.restart_device_thread(device_id) diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py new file mode 100644 index 0000000..ff2d513 --- /dev/null +++ b/algo/device_detection_task.py @@ -0,0 +1,106 @@ +from dataclasses import dataclass +import importlib +from datetime import datetime +from threading import Event +from typing import List, Dict + +from algo.model_manager import AlgoModelExec +from algo.stream_loader import OpenCVStreamLoad +from common.global_thread_pool import GlobalThreadPool +from common.string_utils import camel_to_snake +from db.database import get_db +from entity.device import Device +from entity.frame_analysis_result import FrameAnalysisResultCreate +from services.device_frame_service import DeviceFrameService +from services.frame_analysis_result_service import FrameAnalysisResultService + + +def get_class(module_name, class_name): + # 动态导入模块 + module = importlib.import_module(module_name) + # 使用 getattr 从模块中获取类 + return getattr(module, class_name) + + +@dataclass +class DetectionResult: + object_class_id: int + object_class_name: str + confidence: float + location: str + + @classmethod + def from_dict(cls, data: Dict) -> 'DetectionResult': + return DetectionResult( + object_class_id=data['object_class_id'], + object_class_name=data['object_class_name'], + confidence=data['confidence'], + location=data['location'] + ) + + +class DeviceDetectionTask: + def __init__(self, device: Device, model_exec_list: List[AlgoModelExec]): + self.device = device + self.model_exec_list = model_exec_list + self.__stop_event = Event() # 使用 Event 控制线程的运行状态 + self.frame_ts = None + + with next(get_db()) as db: + self.device_frame_service = DeviceFrameService(db) + self.frame_analysis_result_service = FrameAnalysisResultService(db) + + self.thread_pool = GlobalThreadPool() + + self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code) + + def stop_detection_task(self): + self.__stop_event.set() + self.stream_loader.stop() # 停止视频流加载的线程 + + def check_frame_interval(self): + if self.device.image_save_interval < 0: + return False + if self.frame_ts is None: + self.frame_ts = datetime.now() + return True + if datetime.now() - self.frame_ts > self.device.image_save_interval: + self.frame_ts = datetime.now() + return True + return False + + def save_frame_results(self, frame, results_map): + if not self.check_frame_interval(): + return + device_frame = self.device_frame_service.add_frame(self.device.id, frame) + frame_id = device_frame.id + frame_results = [] + for model_exec_id, results in results_map.items(): + for r in results: + frame_result = FrameAnalysisResultCreate( + device_id=self.device.id, + frame_id=frame_id, + algo_model_id=model_exec_id, + object_class_id=r.object_class_id, + object_class_name=r.object_class_name, + confidence=r.conf, + location=r.location, + ) + frame_results.append(frame_result) + self.frame_analysis_result_service.add_frame_analysis_results(frame_results) + + def run(self): + for frame in self.stream_loader: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + + results_map = {} + for model_exec in self.model_exec_list: + handle_task_name = model_exec.algo_model_info.handle_task + handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) + handler_instance = handler_cls(model_exec) + frame, results = handler_instance.run(frame, None) + results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] + + # 结果处理 + self.thread_pool.submit_task(self.save_frame_results, frame, results_map) diff --git a/algo/model_manager.py b/algo/model_manager.py index 3ab0709..535c666 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -1,5 +1,6 @@ +import os.path from dataclasses import dataclass -from typing import Optional +from typing import Optional, Dict import numpy as np from ultralytics import YOLO @@ -13,7 +14,7 @@ class AlgoModelExec: algo_model_id: int algo_model_info: AlgoModel - algo_model_exec: Optional[object] = None + algo_model_exec: Optional[YOLO] = None input_size: int = 640 @@ -21,11 +22,9 @@ def __init__(self, db, model_warm_up=5): self.db = db self.model_service = ModelService(self.db) - self.models = {} + self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up - self.load_models() - def query_model_inuse(self): algo_model_list = list(self.model_service.get_models_in_use()) for algo_model in algo_model_list: @@ -35,6 +34,7 @@ ) def load_models(self): + logger.info('loading models') self.models = {} self.query_model_inuse() for algo_model_id, algo_model_exec in self.models.items(): @@ -45,6 +45,10 @@ model_path = algo_model_exec.algo_model_info.path logger.info(f'loading model {model_name}: {model_path}') + if not os.path.exists(model_path): + logger.info(f'model path:{model_path} not exists') + return + algo_model_exec.algo_model_exec = YOLO(model_path, task='detect') if self.model_warm_up > 0: logger.info(f'warming up model {model_name}') diff --git a/algo/stream_loader.py b/algo/stream_loader.py new file mode 100644 index 0000000..04faeb2 --- /dev/null +++ b/algo/stream_loader.py @@ -0,0 +1,92 @@ +import cv2 +import time +import numpy as np +from threading import Thread, Event +from common.global_logger import logger + + +class OpenCVStreamLoad: + def __init__(self, camera_url, camera_code, + retry_interval=1, + vid_stride=1): + assert camera_url is not None and camera_url != '' + self.url = camera_url + self.camera_code = camera_code + self.retry_interval = retry_interval + self.vid_stride = vid_stride + + self.cap = self.get_connect() + self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + + _, self.frame = self.cap.read() + self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 + self.thread = Thread(target=self.update, daemon=True) + self.thread.start() + + def create_capture(self): + """ + 尝试创建视频流捕获对象。 + """ + try: + cap = cv2.VideoCapture(self.url) + # 可以在这里设置cap的一些属性,如果需要的话 + return cap + except Exception as e: + logger.error(e) + return None + + def get_connect(self): + """ + 尝试重新连接,直到成功。 + """ + cap = None + while cap is None or not cap.isOpened(): + logger.info(f"{self.url} try to connect...") + cap = self.create_capture() + if cap is None or not cap.isOpened(): + logger.info(f"{self.url} connect failed, retry after {self.retry_interval} second...") + time.sleep(self.retry_interval) # 等待一段时间后重试 + else: + logger.info(f"{self.url} connect success!") + return cap + + def update(self): + vid_n = 0 + log_n = 0 + while True: + vid_n += 1 + if vid_n % self.vid_stride == 0: + try: + ret, frame = self.cap.read() + if not ret: + logger.info("disconnect, try to reconnect...") + self.cap.release() # 释放当前的捕获对象 + self.cap = self.get_connect() # 尝试重新连接 + self.frame = np.zeros_like(self.frame) + continue # 跳过当前循环的剩余部分 + else: + vid_n += 1 + self.frame = frame + # cv2.imwrite('cv_test.jpg', frame) + if log_n % 1000 == 0: + logger.debug('cap success') + log_n = (log_n + 1) % 250 + except Exception as e: + logger.error("update fail", e) + if self.cap is not None: + self.cap.release() + self.cap = self.get_connect() # 尝试重新连接 + + def __iter__(self): + return self + + def __next__(self): + return self.frame + + def stop(self): + """ 停止视频流读取线程 """ + self.__stop_event.set() + self.thread.join() # 确保线程已完全终止 + self.cap.release() diff --git a/.gitignore b/.gitignore index e530a86..6c5fd4c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ test.py -/logs/* \ No newline at end of file +/logs/* +.idea \ No newline at end of file diff --git a/algo/algo_runner.py b/algo/algo_runner.py index ee07eaf..3b07e0a 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,4 @@ +from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager from common.consts import NotifyChangeType from common.global_thread_pool import GlobalThreadPool @@ -6,25 +7,27 @@ from services.device_model_relation_service import DeviceModelRelationService from services.device_service import DeviceService from services.model_service import ModelService - +from common.global_logger import logger class AlgoRunner: def __init__(self): - self.db = get_db() - self.device_service = DeviceService(self.db) - self.model_service = ModelService(self.db) - self.relation_service = DeviceModelRelationService(self.db) - self.model_manager = ModelManager(self.db) + with next(get_db()) as db: + self.device_service = DeviceService(db) + self.model_service = ModelService(db) + self.relation_service = DeviceModelRelationService(db) + self.model_manager = ModelManager(db) + self.thread_pool = GlobalThreadPool() - self.threads = {} # 用于存储设备对应的线程 + self.device_tasks = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 self.device_service.register_change_callback(self.on_device_change) self.model_service.register_change_callback(self.on_model_change) - self.relation_service.relation_change_callbacks(self.on_relation_change) + self.relation_service.register_change_callback(self.on_relation_change) def start(self): + logger.info("Starting AlgoRunner...") """在程序启动时调用,读取设备和模型,启动检测线程""" self.model_manager.load_models() self.load_and_start_devices() @@ -32,56 +35,73 @@ def load_and_start_devices(self): """从数据库读取设备列表并启动线程""" devices = self.device_service.get_device_list() - devices = [device for device in devices if device.input_stream_url] for device in devices: self.start_device_thread(device) - def run_detection(self, device: Device): - """todo 设备目标检测主逻辑""" - print(f"设备 {device.device_id} 的检测线程启动") - video_stream = self.device_service.get_video_stream(device.device_id) - - while True: - try: - frame = video_stream.read_frame() - for model in self.model_manager.models: - # 调用目标检测模型 - result = self.model_service.run_inference(model, frame) - # 在此处进行后处理,例如标记视频、生成告警等 - self.process_result(result, device.device_id) - except Exception as e: - print(f"设备 {device.device_id} 处理时出错: {e}") - break - video_stream.close() - print(f"设备 {device.device_id} 的检测线程结束") + @staticmethod + def get_device_thread_id(device_id): + return f'device_{device_id}' def start_device_thread(self, device: Device): """为单个设备启动检测线程""" if not device.input_stream_url: - print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程') + logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return - if device.device_id in self.threads: - print(f"设备 {device.device_id} 已经在运行中") + thread_id = self.get_device_thread_id(device.id) + + if self.thread_pool.check_task_is_running(thread_id=thread_id): + logger.info(f"设备 {device.code} 已经在运行中") return # 获取设备绑定的模型列表 - relations = self.relation_service.get_device_models(device.device_id) - self.device_model_relations[device.device_id] = relations + relations = self.relation_service.get_device_models(device.id) + relations = [r for r in relations if r.is_use] + self.device_model_relations[device.id] = relations if not relations: - print(f"设备 {device.code} 未绑定模型,无法启动检测") + logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") return - future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}') + model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations + if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] + device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list) + future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) + self.device_tasks[device.id] = device_detection_task def stop_device_thread(self, device_id): - """todo 控制标志位 停止指定设备的检测线程""" - if device_id in self.threads: - self.threads[device_id].cancel() # 尝试取消线程 - del self.threads[device_id] - print(f"设备 {device_id} 的检测线程已停止") + """停止指定设备的检测线程""" + if device_id in self.device_tasks: + self.device_tasks[device_id].stop_detection_task() + del self.device_tasks[device_id] + + # def stop_device_thread(self, device_id): + # """停止指定设备的检测线程,并确保其成功停止""" + # if device_id in self.device_tasks: + # # 获取线程 ID 和 future 对象 + # thread_id = self.get_device_thread_id(device_id) + # future = self.thread_pool.check_task_stopped(thread_id=thread_id) + # + # if future: + # # 调用 stop_detection_task 停止任务 + # self.device_tasks[device_id].stop_detection_task() + # + # try: + # # 设置超时时间等待任务停止(例如10秒) + # result = future.result(timeout=10) + # logger.info(f"Task for device {device_id} stopped successfully.") + # except TimeoutError: + # logger.error(f"Task for device {device_id} did not stop within the timeout.") + # except Exception as e: + # logger.error(f"Task for device {device_id} encountered an error while stopping: {e}") + # finally: + # # 确保无论任务是否停止,都将其从任务列表中移除 + # del self.device_tasks[device_id] + # else: + # logger.warning(f"No task found for device {device_id}.") + # else: + # logger.warning(f"No task exists for device {device_id} in device_tasks.") def restart_device_thread(self, device_id): self.stop_device_thread(device_id) @@ -97,6 +117,7 @@ elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: + #todo 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 self.restart_device_thread(device_id) def on_model_change(self, model_id, change_type): @@ -113,5 +134,6 @@ def on_relation_change(self, device_id, change_type): """设备模型关系变化回调""" + # todo 判断模型绑定关系是否真的发生了变化,避免不必要的重启 if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: self.restart_device_thread(device_id) diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py new file mode 100644 index 0000000..ff2d513 --- /dev/null +++ b/algo/device_detection_task.py @@ -0,0 +1,106 @@ +from dataclasses import dataclass +import importlib +from datetime import datetime +from threading import Event +from typing import List, Dict + +from algo.model_manager import AlgoModelExec +from algo.stream_loader import OpenCVStreamLoad +from common.global_thread_pool import GlobalThreadPool +from common.string_utils import camel_to_snake +from db.database import get_db +from entity.device import Device +from entity.frame_analysis_result import FrameAnalysisResultCreate +from services.device_frame_service import DeviceFrameService +from services.frame_analysis_result_service import FrameAnalysisResultService + + +def get_class(module_name, class_name): + # 动态导入模块 + module = importlib.import_module(module_name) + # 使用 getattr 从模块中获取类 + return getattr(module, class_name) + + +@dataclass +class DetectionResult: + object_class_id: int + object_class_name: str + confidence: float + location: str + + @classmethod + def from_dict(cls, data: Dict) -> 'DetectionResult': + return DetectionResult( + object_class_id=data['object_class_id'], + object_class_name=data['object_class_name'], + confidence=data['confidence'], + location=data['location'] + ) + + +class DeviceDetectionTask: + def __init__(self, device: Device, model_exec_list: List[AlgoModelExec]): + self.device = device + self.model_exec_list = model_exec_list + self.__stop_event = Event() # 使用 Event 控制线程的运行状态 + self.frame_ts = None + + with next(get_db()) as db: + self.device_frame_service = DeviceFrameService(db) + self.frame_analysis_result_service = FrameAnalysisResultService(db) + + self.thread_pool = GlobalThreadPool() + + self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code) + + def stop_detection_task(self): + self.__stop_event.set() + self.stream_loader.stop() # 停止视频流加载的线程 + + def check_frame_interval(self): + if self.device.image_save_interval < 0: + return False + if self.frame_ts is None: + self.frame_ts = datetime.now() + return True + if datetime.now() - self.frame_ts > self.device.image_save_interval: + self.frame_ts = datetime.now() + return True + return False + + def save_frame_results(self, frame, results_map): + if not self.check_frame_interval(): + return + device_frame = self.device_frame_service.add_frame(self.device.id, frame) + frame_id = device_frame.id + frame_results = [] + for model_exec_id, results in results_map.items(): + for r in results: + frame_result = FrameAnalysisResultCreate( + device_id=self.device.id, + frame_id=frame_id, + algo_model_id=model_exec_id, + object_class_id=r.object_class_id, + object_class_name=r.object_class_name, + confidence=r.conf, + location=r.location, + ) + frame_results.append(frame_result) + self.frame_analysis_result_service.add_frame_analysis_results(frame_results) + + def run(self): + for frame in self.stream_loader: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + + results_map = {} + for model_exec in self.model_exec_list: + handle_task_name = model_exec.algo_model_info.handle_task + handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) + handler_instance = handler_cls(model_exec) + frame, results = handler_instance.run(frame, None) + results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] + + # 结果处理 + self.thread_pool.submit_task(self.save_frame_results, frame, results_map) diff --git a/algo/model_manager.py b/algo/model_manager.py index 3ab0709..535c666 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -1,5 +1,6 @@ +import os.path from dataclasses import dataclass -from typing import Optional +from typing import Optional, Dict import numpy as np from ultralytics import YOLO @@ -13,7 +14,7 @@ class AlgoModelExec: algo_model_id: int algo_model_info: AlgoModel - algo_model_exec: Optional[object] = None + algo_model_exec: Optional[YOLO] = None input_size: int = 640 @@ -21,11 +22,9 @@ def __init__(self, db, model_warm_up=5): self.db = db self.model_service = ModelService(self.db) - self.models = {} + self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up - self.load_models() - def query_model_inuse(self): algo_model_list = list(self.model_service.get_models_in_use()) for algo_model in algo_model_list: @@ -35,6 +34,7 @@ ) def load_models(self): + logger.info('loading models') self.models = {} self.query_model_inuse() for algo_model_id, algo_model_exec in self.models.items(): @@ -45,6 +45,10 @@ model_path = algo_model_exec.algo_model_info.path logger.info(f'loading model {model_name}: {model_path}') + if not os.path.exists(model_path): + logger.info(f'model path:{model_path} not exists') + return + algo_model_exec.algo_model_exec = YOLO(model_path, task='detect') if self.model_warm_up > 0: logger.info(f'warming up model {model_name}') diff --git a/algo/stream_loader.py b/algo/stream_loader.py new file mode 100644 index 0000000..04faeb2 --- /dev/null +++ b/algo/stream_loader.py @@ -0,0 +1,92 @@ +import cv2 +import time +import numpy as np +from threading import Thread, Event +from common.global_logger import logger + + +class OpenCVStreamLoad: + def __init__(self, camera_url, camera_code, + retry_interval=1, + vid_stride=1): + assert camera_url is not None and camera_url != '' + self.url = camera_url + self.camera_code = camera_code + self.retry_interval = retry_interval + self.vid_stride = vid_stride + + self.cap = self.get_connect() + self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + + _, self.frame = self.cap.read() + self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 + self.thread = Thread(target=self.update, daemon=True) + self.thread.start() + + def create_capture(self): + """ + 尝试创建视频流捕获对象。 + """ + try: + cap = cv2.VideoCapture(self.url) + # 可以在这里设置cap的一些属性,如果需要的话 + return cap + except Exception as e: + logger.error(e) + return None + + def get_connect(self): + """ + 尝试重新连接,直到成功。 + """ + cap = None + while cap is None or not cap.isOpened(): + logger.info(f"{self.url} try to connect...") + cap = self.create_capture() + if cap is None or not cap.isOpened(): + logger.info(f"{self.url} connect failed, retry after {self.retry_interval} second...") + time.sleep(self.retry_interval) # 等待一段时间后重试 + else: + logger.info(f"{self.url} connect success!") + return cap + + def update(self): + vid_n = 0 + log_n = 0 + while True: + vid_n += 1 + if vid_n % self.vid_stride == 0: + try: + ret, frame = self.cap.read() + if not ret: + logger.info("disconnect, try to reconnect...") + self.cap.release() # 释放当前的捕获对象 + self.cap = self.get_connect() # 尝试重新连接 + self.frame = np.zeros_like(self.frame) + continue # 跳过当前循环的剩余部分 + else: + vid_n += 1 + self.frame = frame + # cv2.imwrite('cv_test.jpg', frame) + if log_n % 1000 == 0: + logger.debug('cap success') + log_n = (log_n + 1) % 250 + except Exception as e: + logger.error("update fail", e) + if self.cap is not None: + self.cap.release() + self.cap = self.get_connect() # 尝试重新连接 + + def __iter__(self): + return self + + def __next__(self): + return self.frame + + def stop(self): + """ 停止视频流读取线程 """ + self.__stop_event.set() + self.thread.join() # 确保线程已完全终止 + self.cap.release() diff --git a/common/biz_exception.py b/common/biz_exception.py index 946a794..4659571 100644 --- a/common/biz_exception.py +++ b/common/biz_exception.py @@ -11,4 +11,7 @@ class BizExceptionHandlers: @staticmethod async def biz_exception_handler(request: Request, exc: HTTPException): - standard_error_response(code=exc.status_code, message=exc.detail) + return standard_error_response(code=exc.status_code, message=exc.detail) + # 使用 JSONResponse 返回响应 + # return JSONResponse(status_code=exc.status_code, content=response_data) + diff --git a/.gitignore b/.gitignore index e530a86..6c5fd4c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ test.py -/logs/* \ No newline at end of file +/logs/* +.idea \ No newline at end of file diff --git a/algo/algo_runner.py b/algo/algo_runner.py index ee07eaf..3b07e0a 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,4 @@ +from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager from common.consts import NotifyChangeType from common.global_thread_pool import GlobalThreadPool @@ -6,25 +7,27 @@ from services.device_model_relation_service import DeviceModelRelationService from services.device_service import DeviceService from services.model_service import ModelService - +from common.global_logger import logger class AlgoRunner: def __init__(self): - self.db = get_db() - self.device_service = DeviceService(self.db) - self.model_service = ModelService(self.db) - self.relation_service = DeviceModelRelationService(self.db) - self.model_manager = ModelManager(self.db) + with next(get_db()) as db: + self.device_service = DeviceService(db) + self.model_service = ModelService(db) + self.relation_service = DeviceModelRelationService(db) + self.model_manager = ModelManager(db) + self.thread_pool = GlobalThreadPool() - self.threads = {} # 用于存储设备对应的线程 + self.device_tasks = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 self.device_service.register_change_callback(self.on_device_change) self.model_service.register_change_callback(self.on_model_change) - self.relation_service.relation_change_callbacks(self.on_relation_change) + self.relation_service.register_change_callback(self.on_relation_change) def start(self): + logger.info("Starting AlgoRunner...") """在程序启动时调用,读取设备和模型,启动检测线程""" self.model_manager.load_models() self.load_and_start_devices() @@ -32,56 +35,73 @@ def load_and_start_devices(self): """从数据库读取设备列表并启动线程""" devices = self.device_service.get_device_list() - devices = [device for device in devices if device.input_stream_url] for device in devices: self.start_device_thread(device) - def run_detection(self, device: Device): - """todo 设备目标检测主逻辑""" - print(f"设备 {device.device_id} 的检测线程启动") - video_stream = self.device_service.get_video_stream(device.device_id) - - while True: - try: - frame = video_stream.read_frame() - for model in self.model_manager.models: - # 调用目标检测模型 - result = self.model_service.run_inference(model, frame) - # 在此处进行后处理,例如标记视频、生成告警等 - self.process_result(result, device.device_id) - except Exception as e: - print(f"设备 {device.device_id} 处理时出错: {e}") - break - video_stream.close() - print(f"设备 {device.device_id} 的检测线程结束") + @staticmethod + def get_device_thread_id(device_id): + return f'device_{device_id}' def start_device_thread(self, device: Device): """为单个设备启动检测线程""" if not device.input_stream_url: - print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程') + logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return - if device.device_id in self.threads: - print(f"设备 {device.device_id} 已经在运行中") + thread_id = self.get_device_thread_id(device.id) + + if self.thread_pool.check_task_is_running(thread_id=thread_id): + logger.info(f"设备 {device.code} 已经在运行中") return # 获取设备绑定的模型列表 - relations = self.relation_service.get_device_models(device.device_id) - self.device_model_relations[device.device_id] = relations + relations = self.relation_service.get_device_models(device.id) + relations = [r for r in relations if r.is_use] + self.device_model_relations[device.id] = relations if not relations: - print(f"设备 {device.code} 未绑定模型,无法启动检测") + logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") return - future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}') + model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations + if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] + device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list) + future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) + self.device_tasks[device.id] = device_detection_task def stop_device_thread(self, device_id): - """todo 控制标志位 停止指定设备的检测线程""" - if device_id in self.threads: - self.threads[device_id].cancel() # 尝试取消线程 - del self.threads[device_id] - print(f"设备 {device_id} 的检测线程已停止") + """停止指定设备的检测线程""" + if device_id in self.device_tasks: + self.device_tasks[device_id].stop_detection_task() + del self.device_tasks[device_id] + + # def stop_device_thread(self, device_id): + # """停止指定设备的检测线程,并确保其成功停止""" + # if device_id in self.device_tasks: + # # 获取线程 ID 和 future 对象 + # thread_id = self.get_device_thread_id(device_id) + # future = self.thread_pool.check_task_stopped(thread_id=thread_id) + # + # if future: + # # 调用 stop_detection_task 停止任务 + # self.device_tasks[device_id].stop_detection_task() + # + # try: + # # 设置超时时间等待任务停止(例如10秒) + # result = future.result(timeout=10) + # logger.info(f"Task for device {device_id} stopped successfully.") + # except TimeoutError: + # logger.error(f"Task for device {device_id} did not stop within the timeout.") + # except Exception as e: + # logger.error(f"Task for device {device_id} encountered an error while stopping: {e}") + # finally: + # # 确保无论任务是否停止,都将其从任务列表中移除 + # del self.device_tasks[device_id] + # else: + # logger.warning(f"No task found for device {device_id}.") + # else: + # logger.warning(f"No task exists for device {device_id} in device_tasks.") def restart_device_thread(self, device_id): self.stop_device_thread(device_id) @@ -97,6 +117,7 @@ elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: + #todo 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 self.restart_device_thread(device_id) def on_model_change(self, model_id, change_type): @@ -113,5 +134,6 @@ def on_relation_change(self, device_id, change_type): """设备模型关系变化回调""" + # todo 判断模型绑定关系是否真的发生了变化,避免不必要的重启 if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: self.restart_device_thread(device_id) diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py new file mode 100644 index 0000000..ff2d513 --- /dev/null +++ b/algo/device_detection_task.py @@ -0,0 +1,106 @@ +from dataclasses import dataclass +import importlib +from datetime import datetime +from threading import Event +from typing import List, Dict + +from algo.model_manager import AlgoModelExec +from algo.stream_loader import OpenCVStreamLoad +from common.global_thread_pool import GlobalThreadPool +from common.string_utils import camel_to_snake +from db.database import get_db +from entity.device import Device +from entity.frame_analysis_result import FrameAnalysisResultCreate +from services.device_frame_service import DeviceFrameService +from services.frame_analysis_result_service import FrameAnalysisResultService + + +def get_class(module_name, class_name): + # 动态导入模块 + module = importlib.import_module(module_name) + # 使用 getattr 从模块中获取类 + return getattr(module, class_name) + + +@dataclass +class DetectionResult: + object_class_id: int + object_class_name: str + confidence: float + location: str + + @classmethod + def from_dict(cls, data: Dict) -> 'DetectionResult': + return DetectionResult( + object_class_id=data['object_class_id'], + object_class_name=data['object_class_name'], + confidence=data['confidence'], + location=data['location'] + ) + + +class DeviceDetectionTask: + def __init__(self, device: Device, model_exec_list: List[AlgoModelExec]): + self.device = device + self.model_exec_list = model_exec_list + self.__stop_event = Event() # 使用 Event 控制线程的运行状态 + self.frame_ts = None + + with next(get_db()) as db: + self.device_frame_service = DeviceFrameService(db) + self.frame_analysis_result_service = FrameAnalysisResultService(db) + + self.thread_pool = GlobalThreadPool() + + self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code) + + def stop_detection_task(self): + self.__stop_event.set() + self.stream_loader.stop() # 停止视频流加载的线程 + + def check_frame_interval(self): + if self.device.image_save_interval < 0: + return False + if self.frame_ts is None: + self.frame_ts = datetime.now() + return True + if datetime.now() - self.frame_ts > self.device.image_save_interval: + self.frame_ts = datetime.now() + return True + return False + + def save_frame_results(self, frame, results_map): + if not self.check_frame_interval(): + return + device_frame = self.device_frame_service.add_frame(self.device.id, frame) + frame_id = device_frame.id + frame_results = [] + for model_exec_id, results in results_map.items(): + for r in results: + frame_result = FrameAnalysisResultCreate( + device_id=self.device.id, + frame_id=frame_id, + algo_model_id=model_exec_id, + object_class_id=r.object_class_id, + object_class_name=r.object_class_name, + confidence=r.conf, + location=r.location, + ) + frame_results.append(frame_result) + self.frame_analysis_result_service.add_frame_analysis_results(frame_results) + + def run(self): + for frame in self.stream_loader: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + + results_map = {} + for model_exec in self.model_exec_list: + handle_task_name = model_exec.algo_model_info.handle_task + handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) + handler_instance = handler_cls(model_exec) + frame, results = handler_instance.run(frame, None) + results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] + + # 结果处理 + self.thread_pool.submit_task(self.save_frame_results, frame, results_map) diff --git a/algo/model_manager.py b/algo/model_manager.py index 3ab0709..535c666 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -1,5 +1,6 @@ +import os.path from dataclasses import dataclass -from typing import Optional +from typing import Optional, Dict import numpy as np from ultralytics import YOLO @@ -13,7 +14,7 @@ class AlgoModelExec: algo_model_id: int algo_model_info: AlgoModel - algo_model_exec: Optional[object] = None + algo_model_exec: Optional[YOLO] = None input_size: int = 640 @@ -21,11 +22,9 @@ def __init__(self, db, model_warm_up=5): self.db = db self.model_service = ModelService(self.db) - self.models = {} + self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up - self.load_models() - def query_model_inuse(self): algo_model_list = list(self.model_service.get_models_in_use()) for algo_model in algo_model_list: @@ -35,6 +34,7 @@ ) def load_models(self): + logger.info('loading models') self.models = {} self.query_model_inuse() for algo_model_id, algo_model_exec in self.models.items(): @@ -45,6 +45,10 @@ model_path = algo_model_exec.algo_model_info.path logger.info(f'loading model {model_name}: {model_path}') + if not os.path.exists(model_path): + logger.info(f'model path:{model_path} not exists') + return + algo_model_exec.algo_model_exec = YOLO(model_path, task='detect') if self.model_warm_up > 0: logger.info(f'warming up model {model_name}') diff --git a/algo/stream_loader.py b/algo/stream_loader.py new file mode 100644 index 0000000..04faeb2 --- /dev/null +++ b/algo/stream_loader.py @@ -0,0 +1,92 @@ +import cv2 +import time +import numpy as np +from threading import Thread, Event +from common.global_logger import logger + + +class OpenCVStreamLoad: + def __init__(self, camera_url, camera_code, + retry_interval=1, + vid_stride=1): + assert camera_url is not None and camera_url != '' + self.url = camera_url + self.camera_code = camera_code + self.retry_interval = retry_interval + self.vid_stride = vid_stride + + self.cap = self.get_connect() + self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + + _, self.frame = self.cap.read() + self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 + self.thread = Thread(target=self.update, daemon=True) + self.thread.start() + + def create_capture(self): + """ + 尝试创建视频流捕获对象。 + """ + try: + cap = cv2.VideoCapture(self.url) + # 可以在这里设置cap的一些属性,如果需要的话 + return cap + except Exception as e: + logger.error(e) + return None + + def get_connect(self): + """ + 尝试重新连接,直到成功。 + """ + cap = None + while cap is None or not cap.isOpened(): + logger.info(f"{self.url} try to connect...") + cap = self.create_capture() + if cap is None or not cap.isOpened(): + logger.info(f"{self.url} connect failed, retry after {self.retry_interval} second...") + time.sleep(self.retry_interval) # 等待一段时间后重试 + else: + logger.info(f"{self.url} connect success!") + return cap + + def update(self): + vid_n = 0 + log_n = 0 + while True: + vid_n += 1 + if vid_n % self.vid_stride == 0: + try: + ret, frame = self.cap.read() + if not ret: + logger.info("disconnect, try to reconnect...") + self.cap.release() # 释放当前的捕获对象 + self.cap = self.get_connect() # 尝试重新连接 + self.frame = np.zeros_like(self.frame) + continue # 跳过当前循环的剩余部分 + else: + vid_n += 1 + self.frame = frame + # cv2.imwrite('cv_test.jpg', frame) + if log_n % 1000 == 0: + logger.debug('cap success') + log_n = (log_n + 1) % 250 + except Exception as e: + logger.error("update fail", e) + if self.cap is not None: + self.cap.release() + self.cap = self.get_connect() # 尝试重新连接 + + def __iter__(self): + return self + + def __next__(self): + return self.frame + + def stop(self): + """ 停止视频流读取线程 """ + self.__stop_event.set() + self.thread.join() # 确保线程已完全终止 + self.cap.release() diff --git a/common/biz_exception.py b/common/biz_exception.py index 946a794..4659571 100644 --- a/common/biz_exception.py +++ b/common/biz_exception.py @@ -11,4 +11,7 @@ class BizExceptionHandlers: @staticmethod async def biz_exception_handler(request: Request, exc: HTTPException): - standard_error_response(code=exc.status_code, message=exc.detail) + return standard_error_response(code=exc.status_code, message=exc.detail) + # 使用 JSONResponse 返回响应 + # return JSONResponse(status_code=exc.status_code, content=response_data) + diff --git a/common/global_logger.py b/common/global_logger.py index 1d65173..3c4861e 100644 --- a/common/global_logger.py +++ b/common/global_logger.py @@ -3,7 +3,7 @@ import os # 确保日志目录存在 -log_dir = '../logs' +log_dir = 'logs' if not os.path.exists(log_dir): os.makedirs(log_dir) diff --git a/.gitignore b/.gitignore index e530a86..6c5fd4c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ test.py -/logs/* \ No newline at end of file +/logs/* +.idea \ No newline at end of file diff --git a/algo/algo_runner.py b/algo/algo_runner.py index ee07eaf..3b07e0a 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,4 @@ +from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager from common.consts import NotifyChangeType from common.global_thread_pool import GlobalThreadPool @@ -6,25 +7,27 @@ from services.device_model_relation_service import DeviceModelRelationService from services.device_service import DeviceService from services.model_service import ModelService - +from common.global_logger import logger class AlgoRunner: def __init__(self): - self.db = get_db() - self.device_service = DeviceService(self.db) - self.model_service = ModelService(self.db) - self.relation_service = DeviceModelRelationService(self.db) - self.model_manager = ModelManager(self.db) + with next(get_db()) as db: + self.device_service = DeviceService(db) + self.model_service = ModelService(db) + self.relation_service = DeviceModelRelationService(db) + self.model_manager = ModelManager(db) + self.thread_pool = GlobalThreadPool() - self.threads = {} # 用于存储设备对应的线程 + self.device_tasks = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 self.device_service.register_change_callback(self.on_device_change) self.model_service.register_change_callback(self.on_model_change) - self.relation_service.relation_change_callbacks(self.on_relation_change) + self.relation_service.register_change_callback(self.on_relation_change) def start(self): + logger.info("Starting AlgoRunner...") """在程序启动时调用,读取设备和模型,启动检测线程""" self.model_manager.load_models() self.load_and_start_devices() @@ -32,56 +35,73 @@ def load_and_start_devices(self): """从数据库读取设备列表并启动线程""" devices = self.device_service.get_device_list() - devices = [device for device in devices if device.input_stream_url] for device in devices: self.start_device_thread(device) - def run_detection(self, device: Device): - """todo 设备目标检测主逻辑""" - print(f"设备 {device.device_id} 的检测线程启动") - video_stream = self.device_service.get_video_stream(device.device_id) - - while True: - try: - frame = video_stream.read_frame() - for model in self.model_manager.models: - # 调用目标检测模型 - result = self.model_service.run_inference(model, frame) - # 在此处进行后处理,例如标记视频、生成告警等 - self.process_result(result, device.device_id) - except Exception as e: - print(f"设备 {device.device_id} 处理时出错: {e}") - break - video_stream.close() - print(f"设备 {device.device_id} 的检测线程结束") + @staticmethod + def get_device_thread_id(device_id): + return f'device_{device_id}' def start_device_thread(self, device: Device): """为单个设备启动检测线程""" if not device.input_stream_url: - print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程') + logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return - if device.device_id in self.threads: - print(f"设备 {device.device_id} 已经在运行中") + thread_id = self.get_device_thread_id(device.id) + + if self.thread_pool.check_task_is_running(thread_id=thread_id): + logger.info(f"设备 {device.code} 已经在运行中") return # 获取设备绑定的模型列表 - relations = self.relation_service.get_device_models(device.device_id) - self.device_model_relations[device.device_id] = relations + relations = self.relation_service.get_device_models(device.id) + relations = [r for r in relations if r.is_use] + self.device_model_relations[device.id] = relations if not relations: - print(f"设备 {device.code} 未绑定模型,无法启动检测") + logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") return - future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}') + model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations + if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] + device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list) + future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) + self.device_tasks[device.id] = device_detection_task def stop_device_thread(self, device_id): - """todo 控制标志位 停止指定设备的检测线程""" - if device_id in self.threads: - self.threads[device_id].cancel() # 尝试取消线程 - del self.threads[device_id] - print(f"设备 {device_id} 的检测线程已停止") + """停止指定设备的检测线程""" + if device_id in self.device_tasks: + self.device_tasks[device_id].stop_detection_task() + del self.device_tasks[device_id] + + # def stop_device_thread(self, device_id): + # """停止指定设备的检测线程,并确保其成功停止""" + # if device_id in self.device_tasks: + # # 获取线程 ID 和 future 对象 + # thread_id = self.get_device_thread_id(device_id) + # future = self.thread_pool.check_task_stopped(thread_id=thread_id) + # + # if future: + # # 调用 stop_detection_task 停止任务 + # self.device_tasks[device_id].stop_detection_task() + # + # try: + # # 设置超时时间等待任务停止(例如10秒) + # result = future.result(timeout=10) + # logger.info(f"Task for device {device_id} stopped successfully.") + # except TimeoutError: + # logger.error(f"Task for device {device_id} did not stop within the timeout.") + # except Exception as e: + # logger.error(f"Task for device {device_id} encountered an error while stopping: {e}") + # finally: + # # 确保无论任务是否停止,都将其从任务列表中移除 + # del self.device_tasks[device_id] + # else: + # logger.warning(f"No task found for device {device_id}.") + # else: + # logger.warning(f"No task exists for device {device_id} in device_tasks.") def restart_device_thread(self, device_id): self.stop_device_thread(device_id) @@ -97,6 +117,7 @@ elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: + #todo 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 self.restart_device_thread(device_id) def on_model_change(self, model_id, change_type): @@ -113,5 +134,6 @@ def on_relation_change(self, device_id, change_type): """设备模型关系变化回调""" + # todo 判断模型绑定关系是否真的发生了变化,避免不必要的重启 if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: self.restart_device_thread(device_id) diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py new file mode 100644 index 0000000..ff2d513 --- /dev/null +++ b/algo/device_detection_task.py @@ -0,0 +1,106 @@ +from dataclasses import dataclass +import importlib +from datetime import datetime +from threading import Event +from typing import List, Dict + +from algo.model_manager import AlgoModelExec +from algo.stream_loader import OpenCVStreamLoad +from common.global_thread_pool import GlobalThreadPool +from common.string_utils import camel_to_snake +from db.database import get_db +from entity.device import Device +from entity.frame_analysis_result import FrameAnalysisResultCreate +from services.device_frame_service import DeviceFrameService +from services.frame_analysis_result_service import FrameAnalysisResultService + + +def get_class(module_name, class_name): + # 动态导入模块 + module = importlib.import_module(module_name) + # 使用 getattr 从模块中获取类 + return getattr(module, class_name) + + +@dataclass +class DetectionResult: + object_class_id: int + object_class_name: str + confidence: float + location: str + + @classmethod + def from_dict(cls, data: Dict) -> 'DetectionResult': + return DetectionResult( + object_class_id=data['object_class_id'], + object_class_name=data['object_class_name'], + confidence=data['confidence'], + location=data['location'] + ) + + +class DeviceDetectionTask: + def __init__(self, device: Device, model_exec_list: List[AlgoModelExec]): + self.device = device + self.model_exec_list = model_exec_list + self.__stop_event = Event() # 使用 Event 控制线程的运行状态 + self.frame_ts = None + + with next(get_db()) as db: + self.device_frame_service = DeviceFrameService(db) + self.frame_analysis_result_service = FrameAnalysisResultService(db) + + self.thread_pool = GlobalThreadPool() + + self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code) + + def stop_detection_task(self): + self.__stop_event.set() + self.stream_loader.stop() # 停止视频流加载的线程 + + def check_frame_interval(self): + if self.device.image_save_interval < 0: + return False + if self.frame_ts is None: + self.frame_ts = datetime.now() + return True + if datetime.now() - self.frame_ts > self.device.image_save_interval: + self.frame_ts = datetime.now() + return True + return False + + def save_frame_results(self, frame, results_map): + if not self.check_frame_interval(): + return + device_frame = self.device_frame_service.add_frame(self.device.id, frame) + frame_id = device_frame.id + frame_results = [] + for model_exec_id, results in results_map.items(): + for r in results: + frame_result = FrameAnalysisResultCreate( + device_id=self.device.id, + frame_id=frame_id, + algo_model_id=model_exec_id, + object_class_id=r.object_class_id, + object_class_name=r.object_class_name, + confidence=r.conf, + location=r.location, + ) + frame_results.append(frame_result) + self.frame_analysis_result_service.add_frame_analysis_results(frame_results) + + def run(self): + for frame in self.stream_loader: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + + results_map = {} + for model_exec in self.model_exec_list: + handle_task_name = model_exec.algo_model_info.handle_task + handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) + handler_instance = handler_cls(model_exec) + frame, results = handler_instance.run(frame, None) + results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] + + # 结果处理 + self.thread_pool.submit_task(self.save_frame_results, frame, results_map) diff --git a/algo/model_manager.py b/algo/model_manager.py index 3ab0709..535c666 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -1,5 +1,6 @@ +import os.path from dataclasses import dataclass -from typing import Optional +from typing import Optional, Dict import numpy as np from ultralytics import YOLO @@ -13,7 +14,7 @@ class AlgoModelExec: algo_model_id: int algo_model_info: AlgoModel - algo_model_exec: Optional[object] = None + algo_model_exec: Optional[YOLO] = None input_size: int = 640 @@ -21,11 +22,9 @@ def __init__(self, db, model_warm_up=5): self.db = db self.model_service = ModelService(self.db) - self.models = {} + self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up - self.load_models() - def query_model_inuse(self): algo_model_list = list(self.model_service.get_models_in_use()) for algo_model in algo_model_list: @@ -35,6 +34,7 @@ ) def load_models(self): + logger.info('loading models') self.models = {} self.query_model_inuse() for algo_model_id, algo_model_exec in self.models.items(): @@ -45,6 +45,10 @@ model_path = algo_model_exec.algo_model_info.path logger.info(f'loading model {model_name}: {model_path}') + if not os.path.exists(model_path): + logger.info(f'model path:{model_path} not exists') + return + algo_model_exec.algo_model_exec = YOLO(model_path, task='detect') if self.model_warm_up > 0: logger.info(f'warming up model {model_name}') diff --git a/algo/stream_loader.py b/algo/stream_loader.py new file mode 100644 index 0000000..04faeb2 --- /dev/null +++ b/algo/stream_loader.py @@ -0,0 +1,92 @@ +import cv2 +import time +import numpy as np +from threading import Thread, Event +from common.global_logger import logger + + +class OpenCVStreamLoad: + def __init__(self, camera_url, camera_code, + retry_interval=1, + vid_stride=1): + assert camera_url is not None and camera_url != '' + self.url = camera_url + self.camera_code = camera_code + self.retry_interval = retry_interval + self.vid_stride = vid_stride + + self.cap = self.get_connect() + self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + + _, self.frame = self.cap.read() + self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 + self.thread = Thread(target=self.update, daemon=True) + self.thread.start() + + def create_capture(self): + """ + 尝试创建视频流捕获对象。 + """ + try: + cap = cv2.VideoCapture(self.url) + # 可以在这里设置cap的一些属性,如果需要的话 + return cap + except Exception as e: + logger.error(e) + return None + + def get_connect(self): + """ + 尝试重新连接,直到成功。 + """ + cap = None + while cap is None or not cap.isOpened(): + logger.info(f"{self.url} try to connect...") + cap = self.create_capture() + if cap is None or not cap.isOpened(): + logger.info(f"{self.url} connect failed, retry after {self.retry_interval} second...") + time.sleep(self.retry_interval) # 等待一段时间后重试 + else: + logger.info(f"{self.url} connect success!") + return cap + + def update(self): + vid_n = 0 + log_n = 0 + while True: + vid_n += 1 + if vid_n % self.vid_stride == 0: + try: + ret, frame = self.cap.read() + if not ret: + logger.info("disconnect, try to reconnect...") + self.cap.release() # 释放当前的捕获对象 + self.cap = self.get_connect() # 尝试重新连接 + self.frame = np.zeros_like(self.frame) + continue # 跳过当前循环的剩余部分 + else: + vid_n += 1 + self.frame = frame + # cv2.imwrite('cv_test.jpg', frame) + if log_n % 1000 == 0: + logger.debug('cap success') + log_n = (log_n + 1) % 250 + except Exception as e: + logger.error("update fail", e) + if self.cap is not None: + self.cap.release() + self.cap = self.get_connect() # 尝试重新连接 + + def __iter__(self): + return self + + def __next__(self): + return self.frame + + def stop(self): + """ 停止视频流读取线程 """ + self.__stop_event.set() + self.thread.join() # 确保线程已完全终止 + self.cap.release() diff --git a/common/biz_exception.py b/common/biz_exception.py index 946a794..4659571 100644 --- a/common/biz_exception.py +++ b/common/biz_exception.py @@ -11,4 +11,7 @@ class BizExceptionHandlers: @staticmethod async def biz_exception_handler(request: Request, exc: HTTPException): - standard_error_response(code=exc.status_code, message=exc.detail) + return standard_error_response(code=exc.status_code, message=exc.detail) + # 使用 JSONResponse 返回响应 + # return JSONResponse(status_code=exc.status_code, content=response_data) + diff --git a/common/global_logger.py b/common/global_logger.py index 1d65173..3c4861e 100644 --- a/common/global_logger.py +++ b/common/global_logger.py @@ -3,7 +3,7 @@ import os # 确保日志目录存在 -log_dir = '../logs' +log_dir = 'logs' if not os.path.exists(log_dir): os.makedirs(log_dir) diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py index 21d7fd5..f55359b 100644 --- a/common/global_thread_pool.py +++ b/common/global_thread_pool.py @@ -2,6 +2,8 @@ from concurrent.futures import ThreadPoolExecutor import threading +from common.global_logger import logger + def generate_thread_id(): """生成唯一的线程 ID""" @@ -44,15 +46,33 @@ else: return False + def check_task_stopped(self, thread_id): + """判断任务是否已停止""" + future = self.task_map.get(thread_id) + if future: + if future.done(): + try: + # 确保任务是正常完成的,而不是因为异常停止 + future.result() # 如果任务抛出异常,这里会捕获 + logger.info(f"Task {thread_id} has stopped successfully.") + except Exception as e: + logger.error(f"Task {thread_id} encountered an error: {e}") + return True # 无论成功还是失败,任务已停止 + else: + return False # 任务仍在运行 + else: + logger.warning(f"No task found with thread ID {thread_id}.") + return True # 如果找不到该任务,认为它已经停止(或者不存在) + def stop_task(self, thread_id): """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务""" future = self.task_map.get(thread_id) if future: future.cancel() # 尝试取消任务 - print(f"任务 {thread_id} 已取消") + logger.info(f"任务 {thread_id} 已取消") del self.task_map[thread_id] # 从任务映射中删除 else: - print(f"未找到线程 ID {thread_id}") + logger.info(f"未找到线程 ID {thread_id}") def shutdown(self, wait=True): """关闭线程池""" diff --git a/.gitignore b/.gitignore index e530a86..6c5fd4c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ test.py -/logs/* \ No newline at end of file +/logs/* +.idea \ No newline at end of file diff --git a/algo/algo_runner.py b/algo/algo_runner.py index ee07eaf..3b07e0a 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,4 @@ +from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager from common.consts import NotifyChangeType from common.global_thread_pool import GlobalThreadPool @@ -6,25 +7,27 @@ from services.device_model_relation_service import DeviceModelRelationService from services.device_service import DeviceService from services.model_service import ModelService - +from common.global_logger import logger class AlgoRunner: def __init__(self): - self.db = get_db() - self.device_service = DeviceService(self.db) - self.model_service = ModelService(self.db) - self.relation_service = DeviceModelRelationService(self.db) - self.model_manager = ModelManager(self.db) + with next(get_db()) as db: + self.device_service = DeviceService(db) + self.model_service = ModelService(db) + self.relation_service = DeviceModelRelationService(db) + self.model_manager = ModelManager(db) + self.thread_pool = GlobalThreadPool() - self.threads = {} # 用于存储设备对应的线程 + self.device_tasks = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 self.device_service.register_change_callback(self.on_device_change) self.model_service.register_change_callback(self.on_model_change) - self.relation_service.relation_change_callbacks(self.on_relation_change) + self.relation_service.register_change_callback(self.on_relation_change) def start(self): + logger.info("Starting AlgoRunner...") """在程序启动时调用,读取设备和模型,启动检测线程""" self.model_manager.load_models() self.load_and_start_devices() @@ -32,56 +35,73 @@ def load_and_start_devices(self): """从数据库读取设备列表并启动线程""" devices = self.device_service.get_device_list() - devices = [device for device in devices if device.input_stream_url] for device in devices: self.start_device_thread(device) - def run_detection(self, device: Device): - """todo 设备目标检测主逻辑""" - print(f"设备 {device.device_id} 的检测线程启动") - video_stream = self.device_service.get_video_stream(device.device_id) - - while True: - try: - frame = video_stream.read_frame() - for model in self.model_manager.models: - # 调用目标检测模型 - result = self.model_service.run_inference(model, frame) - # 在此处进行后处理,例如标记视频、生成告警等 - self.process_result(result, device.device_id) - except Exception as e: - print(f"设备 {device.device_id} 处理时出错: {e}") - break - video_stream.close() - print(f"设备 {device.device_id} 的检测线程结束") + @staticmethod + def get_device_thread_id(device_id): + return f'device_{device_id}' def start_device_thread(self, device: Device): """为单个设备启动检测线程""" if not device.input_stream_url: - print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程') + logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return - if device.device_id in self.threads: - print(f"设备 {device.device_id} 已经在运行中") + thread_id = self.get_device_thread_id(device.id) + + if self.thread_pool.check_task_is_running(thread_id=thread_id): + logger.info(f"设备 {device.code} 已经在运行中") return # 获取设备绑定的模型列表 - relations = self.relation_service.get_device_models(device.device_id) - self.device_model_relations[device.device_id] = relations + relations = self.relation_service.get_device_models(device.id) + relations = [r for r in relations if r.is_use] + self.device_model_relations[device.id] = relations if not relations: - print(f"设备 {device.code} 未绑定模型,无法启动检测") + logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") return - future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}') + model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations + if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] + device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list) + future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) + self.device_tasks[device.id] = device_detection_task def stop_device_thread(self, device_id): - """todo 控制标志位 停止指定设备的检测线程""" - if device_id in self.threads: - self.threads[device_id].cancel() # 尝试取消线程 - del self.threads[device_id] - print(f"设备 {device_id} 的检测线程已停止") + """停止指定设备的检测线程""" + if device_id in self.device_tasks: + self.device_tasks[device_id].stop_detection_task() + del self.device_tasks[device_id] + + # def stop_device_thread(self, device_id): + # """停止指定设备的检测线程,并确保其成功停止""" + # if device_id in self.device_tasks: + # # 获取线程 ID 和 future 对象 + # thread_id = self.get_device_thread_id(device_id) + # future = self.thread_pool.check_task_stopped(thread_id=thread_id) + # + # if future: + # # 调用 stop_detection_task 停止任务 + # self.device_tasks[device_id].stop_detection_task() + # + # try: + # # 设置超时时间等待任务停止(例如10秒) + # result = future.result(timeout=10) + # logger.info(f"Task for device {device_id} stopped successfully.") + # except TimeoutError: + # logger.error(f"Task for device {device_id} did not stop within the timeout.") + # except Exception as e: + # logger.error(f"Task for device {device_id} encountered an error while stopping: {e}") + # finally: + # # 确保无论任务是否停止,都将其从任务列表中移除 + # del self.device_tasks[device_id] + # else: + # logger.warning(f"No task found for device {device_id}.") + # else: + # logger.warning(f"No task exists for device {device_id} in device_tasks.") def restart_device_thread(self, device_id): self.stop_device_thread(device_id) @@ -97,6 +117,7 @@ elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: + #todo 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 self.restart_device_thread(device_id) def on_model_change(self, model_id, change_type): @@ -113,5 +134,6 @@ def on_relation_change(self, device_id, change_type): """设备模型关系变化回调""" + # todo 判断模型绑定关系是否真的发生了变化,避免不必要的重启 if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: self.restart_device_thread(device_id) diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py new file mode 100644 index 0000000..ff2d513 --- /dev/null +++ b/algo/device_detection_task.py @@ -0,0 +1,106 @@ +from dataclasses import dataclass +import importlib +from datetime import datetime +from threading import Event +from typing import List, Dict + +from algo.model_manager import AlgoModelExec +from algo.stream_loader import OpenCVStreamLoad +from common.global_thread_pool import GlobalThreadPool +from common.string_utils import camel_to_snake +from db.database import get_db +from entity.device import Device +from entity.frame_analysis_result import FrameAnalysisResultCreate +from services.device_frame_service import DeviceFrameService +from services.frame_analysis_result_service import FrameAnalysisResultService + + +def get_class(module_name, class_name): + # 动态导入模块 + module = importlib.import_module(module_name) + # 使用 getattr 从模块中获取类 + return getattr(module, class_name) + + +@dataclass +class DetectionResult: + object_class_id: int + object_class_name: str + confidence: float + location: str + + @classmethod + def from_dict(cls, data: Dict) -> 'DetectionResult': + return DetectionResult( + object_class_id=data['object_class_id'], + object_class_name=data['object_class_name'], + confidence=data['confidence'], + location=data['location'] + ) + + +class DeviceDetectionTask: + def __init__(self, device: Device, model_exec_list: List[AlgoModelExec]): + self.device = device + self.model_exec_list = model_exec_list + self.__stop_event = Event() # 使用 Event 控制线程的运行状态 + self.frame_ts = None + + with next(get_db()) as db: + self.device_frame_service = DeviceFrameService(db) + self.frame_analysis_result_service = FrameAnalysisResultService(db) + + self.thread_pool = GlobalThreadPool() + + self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code) + + def stop_detection_task(self): + self.__stop_event.set() + self.stream_loader.stop() # 停止视频流加载的线程 + + def check_frame_interval(self): + if self.device.image_save_interval < 0: + return False + if self.frame_ts is None: + self.frame_ts = datetime.now() + return True + if datetime.now() - self.frame_ts > self.device.image_save_interval: + self.frame_ts = datetime.now() + return True + return False + + def save_frame_results(self, frame, results_map): + if not self.check_frame_interval(): + return + device_frame = self.device_frame_service.add_frame(self.device.id, frame) + frame_id = device_frame.id + frame_results = [] + for model_exec_id, results in results_map.items(): + for r in results: + frame_result = FrameAnalysisResultCreate( + device_id=self.device.id, + frame_id=frame_id, + algo_model_id=model_exec_id, + object_class_id=r.object_class_id, + object_class_name=r.object_class_name, + confidence=r.conf, + location=r.location, + ) + frame_results.append(frame_result) + self.frame_analysis_result_service.add_frame_analysis_results(frame_results) + + def run(self): + for frame in self.stream_loader: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + + results_map = {} + for model_exec in self.model_exec_list: + handle_task_name = model_exec.algo_model_info.handle_task + handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) + handler_instance = handler_cls(model_exec) + frame, results = handler_instance.run(frame, None) + results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] + + # 结果处理 + self.thread_pool.submit_task(self.save_frame_results, frame, results_map) diff --git a/algo/model_manager.py b/algo/model_manager.py index 3ab0709..535c666 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -1,5 +1,6 @@ +import os.path from dataclasses import dataclass -from typing import Optional +from typing import Optional, Dict import numpy as np from ultralytics import YOLO @@ -13,7 +14,7 @@ class AlgoModelExec: algo_model_id: int algo_model_info: AlgoModel - algo_model_exec: Optional[object] = None + algo_model_exec: Optional[YOLO] = None input_size: int = 640 @@ -21,11 +22,9 @@ def __init__(self, db, model_warm_up=5): self.db = db self.model_service = ModelService(self.db) - self.models = {} + self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up - self.load_models() - def query_model_inuse(self): algo_model_list = list(self.model_service.get_models_in_use()) for algo_model in algo_model_list: @@ -35,6 +34,7 @@ ) def load_models(self): + logger.info('loading models') self.models = {} self.query_model_inuse() for algo_model_id, algo_model_exec in self.models.items(): @@ -45,6 +45,10 @@ model_path = algo_model_exec.algo_model_info.path logger.info(f'loading model {model_name}: {model_path}') + if not os.path.exists(model_path): + logger.info(f'model path:{model_path} not exists') + return + algo_model_exec.algo_model_exec = YOLO(model_path, task='detect') if self.model_warm_up > 0: logger.info(f'warming up model {model_name}') diff --git a/algo/stream_loader.py b/algo/stream_loader.py new file mode 100644 index 0000000..04faeb2 --- /dev/null +++ b/algo/stream_loader.py @@ -0,0 +1,92 @@ +import cv2 +import time +import numpy as np +from threading import Thread, Event +from common.global_logger import logger + + +class OpenCVStreamLoad: + def __init__(self, camera_url, camera_code, + retry_interval=1, + vid_stride=1): + assert camera_url is not None and camera_url != '' + self.url = camera_url + self.camera_code = camera_code + self.retry_interval = retry_interval + self.vid_stride = vid_stride + + self.cap = self.get_connect() + self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + + _, self.frame = self.cap.read() + self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 + self.thread = Thread(target=self.update, daemon=True) + self.thread.start() + + def create_capture(self): + """ + 尝试创建视频流捕获对象。 + """ + try: + cap = cv2.VideoCapture(self.url) + # 可以在这里设置cap的一些属性,如果需要的话 + return cap + except Exception as e: + logger.error(e) + return None + + def get_connect(self): + """ + 尝试重新连接,直到成功。 + """ + cap = None + while cap is None or not cap.isOpened(): + logger.info(f"{self.url} try to connect...") + cap = self.create_capture() + if cap is None or not cap.isOpened(): + logger.info(f"{self.url} connect failed, retry after {self.retry_interval} second...") + time.sleep(self.retry_interval) # 等待一段时间后重试 + else: + logger.info(f"{self.url} connect success!") + return cap + + def update(self): + vid_n = 0 + log_n = 0 + while True: + vid_n += 1 + if vid_n % self.vid_stride == 0: + try: + ret, frame = self.cap.read() + if not ret: + logger.info("disconnect, try to reconnect...") + self.cap.release() # 释放当前的捕获对象 + self.cap = self.get_connect() # 尝试重新连接 + self.frame = np.zeros_like(self.frame) + continue # 跳过当前循环的剩余部分 + else: + vid_n += 1 + self.frame = frame + # cv2.imwrite('cv_test.jpg', frame) + if log_n % 1000 == 0: + logger.debug('cap success') + log_n = (log_n + 1) % 250 + except Exception as e: + logger.error("update fail", e) + if self.cap is not None: + self.cap.release() + self.cap = self.get_connect() # 尝试重新连接 + + def __iter__(self): + return self + + def __next__(self): + return self.frame + + def stop(self): + """ 停止视频流读取线程 """ + self.__stop_event.set() + self.thread.join() # 确保线程已完全终止 + self.cap.release() diff --git a/common/biz_exception.py b/common/biz_exception.py index 946a794..4659571 100644 --- a/common/biz_exception.py +++ b/common/biz_exception.py @@ -11,4 +11,7 @@ class BizExceptionHandlers: @staticmethod async def biz_exception_handler(request: Request, exc: HTTPException): - standard_error_response(code=exc.status_code, message=exc.detail) + return standard_error_response(code=exc.status_code, message=exc.detail) + # 使用 JSONResponse 返回响应 + # return JSONResponse(status_code=exc.status_code, content=response_data) + diff --git a/common/global_logger.py b/common/global_logger.py index 1d65173..3c4861e 100644 --- a/common/global_logger.py +++ b/common/global_logger.py @@ -3,7 +3,7 @@ import os # 确保日志目录存在 -log_dir = '../logs' +log_dir = 'logs' if not os.path.exists(log_dir): os.makedirs(log_dir) diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py index 21d7fd5..f55359b 100644 --- a/common/global_thread_pool.py +++ b/common/global_thread_pool.py @@ -2,6 +2,8 @@ from concurrent.futures import ThreadPoolExecutor import threading +from common.global_logger import logger + def generate_thread_id(): """生成唯一的线程 ID""" @@ -44,15 +46,33 @@ else: return False + def check_task_stopped(self, thread_id): + """判断任务是否已停止""" + future = self.task_map.get(thread_id) + if future: + if future.done(): + try: + # 确保任务是正常完成的,而不是因为异常停止 + future.result() # 如果任务抛出异常,这里会捕获 + logger.info(f"Task {thread_id} has stopped successfully.") + except Exception as e: + logger.error(f"Task {thread_id} encountered an error: {e}") + return True # 无论成功还是失败,任务已停止 + else: + return False # 任务仍在运行 + else: + logger.warning(f"No task found with thread ID {thread_id}.") + return True # 如果找不到该任务,认为它已经停止(或者不存在) + def stop_task(self, thread_id): """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务""" future = self.task_map.get(thread_id) if future: future.cancel() # 尝试取消任务 - print(f"任务 {thread_id} 已取消") + logger.info(f"任务 {thread_id} 已取消") del self.task_map[thread_id] # 从任务映射中删除 else: - print(f"未找到线程 ID {thread_id}") + logger.info(f"未找到线程 ID {thread_id}") def shutdown(self, wait=True): """关闭线程池""" diff --git a/common/string_utils.py b/common/string_utils.py new file mode 100644 index 0000000..551a209 --- /dev/null +++ b/common/string_utils.py @@ -0,0 +1,6 @@ +import re + + +def camel_to_snake(name): + # 将大写字母前加上下划线,并将整个字符串转换为小写 + return re.sub(r'(? 'DetectionResult': + return DetectionResult( + object_class_id=data['object_class_id'], + object_class_name=data['object_class_name'], + confidence=data['confidence'], + location=data['location'] + ) + + +class DeviceDetectionTask: + def __init__(self, device: Device, model_exec_list: List[AlgoModelExec]): + self.device = device + self.model_exec_list = model_exec_list + self.__stop_event = Event() # 使用 Event 控制线程的运行状态 + self.frame_ts = None + + with next(get_db()) as db: + self.device_frame_service = DeviceFrameService(db) + self.frame_analysis_result_service = FrameAnalysisResultService(db) + + self.thread_pool = GlobalThreadPool() + + self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code) + + def stop_detection_task(self): + self.__stop_event.set() + self.stream_loader.stop() # 停止视频流加载的线程 + + def check_frame_interval(self): + if self.device.image_save_interval < 0: + return False + if self.frame_ts is None: + self.frame_ts = datetime.now() + return True + if datetime.now() - self.frame_ts > self.device.image_save_interval: + self.frame_ts = datetime.now() + return True + return False + + def save_frame_results(self, frame, results_map): + if not self.check_frame_interval(): + return + device_frame = self.device_frame_service.add_frame(self.device.id, frame) + frame_id = device_frame.id + frame_results = [] + for model_exec_id, results in results_map.items(): + for r in results: + frame_result = FrameAnalysisResultCreate( + device_id=self.device.id, + frame_id=frame_id, + algo_model_id=model_exec_id, + object_class_id=r.object_class_id, + object_class_name=r.object_class_name, + confidence=r.conf, + location=r.location, + ) + frame_results.append(frame_result) + self.frame_analysis_result_service.add_frame_analysis_results(frame_results) + + def run(self): + for frame in self.stream_loader: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + + results_map = {} + for model_exec in self.model_exec_list: + handle_task_name = model_exec.algo_model_info.handle_task + handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) + handler_instance = handler_cls(model_exec) + frame, results = handler_instance.run(frame, None) + results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] + + # 结果处理 + self.thread_pool.submit_task(self.save_frame_results, frame, results_map) diff --git a/algo/model_manager.py b/algo/model_manager.py index 3ab0709..535c666 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -1,5 +1,6 @@ +import os.path from dataclasses import dataclass -from typing import Optional +from typing import Optional, Dict import numpy as np from ultralytics import YOLO @@ -13,7 +14,7 @@ class AlgoModelExec: algo_model_id: int algo_model_info: AlgoModel - algo_model_exec: Optional[object] = None + algo_model_exec: Optional[YOLO] = None input_size: int = 640 @@ -21,11 +22,9 @@ def __init__(self, db, model_warm_up=5): self.db = db self.model_service = ModelService(self.db) - self.models = {} + self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up - self.load_models() - def query_model_inuse(self): algo_model_list = list(self.model_service.get_models_in_use()) for algo_model in algo_model_list: @@ -35,6 +34,7 @@ ) def load_models(self): + logger.info('loading models') self.models = {} self.query_model_inuse() for algo_model_id, algo_model_exec in self.models.items(): @@ -45,6 +45,10 @@ model_path = algo_model_exec.algo_model_info.path logger.info(f'loading model {model_name}: {model_path}') + if not os.path.exists(model_path): + logger.info(f'model path:{model_path} not exists') + return + algo_model_exec.algo_model_exec = YOLO(model_path, task='detect') if self.model_warm_up > 0: logger.info(f'warming up model {model_name}') diff --git a/algo/stream_loader.py b/algo/stream_loader.py new file mode 100644 index 0000000..04faeb2 --- /dev/null +++ b/algo/stream_loader.py @@ -0,0 +1,92 @@ +import cv2 +import time +import numpy as np +from threading import Thread, Event +from common.global_logger import logger + + +class OpenCVStreamLoad: + def __init__(self, camera_url, camera_code, + retry_interval=1, + vid_stride=1): + assert camera_url is not None and camera_url != '' + self.url = camera_url + self.camera_code = camera_code + self.retry_interval = retry_interval + self.vid_stride = vid_stride + + self.cap = self.get_connect() + self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + + _, self.frame = self.cap.read() + self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 + self.thread = Thread(target=self.update, daemon=True) + self.thread.start() + + def create_capture(self): + """ + 尝试创建视频流捕获对象。 + """ + try: + cap = cv2.VideoCapture(self.url) + # 可以在这里设置cap的一些属性,如果需要的话 + return cap + except Exception as e: + logger.error(e) + return None + + def get_connect(self): + """ + 尝试重新连接,直到成功。 + """ + cap = None + while cap is None or not cap.isOpened(): + logger.info(f"{self.url} try to connect...") + cap = self.create_capture() + if cap is None or not cap.isOpened(): + logger.info(f"{self.url} connect failed, retry after {self.retry_interval} second...") + time.sleep(self.retry_interval) # 等待一段时间后重试 + else: + logger.info(f"{self.url} connect success!") + return cap + + def update(self): + vid_n = 0 + log_n = 0 + while True: + vid_n += 1 + if vid_n % self.vid_stride == 0: + try: + ret, frame = self.cap.read() + if not ret: + logger.info("disconnect, try to reconnect...") + self.cap.release() # 释放当前的捕获对象 + self.cap = self.get_connect() # 尝试重新连接 + self.frame = np.zeros_like(self.frame) + continue # 跳过当前循环的剩余部分 + else: + vid_n += 1 + self.frame = frame + # cv2.imwrite('cv_test.jpg', frame) + if log_n % 1000 == 0: + logger.debug('cap success') + log_n = (log_n + 1) % 250 + except Exception as e: + logger.error("update fail", e) + if self.cap is not None: + self.cap.release() + self.cap = self.get_connect() # 尝试重新连接 + + def __iter__(self): + return self + + def __next__(self): + return self.frame + + def stop(self): + """ 停止视频流读取线程 """ + self.__stop_event.set() + self.thread.join() # 确保线程已完全终止 + self.cap.release() diff --git a/common/biz_exception.py b/common/biz_exception.py index 946a794..4659571 100644 --- a/common/biz_exception.py +++ b/common/biz_exception.py @@ -11,4 +11,7 @@ class BizExceptionHandlers: @staticmethod async def biz_exception_handler(request: Request, exc: HTTPException): - standard_error_response(code=exc.status_code, message=exc.detail) + return standard_error_response(code=exc.status_code, message=exc.detail) + # 使用 JSONResponse 返回响应 + # return JSONResponse(status_code=exc.status_code, content=response_data) + diff --git a/common/global_logger.py b/common/global_logger.py index 1d65173..3c4861e 100644 --- a/common/global_logger.py +++ b/common/global_logger.py @@ -3,7 +3,7 @@ import os # 确保日志目录存在 -log_dir = '../logs' +log_dir = 'logs' if not os.path.exists(log_dir): os.makedirs(log_dir) diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py index 21d7fd5..f55359b 100644 --- a/common/global_thread_pool.py +++ b/common/global_thread_pool.py @@ -2,6 +2,8 @@ from concurrent.futures import ThreadPoolExecutor import threading +from common.global_logger import logger + def generate_thread_id(): """生成唯一的线程 ID""" @@ -44,15 +46,33 @@ else: return False + def check_task_stopped(self, thread_id): + """判断任务是否已停止""" + future = self.task_map.get(thread_id) + if future: + if future.done(): + try: + # 确保任务是正常完成的,而不是因为异常停止 + future.result() # 如果任务抛出异常,这里会捕获 + logger.info(f"Task {thread_id} has stopped successfully.") + except Exception as e: + logger.error(f"Task {thread_id} encountered an error: {e}") + return True # 无论成功还是失败,任务已停止 + else: + return False # 任务仍在运行 + else: + logger.warning(f"No task found with thread ID {thread_id}.") + return True # 如果找不到该任务,认为它已经停止(或者不存在) + def stop_task(self, thread_id): """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务""" future = self.task_map.get(thread_id) if future: future.cancel() # 尝试取消任务 - print(f"任务 {thread_id} 已取消") + logger.info(f"任务 {thread_id} 已取消") del self.task_map[thread_id] # 从任务映射中删除 else: - print(f"未找到线程 ID {thread_id}") + logger.info(f"未找到线程 ID {thread_id}") def shutdown(self, wait=True): """关闭线程池""" diff --git a/common/string_utils.py b/common/string_utils.py new file mode 100644 index 0000000..551a209 --- /dev/null +++ b/common/string_utils.py @@ -0,0 +1,6 @@ +import re + + +def camel_to_snake(name): + # 将大写字母前加上下划线,并将整个字符串转换为小写 + return re.sub(r'(? 'DetectionResult': + return DetectionResult( + object_class_id=data['object_class_id'], + object_class_name=data['object_class_name'], + confidence=data['confidence'], + location=data['location'] + ) + + +class DeviceDetectionTask: + def __init__(self, device: Device, model_exec_list: List[AlgoModelExec]): + self.device = device + self.model_exec_list = model_exec_list + self.__stop_event = Event() # 使用 Event 控制线程的运行状态 + self.frame_ts = None + + with next(get_db()) as db: + self.device_frame_service = DeviceFrameService(db) + self.frame_analysis_result_service = FrameAnalysisResultService(db) + + self.thread_pool = GlobalThreadPool() + + self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code) + + def stop_detection_task(self): + self.__stop_event.set() + self.stream_loader.stop() # 停止视频流加载的线程 + + def check_frame_interval(self): + if self.device.image_save_interval < 0: + return False + if self.frame_ts is None: + self.frame_ts = datetime.now() + return True + if datetime.now() - self.frame_ts > self.device.image_save_interval: + self.frame_ts = datetime.now() + return True + return False + + def save_frame_results(self, frame, results_map): + if not self.check_frame_interval(): + return + device_frame = self.device_frame_service.add_frame(self.device.id, frame) + frame_id = device_frame.id + frame_results = [] + for model_exec_id, results in results_map.items(): + for r in results: + frame_result = FrameAnalysisResultCreate( + device_id=self.device.id, + frame_id=frame_id, + algo_model_id=model_exec_id, + object_class_id=r.object_class_id, + object_class_name=r.object_class_name, + confidence=r.conf, + location=r.location, + ) + frame_results.append(frame_result) + self.frame_analysis_result_service.add_frame_analysis_results(frame_results) + + def run(self): + for frame in self.stream_loader: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + + results_map = {} + for model_exec in self.model_exec_list: + handle_task_name = model_exec.algo_model_info.handle_task + handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) + handler_instance = handler_cls(model_exec) + frame, results = handler_instance.run(frame, None) + results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] + + # 结果处理 + self.thread_pool.submit_task(self.save_frame_results, frame, results_map) diff --git a/algo/model_manager.py b/algo/model_manager.py index 3ab0709..535c666 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -1,5 +1,6 @@ +import os.path from dataclasses import dataclass -from typing import Optional +from typing import Optional, Dict import numpy as np from ultralytics import YOLO @@ -13,7 +14,7 @@ class AlgoModelExec: algo_model_id: int algo_model_info: AlgoModel - algo_model_exec: Optional[object] = None + algo_model_exec: Optional[YOLO] = None input_size: int = 640 @@ -21,11 +22,9 @@ def __init__(self, db, model_warm_up=5): self.db = db self.model_service = ModelService(self.db) - self.models = {} + self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up - self.load_models() - def query_model_inuse(self): algo_model_list = list(self.model_service.get_models_in_use()) for algo_model in algo_model_list: @@ -35,6 +34,7 @@ ) def load_models(self): + logger.info('loading models') self.models = {} self.query_model_inuse() for algo_model_id, algo_model_exec in self.models.items(): @@ -45,6 +45,10 @@ model_path = algo_model_exec.algo_model_info.path logger.info(f'loading model {model_name}: {model_path}') + if not os.path.exists(model_path): + logger.info(f'model path:{model_path} not exists') + return + algo_model_exec.algo_model_exec = YOLO(model_path, task='detect') if self.model_warm_up > 0: logger.info(f'warming up model {model_name}') diff --git a/algo/stream_loader.py b/algo/stream_loader.py new file mode 100644 index 0000000..04faeb2 --- /dev/null +++ b/algo/stream_loader.py @@ -0,0 +1,92 @@ +import cv2 +import time +import numpy as np +from threading import Thread, Event +from common.global_logger import logger + + +class OpenCVStreamLoad: + def __init__(self, camera_url, camera_code, + retry_interval=1, + vid_stride=1): + assert camera_url is not None and camera_url != '' + self.url = camera_url + self.camera_code = camera_code + self.retry_interval = retry_interval + self.vid_stride = vid_stride + + self.cap = self.get_connect() + self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + + _, self.frame = self.cap.read() + self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 + self.thread = Thread(target=self.update, daemon=True) + self.thread.start() + + def create_capture(self): + """ + 尝试创建视频流捕获对象。 + """ + try: + cap = cv2.VideoCapture(self.url) + # 可以在这里设置cap的一些属性,如果需要的话 + return cap + except Exception as e: + logger.error(e) + return None + + def get_connect(self): + """ + 尝试重新连接,直到成功。 + """ + cap = None + while cap is None or not cap.isOpened(): + logger.info(f"{self.url} try to connect...") + cap = self.create_capture() + if cap is None or not cap.isOpened(): + logger.info(f"{self.url} connect failed, retry after {self.retry_interval} second...") + time.sleep(self.retry_interval) # 等待一段时间后重试 + else: + logger.info(f"{self.url} connect success!") + return cap + + def update(self): + vid_n = 0 + log_n = 0 + while True: + vid_n += 1 + if vid_n % self.vid_stride == 0: + try: + ret, frame = self.cap.read() + if not ret: + logger.info("disconnect, try to reconnect...") + self.cap.release() # 释放当前的捕获对象 + self.cap = self.get_connect() # 尝试重新连接 + self.frame = np.zeros_like(self.frame) + continue # 跳过当前循环的剩余部分 + else: + vid_n += 1 + self.frame = frame + # cv2.imwrite('cv_test.jpg', frame) + if log_n % 1000 == 0: + logger.debug('cap success') + log_n = (log_n + 1) % 250 + except Exception as e: + logger.error("update fail", e) + if self.cap is not None: + self.cap.release() + self.cap = self.get_connect() # 尝试重新连接 + + def __iter__(self): + return self + + def __next__(self): + return self.frame + + def stop(self): + """ 停止视频流读取线程 """ + self.__stop_event.set() + self.thread.join() # 确保线程已完全终止 + self.cap.release() diff --git a/common/biz_exception.py b/common/biz_exception.py index 946a794..4659571 100644 --- a/common/biz_exception.py +++ b/common/biz_exception.py @@ -11,4 +11,7 @@ class BizExceptionHandlers: @staticmethod async def biz_exception_handler(request: Request, exc: HTTPException): - standard_error_response(code=exc.status_code, message=exc.detail) + return standard_error_response(code=exc.status_code, message=exc.detail) + # 使用 JSONResponse 返回响应 + # return JSONResponse(status_code=exc.status_code, content=response_data) + diff --git a/common/global_logger.py b/common/global_logger.py index 1d65173..3c4861e 100644 --- a/common/global_logger.py +++ b/common/global_logger.py @@ -3,7 +3,7 @@ import os # 确保日志目录存在 -log_dir = '../logs' +log_dir = 'logs' if not os.path.exists(log_dir): os.makedirs(log_dir) diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py index 21d7fd5..f55359b 100644 --- a/common/global_thread_pool.py +++ b/common/global_thread_pool.py @@ -2,6 +2,8 @@ from concurrent.futures import ThreadPoolExecutor import threading +from common.global_logger import logger + def generate_thread_id(): """生成唯一的线程 ID""" @@ -44,15 +46,33 @@ else: return False + def check_task_stopped(self, thread_id): + """判断任务是否已停止""" + future = self.task_map.get(thread_id) + if future: + if future.done(): + try: + # 确保任务是正常完成的,而不是因为异常停止 + future.result() # 如果任务抛出异常,这里会捕获 + logger.info(f"Task {thread_id} has stopped successfully.") + except Exception as e: + logger.error(f"Task {thread_id} encountered an error: {e}") + return True # 无论成功还是失败,任务已停止 + else: + return False # 任务仍在运行 + else: + logger.warning(f"No task found with thread ID {thread_id}.") + return True # 如果找不到该任务,认为它已经停止(或者不存在) + def stop_task(self, thread_id): """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务""" future = self.task_map.get(thread_id) if future: future.cancel() # 尝试取消任务 - print(f"任务 {thread_id} 已取消") + logger.info(f"任务 {thread_id} 已取消") del self.task_map[thread_id] # 从任务映射中删除 else: - print(f"未找到线程 ID {thread_id}") + logger.info(f"未找到线程 ID {thread_id}") def shutdown(self, wait=True): """关闭线程池""" diff --git a/common/string_utils.py b/common/string_utils.py new file mode 100644 index 0000000..551a209 --- /dev/null +++ b/common/string_utils.py @@ -0,0 +1,6 @@ +import re + + +def camel_to_snake(name): + # 将大写字母前加上下划线,并将整个字符串转换为小写 + return re.sub(r'(? 'DetectionResult': + return DetectionResult( + object_class_id=data['object_class_id'], + object_class_name=data['object_class_name'], + confidence=data['confidence'], + location=data['location'] + ) + + +class DeviceDetectionTask: + def __init__(self, device: Device, model_exec_list: List[AlgoModelExec]): + self.device = device + self.model_exec_list = model_exec_list + self.__stop_event = Event() # 使用 Event 控制线程的运行状态 + self.frame_ts = None + + with next(get_db()) as db: + self.device_frame_service = DeviceFrameService(db) + self.frame_analysis_result_service = FrameAnalysisResultService(db) + + self.thread_pool = GlobalThreadPool() + + self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code) + + def stop_detection_task(self): + self.__stop_event.set() + self.stream_loader.stop() # 停止视频流加载的线程 + + def check_frame_interval(self): + if self.device.image_save_interval < 0: + return False + if self.frame_ts is None: + self.frame_ts = datetime.now() + return True + if datetime.now() - self.frame_ts > self.device.image_save_interval: + self.frame_ts = datetime.now() + return True + return False + + def save_frame_results(self, frame, results_map): + if not self.check_frame_interval(): + return + device_frame = self.device_frame_service.add_frame(self.device.id, frame) + frame_id = device_frame.id + frame_results = [] + for model_exec_id, results in results_map.items(): + for r in results: + frame_result = FrameAnalysisResultCreate( + device_id=self.device.id, + frame_id=frame_id, + algo_model_id=model_exec_id, + object_class_id=r.object_class_id, + object_class_name=r.object_class_name, + confidence=r.conf, + location=r.location, + ) + frame_results.append(frame_result) + self.frame_analysis_result_service.add_frame_analysis_results(frame_results) + + def run(self): + for frame in self.stream_loader: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + + results_map = {} + for model_exec in self.model_exec_list: + handle_task_name = model_exec.algo_model_info.handle_task + handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) + handler_instance = handler_cls(model_exec) + frame, results = handler_instance.run(frame, None) + results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] + + # 结果处理 + self.thread_pool.submit_task(self.save_frame_results, frame, results_map) diff --git a/algo/model_manager.py b/algo/model_manager.py index 3ab0709..535c666 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -1,5 +1,6 @@ +import os.path from dataclasses import dataclass -from typing import Optional +from typing import Optional, Dict import numpy as np from ultralytics import YOLO @@ -13,7 +14,7 @@ class AlgoModelExec: algo_model_id: int algo_model_info: AlgoModel - algo_model_exec: Optional[object] = None + algo_model_exec: Optional[YOLO] = None input_size: int = 640 @@ -21,11 +22,9 @@ def __init__(self, db, model_warm_up=5): self.db = db self.model_service = ModelService(self.db) - self.models = {} + self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up - self.load_models() - def query_model_inuse(self): algo_model_list = list(self.model_service.get_models_in_use()) for algo_model in algo_model_list: @@ -35,6 +34,7 @@ ) def load_models(self): + logger.info('loading models') self.models = {} self.query_model_inuse() for algo_model_id, algo_model_exec in self.models.items(): @@ -45,6 +45,10 @@ model_path = algo_model_exec.algo_model_info.path logger.info(f'loading model {model_name}: {model_path}') + if not os.path.exists(model_path): + logger.info(f'model path:{model_path} not exists') + return + algo_model_exec.algo_model_exec = YOLO(model_path, task='detect') if self.model_warm_up > 0: logger.info(f'warming up model {model_name}') diff --git a/algo/stream_loader.py b/algo/stream_loader.py new file mode 100644 index 0000000..04faeb2 --- /dev/null +++ b/algo/stream_loader.py @@ -0,0 +1,92 @@ +import cv2 +import time +import numpy as np +from threading import Thread, Event +from common.global_logger import logger + + +class OpenCVStreamLoad: + def __init__(self, camera_url, camera_code, + retry_interval=1, + vid_stride=1): + assert camera_url is not None and camera_url != '' + self.url = camera_url + self.camera_code = camera_code + self.retry_interval = retry_interval + self.vid_stride = vid_stride + + self.cap = self.get_connect() + self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + + _, self.frame = self.cap.read() + self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 + self.thread = Thread(target=self.update, daemon=True) + self.thread.start() + + def create_capture(self): + """ + 尝试创建视频流捕获对象。 + """ + try: + cap = cv2.VideoCapture(self.url) + # 可以在这里设置cap的一些属性,如果需要的话 + return cap + except Exception as e: + logger.error(e) + return None + + def get_connect(self): + """ + 尝试重新连接,直到成功。 + """ + cap = None + while cap is None or not cap.isOpened(): + logger.info(f"{self.url} try to connect...") + cap = self.create_capture() + if cap is None or not cap.isOpened(): + logger.info(f"{self.url} connect failed, retry after {self.retry_interval} second...") + time.sleep(self.retry_interval) # 等待一段时间后重试 + else: + logger.info(f"{self.url} connect success!") + return cap + + def update(self): + vid_n = 0 + log_n = 0 + while True: + vid_n += 1 + if vid_n % self.vid_stride == 0: + try: + ret, frame = self.cap.read() + if not ret: + logger.info("disconnect, try to reconnect...") + self.cap.release() # 释放当前的捕获对象 + self.cap = self.get_connect() # 尝试重新连接 + self.frame = np.zeros_like(self.frame) + continue # 跳过当前循环的剩余部分 + else: + vid_n += 1 + self.frame = frame + # cv2.imwrite('cv_test.jpg', frame) + if log_n % 1000 == 0: + logger.debug('cap success') + log_n = (log_n + 1) % 250 + except Exception as e: + logger.error("update fail", e) + if self.cap is not None: + self.cap.release() + self.cap = self.get_connect() # 尝试重新连接 + + def __iter__(self): + return self + + def __next__(self): + return self.frame + + def stop(self): + """ 停止视频流读取线程 """ + self.__stop_event.set() + self.thread.join() # 确保线程已完全终止 + self.cap.release() diff --git a/common/biz_exception.py b/common/biz_exception.py index 946a794..4659571 100644 --- a/common/biz_exception.py +++ b/common/biz_exception.py @@ -11,4 +11,7 @@ class BizExceptionHandlers: @staticmethod async def biz_exception_handler(request: Request, exc: HTTPException): - standard_error_response(code=exc.status_code, message=exc.detail) + return standard_error_response(code=exc.status_code, message=exc.detail) + # 使用 JSONResponse 返回响应 + # return JSONResponse(status_code=exc.status_code, content=response_data) + diff --git a/common/global_logger.py b/common/global_logger.py index 1d65173..3c4861e 100644 --- a/common/global_logger.py +++ b/common/global_logger.py @@ -3,7 +3,7 @@ import os # 确保日志目录存在 -log_dir = '../logs' +log_dir = 'logs' if not os.path.exists(log_dir): os.makedirs(log_dir) diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py index 21d7fd5..f55359b 100644 --- a/common/global_thread_pool.py +++ b/common/global_thread_pool.py @@ -2,6 +2,8 @@ from concurrent.futures import ThreadPoolExecutor import threading +from common.global_logger import logger + def generate_thread_id(): """生成唯一的线程 ID""" @@ -44,15 +46,33 @@ else: return False + def check_task_stopped(self, thread_id): + """判断任务是否已停止""" + future = self.task_map.get(thread_id) + if future: + if future.done(): + try: + # 确保任务是正常完成的,而不是因为异常停止 + future.result() # 如果任务抛出异常,这里会捕获 + logger.info(f"Task {thread_id} has stopped successfully.") + except Exception as e: + logger.error(f"Task {thread_id} encountered an error: {e}") + return True # 无论成功还是失败,任务已停止 + else: + return False # 任务仍在运行 + else: + logger.warning(f"No task found with thread ID {thread_id}.") + return True # 如果找不到该任务,认为它已经停止(或者不存在) + def stop_task(self, thread_id): """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务""" future = self.task_map.get(thread_id) if future: future.cancel() # 尝试取消任务 - print(f"任务 {thread_id} 已取消") + logger.info(f"任务 {thread_id} 已取消") del self.task_map[thread_id] # 从任务映射中删除 else: - print(f"未找到线程 ID {thread_id}") + logger.info(f"未找到线程 ID {thread_id}") def shutdown(self, wait=True): """关闭线程池""" diff --git a/common/string_utils.py b/common/string_utils.py new file mode 100644 index 0000000..551a209 --- /dev/null +++ b/common/string_utils.py @@ -0,0 +1,6 @@ +import re + + +def camel_to_snake(name): + # 将大写字母前加上下划线,并将整个字符串转换为小写 + return re.sub(r'(? 'DetectionResult': + return DetectionResult( + object_class_id=data['object_class_id'], + object_class_name=data['object_class_name'], + confidence=data['confidence'], + location=data['location'] + ) + + +class DeviceDetectionTask: + def __init__(self, device: Device, model_exec_list: List[AlgoModelExec]): + self.device = device + self.model_exec_list = model_exec_list + self.__stop_event = Event() # 使用 Event 控制线程的运行状态 + self.frame_ts = None + + with next(get_db()) as db: + self.device_frame_service = DeviceFrameService(db) + self.frame_analysis_result_service = FrameAnalysisResultService(db) + + self.thread_pool = GlobalThreadPool() + + self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code) + + def stop_detection_task(self): + self.__stop_event.set() + self.stream_loader.stop() # 停止视频流加载的线程 + + def check_frame_interval(self): + if self.device.image_save_interval < 0: + return False + if self.frame_ts is None: + self.frame_ts = datetime.now() + return True + if datetime.now() - self.frame_ts > self.device.image_save_interval: + self.frame_ts = datetime.now() + return True + return False + + def save_frame_results(self, frame, results_map): + if not self.check_frame_interval(): + return + device_frame = self.device_frame_service.add_frame(self.device.id, frame) + frame_id = device_frame.id + frame_results = [] + for model_exec_id, results in results_map.items(): + for r in results: + frame_result = FrameAnalysisResultCreate( + device_id=self.device.id, + frame_id=frame_id, + algo_model_id=model_exec_id, + object_class_id=r.object_class_id, + object_class_name=r.object_class_name, + confidence=r.conf, + location=r.location, + ) + frame_results.append(frame_result) + self.frame_analysis_result_service.add_frame_analysis_results(frame_results) + + def run(self): + for frame in self.stream_loader: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + + results_map = {} + for model_exec in self.model_exec_list: + handle_task_name = model_exec.algo_model_info.handle_task + handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) + handler_instance = handler_cls(model_exec) + frame, results = handler_instance.run(frame, None) + results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] + + # 结果处理 + self.thread_pool.submit_task(self.save_frame_results, frame, results_map) diff --git a/algo/model_manager.py b/algo/model_manager.py index 3ab0709..535c666 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -1,5 +1,6 @@ +import os.path from dataclasses import dataclass -from typing import Optional +from typing import Optional, Dict import numpy as np from ultralytics import YOLO @@ -13,7 +14,7 @@ class AlgoModelExec: algo_model_id: int algo_model_info: AlgoModel - algo_model_exec: Optional[object] = None + algo_model_exec: Optional[YOLO] = None input_size: int = 640 @@ -21,11 +22,9 @@ def __init__(self, db, model_warm_up=5): self.db = db self.model_service = ModelService(self.db) - self.models = {} + self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up - self.load_models() - def query_model_inuse(self): algo_model_list = list(self.model_service.get_models_in_use()) for algo_model in algo_model_list: @@ -35,6 +34,7 @@ ) def load_models(self): + logger.info('loading models') self.models = {} self.query_model_inuse() for algo_model_id, algo_model_exec in self.models.items(): @@ -45,6 +45,10 @@ model_path = algo_model_exec.algo_model_info.path logger.info(f'loading model {model_name}: {model_path}') + if not os.path.exists(model_path): + logger.info(f'model path:{model_path} not exists') + return + algo_model_exec.algo_model_exec = YOLO(model_path, task='detect') if self.model_warm_up > 0: logger.info(f'warming up model {model_name}') diff --git a/algo/stream_loader.py b/algo/stream_loader.py new file mode 100644 index 0000000..04faeb2 --- /dev/null +++ b/algo/stream_loader.py @@ -0,0 +1,92 @@ +import cv2 +import time +import numpy as np +from threading import Thread, Event +from common.global_logger import logger + + +class OpenCVStreamLoad: + def __init__(self, camera_url, camera_code, + retry_interval=1, + vid_stride=1): + assert camera_url is not None and camera_url != '' + self.url = camera_url + self.camera_code = camera_code + self.retry_interval = retry_interval + self.vid_stride = vid_stride + + self.cap = self.get_connect() + self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + + _, self.frame = self.cap.read() + self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 + self.thread = Thread(target=self.update, daemon=True) + self.thread.start() + + def create_capture(self): + """ + 尝试创建视频流捕获对象。 + """ + try: + cap = cv2.VideoCapture(self.url) + # 可以在这里设置cap的一些属性,如果需要的话 + return cap + except Exception as e: + logger.error(e) + return None + + def get_connect(self): + """ + 尝试重新连接,直到成功。 + """ + cap = None + while cap is None or not cap.isOpened(): + logger.info(f"{self.url} try to connect...") + cap = self.create_capture() + if cap is None or not cap.isOpened(): + logger.info(f"{self.url} connect failed, retry after {self.retry_interval} second...") + time.sleep(self.retry_interval) # 等待一段时间后重试 + else: + logger.info(f"{self.url} connect success!") + return cap + + def update(self): + vid_n = 0 + log_n = 0 + while True: + vid_n += 1 + if vid_n % self.vid_stride == 0: + try: + ret, frame = self.cap.read() + if not ret: + logger.info("disconnect, try to reconnect...") + self.cap.release() # 释放当前的捕获对象 + self.cap = self.get_connect() # 尝试重新连接 + self.frame = np.zeros_like(self.frame) + continue # 跳过当前循环的剩余部分 + else: + vid_n += 1 + self.frame = frame + # cv2.imwrite('cv_test.jpg', frame) + if log_n % 1000 == 0: + logger.debug('cap success') + log_n = (log_n + 1) % 250 + except Exception as e: + logger.error("update fail", e) + if self.cap is not None: + self.cap.release() + self.cap = self.get_connect() # 尝试重新连接 + + def __iter__(self): + return self + + def __next__(self): + return self.frame + + def stop(self): + """ 停止视频流读取线程 """ + self.__stop_event.set() + self.thread.join() # 确保线程已完全终止 + self.cap.release() diff --git a/common/biz_exception.py b/common/biz_exception.py index 946a794..4659571 100644 --- a/common/biz_exception.py +++ b/common/biz_exception.py @@ -11,4 +11,7 @@ class BizExceptionHandlers: @staticmethod async def biz_exception_handler(request: Request, exc: HTTPException): - standard_error_response(code=exc.status_code, message=exc.detail) + return standard_error_response(code=exc.status_code, message=exc.detail) + # 使用 JSONResponse 返回响应 + # return JSONResponse(status_code=exc.status_code, content=response_data) + diff --git a/common/global_logger.py b/common/global_logger.py index 1d65173..3c4861e 100644 --- a/common/global_logger.py +++ b/common/global_logger.py @@ -3,7 +3,7 @@ import os # 确保日志目录存在 -log_dir = '../logs' +log_dir = 'logs' if not os.path.exists(log_dir): os.makedirs(log_dir) diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py index 21d7fd5..f55359b 100644 --- a/common/global_thread_pool.py +++ b/common/global_thread_pool.py @@ -2,6 +2,8 @@ from concurrent.futures import ThreadPoolExecutor import threading +from common.global_logger import logger + def generate_thread_id(): """生成唯一的线程 ID""" @@ -44,15 +46,33 @@ else: return False + def check_task_stopped(self, thread_id): + """判断任务是否已停止""" + future = self.task_map.get(thread_id) + if future: + if future.done(): + try: + # 确保任务是正常完成的,而不是因为异常停止 + future.result() # 如果任务抛出异常,这里会捕获 + logger.info(f"Task {thread_id} has stopped successfully.") + except Exception as e: + logger.error(f"Task {thread_id} encountered an error: {e}") + return True # 无论成功还是失败,任务已停止 + else: + return False # 任务仍在运行 + else: + logger.warning(f"No task found with thread ID {thread_id}.") + return True # 如果找不到该任务,认为它已经停止(或者不存在) + def stop_task(self, thread_id): """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务""" future = self.task_map.get(thread_id) if future: future.cancel() # 尝试取消任务 - print(f"任务 {thread_id} 已取消") + logger.info(f"任务 {thread_id} 已取消") del self.task_map[thread_id] # 从任务映射中删除 else: - print(f"未找到线程 ID {thread_id}") + logger.info(f"未找到线程 ID {thread_id}") def shutdown(self, wait=True): """关闭线程池""" diff --git a/common/string_utils.py b/common/string_utils.py new file mode 100644 index 0000000..551a209 --- /dev/null +++ b/common/string_utils.py @@ -0,0 +1,6 @@ +import re + + +def camel_to_snake(name): + # 将大写字母前加上下划线,并将整个字符串转换为小写 + return re.sub(r'(? 'DetectionResult': + return DetectionResult( + object_class_id=data['object_class_id'], + object_class_name=data['object_class_name'], + confidence=data['confidence'], + location=data['location'] + ) + + +class DeviceDetectionTask: + def __init__(self, device: Device, model_exec_list: List[AlgoModelExec]): + self.device = device + self.model_exec_list = model_exec_list + self.__stop_event = Event() # 使用 Event 控制线程的运行状态 + self.frame_ts = None + + with next(get_db()) as db: + self.device_frame_service = DeviceFrameService(db) + self.frame_analysis_result_service = FrameAnalysisResultService(db) + + self.thread_pool = GlobalThreadPool() + + self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code) + + def stop_detection_task(self): + self.__stop_event.set() + self.stream_loader.stop() # 停止视频流加载的线程 + + def check_frame_interval(self): + if self.device.image_save_interval < 0: + return False + if self.frame_ts is None: + self.frame_ts = datetime.now() + return True + if datetime.now() - self.frame_ts > self.device.image_save_interval: + self.frame_ts = datetime.now() + return True + return False + + def save_frame_results(self, frame, results_map): + if not self.check_frame_interval(): + return + device_frame = self.device_frame_service.add_frame(self.device.id, frame) + frame_id = device_frame.id + frame_results = [] + for model_exec_id, results in results_map.items(): + for r in results: + frame_result = FrameAnalysisResultCreate( + device_id=self.device.id, + frame_id=frame_id, + algo_model_id=model_exec_id, + object_class_id=r.object_class_id, + object_class_name=r.object_class_name, + confidence=r.conf, + location=r.location, + ) + frame_results.append(frame_result) + self.frame_analysis_result_service.add_frame_analysis_results(frame_results) + + def run(self): + for frame in self.stream_loader: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + + results_map = {} + for model_exec in self.model_exec_list: + handle_task_name = model_exec.algo_model_info.handle_task + handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) + handler_instance = handler_cls(model_exec) + frame, results = handler_instance.run(frame, None) + results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] + + # 结果处理 + self.thread_pool.submit_task(self.save_frame_results, frame, results_map) diff --git a/algo/model_manager.py b/algo/model_manager.py index 3ab0709..535c666 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -1,5 +1,6 @@ +import os.path from dataclasses import dataclass -from typing import Optional +from typing import Optional, Dict import numpy as np from ultralytics import YOLO @@ -13,7 +14,7 @@ class AlgoModelExec: algo_model_id: int algo_model_info: AlgoModel - algo_model_exec: Optional[object] = None + algo_model_exec: Optional[YOLO] = None input_size: int = 640 @@ -21,11 +22,9 @@ def __init__(self, db, model_warm_up=5): self.db = db self.model_service = ModelService(self.db) - self.models = {} + self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up - self.load_models() - def query_model_inuse(self): algo_model_list = list(self.model_service.get_models_in_use()) for algo_model in algo_model_list: @@ -35,6 +34,7 @@ ) def load_models(self): + logger.info('loading models') self.models = {} self.query_model_inuse() for algo_model_id, algo_model_exec in self.models.items(): @@ -45,6 +45,10 @@ model_path = algo_model_exec.algo_model_info.path logger.info(f'loading model {model_name}: {model_path}') + if not os.path.exists(model_path): + logger.info(f'model path:{model_path} not exists') + return + algo_model_exec.algo_model_exec = YOLO(model_path, task='detect') if self.model_warm_up > 0: logger.info(f'warming up model {model_name}') diff --git a/algo/stream_loader.py b/algo/stream_loader.py new file mode 100644 index 0000000..04faeb2 --- /dev/null +++ b/algo/stream_loader.py @@ -0,0 +1,92 @@ +import cv2 +import time +import numpy as np +from threading import Thread, Event +from common.global_logger import logger + + +class OpenCVStreamLoad: + def __init__(self, camera_url, camera_code, + retry_interval=1, + vid_stride=1): + assert camera_url is not None and camera_url != '' + self.url = camera_url + self.camera_code = camera_code + self.retry_interval = retry_interval + self.vid_stride = vid_stride + + self.cap = self.get_connect() + self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + + _, self.frame = self.cap.read() + self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 + self.thread = Thread(target=self.update, daemon=True) + self.thread.start() + + def create_capture(self): + """ + 尝试创建视频流捕获对象。 + """ + try: + cap = cv2.VideoCapture(self.url) + # 可以在这里设置cap的一些属性,如果需要的话 + return cap + except Exception as e: + logger.error(e) + return None + + def get_connect(self): + """ + 尝试重新连接,直到成功。 + """ + cap = None + while cap is None or not cap.isOpened(): + logger.info(f"{self.url} try to connect...") + cap = self.create_capture() + if cap is None or not cap.isOpened(): + logger.info(f"{self.url} connect failed, retry after {self.retry_interval} second...") + time.sleep(self.retry_interval) # 等待一段时间后重试 + else: + logger.info(f"{self.url} connect success!") + return cap + + def update(self): + vid_n = 0 + log_n = 0 + while True: + vid_n += 1 + if vid_n % self.vid_stride == 0: + try: + ret, frame = self.cap.read() + if not ret: + logger.info("disconnect, try to reconnect...") + self.cap.release() # 释放当前的捕获对象 + self.cap = self.get_connect() # 尝试重新连接 + self.frame = np.zeros_like(self.frame) + continue # 跳过当前循环的剩余部分 + else: + vid_n += 1 + self.frame = frame + # cv2.imwrite('cv_test.jpg', frame) + if log_n % 1000 == 0: + logger.debug('cap success') + log_n = (log_n + 1) % 250 + except Exception as e: + logger.error("update fail", e) + if self.cap is not None: + self.cap.release() + self.cap = self.get_connect() # 尝试重新连接 + + def __iter__(self): + return self + + def __next__(self): + return self.frame + + def stop(self): + """ 停止视频流读取线程 """ + self.__stop_event.set() + self.thread.join() # 确保线程已完全终止 + self.cap.release() diff --git a/common/biz_exception.py b/common/biz_exception.py index 946a794..4659571 100644 --- a/common/biz_exception.py +++ b/common/biz_exception.py @@ -11,4 +11,7 @@ class BizExceptionHandlers: @staticmethod async def biz_exception_handler(request: Request, exc: HTTPException): - standard_error_response(code=exc.status_code, message=exc.detail) + return standard_error_response(code=exc.status_code, message=exc.detail) + # 使用 JSONResponse 返回响应 + # return JSONResponse(status_code=exc.status_code, content=response_data) + diff --git a/common/global_logger.py b/common/global_logger.py index 1d65173..3c4861e 100644 --- a/common/global_logger.py +++ b/common/global_logger.py @@ -3,7 +3,7 @@ import os # 确保日志目录存在 -log_dir = '../logs' +log_dir = 'logs' if not os.path.exists(log_dir): os.makedirs(log_dir) diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py index 21d7fd5..f55359b 100644 --- a/common/global_thread_pool.py +++ b/common/global_thread_pool.py @@ -2,6 +2,8 @@ from concurrent.futures import ThreadPoolExecutor import threading +from common.global_logger import logger + def generate_thread_id(): """生成唯一的线程 ID""" @@ -44,15 +46,33 @@ else: return False + def check_task_stopped(self, thread_id): + """判断任务是否已停止""" + future = self.task_map.get(thread_id) + if future: + if future.done(): + try: + # 确保任务是正常完成的,而不是因为异常停止 + future.result() # 如果任务抛出异常,这里会捕获 + logger.info(f"Task {thread_id} has stopped successfully.") + except Exception as e: + logger.error(f"Task {thread_id} encountered an error: {e}") + return True # 无论成功还是失败,任务已停止 + else: + return False # 任务仍在运行 + else: + logger.warning(f"No task found with thread ID {thread_id}.") + return True # 如果找不到该任务,认为它已经停止(或者不存在) + def stop_task(self, thread_id): """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务""" future = self.task_map.get(thread_id) if future: future.cancel() # 尝试取消任务 - print(f"任务 {thread_id} 已取消") + logger.info(f"任务 {thread_id} 已取消") del self.task_map[thread_id] # 从任务映射中删除 else: - print(f"未找到线程 ID {thread_id}") + logger.info(f"未找到线程 ID {thread_id}") def shutdown(self, wait=True): """关闭线程池""" diff --git a/common/string_utils.py b/common/string_utils.py new file mode 100644 index 0000000..551a209 --- /dev/null +++ b/common/string_utils.py @@ -0,0 +1,6 @@ +import re + + +def camel_to_snake(name): + # 将大写字母前加上下划线,并将整个字符串转换为小写 + return re.sub(r'(? 'DetectionResult': + return DetectionResult( + object_class_id=data['object_class_id'], + object_class_name=data['object_class_name'], + confidence=data['confidence'], + location=data['location'] + ) + + +class DeviceDetectionTask: + def __init__(self, device: Device, model_exec_list: List[AlgoModelExec]): + self.device = device + self.model_exec_list = model_exec_list + self.__stop_event = Event() # 使用 Event 控制线程的运行状态 + self.frame_ts = None + + with next(get_db()) as db: + self.device_frame_service = DeviceFrameService(db) + self.frame_analysis_result_service = FrameAnalysisResultService(db) + + self.thread_pool = GlobalThreadPool() + + self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code) + + def stop_detection_task(self): + self.__stop_event.set() + self.stream_loader.stop() # 停止视频流加载的线程 + + def check_frame_interval(self): + if self.device.image_save_interval < 0: + return False + if self.frame_ts is None: + self.frame_ts = datetime.now() + return True + if datetime.now() - self.frame_ts > self.device.image_save_interval: + self.frame_ts = datetime.now() + return True + return False + + def save_frame_results(self, frame, results_map): + if not self.check_frame_interval(): + return + device_frame = self.device_frame_service.add_frame(self.device.id, frame) + frame_id = device_frame.id + frame_results = [] + for model_exec_id, results in results_map.items(): + for r in results: + frame_result = FrameAnalysisResultCreate( + device_id=self.device.id, + frame_id=frame_id, + algo_model_id=model_exec_id, + object_class_id=r.object_class_id, + object_class_name=r.object_class_name, + confidence=r.conf, + location=r.location, + ) + frame_results.append(frame_result) + self.frame_analysis_result_service.add_frame_analysis_results(frame_results) + + def run(self): + for frame in self.stream_loader: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + + results_map = {} + for model_exec in self.model_exec_list: + handle_task_name = model_exec.algo_model_info.handle_task + handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) + handler_instance = handler_cls(model_exec) + frame, results = handler_instance.run(frame, None) + results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] + + # 结果处理 + self.thread_pool.submit_task(self.save_frame_results, frame, results_map) diff --git a/algo/model_manager.py b/algo/model_manager.py index 3ab0709..535c666 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -1,5 +1,6 @@ +import os.path from dataclasses import dataclass -from typing import Optional +from typing import Optional, Dict import numpy as np from ultralytics import YOLO @@ -13,7 +14,7 @@ class AlgoModelExec: algo_model_id: int algo_model_info: AlgoModel - algo_model_exec: Optional[object] = None + algo_model_exec: Optional[YOLO] = None input_size: int = 640 @@ -21,11 +22,9 @@ def __init__(self, db, model_warm_up=5): self.db = db self.model_service = ModelService(self.db) - self.models = {} + self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up - self.load_models() - def query_model_inuse(self): algo_model_list = list(self.model_service.get_models_in_use()) for algo_model in algo_model_list: @@ -35,6 +34,7 @@ ) def load_models(self): + logger.info('loading models') self.models = {} self.query_model_inuse() for algo_model_id, algo_model_exec in self.models.items(): @@ -45,6 +45,10 @@ model_path = algo_model_exec.algo_model_info.path logger.info(f'loading model {model_name}: {model_path}') + if not os.path.exists(model_path): + logger.info(f'model path:{model_path} not exists') + return + algo_model_exec.algo_model_exec = YOLO(model_path, task='detect') if self.model_warm_up > 0: logger.info(f'warming up model {model_name}') diff --git a/algo/stream_loader.py b/algo/stream_loader.py new file mode 100644 index 0000000..04faeb2 --- /dev/null +++ b/algo/stream_loader.py @@ -0,0 +1,92 @@ +import cv2 +import time +import numpy as np +from threading import Thread, Event +from common.global_logger import logger + + +class OpenCVStreamLoad: + def __init__(self, camera_url, camera_code, + retry_interval=1, + vid_stride=1): + assert camera_url is not None and camera_url != '' + self.url = camera_url + self.camera_code = camera_code + self.retry_interval = retry_interval + self.vid_stride = vid_stride + + self.cap = self.get_connect() + self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + + _, self.frame = self.cap.read() + self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 + self.thread = Thread(target=self.update, daemon=True) + self.thread.start() + + def create_capture(self): + """ + 尝试创建视频流捕获对象。 + """ + try: + cap = cv2.VideoCapture(self.url) + # 可以在这里设置cap的一些属性,如果需要的话 + return cap + except Exception as e: + logger.error(e) + return None + + def get_connect(self): + """ + 尝试重新连接,直到成功。 + """ + cap = None + while cap is None or not cap.isOpened(): + logger.info(f"{self.url} try to connect...") + cap = self.create_capture() + if cap is None or not cap.isOpened(): + logger.info(f"{self.url} connect failed, retry after {self.retry_interval} second...") + time.sleep(self.retry_interval) # 等待一段时间后重试 + else: + logger.info(f"{self.url} connect success!") + return cap + + def update(self): + vid_n = 0 + log_n = 0 + while True: + vid_n += 1 + if vid_n % self.vid_stride == 0: + try: + ret, frame = self.cap.read() + if not ret: + logger.info("disconnect, try to reconnect...") + self.cap.release() # 释放当前的捕获对象 + self.cap = self.get_connect() # 尝试重新连接 + self.frame = np.zeros_like(self.frame) + continue # 跳过当前循环的剩余部分 + else: + vid_n += 1 + self.frame = frame + # cv2.imwrite('cv_test.jpg', frame) + if log_n % 1000 == 0: + logger.debug('cap success') + log_n = (log_n + 1) % 250 + except Exception as e: + logger.error("update fail", e) + if self.cap is not None: + self.cap.release() + self.cap = self.get_connect() # 尝试重新连接 + + def __iter__(self): + return self + + def __next__(self): + return self.frame + + def stop(self): + """ 停止视频流读取线程 """ + self.__stop_event.set() + self.thread.join() # 确保线程已完全终止 + self.cap.release() diff --git a/common/biz_exception.py b/common/biz_exception.py index 946a794..4659571 100644 --- a/common/biz_exception.py +++ b/common/biz_exception.py @@ -11,4 +11,7 @@ class BizExceptionHandlers: @staticmethod async def biz_exception_handler(request: Request, exc: HTTPException): - standard_error_response(code=exc.status_code, message=exc.detail) + return standard_error_response(code=exc.status_code, message=exc.detail) + # 使用 JSONResponse 返回响应 + # return JSONResponse(status_code=exc.status_code, content=response_data) + diff --git a/common/global_logger.py b/common/global_logger.py index 1d65173..3c4861e 100644 --- a/common/global_logger.py +++ b/common/global_logger.py @@ -3,7 +3,7 @@ import os # 确保日志目录存在 -log_dir = '../logs' +log_dir = 'logs' if not os.path.exists(log_dir): os.makedirs(log_dir) diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py index 21d7fd5..f55359b 100644 --- a/common/global_thread_pool.py +++ b/common/global_thread_pool.py @@ -2,6 +2,8 @@ from concurrent.futures import ThreadPoolExecutor import threading +from common.global_logger import logger + def generate_thread_id(): """生成唯一的线程 ID""" @@ -44,15 +46,33 @@ else: return False + def check_task_stopped(self, thread_id): + """判断任务是否已停止""" + future = self.task_map.get(thread_id) + if future: + if future.done(): + try: + # 确保任务是正常完成的,而不是因为异常停止 + future.result() # 如果任务抛出异常,这里会捕获 + logger.info(f"Task {thread_id} has stopped successfully.") + except Exception as e: + logger.error(f"Task {thread_id} encountered an error: {e}") + return True # 无论成功还是失败,任务已停止 + else: + return False # 任务仍在运行 + else: + logger.warning(f"No task found with thread ID {thread_id}.") + return True # 如果找不到该任务,认为它已经停止(或者不存在) + def stop_task(self, thread_id): """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务""" future = self.task_map.get(thread_id) if future: future.cancel() # 尝试取消任务 - print(f"任务 {thread_id} 已取消") + logger.info(f"任务 {thread_id} 已取消") del self.task_map[thread_id] # 从任务映射中删除 else: - print(f"未找到线程 ID {thread_id}") + logger.info(f"未找到线程 ID {thread_id}") def shutdown(self, wait=True): """关闭线程池""" diff --git a/common/string_utils.py b/common/string_utils.py new file mode 100644 index 0000000..551a209 --- /dev/null +++ b/common/string_utils.py @@ -0,0 +1,6 @@ +import re + + +def camel_to_snake(name): + # 将大写字母前加上下划线,并将整个字符串转换为小写 + return re.sub(r'(? 'DetectionResult': + return DetectionResult( + object_class_id=data['object_class_id'], + object_class_name=data['object_class_name'], + confidence=data['confidence'], + location=data['location'] + ) + + +class DeviceDetectionTask: + def __init__(self, device: Device, model_exec_list: List[AlgoModelExec]): + self.device = device + self.model_exec_list = model_exec_list + self.__stop_event = Event() # 使用 Event 控制线程的运行状态 + self.frame_ts = None + + with next(get_db()) as db: + self.device_frame_service = DeviceFrameService(db) + self.frame_analysis_result_service = FrameAnalysisResultService(db) + + self.thread_pool = GlobalThreadPool() + + self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code) + + def stop_detection_task(self): + self.__stop_event.set() + self.stream_loader.stop() # 停止视频流加载的线程 + + def check_frame_interval(self): + if self.device.image_save_interval < 0: + return False + if self.frame_ts is None: + self.frame_ts = datetime.now() + return True + if datetime.now() - self.frame_ts > self.device.image_save_interval: + self.frame_ts = datetime.now() + return True + return False + + def save_frame_results(self, frame, results_map): + if not self.check_frame_interval(): + return + device_frame = self.device_frame_service.add_frame(self.device.id, frame) + frame_id = device_frame.id + frame_results = [] + for model_exec_id, results in results_map.items(): + for r in results: + frame_result = FrameAnalysisResultCreate( + device_id=self.device.id, + frame_id=frame_id, + algo_model_id=model_exec_id, + object_class_id=r.object_class_id, + object_class_name=r.object_class_name, + confidence=r.conf, + location=r.location, + ) + frame_results.append(frame_result) + self.frame_analysis_result_service.add_frame_analysis_results(frame_results) + + def run(self): + for frame in self.stream_loader: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + + results_map = {} + for model_exec in self.model_exec_list: + handle_task_name = model_exec.algo_model_info.handle_task + handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) + handler_instance = handler_cls(model_exec) + frame, results = handler_instance.run(frame, None) + results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] + + # 结果处理 + self.thread_pool.submit_task(self.save_frame_results, frame, results_map) diff --git a/algo/model_manager.py b/algo/model_manager.py index 3ab0709..535c666 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -1,5 +1,6 @@ +import os.path from dataclasses import dataclass -from typing import Optional +from typing import Optional, Dict import numpy as np from ultralytics import YOLO @@ -13,7 +14,7 @@ class AlgoModelExec: algo_model_id: int algo_model_info: AlgoModel - algo_model_exec: Optional[object] = None + algo_model_exec: Optional[YOLO] = None input_size: int = 640 @@ -21,11 +22,9 @@ def __init__(self, db, model_warm_up=5): self.db = db self.model_service = ModelService(self.db) - self.models = {} + self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up - self.load_models() - def query_model_inuse(self): algo_model_list = list(self.model_service.get_models_in_use()) for algo_model in algo_model_list: @@ -35,6 +34,7 @@ ) def load_models(self): + logger.info('loading models') self.models = {} self.query_model_inuse() for algo_model_id, algo_model_exec in self.models.items(): @@ -45,6 +45,10 @@ model_path = algo_model_exec.algo_model_info.path logger.info(f'loading model {model_name}: {model_path}') + if not os.path.exists(model_path): + logger.info(f'model path:{model_path} not exists') + return + algo_model_exec.algo_model_exec = YOLO(model_path, task='detect') if self.model_warm_up > 0: logger.info(f'warming up model {model_name}') diff --git a/algo/stream_loader.py b/algo/stream_loader.py new file mode 100644 index 0000000..04faeb2 --- /dev/null +++ b/algo/stream_loader.py @@ -0,0 +1,92 @@ +import cv2 +import time +import numpy as np +from threading import Thread, Event +from common.global_logger import logger + + +class OpenCVStreamLoad: + def __init__(self, camera_url, camera_code, + retry_interval=1, + vid_stride=1): + assert camera_url is not None and camera_url != '' + self.url = camera_url + self.camera_code = camera_code + self.retry_interval = retry_interval + self.vid_stride = vid_stride + + self.cap = self.get_connect() + self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + + _, self.frame = self.cap.read() + self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 + self.thread = Thread(target=self.update, daemon=True) + self.thread.start() + + def create_capture(self): + """ + 尝试创建视频流捕获对象。 + """ + try: + cap = cv2.VideoCapture(self.url) + # 可以在这里设置cap的一些属性,如果需要的话 + return cap + except Exception as e: + logger.error(e) + return None + + def get_connect(self): + """ + 尝试重新连接,直到成功。 + """ + cap = None + while cap is None or not cap.isOpened(): + logger.info(f"{self.url} try to connect...") + cap = self.create_capture() + if cap is None or not cap.isOpened(): + logger.info(f"{self.url} connect failed, retry after {self.retry_interval} second...") + time.sleep(self.retry_interval) # 等待一段时间后重试 + else: + logger.info(f"{self.url} connect success!") + return cap + + def update(self): + vid_n = 0 + log_n = 0 + while True: + vid_n += 1 + if vid_n % self.vid_stride == 0: + try: + ret, frame = self.cap.read() + if not ret: + logger.info("disconnect, try to reconnect...") + self.cap.release() # 释放当前的捕获对象 + self.cap = self.get_connect() # 尝试重新连接 + self.frame = np.zeros_like(self.frame) + continue # 跳过当前循环的剩余部分 + else: + vid_n += 1 + self.frame = frame + # cv2.imwrite('cv_test.jpg', frame) + if log_n % 1000 == 0: + logger.debug('cap success') + log_n = (log_n + 1) % 250 + except Exception as e: + logger.error("update fail", e) + if self.cap is not None: + self.cap.release() + self.cap = self.get_connect() # 尝试重新连接 + + def __iter__(self): + return self + + def __next__(self): + return self.frame + + def stop(self): + """ 停止视频流读取线程 """ + self.__stop_event.set() + self.thread.join() # 确保线程已完全终止 + self.cap.release() diff --git a/common/biz_exception.py b/common/biz_exception.py index 946a794..4659571 100644 --- a/common/biz_exception.py +++ b/common/biz_exception.py @@ -11,4 +11,7 @@ class BizExceptionHandlers: @staticmethod async def biz_exception_handler(request: Request, exc: HTTPException): - standard_error_response(code=exc.status_code, message=exc.detail) + return standard_error_response(code=exc.status_code, message=exc.detail) + # 使用 JSONResponse 返回响应 + # return JSONResponse(status_code=exc.status_code, content=response_data) + diff --git a/common/global_logger.py b/common/global_logger.py index 1d65173..3c4861e 100644 --- a/common/global_logger.py +++ b/common/global_logger.py @@ -3,7 +3,7 @@ import os # 确保日志目录存在 -log_dir = '../logs' +log_dir = 'logs' if not os.path.exists(log_dir): os.makedirs(log_dir) diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py index 21d7fd5..f55359b 100644 --- a/common/global_thread_pool.py +++ b/common/global_thread_pool.py @@ -2,6 +2,8 @@ from concurrent.futures import ThreadPoolExecutor import threading +from common.global_logger import logger + def generate_thread_id(): """生成唯一的线程 ID""" @@ -44,15 +46,33 @@ else: return False + def check_task_stopped(self, thread_id): + """判断任务是否已停止""" + future = self.task_map.get(thread_id) + if future: + if future.done(): + try: + # 确保任务是正常完成的,而不是因为异常停止 + future.result() # 如果任务抛出异常,这里会捕获 + logger.info(f"Task {thread_id} has stopped successfully.") + except Exception as e: + logger.error(f"Task {thread_id} encountered an error: {e}") + return True # 无论成功还是失败,任务已停止 + else: + return False # 任务仍在运行 + else: + logger.warning(f"No task found with thread ID {thread_id}.") + return True # 如果找不到该任务,认为它已经停止(或者不存在) + def stop_task(self, thread_id): """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务""" future = self.task_map.get(thread_id) if future: future.cancel() # 尝试取消任务 - print(f"任务 {thread_id} 已取消") + logger.info(f"任务 {thread_id} 已取消") del self.task_map[thread_id] # 从任务映射中删除 else: - print(f"未找到线程 ID {thread_id}") + logger.info(f"未找到线程 ID {thread_id}") def shutdown(self, wait=True): """关闭线程池""" diff --git a/common/string_utils.py b/common/string_utils.py new file mode 100644 index 0000000..551a209 --- /dev/null +++ b/common/string_utils.py @@ -0,0 +1,6 @@ +import re + + +def camel_to_snake(name): + # 将大写字母前加上下划线,并将整个字符串转换为小写 + return re.sub(r'(? 'DetectionResult': + return DetectionResult( + object_class_id=data['object_class_id'], + object_class_name=data['object_class_name'], + confidence=data['confidence'], + location=data['location'] + ) + + +class DeviceDetectionTask: + def __init__(self, device: Device, model_exec_list: List[AlgoModelExec]): + self.device = device + self.model_exec_list = model_exec_list + self.__stop_event = Event() # 使用 Event 控制线程的运行状态 + self.frame_ts = None + + with next(get_db()) as db: + self.device_frame_service = DeviceFrameService(db) + self.frame_analysis_result_service = FrameAnalysisResultService(db) + + self.thread_pool = GlobalThreadPool() + + self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code) + + def stop_detection_task(self): + self.__stop_event.set() + self.stream_loader.stop() # 停止视频流加载的线程 + + def check_frame_interval(self): + if self.device.image_save_interval < 0: + return False + if self.frame_ts is None: + self.frame_ts = datetime.now() + return True + if datetime.now() - self.frame_ts > self.device.image_save_interval: + self.frame_ts = datetime.now() + return True + return False + + def save_frame_results(self, frame, results_map): + if not self.check_frame_interval(): + return + device_frame = self.device_frame_service.add_frame(self.device.id, frame) + frame_id = device_frame.id + frame_results = [] + for model_exec_id, results in results_map.items(): + for r in results: + frame_result = FrameAnalysisResultCreate( + device_id=self.device.id, + frame_id=frame_id, + algo_model_id=model_exec_id, + object_class_id=r.object_class_id, + object_class_name=r.object_class_name, + confidence=r.conf, + location=r.location, + ) + frame_results.append(frame_result) + self.frame_analysis_result_service.add_frame_analysis_results(frame_results) + + def run(self): + for frame in self.stream_loader: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + + results_map = {} + for model_exec in self.model_exec_list: + handle_task_name = model_exec.algo_model_info.handle_task + handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) + handler_instance = handler_cls(model_exec) + frame, results = handler_instance.run(frame, None) + results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] + + # 结果处理 + self.thread_pool.submit_task(self.save_frame_results, frame, results_map) diff --git a/algo/model_manager.py b/algo/model_manager.py index 3ab0709..535c666 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -1,5 +1,6 @@ +import os.path from dataclasses import dataclass -from typing import Optional +from typing import Optional, Dict import numpy as np from ultralytics import YOLO @@ -13,7 +14,7 @@ class AlgoModelExec: algo_model_id: int algo_model_info: AlgoModel - algo_model_exec: Optional[object] = None + algo_model_exec: Optional[YOLO] = None input_size: int = 640 @@ -21,11 +22,9 @@ def __init__(self, db, model_warm_up=5): self.db = db self.model_service = ModelService(self.db) - self.models = {} + self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up - self.load_models() - def query_model_inuse(self): algo_model_list = list(self.model_service.get_models_in_use()) for algo_model in algo_model_list: @@ -35,6 +34,7 @@ ) def load_models(self): + logger.info('loading models') self.models = {} self.query_model_inuse() for algo_model_id, algo_model_exec in self.models.items(): @@ -45,6 +45,10 @@ model_path = algo_model_exec.algo_model_info.path logger.info(f'loading model {model_name}: {model_path}') + if not os.path.exists(model_path): + logger.info(f'model path:{model_path} not exists') + return + algo_model_exec.algo_model_exec = YOLO(model_path, task='detect') if self.model_warm_up > 0: logger.info(f'warming up model {model_name}') diff --git a/algo/stream_loader.py b/algo/stream_loader.py new file mode 100644 index 0000000..04faeb2 --- /dev/null +++ b/algo/stream_loader.py @@ -0,0 +1,92 @@ +import cv2 +import time +import numpy as np +from threading import Thread, Event +from common.global_logger import logger + + +class OpenCVStreamLoad: + def __init__(self, camera_url, camera_code, + retry_interval=1, + vid_stride=1): + assert camera_url is not None and camera_url != '' + self.url = camera_url + self.camera_code = camera_code + self.retry_interval = retry_interval + self.vid_stride = vid_stride + + self.cap = self.get_connect() + self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + + _, self.frame = self.cap.read() + self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 + self.thread = Thread(target=self.update, daemon=True) + self.thread.start() + + def create_capture(self): + """ + 尝试创建视频流捕获对象。 + """ + try: + cap = cv2.VideoCapture(self.url) + # 可以在这里设置cap的一些属性,如果需要的话 + return cap + except Exception as e: + logger.error(e) + return None + + def get_connect(self): + """ + 尝试重新连接,直到成功。 + """ + cap = None + while cap is None or not cap.isOpened(): + logger.info(f"{self.url} try to connect...") + cap = self.create_capture() + if cap is None or not cap.isOpened(): + logger.info(f"{self.url} connect failed, retry after {self.retry_interval} second...") + time.sleep(self.retry_interval) # 等待一段时间后重试 + else: + logger.info(f"{self.url} connect success!") + return cap + + def update(self): + vid_n = 0 + log_n = 0 + while True: + vid_n += 1 + if vid_n % self.vid_stride == 0: + try: + ret, frame = self.cap.read() + if not ret: + logger.info("disconnect, try to reconnect...") + self.cap.release() # 释放当前的捕获对象 + self.cap = self.get_connect() # 尝试重新连接 + self.frame = np.zeros_like(self.frame) + continue # 跳过当前循环的剩余部分 + else: + vid_n += 1 + self.frame = frame + # cv2.imwrite('cv_test.jpg', frame) + if log_n % 1000 == 0: + logger.debug('cap success') + log_n = (log_n + 1) % 250 + except Exception as e: + logger.error("update fail", e) + if self.cap is not None: + self.cap.release() + self.cap = self.get_connect() # 尝试重新连接 + + def __iter__(self): + return self + + def __next__(self): + return self.frame + + def stop(self): + """ 停止视频流读取线程 """ + self.__stop_event.set() + self.thread.join() # 确保线程已完全终止 + self.cap.release() diff --git a/common/biz_exception.py b/common/biz_exception.py index 946a794..4659571 100644 --- a/common/biz_exception.py +++ b/common/biz_exception.py @@ -11,4 +11,7 @@ class BizExceptionHandlers: @staticmethod async def biz_exception_handler(request: Request, exc: HTTPException): - standard_error_response(code=exc.status_code, message=exc.detail) + return standard_error_response(code=exc.status_code, message=exc.detail) + # 使用 JSONResponse 返回响应 + # return JSONResponse(status_code=exc.status_code, content=response_data) + diff --git a/common/global_logger.py b/common/global_logger.py index 1d65173..3c4861e 100644 --- a/common/global_logger.py +++ b/common/global_logger.py @@ -3,7 +3,7 @@ import os # 确保日志目录存在 -log_dir = '../logs' +log_dir = 'logs' if not os.path.exists(log_dir): os.makedirs(log_dir) diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py index 21d7fd5..f55359b 100644 --- a/common/global_thread_pool.py +++ b/common/global_thread_pool.py @@ -2,6 +2,8 @@ from concurrent.futures import ThreadPoolExecutor import threading +from common.global_logger import logger + def generate_thread_id(): """生成唯一的线程 ID""" @@ -44,15 +46,33 @@ else: return False + def check_task_stopped(self, thread_id): + """判断任务是否已停止""" + future = self.task_map.get(thread_id) + if future: + if future.done(): + try: + # 确保任务是正常完成的,而不是因为异常停止 + future.result() # 如果任务抛出异常,这里会捕获 + logger.info(f"Task {thread_id} has stopped successfully.") + except Exception as e: + logger.error(f"Task {thread_id} encountered an error: {e}") + return True # 无论成功还是失败,任务已停止 + else: + return False # 任务仍在运行 + else: + logger.warning(f"No task found with thread ID {thread_id}.") + return True # 如果找不到该任务,认为它已经停止(或者不存在) + def stop_task(self, thread_id): """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务""" future = self.task_map.get(thread_id) if future: future.cancel() # 尝试取消任务 - print(f"任务 {thread_id} 已取消") + logger.info(f"任务 {thread_id} 已取消") del self.task_map[thread_id] # 从任务映射中删除 else: - print(f"未找到线程 ID {thread_id}") + logger.info(f"未找到线程 ID {thread_id}") def shutdown(self, wait=True): """关闭线程池""" diff --git a/common/string_utils.py b/common/string_utils.py new file mode 100644 index 0000000..551a209 --- /dev/null +++ b/common/string_utils.py @@ -0,0 +1,6 @@ +import re + + +def camel_to_snake(name): + # 将大写字母前加上下划线,并将整个字符串转换为小写 + return re.sub(r'(? DeviceFrame: + def save_frame_file(): + # 生成当前年月日作为目录路径 + current_date = datetime.now().strftime('%Y/%m/%d') + # 生成随机 UUID 作为文件名 + file_name = str(uuid.uuid4()) + ".png" + # 创建保存图片的完整路径 + save_path = os.path.join('./storage/frames', current_date, file_name) + # 创建目录(如果不存在) + os.makedirs(os.path.dirname(save_path), exist_ok=True) + # 保存图片 + cv2.imwrite(save_path, frame_data) + return save_path + + # 保存图片文件 + file_path = save_frame_file() + device_frame = DeviceFrame(device_id=device_id, frame_path=file_path) + self.db.add(device_frame) + self.db.commit() + self.db.refresh(device_frame) + return device_frame diff --git a/.gitignore b/.gitignore index e530a86..6c5fd4c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ test.py -/logs/* \ No newline at end of file +/logs/* +.idea \ No newline at end of file diff --git a/algo/algo_runner.py b/algo/algo_runner.py index ee07eaf..3b07e0a 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,4 @@ +from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager from common.consts import NotifyChangeType from common.global_thread_pool import GlobalThreadPool @@ -6,25 +7,27 @@ from services.device_model_relation_service import DeviceModelRelationService from services.device_service import DeviceService from services.model_service import ModelService - +from common.global_logger import logger class AlgoRunner: def __init__(self): - self.db = get_db() - self.device_service = DeviceService(self.db) - self.model_service = ModelService(self.db) - self.relation_service = DeviceModelRelationService(self.db) - self.model_manager = ModelManager(self.db) + with next(get_db()) as db: + self.device_service = DeviceService(db) + self.model_service = ModelService(db) + self.relation_service = DeviceModelRelationService(db) + self.model_manager = ModelManager(db) + self.thread_pool = GlobalThreadPool() - self.threads = {} # 用于存储设备对应的线程 + self.device_tasks = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 self.device_service.register_change_callback(self.on_device_change) self.model_service.register_change_callback(self.on_model_change) - self.relation_service.relation_change_callbacks(self.on_relation_change) + self.relation_service.register_change_callback(self.on_relation_change) def start(self): + logger.info("Starting AlgoRunner...") """在程序启动时调用,读取设备和模型,启动检测线程""" self.model_manager.load_models() self.load_and_start_devices() @@ -32,56 +35,73 @@ def load_and_start_devices(self): """从数据库读取设备列表并启动线程""" devices = self.device_service.get_device_list() - devices = [device for device in devices if device.input_stream_url] for device in devices: self.start_device_thread(device) - def run_detection(self, device: Device): - """todo 设备目标检测主逻辑""" - print(f"设备 {device.device_id} 的检测线程启动") - video_stream = self.device_service.get_video_stream(device.device_id) - - while True: - try: - frame = video_stream.read_frame() - for model in self.model_manager.models: - # 调用目标检测模型 - result = self.model_service.run_inference(model, frame) - # 在此处进行后处理,例如标记视频、生成告警等 - self.process_result(result, device.device_id) - except Exception as e: - print(f"设备 {device.device_id} 处理时出错: {e}") - break - video_stream.close() - print(f"设备 {device.device_id} 的检测线程结束") + @staticmethod + def get_device_thread_id(device_id): + return f'device_{device_id}' def start_device_thread(self, device: Device): """为单个设备启动检测线程""" if not device.input_stream_url: - print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程') + logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return - if device.device_id in self.threads: - print(f"设备 {device.device_id} 已经在运行中") + thread_id = self.get_device_thread_id(device.id) + + if self.thread_pool.check_task_is_running(thread_id=thread_id): + logger.info(f"设备 {device.code} 已经在运行中") return # 获取设备绑定的模型列表 - relations = self.relation_service.get_device_models(device.device_id) - self.device_model_relations[device.device_id] = relations + relations = self.relation_service.get_device_models(device.id) + relations = [r for r in relations if r.is_use] + self.device_model_relations[device.id] = relations if not relations: - print(f"设备 {device.code} 未绑定模型,无法启动检测") + logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") return - future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}') + model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations + if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] + device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list) + future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) + self.device_tasks[device.id] = device_detection_task def stop_device_thread(self, device_id): - """todo 控制标志位 停止指定设备的检测线程""" - if device_id in self.threads: - self.threads[device_id].cancel() # 尝试取消线程 - del self.threads[device_id] - print(f"设备 {device_id} 的检测线程已停止") + """停止指定设备的检测线程""" + if device_id in self.device_tasks: + self.device_tasks[device_id].stop_detection_task() + del self.device_tasks[device_id] + + # def stop_device_thread(self, device_id): + # """停止指定设备的检测线程,并确保其成功停止""" + # if device_id in self.device_tasks: + # # 获取线程 ID 和 future 对象 + # thread_id = self.get_device_thread_id(device_id) + # future = self.thread_pool.check_task_stopped(thread_id=thread_id) + # + # if future: + # # 调用 stop_detection_task 停止任务 + # self.device_tasks[device_id].stop_detection_task() + # + # try: + # # 设置超时时间等待任务停止(例如10秒) + # result = future.result(timeout=10) + # logger.info(f"Task for device {device_id} stopped successfully.") + # except TimeoutError: + # logger.error(f"Task for device {device_id} did not stop within the timeout.") + # except Exception as e: + # logger.error(f"Task for device {device_id} encountered an error while stopping: {e}") + # finally: + # # 确保无论任务是否停止,都将其从任务列表中移除 + # del self.device_tasks[device_id] + # else: + # logger.warning(f"No task found for device {device_id}.") + # else: + # logger.warning(f"No task exists for device {device_id} in device_tasks.") def restart_device_thread(self, device_id): self.stop_device_thread(device_id) @@ -97,6 +117,7 @@ elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: + #todo 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 self.restart_device_thread(device_id) def on_model_change(self, model_id, change_type): @@ -113,5 +134,6 @@ def on_relation_change(self, device_id, change_type): """设备模型关系变化回调""" + # todo 判断模型绑定关系是否真的发生了变化,避免不必要的重启 if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: self.restart_device_thread(device_id) diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py new file mode 100644 index 0000000..ff2d513 --- /dev/null +++ b/algo/device_detection_task.py @@ -0,0 +1,106 @@ +from dataclasses import dataclass +import importlib +from datetime import datetime +from threading import Event +from typing import List, Dict + +from algo.model_manager import AlgoModelExec +from algo.stream_loader import OpenCVStreamLoad +from common.global_thread_pool import GlobalThreadPool +from common.string_utils import camel_to_snake +from db.database import get_db +from entity.device import Device +from entity.frame_analysis_result import FrameAnalysisResultCreate +from services.device_frame_service import DeviceFrameService +from services.frame_analysis_result_service import FrameAnalysisResultService + + +def get_class(module_name, class_name): + # 动态导入模块 + module = importlib.import_module(module_name) + # 使用 getattr 从模块中获取类 + return getattr(module, class_name) + + +@dataclass +class DetectionResult: + object_class_id: int + object_class_name: str + confidence: float + location: str + + @classmethod + def from_dict(cls, data: Dict) -> 'DetectionResult': + return DetectionResult( + object_class_id=data['object_class_id'], + object_class_name=data['object_class_name'], + confidence=data['confidence'], + location=data['location'] + ) + + +class DeviceDetectionTask: + def __init__(self, device: Device, model_exec_list: List[AlgoModelExec]): + self.device = device + self.model_exec_list = model_exec_list + self.__stop_event = Event() # 使用 Event 控制线程的运行状态 + self.frame_ts = None + + with next(get_db()) as db: + self.device_frame_service = DeviceFrameService(db) + self.frame_analysis_result_service = FrameAnalysisResultService(db) + + self.thread_pool = GlobalThreadPool() + + self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code) + + def stop_detection_task(self): + self.__stop_event.set() + self.stream_loader.stop() # 停止视频流加载的线程 + + def check_frame_interval(self): + if self.device.image_save_interval < 0: + return False + if self.frame_ts is None: + self.frame_ts = datetime.now() + return True + if datetime.now() - self.frame_ts > self.device.image_save_interval: + self.frame_ts = datetime.now() + return True + return False + + def save_frame_results(self, frame, results_map): + if not self.check_frame_interval(): + return + device_frame = self.device_frame_service.add_frame(self.device.id, frame) + frame_id = device_frame.id + frame_results = [] + for model_exec_id, results in results_map.items(): + for r in results: + frame_result = FrameAnalysisResultCreate( + device_id=self.device.id, + frame_id=frame_id, + algo_model_id=model_exec_id, + object_class_id=r.object_class_id, + object_class_name=r.object_class_name, + confidence=r.conf, + location=r.location, + ) + frame_results.append(frame_result) + self.frame_analysis_result_service.add_frame_analysis_results(frame_results) + + def run(self): + for frame in self.stream_loader: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + + results_map = {} + for model_exec in self.model_exec_list: + handle_task_name = model_exec.algo_model_info.handle_task + handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) + handler_instance = handler_cls(model_exec) + frame, results = handler_instance.run(frame, None) + results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] + + # 结果处理 + self.thread_pool.submit_task(self.save_frame_results, frame, results_map) diff --git a/algo/model_manager.py b/algo/model_manager.py index 3ab0709..535c666 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -1,5 +1,6 @@ +import os.path from dataclasses import dataclass -from typing import Optional +from typing import Optional, Dict import numpy as np from ultralytics import YOLO @@ -13,7 +14,7 @@ class AlgoModelExec: algo_model_id: int algo_model_info: AlgoModel - algo_model_exec: Optional[object] = None + algo_model_exec: Optional[YOLO] = None input_size: int = 640 @@ -21,11 +22,9 @@ def __init__(self, db, model_warm_up=5): self.db = db self.model_service = ModelService(self.db) - self.models = {} + self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up - self.load_models() - def query_model_inuse(self): algo_model_list = list(self.model_service.get_models_in_use()) for algo_model in algo_model_list: @@ -35,6 +34,7 @@ ) def load_models(self): + logger.info('loading models') self.models = {} self.query_model_inuse() for algo_model_id, algo_model_exec in self.models.items(): @@ -45,6 +45,10 @@ model_path = algo_model_exec.algo_model_info.path logger.info(f'loading model {model_name}: {model_path}') + if not os.path.exists(model_path): + logger.info(f'model path:{model_path} not exists') + return + algo_model_exec.algo_model_exec = YOLO(model_path, task='detect') if self.model_warm_up > 0: logger.info(f'warming up model {model_name}') diff --git a/algo/stream_loader.py b/algo/stream_loader.py new file mode 100644 index 0000000..04faeb2 --- /dev/null +++ b/algo/stream_loader.py @@ -0,0 +1,92 @@ +import cv2 +import time +import numpy as np +from threading import Thread, Event +from common.global_logger import logger + + +class OpenCVStreamLoad: + def __init__(self, camera_url, camera_code, + retry_interval=1, + vid_stride=1): + assert camera_url is not None and camera_url != '' + self.url = camera_url + self.camera_code = camera_code + self.retry_interval = retry_interval + self.vid_stride = vid_stride + + self.cap = self.get_connect() + self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + + _, self.frame = self.cap.read() + self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 + self.thread = Thread(target=self.update, daemon=True) + self.thread.start() + + def create_capture(self): + """ + 尝试创建视频流捕获对象。 + """ + try: + cap = cv2.VideoCapture(self.url) + # 可以在这里设置cap的一些属性,如果需要的话 + return cap + except Exception as e: + logger.error(e) + return None + + def get_connect(self): + """ + 尝试重新连接,直到成功。 + """ + cap = None + while cap is None or not cap.isOpened(): + logger.info(f"{self.url} try to connect...") + cap = self.create_capture() + if cap is None or not cap.isOpened(): + logger.info(f"{self.url} connect failed, retry after {self.retry_interval} second...") + time.sleep(self.retry_interval) # 等待一段时间后重试 + else: + logger.info(f"{self.url} connect success!") + return cap + + def update(self): + vid_n = 0 + log_n = 0 + while True: + vid_n += 1 + if vid_n % self.vid_stride == 0: + try: + ret, frame = self.cap.read() + if not ret: + logger.info("disconnect, try to reconnect...") + self.cap.release() # 释放当前的捕获对象 + self.cap = self.get_connect() # 尝试重新连接 + self.frame = np.zeros_like(self.frame) + continue # 跳过当前循环的剩余部分 + else: + vid_n += 1 + self.frame = frame + # cv2.imwrite('cv_test.jpg', frame) + if log_n % 1000 == 0: + logger.debug('cap success') + log_n = (log_n + 1) % 250 + except Exception as e: + logger.error("update fail", e) + if self.cap is not None: + self.cap.release() + self.cap = self.get_connect() # 尝试重新连接 + + def __iter__(self): + return self + + def __next__(self): + return self.frame + + def stop(self): + """ 停止视频流读取线程 """ + self.__stop_event.set() + self.thread.join() # 确保线程已完全终止 + self.cap.release() diff --git a/common/biz_exception.py b/common/biz_exception.py index 946a794..4659571 100644 --- a/common/biz_exception.py +++ b/common/biz_exception.py @@ -11,4 +11,7 @@ class BizExceptionHandlers: @staticmethod async def biz_exception_handler(request: Request, exc: HTTPException): - standard_error_response(code=exc.status_code, message=exc.detail) + return standard_error_response(code=exc.status_code, message=exc.detail) + # 使用 JSONResponse 返回响应 + # return JSONResponse(status_code=exc.status_code, content=response_data) + diff --git a/common/global_logger.py b/common/global_logger.py index 1d65173..3c4861e 100644 --- a/common/global_logger.py +++ b/common/global_logger.py @@ -3,7 +3,7 @@ import os # 确保日志目录存在 -log_dir = '../logs' +log_dir = 'logs' if not os.path.exists(log_dir): os.makedirs(log_dir) diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py index 21d7fd5..f55359b 100644 --- a/common/global_thread_pool.py +++ b/common/global_thread_pool.py @@ -2,6 +2,8 @@ from concurrent.futures import ThreadPoolExecutor import threading +from common.global_logger import logger + def generate_thread_id(): """生成唯一的线程 ID""" @@ -44,15 +46,33 @@ else: return False + def check_task_stopped(self, thread_id): + """判断任务是否已停止""" + future = self.task_map.get(thread_id) + if future: + if future.done(): + try: + # 确保任务是正常完成的,而不是因为异常停止 + future.result() # 如果任务抛出异常,这里会捕获 + logger.info(f"Task {thread_id} has stopped successfully.") + except Exception as e: + logger.error(f"Task {thread_id} encountered an error: {e}") + return True # 无论成功还是失败,任务已停止 + else: + return False # 任务仍在运行 + else: + logger.warning(f"No task found with thread ID {thread_id}.") + return True # 如果找不到该任务,认为它已经停止(或者不存在) + def stop_task(self, thread_id): """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务""" future = self.task_map.get(thread_id) if future: future.cancel() # 尝试取消任务 - print(f"任务 {thread_id} 已取消") + logger.info(f"任务 {thread_id} 已取消") del self.task_map[thread_id] # 从任务映射中删除 else: - print(f"未找到线程 ID {thread_id}") + logger.info(f"未找到线程 ID {thread_id}") def shutdown(self, wait=True): """关闭线程池""" diff --git a/common/string_utils.py b/common/string_utils.py new file mode 100644 index 0000000..551a209 --- /dev/null +++ b/common/string_utils.py @@ -0,0 +1,6 @@ +import re + + +def camel_to_snake(name): + # 将大写字母前加上下划线,并将整个字符串转换为小写 + return re.sub(r'(? DeviceFrame: + def save_frame_file(): + # 生成当前年月日作为目录路径 + current_date = datetime.now().strftime('%Y/%m/%d') + # 生成随机 UUID 作为文件名 + file_name = str(uuid.uuid4()) + ".png" + # 创建保存图片的完整路径 + save_path = os.path.join('./storage/frames', current_date, file_name) + # 创建目录(如果不存在) + os.makedirs(os.path.dirname(save_path), exist_ok=True) + # 保存图片 + cv2.imwrite(save_path, frame_data) + return save_path + + # 保存图片文件 + file_path = save_frame_file() + device_frame = DeviceFrame(device_id=device_id, frame_path=file_path) + self.db.add(device_frame) + self.db.commit() + self.db.refresh(device_frame) + return device_frame diff --git a/services/device_model_relation_service.py b/services/device_model_relation_service.py index 0f8d9a7..49e167a 100644 --- a/services/device_model_relation_service.py +++ b/services/device_model_relation_service.py @@ -12,16 +12,16 @@ class DeviceModelRelationService: def __init__(self, db: Session): self.db = db - self.relation_change_callbacks = [] # 用于存储回调函数 + self.__relation_change_callbacks = [] # 用于存储回调函数 self.thread_pool = GlobalThreadPool() def register_change_callback(self, callback): """注册设备变化回调函数""" - self.relation_change_callbacks.append(callback) + self.__relation_change_callbacks.append(callback) def notify_change(self, device_id, change_type): """当设备发生变化时,调用回调通知变化""" - for callback in self.relation_change_callbacks: + for callback in self.__relation_change_callbacks: self.thread_pool.executor.submit(callback, device_id, change_type) def get_device_models(self, device_id: int) -> List[DeviceModelRelationInfo]: diff --git a/.gitignore b/.gitignore index e530a86..6c5fd4c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ test.py -/logs/* \ No newline at end of file +/logs/* +.idea \ No newline at end of file diff --git a/algo/algo_runner.py b/algo/algo_runner.py index ee07eaf..3b07e0a 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,4 @@ +from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager from common.consts import NotifyChangeType from common.global_thread_pool import GlobalThreadPool @@ -6,25 +7,27 @@ from services.device_model_relation_service import DeviceModelRelationService from services.device_service import DeviceService from services.model_service import ModelService - +from common.global_logger import logger class AlgoRunner: def __init__(self): - self.db = get_db() - self.device_service = DeviceService(self.db) - self.model_service = ModelService(self.db) - self.relation_service = DeviceModelRelationService(self.db) - self.model_manager = ModelManager(self.db) + with next(get_db()) as db: + self.device_service = DeviceService(db) + self.model_service = ModelService(db) + self.relation_service = DeviceModelRelationService(db) + self.model_manager = ModelManager(db) + self.thread_pool = GlobalThreadPool() - self.threads = {} # 用于存储设备对应的线程 + self.device_tasks = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 self.device_service.register_change_callback(self.on_device_change) self.model_service.register_change_callback(self.on_model_change) - self.relation_service.relation_change_callbacks(self.on_relation_change) + self.relation_service.register_change_callback(self.on_relation_change) def start(self): + logger.info("Starting AlgoRunner...") """在程序启动时调用,读取设备和模型,启动检测线程""" self.model_manager.load_models() self.load_and_start_devices() @@ -32,56 +35,73 @@ def load_and_start_devices(self): """从数据库读取设备列表并启动线程""" devices = self.device_service.get_device_list() - devices = [device for device in devices if device.input_stream_url] for device in devices: self.start_device_thread(device) - def run_detection(self, device: Device): - """todo 设备目标检测主逻辑""" - print(f"设备 {device.device_id} 的检测线程启动") - video_stream = self.device_service.get_video_stream(device.device_id) - - while True: - try: - frame = video_stream.read_frame() - for model in self.model_manager.models: - # 调用目标检测模型 - result = self.model_service.run_inference(model, frame) - # 在此处进行后处理,例如标记视频、生成告警等 - self.process_result(result, device.device_id) - except Exception as e: - print(f"设备 {device.device_id} 处理时出错: {e}") - break - video_stream.close() - print(f"设备 {device.device_id} 的检测线程结束") + @staticmethod + def get_device_thread_id(device_id): + return f'device_{device_id}' def start_device_thread(self, device: Device): """为单个设备启动检测线程""" if not device.input_stream_url: - print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程') + logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return - if device.device_id in self.threads: - print(f"设备 {device.device_id} 已经在运行中") + thread_id = self.get_device_thread_id(device.id) + + if self.thread_pool.check_task_is_running(thread_id=thread_id): + logger.info(f"设备 {device.code} 已经在运行中") return # 获取设备绑定的模型列表 - relations = self.relation_service.get_device_models(device.device_id) - self.device_model_relations[device.device_id] = relations + relations = self.relation_service.get_device_models(device.id) + relations = [r for r in relations if r.is_use] + self.device_model_relations[device.id] = relations if not relations: - print(f"设备 {device.code} 未绑定模型,无法启动检测") + logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") return - future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}') + model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations + if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] + device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list) + future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) + self.device_tasks[device.id] = device_detection_task def stop_device_thread(self, device_id): - """todo 控制标志位 停止指定设备的检测线程""" - if device_id in self.threads: - self.threads[device_id].cancel() # 尝试取消线程 - del self.threads[device_id] - print(f"设备 {device_id} 的检测线程已停止") + """停止指定设备的检测线程""" + if device_id in self.device_tasks: + self.device_tasks[device_id].stop_detection_task() + del self.device_tasks[device_id] + + # def stop_device_thread(self, device_id): + # """停止指定设备的检测线程,并确保其成功停止""" + # if device_id in self.device_tasks: + # # 获取线程 ID 和 future 对象 + # thread_id = self.get_device_thread_id(device_id) + # future = self.thread_pool.check_task_stopped(thread_id=thread_id) + # + # if future: + # # 调用 stop_detection_task 停止任务 + # self.device_tasks[device_id].stop_detection_task() + # + # try: + # # 设置超时时间等待任务停止(例如10秒) + # result = future.result(timeout=10) + # logger.info(f"Task for device {device_id} stopped successfully.") + # except TimeoutError: + # logger.error(f"Task for device {device_id} did not stop within the timeout.") + # except Exception as e: + # logger.error(f"Task for device {device_id} encountered an error while stopping: {e}") + # finally: + # # 确保无论任务是否停止,都将其从任务列表中移除 + # del self.device_tasks[device_id] + # else: + # logger.warning(f"No task found for device {device_id}.") + # else: + # logger.warning(f"No task exists for device {device_id} in device_tasks.") def restart_device_thread(self, device_id): self.stop_device_thread(device_id) @@ -97,6 +117,7 @@ elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: + #todo 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 self.restart_device_thread(device_id) def on_model_change(self, model_id, change_type): @@ -113,5 +134,6 @@ def on_relation_change(self, device_id, change_type): """设备模型关系变化回调""" + # todo 判断模型绑定关系是否真的发生了变化,避免不必要的重启 if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: self.restart_device_thread(device_id) diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py new file mode 100644 index 0000000..ff2d513 --- /dev/null +++ b/algo/device_detection_task.py @@ -0,0 +1,106 @@ +from dataclasses import dataclass +import importlib +from datetime import datetime +from threading import Event +from typing import List, Dict + +from algo.model_manager import AlgoModelExec +from algo.stream_loader import OpenCVStreamLoad +from common.global_thread_pool import GlobalThreadPool +from common.string_utils import camel_to_snake +from db.database import get_db +from entity.device import Device +from entity.frame_analysis_result import FrameAnalysisResultCreate +from services.device_frame_service import DeviceFrameService +from services.frame_analysis_result_service import FrameAnalysisResultService + + +def get_class(module_name, class_name): + # 动态导入模块 + module = importlib.import_module(module_name) + # 使用 getattr 从模块中获取类 + return getattr(module, class_name) + + +@dataclass +class DetectionResult: + object_class_id: int + object_class_name: str + confidence: float + location: str + + @classmethod + def from_dict(cls, data: Dict) -> 'DetectionResult': + return DetectionResult( + object_class_id=data['object_class_id'], + object_class_name=data['object_class_name'], + confidence=data['confidence'], + location=data['location'] + ) + + +class DeviceDetectionTask: + def __init__(self, device: Device, model_exec_list: List[AlgoModelExec]): + self.device = device + self.model_exec_list = model_exec_list + self.__stop_event = Event() # 使用 Event 控制线程的运行状态 + self.frame_ts = None + + with next(get_db()) as db: + self.device_frame_service = DeviceFrameService(db) + self.frame_analysis_result_service = FrameAnalysisResultService(db) + + self.thread_pool = GlobalThreadPool() + + self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code) + + def stop_detection_task(self): + self.__stop_event.set() + self.stream_loader.stop() # 停止视频流加载的线程 + + def check_frame_interval(self): + if self.device.image_save_interval < 0: + return False + if self.frame_ts is None: + self.frame_ts = datetime.now() + return True + if datetime.now() - self.frame_ts > self.device.image_save_interval: + self.frame_ts = datetime.now() + return True + return False + + def save_frame_results(self, frame, results_map): + if not self.check_frame_interval(): + return + device_frame = self.device_frame_service.add_frame(self.device.id, frame) + frame_id = device_frame.id + frame_results = [] + for model_exec_id, results in results_map.items(): + for r in results: + frame_result = FrameAnalysisResultCreate( + device_id=self.device.id, + frame_id=frame_id, + algo_model_id=model_exec_id, + object_class_id=r.object_class_id, + object_class_name=r.object_class_name, + confidence=r.conf, + location=r.location, + ) + frame_results.append(frame_result) + self.frame_analysis_result_service.add_frame_analysis_results(frame_results) + + def run(self): + for frame in self.stream_loader: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + + results_map = {} + for model_exec in self.model_exec_list: + handle_task_name = model_exec.algo_model_info.handle_task + handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) + handler_instance = handler_cls(model_exec) + frame, results = handler_instance.run(frame, None) + results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] + + # 结果处理 + self.thread_pool.submit_task(self.save_frame_results, frame, results_map) diff --git a/algo/model_manager.py b/algo/model_manager.py index 3ab0709..535c666 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -1,5 +1,6 @@ +import os.path from dataclasses import dataclass -from typing import Optional +from typing import Optional, Dict import numpy as np from ultralytics import YOLO @@ -13,7 +14,7 @@ class AlgoModelExec: algo_model_id: int algo_model_info: AlgoModel - algo_model_exec: Optional[object] = None + algo_model_exec: Optional[YOLO] = None input_size: int = 640 @@ -21,11 +22,9 @@ def __init__(self, db, model_warm_up=5): self.db = db self.model_service = ModelService(self.db) - self.models = {} + self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up - self.load_models() - def query_model_inuse(self): algo_model_list = list(self.model_service.get_models_in_use()) for algo_model in algo_model_list: @@ -35,6 +34,7 @@ ) def load_models(self): + logger.info('loading models') self.models = {} self.query_model_inuse() for algo_model_id, algo_model_exec in self.models.items(): @@ -45,6 +45,10 @@ model_path = algo_model_exec.algo_model_info.path logger.info(f'loading model {model_name}: {model_path}') + if not os.path.exists(model_path): + logger.info(f'model path:{model_path} not exists') + return + algo_model_exec.algo_model_exec = YOLO(model_path, task='detect') if self.model_warm_up > 0: logger.info(f'warming up model {model_name}') diff --git a/algo/stream_loader.py b/algo/stream_loader.py new file mode 100644 index 0000000..04faeb2 --- /dev/null +++ b/algo/stream_loader.py @@ -0,0 +1,92 @@ +import cv2 +import time +import numpy as np +from threading import Thread, Event +from common.global_logger import logger + + +class OpenCVStreamLoad: + def __init__(self, camera_url, camera_code, + retry_interval=1, + vid_stride=1): + assert camera_url is not None and camera_url != '' + self.url = camera_url + self.camera_code = camera_code + self.retry_interval = retry_interval + self.vid_stride = vid_stride + + self.cap = self.get_connect() + self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + + _, self.frame = self.cap.read() + self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 + self.thread = Thread(target=self.update, daemon=True) + self.thread.start() + + def create_capture(self): + """ + 尝试创建视频流捕获对象。 + """ + try: + cap = cv2.VideoCapture(self.url) + # 可以在这里设置cap的一些属性,如果需要的话 + return cap + except Exception as e: + logger.error(e) + return None + + def get_connect(self): + """ + 尝试重新连接,直到成功。 + """ + cap = None + while cap is None or not cap.isOpened(): + logger.info(f"{self.url} try to connect...") + cap = self.create_capture() + if cap is None or not cap.isOpened(): + logger.info(f"{self.url} connect failed, retry after {self.retry_interval} second...") + time.sleep(self.retry_interval) # 等待一段时间后重试 + else: + logger.info(f"{self.url} connect success!") + return cap + + def update(self): + vid_n = 0 + log_n = 0 + while True: + vid_n += 1 + if vid_n % self.vid_stride == 0: + try: + ret, frame = self.cap.read() + if not ret: + logger.info("disconnect, try to reconnect...") + self.cap.release() # 释放当前的捕获对象 + self.cap = self.get_connect() # 尝试重新连接 + self.frame = np.zeros_like(self.frame) + continue # 跳过当前循环的剩余部分 + else: + vid_n += 1 + self.frame = frame + # cv2.imwrite('cv_test.jpg', frame) + if log_n % 1000 == 0: + logger.debug('cap success') + log_n = (log_n + 1) % 250 + except Exception as e: + logger.error("update fail", e) + if self.cap is not None: + self.cap.release() + self.cap = self.get_connect() # 尝试重新连接 + + def __iter__(self): + return self + + def __next__(self): + return self.frame + + def stop(self): + """ 停止视频流读取线程 """ + self.__stop_event.set() + self.thread.join() # 确保线程已完全终止 + self.cap.release() diff --git a/common/biz_exception.py b/common/biz_exception.py index 946a794..4659571 100644 --- a/common/biz_exception.py +++ b/common/biz_exception.py @@ -11,4 +11,7 @@ class BizExceptionHandlers: @staticmethod async def biz_exception_handler(request: Request, exc: HTTPException): - standard_error_response(code=exc.status_code, message=exc.detail) + return standard_error_response(code=exc.status_code, message=exc.detail) + # 使用 JSONResponse 返回响应 + # return JSONResponse(status_code=exc.status_code, content=response_data) + diff --git a/common/global_logger.py b/common/global_logger.py index 1d65173..3c4861e 100644 --- a/common/global_logger.py +++ b/common/global_logger.py @@ -3,7 +3,7 @@ import os # 确保日志目录存在 -log_dir = '../logs' +log_dir = 'logs' if not os.path.exists(log_dir): os.makedirs(log_dir) diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py index 21d7fd5..f55359b 100644 --- a/common/global_thread_pool.py +++ b/common/global_thread_pool.py @@ -2,6 +2,8 @@ from concurrent.futures import ThreadPoolExecutor import threading +from common.global_logger import logger + def generate_thread_id(): """生成唯一的线程 ID""" @@ -44,15 +46,33 @@ else: return False + def check_task_stopped(self, thread_id): + """判断任务是否已停止""" + future = self.task_map.get(thread_id) + if future: + if future.done(): + try: + # 确保任务是正常完成的,而不是因为异常停止 + future.result() # 如果任务抛出异常,这里会捕获 + logger.info(f"Task {thread_id} has stopped successfully.") + except Exception as e: + logger.error(f"Task {thread_id} encountered an error: {e}") + return True # 无论成功还是失败,任务已停止 + else: + return False # 任务仍在运行 + else: + logger.warning(f"No task found with thread ID {thread_id}.") + return True # 如果找不到该任务,认为它已经停止(或者不存在) + def stop_task(self, thread_id): """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务""" future = self.task_map.get(thread_id) if future: future.cancel() # 尝试取消任务 - print(f"任务 {thread_id} 已取消") + logger.info(f"任务 {thread_id} 已取消") del self.task_map[thread_id] # 从任务映射中删除 else: - print(f"未找到线程 ID {thread_id}") + logger.info(f"未找到线程 ID {thread_id}") def shutdown(self, wait=True): """关闭线程池""" diff --git a/common/string_utils.py b/common/string_utils.py new file mode 100644 index 0000000..551a209 --- /dev/null +++ b/common/string_utils.py @@ -0,0 +1,6 @@ +import re + + +def camel_to_snake(name): + # 将大写字母前加上下划线,并将整个字符串转换为小写 + return re.sub(r'(? DeviceFrame: + def save_frame_file(): + # 生成当前年月日作为目录路径 + current_date = datetime.now().strftime('%Y/%m/%d') + # 生成随机 UUID 作为文件名 + file_name = str(uuid.uuid4()) + ".png" + # 创建保存图片的完整路径 + save_path = os.path.join('./storage/frames', current_date, file_name) + # 创建目录(如果不存在) + os.makedirs(os.path.dirname(save_path), exist_ok=True) + # 保存图片 + cv2.imwrite(save_path, frame_data) + return save_path + + # 保存图片文件 + file_path = save_frame_file() + device_frame = DeviceFrame(device_id=device_id, frame_path=file_path) + self.db.add(device_frame) + self.db.commit() + self.db.refresh(device_frame) + return device_frame diff --git a/services/device_model_relation_service.py b/services/device_model_relation_service.py index 0f8d9a7..49e167a 100644 --- a/services/device_model_relation_service.py +++ b/services/device_model_relation_service.py @@ -12,16 +12,16 @@ class DeviceModelRelationService: def __init__(self, db: Session): self.db = db - self.relation_change_callbacks = [] # 用于存储回调函数 + self.__relation_change_callbacks = [] # 用于存储回调函数 self.thread_pool = GlobalThreadPool() def register_change_callback(self, callback): """注册设备变化回调函数""" - self.relation_change_callbacks.append(callback) + self.__relation_change_callbacks.append(callback) def notify_change(self, device_id, change_type): """当设备发生变化时,调用回调通知变化""" - for callback in self.relation_change_callbacks: + for callback in self.__relation_change_callbacks: self.thread_pool.executor.submit(callback, device_id, change_type) def get_device_models(self, device_id: int) -> List[DeviceModelRelationInfo]: diff --git a/services/device_service.py b/services/device_service.py index 0284c4e..d3a2f10 100644 --- a/services/device_service.py +++ b/services/device_service.py @@ -13,16 +13,16 @@ class DeviceService: def __init__(self, db: Session): self.db = db - self.device_change_callbacks = [] # 用于存储回调函数 + self.__device_change_callbacks = [] # 用于存储回调函数 self.thread_pool = GlobalThreadPool() def register_change_callback(self, callback): """注册设备变化回调函数""" - self.device_change_callbacks.append(callback) + self.__device_change_callbacks.append(callback) def notify_change(self, device_id, change_type): """当设备发生变化时,调用回调通知变化""" - for callback in self.device_change_callbacks: + for callback in self.__device_change_callbacks: self.thread_pool.executor.submit(callback, device_id, change_type) def get_device_list(self, diff --git a/.gitignore b/.gitignore index e530a86..6c5fd4c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ test.py -/logs/* \ No newline at end of file +/logs/* +.idea \ No newline at end of file diff --git a/algo/algo_runner.py b/algo/algo_runner.py index ee07eaf..3b07e0a 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,4 @@ +from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager from common.consts import NotifyChangeType from common.global_thread_pool import GlobalThreadPool @@ -6,25 +7,27 @@ from services.device_model_relation_service import DeviceModelRelationService from services.device_service import DeviceService from services.model_service import ModelService - +from common.global_logger import logger class AlgoRunner: def __init__(self): - self.db = get_db() - self.device_service = DeviceService(self.db) - self.model_service = ModelService(self.db) - self.relation_service = DeviceModelRelationService(self.db) - self.model_manager = ModelManager(self.db) + with next(get_db()) as db: + self.device_service = DeviceService(db) + self.model_service = ModelService(db) + self.relation_service = DeviceModelRelationService(db) + self.model_manager = ModelManager(db) + self.thread_pool = GlobalThreadPool() - self.threads = {} # 用于存储设备对应的线程 + self.device_tasks = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 self.device_service.register_change_callback(self.on_device_change) self.model_service.register_change_callback(self.on_model_change) - self.relation_service.relation_change_callbacks(self.on_relation_change) + self.relation_service.register_change_callback(self.on_relation_change) def start(self): + logger.info("Starting AlgoRunner...") """在程序启动时调用,读取设备和模型,启动检测线程""" self.model_manager.load_models() self.load_and_start_devices() @@ -32,56 +35,73 @@ def load_and_start_devices(self): """从数据库读取设备列表并启动线程""" devices = self.device_service.get_device_list() - devices = [device for device in devices if device.input_stream_url] for device in devices: self.start_device_thread(device) - def run_detection(self, device: Device): - """todo 设备目标检测主逻辑""" - print(f"设备 {device.device_id} 的检测线程启动") - video_stream = self.device_service.get_video_stream(device.device_id) - - while True: - try: - frame = video_stream.read_frame() - for model in self.model_manager.models: - # 调用目标检测模型 - result = self.model_service.run_inference(model, frame) - # 在此处进行后处理,例如标记视频、生成告警等 - self.process_result(result, device.device_id) - except Exception as e: - print(f"设备 {device.device_id} 处理时出错: {e}") - break - video_stream.close() - print(f"设备 {device.device_id} 的检测线程结束") + @staticmethod + def get_device_thread_id(device_id): + return f'device_{device_id}' def start_device_thread(self, device: Device): """为单个设备启动检测线程""" if not device.input_stream_url: - print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程') + logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return - if device.device_id in self.threads: - print(f"设备 {device.device_id} 已经在运行中") + thread_id = self.get_device_thread_id(device.id) + + if self.thread_pool.check_task_is_running(thread_id=thread_id): + logger.info(f"设备 {device.code} 已经在运行中") return # 获取设备绑定的模型列表 - relations = self.relation_service.get_device_models(device.device_id) - self.device_model_relations[device.device_id] = relations + relations = self.relation_service.get_device_models(device.id) + relations = [r for r in relations if r.is_use] + self.device_model_relations[device.id] = relations if not relations: - print(f"设备 {device.code} 未绑定模型,无法启动检测") + logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") return - future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}') + model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations + if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] + device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list) + future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) + self.device_tasks[device.id] = device_detection_task def stop_device_thread(self, device_id): - """todo 控制标志位 停止指定设备的检测线程""" - if device_id in self.threads: - self.threads[device_id].cancel() # 尝试取消线程 - del self.threads[device_id] - print(f"设备 {device_id} 的检测线程已停止") + """停止指定设备的检测线程""" + if device_id in self.device_tasks: + self.device_tasks[device_id].stop_detection_task() + del self.device_tasks[device_id] + + # def stop_device_thread(self, device_id): + # """停止指定设备的检测线程,并确保其成功停止""" + # if device_id in self.device_tasks: + # # 获取线程 ID 和 future 对象 + # thread_id = self.get_device_thread_id(device_id) + # future = self.thread_pool.check_task_stopped(thread_id=thread_id) + # + # if future: + # # 调用 stop_detection_task 停止任务 + # self.device_tasks[device_id].stop_detection_task() + # + # try: + # # 设置超时时间等待任务停止(例如10秒) + # result = future.result(timeout=10) + # logger.info(f"Task for device {device_id} stopped successfully.") + # except TimeoutError: + # logger.error(f"Task for device {device_id} did not stop within the timeout.") + # except Exception as e: + # logger.error(f"Task for device {device_id} encountered an error while stopping: {e}") + # finally: + # # 确保无论任务是否停止,都将其从任务列表中移除 + # del self.device_tasks[device_id] + # else: + # logger.warning(f"No task found for device {device_id}.") + # else: + # logger.warning(f"No task exists for device {device_id} in device_tasks.") def restart_device_thread(self, device_id): self.stop_device_thread(device_id) @@ -97,6 +117,7 @@ elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: + #todo 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 self.restart_device_thread(device_id) def on_model_change(self, model_id, change_type): @@ -113,5 +134,6 @@ def on_relation_change(self, device_id, change_type): """设备模型关系变化回调""" + # todo 判断模型绑定关系是否真的发生了变化,避免不必要的重启 if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: self.restart_device_thread(device_id) diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py new file mode 100644 index 0000000..ff2d513 --- /dev/null +++ b/algo/device_detection_task.py @@ -0,0 +1,106 @@ +from dataclasses import dataclass +import importlib +from datetime import datetime +from threading import Event +from typing import List, Dict + +from algo.model_manager import AlgoModelExec +from algo.stream_loader import OpenCVStreamLoad +from common.global_thread_pool import GlobalThreadPool +from common.string_utils import camel_to_snake +from db.database import get_db +from entity.device import Device +from entity.frame_analysis_result import FrameAnalysisResultCreate +from services.device_frame_service import DeviceFrameService +from services.frame_analysis_result_service import FrameAnalysisResultService + + +def get_class(module_name, class_name): + # 动态导入模块 + module = importlib.import_module(module_name) + # 使用 getattr 从模块中获取类 + return getattr(module, class_name) + + +@dataclass +class DetectionResult: + object_class_id: int + object_class_name: str + confidence: float + location: str + + @classmethod + def from_dict(cls, data: Dict) -> 'DetectionResult': + return DetectionResult( + object_class_id=data['object_class_id'], + object_class_name=data['object_class_name'], + confidence=data['confidence'], + location=data['location'] + ) + + +class DeviceDetectionTask: + def __init__(self, device: Device, model_exec_list: List[AlgoModelExec]): + self.device = device + self.model_exec_list = model_exec_list + self.__stop_event = Event() # 使用 Event 控制线程的运行状态 + self.frame_ts = None + + with next(get_db()) as db: + self.device_frame_service = DeviceFrameService(db) + self.frame_analysis_result_service = FrameAnalysisResultService(db) + + self.thread_pool = GlobalThreadPool() + + self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code) + + def stop_detection_task(self): + self.__stop_event.set() + self.stream_loader.stop() # 停止视频流加载的线程 + + def check_frame_interval(self): + if self.device.image_save_interval < 0: + return False + if self.frame_ts is None: + self.frame_ts = datetime.now() + return True + if datetime.now() - self.frame_ts > self.device.image_save_interval: + self.frame_ts = datetime.now() + return True + return False + + def save_frame_results(self, frame, results_map): + if not self.check_frame_interval(): + return + device_frame = self.device_frame_service.add_frame(self.device.id, frame) + frame_id = device_frame.id + frame_results = [] + for model_exec_id, results in results_map.items(): + for r in results: + frame_result = FrameAnalysisResultCreate( + device_id=self.device.id, + frame_id=frame_id, + algo_model_id=model_exec_id, + object_class_id=r.object_class_id, + object_class_name=r.object_class_name, + confidence=r.conf, + location=r.location, + ) + frame_results.append(frame_result) + self.frame_analysis_result_service.add_frame_analysis_results(frame_results) + + def run(self): + for frame in self.stream_loader: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + + results_map = {} + for model_exec in self.model_exec_list: + handle_task_name = model_exec.algo_model_info.handle_task + handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) + handler_instance = handler_cls(model_exec) + frame, results = handler_instance.run(frame, None) + results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] + + # 结果处理 + self.thread_pool.submit_task(self.save_frame_results, frame, results_map) diff --git a/algo/model_manager.py b/algo/model_manager.py index 3ab0709..535c666 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -1,5 +1,6 @@ +import os.path from dataclasses import dataclass -from typing import Optional +from typing import Optional, Dict import numpy as np from ultralytics import YOLO @@ -13,7 +14,7 @@ class AlgoModelExec: algo_model_id: int algo_model_info: AlgoModel - algo_model_exec: Optional[object] = None + algo_model_exec: Optional[YOLO] = None input_size: int = 640 @@ -21,11 +22,9 @@ def __init__(self, db, model_warm_up=5): self.db = db self.model_service = ModelService(self.db) - self.models = {} + self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up - self.load_models() - def query_model_inuse(self): algo_model_list = list(self.model_service.get_models_in_use()) for algo_model in algo_model_list: @@ -35,6 +34,7 @@ ) def load_models(self): + logger.info('loading models') self.models = {} self.query_model_inuse() for algo_model_id, algo_model_exec in self.models.items(): @@ -45,6 +45,10 @@ model_path = algo_model_exec.algo_model_info.path logger.info(f'loading model {model_name}: {model_path}') + if not os.path.exists(model_path): + logger.info(f'model path:{model_path} not exists') + return + algo_model_exec.algo_model_exec = YOLO(model_path, task='detect') if self.model_warm_up > 0: logger.info(f'warming up model {model_name}') diff --git a/algo/stream_loader.py b/algo/stream_loader.py new file mode 100644 index 0000000..04faeb2 --- /dev/null +++ b/algo/stream_loader.py @@ -0,0 +1,92 @@ +import cv2 +import time +import numpy as np +from threading import Thread, Event +from common.global_logger import logger + + +class OpenCVStreamLoad: + def __init__(self, camera_url, camera_code, + retry_interval=1, + vid_stride=1): + assert camera_url is not None and camera_url != '' + self.url = camera_url + self.camera_code = camera_code + self.retry_interval = retry_interval + self.vid_stride = vid_stride + + self.cap = self.get_connect() + self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + + _, self.frame = self.cap.read() + self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 + self.thread = Thread(target=self.update, daemon=True) + self.thread.start() + + def create_capture(self): + """ + 尝试创建视频流捕获对象。 + """ + try: + cap = cv2.VideoCapture(self.url) + # 可以在这里设置cap的一些属性,如果需要的话 + return cap + except Exception as e: + logger.error(e) + return None + + def get_connect(self): + """ + 尝试重新连接,直到成功。 + """ + cap = None + while cap is None or not cap.isOpened(): + logger.info(f"{self.url} try to connect...") + cap = self.create_capture() + if cap is None or not cap.isOpened(): + logger.info(f"{self.url} connect failed, retry after {self.retry_interval} second...") + time.sleep(self.retry_interval) # 等待一段时间后重试 + else: + logger.info(f"{self.url} connect success!") + return cap + + def update(self): + vid_n = 0 + log_n = 0 + while True: + vid_n += 1 + if vid_n % self.vid_stride == 0: + try: + ret, frame = self.cap.read() + if not ret: + logger.info("disconnect, try to reconnect...") + self.cap.release() # 释放当前的捕获对象 + self.cap = self.get_connect() # 尝试重新连接 + self.frame = np.zeros_like(self.frame) + continue # 跳过当前循环的剩余部分 + else: + vid_n += 1 + self.frame = frame + # cv2.imwrite('cv_test.jpg', frame) + if log_n % 1000 == 0: + logger.debug('cap success') + log_n = (log_n + 1) % 250 + except Exception as e: + logger.error("update fail", e) + if self.cap is not None: + self.cap.release() + self.cap = self.get_connect() # 尝试重新连接 + + def __iter__(self): + return self + + def __next__(self): + return self.frame + + def stop(self): + """ 停止视频流读取线程 """ + self.__stop_event.set() + self.thread.join() # 确保线程已完全终止 + self.cap.release() diff --git a/common/biz_exception.py b/common/biz_exception.py index 946a794..4659571 100644 --- a/common/biz_exception.py +++ b/common/biz_exception.py @@ -11,4 +11,7 @@ class BizExceptionHandlers: @staticmethod async def biz_exception_handler(request: Request, exc: HTTPException): - standard_error_response(code=exc.status_code, message=exc.detail) + return standard_error_response(code=exc.status_code, message=exc.detail) + # 使用 JSONResponse 返回响应 + # return JSONResponse(status_code=exc.status_code, content=response_data) + diff --git a/common/global_logger.py b/common/global_logger.py index 1d65173..3c4861e 100644 --- a/common/global_logger.py +++ b/common/global_logger.py @@ -3,7 +3,7 @@ import os # 确保日志目录存在 -log_dir = '../logs' +log_dir = 'logs' if not os.path.exists(log_dir): os.makedirs(log_dir) diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py index 21d7fd5..f55359b 100644 --- a/common/global_thread_pool.py +++ b/common/global_thread_pool.py @@ -2,6 +2,8 @@ from concurrent.futures import ThreadPoolExecutor import threading +from common.global_logger import logger + def generate_thread_id(): """生成唯一的线程 ID""" @@ -44,15 +46,33 @@ else: return False + def check_task_stopped(self, thread_id): + """判断任务是否已停止""" + future = self.task_map.get(thread_id) + if future: + if future.done(): + try: + # 确保任务是正常完成的,而不是因为异常停止 + future.result() # 如果任务抛出异常,这里会捕获 + logger.info(f"Task {thread_id} has stopped successfully.") + except Exception as e: + logger.error(f"Task {thread_id} encountered an error: {e}") + return True # 无论成功还是失败,任务已停止 + else: + return False # 任务仍在运行 + else: + logger.warning(f"No task found with thread ID {thread_id}.") + return True # 如果找不到该任务,认为它已经停止(或者不存在) + def stop_task(self, thread_id): """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务""" future = self.task_map.get(thread_id) if future: future.cancel() # 尝试取消任务 - print(f"任务 {thread_id} 已取消") + logger.info(f"任务 {thread_id} 已取消") del self.task_map[thread_id] # 从任务映射中删除 else: - print(f"未找到线程 ID {thread_id}") + logger.info(f"未找到线程 ID {thread_id}") def shutdown(self, wait=True): """关闭线程池""" diff --git a/common/string_utils.py b/common/string_utils.py new file mode 100644 index 0000000..551a209 --- /dev/null +++ b/common/string_utils.py @@ -0,0 +1,6 @@ +import re + + +def camel_to_snake(name): + # 将大写字母前加上下划线,并将整个字符串转换为小写 + return re.sub(r'(? DeviceFrame: + def save_frame_file(): + # 生成当前年月日作为目录路径 + current_date = datetime.now().strftime('%Y/%m/%d') + # 生成随机 UUID 作为文件名 + file_name = str(uuid.uuid4()) + ".png" + # 创建保存图片的完整路径 + save_path = os.path.join('./storage/frames', current_date, file_name) + # 创建目录(如果不存在) + os.makedirs(os.path.dirname(save_path), exist_ok=True) + # 保存图片 + cv2.imwrite(save_path, frame_data) + return save_path + + # 保存图片文件 + file_path = save_frame_file() + device_frame = DeviceFrame(device_id=device_id, frame_path=file_path) + self.db.add(device_frame) + self.db.commit() + self.db.refresh(device_frame) + return device_frame diff --git a/services/device_model_relation_service.py b/services/device_model_relation_service.py index 0f8d9a7..49e167a 100644 --- a/services/device_model_relation_service.py +++ b/services/device_model_relation_service.py @@ -12,16 +12,16 @@ class DeviceModelRelationService: def __init__(self, db: Session): self.db = db - self.relation_change_callbacks = [] # 用于存储回调函数 + self.__relation_change_callbacks = [] # 用于存储回调函数 self.thread_pool = GlobalThreadPool() def register_change_callback(self, callback): """注册设备变化回调函数""" - self.relation_change_callbacks.append(callback) + self.__relation_change_callbacks.append(callback) def notify_change(self, device_id, change_type): """当设备发生变化时,调用回调通知变化""" - for callback in self.relation_change_callbacks: + for callback in self.__relation_change_callbacks: self.thread_pool.executor.submit(callback, device_id, change_type) def get_device_models(self, device_id: int) -> List[DeviceModelRelationInfo]: diff --git a/services/device_service.py b/services/device_service.py index 0284c4e..d3a2f10 100644 --- a/services/device_service.py +++ b/services/device_service.py @@ -13,16 +13,16 @@ class DeviceService: def __init__(self, db: Session): self.db = db - self.device_change_callbacks = [] # 用于存储回调函数 + self.__device_change_callbacks = [] # 用于存储回调函数 self.thread_pool = GlobalThreadPool() def register_change_callback(self, callback): """注册设备变化回调函数""" - self.device_change_callbacks.append(callback) + self.__device_change_callbacks.append(callback) def notify_change(self, device_id, change_type): """当设备发生变化时,调用回调通知变化""" - for callback in self.device_change_callbacks: + for callback in self.__device_change_callbacks: self.thread_pool.executor.submit(callback, device_id, change_type) def get_device_list(self, diff --git a/services/frame_analysis_result_service.py b/services/frame_analysis_result_service.py new file mode 100644 index 0000000..3884fa3 --- /dev/null +++ b/services/frame_analysis_result_service.py @@ -0,0 +1,18 @@ +from typing import List + +from sqlmodel import Session +from entity.frame_analysis_result import FrameAnalysisResultCreate, FrameAnalysisResult + + +class FrameAnalysisResultService: + + def __init__(self, db: Session): + self.db = db + + def add_frame_analysis_results(self, results: List[FrameAnalysisResultCreate]): + new_results = [FrameAnalysisResult.model_validate(result) for result in results] + self.db.add_all(new_results) + self.db.commit() + for result in new_results: + self.db.refresh(result) + return new_results diff --git a/.gitignore b/.gitignore index e530a86..6c5fd4c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ test.py -/logs/* \ No newline at end of file +/logs/* +.idea \ No newline at end of file diff --git a/algo/algo_runner.py b/algo/algo_runner.py index ee07eaf..3b07e0a 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,4 @@ +from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager from common.consts import NotifyChangeType from common.global_thread_pool import GlobalThreadPool @@ -6,25 +7,27 @@ from services.device_model_relation_service import DeviceModelRelationService from services.device_service import DeviceService from services.model_service import ModelService - +from common.global_logger import logger class AlgoRunner: def __init__(self): - self.db = get_db() - self.device_service = DeviceService(self.db) - self.model_service = ModelService(self.db) - self.relation_service = DeviceModelRelationService(self.db) - self.model_manager = ModelManager(self.db) + with next(get_db()) as db: + self.device_service = DeviceService(db) + self.model_service = ModelService(db) + self.relation_service = DeviceModelRelationService(db) + self.model_manager = ModelManager(db) + self.thread_pool = GlobalThreadPool() - self.threads = {} # 用于存储设备对应的线程 + self.device_tasks = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 self.device_service.register_change_callback(self.on_device_change) self.model_service.register_change_callback(self.on_model_change) - self.relation_service.relation_change_callbacks(self.on_relation_change) + self.relation_service.register_change_callback(self.on_relation_change) def start(self): + logger.info("Starting AlgoRunner...") """在程序启动时调用,读取设备和模型,启动检测线程""" self.model_manager.load_models() self.load_and_start_devices() @@ -32,56 +35,73 @@ def load_and_start_devices(self): """从数据库读取设备列表并启动线程""" devices = self.device_service.get_device_list() - devices = [device for device in devices if device.input_stream_url] for device in devices: self.start_device_thread(device) - def run_detection(self, device: Device): - """todo 设备目标检测主逻辑""" - print(f"设备 {device.device_id} 的检测线程启动") - video_stream = self.device_service.get_video_stream(device.device_id) - - while True: - try: - frame = video_stream.read_frame() - for model in self.model_manager.models: - # 调用目标检测模型 - result = self.model_service.run_inference(model, frame) - # 在此处进行后处理,例如标记视频、生成告警等 - self.process_result(result, device.device_id) - except Exception as e: - print(f"设备 {device.device_id} 处理时出错: {e}") - break - video_stream.close() - print(f"设备 {device.device_id} 的检测线程结束") + @staticmethod + def get_device_thread_id(device_id): + return f'device_{device_id}' def start_device_thread(self, device: Device): """为单个设备启动检测线程""" if not device.input_stream_url: - print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程') + logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return - if device.device_id in self.threads: - print(f"设备 {device.device_id} 已经在运行中") + thread_id = self.get_device_thread_id(device.id) + + if self.thread_pool.check_task_is_running(thread_id=thread_id): + logger.info(f"设备 {device.code} 已经在运行中") return # 获取设备绑定的模型列表 - relations = self.relation_service.get_device_models(device.device_id) - self.device_model_relations[device.device_id] = relations + relations = self.relation_service.get_device_models(device.id) + relations = [r for r in relations if r.is_use] + self.device_model_relations[device.id] = relations if not relations: - print(f"设备 {device.code} 未绑定模型,无法启动检测") + logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") return - future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}') + model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations + if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] + device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list) + future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) + self.device_tasks[device.id] = device_detection_task def stop_device_thread(self, device_id): - """todo 控制标志位 停止指定设备的检测线程""" - if device_id in self.threads: - self.threads[device_id].cancel() # 尝试取消线程 - del self.threads[device_id] - print(f"设备 {device_id} 的检测线程已停止") + """停止指定设备的检测线程""" + if device_id in self.device_tasks: + self.device_tasks[device_id].stop_detection_task() + del self.device_tasks[device_id] + + # def stop_device_thread(self, device_id): + # """停止指定设备的检测线程,并确保其成功停止""" + # if device_id in self.device_tasks: + # # 获取线程 ID 和 future 对象 + # thread_id = self.get_device_thread_id(device_id) + # future = self.thread_pool.check_task_stopped(thread_id=thread_id) + # + # if future: + # # 调用 stop_detection_task 停止任务 + # self.device_tasks[device_id].stop_detection_task() + # + # try: + # # 设置超时时间等待任务停止(例如10秒) + # result = future.result(timeout=10) + # logger.info(f"Task for device {device_id} stopped successfully.") + # except TimeoutError: + # logger.error(f"Task for device {device_id} did not stop within the timeout.") + # except Exception as e: + # logger.error(f"Task for device {device_id} encountered an error while stopping: {e}") + # finally: + # # 确保无论任务是否停止,都将其从任务列表中移除 + # del self.device_tasks[device_id] + # else: + # logger.warning(f"No task found for device {device_id}.") + # else: + # logger.warning(f"No task exists for device {device_id} in device_tasks.") def restart_device_thread(self, device_id): self.stop_device_thread(device_id) @@ -97,6 +117,7 @@ elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: + #todo 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 self.restart_device_thread(device_id) def on_model_change(self, model_id, change_type): @@ -113,5 +134,6 @@ def on_relation_change(self, device_id, change_type): """设备模型关系变化回调""" + # todo 判断模型绑定关系是否真的发生了变化,避免不必要的重启 if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: self.restart_device_thread(device_id) diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py new file mode 100644 index 0000000..ff2d513 --- /dev/null +++ b/algo/device_detection_task.py @@ -0,0 +1,106 @@ +from dataclasses import dataclass +import importlib +from datetime import datetime +from threading import Event +from typing import List, Dict + +from algo.model_manager import AlgoModelExec +from algo.stream_loader import OpenCVStreamLoad +from common.global_thread_pool import GlobalThreadPool +from common.string_utils import camel_to_snake +from db.database import get_db +from entity.device import Device +from entity.frame_analysis_result import FrameAnalysisResultCreate +from services.device_frame_service import DeviceFrameService +from services.frame_analysis_result_service import FrameAnalysisResultService + + +def get_class(module_name, class_name): + # 动态导入模块 + module = importlib.import_module(module_name) + # 使用 getattr 从模块中获取类 + return getattr(module, class_name) + + +@dataclass +class DetectionResult: + object_class_id: int + object_class_name: str + confidence: float + location: str + + @classmethod + def from_dict(cls, data: Dict) -> 'DetectionResult': + return DetectionResult( + object_class_id=data['object_class_id'], + object_class_name=data['object_class_name'], + confidence=data['confidence'], + location=data['location'] + ) + + +class DeviceDetectionTask: + def __init__(self, device: Device, model_exec_list: List[AlgoModelExec]): + self.device = device + self.model_exec_list = model_exec_list + self.__stop_event = Event() # 使用 Event 控制线程的运行状态 + self.frame_ts = None + + with next(get_db()) as db: + self.device_frame_service = DeviceFrameService(db) + self.frame_analysis_result_service = FrameAnalysisResultService(db) + + self.thread_pool = GlobalThreadPool() + + self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code) + + def stop_detection_task(self): + self.__stop_event.set() + self.stream_loader.stop() # 停止视频流加载的线程 + + def check_frame_interval(self): + if self.device.image_save_interval < 0: + return False + if self.frame_ts is None: + self.frame_ts = datetime.now() + return True + if datetime.now() - self.frame_ts > self.device.image_save_interval: + self.frame_ts = datetime.now() + return True + return False + + def save_frame_results(self, frame, results_map): + if not self.check_frame_interval(): + return + device_frame = self.device_frame_service.add_frame(self.device.id, frame) + frame_id = device_frame.id + frame_results = [] + for model_exec_id, results in results_map.items(): + for r in results: + frame_result = FrameAnalysisResultCreate( + device_id=self.device.id, + frame_id=frame_id, + algo_model_id=model_exec_id, + object_class_id=r.object_class_id, + object_class_name=r.object_class_name, + confidence=r.conf, + location=r.location, + ) + frame_results.append(frame_result) + self.frame_analysis_result_service.add_frame_analysis_results(frame_results) + + def run(self): + for frame in self.stream_loader: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + + results_map = {} + for model_exec in self.model_exec_list: + handle_task_name = model_exec.algo_model_info.handle_task + handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) + handler_instance = handler_cls(model_exec) + frame, results = handler_instance.run(frame, None) + results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] + + # 结果处理 + self.thread_pool.submit_task(self.save_frame_results, frame, results_map) diff --git a/algo/model_manager.py b/algo/model_manager.py index 3ab0709..535c666 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -1,5 +1,6 @@ +import os.path from dataclasses import dataclass -from typing import Optional +from typing import Optional, Dict import numpy as np from ultralytics import YOLO @@ -13,7 +14,7 @@ class AlgoModelExec: algo_model_id: int algo_model_info: AlgoModel - algo_model_exec: Optional[object] = None + algo_model_exec: Optional[YOLO] = None input_size: int = 640 @@ -21,11 +22,9 @@ def __init__(self, db, model_warm_up=5): self.db = db self.model_service = ModelService(self.db) - self.models = {} + self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up - self.load_models() - def query_model_inuse(self): algo_model_list = list(self.model_service.get_models_in_use()) for algo_model in algo_model_list: @@ -35,6 +34,7 @@ ) def load_models(self): + logger.info('loading models') self.models = {} self.query_model_inuse() for algo_model_id, algo_model_exec in self.models.items(): @@ -45,6 +45,10 @@ model_path = algo_model_exec.algo_model_info.path logger.info(f'loading model {model_name}: {model_path}') + if not os.path.exists(model_path): + logger.info(f'model path:{model_path} not exists') + return + algo_model_exec.algo_model_exec = YOLO(model_path, task='detect') if self.model_warm_up > 0: logger.info(f'warming up model {model_name}') diff --git a/algo/stream_loader.py b/algo/stream_loader.py new file mode 100644 index 0000000..04faeb2 --- /dev/null +++ b/algo/stream_loader.py @@ -0,0 +1,92 @@ +import cv2 +import time +import numpy as np +from threading import Thread, Event +from common.global_logger import logger + + +class OpenCVStreamLoad: + def __init__(self, camera_url, camera_code, + retry_interval=1, + vid_stride=1): + assert camera_url is not None and camera_url != '' + self.url = camera_url + self.camera_code = camera_code + self.retry_interval = retry_interval + self.vid_stride = vid_stride + + self.cap = self.get_connect() + self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + + _, self.frame = self.cap.read() + self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 + self.thread = Thread(target=self.update, daemon=True) + self.thread.start() + + def create_capture(self): + """ + 尝试创建视频流捕获对象。 + """ + try: + cap = cv2.VideoCapture(self.url) + # 可以在这里设置cap的一些属性,如果需要的话 + return cap + except Exception as e: + logger.error(e) + return None + + def get_connect(self): + """ + 尝试重新连接,直到成功。 + """ + cap = None + while cap is None or not cap.isOpened(): + logger.info(f"{self.url} try to connect...") + cap = self.create_capture() + if cap is None or not cap.isOpened(): + logger.info(f"{self.url} connect failed, retry after {self.retry_interval} second...") + time.sleep(self.retry_interval) # 等待一段时间后重试 + else: + logger.info(f"{self.url} connect success!") + return cap + + def update(self): + vid_n = 0 + log_n = 0 + while True: + vid_n += 1 + if vid_n % self.vid_stride == 0: + try: + ret, frame = self.cap.read() + if not ret: + logger.info("disconnect, try to reconnect...") + self.cap.release() # 释放当前的捕获对象 + self.cap = self.get_connect() # 尝试重新连接 + self.frame = np.zeros_like(self.frame) + continue # 跳过当前循环的剩余部分 + else: + vid_n += 1 + self.frame = frame + # cv2.imwrite('cv_test.jpg', frame) + if log_n % 1000 == 0: + logger.debug('cap success') + log_n = (log_n + 1) % 250 + except Exception as e: + logger.error("update fail", e) + if self.cap is not None: + self.cap.release() + self.cap = self.get_connect() # 尝试重新连接 + + def __iter__(self): + return self + + def __next__(self): + return self.frame + + def stop(self): + """ 停止视频流读取线程 """ + self.__stop_event.set() + self.thread.join() # 确保线程已完全终止 + self.cap.release() diff --git a/common/biz_exception.py b/common/biz_exception.py index 946a794..4659571 100644 --- a/common/biz_exception.py +++ b/common/biz_exception.py @@ -11,4 +11,7 @@ class BizExceptionHandlers: @staticmethod async def biz_exception_handler(request: Request, exc: HTTPException): - standard_error_response(code=exc.status_code, message=exc.detail) + return standard_error_response(code=exc.status_code, message=exc.detail) + # 使用 JSONResponse 返回响应 + # return JSONResponse(status_code=exc.status_code, content=response_data) + diff --git a/common/global_logger.py b/common/global_logger.py index 1d65173..3c4861e 100644 --- a/common/global_logger.py +++ b/common/global_logger.py @@ -3,7 +3,7 @@ import os # 确保日志目录存在 -log_dir = '../logs' +log_dir = 'logs' if not os.path.exists(log_dir): os.makedirs(log_dir) diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py index 21d7fd5..f55359b 100644 --- a/common/global_thread_pool.py +++ b/common/global_thread_pool.py @@ -2,6 +2,8 @@ from concurrent.futures import ThreadPoolExecutor import threading +from common.global_logger import logger + def generate_thread_id(): """生成唯一的线程 ID""" @@ -44,15 +46,33 @@ else: return False + def check_task_stopped(self, thread_id): + """判断任务是否已停止""" + future = self.task_map.get(thread_id) + if future: + if future.done(): + try: + # 确保任务是正常完成的,而不是因为异常停止 + future.result() # 如果任务抛出异常,这里会捕获 + logger.info(f"Task {thread_id} has stopped successfully.") + except Exception as e: + logger.error(f"Task {thread_id} encountered an error: {e}") + return True # 无论成功还是失败,任务已停止 + else: + return False # 任务仍在运行 + else: + logger.warning(f"No task found with thread ID {thread_id}.") + return True # 如果找不到该任务,认为它已经停止(或者不存在) + def stop_task(self, thread_id): """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务""" future = self.task_map.get(thread_id) if future: future.cancel() # 尝试取消任务 - print(f"任务 {thread_id} 已取消") + logger.info(f"任务 {thread_id} 已取消") del self.task_map[thread_id] # 从任务映射中删除 else: - print(f"未找到线程 ID {thread_id}") + logger.info(f"未找到线程 ID {thread_id}") def shutdown(self, wait=True): """关闭线程池""" diff --git a/common/string_utils.py b/common/string_utils.py new file mode 100644 index 0000000..551a209 --- /dev/null +++ b/common/string_utils.py @@ -0,0 +1,6 @@ +import re + + +def camel_to_snake(name): + # 将大写字母前加上下划线,并将整个字符串转换为小写 + return re.sub(r'(? DeviceFrame: + def save_frame_file(): + # 生成当前年月日作为目录路径 + current_date = datetime.now().strftime('%Y/%m/%d') + # 生成随机 UUID 作为文件名 + file_name = str(uuid.uuid4()) + ".png" + # 创建保存图片的完整路径 + save_path = os.path.join('./storage/frames', current_date, file_name) + # 创建目录(如果不存在) + os.makedirs(os.path.dirname(save_path), exist_ok=True) + # 保存图片 + cv2.imwrite(save_path, frame_data) + return save_path + + # 保存图片文件 + file_path = save_frame_file() + device_frame = DeviceFrame(device_id=device_id, frame_path=file_path) + self.db.add(device_frame) + self.db.commit() + self.db.refresh(device_frame) + return device_frame diff --git a/services/device_model_relation_service.py b/services/device_model_relation_service.py index 0f8d9a7..49e167a 100644 --- a/services/device_model_relation_service.py +++ b/services/device_model_relation_service.py @@ -12,16 +12,16 @@ class DeviceModelRelationService: def __init__(self, db: Session): self.db = db - self.relation_change_callbacks = [] # 用于存储回调函数 + self.__relation_change_callbacks = [] # 用于存储回调函数 self.thread_pool = GlobalThreadPool() def register_change_callback(self, callback): """注册设备变化回调函数""" - self.relation_change_callbacks.append(callback) + self.__relation_change_callbacks.append(callback) def notify_change(self, device_id, change_type): """当设备发生变化时,调用回调通知变化""" - for callback in self.relation_change_callbacks: + for callback in self.__relation_change_callbacks: self.thread_pool.executor.submit(callback, device_id, change_type) def get_device_models(self, device_id: int) -> List[DeviceModelRelationInfo]: diff --git a/services/device_service.py b/services/device_service.py index 0284c4e..d3a2f10 100644 --- a/services/device_service.py +++ b/services/device_service.py @@ -13,16 +13,16 @@ class DeviceService: def __init__(self, db: Session): self.db = db - self.device_change_callbacks = [] # 用于存储回调函数 + self.__device_change_callbacks = [] # 用于存储回调函数 self.thread_pool = GlobalThreadPool() def register_change_callback(self, callback): """注册设备变化回调函数""" - self.device_change_callbacks.append(callback) + self.__device_change_callbacks.append(callback) def notify_change(self, device_id, change_type): """当设备发生变化时,调用回调通知变化""" - for callback in self.device_change_callbacks: + for callback in self.__device_change_callbacks: self.thread_pool.executor.submit(callback, device_id, change_type) def get_device_list(self, diff --git a/services/frame_analysis_result_service.py b/services/frame_analysis_result_service.py new file mode 100644 index 0000000..3884fa3 --- /dev/null +++ b/services/frame_analysis_result_service.py @@ -0,0 +1,18 @@ +from typing import List + +from sqlmodel import Session +from entity.frame_analysis_result import FrameAnalysisResultCreate, FrameAnalysisResult + + +class FrameAnalysisResultService: + + def __init__(self, db: Session): + self.db = db + + def add_frame_analysis_results(self, results: List[FrameAnalysisResultCreate]): + new_results = [FrameAnalysisResult.model_validate(result) for result in results] + self.db.add_all(new_results) + self.db.commit() + for result in new_results: + self.db.refresh(result) + return new_results diff --git a/services/model_service.py b/services/model_service.py index 568895e..de387c0 100644 --- a/services/model_service.py +++ b/services/model_service.py @@ -14,16 +14,16 @@ class ModelService: def __init__(self, db: Session): self.db = db - self.model_change_callbacks = [] # 用于存储回调函数 + self.__model_change_callbacks = [] # 用于存储回调函数 self.thread_pool = GlobalThreadPool() def register_change_callback(self, callback): """注册设备变化回调函数""" - self.model_change_callbacks.append(callback) + self.__model_change_callbacks.append(callback) def notify_change(self, algo_model_id, change_type): """当设备发生变化时,调用回调通知变化""" - for callback in self.model_change_callbacks: + for callback in self.__model_change_callbacks: self.thread_pool.executor.submit(callback, algo_model_id, change_type) def get_model_list(self,