diff --git a/algo/algo_runner.py b/algo/algo_runner.py index 3b07e0a..a97aee9 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,7 @@ +import copy +import uuid +from typing import Dict + from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager from common.consts import NotifyChangeType @@ -9,6 +13,7 @@ from services.model_service import ModelService from common.global_logger import logger + class AlgoRunner: def __init__(self): with next(get_db()) as db: @@ -18,7 +23,7 @@ self.model_manager = ModelManager(db) self.thread_pool = GlobalThreadPool() - self.device_tasks = {} # 用于存储设备对应的线程 + self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 @@ -38,22 +43,24 @@ for device in devices: self.start_device_thread(device) - @staticmethod - def get_device_thread_id(device_id): - return f'device_{device_id}' + # @staticmethod + # def get_device_thread_id(device_id): + # return f'device_{device_id}' def start_device_thread(self, device: Device): + device = copy.deepcopy(device) """为单个设备启动检测线程""" if not device.input_stream_url: logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return - thread_id = self.get_device_thread_id(device.id) + thread_id = f'device_{device.id}_{uuid.uuid4()}' + logger.info(f'start thread {thread_id}, device info: {device}') - if self.thread_pool.check_task_is_running(thread_id=thread_id): - logger.info(f"设备 {device.code} 已经在运行中") - return + # if self.thread_pool.check_task_is_running(thread_id=thread_id): + # logger.info(f"设备 {device.code} 已经在运行中") + # return # 获取设备绑定的模型列表 relations = self.relation_service.get_device_models(device.id) @@ -64,52 +71,62 @@ logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") return + for r in relations: + if r.algo_model_id not in self.model_manager.models: + self.model_manager.load_new_model(r.algo_model_id) + model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] - device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list) + device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list, thread_id=thread_id) future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) self.device_tasks[device.id] = device_detection_task - def stop_device_thread(self, device_id): - """停止指定设备的检测线程""" - if device_id in self.device_tasks: - self.device_tasks[device_id].stop_detection_task() - del self.device_tasks[device_id] - # def stop_device_thread(self, device_id): - # """停止指定设备的检测线程,并确保其成功停止""" + # """停止指定设备的检测线程""" # if device_id in self.device_tasks: - # # 获取线程 ID 和 future 对象 - # thread_id = self.get_device_thread_id(device_id) - # future = self.thread_pool.check_task_stopped(thread_id=thread_id) - # - # if future: - # # 调用 stop_detection_task 停止任务 - # self.device_tasks[device_id].stop_detection_task() - # - # try: - # # 设置超时时间等待任务停止(例如10秒) - # result = future.result(timeout=10) - # logger.info(f"Task for device {device_id} stopped successfully.") - # except TimeoutError: - # logger.error(f"Task for device {device_id} did not stop within the timeout.") - # except Exception as e: - # logger.error(f"Task for device {device_id} encountered an error while stopping: {e}") - # finally: - # # 确保无论任务是否停止,都将其从任务列表中移除 - # del self.device_tasks[device_id] - # else: - # logger.warning(f"No task found for device {device_id}.") - # else: - # logger.warning(f"No task exists for device {device_id} in device_tasks.") + # logger.info(f'stop device {device_id} thread') + # self.device_tasks[device_id].stop_detection_task() + # del self.device_tasks[device_id] + + def stop_device_thread(self, device_id): + """停止指定设备的检测线程,并确保其成功停止""" + if device_id in self.device_tasks: + logger.info(f'stop device {device_id} thread') + # 获取线程 ID 和 future 对象 + thread_id = self.device_tasks[device_id].thread_id + future = self.thread_pool.get_task_future(thread_id=thread_id) + + if future: + if not future.done(): + # 任务正在运行,调用 stop_detection_task 停止任务 + self.device_tasks[device_id].stop_detection_task() + try: + # 设置超时时间等待任务停止(例如10秒) + result = future.result(timeout=10) + logger.info(f"Task {thread_id} stopped successfully.") + except TimeoutError: + logger.error(f"Task {thread_id} did not stop within the timeout.") + except Exception as e: + logger.error(f"Task {thread_id} encountered an error while stopping: {e}") + finally: + # 确保无论任务是否停止,都将其从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.info(f"Task {thread_id} has already stopped.") + # 任务已停止,直接从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.warning(f"No task found for {thread_id} .") + else: + logger.warning(f"No task exists for device {device_id} in device_tasks.") def restart_device_thread(self, device_id): self.stop_device_thread(device_id) - # todo 需要留个关闭时间吗 device = self.device_service.get_device(device_id) self.start_device_thread(device) def on_device_change(self, device_id, change_type): + logger.info(f"on device change, device {device_id} {change_type}") """设备变化回调""" if change_type == NotifyChangeType.DEVICE_CREATE: device = self.device_service.get_device(device_id) @@ -117,7 +134,10 @@ elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: - #todo 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) + self.restart_device_thread(device_id) def on_model_change(self, model_id, change_type): diff --git a/algo/algo_runner.py b/algo/algo_runner.py index 3b07e0a..a97aee9 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,7 @@ +import copy +import uuid +from typing import Dict + from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager from common.consts import NotifyChangeType @@ -9,6 +13,7 @@ from services.model_service import ModelService from common.global_logger import logger + class AlgoRunner: def __init__(self): with next(get_db()) as db: @@ -18,7 +23,7 @@ self.model_manager = ModelManager(db) self.thread_pool = GlobalThreadPool() - self.device_tasks = {} # 用于存储设备对应的线程 + self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 @@ -38,22 +43,24 @@ for device in devices: self.start_device_thread(device) - @staticmethod - def get_device_thread_id(device_id): - return f'device_{device_id}' + # @staticmethod + # def get_device_thread_id(device_id): + # return f'device_{device_id}' def start_device_thread(self, device: Device): + device = copy.deepcopy(device) """为单个设备启动检测线程""" if not device.input_stream_url: logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return - thread_id = self.get_device_thread_id(device.id) + thread_id = f'device_{device.id}_{uuid.uuid4()}' + logger.info(f'start thread {thread_id}, device info: {device}') - if self.thread_pool.check_task_is_running(thread_id=thread_id): - logger.info(f"设备 {device.code} 已经在运行中") - return + # if self.thread_pool.check_task_is_running(thread_id=thread_id): + # logger.info(f"设备 {device.code} 已经在运行中") + # return # 获取设备绑定的模型列表 relations = self.relation_service.get_device_models(device.id) @@ -64,52 +71,62 @@ logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") return + for r in relations: + if r.algo_model_id not in self.model_manager.models: + self.model_manager.load_new_model(r.algo_model_id) + model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] - device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list) + device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list, thread_id=thread_id) future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) self.device_tasks[device.id] = device_detection_task - def stop_device_thread(self, device_id): - """停止指定设备的检测线程""" - if device_id in self.device_tasks: - self.device_tasks[device_id].stop_detection_task() - del self.device_tasks[device_id] - # def stop_device_thread(self, device_id): - # """停止指定设备的检测线程,并确保其成功停止""" + # """停止指定设备的检测线程""" # if device_id in self.device_tasks: - # # 获取线程 ID 和 future 对象 - # thread_id = self.get_device_thread_id(device_id) - # future = self.thread_pool.check_task_stopped(thread_id=thread_id) - # - # if future: - # # 调用 stop_detection_task 停止任务 - # self.device_tasks[device_id].stop_detection_task() - # - # try: - # # 设置超时时间等待任务停止(例如10秒) - # result = future.result(timeout=10) - # logger.info(f"Task for device {device_id} stopped successfully.") - # except TimeoutError: - # logger.error(f"Task for device {device_id} did not stop within the timeout.") - # except Exception as e: - # logger.error(f"Task for device {device_id} encountered an error while stopping: {e}") - # finally: - # # 确保无论任务是否停止,都将其从任务列表中移除 - # del self.device_tasks[device_id] - # else: - # logger.warning(f"No task found for device {device_id}.") - # else: - # logger.warning(f"No task exists for device {device_id} in device_tasks.") + # logger.info(f'stop device {device_id} thread') + # self.device_tasks[device_id].stop_detection_task() + # del self.device_tasks[device_id] + + def stop_device_thread(self, device_id): + """停止指定设备的检测线程,并确保其成功停止""" + if device_id in self.device_tasks: + logger.info(f'stop device {device_id} thread') + # 获取线程 ID 和 future 对象 + thread_id = self.device_tasks[device_id].thread_id + future = self.thread_pool.get_task_future(thread_id=thread_id) + + if future: + if not future.done(): + # 任务正在运行,调用 stop_detection_task 停止任务 + self.device_tasks[device_id].stop_detection_task() + try: + # 设置超时时间等待任务停止(例如10秒) + result = future.result(timeout=10) + logger.info(f"Task {thread_id} stopped successfully.") + except TimeoutError: + logger.error(f"Task {thread_id} did not stop within the timeout.") + except Exception as e: + logger.error(f"Task {thread_id} encountered an error while stopping: {e}") + finally: + # 确保无论任务是否停止,都将其从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.info(f"Task {thread_id} has already stopped.") + # 任务已停止,直接从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.warning(f"No task found for {thread_id} .") + else: + logger.warning(f"No task exists for device {device_id} in device_tasks.") def restart_device_thread(self, device_id): self.stop_device_thread(device_id) - # todo 需要留个关闭时间吗 device = self.device_service.get_device(device_id) self.start_device_thread(device) def on_device_change(self, device_id, change_type): + logger.info(f"on device change, device {device_id} {change_type}") """设备变化回调""" if change_type == NotifyChangeType.DEVICE_CREATE: device = self.device_service.get_device(device_id) @@ -117,7 +134,10 @@ elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: - #todo 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) + self.restart_device_thread(device_id) def on_model_change(self, model_id, change_type): diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py new file mode 100644 index 0000000..d7d8015 --- /dev/null +++ b/algo/algo_runner_manager.py @@ -0,0 +1,7 @@ +from algo.algo_runner import AlgoRunner + +algo_runner = AlgoRunner() + + +def get_algo_runner(): + return algo_runner diff --git a/algo/algo_runner.py b/algo/algo_runner.py index 3b07e0a..a97aee9 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,7 @@ +import copy +import uuid +from typing import Dict + from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager from common.consts import NotifyChangeType @@ -9,6 +13,7 @@ from services.model_service import ModelService from common.global_logger import logger + class AlgoRunner: def __init__(self): with next(get_db()) as db: @@ -18,7 +23,7 @@ self.model_manager = ModelManager(db) self.thread_pool = GlobalThreadPool() - self.device_tasks = {} # 用于存储设备对应的线程 + self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 @@ -38,22 +43,24 @@ for device in devices: self.start_device_thread(device) - @staticmethod - def get_device_thread_id(device_id): - return f'device_{device_id}' + # @staticmethod + # def get_device_thread_id(device_id): + # return f'device_{device_id}' def start_device_thread(self, device: Device): + device = copy.deepcopy(device) """为单个设备启动检测线程""" if not device.input_stream_url: logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return - thread_id = self.get_device_thread_id(device.id) + thread_id = f'device_{device.id}_{uuid.uuid4()}' + logger.info(f'start thread {thread_id}, device info: {device}') - if self.thread_pool.check_task_is_running(thread_id=thread_id): - logger.info(f"设备 {device.code} 已经在运行中") - return + # if self.thread_pool.check_task_is_running(thread_id=thread_id): + # logger.info(f"设备 {device.code} 已经在运行中") + # return # 获取设备绑定的模型列表 relations = self.relation_service.get_device_models(device.id) @@ -64,52 +71,62 @@ logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") return + for r in relations: + if r.algo_model_id not in self.model_manager.models: + self.model_manager.load_new_model(r.algo_model_id) + model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] - device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list) + device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list, thread_id=thread_id) future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) self.device_tasks[device.id] = device_detection_task - def stop_device_thread(self, device_id): - """停止指定设备的检测线程""" - if device_id in self.device_tasks: - self.device_tasks[device_id].stop_detection_task() - del self.device_tasks[device_id] - # def stop_device_thread(self, device_id): - # """停止指定设备的检测线程,并确保其成功停止""" + # """停止指定设备的检测线程""" # if device_id in self.device_tasks: - # # 获取线程 ID 和 future 对象 - # thread_id = self.get_device_thread_id(device_id) - # future = self.thread_pool.check_task_stopped(thread_id=thread_id) - # - # if future: - # # 调用 stop_detection_task 停止任务 - # self.device_tasks[device_id].stop_detection_task() - # - # try: - # # 设置超时时间等待任务停止(例如10秒) - # result = future.result(timeout=10) - # logger.info(f"Task for device {device_id} stopped successfully.") - # except TimeoutError: - # logger.error(f"Task for device {device_id} did not stop within the timeout.") - # except Exception as e: - # logger.error(f"Task for device {device_id} encountered an error while stopping: {e}") - # finally: - # # 确保无论任务是否停止,都将其从任务列表中移除 - # del self.device_tasks[device_id] - # else: - # logger.warning(f"No task found for device {device_id}.") - # else: - # logger.warning(f"No task exists for device {device_id} in device_tasks.") + # logger.info(f'stop device {device_id} thread') + # self.device_tasks[device_id].stop_detection_task() + # del self.device_tasks[device_id] + + def stop_device_thread(self, device_id): + """停止指定设备的检测线程,并确保其成功停止""" + if device_id in self.device_tasks: + logger.info(f'stop device {device_id} thread') + # 获取线程 ID 和 future 对象 + thread_id = self.device_tasks[device_id].thread_id + future = self.thread_pool.get_task_future(thread_id=thread_id) + + if future: + if not future.done(): + # 任务正在运行,调用 stop_detection_task 停止任务 + self.device_tasks[device_id].stop_detection_task() + try: + # 设置超时时间等待任务停止(例如10秒) + result = future.result(timeout=10) + logger.info(f"Task {thread_id} stopped successfully.") + except TimeoutError: + logger.error(f"Task {thread_id} did not stop within the timeout.") + except Exception as e: + logger.error(f"Task {thread_id} encountered an error while stopping: {e}") + finally: + # 确保无论任务是否停止,都将其从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.info(f"Task {thread_id} has already stopped.") + # 任务已停止,直接从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.warning(f"No task found for {thread_id} .") + else: + logger.warning(f"No task exists for device {device_id} in device_tasks.") def restart_device_thread(self, device_id): self.stop_device_thread(device_id) - # todo 需要留个关闭时间吗 device = self.device_service.get_device(device_id) self.start_device_thread(device) def on_device_change(self, device_id, change_type): + logger.info(f"on device change, device {device_id} {change_type}") """设备变化回调""" if change_type == NotifyChangeType.DEVICE_CREATE: device = self.device_service.get_device(device_id) @@ -117,7 +134,10 @@ elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: - #todo 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) + self.restart_device_thread(device_id) def on_model_change(self, model_id, change_type): diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py new file mode 100644 index 0000000..d7d8015 --- /dev/null +++ b/algo/algo_runner_manager.py @@ -0,0 +1,7 @@ +from algo.algo_runner import AlgoRunner + +algo_runner = AlgoRunner() + + +def get_algo_runner(): + return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index ff2d513..5e1c75d 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -6,6 +6,7 @@ from algo.model_manager import AlgoModelExec from algo.stream_loader import OpenCVStreamLoad +from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool from common.string_utils import camel_to_snake from db.database import get_db @@ -40,11 +41,12 @@ class DeviceDetectionTask: - def __init__(self, device: Device, model_exec_list: List[AlgoModelExec]): + def __init__(self, device: Device, model_exec_list: List[AlgoModelExec], thread_id): self.device = device self.model_exec_list = model_exec_list self.__stop_event = Event() # 使用 Event 控制线程的运行状态 self.frame_ts = None + self.thread_id = thread_id with next(get_db()) as db: self.device_frame_service = DeviceFrameService(db) @@ -52,9 +54,11 @@ self.thread_pool = GlobalThreadPool() - self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code) + self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, + device_thread_id=thread_id) def stop_detection_task(self): + logger.info(f'stop detection task {self.device.id}, thread_id: {self.thread_id}') self.__stop_event.set() self.stream_loader.stop() # 停止视频流加载的线程 @@ -64,7 +68,7 @@ if self.frame_ts is None: self.frame_ts = datetime.now() return True - if datetime.now() - self.frame_ts > self.device.image_save_interval: + if (datetime.now() - self.frame_ts).total_seconds() > self.device.image_save_interval: self.frame_ts = datetime.now() return True return False @@ -74,6 +78,7 @@ return device_frame = self.device_frame_service.add_frame(self.device.id, frame) frame_id = device_frame.id + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') frame_results = [] for model_exec_id, results in results_map.items(): for r in results: @@ -83,7 +88,7 @@ algo_model_id=model_exec_id, object_class_id=r.object_class_id, object_class_name=r.object_class_name, - confidence=r.conf, + confidence=round(r.confidence, 4), location=r.location, ) frame_results.append(frame_result) @@ -93,6 +98,9 @@ for frame in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 + # print('frame') + if frame is None: + continue results_map = {} for model_exec in self.model_exec_list: diff --git a/algo/algo_runner.py b/algo/algo_runner.py index 3b07e0a..a97aee9 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,7 @@ +import copy +import uuid +from typing import Dict + from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager from common.consts import NotifyChangeType @@ -9,6 +13,7 @@ from services.model_service import ModelService from common.global_logger import logger + class AlgoRunner: def __init__(self): with next(get_db()) as db: @@ -18,7 +23,7 @@ self.model_manager = ModelManager(db) self.thread_pool = GlobalThreadPool() - self.device_tasks = {} # 用于存储设备对应的线程 + self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 @@ -38,22 +43,24 @@ for device in devices: self.start_device_thread(device) - @staticmethod - def get_device_thread_id(device_id): - return f'device_{device_id}' + # @staticmethod + # def get_device_thread_id(device_id): + # return f'device_{device_id}' def start_device_thread(self, device: Device): + device = copy.deepcopy(device) """为单个设备启动检测线程""" if not device.input_stream_url: logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return - thread_id = self.get_device_thread_id(device.id) + thread_id = f'device_{device.id}_{uuid.uuid4()}' + logger.info(f'start thread {thread_id}, device info: {device}') - if self.thread_pool.check_task_is_running(thread_id=thread_id): - logger.info(f"设备 {device.code} 已经在运行中") - return + # if self.thread_pool.check_task_is_running(thread_id=thread_id): + # logger.info(f"设备 {device.code} 已经在运行中") + # return # 获取设备绑定的模型列表 relations = self.relation_service.get_device_models(device.id) @@ -64,52 +71,62 @@ logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") return + for r in relations: + if r.algo_model_id not in self.model_manager.models: + self.model_manager.load_new_model(r.algo_model_id) + model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] - device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list) + device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list, thread_id=thread_id) future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) self.device_tasks[device.id] = device_detection_task - def stop_device_thread(self, device_id): - """停止指定设备的检测线程""" - if device_id in self.device_tasks: - self.device_tasks[device_id].stop_detection_task() - del self.device_tasks[device_id] - # def stop_device_thread(self, device_id): - # """停止指定设备的检测线程,并确保其成功停止""" + # """停止指定设备的检测线程""" # if device_id in self.device_tasks: - # # 获取线程 ID 和 future 对象 - # thread_id = self.get_device_thread_id(device_id) - # future = self.thread_pool.check_task_stopped(thread_id=thread_id) - # - # if future: - # # 调用 stop_detection_task 停止任务 - # self.device_tasks[device_id].stop_detection_task() - # - # try: - # # 设置超时时间等待任务停止(例如10秒) - # result = future.result(timeout=10) - # logger.info(f"Task for device {device_id} stopped successfully.") - # except TimeoutError: - # logger.error(f"Task for device {device_id} did not stop within the timeout.") - # except Exception as e: - # logger.error(f"Task for device {device_id} encountered an error while stopping: {e}") - # finally: - # # 确保无论任务是否停止,都将其从任务列表中移除 - # del self.device_tasks[device_id] - # else: - # logger.warning(f"No task found for device {device_id}.") - # else: - # logger.warning(f"No task exists for device {device_id} in device_tasks.") + # logger.info(f'stop device {device_id} thread') + # self.device_tasks[device_id].stop_detection_task() + # del self.device_tasks[device_id] + + def stop_device_thread(self, device_id): + """停止指定设备的检测线程,并确保其成功停止""" + if device_id in self.device_tasks: + logger.info(f'stop device {device_id} thread') + # 获取线程 ID 和 future 对象 + thread_id = self.device_tasks[device_id].thread_id + future = self.thread_pool.get_task_future(thread_id=thread_id) + + if future: + if not future.done(): + # 任务正在运行,调用 stop_detection_task 停止任务 + self.device_tasks[device_id].stop_detection_task() + try: + # 设置超时时间等待任务停止(例如10秒) + result = future.result(timeout=10) + logger.info(f"Task {thread_id} stopped successfully.") + except TimeoutError: + logger.error(f"Task {thread_id} did not stop within the timeout.") + except Exception as e: + logger.error(f"Task {thread_id} encountered an error while stopping: {e}") + finally: + # 确保无论任务是否停止,都将其从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.info(f"Task {thread_id} has already stopped.") + # 任务已停止,直接从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.warning(f"No task found for {thread_id} .") + else: + logger.warning(f"No task exists for device {device_id} in device_tasks.") def restart_device_thread(self, device_id): self.stop_device_thread(device_id) - # todo 需要留个关闭时间吗 device = self.device_service.get_device(device_id) self.start_device_thread(device) def on_device_change(self, device_id, change_type): + logger.info(f"on device change, device {device_id} {change_type}") """设备变化回调""" if change_type == NotifyChangeType.DEVICE_CREATE: device = self.device_service.get_device(device_id) @@ -117,7 +134,10 @@ elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: - #todo 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) + self.restart_device_thread(device_id) def on_model_change(self, model_id, change_type): diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py new file mode 100644 index 0000000..d7d8015 --- /dev/null +++ b/algo/algo_runner_manager.py @@ -0,0 +1,7 @@ +from algo.algo_runner import AlgoRunner + +algo_runner = AlgoRunner() + + +def get_algo_runner(): + return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index ff2d513..5e1c75d 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -6,6 +6,7 @@ from algo.model_manager import AlgoModelExec from algo.stream_loader import OpenCVStreamLoad +from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool from common.string_utils import camel_to_snake from db.database import get_db @@ -40,11 +41,12 @@ class DeviceDetectionTask: - def __init__(self, device: Device, model_exec_list: List[AlgoModelExec]): + def __init__(self, device: Device, model_exec_list: List[AlgoModelExec], thread_id): self.device = device self.model_exec_list = model_exec_list self.__stop_event = Event() # 使用 Event 控制线程的运行状态 self.frame_ts = None + self.thread_id = thread_id with next(get_db()) as db: self.device_frame_service = DeviceFrameService(db) @@ -52,9 +54,11 @@ self.thread_pool = GlobalThreadPool() - self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code) + self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, + device_thread_id=thread_id) def stop_detection_task(self): + logger.info(f'stop detection task {self.device.id}, thread_id: {self.thread_id}') self.__stop_event.set() self.stream_loader.stop() # 停止视频流加载的线程 @@ -64,7 +68,7 @@ if self.frame_ts is None: self.frame_ts = datetime.now() return True - if datetime.now() - self.frame_ts > self.device.image_save_interval: + if (datetime.now() - self.frame_ts).total_seconds() > self.device.image_save_interval: self.frame_ts = datetime.now() return True return False @@ -74,6 +78,7 @@ return device_frame = self.device_frame_service.add_frame(self.device.id, frame) frame_id = device_frame.id + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') frame_results = [] for model_exec_id, results in results_map.items(): for r in results: @@ -83,7 +88,7 @@ algo_model_id=model_exec_id, object_class_id=r.object_class_id, object_class_name=r.object_class_name, - confidence=r.conf, + confidence=round(r.confidence, 4), location=r.location, ) frame_results.append(frame_result) @@ -93,6 +98,9 @@ for frame in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 + # print('frame') + if frame is None: + continue results_map = {} for model_exec in self.model_exec_list: diff --git a/algo/model_manager.py b/algo/model_manager.py index 535c666..94a058c 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -66,3 +66,15 @@ if algo_model_exec: algo_model_exec.algo_model_exec = None self.load_model(algo_model_exec) + + def load_new_model(self, model_id): + algo_model = self.model_service.get_model_by_id(model_id) + if algo_model: + algo_model_exec = AlgoModelExec( + algo_model_id=algo_model.id, + algo_model_info=algo_model + ) + self.models[algo_model.id] = algo_model_exec + self.load_model(algo_model_exec) + return algo_model_exec + return None diff --git a/algo/algo_runner.py b/algo/algo_runner.py index 3b07e0a..a97aee9 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,7 @@ +import copy +import uuid +from typing import Dict + from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager from common.consts import NotifyChangeType @@ -9,6 +13,7 @@ from services.model_service import ModelService from common.global_logger import logger + class AlgoRunner: def __init__(self): with next(get_db()) as db: @@ -18,7 +23,7 @@ self.model_manager = ModelManager(db) self.thread_pool = GlobalThreadPool() - self.device_tasks = {} # 用于存储设备对应的线程 + self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 @@ -38,22 +43,24 @@ for device in devices: self.start_device_thread(device) - @staticmethod - def get_device_thread_id(device_id): - return f'device_{device_id}' + # @staticmethod + # def get_device_thread_id(device_id): + # return f'device_{device_id}' def start_device_thread(self, device: Device): + device = copy.deepcopy(device) """为单个设备启动检测线程""" if not device.input_stream_url: logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return - thread_id = self.get_device_thread_id(device.id) + thread_id = f'device_{device.id}_{uuid.uuid4()}' + logger.info(f'start thread {thread_id}, device info: {device}') - if self.thread_pool.check_task_is_running(thread_id=thread_id): - logger.info(f"设备 {device.code} 已经在运行中") - return + # if self.thread_pool.check_task_is_running(thread_id=thread_id): + # logger.info(f"设备 {device.code} 已经在运行中") + # return # 获取设备绑定的模型列表 relations = self.relation_service.get_device_models(device.id) @@ -64,52 +71,62 @@ logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") return + for r in relations: + if r.algo_model_id not in self.model_manager.models: + self.model_manager.load_new_model(r.algo_model_id) + model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] - device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list) + device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list, thread_id=thread_id) future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) self.device_tasks[device.id] = device_detection_task - def stop_device_thread(self, device_id): - """停止指定设备的检测线程""" - if device_id in self.device_tasks: - self.device_tasks[device_id].stop_detection_task() - del self.device_tasks[device_id] - # def stop_device_thread(self, device_id): - # """停止指定设备的检测线程,并确保其成功停止""" + # """停止指定设备的检测线程""" # if device_id in self.device_tasks: - # # 获取线程 ID 和 future 对象 - # thread_id = self.get_device_thread_id(device_id) - # future = self.thread_pool.check_task_stopped(thread_id=thread_id) - # - # if future: - # # 调用 stop_detection_task 停止任务 - # self.device_tasks[device_id].stop_detection_task() - # - # try: - # # 设置超时时间等待任务停止(例如10秒) - # result = future.result(timeout=10) - # logger.info(f"Task for device {device_id} stopped successfully.") - # except TimeoutError: - # logger.error(f"Task for device {device_id} did not stop within the timeout.") - # except Exception as e: - # logger.error(f"Task for device {device_id} encountered an error while stopping: {e}") - # finally: - # # 确保无论任务是否停止,都将其从任务列表中移除 - # del self.device_tasks[device_id] - # else: - # logger.warning(f"No task found for device {device_id}.") - # else: - # logger.warning(f"No task exists for device {device_id} in device_tasks.") + # logger.info(f'stop device {device_id} thread') + # self.device_tasks[device_id].stop_detection_task() + # del self.device_tasks[device_id] + + def stop_device_thread(self, device_id): + """停止指定设备的检测线程,并确保其成功停止""" + if device_id in self.device_tasks: + logger.info(f'stop device {device_id} thread') + # 获取线程 ID 和 future 对象 + thread_id = self.device_tasks[device_id].thread_id + future = self.thread_pool.get_task_future(thread_id=thread_id) + + if future: + if not future.done(): + # 任务正在运行,调用 stop_detection_task 停止任务 + self.device_tasks[device_id].stop_detection_task() + try: + # 设置超时时间等待任务停止(例如10秒) + result = future.result(timeout=10) + logger.info(f"Task {thread_id} stopped successfully.") + except TimeoutError: + logger.error(f"Task {thread_id} did not stop within the timeout.") + except Exception as e: + logger.error(f"Task {thread_id} encountered an error while stopping: {e}") + finally: + # 确保无论任务是否停止,都将其从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.info(f"Task {thread_id} has already stopped.") + # 任务已停止,直接从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.warning(f"No task found for {thread_id} .") + else: + logger.warning(f"No task exists for device {device_id} in device_tasks.") def restart_device_thread(self, device_id): self.stop_device_thread(device_id) - # todo 需要留个关闭时间吗 device = self.device_service.get_device(device_id) self.start_device_thread(device) def on_device_change(self, device_id, change_type): + logger.info(f"on device change, device {device_id} {change_type}") """设备变化回调""" if change_type == NotifyChangeType.DEVICE_CREATE: device = self.device_service.get_device(device_id) @@ -117,7 +134,10 @@ elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: - #todo 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) + self.restart_device_thread(device_id) def on_model_change(self, model_id, change_type): diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py new file mode 100644 index 0000000..d7d8015 --- /dev/null +++ b/algo/algo_runner_manager.py @@ -0,0 +1,7 @@ +from algo.algo_runner import AlgoRunner + +algo_runner = AlgoRunner() + + +def get_algo_runner(): + return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index ff2d513..5e1c75d 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -6,6 +6,7 @@ from algo.model_manager import AlgoModelExec from algo.stream_loader import OpenCVStreamLoad +from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool from common.string_utils import camel_to_snake from db.database import get_db @@ -40,11 +41,12 @@ class DeviceDetectionTask: - def __init__(self, device: Device, model_exec_list: List[AlgoModelExec]): + def __init__(self, device: Device, model_exec_list: List[AlgoModelExec], thread_id): self.device = device self.model_exec_list = model_exec_list self.__stop_event = Event() # 使用 Event 控制线程的运行状态 self.frame_ts = None + self.thread_id = thread_id with next(get_db()) as db: self.device_frame_service = DeviceFrameService(db) @@ -52,9 +54,11 @@ self.thread_pool = GlobalThreadPool() - self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code) + self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, + device_thread_id=thread_id) def stop_detection_task(self): + logger.info(f'stop detection task {self.device.id}, thread_id: {self.thread_id}') self.__stop_event.set() self.stream_loader.stop() # 停止视频流加载的线程 @@ -64,7 +68,7 @@ if self.frame_ts is None: self.frame_ts = datetime.now() return True - if datetime.now() - self.frame_ts > self.device.image_save_interval: + if (datetime.now() - self.frame_ts).total_seconds() > self.device.image_save_interval: self.frame_ts = datetime.now() return True return False @@ -74,6 +78,7 @@ return device_frame = self.device_frame_service.add_frame(self.device.id, frame) frame_id = device_frame.id + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') frame_results = [] for model_exec_id, results in results_map.items(): for r in results: @@ -83,7 +88,7 @@ algo_model_id=model_exec_id, object_class_id=r.object_class_id, object_class_name=r.object_class_name, - confidence=r.conf, + confidence=round(r.confidence, 4), location=r.location, ) frame_results.append(frame_result) @@ -93,6 +98,9 @@ for frame in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 + # print('frame') + if frame is None: + continue results_map = {} for model_exec in self.model_exec_list: diff --git a/algo/model_manager.py b/algo/model_manager.py index 535c666..94a058c 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -66,3 +66,15 @@ if algo_model_exec: algo_model_exec.algo_model_exec = None self.load_model(algo_model_exec) + + def load_new_model(self, model_id): + algo_model = self.model_service.get_model_by_id(model_id) + if algo_model: + algo_model_exec = AlgoModelExec( + algo_model_id=algo_model.id, + algo_model_info=algo_model + ) + self.models[algo_model.id] = algo_model_exec + self.load_model(algo_model_exec) + return algo_model_exec + return None diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 04faeb2..9a1199e 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -3,10 +3,11 @@ import numpy as np from threading import Thread, Event from common.global_logger import logger +from common.global_thread_pool import GlobalThreadPool class OpenCVStreamLoad: - def __init__(self, camera_url, camera_code, + def __init__(self, camera_url, camera_code, device_thread_id = '', retry_interval=1, vid_stride=1): assert camera_url is not None and camera_url != '' @@ -14,16 +15,26 @@ self.camera_code = camera_code self.retry_interval = retry_interval self.vid_stride = vid_stride + self.device_thread_id = device_thread_id - self.cap = self.get_connect() - self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) - self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) - self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + self.init = False + self.frame = None + self.cap = None - _, self.frame = self.cap.read() self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 - self.thread = Thread(target=self.update, daemon=True) - self.thread.start() + self.thread_pool = GlobalThreadPool() + self.thread_pool.submit_task(self.update) + # self.thread = Thread(target=self.update, daemon=True) + # self.thread.start() + + def init_cap(self): + self.cap = self.get_connect() + if self.cap: + self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + _, self.frame = self.cap.read() + self.init = True def create_capture(self): """ @@ -43,7 +54,11 @@ """ cap = None while cap is None or not cap.isOpened(): + if self.__stop_event.is_set(): # 检查是否收到停止信号 + logger.info(f"{self.url} stopping connection attempts...thread {self.device_thread_id} stopped") + return None # 退出循环,返回 None 或者其他你认为合适的值 logger.info(f"{self.url} try to connect...") + print('cap') cap = self.create_capture() if cap is None or not cap.isOpened(): logger.info(f"{self.url} connect failed, retry after {self.retry_interval} second...") @@ -55,13 +70,18 @@ def update(self): vid_n = 0 log_n = 0 - while True: + while not self.__stop_event.is_set(): + print('update') + if not self.init: + self.init_cap() + if self.cap is None: + continue vid_n += 1 if vid_n % self.vid_stride == 0: try: ret, frame = self.cap.read() if not ret: - logger.info("disconnect, try to reconnect...") + logger.info(f"{self.url} disconnect, try to reconnect...") self.cap.release() # 释放当前的捕获对象 self.cap = self.get_connect() # 尝试重新连接 self.frame = np.zeros_like(self.frame) @@ -69,12 +89,11 @@ else: vid_n += 1 self.frame = frame - # cv2.imwrite('cv_test.jpg', frame) if log_n % 1000 == 0: - logger.debug('cap success') + logger.debug(f'{self.url} cap success') log_n = (log_n + 1) % 250 except Exception as e: - logger.error("update fail", e) + logger.error(f"{self.url} update fail", e) if self.cap is not None: self.cap.release() self.cap = self.get_connect() # 尝试重新连接 @@ -86,7 +105,9 @@ return self.frame def stop(self): + logger.info(f'stop stream loader {self.url},device thread {self.device_thread_id} stopped') """ 停止视频流读取线程 """ self.__stop_event.set() - self.thread.join() # 确保线程已完全终止 - self.cap.release() + # self.thread.join() # 确保线程已完全终止 + if self.cap: + self.cap.release() diff --git a/algo/algo_runner.py b/algo/algo_runner.py index 3b07e0a..a97aee9 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,7 @@ +import copy +import uuid +from typing import Dict + from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager from common.consts import NotifyChangeType @@ -9,6 +13,7 @@ from services.model_service import ModelService from common.global_logger import logger + class AlgoRunner: def __init__(self): with next(get_db()) as db: @@ -18,7 +23,7 @@ self.model_manager = ModelManager(db) self.thread_pool = GlobalThreadPool() - self.device_tasks = {} # 用于存储设备对应的线程 + self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 @@ -38,22 +43,24 @@ for device in devices: self.start_device_thread(device) - @staticmethod - def get_device_thread_id(device_id): - return f'device_{device_id}' + # @staticmethod + # def get_device_thread_id(device_id): + # return f'device_{device_id}' def start_device_thread(self, device: Device): + device = copy.deepcopy(device) """为单个设备启动检测线程""" if not device.input_stream_url: logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return - thread_id = self.get_device_thread_id(device.id) + thread_id = f'device_{device.id}_{uuid.uuid4()}' + logger.info(f'start thread {thread_id}, device info: {device}') - if self.thread_pool.check_task_is_running(thread_id=thread_id): - logger.info(f"设备 {device.code} 已经在运行中") - return + # if self.thread_pool.check_task_is_running(thread_id=thread_id): + # logger.info(f"设备 {device.code} 已经在运行中") + # return # 获取设备绑定的模型列表 relations = self.relation_service.get_device_models(device.id) @@ -64,52 +71,62 @@ logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") return + for r in relations: + if r.algo_model_id not in self.model_manager.models: + self.model_manager.load_new_model(r.algo_model_id) + model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] - device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list) + device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list, thread_id=thread_id) future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) self.device_tasks[device.id] = device_detection_task - def stop_device_thread(self, device_id): - """停止指定设备的检测线程""" - if device_id in self.device_tasks: - self.device_tasks[device_id].stop_detection_task() - del self.device_tasks[device_id] - # def stop_device_thread(self, device_id): - # """停止指定设备的检测线程,并确保其成功停止""" + # """停止指定设备的检测线程""" # if device_id in self.device_tasks: - # # 获取线程 ID 和 future 对象 - # thread_id = self.get_device_thread_id(device_id) - # future = self.thread_pool.check_task_stopped(thread_id=thread_id) - # - # if future: - # # 调用 stop_detection_task 停止任务 - # self.device_tasks[device_id].stop_detection_task() - # - # try: - # # 设置超时时间等待任务停止(例如10秒) - # result = future.result(timeout=10) - # logger.info(f"Task for device {device_id} stopped successfully.") - # except TimeoutError: - # logger.error(f"Task for device {device_id} did not stop within the timeout.") - # except Exception as e: - # logger.error(f"Task for device {device_id} encountered an error while stopping: {e}") - # finally: - # # 确保无论任务是否停止,都将其从任务列表中移除 - # del self.device_tasks[device_id] - # else: - # logger.warning(f"No task found for device {device_id}.") - # else: - # logger.warning(f"No task exists for device {device_id} in device_tasks.") + # logger.info(f'stop device {device_id} thread') + # self.device_tasks[device_id].stop_detection_task() + # del self.device_tasks[device_id] + + def stop_device_thread(self, device_id): + """停止指定设备的检测线程,并确保其成功停止""" + if device_id in self.device_tasks: + logger.info(f'stop device {device_id} thread') + # 获取线程 ID 和 future 对象 + thread_id = self.device_tasks[device_id].thread_id + future = self.thread_pool.get_task_future(thread_id=thread_id) + + if future: + if not future.done(): + # 任务正在运行,调用 stop_detection_task 停止任务 + self.device_tasks[device_id].stop_detection_task() + try: + # 设置超时时间等待任务停止(例如10秒) + result = future.result(timeout=10) + logger.info(f"Task {thread_id} stopped successfully.") + except TimeoutError: + logger.error(f"Task {thread_id} did not stop within the timeout.") + except Exception as e: + logger.error(f"Task {thread_id} encountered an error while stopping: {e}") + finally: + # 确保无论任务是否停止,都将其从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.info(f"Task {thread_id} has already stopped.") + # 任务已停止,直接从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.warning(f"No task found for {thread_id} .") + else: + logger.warning(f"No task exists for device {device_id} in device_tasks.") def restart_device_thread(self, device_id): self.stop_device_thread(device_id) - # todo 需要留个关闭时间吗 device = self.device_service.get_device(device_id) self.start_device_thread(device) def on_device_change(self, device_id, change_type): + logger.info(f"on device change, device {device_id} {change_type}") """设备变化回调""" if change_type == NotifyChangeType.DEVICE_CREATE: device = self.device_service.get_device(device_id) @@ -117,7 +134,10 @@ elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: - #todo 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) + self.restart_device_thread(device_id) def on_model_change(self, model_id, change_type): diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py new file mode 100644 index 0000000..d7d8015 --- /dev/null +++ b/algo/algo_runner_manager.py @@ -0,0 +1,7 @@ +from algo.algo_runner import AlgoRunner + +algo_runner = AlgoRunner() + + +def get_algo_runner(): + return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index ff2d513..5e1c75d 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -6,6 +6,7 @@ from algo.model_manager import AlgoModelExec from algo.stream_loader import OpenCVStreamLoad +from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool from common.string_utils import camel_to_snake from db.database import get_db @@ -40,11 +41,12 @@ class DeviceDetectionTask: - def __init__(self, device: Device, model_exec_list: List[AlgoModelExec]): + def __init__(self, device: Device, model_exec_list: List[AlgoModelExec], thread_id): self.device = device self.model_exec_list = model_exec_list self.__stop_event = Event() # 使用 Event 控制线程的运行状态 self.frame_ts = None + self.thread_id = thread_id with next(get_db()) as db: self.device_frame_service = DeviceFrameService(db) @@ -52,9 +54,11 @@ self.thread_pool = GlobalThreadPool() - self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code) + self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, + device_thread_id=thread_id) def stop_detection_task(self): + logger.info(f'stop detection task {self.device.id}, thread_id: {self.thread_id}') self.__stop_event.set() self.stream_loader.stop() # 停止视频流加载的线程 @@ -64,7 +68,7 @@ if self.frame_ts is None: self.frame_ts = datetime.now() return True - if datetime.now() - self.frame_ts > self.device.image_save_interval: + if (datetime.now() - self.frame_ts).total_seconds() > self.device.image_save_interval: self.frame_ts = datetime.now() return True return False @@ -74,6 +78,7 @@ return device_frame = self.device_frame_service.add_frame(self.device.id, frame) frame_id = device_frame.id + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') frame_results = [] for model_exec_id, results in results_map.items(): for r in results: @@ -83,7 +88,7 @@ algo_model_id=model_exec_id, object_class_id=r.object_class_id, object_class_name=r.object_class_name, - confidence=r.conf, + confidence=round(r.confidence, 4), location=r.location, ) frame_results.append(frame_result) @@ -93,6 +98,9 @@ for frame in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 + # print('frame') + if frame is None: + continue results_map = {} for model_exec in self.model_exec_list: diff --git a/algo/model_manager.py b/algo/model_manager.py index 535c666..94a058c 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -66,3 +66,15 @@ if algo_model_exec: algo_model_exec.algo_model_exec = None self.load_model(algo_model_exec) + + def load_new_model(self, model_id): + algo_model = self.model_service.get_model_by_id(model_id) + if algo_model: + algo_model_exec = AlgoModelExec( + algo_model_id=algo_model.id, + algo_model_info=algo_model + ) + self.models[algo_model.id] = algo_model_exec + self.load_model(algo_model_exec) + return algo_model_exec + return None diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 04faeb2..9a1199e 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -3,10 +3,11 @@ import numpy as np from threading import Thread, Event from common.global_logger import logger +from common.global_thread_pool import GlobalThreadPool class OpenCVStreamLoad: - def __init__(self, camera_url, camera_code, + def __init__(self, camera_url, camera_code, device_thread_id = '', retry_interval=1, vid_stride=1): assert camera_url is not None and camera_url != '' @@ -14,16 +15,26 @@ self.camera_code = camera_code self.retry_interval = retry_interval self.vid_stride = vid_stride + self.device_thread_id = device_thread_id - self.cap = self.get_connect() - self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) - self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) - self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + self.init = False + self.frame = None + self.cap = None - _, self.frame = self.cap.read() self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 - self.thread = Thread(target=self.update, daemon=True) - self.thread.start() + self.thread_pool = GlobalThreadPool() + self.thread_pool.submit_task(self.update) + # self.thread = Thread(target=self.update, daemon=True) + # self.thread.start() + + def init_cap(self): + self.cap = self.get_connect() + if self.cap: + self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + _, self.frame = self.cap.read() + self.init = True def create_capture(self): """ @@ -43,7 +54,11 @@ """ cap = None while cap is None or not cap.isOpened(): + if self.__stop_event.is_set(): # 检查是否收到停止信号 + logger.info(f"{self.url} stopping connection attempts...thread {self.device_thread_id} stopped") + return None # 退出循环,返回 None 或者其他你认为合适的值 logger.info(f"{self.url} try to connect...") + print('cap') cap = self.create_capture() if cap is None or not cap.isOpened(): logger.info(f"{self.url} connect failed, retry after {self.retry_interval} second...") @@ -55,13 +70,18 @@ def update(self): vid_n = 0 log_n = 0 - while True: + while not self.__stop_event.is_set(): + print('update') + if not self.init: + self.init_cap() + if self.cap is None: + continue vid_n += 1 if vid_n % self.vid_stride == 0: try: ret, frame = self.cap.read() if not ret: - logger.info("disconnect, try to reconnect...") + logger.info(f"{self.url} disconnect, try to reconnect...") self.cap.release() # 释放当前的捕获对象 self.cap = self.get_connect() # 尝试重新连接 self.frame = np.zeros_like(self.frame) @@ -69,12 +89,11 @@ else: vid_n += 1 self.frame = frame - # cv2.imwrite('cv_test.jpg', frame) if log_n % 1000 == 0: - logger.debug('cap success') + logger.debug(f'{self.url} cap success') log_n = (log_n + 1) % 250 except Exception as e: - logger.error("update fail", e) + logger.error(f"{self.url} update fail", e) if self.cap is not None: self.cap.release() self.cap = self.get_connect() # 尝试重新连接 @@ -86,7 +105,9 @@ return self.frame def stop(self): + logger.info(f'stop stream loader {self.url},device thread {self.device_thread_id} stopped') """ 停止视频流读取线程 """ self.__stop_event.set() - self.thread.join() # 确保线程已完全终止 - self.cap.release() + # self.thread.join() # 确保线程已完全终止 + if self.cap: + self.cap.release() diff --git a/apis/device.py b/apis/device.py index 7eeaa78..3be394f 100644 --- a/apis/device.py +++ b/apis/device.py @@ -3,11 +3,14 @@ from fastapi import APIRouter, Depends, Query from sqlmodel import Session +from algo.algo_runner import AlgoRunner from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response from db.database import get_db from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo from services.device_service import DeviceService +from algo.algo_runner_manager import get_algo_runner + router = APIRouter() @@ -41,15 +44,15 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)): - service = DeviceService(db) +def create_device(device_data: DeviceCreate, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.device_service device = service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)): - service = DeviceService(db) +def update_device(device_data: DeviceUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.device_service device = service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") @@ -57,8 +60,8 @@ @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, db: Session = Depends(get_db)): - service = DeviceService(db) +def delete_device(device_id: int, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.device_service device = service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") diff --git a/algo/algo_runner.py b/algo/algo_runner.py index 3b07e0a..a97aee9 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,7 @@ +import copy +import uuid +from typing import Dict + from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager from common.consts import NotifyChangeType @@ -9,6 +13,7 @@ from services.model_service import ModelService from common.global_logger import logger + class AlgoRunner: def __init__(self): with next(get_db()) as db: @@ -18,7 +23,7 @@ self.model_manager = ModelManager(db) self.thread_pool = GlobalThreadPool() - self.device_tasks = {} # 用于存储设备对应的线程 + self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 @@ -38,22 +43,24 @@ for device in devices: self.start_device_thread(device) - @staticmethod - def get_device_thread_id(device_id): - return f'device_{device_id}' + # @staticmethod + # def get_device_thread_id(device_id): + # return f'device_{device_id}' def start_device_thread(self, device: Device): + device = copy.deepcopy(device) """为单个设备启动检测线程""" if not device.input_stream_url: logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return - thread_id = self.get_device_thread_id(device.id) + thread_id = f'device_{device.id}_{uuid.uuid4()}' + logger.info(f'start thread {thread_id}, device info: {device}') - if self.thread_pool.check_task_is_running(thread_id=thread_id): - logger.info(f"设备 {device.code} 已经在运行中") - return + # if self.thread_pool.check_task_is_running(thread_id=thread_id): + # logger.info(f"设备 {device.code} 已经在运行中") + # return # 获取设备绑定的模型列表 relations = self.relation_service.get_device_models(device.id) @@ -64,52 +71,62 @@ logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") return + for r in relations: + if r.algo_model_id not in self.model_manager.models: + self.model_manager.load_new_model(r.algo_model_id) + model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] - device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list) + device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list, thread_id=thread_id) future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) self.device_tasks[device.id] = device_detection_task - def stop_device_thread(self, device_id): - """停止指定设备的检测线程""" - if device_id in self.device_tasks: - self.device_tasks[device_id].stop_detection_task() - del self.device_tasks[device_id] - # def stop_device_thread(self, device_id): - # """停止指定设备的检测线程,并确保其成功停止""" + # """停止指定设备的检测线程""" # if device_id in self.device_tasks: - # # 获取线程 ID 和 future 对象 - # thread_id = self.get_device_thread_id(device_id) - # future = self.thread_pool.check_task_stopped(thread_id=thread_id) - # - # if future: - # # 调用 stop_detection_task 停止任务 - # self.device_tasks[device_id].stop_detection_task() - # - # try: - # # 设置超时时间等待任务停止(例如10秒) - # result = future.result(timeout=10) - # logger.info(f"Task for device {device_id} stopped successfully.") - # except TimeoutError: - # logger.error(f"Task for device {device_id} did not stop within the timeout.") - # except Exception as e: - # logger.error(f"Task for device {device_id} encountered an error while stopping: {e}") - # finally: - # # 确保无论任务是否停止,都将其从任务列表中移除 - # del self.device_tasks[device_id] - # else: - # logger.warning(f"No task found for device {device_id}.") - # else: - # logger.warning(f"No task exists for device {device_id} in device_tasks.") + # logger.info(f'stop device {device_id} thread') + # self.device_tasks[device_id].stop_detection_task() + # del self.device_tasks[device_id] + + def stop_device_thread(self, device_id): + """停止指定设备的检测线程,并确保其成功停止""" + if device_id in self.device_tasks: + logger.info(f'stop device {device_id} thread') + # 获取线程 ID 和 future 对象 + thread_id = self.device_tasks[device_id].thread_id + future = self.thread_pool.get_task_future(thread_id=thread_id) + + if future: + if not future.done(): + # 任务正在运行,调用 stop_detection_task 停止任务 + self.device_tasks[device_id].stop_detection_task() + try: + # 设置超时时间等待任务停止(例如10秒) + result = future.result(timeout=10) + logger.info(f"Task {thread_id} stopped successfully.") + except TimeoutError: + logger.error(f"Task {thread_id} did not stop within the timeout.") + except Exception as e: + logger.error(f"Task {thread_id} encountered an error while stopping: {e}") + finally: + # 确保无论任务是否停止,都将其从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.info(f"Task {thread_id} has already stopped.") + # 任务已停止,直接从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.warning(f"No task found for {thread_id} .") + else: + logger.warning(f"No task exists for device {device_id} in device_tasks.") def restart_device_thread(self, device_id): self.stop_device_thread(device_id) - # todo 需要留个关闭时间吗 device = self.device_service.get_device(device_id) self.start_device_thread(device) def on_device_change(self, device_id, change_type): + logger.info(f"on device change, device {device_id} {change_type}") """设备变化回调""" if change_type == NotifyChangeType.DEVICE_CREATE: device = self.device_service.get_device(device_id) @@ -117,7 +134,10 @@ elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: - #todo 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) + self.restart_device_thread(device_id) def on_model_change(self, model_id, change_type): diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py new file mode 100644 index 0000000..d7d8015 --- /dev/null +++ b/algo/algo_runner_manager.py @@ -0,0 +1,7 @@ +from algo.algo_runner import AlgoRunner + +algo_runner = AlgoRunner() + + +def get_algo_runner(): + return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index ff2d513..5e1c75d 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -6,6 +6,7 @@ from algo.model_manager import AlgoModelExec from algo.stream_loader import OpenCVStreamLoad +from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool from common.string_utils import camel_to_snake from db.database import get_db @@ -40,11 +41,12 @@ class DeviceDetectionTask: - def __init__(self, device: Device, model_exec_list: List[AlgoModelExec]): + def __init__(self, device: Device, model_exec_list: List[AlgoModelExec], thread_id): self.device = device self.model_exec_list = model_exec_list self.__stop_event = Event() # 使用 Event 控制线程的运行状态 self.frame_ts = None + self.thread_id = thread_id with next(get_db()) as db: self.device_frame_service = DeviceFrameService(db) @@ -52,9 +54,11 @@ self.thread_pool = GlobalThreadPool() - self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code) + self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, + device_thread_id=thread_id) def stop_detection_task(self): + logger.info(f'stop detection task {self.device.id}, thread_id: {self.thread_id}') self.__stop_event.set() self.stream_loader.stop() # 停止视频流加载的线程 @@ -64,7 +68,7 @@ if self.frame_ts is None: self.frame_ts = datetime.now() return True - if datetime.now() - self.frame_ts > self.device.image_save_interval: + if (datetime.now() - self.frame_ts).total_seconds() > self.device.image_save_interval: self.frame_ts = datetime.now() return True return False @@ -74,6 +78,7 @@ return device_frame = self.device_frame_service.add_frame(self.device.id, frame) frame_id = device_frame.id + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') frame_results = [] for model_exec_id, results in results_map.items(): for r in results: @@ -83,7 +88,7 @@ algo_model_id=model_exec_id, object_class_id=r.object_class_id, object_class_name=r.object_class_name, - confidence=r.conf, + confidence=round(r.confidence, 4), location=r.location, ) frame_results.append(frame_result) @@ -93,6 +98,9 @@ for frame in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 + # print('frame') + if frame is None: + continue results_map = {} for model_exec in self.model_exec_list: diff --git a/algo/model_manager.py b/algo/model_manager.py index 535c666..94a058c 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -66,3 +66,15 @@ if algo_model_exec: algo_model_exec.algo_model_exec = None self.load_model(algo_model_exec) + + def load_new_model(self, model_id): + algo_model = self.model_service.get_model_by_id(model_id) + if algo_model: + algo_model_exec = AlgoModelExec( + algo_model_id=algo_model.id, + algo_model_info=algo_model + ) + self.models[algo_model.id] = algo_model_exec + self.load_model(algo_model_exec) + return algo_model_exec + return None diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 04faeb2..9a1199e 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -3,10 +3,11 @@ import numpy as np from threading import Thread, Event from common.global_logger import logger +from common.global_thread_pool import GlobalThreadPool class OpenCVStreamLoad: - def __init__(self, camera_url, camera_code, + def __init__(self, camera_url, camera_code, device_thread_id = '', retry_interval=1, vid_stride=1): assert camera_url is not None and camera_url != '' @@ -14,16 +15,26 @@ self.camera_code = camera_code self.retry_interval = retry_interval self.vid_stride = vid_stride + self.device_thread_id = device_thread_id - self.cap = self.get_connect() - self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) - self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) - self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + self.init = False + self.frame = None + self.cap = None - _, self.frame = self.cap.read() self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 - self.thread = Thread(target=self.update, daemon=True) - self.thread.start() + self.thread_pool = GlobalThreadPool() + self.thread_pool.submit_task(self.update) + # self.thread = Thread(target=self.update, daemon=True) + # self.thread.start() + + def init_cap(self): + self.cap = self.get_connect() + if self.cap: + self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + _, self.frame = self.cap.read() + self.init = True def create_capture(self): """ @@ -43,7 +54,11 @@ """ cap = None while cap is None or not cap.isOpened(): + if self.__stop_event.is_set(): # 检查是否收到停止信号 + logger.info(f"{self.url} stopping connection attempts...thread {self.device_thread_id} stopped") + return None # 退出循环,返回 None 或者其他你认为合适的值 logger.info(f"{self.url} try to connect...") + print('cap') cap = self.create_capture() if cap is None or not cap.isOpened(): logger.info(f"{self.url} connect failed, retry after {self.retry_interval} second...") @@ -55,13 +70,18 @@ def update(self): vid_n = 0 log_n = 0 - while True: + while not self.__stop_event.is_set(): + print('update') + if not self.init: + self.init_cap() + if self.cap is None: + continue vid_n += 1 if vid_n % self.vid_stride == 0: try: ret, frame = self.cap.read() if not ret: - logger.info("disconnect, try to reconnect...") + logger.info(f"{self.url} disconnect, try to reconnect...") self.cap.release() # 释放当前的捕获对象 self.cap = self.get_connect() # 尝试重新连接 self.frame = np.zeros_like(self.frame) @@ -69,12 +89,11 @@ else: vid_n += 1 self.frame = frame - # cv2.imwrite('cv_test.jpg', frame) if log_n % 1000 == 0: - logger.debug('cap success') + logger.debug(f'{self.url} cap success') log_n = (log_n + 1) % 250 except Exception as e: - logger.error("update fail", e) + logger.error(f"{self.url} update fail", e) if self.cap is not None: self.cap.release() self.cap = self.get_connect() # 尝试重新连接 @@ -86,7 +105,9 @@ return self.frame def stop(self): + logger.info(f'stop stream loader {self.url},device thread {self.device_thread_id} stopped') """ 停止视频流读取线程 """ self.__stop_event.set() - self.thread.join() # 确保线程已完全终止 - self.cap.release() + # self.thread.join() # 确保线程已完全终止 + if self.cap: + self.cap.release() diff --git a/apis/device.py b/apis/device.py index 7eeaa78..3be394f 100644 --- a/apis/device.py +++ b/apis/device.py @@ -3,11 +3,14 @@ from fastapi import APIRouter, Depends, Query from sqlmodel import Session +from algo.algo_runner import AlgoRunner from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response from db.database import get_db from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo from services.device_service import DeviceService +from algo.algo_runner_manager import get_algo_runner + router = APIRouter() @@ -41,15 +44,15 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)): - service = DeviceService(db) +def create_device(device_data: DeviceCreate, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.device_service device = service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)): - service = DeviceService(db) +def update_device(device_data: DeviceUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.device_service device = service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") @@ -57,8 +60,8 @@ @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, db: Session = Depends(get_db)): - service = DeviceService(db) +def delete_device(device_id: int, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.device_service device = service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py index 7108571..9152eca 100644 --- a/apis/device_model_realtion.py +++ b/apis/device_model_realtion.py @@ -8,6 +8,9 @@ from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation from services.device_model_relation_service import DeviceModelRelationService +from algo.algo_runner import AlgoRunner +from algo.algo_runner_manager import get_algo_runner + router = APIRouter() @@ -32,8 +35,8 @@ @router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) def update_by_device(relation_data: List[DeviceModelRelationCreate], device_id: int = Query(...), - db: Session = Depends(get_db)): - service = DeviceModelRelationService(db) + algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.relation_service relations = service.update_relations_by_device(device_id, relation_data) return standard_response(data=relations) diff --git a/algo/algo_runner.py b/algo/algo_runner.py index 3b07e0a..a97aee9 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,7 @@ +import copy +import uuid +from typing import Dict + from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager from common.consts import NotifyChangeType @@ -9,6 +13,7 @@ from services.model_service import ModelService from common.global_logger import logger + class AlgoRunner: def __init__(self): with next(get_db()) as db: @@ -18,7 +23,7 @@ self.model_manager = ModelManager(db) self.thread_pool = GlobalThreadPool() - self.device_tasks = {} # 用于存储设备对应的线程 + self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 @@ -38,22 +43,24 @@ for device in devices: self.start_device_thread(device) - @staticmethod - def get_device_thread_id(device_id): - return f'device_{device_id}' + # @staticmethod + # def get_device_thread_id(device_id): + # return f'device_{device_id}' def start_device_thread(self, device: Device): + device = copy.deepcopy(device) """为单个设备启动检测线程""" if not device.input_stream_url: logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return - thread_id = self.get_device_thread_id(device.id) + thread_id = f'device_{device.id}_{uuid.uuid4()}' + logger.info(f'start thread {thread_id}, device info: {device}') - if self.thread_pool.check_task_is_running(thread_id=thread_id): - logger.info(f"设备 {device.code} 已经在运行中") - return + # if self.thread_pool.check_task_is_running(thread_id=thread_id): + # logger.info(f"设备 {device.code} 已经在运行中") + # return # 获取设备绑定的模型列表 relations = self.relation_service.get_device_models(device.id) @@ -64,52 +71,62 @@ logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") return + for r in relations: + if r.algo_model_id not in self.model_manager.models: + self.model_manager.load_new_model(r.algo_model_id) + model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] - device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list) + device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list, thread_id=thread_id) future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) self.device_tasks[device.id] = device_detection_task - def stop_device_thread(self, device_id): - """停止指定设备的检测线程""" - if device_id in self.device_tasks: - self.device_tasks[device_id].stop_detection_task() - del self.device_tasks[device_id] - # def stop_device_thread(self, device_id): - # """停止指定设备的检测线程,并确保其成功停止""" + # """停止指定设备的检测线程""" # if device_id in self.device_tasks: - # # 获取线程 ID 和 future 对象 - # thread_id = self.get_device_thread_id(device_id) - # future = self.thread_pool.check_task_stopped(thread_id=thread_id) - # - # if future: - # # 调用 stop_detection_task 停止任务 - # self.device_tasks[device_id].stop_detection_task() - # - # try: - # # 设置超时时间等待任务停止(例如10秒) - # result = future.result(timeout=10) - # logger.info(f"Task for device {device_id} stopped successfully.") - # except TimeoutError: - # logger.error(f"Task for device {device_id} did not stop within the timeout.") - # except Exception as e: - # logger.error(f"Task for device {device_id} encountered an error while stopping: {e}") - # finally: - # # 确保无论任务是否停止,都将其从任务列表中移除 - # del self.device_tasks[device_id] - # else: - # logger.warning(f"No task found for device {device_id}.") - # else: - # logger.warning(f"No task exists for device {device_id} in device_tasks.") + # logger.info(f'stop device {device_id} thread') + # self.device_tasks[device_id].stop_detection_task() + # del self.device_tasks[device_id] + + def stop_device_thread(self, device_id): + """停止指定设备的检测线程,并确保其成功停止""" + if device_id in self.device_tasks: + logger.info(f'stop device {device_id} thread') + # 获取线程 ID 和 future 对象 + thread_id = self.device_tasks[device_id].thread_id + future = self.thread_pool.get_task_future(thread_id=thread_id) + + if future: + if not future.done(): + # 任务正在运行,调用 stop_detection_task 停止任务 + self.device_tasks[device_id].stop_detection_task() + try: + # 设置超时时间等待任务停止(例如10秒) + result = future.result(timeout=10) + logger.info(f"Task {thread_id} stopped successfully.") + except TimeoutError: + logger.error(f"Task {thread_id} did not stop within the timeout.") + except Exception as e: + logger.error(f"Task {thread_id} encountered an error while stopping: {e}") + finally: + # 确保无论任务是否停止,都将其从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.info(f"Task {thread_id} has already stopped.") + # 任务已停止,直接从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.warning(f"No task found for {thread_id} .") + else: + logger.warning(f"No task exists for device {device_id} in device_tasks.") def restart_device_thread(self, device_id): self.stop_device_thread(device_id) - # todo 需要留个关闭时间吗 device = self.device_service.get_device(device_id) self.start_device_thread(device) def on_device_change(self, device_id, change_type): + logger.info(f"on device change, device {device_id} {change_type}") """设备变化回调""" if change_type == NotifyChangeType.DEVICE_CREATE: device = self.device_service.get_device(device_id) @@ -117,7 +134,10 @@ elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: - #todo 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) + self.restart_device_thread(device_id) def on_model_change(self, model_id, change_type): diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py new file mode 100644 index 0000000..d7d8015 --- /dev/null +++ b/algo/algo_runner_manager.py @@ -0,0 +1,7 @@ +from algo.algo_runner import AlgoRunner + +algo_runner = AlgoRunner() + + +def get_algo_runner(): + return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index ff2d513..5e1c75d 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -6,6 +6,7 @@ from algo.model_manager import AlgoModelExec from algo.stream_loader import OpenCVStreamLoad +from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool from common.string_utils import camel_to_snake from db.database import get_db @@ -40,11 +41,12 @@ class DeviceDetectionTask: - def __init__(self, device: Device, model_exec_list: List[AlgoModelExec]): + def __init__(self, device: Device, model_exec_list: List[AlgoModelExec], thread_id): self.device = device self.model_exec_list = model_exec_list self.__stop_event = Event() # 使用 Event 控制线程的运行状态 self.frame_ts = None + self.thread_id = thread_id with next(get_db()) as db: self.device_frame_service = DeviceFrameService(db) @@ -52,9 +54,11 @@ self.thread_pool = GlobalThreadPool() - self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code) + self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, + device_thread_id=thread_id) def stop_detection_task(self): + logger.info(f'stop detection task {self.device.id}, thread_id: {self.thread_id}') self.__stop_event.set() self.stream_loader.stop() # 停止视频流加载的线程 @@ -64,7 +68,7 @@ if self.frame_ts is None: self.frame_ts = datetime.now() return True - if datetime.now() - self.frame_ts > self.device.image_save_interval: + if (datetime.now() - self.frame_ts).total_seconds() > self.device.image_save_interval: self.frame_ts = datetime.now() return True return False @@ -74,6 +78,7 @@ return device_frame = self.device_frame_service.add_frame(self.device.id, frame) frame_id = device_frame.id + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') frame_results = [] for model_exec_id, results in results_map.items(): for r in results: @@ -83,7 +88,7 @@ algo_model_id=model_exec_id, object_class_id=r.object_class_id, object_class_name=r.object_class_name, - confidence=r.conf, + confidence=round(r.confidence, 4), location=r.location, ) frame_results.append(frame_result) @@ -93,6 +98,9 @@ for frame in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 + # print('frame') + if frame is None: + continue results_map = {} for model_exec in self.model_exec_list: diff --git a/algo/model_manager.py b/algo/model_manager.py index 535c666..94a058c 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -66,3 +66,15 @@ if algo_model_exec: algo_model_exec.algo_model_exec = None self.load_model(algo_model_exec) + + def load_new_model(self, model_id): + algo_model = self.model_service.get_model_by_id(model_id) + if algo_model: + algo_model_exec = AlgoModelExec( + algo_model_id=algo_model.id, + algo_model_info=algo_model + ) + self.models[algo_model.id] = algo_model_exec + self.load_model(algo_model_exec) + return algo_model_exec + return None diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 04faeb2..9a1199e 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -3,10 +3,11 @@ import numpy as np from threading import Thread, Event from common.global_logger import logger +from common.global_thread_pool import GlobalThreadPool class OpenCVStreamLoad: - def __init__(self, camera_url, camera_code, + def __init__(self, camera_url, camera_code, device_thread_id = '', retry_interval=1, vid_stride=1): assert camera_url is not None and camera_url != '' @@ -14,16 +15,26 @@ self.camera_code = camera_code self.retry_interval = retry_interval self.vid_stride = vid_stride + self.device_thread_id = device_thread_id - self.cap = self.get_connect() - self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) - self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) - self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + self.init = False + self.frame = None + self.cap = None - _, self.frame = self.cap.read() self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 - self.thread = Thread(target=self.update, daemon=True) - self.thread.start() + self.thread_pool = GlobalThreadPool() + self.thread_pool.submit_task(self.update) + # self.thread = Thread(target=self.update, daemon=True) + # self.thread.start() + + def init_cap(self): + self.cap = self.get_connect() + if self.cap: + self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + _, self.frame = self.cap.read() + self.init = True def create_capture(self): """ @@ -43,7 +54,11 @@ """ cap = None while cap is None or not cap.isOpened(): + if self.__stop_event.is_set(): # 检查是否收到停止信号 + logger.info(f"{self.url} stopping connection attempts...thread {self.device_thread_id} stopped") + return None # 退出循环,返回 None 或者其他你认为合适的值 logger.info(f"{self.url} try to connect...") + print('cap') cap = self.create_capture() if cap is None or not cap.isOpened(): logger.info(f"{self.url} connect failed, retry after {self.retry_interval} second...") @@ -55,13 +70,18 @@ def update(self): vid_n = 0 log_n = 0 - while True: + while not self.__stop_event.is_set(): + print('update') + if not self.init: + self.init_cap() + if self.cap is None: + continue vid_n += 1 if vid_n % self.vid_stride == 0: try: ret, frame = self.cap.read() if not ret: - logger.info("disconnect, try to reconnect...") + logger.info(f"{self.url} disconnect, try to reconnect...") self.cap.release() # 释放当前的捕获对象 self.cap = self.get_connect() # 尝试重新连接 self.frame = np.zeros_like(self.frame) @@ -69,12 +89,11 @@ else: vid_n += 1 self.frame = frame - # cv2.imwrite('cv_test.jpg', frame) if log_n % 1000 == 0: - logger.debug('cap success') + logger.debug(f'{self.url} cap success') log_n = (log_n + 1) % 250 except Exception as e: - logger.error("update fail", e) + logger.error(f"{self.url} update fail", e) if self.cap is not None: self.cap.release() self.cap = self.get_connect() # 尝试重新连接 @@ -86,7 +105,9 @@ return self.frame def stop(self): + logger.info(f'stop stream loader {self.url},device thread {self.device_thread_id} stopped') """ 停止视频流读取线程 """ self.__stop_event.set() - self.thread.join() # 确保线程已完全终止 - self.cap.release() + # self.thread.join() # 确保线程已完全终止 + if self.cap: + self.cap.release() diff --git a/apis/device.py b/apis/device.py index 7eeaa78..3be394f 100644 --- a/apis/device.py +++ b/apis/device.py @@ -3,11 +3,14 @@ from fastapi import APIRouter, Depends, Query from sqlmodel import Session +from algo.algo_runner import AlgoRunner from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response from db.database import get_db from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo from services.device_service import DeviceService +from algo.algo_runner_manager import get_algo_runner + router = APIRouter() @@ -41,15 +44,15 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)): - service = DeviceService(db) +def create_device(device_data: DeviceCreate, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.device_service device = service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)): - service = DeviceService(db) +def update_device(device_data: DeviceUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.device_service device = service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") @@ -57,8 +60,8 @@ @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, db: Session = Depends(get_db)): - service = DeviceService(db) +def delete_device(device_id: int, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.device_service device = service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py index 7108571..9152eca 100644 --- a/apis/device_model_realtion.py +++ b/apis/device_model_realtion.py @@ -8,6 +8,9 @@ from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation from services.device_model_relation_service import DeviceModelRelationService +from algo.algo_runner import AlgoRunner +from algo.algo_runner_manager import get_algo_runner + router = APIRouter() @@ -32,8 +35,8 @@ @router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) def update_by_device(relation_data: List[DeviceModelRelationCreate], device_id: int = Query(...), - db: Session = Depends(get_db)): - service = DeviceModelRelationService(db) + algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.relation_service relations = service.update_relations_by_device(device_id, relation_data) return standard_response(data=relations) diff --git a/apis/model.py b/apis/model.py index 8adcc14..da63a61 100644 --- a/apis/model.py +++ b/apis/model.py @@ -1,14 +1,17 @@ from typing import List, Optional -from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi import APIRouter, Depends, Query from sqlmodel import Session + from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response -from common.biz_exception import BizException from db.database import get_db from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo from services.model_service import ModelService +from algo.algo_runner import AlgoRunner +from algo.algo_runner_manager import get_algo_runner + router = APIRouter() @@ -47,8 +50,8 @@ @router.post("/update", response_model=StandardResponse[AlgoModelInfo]) -def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)): - service = ModelService(db) +def update_model(model_data: AlgoModelUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.model_service model = service.update_model(model_data) if not model: return standard_error_response(data=model_data, message="Model not found") diff --git a/algo/algo_runner.py b/algo/algo_runner.py index 3b07e0a..a97aee9 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,7 @@ +import copy +import uuid +from typing import Dict + from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager from common.consts import NotifyChangeType @@ -9,6 +13,7 @@ from services.model_service import ModelService from common.global_logger import logger + class AlgoRunner: def __init__(self): with next(get_db()) as db: @@ -18,7 +23,7 @@ self.model_manager = ModelManager(db) self.thread_pool = GlobalThreadPool() - self.device_tasks = {} # 用于存储设备对应的线程 + self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 @@ -38,22 +43,24 @@ for device in devices: self.start_device_thread(device) - @staticmethod - def get_device_thread_id(device_id): - return f'device_{device_id}' + # @staticmethod + # def get_device_thread_id(device_id): + # return f'device_{device_id}' def start_device_thread(self, device: Device): + device = copy.deepcopy(device) """为单个设备启动检测线程""" if not device.input_stream_url: logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return - thread_id = self.get_device_thread_id(device.id) + thread_id = f'device_{device.id}_{uuid.uuid4()}' + logger.info(f'start thread {thread_id}, device info: {device}') - if self.thread_pool.check_task_is_running(thread_id=thread_id): - logger.info(f"设备 {device.code} 已经在运行中") - return + # if self.thread_pool.check_task_is_running(thread_id=thread_id): + # logger.info(f"设备 {device.code} 已经在运行中") + # return # 获取设备绑定的模型列表 relations = self.relation_service.get_device_models(device.id) @@ -64,52 +71,62 @@ logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") return + for r in relations: + if r.algo_model_id not in self.model_manager.models: + self.model_manager.load_new_model(r.algo_model_id) + model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] - device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list) + device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list, thread_id=thread_id) future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) self.device_tasks[device.id] = device_detection_task - def stop_device_thread(self, device_id): - """停止指定设备的检测线程""" - if device_id in self.device_tasks: - self.device_tasks[device_id].stop_detection_task() - del self.device_tasks[device_id] - # def stop_device_thread(self, device_id): - # """停止指定设备的检测线程,并确保其成功停止""" + # """停止指定设备的检测线程""" # if device_id in self.device_tasks: - # # 获取线程 ID 和 future 对象 - # thread_id = self.get_device_thread_id(device_id) - # future = self.thread_pool.check_task_stopped(thread_id=thread_id) - # - # if future: - # # 调用 stop_detection_task 停止任务 - # self.device_tasks[device_id].stop_detection_task() - # - # try: - # # 设置超时时间等待任务停止(例如10秒) - # result = future.result(timeout=10) - # logger.info(f"Task for device {device_id} stopped successfully.") - # except TimeoutError: - # logger.error(f"Task for device {device_id} did not stop within the timeout.") - # except Exception as e: - # logger.error(f"Task for device {device_id} encountered an error while stopping: {e}") - # finally: - # # 确保无论任务是否停止,都将其从任务列表中移除 - # del self.device_tasks[device_id] - # else: - # logger.warning(f"No task found for device {device_id}.") - # else: - # logger.warning(f"No task exists for device {device_id} in device_tasks.") + # logger.info(f'stop device {device_id} thread') + # self.device_tasks[device_id].stop_detection_task() + # del self.device_tasks[device_id] + + def stop_device_thread(self, device_id): + """停止指定设备的检测线程,并确保其成功停止""" + if device_id in self.device_tasks: + logger.info(f'stop device {device_id} thread') + # 获取线程 ID 和 future 对象 + thread_id = self.device_tasks[device_id].thread_id + future = self.thread_pool.get_task_future(thread_id=thread_id) + + if future: + if not future.done(): + # 任务正在运行,调用 stop_detection_task 停止任务 + self.device_tasks[device_id].stop_detection_task() + try: + # 设置超时时间等待任务停止(例如10秒) + result = future.result(timeout=10) + logger.info(f"Task {thread_id} stopped successfully.") + except TimeoutError: + logger.error(f"Task {thread_id} did not stop within the timeout.") + except Exception as e: + logger.error(f"Task {thread_id} encountered an error while stopping: {e}") + finally: + # 确保无论任务是否停止,都将其从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.info(f"Task {thread_id} has already stopped.") + # 任务已停止,直接从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.warning(f"No task found for {thread_id} .") + else: + logger.warning(f"No task exists for device {device_id} in device_tasks.") def restart_device_thread(self, device_id): self.stop_device_thread(device_id) - # todo 需要留个关闭时间吗 device = self.device_service.get_device(device_id) self.start_device_thread(device) def on_device_change(self, device_id, change_type): + logger.info(f"on device change, device {device_id} {change_type}") """设备变化回调""" if change_type == NotifyChangeType.DEVICE_CREATE: device = self.device_service.get_device(device_id) @@ -117,7 +134,10 @@ elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: - #todo 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) + self.restart_device_thread(device_id) def on_model_change(self, model_id, change_type): diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py new file mode 100644 index 0000000..d7d8015 --- /dev/null +++ b/algo/algo_runner_manager.py @@ -0,0 +1,7 @@ +from algo.algo_runner import AlgoRunner + +algo_runner = AlgoRunner() + + +def get_algo_runner(): + return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index ff2d513..5e1c75d 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -6,6 +6,7 @@ from algo.model_manager import AlgoModelExec from algo.stream_loader import OpenCVStreamLoad +from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool from common.string_utils import camel_to_snake from db.database import get_db @@ -40,11 +41,12 @@ class DeviceDetectionTask: - def __init__(self, device: Device, model_exec_list: List[AlgoModelExec]): + def __init__(self, device: Device, model_exec_list: List[AlgoModelExec], thread_id): self.device = device self.model_exec_list = model_exec_list self.__stop_event = Event() # 使用 Event 控制线程的运行状态 self.frame_ts = None + self.thread_id = thread_id with next(get_db()) as db: self.device_frame_service = DeviceFrameService(db) @@ -52,9 +54,11 @@ self.thread_pool = GlobalThreadPool() - self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code) + self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, + device_thread_id=thread_id) def stop_detection_task(self): + logger.info(f'stop detection task {self.device.id}, thread_id: {self.thread_id}') self.__stop_event.set() self.stream_loader.stop() # 停止视频流加载的线程 @@ -64,7 +68,7 @@ if self.frame_ts is None: self.frame_ts = datetime.now() return True - if datetime.now() - self.frame_ts > self.device.image_save_interval: + if (datetime.now() - self.frame_ts).total_seconds() > self.device.image_save_interval: self.frame_ts = datetime.now() return True return False @@ -74,6 +78,7 @@ return device_frame = self.device_frame_service.add_frame(self.device.id, frame) frame_id = device_frame.id + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') frame_results = [] for model_exec_id, results in results_map.items(): for r in results: @@ -83,7 +88,7 @@ algo_model_id=model_exec_id, object_class_id=r.object_class_id, object_class_name=r.object_class_name, - confidence=r.conf, + confidence=round(r.confidence, 4), location=r.location, ) frame_results.append(frame_result) @@ -93,6 +98,9 @@ for frame in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 + # print('frame') + if frame is None: + continue results_map = {} for model_exec in self.model_exec_list: diff --git a/algo/model_manager.py b/algo/model_manager.py index 535c666..94a058c 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -66,3 +66,15 @@ if algo_model_exec: algo_model_exec.algo_model_exec = None self.load_model(algo_model_exec) + + def load_new_model(self, model_id): + algo_model = self.model_service.get_model_by_id(model_id) + if algo_model: + algo_model_exec = AlgoModelExec( + algo_model_id=algo_model.id, + algo_model_info=algo_model + ) + self.models[algo_model.id] = algo_model_exec + self.load_model(algo_model_exec) + return algo_model_exec + return None diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 04faeb2..9a1199e 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -3,10 +3,11 @@ import numpy as np from threading import Thread, Event from common.global_logger import logger +from common.global_thread_pool import GlobalThreadPool class OpenCVStreamLoad: - def __init__(self, camera_url, camera_code, + def __init__(self, camera_url, camera_code, device_thread_id = '', retry_interval=1, vid_stride=1): assert camera_url is not None and camera_url != '' @@ -14,16 +15,26 @@ self.camera_code = camera_code self.retry_interval = retry_interval self.vid_stride = vid_stride + self.device_thread_id = device_thread_id - self.cap = self.get_connect() - self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) - self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) - self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + self.init = False + self.frame = None + self.cap = None - _, self.frame = self.cap.read() self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 - self.thread = Thread(target=self.update, daemon=True) - self.thread.start() + self.thread_pool = GlobalThreadPool() + self.thread_pool.submit_task(self.update) + # self.thread = Thread(target=self.update, daemon=True) + # self.thread.start() + + def init_cap(self): + self.cap = self.get_connect() + if self.cap: + self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + _, self.frame = self.cap.read() + self.init = True def create_capture(self): """ @@ -43,7 +54,11 @@ """ cap = None while cap is None or not cap.isOpened(): + if self.__stop_event.is_set(): # 检查是否收到停止信号 + logger.info(f"{self.url} stopping connection attempts...thread {self.device_thread_id} stopped") + return None # 退出循环,返回 None 或者其他你认为合适的值 logger.info(f"{self.url} try to connect...") + print('cap') cap = self.create_capture() if cap is None or not cap.isOpened(): logger.info(f"{self.url} connect failed, retry after {self.retry_interval} second...") @@ -55,13 +70,18 @@ def update(self): vid_n = 0 log_n = 0 - while True: + while not self.__stop_event.is_set(): + print('update') + if not self.init: + self.init_cap() + if self.cap is None: + continue vid_n += 1 if vid_n % self.vid_stride == 0: try: ret, frame = self.cap.read() if not ret: - logger.info("disconnect, try to reconnect...") + logger.info(f"{self.url} disconnect, try to reconnect...") self.cap.release() # 释放当前的捕获对象 self.cap = self.get_connect() # 尝试重新连接 self.frame = np.zeros_like(self.frame) @@ -69,12 +89,11 @@ else: vid_n += 1 self.frame = frame - # cv2.imwrite('cv_test.jpg', frame) if log_n % 1000 == 0: - logger.debug('cap success') + logger.debug(f'{self.url} cap success') log_n = (log_n + 1) % 250 except Exception as e: - logger.error("update fail", e) + logger.error(f"{self.url} update fail", e) if self.cap is not None: self.cap.release() self.cap = self.get_connect() # 尝试重新连接 @@ -86,7 +105,9 @@ return self.frame def stop(self): + logger.info(f'stop stream loader {self.url},device thread {self.device_thread_id} stopped') """ 停止视频流读取线程 """ self.__stop_event.set() - self.thread.join() # 确保线程已完全终止 - self.cap.release() + # self.thread.join() # 确保线程已完全终止 + if self.cap: + self.cap.release() diff --git a/apis/device.py b/apis/device.py index 7eeaa78..3be394f 100644 --- a/apis/device.py +++ b/apis/device.py @@ -3,11 +3,14 @@ from fastapi import APIRouter, Depends, Query from sqlmodel import Session +from algo.algo_runner import AlgoRunner from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response from db.database import get_db from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo from services.device_service import DeviceService +from algo.algo_runner_manager import get_algo_runner + router = APIRouter() @@ -41,15 +44,15 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)): - service = DeviceService(db) +def create_device(device_data: DeviceCreate, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.device_service device = service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)): - service = DeviceService(db) +def update_device(device_data: DeviceUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.device_service device = service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") @@ -57,8 +60,8 @@ @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, db: Session = Depends(get_db)): - service = DeviceService(db) +def delete_device(device_id: int, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.device_service device = service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py index 7108571..9152eca 100644 --- a/apis/device_model_realtion.py +++ b/apis/device_model_realtion.py @@ -8,6 +8,9 @@ from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation from services.device_model_relation_service import DeviceModelRelationService +from algo.algo_runner import AlgoRunner +from algo.algo_runner_manager import get_algo_runner + router = APIRouter() @@ -32,8 +35,8 @@ @router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) def update_by_device(relation_data: List[DeviceModelRelationCreate], device_id: int = Query(...), - db: Session = Depends(get_db)): - service = DeviceModelRelationService(db) + algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.relation_service relations = service.update_relations_by_device(device_id, relation_data) return standard_response(data=relations) diff --git a/apis/model.py b/apis/model.py index 8adcc14..da63a61 100644 --- a/apis/model.py +++ b/apis/model.py @@ -1,14 +1,17 @@ from typing import List, Optional -from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi import APIRouter, Depends, Query from sqlmodel import Session + from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response -from common.biz_exception import BizException from db.database import get_db from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo from services.model_service import ModelService +from algo.algo_runner import AlgoRunner +from algo.algo_runner_manager import get_algo_runner + router = APIRouter() @@ -47,8 +50,8 @@ @router.post("/update", response_model=StandardResponse[AlgoModelInfo]) -def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)): - service = ModelService(db) +def update_model(model_data: AlgoModelUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.model_service model = service.update_model(model_data) if not model: return standard_error_response(data=model_data, message="Model not found") diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py index f55359b..7f3be9f 100644 --- a/common/global_thread_pool.py +++ b/common/global_thread_pool.py @@ -1,3 +1,4 @@ +import traceback import uuid from concurrent.futures import ThreadPoolExecutor import threading @@ -33,6 +34,7 @@ raise ValueError(f"线程 ID {thread_id} 已存在") future = self.executor.submit(fn, *args, **kwargs) self.task_map[thread_id] = future # 记录线程 ID 和 Future 对象的映射 + future.add_done_callback(lambda f: self._handle_exception(f, thread_id)) return thread_id def check_task_is_running(self, thread_id): @@ -64,6 +66,15 @@ logger.warning(f"No task found with thread ID {thread_id}.") return True # 如果找不到该任务,认为它已经停止(或者不存在) + def get_task_future(self, thread_id): + """获取指定线程 ID 的 future 对象""" + future = self.task_map.get(thread_id) + if future: + return future + else: + logger.warning(f"No task found with thread ID {thread_id}.") + return None + def stop_task(self, thread_id): """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务""" future = self.task_map.get(thread_id) @@ -78,3 +89,16 @@ """关闭线程池""" self.executor.shutdown(wait=wait) GlobalThreadPool._instance = None + + def _handle_exception(self, future, thread_id): + """ + 处理任务完成时的异常 + :param future: 完成的任务 future 对象 + :param camera_id: 对应的摄像头 ID + """ + try: + # 获取任务结果,如果任务有异常,这里会抛出 + result = future.result() + except Exception as e: + logger.error(f"Task for thread {thread_id} raised an exception: {e}") + logger.error(f"Traceback: {traceback.format_exc()}") \ No newline at end of file diff --git a/algo/algo_runner.py b/algo/algo_runner.py index 3b07e0a..a97aee9 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,7 @@ +import copy +import uuid +from typing import Dict + from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager from common.consts import NotifyChangeType @@ -9,6 +13,7 @@ from services.model_service import ModelService from common.global_logger import logger + class AlgoRunner: def __init__(self): with next(get_db()) as db: @@ -18,7 +23,7 @@ self.model_manager = ModelManager(db) self.thread_pool = GlobalThreadPool() - self.device_tasks = {} # 用于存储设备对应的线程 + self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 @@ -38,22 +43,24 @@ for device in devices: self.start_device_thread(device) - @staticmethod - def get_device_thread_id(device_id): - return f'device_{device_id}' + # @staticmethod + # def get_device_thread_id(device_id): + # return f'device_{device_id}' def start_device_thread(self, device: Device): + device = copy.deepcopy(device) """为单个设备启动检测线程""" if not device.input_stream_url: logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return - thread_id = self.get_device_thread_id(device.id) + thread_id = f'device_{device.id}_{uuid.uuid4()}' + logger.info(f'start thread {thread_id}, device info: {device}') - if self.thread_pool.check_task_is_running(thread_id=thread_id): - logger.info(f"设备 {device.code} 已经在运行中") - return + # if self.thread_pool.check_task_is_running(thread_id=thread_id): + # logger.info(f"设备 {device.code} 已经在运行中") + # return # 获取设备绑定的模型列表 relations = self.relation_service.get_device_models(device.id) @@ -64,52 +71,62 @@ logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") return + for r in relations: + if r.algo_model_id not in self.model_manager.models: + self.model_manager.load_new_model(r.algo_model_id) + model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] - device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list) + device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list, thread_id=thread_id) future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) self.device_tasks[device.id] = device_detection_task - def stop_device_thread(self, device_id): - """停止指定设备的检测线程""" - if device_id in self.device_tasks: - self.device_tasks[device_id].stop_detection_task() - del self.device_tasks[device_id] - # def stop_device_thread(self, device_id): - # """停止指定设备的检测线程,并确保其成功停止""" + # """停止指定设备的检测线程""" # if device_id in self.device_tasks: - # # 获取线程 ID 和 future 对象 - # thread_id = self.get_device_thread_id(device_id) - # future = self.thread_pool.check_task_stopped(thread_id=thread_id) - # - # if future: - # # 调用 stop_detection_task 停止任务 - # self.device_tasks[device_id].stop_detection_task() - # - # try: - # # 设置超时时间等待任务停止(例如10秒) - # result = future.result(timeout=10) - # logger.info(f"Task for device {device_id} stopped successfully.") - # except TimeoutError: - # logger.error(f"Task for device {device_id} did not stop within the timeout.") - # except Exception as e: - # logger.error(f"Task for device {device_id} encountered an error while stopping: {e}") - # finally: - # # 确保无论任务是否停止,都将其从任务列表中移除 - # del self.device_tasks[device_id] - # else: - # logger.warning(f"No task found for device {device_id}.") - # else: - # logger.warning(f"No task exists for device {device_id} in device_tasks.") + # logger.info(f'stop device {device_id} thread') + # self.device_tasks[device_id].stop_detection_task() + # del self.device_tasks[device_id] + + def stop_device_thread(self, device_id): + """停止指定设备的检测线程,并确保其成功停止""" + if device_id in self.device_tasks: + logger.info(f'stop device {device_id} thread') + # 获取线程 ID 和 future 对象 + thread_id = self.device_tasks[device_id].thread_id + future = self.thread_pool.get_task_future(thread_id=thread_id) + + if future: + if not future.done(): + # 任务正在运行,调用 stop_detection_task 停止任务 + self.device_tasks[device_id].stop_detection_task() + try: + # 设置超时时间等待任务停止(例如10秒) + result = future.result(timeout=10) + logger.info(f"Task {thread_id} stopped successfully.") + except TimeoutError: + logger.error(f"Task {thread_id} did not stop within the timeout.") + except Exception as e: + logger.error(f"Task {thread_id} encountered an error while stopping: {e}") + finally: + # 确保无论任务是否停止,都将其从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.info(f"Task {thread_id} has already stopped.") + # 任务已停止,直接从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.warning(f"No task found for {thread_id} .") + else: + logger.warning(f"No task exists for device {device_id} in device_tasks.") def restart_device_thread(self, device_id): self.stop_device_thread(device_id) - # todo 需要留个关闭时间吗 device = self.device_service.get_device(device_id) self.start_device_thread(device) def on_device_change(self, device_id, change_type): + logger.info(f"on device change, device {device_id} {change_type}") """设备变化回调""" if change_type == NotifyChangeType.DEVICE_CREATE: device = self.device_service.get_device(device_id) @@ -117,7 +134,10 @@ elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: - #todo 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) + self.restart_device_thread(device_id) def on_model_change(self, model_id, change_type): diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py new file mode 100644 index 0000000..d7d8015 --- /dev/null +++ b/algo/algo_runner_manager.py @@ -0,0 +1,7 @@ +from algo.algo_runner import AlgoRunner + +algo_runner = AlgoRunner() + + +def get_algo_runner(): + return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index ff2d513..5e1c75d 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -6,6 +6,7 @@ from algo.model_manager import AlgoModelExec from algo.stream_loader import OpenCVStreamLoad +from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool from common.string_utils import camel_to_snake from db.database import get_db @@ -40,11 +41,12 @@ class DeviceDetectionTask: - def __init__(self, device: Device, model_exec_list: List[AlgoModelExec]): + def __init__(self, device: Device, model_exec_list: List[AlgoModelExec], thread_id): self.device = device self.model_exec_list = model_exec_list self.__stop_event = Event() # 使用 Event 控制线程的运行状态 self.frame_ts = None + self.thread_id = thread_id with next(get_db()) as db: self.device_frame_service = DeviceFrameService(db) @@ -52,9 +54,11 @@ self.thread_pool = GlobalThreadPool() - self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code) + self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, + device_thread_id=thread_id) def stop_detection_task(self): + logger.info(f'stop detection task {self.device.id}, thread_id: {self.thread_id}') self.__stop_event.set() self.stream_loader.stop() # 停止视频流加载的线程 @@ -64,7 +68,7 @@ if self.frame_ts is None: self.frame_ts = datetime.now() return True - if datetime.now() - self.frame_ts > self.device.image_save_interval: + if (datetime.now() - self.frame_ts).total_seconds() > self.device.image_save_interval: self.frame_ts = datetime.now() return True return False @@ -74,6 +78,7 @@ return device_frame = self.device_frame_service.add_frame(self.device.id, frame) frame_id = device_frame.id + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') frame_results = [] for model_exec_id, results in results_map.items(): for r in results: @@ -83,7 +88,7 @@ algo_model_id=model_exec_id, object_class_id=r.object_class_id, object_class_name=r.object_class_name, - confidence=r.conf, + confidence=round(r.confidence, 4), location=r.location, ) frame_results.append(frame_result) @@ -93,6 +98,9 @@ for frame in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 + # print('frame') + if frame is None: + continue results_map = {} for model_exec in self.model_exec_list: diff --git a/algo/model_manager.py b/algo/model_manager.py index 535c666..94a058c 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -66,3 +66,15 @@ if algo_model_exec: algo_model_exec.algo_model_exec = None self.load_model(algo_model_exec) + + def load_new_model(self, model_id): + algo_model = self.model_service.get_model_by_id(model_id) + if algo_model: + algo_model_exec = AlgoModelExec( + algo_model_id=algo_model.id, + algo_model_info=algo_model + ) + self.models[algo_model.id] = algo_model_exec + self.load_model(algo_model_exec) + return algo_model_exec + return None diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 04faeb2..9a1199e 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -3,10 +3,11 @@ import numpy as np from threading import Thread, Event from common.global_logger import logger +from common.global_thread_pool import GlobalThreadPool class OpenCVStreamLoad: - def __init__(self, camera_url, camera_code, + def __init__(self, camera_url, camera_code, device_thread_id = '', retry_interval=1, vid_stride=1): assert camera_url is not None and camera_url != '' @@ -14,16 +15,26 @@ self.camera_code = camera_code self.retry_interval = retry_interval self.vid_stride = vid_stride + self.device_thread_id = device_thread_id - self.cap = self.get_connect() - self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) - self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) - self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + self.init = False + self.frame = None + self.cap = None - _, self.frame = self.cap.read() self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 - self.thread = Thread(target=self.update, daemon=True) - self.thread.start() + self.thread_pool = GlobalThreadPool() + self.thread_pool.submit_task(self.update) + # self.thread = Thread(target=self.update, daemon=True) + # self.thread.start() + + def init_cap(self): + self.cap = self.get_connect() + if self.cap: + self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + _, self.frame = self.cap.read() + self.init = True def create_capture(self): """ @@ -43,7 +54,11 @@ """ cap = None while cap is None or not cap.isOpened(): + if self.__stop_event.is_set(): # 检查是否收到停止信号 + logger.info(f"{self.url} stopping connection attempts...thread {self.device_thread_id} stopped") + return None # 退出循环,返回 None 或者其他你认为合适的值 logger.info(f"{self.url} try to connect...") + print('cap') cap = self.create_capture() if cap is None or not cap.isOpened(): logger.info(f"{self.url} connect failed, retry after {self.retry_interval} second...") @@ -55,13 +70,18 @@ def update(self): vid_n = 0 log_n = 0 - while True: + while not self.__stop_event.is_set(): + print('update') + if not self.init: + self.init_cap() + if self.cap is None: + continue vid_n += 1 if vid_n % self.vid_stride == 0: try: ret, frame = self.cap.read() if not ret: - logger.info("disconnect, try to reconnect...") + logger.info(f"{self.url} disconnect, try to reconnect...") self.cap.release() # 释放当前的捕获对象 self.cap = self.get_connect() # 尝试重新连接 self.frame = np.zeros_like(self.frame) @@ -69,12 +89,11 @@ else: vid_n += 1 self.frame = frame - # cv2.imwrite('cv_test.jpg', frame) if log_n % 1000 == 0: - logger.debug('cap success') + logger.debug(f'{self.url} cap success') log_n = (log_n + 1) % 250 except Exception as e: - logger.error("update fail", e) + logger.error(f"{self.url} update fail", e) if self.cap is not None: self.cap.release() self.cap = self.get_connect() # 尝试重新连接 @@ -86,7 +105,9 @@ return self.frame def stop(self): + logger.info(f'stop stream loader {self.url},device thread {self.device_thread_id} stopped') """ 停止视频流读取线程 """ self.__stop_event.set() - self.thread.join() # 确保线程已完全终止 - self.cap.release() + # self.thread.join() # 确保线程已完全终止 + if self.cap: + self.cap.release() diff --git a/apis/device.py b/apis/device.py index 7eeaa78..3be394f 100644 --- a/apis/device.py +++ b/apis/device.py @@ -3,11 +3,14 @@ from fastapi import APIRouter, Depends, Query from sqlmodel import Session +from algo.algo_runner import AlgoRunner from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response from db.database import get_db from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo from services.device_service import DeviceService +from algo.algo_runner_manager import get_algo_runner + router = APIRouter() @@ -41,15 +44,15 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)): - service = DeviceService(db) +def create_device(device_data: DeviceCreate, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.device_service device = service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)): - service = DeviceService(db) +def update_device(device_data: DeviceUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.device_service device = service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") @@ -57,8 +60,8 @@ @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, db: Session = Depends(get_db)): - service = DeviceService(db) +def delete_device(device_id: int, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.device_service device = service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py index 7108571..9152eca 100644 --- a/apis/device_model_realtion.py +++ b/apis/device_model_realtion.py @@ -8,6 +8,9 @@ from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation from services.device_model_relation_service import DeviceModelRelationService +from algo.algo_runner import AlgoRunner +from algo.algo_runner_manager import get_algo_runner + router = APIRouter() @@ -32,8 +35,8 @@ @router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) def update_by_device(relation_data: List[DeviceModelRelationCreate], device_id: int = Query(...), - db: Session = Depends(get_db)): - service = DeviceModelRelationService(db) + algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.relation_service relations = service.update_relations_by_device(device_id, relation_data) return standard_response(data=relations) diff --git a/apis/model.py b/apis/model.py index 8adcc14..da63a61 100644 --- a/apis/model.py +++ b/apis/model.py @@ -1,14 +1,17 @@ from typing import List, Optional -from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi import APIRouter, Depends, Query from sqlmodel import Session + from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response -from common.biz_exception import BizException from db.database import get_db from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo from services.model_service import ModelService +from algo.algo_runner import AlgoRunner +from algo.algo_runner_manager import get_algo_runner + router = APIRouter() @@ -47,8 +50,8 @@ @router.post("/update", response_model=StandardResponse[AlgoModelInfo]) -def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)): - service = ModelService(db) +def update_model(model_data: AlgoModelUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.model_service model = service.update_model(model_data) if not model: return standard_error_response(data=model_data, message="Model not found") diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py index f55359b..7f3be9f 100644 --- a/common/global_thread_pool.py +++ b/common/global_thread_pool.py @@ -1,3 +1,4 @@ +import traceback import uuid from concurrent.futures import ThreadPoolExecutor import threading @@ -33,6 +34,7 @@ raise ValueError(f"线程 ID {thread_id} 已存在") future = self.executor.submit(fn, *args, **kwargs) self.task_map[thread_id] = future # 记录线程 ID 和 Future 对象的映射 + future.add_done_callback(lambda f: self._handle_exception(f, thread_id)) return thread_id def check_task_is_running(self, thread_id): @@ -64,6 +66,15 @@ logger.warning(f"No task found with thread ID {thread_id}.") return True # 如果找不到该任务,认为它已经停止(或者不存在) + def get_task_future(self, thread_id): + """获取指定线程 ID 的 future 对象""" + future = self.task_map.get(thread_id) + if future: + return future + else: + logger.warning(f"No task found with thread ID {thread_id}.") + return None + def stop_task(self, thread_id): """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务""" future = self.task_map.get(thread_id) @@ -78,3 +89,16 @@ """关闭线程池""" self.executor.shutdown(wait=wait) GlobalThreadPool._instance = None + + def _handle_exception(self, future, thread_id): + """ + 处理任务完成时的异常 + :param future: 完成的任务 future 对象 + :param camera_id: 对应的摄像头 ID + """ + try: + # 获取任务结果,如果任务有异常,这里会抛出 + result = future.result() + except Exception as e: + logger.error(f"Task for thread {thread_id} raised an exception: {e}") + logger.error(f"Traceback: {traceback.format_exc()}") \ No newline at end of file diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db index 62f6546..765b724 100644 --- a/db/safe-algo-pro.db +++ b/db/safe-algo-pro.db Binary files differ diff --git a/algo/algo_runner.py b/algo/algo_runner.py index 3b07e0a..a97aee9 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,7 @@ +import copy +import uuid +from typing import Dict + from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager from common.consts import NotifyChangeType @@ -9,6 +13,7 @@ from services.model_service import ModelService from common.global_logger import logger + class AlgoRunner: def __init__(self): with next(get_db()) as db: @@ -18,7 +23,7 @@ self.model_manager = ModelManager(db) self.thread_pool = GlobalThreadPool() - self.device_tasks = {} # 用于存储设备对应的线程 + self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 @@ -38,22 +43,24 @@ for device in devices: self.start_device_thread(device) - @staticmethod - def get_device_thread_id(device_id): - return f'device_{device_id}' + # @staticmethod + # def get_device_thread_id(device_id): + # return f'device_{device_id}' def start_device_thread(self, device: Device): + device = copy.deepcopy(device) """为单个设备启动检测线程""" if not device.input_stream_url: logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return - thread_id = self.get_device_thread_id(device.id) + thread_id = f'device_{device.id}_{uuid.uuid4()}' + logger.info(f'start thread {thread_id}, device info: {device}') - if self.thread_pool.check_task_is_running(thread_id=thread_id): - logger.info(f"设备 {device.code} 已经在运行中") - return + # if self.thread_pool.check_task_is_running(thread_id=thread_id): + # logger.info(f"设备 {device.code} 已经在运行中") + # return # 获取设备绑定的模型列表 relations = self.relation_service.get_device_models(device.id) @@ -64,52 +71,62 @@ logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") return + for r in relations: + if r.algo_model_id not in self.model_manager.models: + self.model_manager.load_new_model(r.algo_model_id) + model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] - device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list) + device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list, thread_id=thread_id) future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) self.device_tasks[device.id] = device_detection_task - def stop_device_thread(self, device_id): - """停止指定设备的检测线程""" - if device_id in self.device_tasks: - self.device_tasks[device_id].stop_detection_task() - del self.device_tasks[device_id] - # def stop_device_thread(self, device_id): - # """停止指定设备的检测线程,并确保其成功停止""" + # """停止指定设备的检测线程""" # if device_id in self.device_tasks: - # # 获取线程 ID 和 future 对象 - # thread_id = self.get_device_thread_id(device_id) - # future = self.thread_pool.check_task_stopped(thread_id=thread_id) - # - # if future: - # # 调用 stop_detection_task 停止任务 - # self.device_tasks[device_id].stop_detection_task() - # - # try: - # # 设置超时时间等待任务停止(例如10秒) - # result = future.result(timeout=10) - # logger.info(f"Task for device {device_id} stopped successfully.") - # except TimeoutError: - # logger.error(f"Task for device {device_id} did not stop within the timeout.") - # except Exception as e: - # logger.error(f"Task for device {device_id} encountered an error while stopping: {e}") - # finally: - # # 确保无论任务是否停止,都将其从任务列表中移除 - # del self.device_tasks[device_id] - # else: - # logger.warning(f"No task found for device {device_id}.") - # else: - # logger.warning(f"No task exists for device {device_id} in device_tasks.") + # logger.info(f'stop device {device_id} thread') + # self.device_tasks[device_id].stop_detection_task() + # del self.device_tasks[device_id] + + def stop_device_thread(self, device_id): + """停止指定设备的检测线程,并确保其成功停止""" + if device_id in self.device_tasks: + logger.info(f'stop device {device_id} thread') + # 获取线程 ID 和 future 对象 + thread_id = self.device_tasks[device_id].thread_id + future = self.thread_pool.get_task_future(thread_id=thread_id) + + if future: + if not future.done(): + # 任务正在运行,调用 stop_detection_task 停止任务 + self.device_tasks[device_id].stop_detection_task() + try: + # 设置超时时间等待任务停止(例如10秒) + result = future.result(timeout=10) + logger.info(f"Task {thread_id} stopped successfully.") + except TimeoutError: + logger.error(f"Task {thread_id} did not stop within the timeout.") + except Exception as e: + logger.error(f"Task {thread_id} encountered an error while stopping: {e}") + finally: + # 确保无论任务是否停止,都将其从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.info(f"Task {thread_id} has already stopped.") + # 任务已停止,直接从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.warning(f"No task found for {thread_id} .") + else: + logger.warning(f"No task exists for device {device_id} in device_tasks.") def restart_device_thread(self, device_id): self.stop_device_thread(device_id) - # todo 需要留个关闭时间吗 device = self.device_service.get_device(device_id) self.start_device_thread(device) def on_device_change(self, device_id, change_type): + logger.info(f"on device change, device {device_id} {change_type}") """设备变化回调""" if change_type == NotifyChangeType.DEVICE_CREATE: device = self.device_service.get_device(device_id) @@ -117,7 +134,10 @@ elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: - #todo 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) + self.restart_device_thread(device_id) def on_model_change(self, model_id, change_type): diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py new file mode 100644 index 0000000..d7d8015 --- /dev/null +++ b/algo/algo_runner_manager.py @@ -0,0 +1,7 @@ +from algo.algo_runner import AlgoRunner + +algo_runner = AlgoRunner() + + +def get_algo_runner(): + return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index ff2d513..5e1c75d 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -6,6 +6,7 @@ from algo.model_manager import AlgoModelExec from algo.stream_loader import OpenCVStreamLoad +from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool from common.string_utils import camel_to_snake from db.database import get_db @@ -40,11 +41,12 @@ class DeviceDetectionTask: - def __init__(self, device: Device, model_exec_list: List[AlgoModelExec]): + def __init__(self, device: Device, model_exec_list: List[AlgoModelExec], thread_id): self.device = device self.model_exec_list = model_exec_list self.__stop_event = Event() # 使用 Event 控制线程的运行状态 self.frame_ts = None + self.thread_id = thread_id with next(get_db()) as db: self.device_frame_service = DeviceFrameService(db) @@ -52,9 +54,11 @@ self.thread_pool = GlobalThreadPool() - self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code) + self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, + device_thread_id=thread_id) def stop_detection_task(self): + logger.info(f'stop detection task {self.device.id}, thread_id: {self.thread_id}') self.__stop_event.set() self.stream_loader.stop() # 停止视频流加载的线程 @@ -64,7 +68,7 @@ if self.frame_ts is None: self.frame_ts = datetime.now() return True - if datetime.now() - self.frame_ts > self.device.image_save_interval: + if (datetime.now() - self.frame_ts).total_seconds() > self.device.image_save_interval: self.frame_ts = datetime.now() return True return False @@ -74,6 +78,7 @@ return device_frame = self.device_frame_service.add_frame(self.device.id, frame) frame_id = device_frame.id + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') frame_results = [] for model_exec_id, results in results_map.items(): for r in results: @@ -83,7 +88,7 @@ algo_model_id=model_exec_id, object_class_id=r.object_class_id, object_class_name=r.object_class_name, - confidence=r.conf, + confidence=round(r.confidence, 4), location=r.location, ) frame_results.append(frame_result) @@ -93,6 +98,9 @@ for frame in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 + # print('frame') + if frame is None: + continue results_map = {} for model_exec in self.model_exec_list: diff --git a/algo/model_manager.py b/algo/model_manager.py index 535c666..94a058c 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -66,3 +66,15 @@ if algo_model_exec: algo_model_exec.algo_model_exec = None self.load_model(algo_model_exec) + + def load_new_model(self, model_id): + algo_model = self.model_service.get_model_by_id(model_id) + if algo_model: + algo_model_exec = AlgoModelExec( + algo_model_id=algo_model.id, + algo_model_info=algo_model + ) + self.models[algo_model.id] = algo_model_exec + self.load_model(algo_model_exec) + return algo_model_exec + return None diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 04faeb2..9a1199e 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -3,10 +3,11 @@ import numpy as np from threading import Thread, Event from common.global_logger import logger +from common.global_thread_pool import GlobalThreadPool class OpenCVStreamLoad: - def __init__(self, camera_url, camera_code, + def __init__(self, camera_url, camera_code, device_thread_id = '', retry_interval=1, vid_stride=1): assert camera_url is not None and camera_url != '' @@ -14,16 +15,26 @@ self.camera_code = camera_code self.retry_interval = retry_interval self.vid_stride = vid_stride + self.device_thread_id = device_thread_id - self.cap = self.get_connect() - self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) - self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) - self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + self.init = False + self.frame = None + self.cap = None - _, self.frame = self.cap.read() self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 - self.thread = Thread(target=self.update, daemon=True) - self.thread.start() + self.thread_pool = GlobalThreadPool() + self.thread_pool.submit_task(self.update) + # self.thread = Thread(target=self.update, daemon=True) + # self.thread.start() + + def init_cap(self): + self.cap = self.get_connect() + if self.cap: + self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + _, self.frame = self.cap.read() + self.init = True def create_capture(self): """ @@ -43,7 +54,11 @@ """ cap = None while cap is None or not cap.isOpened(): + if self.__stop_event.is_set(): # 检查是否收到停止信号 + logger.info(f"{self.url} stopping connection attempts...thread {self.device_thread_id} stopped") + return None # 退出循环,返回 None 或者其他你认为合适的值 logger.info(f"{self.url} try to connect...") + print('cap') cap = self.create_capture() if cap is None or not cap.isOpened(): logger.info(f"{self.url} connect failed, retry after {self.retry_interval} second...") @@ -55,13 +70,18 @@ def update(self): vid_n = 0 log_n = 0 - while True: + while not self.__stop_event.is_set(): + print('update') + if not self.init: + self.init_cap() + if self.cap is None: + continue vid_n += 1 if vid_n % self.vid_stride == 0: try: ret, frame = self.cap.read() if not ret: - logger.info("disconnect, try to reconnect...") + logger.info(f"{self.url} disconnect, try to reconnect...") self.cap.release() # 释放当前的捕获对象 self.cap = self.get_connect() # 尝试重新连接 self.frame = np.zeros_like(self.frame) @@ -69,12 +89,11 @@ else: vid_n += 1 self.frame = frame - # cv2.imwrite('cv_test.jpg', frame) if log_n % 1000 == 0: - logger.debug('cap success') + logger.debug(f'{self.url} cap success') log_n = (log_n + 1) % 250 except Exception as e: - logger.error("update fail", e) + logger.error(f"{self.url} update fail", e) if self.cap is not None: self.cap.release() self.cap = self.get_connect() # 尝试重新连接 @@ -86,7 +105,9 @@ return self.frame def stop(self): + logger.info(f'stop stream loader {self.url},device thread {self.device_thread_id} stopped') """ 停止视频流读取线程 """ self.__stop_event.set() - self.thread.join() # 确保线程已完全终止 - self.cap.release() + # self.thread.join() # 确保线程已完全终止 + if self.cap: + self.cap.release() diff --git a/apis/device.py b/apis/device.py index 7eeaa78..3be394f 100644 --- a/apis/device.py +++ b/apis/device.py @@ -3,11 +3,14 @@ from fastapi import APIRouter, Depends, Query from sqlmodel import Session +from algo.algo_runner import AlgoRunner from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response from db.database import get_db from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo from services.device_service import DeviceService +from algo.algo_runner_manager import get_algo_runner + router = APIRouter() @@ -41,15 +44,15 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)): - service = DeviceService(db) +def create_device(device_data: DeviceCreate, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.device_service device = service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)): - service = DeviceService(db) +def update_device(device_data: DeviceUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.device_service device = service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") @@ -57,8 +60,8 @@ @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, db: Session = Depends(get_db)): - service = DeviceService(db) +def delete_device(device_id: int, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.device_service device = service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py index 7108571..9152eca 100644 --- a/apis/device_model_realtion.py +++ b/apis/device_model_realtion.py @@ -8,6 +8,9 @@ from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation from services.device_model_relation_service import DeviceModelRelationService +from algo.algo_runner import AlgoRunner +from algo.algo_runner_manager import get_algo_runner + router = APIRouter() @@ -32,8 +35,8 @@ @router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) def update_by_device(relation_data: List[DeviceModelRelationCreate], device_id: int = Query(...), - db: Session = Depends(get_db)): - service = DeviceModelRelationService(db) + algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.relation_service relations = service.update_relations_by_device(device_id, relation_data) return standard_response(data=relations) diff --git a/apis/model.py b/apis/model.py index 8adcc14..da63a61 100644 --- a/apis/model.py +++ b/apis/model.py @@ -1,14 +1,17 @@ from typing import List, Optional -from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi import APIRouter, Depends, Query from sqlmodel import Session + from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response -from common.biz_exception import BizException from db.database import get_db from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo from services.model_service import ModelService +from algo.algo_runner import AlgoRunner +from algo.algo_runner_manager import get_algo_runner + router = APIRouter() @@ -47,8 +50,8 @@ @router.post("/update", response_model=StandardResponse[AlgoModelInfo]) -def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)): - service = ModelService(db) +def update_model(model_data: AlgoModelUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.model_service model = service.update_model(model_data) if not model: return standard_error_response(data=model_data, message="Model not found") diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py index f55359b..7f3be9f 100644 --- a/common/global_thread_pool.py +++ b/common/global_thread_pool.py @@ -1,3 +1,4 @@ +import traceback import uuid from concurrent.futures import ThreadPoolExecutor import threading @@ -33,6 +34,7 @@ raise ValueError(f"线程 ID {thread_id} 已存在") future = self.executor.submit(fn, *args, **kwargs) self.task_map[thread_id] = future # 记录线程 ID 和 Future 对象的映射 + future.add_done_callback(lambda f: self._handle_exception(f, thread_id)) return thread_id def check_task_is_running(self, thread_id): @@ -64,6 +66,15 @@ logger.warning(f"No task found with thread ID {thread_id}.") return True # 如果找不到该任务,认为它已经停止(或者不存在) + def get_task_future(self, thread_id): + """获取指定线程 ID 的 future 对象""" + future = self.task_map.get(thread_id) + if future: + return future + else: + logger.warning(f"No task found with thread ID {thread_id}.") + return None + def stop_task(self, thread_id): """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务""" future = self.task_map.get(thread_id) @@ -78,3 +89,16 @@ """关闭线程池""" self.executor.shutdown(wait=wait) GlobalThreadPool._instance = None + + def _handle_exception(self, future, thread_id): + """ + 处理任务完成时的异常 + :param future: 完成的任务 future 对象 + :param camera_id: 对应的摄像头 ID + """ + try: + # 获取任务结果,如果任务有异常,这里会抛出 + result = future.result() + except Exception as e: + logger.error(f"Task for thread {thread_id} raised an exception: {e}") + logger.error(f"Traceback: {traceback.format_exc()}") \ No newline at end of file diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db index 62f6546..765b724 100644 --- a/db/safe-algo-pro.db +++ b/db/safe-algo-pro.db Binary files differ diff --git a/entity/device.py b/entity/device.py index f09b52d..4e30fbd 100644 --- a/entity/device.py +++ b/entity/device.py @@ -17,6 +17,11 @@ class Device(DeviceBase, TimestampMixin, table=True): id: Optional[int] = Field(default=None, primary_key=True) + def __str__(self): + return f"id={self.id} name={self.name} code={self.code} ip={self.ip}" \ + f" input_stream_url={self.input_stream_url} output_stream_url={self.output_stream_url}" \ + f" image_save_interval={self.image_save_interval}" + class DeviceCreate(DeviceBase): pass diff --git a/algo/algo_runner.py b/algo/algo_runner.py index 3b07e0a..a97aee9 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,7 @@ +import copy +import uuid +from typing import Dict + from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager from common.consts import NotifyChangeType @@ -9,6 +13,7 @@ from services.model_service import ModelService from common.global_logger import logger + class AlgoRunner: def __init__(self): with next(get_db()) as db: @@ -18,7 +23,7 @@ self.model_manager = ModelManager(db) self.thread_pool = GlobalThreadPool() - self.device_tasks = {} # 用于存储设备对应的线程 + self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 @@ -38,22 +43,24 @@ for device in devices: self.start_device_thread(device) - @staticmethod - def get_device_thread_id(device_id): - return f'device_{device_id}' + # @staticmethod + # def get_device_thread_id(device_id): + # return f'device_{device_id}' def start_device_thread(self, device: Device): + device = copy.deepcopy(device) """为单个设备启动检测线程""" if not device.input_stream_url: logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return - thread_id = self.get_device_thread_id(device.id) + thread_id = f'device_{device.id}_{uuid.uuid4()}' + logger.info(f'start thread {thread_id}, device info: {device}') - if self.thread_pool.check_task_is_running(thread_id=thread_id): - logger.info(f"设备 {device.code} 已经在运行中") - return + # if self.thread_pool.check_task_is_running(thread_id=thread_id): + # logger.info(f"设备 {device.code} 已经在运行中") + # return # 获取设备绑定的模型列表 relations = self.relation_service.get_device_models(device.id) @@ -64,52 +71,62 @@ logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") return + for r in relations: + if r.algo_model_id not in self.model_manager.models: + self.model_manager.load_new_model(r.algo_model_id) + model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] - device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list) + device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list, thread_id=thread_id) future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) self.device_tasks[device.id] = device_detection_task - def stop_device_thread(self, device_id): - """停止指定设备的检测线程""" - if device_id in self.device_tasks: - self.device_tasks[device_id].stop_detection_task() - del self.device_tasks[device_id] - # def stop_device_thread(self, device_id): - # """停止指定设备的检测线程,并确保其成功停止""" + # """停止指定设备的检测线程""" # if device_id in self.device_tasks: - # # 获取线程 ID 和 future 对象 - # thread_id = self.get_device_thread_id(device_id) - # future = self.thread_pool.check_task_stopped(thread_id=thread_id) - # - # if future: - # # 调用 stop_detection_task 停止任务 - # self.device_tasks[device_id].stop_detection_task() - # - # try: - # # 设置超时时间等待任务停止(例如10秒) - # result = future.result(timeout=10) - # logger.info(f"Task for device {device_id} stopped successfully.") - # except TimeoutError: - # logger.error(f"Task for device {device_id} did not stop within the timeout.") - # except Exception as e: - # logger.error(f"Task for device {device_id} encountered an error while stopping: {e}") - # finally: - # # 确保无论任务是否停止,都将其从任务列表中移除 - # del self.device_tasks[device_id] - # else: - # logger.warning(f"No task found for device {device_id}.") - # else: - # logger.warning(f"No task exists for device {device_id} in device_tasks.") + # logger.info(f'stop device {device_id} thread') + # self.device_tasks[device_id].stop_detection_task() + # del self.device_tasks[device_id] + + def stop_device_thread(self, device_id): + """停止指定设备的检测线程,并确保其成功停止""" + if device_id in self.device_tasks: + logger.info(f'stop device {device_id} thread') + # 获取线程 ID 和 future 对象 + thread_id = self.device_tasks[device_id].thread_id + future = self.thread_pool.get_task_future(thread_id=thread_id) + + if future: + if not future.done(): + # 任务正在运行,调用 stop_detection_task 停止任务 + self.device_tasks[device_id].stop_detection_task() + try: + # 设置超时时间等待任务停止(例如10秒) + result = future.result(timeout=10) + logger.info(f"Task {thread_id} stopped successfully.") + except TimeoutError: + logger.error(f"Task {thread_id} did not stop within the timeout.") + except Exception as e: + logger.error(f"Task {thread_id} encountered an error while stopping: {e}") + finally: + # 确保无论任务是否停止,都将其从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.info(f"Task {thread_id} has already stopped.") + # 任务已停止,直接从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.warning(f"No task found for {thread_id} .") + else: + logger.warning(f"No task exists for device {device_id} in device_tasks.") def restart_device_thread(self, device_id): self.stop_device_thread(device_id) - # todo 需要留个关闭时间吗 device = self.device_service.get_device(device_id) self.start_device_thread(device) def on_device_change(self, device_id, change_type): + logger.info(f"on device change, device {device_id} {change_type}") """设备变化回调""" if change_type == NotifyChangeType.DEVICE_CREATE: device = self.device_service.get_device(device_id) @@ -117,7 +134,10 @@ elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: - #todo 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) + self.restart_device_thread(device_id) def on_model_change(self, model_id, change_type): diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py new file mode 100644 index 0000000..d7d8015 --- /dev/null +++ b/algo/algo_runner_manager.py @@ -0,0 +1,7 @@ +from algo.algo_runner import AlgoRunner + +algo_runner = AlgoRunner() + + +def get_algo_runner(): + return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index ff2d513..5e1c75d 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -6,6 +6,7 @@ from algo.model_manager import AlgoModelExec from algo.stream_loader import OpenCVStreamLoad +from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool from common.string_utils import camel_to_snake from db.database import get_db @@ -40,11 +41,12 @@ class DeviceDetectionTask: - def __init__(self, device: Device, model_exec_list: List[AlgoModelExec]): + def __init__(self, device: Device, model_exec_list: List[AlgoModelExec], thread_id): self.device = device self.model_exec_list = model_exec_list self.__stop_event = Event() # 使用 Event 控制线程的运行状态 self.frame_ts = None + self.thread_id = thread_id with next(get_db()) as db: self.device_frame_service = DeviceFrameService(db) @@ -52,9 +54,11 @@ self.thread_pool = GlobalThreadPool() - self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code) + self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, + device_thread_id=thread_id) def stop_detection_task(self): + logger.info(f'stop detection task {self.device.id}, thread_id: {self.thread_id}') self.__stop_event.set() self.stream_loader.stop() # 停止视频流加载的线程 @@ -64,7 +68,7 @@ if self.frame_ts is None: self.frame_ts = datetime.now() return True - if datetime.now() - self.frame_ts > self.device.image_save_interval: + if (datetime.now() - self.frame_ts).total_seconds() > self.device.image_save_interval: self.frame_ts = datetime.now() return True return False @@ -74,6 +78,7 @@ return device_frame = self.device_frame_service.add_frame(self.device.id, frame) frame_id = device_frame.id + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') frame_results = [] for model_exec_id, results in results_map.items(): for r in results: @@ -83,7 +88,7 @@ algo_model_id=model_exec_id, object_class_id=r.object_class_id, object_class_name=r.object_class_name, - confidence=r.conf, + confidence=round(r.confidence, 4), location=r.location, ) frame_results.append(frame_result) @@ -93,6 +98,9 @@ for frame in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 + # print('frame') + if frame is None: + continue results_map = {} for model_exec in self.model_exec_list: diff --git a/algo/model_manager.py b/algo/model_manager.py index 535c666..94a058c 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -66,3 +66,15 @@ if algo_model_exec: algo_model_exec.algo_model_exec = None self.load_model(algo_model_exec) + + def load_new_model(self, model_id): + algo_model = self.model_service.get_model_by_id(model_id) + if algo_model: + algo_model_exec = AlgoModelExec( + algo_model_id=algo_model.id, + algo_model_info=algo_model + ) + self.models[algo_model.id] = algo_model_exec + self.load_model(algo_model_exec) + return algo_model_exec + return None diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 04faeb2..9a1199e 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -3,10 +3,11 @@ import numpy as np from threading import Thread, Event from common.global_logger import logger +from common.global_thread_pool import GlobalThreadPool class OpenCVStreamLoad: - def __init__(self, camera_url, camera_code, + def __init__(self, camera_url, camera_code, device_thread_id = '', retry_interval=1, vid_stride=1): assert camera_url is not None and camera_url != '' @@ -14,16 +15,26 @@ self.camera_code = camera_code self.retry_interval = retry_interval self.vid_stride = vid_stride + self.device_thread_id = device_thread_id - self.cap = self.get_connect() - self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) - self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) - self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + self.init = False + self.frame = None + self.cap = None - _, self.frame = self.cap.read() self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 - self.thread = Thread(target=self.update, daemon=True) - self.thread.start() + self.thread_pool = GlobalThreadPool() + self.thread_pool.submit_task(self.update) + # self.thread = Thread(target=self.update, daemon=True) + # self.thread.start() + + def init_cap(self): + self.cap = self.get_connect() + if self.cap: + self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + _, self.frame = self.cap.read() + self.init = True def create_capture(self): """ @@ -43,7 +54,11 @@ """ cap = None while cap is None or not cap.isOpened(): + if self.__stop_event.is_set(): # 检查是否收到停止信号 + logger.info(f"{self.url} stopping connection attempts...thread {self.device_thread_id} stopped") + return None # 退出循环,返回 None 或者其他你认为合适的值 logger.info(f"{self.url} try to connect...") + print('cap') cap = self.create_capture() if cap is None or not cap.isOpened(): logger.info(f"{self.url} connect failed, retry after {self.retry_interval} second...") @@ -55,13 +70,18 @@ def update(self): vid_n = 0 log_n = 0 - while True: + while not self.__stop_event.is_set(): + print('update') + if not self.init: + self.init_cap() + if self.cap is None: + continue vid_n += 1 if vid_n % self.vid_stride == 0: try: ret, frame = self.cap.read() if not ret: - logger.info("disconnect, try to reconnect...") + logger.info(f"{self.url} disconnect, try to reconnect...") self.cap.release() # 释放当前的捕获对象 self.cap = self.get_connect() # 尝试重新连接 self.frame = np.zeros_like(self.frame) @@ -69,12 +89,11 @@ else: vid_n += 1 self.frame = frame - # cv2.imwrite('cv_test.jpg', frame) if log_n % 1000 == 0: - logger.debug('cap success') + logger.debug(f'{self.url} cap success') log_n = (log_n + 1) % 250 except Exception as e: - logger.error("update fail", e) + logger.error(f"{self.url} update fail", e) if self.cap is not None: self.cap.release() self.cap = self.get_connect() # 尝试重新连接 @@ -86,7 +105,9 @@ return self.frame def stop(self): + logger.info(f'stop stream loader {self.url},device thread {self.device_thread_id} stopped') """ 停止视频流读取线程 """ self.__stop_event.set() - self.thread.join() # 确保线程已完全终止 - self.cap.release() + # self.thread.join() # 确保线程已完全终止 + if self.cap: + self.cap.release() diff --git a/apis/device.py b/apis/device.py index 7eeaa78..3be394f 100644 --- a/apis/device.py +++ b/apis/device.py @@ -3,11 +3,14 @@ from fastapi import APIRouter, Depends, Query from sqlmodel import Session +from algo.algo_runner import AlgoRunner from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response from db.database import get_db from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo from services.device_service import DeviceService +from algo.algo_runner_manager import get_algo_runner + router = APIRouter() @@ -41,15 +44,15 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)): - service = DeviceService(db) +def create_device(device_data: DeviceCreate, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.device_service device = service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)): - service = DeviceService(db) +def update_device(device_data: DeviceUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.device_service device = service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") @@ -57,8 +60,8 @@ @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, db: Session = Depends(get_db)): - service = DeviceService(db) +def delete_device(device_id: int, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.device_service device = service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py index 7108571..9152eca 100644 --- a/apis/device_model_realtion.py +++ b/apis/device_model_realtion.py @@ -8,6 +8,9 @@ from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation from services.device_model_relation_service import DeviceModelRelationService +from algo.algo_runner import AlgoRunner +from algo.algo_runner_manager import get_algo_runner + router = APIRouter() @@ -32,8 +35,8 @@ @router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) def update_by_device(relation_data: List[DeviceModelRelationCreate], device_id: int = Query(...), - db: Session = Depends(get_db)): - service = DeviceModelRelationService(db) + algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.relation_service relations = service.update_relations_by_device(device_id, relation_data) return standard_response(data=relations) diff --git a/apis/model.py b/apis/model.py index 8adcc14..da63a61 100644 --- a/apis/model.py +++ b/apis/model.py @@ -1,14 +1,17 @@ from typing import List, Optional -from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi import APIRouter, Depends, Query from sqlmodel import Session + from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response -from common.biz_exception import BizException from db.database import get_db from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo from services.model_service import ModelService +from algo.algo_runner import AlgoRunner +from algo.algo_runner_manager import get_algo_runner + router = APIRouter() @@ -47,8 +50,8 @@ @router.post("/update", response_model=StandardResponse[AlgoModelInfo]) -def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)): - service = ModelService(db) +def update_model(model_data: AlgoModelUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.model_service model = service.update_model(model_data) if not model: return standard_error_response(data=model_data, message="Model not found") diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py index f55359b..7f3be9f 100644 --- a/common/global_thread_pool.py +++ b/common/global_thread_pool.py @@ -1,3 +1,4 @@ +import traceback import uuid from concurrent.futures import ThreadPoolExecutor import threading @@ -33,6 +34,7 @@ raise ValueError(f"线程 ID {thread_id} 已存在") future = self.executor.submit(fn, *args, **kwargs) self.task_map[thread_id] = future # 记录线程 ID 和 Future 对象的映射 + future.add_done_callback(lambda f: self._handle_exception(f, thread_id)) return thread_id def check_task_is_running(self, thread_id): @@ -64,6 +66,15 @@ logger.warning(f"No task found with thread ID {thread_id}.") return True # 如果找不到该任务,认为它已经停止(或者不存在) + def get_task_future(self, thread_id): + """获取指定线程 ID 的 future 对象""" + future = self.task_map.get(thread_id) + if future: + return future + else: + logger.warning(f"No task found with thread ID {thread_id}.") + return None + def stop_task(self, thread_id): """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务""" future = self.task_map.get(thread_id) @@ -78,3 +89,16 @@ """关闭线程池""" self.executor.shutdown(wait=wait) GlobalThreadPool._instance = None + + def _handle_exception(self, future, thread_id): + """ + 处理任务完成时的异常 + :param future: 完成的任务 future 对象 + :param camera_id: 对应的摄像头 ID + """ + try: + # 获取任务结果,如果任务有异常,这里会抛出 + result = future.result() + except Exception as e: + logger.error(f"Task for thread {thread_id} raised an exception: {e}") + logger.error(f"Traceback: {traceback.format_exc()}") \ No newline at end of file diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db index 62f6546..765b724 100644 --- a/db/safe-algo-pro.db +++ b/db/safe-algo-pro.db Binary files differ diff --git a/entity/device.py b/entity/device.py index f09b52d..4e30fbd 100644 --- a/entity/device.py +++ b/entity/device.py @@ -17,6 +17,11 @@ class Device(DeviceBase, TimestampMixin, table=True): id: Optional[int] = Field(default=None, primary_key=True) + def __str__(self): + return f"id={self.id} name={self.name} code={self.code} ip={self.ip}" \ + f" input_stream_url={self.input_stream_url} output_stream_url={self.output_stream_url}" \ + f" image_save_interval={self.image_save_interval}" + class DeviceCreate(DeviceBase): pass diff --git a/entity/device_frame.py b/entity/device_frame.py index 5612725..7c9aa89 100644 --- a/entity/device_frame.py +++ b/entity/device_frame.py @@ -10,7 +10,7 @@ time: datetime = Field(default_factory=datetime.now) -class DeviceFrame(DeviceFrameBase): +class DeviceFrame(DeviceFrameBase, table = True): __tablename__ = 'device_frame' id: Optional[int] = Field(default=None, primary_key=True) diff --git a/algo/algo_runner.py b/algo/algo_runner.py index 3b07e0a..a97aee9 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,7 @@ +import copy +import uuid +from typing import Dict + from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager from common.consts import NotifyChangeType @@ -9,6 +13,7 @@ from services.model_service import ModelService from common.global_logger import logger + class AlgoRunner: def __init__(self): with next(get_db()) as db: @@ -18,7 +23,7 @@ self.model_manager = ModelManager(db) self.thread_pool = GlobalThreadPool() - self.device_tasks = {} # 用于存储设备对应的线程 + self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 @@ -38,22 +43,24 @@ for device in devices: self.start_device_thread(device) - @staticmethod - def get_device_thread_id(device_id): - return f'device_{device_id}' + # @staticmethod + # def get_device_thread_id(device_id): + # return f'device_{device_id}' def start_device_thread(self, device: Device): + device = copy.deepcopy(device) """为单个设备启动检测线程""" if not device.input_stream_url: logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return - thread_id = self.get_device_thread_id(device.id) + thread_id = f'device_{device.id}_{uuid.uuid4()}' + logger.info(f'start thread {thread_id}, device info: {device}') - if self.thread_pool.check_task_is_running(thread_id=thread_id): - logger.info(f"设备 {device.code} 已经在运行中") - return + # if self.thread_pool.check_task_is_running(thread_id=thread_id): + # logger.info(f"设备 {device.code} 已经在运行中") + # return # 获取设备绑定的模型列表 relations = self.relation_service.get_device_models(device.id) @@ -64,52 +71,62 @@ logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") return + for r in relations: + if r.algo_model_id not in self.model_manager.models: + self.model_manager.load_new_model(r.algo_model_id) + model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] - device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list) + device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list, thread_id=thread_id) future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) self.device_tasks[device.id] = device_detection_task - def stop_device_thread(self, device_id): - """停止指定设备的检测线程""" - if device_id in self.device_tasks: - self.device_tasks[device_id].stop_detection_task() - del self.device_tasks[device_id] - # def stop_device_thread(self, device_id): - # """停止指定设备的检测线程,并确保其成功停止""" + # """停止指定设备的检测线程""" # if device_id in self.device_tasks: - # # 获取线程 ID 和 future 对象 - # thread_id = self.get_device_thread_id(device_id) - # future = self.thread_pool.check_task_stopped(thread_id=thread_id) - # - # if future: - # # 调用 stop_detection_task 停止任务 - # self.device_tasks[device_id].stop_detection_task() - # - # try: - # # 设置超时时间等待任务停止(例如10秒) - # result = future.result(timeout=10) - # logger.info(f"Task for device {device_id} stopped successfully.") - # except TimeoutError: - # logger.error(f"Task for device {device_id} did not stop within the timeout.") - # except Exception as e: - # logger.error(f"Task for device {device_id} encountered an error while stopping: {e}") - # finally: - # # 确保无论任务是否停止,都将其从任务列表中移除 - # del self.device_tasks[device_id] - # else: - # logger.warning(f"No task found for device {device_id}.") - # else: - # logger.warning(f"No task exists for device {device_id} in device_tasks.") + # logger.info(f'stop device {device_id} thread') + # self.device_tasks[device_id].stop_detection_task() + # del self.device_tasks[device_id] + + def stop_device_thread(self, device_id): + """停止指定设备的检测线程,并确保其成功停止""" + if device_id in self.device_tasks: + logger.info(f'stop device {device_id} thread') + # 获取线程 ID 和 future 对象 + thread_id = self.device_tasks[device_id].thread_id + future = self.thread_pool.get_task_future(thread_id=thread_id) + + if future: + if not future.done(): + # 任务正在运行,调用 stop_detection_task 停止任务 + self.device_tasks[device_id].stop_detection_task() + try: + # 设置超时时间等待任务停止(例如10秒) + result = future.result(timeout=10) + logger.info(f"Task {thread_id} stopped successfully.") + except TimeoutError: + logger.error(f"Task {thread_id} did not stop within the timeout.") + except Exception as e: + logger.error(f"Task {thread_id} encountered an error while stopping: {e}") + finally: + # 确保无论任务是否停止,都将其从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.info(f"Task {thread_id} has already stopped.") + # 任务已停止,直接从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.warning(f"No task found for {thread_id} .") + else: + logger.warning(f"No task exists for device {device_id} in device_tasks.") def restart_device_thread(self, device_id): self.stop_device_thread(device_id) - # todo 需要留个关闭时间吗 device = self.device_service.get_device(device_id) self.start_device_thread(device) def on_device_change(self, device_id, change_type): + logger.info(f"on device change, device {device_id} {change_type}") """设备变化回调""" if change_type == NotifyChangeType.DEVICE_CREATE: device = self.device_service.get_device(device_id) @@ -117,7 +134,10 @@ elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: - #todo 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) + self.restart_device_thread(device_id) def on_model_change(self, model_id, change_type): diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py new file mode 100644 index 0000000..d7d8015 --- /dev/null +++ b/algo/algo_runner_manager.py @@ -0,0 +1,7 @@ +from algo.algo_runner import AlgoRunner + +algo_runner = AlgoRunner() + + +def get_algo_runner(): + return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index ff2d513..5e1c75d 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -6,6 +6,7 @@ from algo.model_manager import AlgoModelExec from algo.stream_loader import OpenCVStreamLoad +from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool from common.string_utils import camel_to_snake from db.database import get_db @@ -40,11 +41,12 @@ class DeviceDetectionTask: - def __init__(self, device: Device, model_exec_list: List[AlgoModelExec]): + def __init__(self, device: Device, model_exec_list: List[AlgoModelExec], thread_id): self.device = device self.model_exec_list = model_exec_list self.__stop_event = Event() # 使用 Event 控制线程的运行状态 self.frame_ts = None + self.thread_id = thread_id with next(get_db()) as db: self.device_frame_service = DeviceFrameService(db) @@ -52,9 +54,11 @@ self.thread_pool = GlobalThreadPool() - self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code) + self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, + device_thread_id=thread_id) def stop_detection_task(self): + logger.info(f'stop detection task {self.device.id}, thread_id: {self.thread_id}') self.__stop_event.set() self.stream_loader.stop() # 停止视频流加载的线程 @@ -64,7 +68,7 @@ if self.frame_ts is None: self.frame_ts = datetime.now() return True - if datetime.now() - self.frame_ts > self.device.image_save_interval: + if (datetime.now() - self.frame_ts).total_seconds() > self.device.image_save_interval: self.frame_ts = datetime.now() return True return False @@ -74,6 +78,7 @@ return device_frame = self.device_frame_service.add_frame(self.device.id, frame) frame_id = device_frame.id + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') frame_results = [] for model_exec_id, results in results_map.items(): for r in results: @@ -83,7 +88,7 @@ algo_model_id=model_exec_id, object_class_id=r.object_class_id, object_class_name=r.object_class_name, - confidence=r.conf, + confidence=round(r.confidence, 4), location=r.location, ) frame_results.append(frame_result) @@ -93,6 +98,9 @@ for frame in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 + # print('frame') + if frame is None: + continue results_map = {} for model_exec in self.model_exec_list: diff --git a/algo/model_manager.py b/algo/model_manager.py index 535c666..94a058c 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -66,3 +66,15 @@ if algo_model_exec: algo_model_exec.algo_model_exec = None self.load_model(algo_model_exec) + + def load_new_model(self, model_id): + algo_model = self.model_service.get_model_by_id(model_id) + if algo_model: + algo_model_exec = AlgoModelExec( + algo_model_id=algo_model.id, + algo_model_info=algo_model + ) + self.models[algo_model.id] = algo_model_exec + self.load_model(algo_model_exec) + return algo_model_exec + return None diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 04faeb2..9a1199e 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -3,10 +3,11 @@ import numpy as np from threading import Thread, Event from common.global_logger import logger +from common.global_thread_pool import GlobalThreadPool class OpenCVStreamLoad: - def __init__(self, camera_url, camera_code, + def __init__(self, camera_url, camera_code, device_thread_id = '', retry_interval=1, vid_stride=1): assert camera_url is not None and camera_url != '' @@ -14,16 +15,26 @@ self.camera_code = camera_code self.retry_interval = retry_interval self.vid_stride = vid_stride + self.device_thread_id = device_thread_id - self.cap = self.get_connect() - self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) - self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) - self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + self.init = False + self.frame = None + self.cap = None - _, self.frame = self.cap.read() self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 - self.thread = Thread(target=self.update, daemon=True) - self.thread.start() + self.thread_pool = GlobalThreadPool() + self.thread_pool.submit_task(self.update) + # self.thread = Thread(target=self.update, daemon=True) + # self.thread.start() + + def init_cap(self): + self.cap = self.get_connect() + if self.cap: + self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + _, self.frame = self.cap.read() + self.init = True def create_capture(self): """ @@ -43,7 +54,11 @@ """ cap = None while cap is None or not cap.isOpened(): + if self.__stop_event.is_set(): # 检查是否收到停止信号 + logger.info(f"{self.url} stopping connection attempts...thread {self.device_thread_id} stopped") + return None # 退出循环,返回 None 或者其他你认为合适的值 logger.info(f"{self.url} try to connect...") + print('cap') cap = self.create_capture() if cap is None or not cap.isOpened(): logger.info(f"{self.url} connect failed, retry after {self.retry_interval} second...") @@ -55,13 +70,18 @@ def update(self): vid_n = 0 log_n = 0 - while True: + while not self.__stop_event.is_set(): + print('update') + if not self.init: + self.init_cap() + if self.cap is None: + continue vid_n += 1 if vid_n % self.vid_stride == 0: try: ret, frame = self.cap.read() if not ret: - logger.info("disconnect, try to reconnect...") + logger.info(f"{self.url} disconnect, try to reconnect...") self.cap.release() # 释放当前的捕获对象 self.cap = self.get_connect() # 尝试重新连接 self.frame = np.zeros_like(self.frame) @@ -69,12 +89,11 @@ else: vid_n += 1 self.frame = frame - # cv2.imwrite('cv_test.jpg', frame) if log_n % 1000 == 0: - logger.debug('cap success') + logger.debug(f'{self.url} cap success') log_n = (log_n + 1) % 250 except Exception as e: - logger.error("update fail", e) + logger.error(f"{self.url} update fail", e) if self.cap is not None: self.cap.release() self.cap = self.get_connect() # 尝试重新连接 @@ -86,7 +105,9 @@ return self.frame def stop(self): + logger.info(f'stop stream loader {self.url},device thread {self.device_thread_id} stopped') """ 停止视频流读取线程 """ self.__stop_event.set() - self.thread.join() # 确保线程已完全终止 - self.cap.release() + # self.thread.join() # 确保线程已完全终止 + if self.cap: + self.cap.release() diff --git a/apis/device.py b/apis/device.py index 7eeaa78..3be394f 100644 --- a/apis/device.py +++ b/apis/device.py @@ -3,11 +3,14 @@ from fastapi import APIRouter, Depends, Query from sqlmodel import Session +from algo.algo_runner import AlgoRunner from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response from db.database import get_db from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo from services.device_service import DeviceService +from algo.algo_runner_manager import get_algo_runner + router = APIRouter() @@ -41,15 +44,15 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)): - service = DeviceService(db) +def create_device(device_data: DeviceCreate, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.device_service device = service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)): - service = DeviceService(db) +def update_device(device_data: DeviceUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.device_service device = service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") @@ -57,8 +60,8 @@ @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, db: Session = Depends(get_db)): - service = DeviceService(db) +def delete_device(device_id: int, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.device_service device = service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py index 7108571..9152eca 100644 --- a/apis/device_model_realtion.py +++ b/apis/device_model_realtion.py @@ -8,6 +8,9 @@ from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation from services.device_model_relation_service import DeviceModelRelationService +from algo.algo_runner import AlgoRunner +from algo.algo_runner_manager import get_algo_runner + router = APIRouter() @@ -32,8 +35,8 @@ @router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) def update_by_device(relation_data: List[DeviceModelRelationCreate], device_id: int = Query(...), - db: Session = Depends(get_db)): - service = DeviceModelRelationService(db) + algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.relation_service relations = service.update_relations_by_device(device_id, relation_data) return standard_response(data=relations) diff --git a/apis/model.py b/apis/model.py index 8adcc14..da63a61 100644 --- a/apis/model.py +++ b/apis/model.py @@ -1,14 +1,17 @@ from typing import List, Optional -from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi import APIRouter, Depends, Query from sqlmodel import Session + from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response -from common.biz_exception import BizException from db.database import get_db from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo from services.model_service import ModelService +from algo.algo_runner import AlgoRunner +from algo.algo_runner_manager import get_algo_runner + router = APIRouter() @@ -47,8 +50,8 @@ @router.post("/update", response_model=StandardResponse[AlgoModelInfo]) -def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)): - service = ModelService(db) +def update_model(model_data: AlgoModelUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.model_service model = service.update_model(model_data) if not model: return standard_error_response(data=model_data, message="Model not found") diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py index f55359b..7f3be9f 100644 --- a/common/global_thread_pool.py +++ b/common/global_thread_pool.py @@ -1,3 +1,4 @@ +import traceback import uuid from concurrent.futures import ThreadPoolExecutor import threading @@ -33,6 +34,7 @@ raise ValueError(f"线程 ID {thread_id} 已存在") future = self.executor.submit(fn, *args, **kwargs) self.task_map[thread_id] = future # 记录线程 ID 和 Future 对象的映射 + future.add_done_callback(lambda f: self._handle_exception(f, thread_id)) return thread_id def check_task_is_running(self, thread_id): @@ -64,6 +66,15 @@ logger.warning(f"No task found with thread ID {thread_id}.") return True # 如果找不到该任务,认为它已经停止(或者不存在) + def get_task_future(self, thread_id): + """获取指定线程 ID 的 future 对象""" + future = self.task_map.get(thread_id) + if future: + return future + else: + logger.warning(f"No task found with thread ID {thread_id}.") + return None + def stop_task(self, thread_id): """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务""" future = self.task_map.get(thread_id) @@ -78,3 +89,16 @@ """关闭线程池""" self.executor.shutdown(wait=wait) GlobalThreadPool._instance = None + + def _handle_exception(self, future, thread_id): + """ + 处理任务完成时的异常 + :param future: 完成的任务 future 对象 + :param camera_id: 对应的摄像头 ID + """ + try: + # 获取任务结果,如果任务有异常,这里会抛出 + result = future.result() + except Exception as e: + logger.error(f"Task for thread {thread_id} raised an exception: {e}") + logger.error(f"Traceback: {traceback.format_exc()}") \ No newline at end of file diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db index 62f6546..765b724 100644 --- a/db/safe-algo-pro.db +++ b/db/safe-algo-pro.db Binary files differ diff --git a/entity/device.py b/entity/device.py index f09b52d..4e30fbd 100644 --- a/entity/device.py +++ b/entity/device.py @@ -17,6 +17,11 @@ class Device(DeviceBase, TimestampMixin, table=True): id: Optional[int] = Field(default=None, primary_key=True) + def __str__(self): + return f"id={self.id} name={self.name} code={self.code} ip={self.ip}" \ + f" input_stream_url={self.input_stream_url} output_stream_url={self.output_stream_url}" \ + f" image_save_interval={self.image_save_interval}" + class DeviceCreate(DeviceBase): pass diff --git a/entity/device_frame.py b/entity/device_frame.py index 5612725..7c9aa89 100644 --- a/entity/device_frame.py +++ b/entity/device_frame.py @@ -10,7 +10,7 @@ time: datetime = Field(default_factory=datetime.now) -class DeviceFrame(DeviceFrameBase): +class DeviceFrame(DeviceFrameBase, table = True): __tablename__ = 'device_frame' id: Optional[int] = Field(default=None, primary_key=True) diff --git a/entity/frame_analysis_result.py b/entity/frame_analysis_result.py index 2185e4b..ca8793e 100644 --- a/entity/frame_analysis_result.py +++ b/entity/frame_analysis_result.py @@ -15,7 +15,7 @@ time: datetime = Field(default_factory=datetime.now) -class FrameAnalysisResult(FrameAnalysisResultBase): +class FrameAnalysisResult(FrameAnalysisResultBase, table = True): __tablename__ = 'frame_analysis_result' id: Optional[int] = Field(default=None, primary_key=True) diff --git a/algo/algo_runner.py b/algo/algo_runner.py index 3b07e0a..a97aee9 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,7 @@ +import copy +import uuid +from typing import Dict + from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager from common.consts import NotifyChangeType @@ -9,6 +13,7 @@ from services.model_service import ModelService from common.global_logger import logger + class AlgoRunner: def __init__(self): with next(get_db()) as db: @@ -18,7 +23,7 @@ self.model_manager = ModelManager(db) self.thread_pool = GlobalThreadPool() - self.device_tasks = {} # 用于存储设备对应的线程 + self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 @@ -38,22 +43,24 @@ for device in devices: self.start_device_thread(device) - @staticmethod - def get_device_thread_id(device_id): - return f'device_{device_id}' + # @staticmethod + # def get_device_thread_id(device_id): + # return f'device_{device_id}' def start_device_thread(self, device: Device): + device = copy.deepcopy(device) """为单个设备启动检测线程""" if not device.input_stream_url: logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return - thread_id = self.get_device_thread_id(device.id) + thread_id = f'device_{device.id}_{uuid.uuid4()}' + logger.info(f'start thread {thread_id}, device info: {device}') - if self.thread_pool.check_task_is_running(thread_id=thread_id): - logger.info(f"设备 {device.code} 已经在运行中") - return + # if self.thread_pool.check_task_is_running(thread_id=thread_id): + # logger.info(f"设备 {device.code} 已经在运行中") + # return # 获取设备绑定的模型列表 relations = self.relation_service.get_device_models(device.id) @@ -64,52 +71,62 @@ logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") return + for r in relations: + if r.algo_model_id not in self.model_manager.models: + self.model_manager.load_new_model(r.algo_model_id) + model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] - device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list) + device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list, thread_id=thread_id) future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) self.device_tasks[device.id] = device_detection_task - def stop_device_thread(self, device_id): - """停止指定设备的检测线程""" - if device_id in self.device_tasks: - self.device_tasks[device_id].stop_detection_task() - del self.device_tasks[device_id] - # def stop_device_thread(self, device_id): - # """停止指定设备的检测线程,并确保其成功停止""" + # """停止指定设备的检测线程""" # if device_id in self.device_tasks: - # # 获取线程 ID 和 future 对象 - # thread_id = self.get_device_thread_id(device_id) - # future = self.thread_pool.check_task_stopped(thread_id=thread_id) - # - # if future: - # # 调用 stop_detection_task 停止任务 - # self.device_tasks[device_id].stop_detection_task() - # - # try: - # # 设置超时时间等待任务停止(例如10秒) - # result = future.result(timeout=10) - # logger.info(f"Task for device {device_id} stopped successfully.") - # except TimeoutError: - # logger.error(f"Task for device {device_id} did not stop within the timeout.") - # except Exception as e: - # logger.error(f"Task for device {device_id} encountered an error while stopping: {e}") - # finally: - # # 确保无论任务是否停止,都将其从任务列表中移除 - # del self.device_tasks[device_id] - # else: - # logger.warning(f"No task found for device {device_id}.") - # else: - # logger.warning(f"No task exists for device {device_id} in device_tasks.") + # logger.info(f'stop device {device_id} thread') + # self.device_tasks[device_id].stop_detection_task() + # del self.device_tasks[device_id] + + def stop_device_thread(self, device_id): + """停止指定设备的检测线程,并确保其成功停止""" + if device_id in self.device_tasks: + logger.info(f'stop device {device_id} thread') + # 获取线程 ID 和 future 对象 + thread_id = self.device_tasks[device_id].thread_id + future = self.thread_pool.get_task_future(thread_id=thread_id) + + if future: + if not future.done(): + # 任务正在运行,调用 stop_detection_task 停止任务 + self.device_tasks[device_id].stop_detection_task() + try: + # 设置超时时间等待任务停止(例如10秒) + result = future.result(timeout=10) + logger.info(f"Task {thread_id} stopped successfully.") + except TimeoutError: + logger.error(f"Task {thread_id} did not stop within the timeout.") + except Exception as e: + logger.error(f"Task {thread_id} encountered an error while stopping: {e}") + finally: + # 确保无论任务是否停止,都将其从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.info(f"Task {thread_id} has already stopped.") + # 任务已停止,直接从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.warning(f"No task found for {thread_id} .") + else: + logger.warning(f"No task exists for device {device_id} in device_tasks.") def restart_device_thread(self, device_id): self.stop_device_thread(device_id) - # todo 需要留个关闭时间吗 device = self.device_service.get_device(device_id) self.start_device_thread(device) def on_device_change(self, device_id, change_type): + logger.info(f"on device change, device {device_id} {change_type}") """设备变化回调""" if change_type == NotifyChangeType.DEVICE_CREATE: device = self.device_service.get_device(device_id) @@ -117,7 +134,10 @@ elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: - #todo 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) + self.restart_device_thread(device_id) def on_model_change(self, model_id, change_type): diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py new file mode 100644 index 0000000..d7d8015 --- /dev/null +++ b/algo/algo_runner_manager.py @@ -0,0 +1,7 @@ +from algo.algo_runner import AlgoRunner + +algo_runner = AlgoRunner() + + +def get_algo_runner(): + return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index ff2d513..5e1c75d 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -6,6 +6,7 @@ from algo.model_manager import AlgoModelExec from algo.stream_loader import OpenCVStreamLoad +from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool from common.string_utils import camel_to_snake from db.database import get_db @@ -40,11 +41,12 @@ class DeviceDetectionTask: - def __init__(self, device: Device, model_exec_list: List[AlgoModelExec]): + def __init__(self, device: Device, model_exec_list: List[AlgoModelExec], thread_id): self.device = device self.model_exec_list = model_exec_list self.__stop_event = Event() # 使用 Event 控制线程的运行状态 self.frame_ts = None + self.thread_id = thread_id with next(get_db()) as db: self.device_frame_service = DeviceFrameService(db) @@ -52,9 +54,11 @@ self.thread_pool = GlobalThreadPool() - self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code) + self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, + device_thread_id=thread_id) def stop_detection_task(self): + logger.info(f'stop detection task {self.device.id}, thread_id: {self.thread_id}') self.__stop_event.set() self.stream_loader.stop() # 停止视频流加载的线程 @@ -64,7 +68,7 @@ if self.frame_ts is None: self.frame_ts = datetime.now() return True - if datetime.now() - self.frame_ts > self.device.image_save_interval: + if (datetime.now() - self.frame_ts).total_seconds() > self.device.image_save_interval: self.frame_ts = datetime.now() return True return False @@ -74,6 +78,7 @@ return device_frame = self.device_frame_service.add_frame(self.device.id, frame) frame_id = device_frame.id + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') frame_results = [] for model_exec_id, results in results_map.items(): for r in results: @@ -83,7 +88,7 @@ algo_model_id=model_exec_id, object_class_id=r.object_class_id, object_class_name=r.object_class_name, - confidence=r.conf, + confidence=round(r.confidence, 4), location=r.location, ) frame_results.append(frame_result) @@ -93,6 +98,9 @@ for frame in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 + # print('frame') + if frame is None: + continue results_map = {} for model_exec in self.model_exec_list: diff --git a/algo/model_manager.py b/algo/model_manager.py index 535c666..94a058c 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -66,3 +66,15 @@ if algo_model_exec: algo_model_exec.algo_model_exec = None self.load_model(algo_model_exec) + + def load_new_model(self, model_id): + algo_model = self.model_service.get_model_by_id(model_id) + if algo_model: + algo_model_exec = AlgoModelExec( + algo_model_id=algo_model.id, + algo_model_info=algo_model + ) + self.models[algo_model.id] = algo_model_exec + self.load_model(algo_model_exec) + return algo_model_exec + return None diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 04faeb2..9a1199e 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -3,10 +3,11 @@ import numpy as np from threading import Thread, Event from common.global_logger import logger +from common.global_thread_pool import GlobalThreadPool class OpenCVStreamLoad: - def __init__(self, camera_url, camera_code, + def __init__(self, camera_url, camera_code, device_thread_id = '', retry_interval=1, vid_stride=1): assert camera_url is not None and camera_url != '' @@ -14,16 +15,26 @@ self.camera_code = camera_code self.retry_interval = retry_interval self.vid_stride = vid_stride + self.device_thread_id = device_thread_id - self.cap = self.get_connect() - self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) - self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) - self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + self.init = False + self.frame = None + self.cap = None - _, self.frame = self.cap.read() self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 - self.thread = Thread(target=self.update, daemon=True) - self.thread.start() + self.thread_pool = GlobalThreadPool() + self.thread_pool.submit_task(self.update) + # self.thread = Thread(target=self.update, daemon=True) + # self.thread.start() + + def init_cap(self): + self.cap = self.get_connect() + if self.cap: + self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + _, self.frame = self.cap.read() + self.init = True def create_capture(self): """ @@ -43,7 +54,11 @@ """ cap = None while cap is None or not cap.isOpened(): + if self.__stop_event.is_set(): # 检查是否收到停止信号 + logger.info(f"{self.url} stopping connection attempts...thread {self.device_thread_id} stopped") + return None # 退出循环,返回 None 或者其他你认为合适的值 logger.info(f"{self.url} try to connect...") + print('cap') cap = self.create_capture() if cap is None or not cap.isOpened(): logger.info(f"{self.url} connect failed, retry after {self.retry_interval} second...") @@ -55,13 +70,18 @@ def update(self): vid_n = 0 log_n = 0 - while True: + while not self.__stop_event.is_set(): + print('update') + if not self.init: + self.init_cap() + if self.cap is None: + continue vid_n += 1 if vid_n % self.vid_stride == 0: try: ret, frame = self.cap.read() if not ret: - logger.info("disconnect, try to reconnect...") + logger.info(f"{self.url} disconnect, try to reconnect...") self.cap.release() # 释放当前的捕获对象 self.cap = self.get_connect() # 尝试重新连接 self.frame = np.zeros_like(self.frame) @@ -69,12 +89,11 @@ else: vid_n += 1 self.frame = frame - # cv2.imwrite('cv_test.jpg', frame) if log_n % 1000 == 0: - logger.debug('cap success') + logger.debug(f'{self.url} cap success') log_n = (log_n + 1) % 250 except Exception as e: - logger.error("update fail", e) + logger.error(f"{self.url} update fail", e) if self.cap is not None: self.cap.release() self.cap = self.get_connect() # 尝试重新连接 @@ -86,7 +105,9 @@ return self.frame def stop(self): + logger.info(f'stop stream loader {self.url},device thread {self.device_thread_id} stopped') """ 停止视频流读取线程 """ self.__stop_event.set() - self.thread.join() # 确保线程已完全终止 - self.cap.release() + # self.thread.join() # 确保线程已完全终止 + if self.cap: + self.cap.release() diff --git a/apis/device.py b/apis/device.py index 7eeaa78..3be394f 100644 --- a/apis/device.py +++ b/apis/device.py @@ -3,11 +3,14 @@ from fastapi import APIRouter, Depends, Query from sqlmodel import Session +from algo.algo_runner import AlgoRunner from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response from db.database import get_db from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo from services.device_service import DeviceService +from algo.algo_runner_manager import get_algo_runner + router = APIRouter() @@ -41,15 +44,15 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)): - service = DeviceService(db) +def create_device(device_data: DeviceCreate, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.device_service device = service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)): - service = DeviceService(db) +def update_device(device_data: DeviceUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.device_service device = service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") @@ -57,8 +60,8 @@ @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, db: Session = Depends(get_db)): - service = DeviceService(db) +def delete_device(device_id: int, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.device_service device = service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py index 7108571..9152eca 100644 --- a/apis/device_model_realtion.py +++ b/apis/device_model_realtion.py @@ -8,6 +8,9 @@ from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation from services.device_model_relation_service import DeviceModelRelationService +from algo.algo_runner import AlgoRunner +from algo.algo_runner_manager import get_algo_runner + router = APIRouter() @@ -32,8 +35,8 @@ @router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) def update_by_device(relation_data: List[DeviceModelRelationCreate], device_id: int = Query(...), - db: Session = Depends(get_db)): - service = DeviceModelRelationService(db) + algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.relation_service relations = service.update_relations_by_device(device_id, relation_data) return standard_response(data=relations) diff --git a/apis/model.py b/apis/model.py index 8adcc14..da63a61 100644 --- a/apis/model.py +++ b/apis/model.py @@ -1,14 +1,17 @@ from typing import List, Optional -from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi import APIRouter, Depends, Query from sqlmodel import Session + from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response -from common.biz_exception import BizException from db.database import get_db from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo from services.model_service import ModelService +from algo.algo_runner import AlgoRunner +from algo.algo_runner_manager import get_algo_runner + router = APIRouter() @@ -47,8 +50,8 @@ @router.post("/update", response_model=StandardResponse[AlgoModelInfo]) -def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)): - service = ModelService(db) +def update_model(model_data: AlgoModelUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.model_service model = service.update_model(model_data) if not model: return standard_error_response(data=model_data, message="Model not found") diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py index f55359b..7f3be9f 100644 --- a/common/global_thread_pool.py +++ b/common/global_thread_pool.py @@ -1,3 +1,4 @@ +import traceback import uuid from concurrent.futures import ThreadPoolExecutor import threading @@ -33,6 +34,7 @@ raise ValueError(f"线程 ID {thread_id} 已存在") future = self.executor.submit(fn, *args, **kwargs) self.task_map[thread_id] = future # 记录线程 ID 和 Future 对象的映射 + future.add_done_callback(lambda f: self._handle_exception(f, thread_id)) return thread_id def check_task_is_running(self, thread_id): @@ -64,6 +66,15 @@ logger.warning(f"No task found with thread ID {thread_id}.") return True # 如果找不到该任务,认为它已经停止(或者不存在) + def get_task_future(self, thread_id): + """获取指定线程 ID 的 future 对象""" + future = self.task_map.get(thread_id) + if future: + return future + else: + logger.warning(f"No task found with thread ID {thread_id}.") + return None + def stop_task(self, thread_id): """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务""" future = self.task_map.get(thread_id) @@ -78,3 +89,16 @@ """关闭线程池""" self.executor.shutdown(wait=wait) GlobalThreadPool._instance = None + + def _handle_exception(self, future, thread_id): + """ + 处理任务完成时的异常 + :param future: 完成的任务 future 对象 + :param camera_id: 对应的摄像头 ID + """ + try: + # 获取任务结果,如果任务有异常,这里会抛出 + result = future.result() + except Exception as e: + logger.error(f"Task for thread {thread_id} raised an exception: {e}") + logger.error(f"Traceback: {traceback.format_exc()}") \ No newline at end of file diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db index 62f6546..765b724 100644 --- a/db/safe-algo-pro.db +++ b/db/safe-algo-pro.db Binary files differ diff --git a/entity/device.py b/entity/device.py index f09b52d..4e30fbd 100644 --- a/entity/device.py +++ b/entity/device.py @@ -17,6 +17,11 @@ class Device(DeviceBase, TimestampMixin, table=True): id: Optional[int] = Field(default=None, primary_key=True) + def __str__(self): + return f"id={self.id} name={self.name} code={self.code} ip={self.ip}" \ + f" input_stream_url={self.input_stream_url} output_stream_url={self.output_stream_url}" \ + f" image_save_interval={self.image_save_interval}" + class DeviceCreate(DeviceBase): pass diff --git a/entity/device_frame.py b/entity/device_frame.py index 5612725..7c9aa89 100644 --- a/entity/device_frame.py +++ b/entity/device_frame.py @@ -10,7 +10,7 @@ time: datetime = Field(default_factory=datetime.now) -class DeviceFrame(DeviceFrameBase): +class DeviceFrame(DeviceFrameBase, table = True): __tablename__ = 'device_frame' id: Optional[int] = Field(default=None, primary_key=True) diff --git a/entity/frame_analysis_result.py b/entity/frame_analysis_result.py index 2185e4b..ca8793e 100644 --- a/entity/frame_analysis_result.py +++ b/entity/frame_analysis_result.py @@ -15,7 +15,7 @@ time: datetime = Field(default_factory=datetime.now) -class FrameAnalysisResult(FrameAnalysisResultBase): +class FrameAnalysisResult(FrameAnalysisResultBase, table = True): __tablename__ = 'frame_analysis_result' id: Optional[int] = Field(default=None, primary_key=True) diff --git a/main.py b/main.py index fdbfb7d..1121638 100644 --- a/main.py +++ b/main.py @@ -2,7 +2,7 @@ from fastapi import FastAPI, HTTPException -from algo.algo_runner import AlgoRunner +from algo.algo_runner_manager import get_algo_runner from apis.router import router import uvicorn import logging @@ -22,7 +22,7 @@ # async def startup_event(): # algo_runner.start() -algo_runner = AlgoRunner() +algo_runner = get_algo_runner() @asynccontextmanager diff --git a/algo/algo_runner.py b/algo/algo_runner.py index 3b07e0a..a97aee9 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,7 @@ +import copy +import uuid +from typing import Dict + from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager from common.consts import NotifyChangeType @@ -9,6 +13,7 @@ from services.model_service import ModelService from common.global_logger import logger + class AlgoRunner: def __init__(self): with next(get_db()) as db: @@ -18,7 +23,7 @@ self.model_manager = ModelManager(db) self.thread_pool = GlobalThreadPool() - self.device_tasks = {} # 用于存储设备对应的线程 + self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 @@ -38,22 +43,24 @@ for device in devices: self.start_device_thread(device) - @staticmethod - def get_device_thread_id(device_id): - return f'device_{device_id}' + # @staticmethod + # def get_device_thread_id(device_id): + # return f'device_{device_id}' def start_device_thread(self, device: Device): + device = copy.deepcopy(device) """为单个设备启动检测线程""" if not device.input_stream_url: logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return - thread_id = self.get_device_thread_id(device.id) + thread_id = f'device_{device.id}_{uuid.uuid4()}' + logger.info(f'start thread {thread_id}, device info: {device}') - if self.thread_pool.check_task_is_running(thread_id=thread_id): - logger.info(f"设备 {device.code} 已经在运行中") - return + # if self.thread_pool.check_task_is_running(thread_id=thread_id): + # logger.info(f"设备 {device.code} 已经在运行中") + # return # 获取设备绑定的模型列表 relations = self.relation_service.get_device_models(device.id) @@ -64,52 +71,62 @@ logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") return + for r in relations: + if r.algo_model_id not in self.model_manager.models: + self.model_manager.load_new_model(r.algo_model_id) + model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] - device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list) + device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list, thread_id=thread_id) future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) self.device_tasks[device.id] = device_detection_task - def stop_device_thread(self, device_id): - """停止指定设备的检测线程""" - if device_id in self.device_tasks: - self.device_tasks[device_id].stop_detection_task() - del self.device_tasks[device_id] - # def stop_device_thread(self, device_id): - # """停止指定设备的检测线程,并确保其成功停止""" + # """停止指定设备的检测线程""" # if device_id in self.device_tasks: - # # 获取线程 ID 和 future 对象 - # thread_id = self.get_device_thread_id(device_id) - # future = self.thread_pool.check_task_stopped(thread_id=thread_id) - # - # if future: - # # 调用 stop_detection_task 停止任务 - # self.device_tasks[device_id].stop_detection_task() - # - # try: - # # 设置超时时间等待任务停止(例如10秒) - # result = future.result(timeout=10) - # logger.info(f"Task for device {device_id} stopped successfully.") - # except TimeoutError: - # logger.error(f"Task for device {device_id} did not stop within the timeout.") - # except Exception as e: - # logger.error(f"Task for device {device_id} encountered an error while stopping: {e}") - # finally: - # # 确保无论任务是否停止,都将其从任务列表中移除 - # del self.device_tasks[device_id] - # else: - # logger.warning(f"No task found for device {device_id}.") - # else: - # logger.warning(f"No task exists for device {device_id} in device_tasks.") + # logger.info(f'stop device {device_id} thread') + # self.device_tasks[device_id].stop_detection_task() + # del self.device_tasks[device_id] + + def stop_device_thread(self, device_id): + """停止指定设备的检测线程,并确保其成功停止""" + if device_id in self.device_tasks: + logger.info(f'stop device {device_id} thread') + # 获取线程 ID 和 future 对象 + thread_id = self.device_tasks[device_id].thread_id + future = self.thread_pool.get_task_future(thread_id=thread_id) + + if future: + if not future.done(): + # 任务正在运行,调用 stop_detection_task 停止任务 + self.device_tasks[device_id].stop_detection_task() + try: + # 设置超时时间等待任务停止(例如10秒) + result = future.result(timeout=10) + logger.info(f"Task {thread_id} stopped successfully.") + except TimeoutError: + logger.error(f"Task {thread_id} did not stop within the timeout.") + except Exception as e: + logger.error(f"Task {thread_id} encountered an error while stopping: {e}") + finally: + # 确保无论任务是否停止,都将其从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.info(f"Task {thread_id} has already stopped.") + # 任务已停止,直接从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.warning(f"No task found for {thread_id} .") + else: + logger.warning(f"No task exists for device {device_id} in device_tasks.") def restart_device_thread(self, device_id): self.stop_device_thread(device_id) - # todo 需要留个关闭时间吗 device = self.device_service.get_device(device_id) self.start_device_thread(device) def on_device_change(self, device_id, change_type): + logger.info(f"on device change, device {device_id} {change_type}") """设备变化回调""" if change_type == NotifyChangeType.DEVICE_CREATE: device = self.device_service.get_device(device_id) @@ -117,7 +134,10 @@ elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: - #todo 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) + self.restart_device_thread(device_id) def on_model_change(self, model_id, change_type): diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py new file mode 100644 index 0000000..d7d8015 --- /dev/null +++ b/algo/algo_runner_manager.py @@ -0,0 +1,7 @@ +from algo.algo_runner import AlgoRunner + +algo_runner = AlgoRunner() + + +def get_algo_runner(): + return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index ff2d513..5e1c75d 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -6,6 +6,7 @@ from algo.model_manager import AlgoModelExec from algo.stream_loader import OpenCVStreamLoad +from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool from common.string_utils import camel_to_snake from db.database import get_db @@ -40,11 +41,12 @@ class DeviceDetectionTask: - def __init__(self, device: Device, model_exec_list: List[AlgoModelExec]): + def __init__(self, device: Device, model_exec_list: List[AlgoModelExec], thread_id): self.device = device self.model_exec_list = model_exec_list self.__stop_event = Event() # 使用 Event 控制线程的运行状态 self.frame_ts = None + self.thread_id = thread_id with next(get_db()) as db: self.device_frame_service = DeviceFrameService(db) @@ -52,9 +54,11 @@ self.thread_pool = GlobalThreadPool() - self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code) + self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, + device_thread_id=thread_id) def stop_detection_task(self): + logger.info(f'stop detection task {self.device.id}, thread_id: {self.thread_id}') self.__stop_event.set() self.stream_loader.stop() # 停止视频流加载的线程 @@ -64,7 +68,7 @@ if self.frame_ts is None: self.frame_ts = datetime.now() return True - if datetime.now() - self.frame_ts > self.device.image_save_interval: + if (datetime.now() - self.frame_ts).total_seconds() > self.device.image_save_interval: self.frame_ts = datetime.now() return True return False @@ -74,6 +78,7 @@ return device_frame = self.device_frame_service.add_frame(self.device.id, frame) frame_id = device_frame.id + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') frame_results = [] for model_exec_id, results in results_map.items(): for r in results: @@ -83,7 +88,7 @@ algo_model_id=model_exec_id, object_class_id=r.object_class_id, object_class_name=r.object_class_name, - confidence=r.conf, + confidence=round(r.confidence, 4), location=r.location, ) frame_results.append(frame_result) @@ -93,6 +98,9 @@ for frame in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 + # print('frame') + if frame is None: + continue results_map = {} for model_exec in self.model_exec_list: diff --git a/algo/model_manager.py b/algo/model_manager.py index 535c666..94a058c 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -66,3 +66,15 @@ if algo_model_exec: algo_model_exec.algo_model_exec = None self.load_model(algo_model_exec) + + def load_new_model(self, model_id): + algo_model = self.model_service.get_model_by_id(model_id) + if algo_model: + algo_model_exec = AlgoModelExec( + algo_model_id=algo_model.id, + algo_model_info=algo_model + ) + self.models[algo_model.id] = algo_model_exec + self.load_model(algo_model_exec) + return algo_model_exec + return None diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 04faeb2..9a1199e 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -3,10 +3,11 @@ import numpy as np from threading import Thread, Event from common.global_logger import logger +from common.global_thread_pool import GlobalThreadPool class OpenCVStreamLoad: - def __init__(self, camera_url, camera_code, + def __init__(self, camera_url, camera_code, device_thread_id = '', retry_interval=1, vid_stride=1): assert camera_url is not None and camera_url != '' @@ -14,16 +15,26 @@ self.camera_code = camera_code self.retry_interval = retry_interval self.vid_stride = vid_stride + self.device_thread_id = device_thread_id - self.cap = self.get_connect() - self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) - self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) - self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + self.init = False + self.frame = None + self.cap = None - _, self.frame = self.cap.read() self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 - self.thread = Thread(target=self.update, daemon=True) - self.thread.start() + self.thread_pool = GlobalThreadPool() + self.thread_pool.submit_task(self.update) + # self.thread = Thread(target=self.update, daemon=True) + # self.thread.start() + + def init_cap(self): + self.cap = self.get_connect() + if self.cap: + self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + _, self.frame = self.cap.read() + self.init = True def create_capture(self): """ @@ -43,7 +54,11 @@ """ cap = None while cap is None or not cap.isOpened(): + if self.__stop_event.is_set(): # 检查是否收到停止信号 + logger.info(f"{self.url} stopping connection attempts...thread {self.device_thread_id} stopped") + return None # 退出循环,返回 None 或者其他你认为合适的值 logger.info(f"{self.url} try to connect...") + print('cap') cap = self.create_capture() if cap is None or not cap.isOpened(): logger.info(f"{self.url} connect failed, retry after {self.retry_interval} second...") @@ -55,13 +70,18 @@ def update(self): vid_n = 0 log_n = 0 - while True: + while not self.__stop_event.is_set(): + print('update') + if not self.init: + self.init_cap() + if self.cap is None: + continue vid_n += 1 if vid_n % self.vid_stride == 0: try: ret, frame = self.cap.read() if not ret: - logger.info("disconnect, try to reconnect...") + logger.info(f"{self.url} disconnect, try to reconnect...") self.cap.release() # 释放当前的捕获对象 self.cap = self.get_connect() # 尝试重新连接 self.frame = np.zeros_like(self.frame) @@ -69,12 +89,11 @@ else: vid_n += 1 self.frame = frame - # cv2.imwrite('cv_test.jpg', frame) if log_n % 1000 == 0: - logger.debug('cap success') + logger.debug(f'{self.url} cap success') log_n = (log_n + 1) % 250 except Exception as e: - logger.error("update fail", e) + logger.error(f"{self.url} update fail", e) if self.cap is not None: self.cap.release() self.cap = self.get_connect() # 尝试重新连接 @@ -86,7 +105,9 @@ return self.frame def stop(self): + logger.info(f'stop stream loader {self.url},device thread {self.device_thread_id} stopped') """ 停止视频流读取线程 """ self.__stop_event.set() - self.thread.join() # 确保线程已完全终止 - self.cap.release() + # self.thread.join() # 确保线程已完全终止 + if self.cap: + self.cap.release() diff --git a/apis/device.py b/apis/device.py index 7eeaa78..3be394f 100644 --- a/apis/device.py +++ b/apis/device.py @@ -3,11 +3,14 @@ from fastapi import APIRouter, Depends, Query from sqlmodel import Session +from algo.algo_runner import AlgoRunner from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response from db.database import get_db from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo from services.device_service import DeviceService +from algo.algo_runner_manager import get_algo_runner + router = APIRouter() @@ -41,15 +44,15 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)): - service = DeviceService(db) +def create_device(device_data: DeviceCreate, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.device_service device = service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)): - service = DeviceService(db) +def update_device(device_data: DeviceUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.device_service device = service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") @@ -57,8 +60,8 @@ @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, db: Session = Depends(get_db)): - service = DeviceService(db) +def delete_device(device_id: int, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.device_service device = service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py index 7108571..9152eca 100644 --- a/apis/device_model_realtion.py +++ b/apis/device_model_realtion.py @@ -8,6 +8,9 @@ from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation from services.device_model_relation_service import DeviceModelRelationService +from algo.algo_runner import AlgoRunner +from algo.algo_runner_manager import get_algo_runner + router = APIRouter() @@ -32,8 +35,8 @@ @router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) def update_by_device(relation_data: List[DeviceModelRelationCreate], device_id: int = Query(...), - db: Session = Depends(get_db)): - service = DeviceModelRelationService(db) + algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.relation_service relations = service.update_relations_by_device(device_id, relation_data) return standard_response(data=relations) diff --git a/apis/model.py b/apis/model.py index 8adcc14..da63a61 100644 --- a/apis/model.py +++ b/apis/model.py @@ -1,14 +1,17 @@ from typing import List, Optional -from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi import APIRouter, Depends, Query from sqlmodel import Session + from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response -from common.biz_exception import BizException from db.database import get_db from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo from services.model_service import ModelService +from algo.algo_runner import AlgoRunner +from algo.algo_runner_manager import get_algo_runner + router = APIRouter() @@ -47,8 +50,8 @@ @router.post("/update", response_model=StandardResponse[AlgoModelInfo]) -def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)): - service = ModelService(db) +def update_model(model_data: AlgoModelUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.model_service model = service.update_model(model_data) if not model: return standard_error_response(data=model_data, message="Model not found") diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py index f55359b..7f3be9f 100644 --- a/common/global_thread_pool.py +++ b/common/global_thread_pool.py @@ -1,3 +1,4 @@ +import traceback import uuid from concurrent.futures import ThreadPoolExecutor import threading @@ -33,6 +34,7 @@ raise ValueError(f"线程 ID {thread_id} 已存在") future = self.executor.submit(fn, *args, **kwargs) self.task_map[thread_id] = future # 记录线程 ID 和 Future 对象的映射 + future.add_done_callback(lambda f: self._handle_exception(f, thread_id)) return thread_id def check_task_is_running(self, thread_id): @@ -64,6 +66,15 @@ logger.warning(f"No task found with thread ID {thread_id}.") return True # 如果找不到该任务,认为它已经停止(或者不存在) + def get_task_future(self, thread_id): + """获取指定线程 ID 的 future 对象""" + future = self.task_map.get(thread_id) + if future: + return future + else: + logger.warning(f"No task found with thread ID {thread_id}.") + return None + def stop_task(self, thread_id): """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务""" future = self.task_map.get(thread_id) @@ -78,3 +89,16 @@ """关闭线程池""" self.executor.shutdown(wait=wait) GlobalThreadPool._instance = None + + def _handle_exception(self, future, thread_id): + """ + 处理任务完成时的异常 + :param future: 完成的任务 future 对象 + :param camera_id: 对应的摄像头 ID + """ + try: + # 获取任务结果,如果任务有异常,这里会抛出 + result = future.result() + except Exception as e: + logger.error(f"Task for thread {thread_id} raised an exception: {e}") + logger.error(f"Traceback: {traceback.format_exc()}") \ No newline at end of file diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db index 62f6546..765b724 100644 --- a/db/safe-algo-pro.db +++ b/db/safe-algo-pro.db Binary files differ diff --git a/entity/device.py b/entity/device.py index f09b52d..4e30fbd 100644 --- a/entity/device.py +++ b/entity/device.py @@ -17,6 +17,11 @@ class Device(DeviceBase, TimestampMixin, table=True): id: Optional[int] = Field(default=None, primary_key=True) + def __str__(self): + return f"id={self.id} name={self.name} code={self.code} ip={self.ip}" \ + f" input_stream_url={self.input_stream_url} output_stream_url={self.output_stream_url}" \ + f" image_save_interval={self.image_save_interval}" + class DeviceCreate(DeviceBase): pass diff --git a/entity/device_frame.py b/entity/device_frame.py index 5612725..7c9aa89 100644 --- a/entity/device_frame.py +++ b/entity/device_frame.py @@ -10,7 +10,7 @@ time: datetime = Field(default_factory=datetime.now) -class DeviceFrame(DeviceFrameBase): +class DeviceFrame(DeviceFrameBase, table = True): __tablename__ = 'device_frame' id: Optional[int] = Field(default=None, primary_key=True) diff --git a/entity/frame_analysis_result.py b/entity/frame_analysis_result.py index 2185e4b..ca8793e 100644 --- a/entity/frame_analysis_result.py +++ b/entity/frame_analysis_result.py @@ -15,7 +15,7 @@ time: datetime = Field(default_factory=datetime.now) -class FrameAnalysisResult(FrameAnalysisResultBase): +class FrameAnalysisResult(FrameAnalysisResultBase, table = True): __tablename__ = 'frame_analysis_result' id: Optional[int] = Field(default=None, primary_key=True) diff --git a/main.py b/main.py index fdbfb7d..1121638 100644 --- a/main.py +++ b/main.py @@ -2,7 +2,7 @@ from fastapi import FastAPI, HTTPException -from algo.algo_runner import AlgoRunner +from algo.algo_runner_manager import get_algo_runner from apis.router import router import uvicorn import logging @@ -22,7 +22,7 @@ # async def startup_event(): # algo_runner.start() -algo_runner = AlgoRunner() +algo_runner = get_algo_runner() @asynccontextmanager diff --git a/model_handler/base_model_handler.py b/model_handler/base_model_handler.py index 7dd4999..68c04ac 100644 --- a/model_handler/base_model_handler.py +++ b/model_handler/base_model_handler.py @@ -5,7 +5,7 @@ def __init__(self, model: AlgoModelExec): self.model = model - self.model_names = model.algo_model_exec.model_name + self.model_names = model.algo_model_exec.names def pre_process(self, frame): return frame @@ -27,7 +27,7 @@ 'object_class_id': int(box.cls), 'object_class_name': self.model_names[int(box.cls)], 'confidence': float(box.conf), - 'location': box.xyxyn.cpu().squeeze().tolist() + 'location': ", ".join([f"{x:.6f}" for x in box.xyxyn.cpu().squeeze().tolist()]) } ) if annotator is not None: diff --git a/algo/algo_runner.py b/algo/algo_runner.py index 3b07e0a..a97aee9 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,7 @@ +import copy +import uuid +from typing import Dict + from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager from common.consts import NotifyChangeType @@ -9,6 +13,7 @@ from services.model_service import ModelService from common.global_logger import logger + class AlgoRunner: def __init__(self): with next(get_db()) as db: @@ -18,7 +23,7 @@ self.model_manager = ModelManager(db) self.thread_pool = GlobalThreadPool() - self.device_tasks = {} # 用于存储设备对应的线程 + self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 @@ -38,22 +43,24 @@ for device in devices: self.start_device_thread(device) - @staticmethod - def get_device_thread_id(device_id): - return f'device_{device_id}' + # @staticmethod + # def get_device_thread_id(device_id): + # return f'device_{device_id}' def start_device_thread(self, device: Device): + device = copy.deepcopy(device) """为单个设备启动检测线程""" if not device.input_stream_url: logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return - thread_id = self.get_device_thread_id(device.id) + thread_id = f'device_{device.id}_{uuid.uuid4()}' + logger.info(f'start thread {thread_id}, device info: {device}') - if self.thread_pool.check_task_is_running(thread_id=thread_id): - logger.info(f"设备 {device.code} 已经在运行中") - return + # if self.thread_pool.check_task_is_running(thread_id=thread_id): + # logger.info(f"设备 {device.code} 已经在运行中") + # return # 获取设备绑定的模型列表 relations = self.relation_service.get_device_models(device.id) @@ -64,52 +71,62 @@ logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") return + for r in relations: + if r.algo_model_id not in self.model_manager.models: + self.model_manager.load_new_model(r.algo_model_id) + model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] - device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list) + device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list, thread_id=thread_id) future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) self.device_tasks[device.id] = device_detection_task - def stop_device_thread(self, device_id): - """停止指定设备的检测线程""" - if device_id in self.device_tasks: - self.device_tasks[device_id].stop_detection_task() - del self.device_tasks[device_id] - # def stop_device_thread(self, device_id): - # """停止指定设备的检测线程,并确保其成功停止""" + # """停止指定设备的检测线程""" # if device_id in self.device_tasks: - # # 获取线程 ID 和 future 对象 - # thread_id = self.get_device_thread_id(device_id) - # future = self.thread_pool.check_task_stopped(thread_id=thread_id) - # - # if future: - # # 调用 stop_detection_task 停止任务 - # self.device_tasks[device_id].stop_detection_task() - # - # try: - # # 设置超时时间等待任务停止(例如10秒) - # result = future.result(timeout=10) - # logger.info(f"Task for device {device_id} stopped successfully.") - # except TimeoutError: - # logger.error(f"Task for device {device_id} did not stop within the timeout.") - # except Exception as e: - # logger.error(f"Task for device {device_id} encountered an error while stopping: {e}") - # finally: - # # 确保无论任务是否停止,都将其从任务列表中移除 - # del self.device_tasks[device_id] - # else: - # logger.warning(f"No task found for device {device_id}.") - # else: - # logger.warning(f"No task exists for device {device_id} in device_tasks.") + # logger.info(f'stop device {device_id} thread') + # self.device_tasks[device_id].stop_detection_task() + # del self.device_tasks[device_id] + + def stop_device_thread(self, device_id): + """停止指定设备的检测线程,并确保其成功停止""" + if device_id in self.device_tasks: + logger.info(f'stop device {device_id} thread') + # 获取线程 ID 和 future 对象 + thread_id = self.device_tasks[device_id].thread_id + future = self.thread_pool.get_task_future(thread_id=thread_id) + + if future: + if not future.done(): + # 任务正在运行,调用 stop_detection_task 停止任务 + self.device_tasks[device_id].stop_detection_task() + try: + # 设置超时时间等待任务停止(例如10秒) + result = future.result(timeout=10) + logger.info(f"Task {thread_id} stopped successfully.") + except TimeoutError: + logger.error(f"Task {thread_id} did not stop within the timeout.") + except Exception as e: + logger.error(f"Task {thread_id} encountered an error while stopping: {e}") + finally: + # 确保无论任务是否停止,都将其从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.info(f"Task {thread_id} has already stopped.") + # 任务已停止,直接从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.warning(f"No task found for {thread_id} .") + else: + logger.warning(f"No task exists for device {device_id} in device_tasks.") def restart_device_thread(self, device_id): self.stop_device_thread(device_id) - # todo 需要留个关闭时间吗 device = self.device_service.get_device(device_id) self.start_device_thread(device) def on_device_change(self, device_id, change_type): + logger.info(f"on device change, device {device_id} {change_type}") """设备变化回调""" if change_type == NotifyChangeType.DEVICE_CREATE: device = self.device_service.get_device(device_id) @@ -117,7 +134,10 @@ elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: - #todo 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) + self.restart_device_thread(device_id) def on_model_change(self, model_id, change_type): diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py new file mode 100644 index 0000000..d7d8015 --- /dev/null +++ b/algo/algo_runner_manager.py @@ -0,0 +1,7 @@ +from algo.algo_runner import AlgoRunner + +algo_runner = AlgoRunner() + + +def get_algo_runner(): + return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index ff2d513..5e1c75d 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -6,6 +6,7 @@ from algo.model_manager import AlgoModelExec from algo.stream_loader import OpenCVStreamLoad +from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool from common.string_utils import camel_to_snake from db.database import get_db @@ -40,11 +41,12 @@ class DeviceDetectionTask: - def __init__(self, device: Device, model_exec_list: List[AlgoModelExec]): + def __init__(self, device: Device, model_exec_list: List[AlgoModelExec], thread_id): self.device = device self.model_exec_list = model_exec_list self.__stop_event = Event() # 使用 Event 控制线程的运行状态 self.frame_ts = None + self.thread_id = thread_id with next(get_db()) as db: self.device_frame_service = DeviceFrameService(db) @@ -52,9 +54,11 @@ self.thread_pool = GlobalThreadPool() - self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code) + self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, + device_thread_id=thread_id) def stop_detection_task(self): + logger.info(f'stop detection task {self.device.id}, thread_id: {self.thread_id}') self.__stop_event.set() self.stream_loader.stop() # 停止视频流加载的线程 @@ -64,7 +68,7 @@ if self.frame_ts is None: self.frame_ts = datetime.now() return True - if datetime.now() - self.frame_ts > self.device.image_save_interval: + if (datetime.now() - self.frame_ts).total_seconds() > self.device.image_save_interval: self.frame_ts = datetime.now() return True return False @@ -74,6 +78,7 @@ return device_frame = self.device_frame_service.add_frame(self.device.id, frame) frame_id = device_frame.id + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') frame_results = [] for model_exec_id, results in results_map.items(): for r in results: @@ -83,7 +88,7 @@ algo_model_id=model_exec_id, object_class_id=r.object_class_id, object_class_name=r.object_class_name, - confidence=r.conf, + confidence=round(r.confidence, 4), location=r.location, ) frame_results.append(frame_result) @@ -93,6 +98,9 @@ for frame in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 + # print('frame') + if frame is None: + continue results_map = {} for model_exec in self.model_exec_list: diff --git a/algo/model_manager.py b/algo/model_manager.py index 535c666..94a058c 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -66,3 +66,15 @@ if algo_model_exec: algo_model_exec.algo_model_exec = None self.load_model(algo_model_exec) + + def load_new_model(self, model_id): + algo_model = self.model_service.get_model_by_id(model_id) + if algo_model: + algo_model_exec = AlgoModelExec( + algo_model_id=algo_model.id, + algo_model_info=algo_model + ) + self.models[algo_model.id] = algo_model_exec + self.load_model(algo_model_exec) + return algo_model_exec + return None diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 04faeb2..9a1199e 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -3,10 +3,11 @@ import numpy as np from threading import Thread, Event from common.global_logger import logger +from common.global_thread_pool import GlobalThreadPool class OpenCVStreamLoad: - def __init__(self, camera_url, camera_code, + def __init__(self, camera_url, camera_code, device_thread_id = '', retry_interval=1, vid_stride=1): assert camera_url is not None and camera_url != '' @@ -14,16 +15,26 @@ self.camera_code = camera_code self.retry_interval = retry_interval self.vid_stride = vid_stride + self.device_thread_id = device_thread_id - self.cap = self.get_connect() - self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) - self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) - self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + self.init = False + self.frame = None + self.cap = None - _, self.frame = self.cap.read() self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 - self.thread = Thread(target=self.update, daemon=True) - self.thread.start() + self.thread_pool = GlobalThreadPool() + self.thread_pool.submit_task(self.update) + # self.thread = Thread(target=self.update, daemon=True) + # self.thread.start() + + def init_cap(self): + self.cap = self.get_connect() + if self.cap: + self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + _, self.frame = self.cap.read() + self.init = True def create_capture(self): """ @@ -43,7 +54,11 @@ """ cap = None while cap is None or not cap.isOpened(): + if self.__stop_event.is_set(): # 检查是否收到停止信号 + logger.info(f"{self.url} stopping connection attempts...thread {self.device_thread_id} stopped") + return None # 退出循环,返回 None 或者其他你认为合适的值 logger.info(f"{self.url} try to connect...") + print('cap') cap = self.create_capture() if cap is None or not cap.isOpened(): logger.info(f"{self.url} connect failed, retry after {self.retry_interval} second...") @@ -55,13 +70,18 @@ def update(self): vid_n = 0 log_n = 0 - while True: + while not self.__stop_event.is_set(): + print('update') + if not self.init: + self.init_cap() + if self.cap is None: + continue vid_n += 1 if vid_n % self.vid_stride == 0: try: ret, frame = self.cap.read() if not ret: - logger.info("disconnect, try to reconnect...") + logger.info(f"{self.url} disconnect, try to reconnect...") self.cap.release() # 释放当前的捕获对象 self.cap = self.get_connect() # 尝试重新连接 self.frame = np.zeros_like(self.frame) @@ -69,12 +89,11 @@ else: vid_n += 1 self.frame = frame - # cv2.imwrite('cv_test.jpg', frame) if log_n % 1000 == 0: - logger.debug('cap success') + logger.debug(f'{self.url} cap success') log_n = (log_n + 1) % 250 except Exception as e: - logger.error("update fail", e) + logger.error(f"{self.url} update fail", e) if self.cap is not None: self.cap.release() self.cap = self.get_connect() # 尝试重新连接 @@ -86,7 +105,9 @@ return self.frame def stop(self): + logger.info(f'stop stream loader {self.url},device thread {self.device_thread_id} stopped') """ 停止视频流读取线程 """ self.__stop_event.set() - self.thread.join() # 确保线程已完全终止 - self.cap.release() + # self.thread.join() # 确保线程已完全终止 + if self.cap: + self.cap.release() diff --git a/apis/device.py b/apis/device.py index 7eeaa78..3be394f 100644 --- a/apis/device.py +++ b/apis/device.py @@ -3,11 +3,14 @@ from fastapi import APIRouter, Depends, Query from sqlmodel import Session +from algo.algo_runner import AlgoRunner from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response from db.database import get_db from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo from services.device_service import DeviceService +from algo.algo_runner_manager import get_algo_runner + router = APIRouter() @@ -41,15 +44,15 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)): - service = DeviceService(db) +def create_device(device_data: DeviceCreate, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.device_service device = service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)): - service = DeviceService(db) +def update_device(device_data: DeviceUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.device_service device = service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") @@ -57,8 +60,8 @@ @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, db: Session = Depends(get_db)): - service = DeviceService(db) +def delete_device(device_id: int, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.device_service device = service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py index 7108571..9152eca 100644 --- a/apis/device_model_realtion.py +++ b/apis/device_model_realtion.py @@ -8,6 +8,9 @@ from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation from services.device_model_relation_service import DeviceModelRelationService +from algo.algo_runner import AlgoRunner +from algo.algo_runner_manager import get_algo_runner + router = APIRouter() @@ -32,8 +35,8 @@ @router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) def update_by_device(relation_data: List[DeviceModelRelationCreate], device_id: int = Query(...), - db: Session = Depends(get_db)): - service = DeviceModelRelationService(db) + algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.relation_service relations = service.update_relations_by_device(device_id, relation_data) return standard_response(data=relations) diff --git a/apis/model.py b/apis/model.py index 8adcc14..da63a61 100644 --- a/apis/model.py +++ b/apis/model.py @@ -1,14 +1,17 @@ from typing import List, Optional -from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi import APIRouter, Depends, Query from sqlmodel import Session + from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response -from common.biz_exception import BizException from db.database import get_db from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo from services.model_service import ModelService +from algo.algo_runner import AlgoRunner +from algo.algo_runner_manager import get_algo_runner + router = APIRouter() @@ -47,8 +50,8 @@ @router.post("/update", response_model=StandardResponse[AlgoModelInfo]) -def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)): - service = ModelService(db) +def update_model(model_data: AlgoModelUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.model_service model = service.update_model(model_data) if not model: return standard_error_response(data=model_data, message="Model not found") diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py index f55359b..7f3be9f 100644 --- a/common/global_thread_pool.py +++ b/common/global_thread_pool.py @@ -1,3 +1,4 @@ +import traceback import uuid from concurrent.futures import ThreadPoolExecutor import threading @@ -33,6 +34,7 @@ raise ValueError(f"线程 ID {thread_id} 已存在") future = self.executor.submit(fn, *args, **kwargs) self.task_map[thread_id] = future # 记录线程 ID 和 Future 对象的映射 + future.add_done_callback(lambda f: self._handle_exception(f, thread_id)) return thread_id def check_task_is_running(self, thread_id): @@ -64,6 +66,15 @@ logger.warning(f"No task found with thread ID {thread_id}.") return True # 如果找不到该任务,认为它已经停止(或者不存在) + def get_task_future(self, thread_id): + """获取指定线程 ID 的 future 对象""" + future = self.task_map.get(thread_id) + if future: + return future + else: + logger.warning(f"No task found with thread ID {thread_id}.") + return None + def stop_task(self, thread_id): """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务""" future = self.task_map.get(thread_id) @@ -78,3 +89,16 @@ """关闭线程池""" self.executor.shutdown(wait=wait) GlobalThreadPool._instance = None + + def _handle_exception(self, future, thread_id): + """ + 处理任务完成时的异常 + :param future: 完成的任务 future 对象 + :param camera_id: 对应的摄像头 ID + """ + try: + # 获取任务结果,如果任务有异常,这里会抛出 + result = future.result() + except Exception as e: + logger.error(f"Task for thread {thread_id} raised an exception: {e}") + logger.error(f"Traceback: {traceback.format_exc()}") \ No newline at end of file diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db index 62f6546..765b724 100644 --- a/db/safe-algo-pro.db +++ b/db/safe-algo-pro.db Binary files differ diff --git a/entity/device.py b/entity/device.py index f09b52d..4e30fbd 100644 --- a/entity/device.py +++ b/entity/device.py @@ -17,6 +17,11 @@ class Device(DeviceBase, TimestampMixin, table=True): id: Optional[int] = Field(default=None, primary_key=True) + def __str__(self): + return f"id={self.id} name={self.name} code={self.code} ip={self.ip}" \ + f" input_stream_url={self.input_stream_url} output_stream_url={self.output_stream_url}" \ + f" image_save_interval={self.image_save_interval}" + class DeviceCreate(DeviceBase): pass diff --git a/entity/device_frame.py b/entity/device_frame.py index 5612725..7c9aa89 100644 --- a/entity/device_frame.py +++ b/entity/device_frame.py @@ -10,7 +10,7 @@ time: datetime = Field(default_factory=datetime.now) -class DeviceFrame(DeviceFrameBase): +class DeviceFrame(DeviceFrameBase, table = True): __tablename__ = 'device_frame' id: Optional[int] = Field(default=None, primary_key=True) diff --git a/entity/frame_analysis_result.py b/entity/frame_analysis_result.py index 2185e4b..ca8793e 100644 --- a/entity/frame_analysis_result.py +++ b/entity/frame_analysis_result.py @@ -15,7 +15,7 @@ time: datetime = Field(default_factory=datetime.now) -class FrameAnalysisResult(FrameAnalysisResultBase): +class FrameAnalysisResult(FrameAnalysisResultBase, table = True): __tablename__ = 'frame_analysis_result' id: Optional[int] = Field(default=None, primary_key=True) diff --git a/main.py b/main.py index fdbfb7d..1121638 100644 --- a/main.py +++ b/main.py @@ -2,7 +2,7 @@ from fastapi import FastAPI, HTTPException -from algo.algo_runner import AlgoRunner +from algo.algo_runner_manager import get_algo_runner from apis.router import router import uvicorn import logging @@ -22,7 +22,7 @@ # async def startup_event(): # algo_runner.start() -algo_runner = AlgoRunner() +algo_runner = get_algo_runner() @asynccontextmanager diff --git a/model_handler/base_model_handler.py b/model_handler/base_model_handler.py index 7dd4999..68c04ac 100644 --- a/model_handler/base_model_handler.py +++ b/model_handler/base_model_handler.py @@ -5,7 +5,7 @@ def __init__(self, model: AlgoModelExec): self.model = model - self.model_names = model.algo_model_exec.model_name + self.model_names = model.algo_model_exec.names def pre_process(self, frame): return frame @@ -27,7 +27,7 @@ 'object_class_id': int(box.cls), 'object_class_name': self.model_names[int(box.cls)], 'confidence': float(box.conf), - 'location': box.xyxyn.cpu().squeeze().tolist() + 'location': ", ".join([f"{x:.6f}" for x in box.xyxyn.cpu().squeeze().tolist()]) } ) if annotator is not None: diff --git a/services/device_frame_service.py b/services/device_frame_service.py index b062dd4..6b9c308 100644 --- a/services/device_frame_service.py +++ b/services/device_frame_service.py @@ -16,9 +16,9 @@ def add_frame(self, device_id, frame_data) -> DeviceFrame: def save_frame_file(): # 生成当前年月日作为目录路径 - current_date = datetime.now().strftime('%Y/%m/%d') + current_date = datetime.now().strftime('%Y-%m-%d') # 生成随机 UUID 作为文件名 - file_name = str(uuid.uuid4()) + ".png" + file_name = str(uuid.uuid4()) + ".jpeg" # 创建保存图片的完整路径 save_path = os.path.join('./storage/frames', current_date, file_name) # 创建目录(如果不存在) diff --git a/algo/algo_runner.py b/algo/algo_runner.py index 3b07e0a..a97aee9 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,7 @@ +import copy +import uuid +from typing import Dict + from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager from common.consts import NotifyChangeType @@ -9,6 +13,7 @@ from services.model_service import ModelService from common.global_logger import logger + class AlgoRunner: def __init__(self): with next(get_db()) as db: @@ -18,7 +23,7 @@ self.model_manager = ModelManager(db) self.thread_pool = GlobalThreadPool() - self.device_tasks = {} # 用于存储设备对应的线程 + self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 @@ -38,22 +43,24 @@ for device in devices: self.start_device_thread(device) - @staticmethod - def get_device_thread_id(device_id): - return f'device_{device_id}' + # @staticmethod + # def get_device_thread_id(device_id): + # return f'device_{device_id}' def start_device_thread(self, device: Device): + device = copy.deepcopy(device) """为单个设备启动检测线程""" if not device.input_stream_url: logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return - thread_id = self.get_device_thread_id(device.id) + thread_id = f'device_{device.id}_{uuid.uuid4()}' + logger.info(f'start thread {thread_id}, device info: {device}') - if self.thread_pool.check_task_is_running(thread_id=thread_id): - logger.info(f"设备 {device.code} 已经在运行中") - return + # if self.thread_pool.check_task_is_running(thread_id=thread_id): + # logger.info(f"设备 {device.code} 已经在运行中") + # return # 获取设备绑定的模型列表 relations = self.relation_service.get_device_models(device.id) @@ -64,52 +71,62 @@ logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") return + for r in relations: + if r.algo_model_id not in self.model_manager.models: + self.model_manager.load_new_model(r.algo_model_id) + model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] - device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list) + device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list, thread_id=thread_id) future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) self.device_tasks[device.id] = device_detection_task - def stop_device_thread(self, device_id): - """停止指定设备的检测线程""" - if device_id in self.device_tasks: - self.device_tasks[device_id].stop_detection_task() - del self.device_tasks[device_id] - # def stop_device_thread(self, device_id): - # """停止指定设备的检测线程,并确保其成功停止""" + # """停止指定设备的检测线程""" # if device_id in self.device_tasks: - # # 获取线程 ID 和 future 对象 - # thread_id = self.get_device_thread_id(device_id) - # future = self.thread_pool.check_task_stopped(thread_id=thread_id) - # - # if future: - # # 调用 stop_detection_task 停止任务 - # self.device_tasks[device_id].stop_detection_task() - # - # try: - # # 设置超时时间等待任务停止(例如10秒) - # result = future.result(timeout=10) - # logger.info(f"Task for device {device_id} stopped successfully.") - # except TimeoutError: - # logger.error(f"Task for device {device_id} did not stop within the timeout.") - # except Exception as e: - # logger.error(f"Task for device {device_id} encountered an error while stopping: {e}") - # finally: - # # 确保无论任务是否停止,都将其从任务列表中移除 - # del self.device_tasks[device_id] - # else: - # logger.warning(f"No task found for device {device_id}.") - # else: - # logger.warning(f"No task exists for device {device_id} in device_tasks.") + # logger.info(f'stop device {device_id} thread') + # self.device_tasks[device_id].stop_detection_task() + # del self.device_tasks[device_id] + + def stop_device_thread(self, device_id): + """停止指定设备的检测线程,并确保其成功停止""" + if device_id in self.device_tasks: + logger.info(f'stop device {device_id} thread') + # 获取线程 ID 和 future 对象 + thread_id = self.device_tasks[device_id].thread_id + future = self.thread_pool.get_task_future(thread_id=thread_id) + + if future: + if not future.done(): + # 任务正在运行,调用 stop_detection_task 停止任务 + self.device_tasks[device_id].stop_detection_task() + try: + # 设置超时时间等待任务停止(例如10秒) + result = future.result(timeout=10) + logger.info(f"Task {thread_id} stopped successfully.") + except TimeoutError: + logger.error(f"Task {thread_id} did not stop within the timeout.") + except Exception as e: + logger.error(f"Task {thread_id} encountered an error while stopping: {e}") + finally: + # 确保无论任务是否停止,都将其从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.info(f"Task {thread_id} has already stopped.") + # 任务已停止,直接从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.warning(f"No task found for {thread_id} .") + else: + logger.warning(f"No task exists for device {device_id} in device_tasks.") def restart_device_thread(self, device_id): self.stop_device_thread(device_id) - # todo 需要留个关闭时间吗 device = self.device_service.get_device(device_id) self.start_device_thread(device) def on_device_change(self, device_id, change_type): + logger.info(f"on device change, device {device_id} {change_type}") """设备变化回调""" if change_type == NotifyChangeType.DEVICE_CREATE: device = self.device_service.get_device(device_id) @@ -117,7 +134,10 @@ elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: - #todo 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) + self.restart_device_thread(device_id) def on_model_change(self, model_id, change_type): diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py new file mode 100644 index 0000000..d7d8015 --- /dev/null +++ b/algo/algo_runner_manager.py @@ -0,0 +1,7 @@ +from algo.algo_runner import AlgoRunner + +algo_runner = AlgoRunner() + + +def get_algo_runner(): + return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index ff2d513..5e1c75d 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -6,6 +6,7 @@ from algo.model_manager import AlgoModelExec from algo.stream_loader import OpenCVStreamLoad +from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool from common.string_utils import camel_to_snake from db.database import get_db @@ -40,11 +41,12 @@ class DeviceDetectionTask: - def __init__(self, device: Device, model_exec_list: List[AlgoModelExec]): + def __init__(self, device: Device, model_exec_list: List[AlgoModelExec], thread_id): self.device = device self.model_exec_list = model_exec_list self.__stop_event = Event() # 使用 Event 控制线程的运行状态 self.frame_ts = None + self.thread_id = thread_id with next(get_db()) as db: self.device_frame_service = DeviceFrameService(db) @@ -52,9 +54,11 @@ self.thread_pool = GlobalThreadPool() - self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code) + self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, + device_thread_id=thread_id) def stop_detection_task(self): + logger.info(f'stop detection task {self.device.id}, thread_id: {self.thread_id}') self.__stop_event.set() self.stream_loader.stop() # 停止视频流加载的线程 @@ -64,7 +68,7 @@ if self.frame_ts is None: self.frame_ts = datetime.now() return True - if datetime.now() - self.frame_ts > self.device.image_save_interval: + if (datetime.now() - self.frame_ts).total_seconds() > self.device.image_save_interval: self.frame_ts = datetime.now() return True return False @@ -74,6 +78,7 @@ return device_frame = self.device_frame_service.add_frame(self.device.id, frame) frame_id = device_frame.id + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') frame_results = [] for model_exec_id, results in results_map.items(): for r in results: @@ -83,7 +88,7 @@ algo_model_id=model_exec_id, object_class_id=r.object_class_id, object_class_name=r.object_class_name, - confidence=r.conf, + confidence=round(r.confidence, 4), location=r.location, ) frame_results.append(frame_result) @@ -93,6 +98,9 @@ for frame in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 + # print('frame') + if frame is None: + continue results_map = {} for model_exec in self.model_exec_list: diff --git a/algo/model_manager.py b/algo/model_manager.py index 535c666..94a058c 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -66,3 +66,15 @@ if algo_model_exec: algo_model_exec.algo_model_exec = None self.load_model(algo_model_exec) + + def load_new_model(self, model_id): + algo_model = self.model_service.get_model_by_id(model_id) + if algo_model: + algo_model_exec = AlgoModelExec( + algo_model_id=algo_model.id, + algo_model_info=algo_model + ) + self.models[algo_model.id] = algo_model_exec + self.load_model(algo_model_exec) + return algo_model_exec + return None diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 04faeb2..9a1199e 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -3,10 +3,11 @@ import numpy as np from threading import Thread, Event from common.global_logger import logger +from common.global_thread_pool import GlobalThreadPool class OpenCVStreamLoad: - def __init__(self, camera_url, camera_code, + def __init__(self, camera_url, camera_code, device_thread_id = '', retry_interval=1, vid_stride=1): assert camera_url is not None and camera_url != '' @@ -14,16 +15,26 @@ self.camera_code = camera_code self.retry_interval = retry_interval self.vid_stride = vid_stride + self.device_thread_id = device_thread_id - self.cap = self.get_connect() - self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) - self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) - self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + self.init = False + self.frame = None + self.cap = None - _, self.frame = self.cap.read() self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 - self.thread = Thread(target=self.update, daemon=True) - self.thread.start() + self.thread_pool = GlobalThreadPool() + self.thread_pool.submit_task(self.update) + # self.thread = Thread(target=self.update, daemon=True) + # self.thread.start() + + def init_cap(self): + self.cap = self.get_connect() + if self.cap: + self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + _, self.frame = self.cap.read() + self.init = True def create_capture(self): """ @@ -43,7 +54,11 @@ """ cap = None while cap is None or not cap.isOpened(): + if self.__stop_event.is_set(): # 检查是否收到停止信号 + logger.info(f"{self.url} stopping connection attempts...thread {self.device_thread_id} stopped") + return None # 退出循环,返回 None 或者其他你认为合适的值 logger.info(f"{self.url} try to connect...") + print('cap') cap = self.create_capture() if cap is None or not cap.isOpened(): logger.info(f"{self.url} connect failed, retry after {self.retry_interval} second...") @@ -55,13 +70,18 @@ def update(self): vid_n = 0 log_n = 0 - while True: + while not self.__stop_event.is_set(): + print('update') + if not self.init: + self.init_cap() + if self.cap is None: + continue vid_n += 1 if vid_n % self.vid_stride == 0: try: ret, frame = self.cap.read() if not ret: - logger.info("disconnect, try to reconnect...") + logger.info(f"{self.url} disconnect, try to reconnect...") self.cap.release() # 释放当前的捕获对象 self.cap = self.get_connect() # 尝试重新连接 self.frame = np.zeros_like(self.frame) @@ -69,12 +89,11 @@ else: vid_n += 1 self.frame = frame - # cv2.imwrite('cv_test.jpg', frame) if log_n % 1000 == 0: - logger.debug('cap success') + logger.debug(f'{self.url} cap success') log_n = (log_n + 1) % 250 except Exception as e: - logger.error("update fail", e) + logger.error(f"{self.url} update fail", e) if self.cap is not None: self.cap.release() self.cap = self.get_connect() # 尝试重新连接 @@ -86,7 +105,9 @@ return self.frame def stop(self): + logger.info(f'stop stream loader {self.url},device thread {self.device_thread_id} stopped') """ 停止视频流读取线程 """ self.__stop_event.set() - self.thread.join() # 确保线程已完全终止 - self.cap.release() + # self.thread.join() # 确保线程已完全终止 + if self.cap: + self.cap.release() diff --git a/apis/device.py b/apis/device.py index 7eeaa78..3be394f 100644 --- a/apis/device.py +++ b/apis/device.py @@ -3,11 +3,14 @@ from fastapi import APIRouter, Depends, Query from sqlmodel import Session +from algo.algo_runner import AlgoRunner from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response from db.database import get_db from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo from services.device_service import DeviceService +from algo.algo_runner_manager import get_algo_runner + router = APIRouter() @@ -41,15 +44,15 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)): - service = DeviceService(db) +def create_device(device_data: DeviceCreate, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.device_service device = service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)): - service = DeviceService(db) +def update_device(device_data: DeviceUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.device_service device = service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") @@ -57,8 +60,8 @@ @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, db: Session = Depends(get_db)): - service = DeviceService(db) +def delete_device(device_id: int, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.device_service device = service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py index 7108571..9152eca 100644 --- a/apis/device_model_realtion.py +++ b/apis/device_model_realtion.py @@ -8,6 +8,9 @@ from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation from services.device_model_relation_service import DeviceModelRelationService +from algo.algo_runner import AlgoRunner +from algo.algo_runner_manager import get_algo_runner + router = APIRouter() @@ -32,8 +35,8 @@ @router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) def update_by_device(relation_data: List[DeviceModelRelationCreate], device_id: int = Query(...), - db: Session = Depends(get_db)): - service = DeviceModelRelationService(db) + algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.relation_service relations = service.update_relations_by_device(device_id, relation_data) return standard_response(data=relations) diff --git a/apis/model.py b/apis/model.py index 8adcc14..da63a61 100644 --- a/apis/model.py +++ b/apis/model.py @@ -1,14 +1,17 @@ from typing import List, Optional -from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi import APIRouter, Depends, Query from sqlmodel import Session + from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response -from common.biz_exception import BizException from db.database import get_db from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo from services.model_service import ModelService +from algo.algo_runner import AlgoRunner +from algo.algo_runner_manager import get_algo_runner + router = APIRouter() @@ -47,8 +50,8 @@ @router.post("/update", response_model=StandardResponse[AlgoModelInfo]) -def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)): - service = ModelService(db) +def update_model(model_data: AlgoModelUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.model_service model = service.update_model(model_data) if not model: return standard_error_response(data=model_data, message="Model not found") diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py index f55359b..7f3be9f 100644 --- a/common/global_thread_pool.py +++ b/common/global_thread_pool.py @@ -1,3 +1,4 @@ +import traceback import uuid from concurrent.futures import ThreadPoolExecutor import threading @@ -33,6 +34,7 @@ raise ValueError(f"线程 ID {thread_id} 已存在") future = self.executor.submit(fn, *args, **kwargs) self.task_map[thread_id] = future # 记录线程 ID 和 Future 对象的映射 + future.add_done_callback(lambda f: self._handle_exception(f, thread_id)) return thread_id def check_task_is_running(self, thread_id): @@ -64,6 +66,15 @@ logger.warning(f"No task found with thread ID {thread_id}.") return True # 如果找不到该任务,认为它已经停止(或者不存在) + def get_task_future(self, thread_id): + """获取指定线程 ID 的 future 对象""" + future = self.task_map.get(thread_id) + if future: + return future + else: + logger.warning(f"No task found with thread ID {thread_id}.") + return None + def stop_task(self, thread_id): """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务""" future = self.task_map.get(thread_id) @@ -78,3 +89,16 @@ """关闭线程池""" self.executor.shutdown(wait=wait) GlobalThreadPool._instance = None + + def _handle_exception(self, future, thread_id): + """ + 处理任务完成时的异常 + :param future: 完成的任务 future 对象 + :param camera_id: 对应的摄像头 ID + """ + try: + # 获取任务结果,如果任务有异常,这里会抛出 + result = future.result() + except Exception as e: + logger.error(f"Task for thread {thread_id} raised an exception: {e}") + logger.error(f"Traceback: {traceback.format_exc()}") \ No newline at end of file diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db index 62f6546..765b724 100644 --- a/db/safe-algo-pro.db +++ b/db/safe-algo-pro.db Binary files differ diff --git a/entity/device.py b/entity/device.py index f09b52d..4e30fbd 100644 --- a/entity/device.py +++ b/entity/device.py @@ -17,6 +17,11 @@ class Device(DeviceBase, TimestampMixin, table=True): id: Optional[int] = Field(default=None, primary_key=True) + def __str__(self): + return f"id={self.id} name={self.name} code={self.code} ip={self.ip}" \ + f" input_stream_url={self.input_stream_url} output_stream_url={self.output_stream_url}" \ + f" image_save_interval={self.image_save_interval}" + class DeviceCreate(DeviceBase): pass diff --git a/entity/device_frame.py b/entity/device_frame.py index 5612725..7c9aa89 100644 --- a/entity/device_frame.py +++ b/entity/device_frame.py @@ -10,7 +10,7 @@ time: datetime = Field(default_factory=datetime.now) -class DeviceFrame(DeviceFrameBase): +class DeviceFrame(DeviceFrameBase, table = True): __tablename__ = 'device_frame' id: Optional[int] = Field(default=None, primary_key=True) diff --git a/entity/frame_analysis_result.py b/entity/frame_analysis_result.py index 2185e4b..ca8793e 100644 --- a/entity/frame_analysis_result.py +++ b/entity/frame_analysis_result.py @@ -15,7 +15,7 @@ time: datetime = Field(default_factory=datetime.now) -class FrameAnalysisResult(FrameAnalysisResultBase): +class FrameAnalysisResult(FrameAnalysisResultBase, table = True): __tablename__ = 'frame_analysis_result' id: Optional[int] = Field(default=None, primary_key=True) diff --git a/main.py b/main.py index fdbfb7d..1121638 100644 --- a/main.py +++ b/main.py @@ -2,7 +2,7 @@ from fastapi import FastAPI, HTTPException -from algo.algo_runner import AlgoRunner +from algo.algo_runner_manager import get_algo_runner from apis.router import router import uvicorn import logging @@ -22,7 +22,7 @@ # async def startup_event(): # algo_runner.start() -algo_runner = AlgoRunner() +algo_runner = get_algo_runner() @asynccontextmanager diff --git a/model_handler/base_model_handler.py b/model_handler/base_model_handler.py index 7dd4999..68c04ac 100644 --- a/model_handler/base_model_handler.py +++ b/model_handler/base_model_handler.py @@ -5,7 +5,7 @@ def __init__(self, model: AlgoModelExec): self.model = model - self.model_names = model.algo_model_exec.model_name + self.model_names = model.algo_model_exec.names def pre_process(self, frame): return frame @@ -27,7 +27,7 @@ 'object_class_id': int(box.cls), 'object_class_name': self.model_names[int(box.cls)], 'confidence': float(box.conf), - 'location': box.xyxyn.cpu().squeeze().tolist() + 'location': ", ".join([f"{x:.6f}" for x in box.xyxyn.cpu().squeeze().tolist()]) } ) if annotator is not None: diff --git a/services/device_frame_service.py b/services/device_frame_service.py index b062dd4..6b9c308 100644 --- a/services/device_frame_service.py +++ b/services/device_frame_service.py @@ -16,9 +16,9 @@ def add_frame(self, device_id, frame_data) -> DeviceFrame: def save_frame_file(): # 生成当前年月日作为目录路径 - current_date = datetime.now().strftime('%Y/%m/%d') + current_date = datetime.now().strftime('%Y-%m-%d') # 生成随机 UUID 作为文件名 - file_name = str(uuid.uuid4()) + ".png" + file_name = str(uuid.uuid4()) + ".jpeg" # 创建保存图片的完整路径 save_path = os.path.join('./storage/frames', current_date, file_name) # 创建目录(如果不存在) diff --git a/services/device_service.py b/services/device_service.py index d3a2f10..918caae 100644 --- a/services/device_service.py +++ b/services/device_service.py @@ -76,6 +76,7 @@ return device def update_device(self, device_data: DeviceUpdate): + device_old = self.db.get(Device, device_data.id) device = self.db.get(Device, device_data.id) if not device: return None diff --git a/algo/algo_runner.py b/algo/algo_runner.py index 3b07e0a..a97aee9 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,7 @@ +import copy +import uuid +from typing import Dict + from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager from common.consts import NotifyChangeType @@ -9,6 +13,7 @@ from services.model_service import ModelService from common.global_logger import logger + class AlgoRunner: def __init__(self): with next(get_db()) as db: @@ -18,7 +23,7 @@ self.model_manager = ModelManager(db) self.thread_pool = GlobalThreadPool() - self.device_tasks = {} # 用于存储设备对应的线程 + self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 @@ -38,22 +43,24 @@ for device in devices: self.start_device_thread(device) - @staticmethod - def get_device_thread_id(device_id): - return f'device_{device_id}' + # @staticmethod + # def get_device_thread_id(device_id): + # return f'device_{device_id}' def start_device_thread(self, device: Device): + device = copy.deepcopy(device) """为单个设备启动检测线程""" if not device.input_stream_url: logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return - thread_id = self.get_device_thread_id(device.id) + thread_id = f'device_{device.id}_{uuid.uuid4()}' + logger.info(f'start thread {thread_id}, device info: {device}') - if self.thread_pool.check_task_is_running(thread_id=thread_id): - logger.info(f"设备 {device.code} 已经在运行中") - return + # if self.thread_pool.check_task_is_running(thread_id=thread_id): + # logger.info(f"设备 {device.code} 已经在运行中") + # return # 获取设备绑定的模型列表 relations = self.relation_service.get_device_models(device.id) @@ -64,52 +71,62 @@ logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") return + for r in relations: + if r.algo_model_id not in self.model_manager.models: + self.model_manager.load_new_model(r.algo_model_id) + model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] - device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list) + device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list, thread_id=thread_id) future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) self.device_tasks[device.id] = device_detection_task - def stop_device_thread(self, device_id): - """停止指定设备的检测线程""" - if device_id in self.device_tasks: - self.device_tasks[device_id].stop_detection_task() - del self.device_tasks[device_id] - # def stop_device_thread(self, device_id): - # """停止指定设备的检测线程,并确保其成功停止""" + # """停止指定设备的检测线程""" # if device_id in self.device_tasks: - # # 获取线程 ID 和 future 对象 - # thread_id = self.get_device_thread_id(device_id) - # future = self.thread_pool.check_task_stopped(thread_id=thread_id) - # - # if future: - # # 调用 stop_detection_task 停止任务 - # self.device_tasks[device_id].stop_detection_task() - # - # try: - # # 设置超时时间等待任务停止(例如10秒) - # result = future.result(timeout=10) - # logger.info(f"Task for device {device_id} stopped successfully.") - # except TimeoutError: - # logger.error(f"Task for device {device_id} did not stop within the timeout.") - # except Exception as e: - # logger.error(f"Task for device {device_id} encountered an error while stopping: {e}") - # finally: - # # 确保无论任务是否停止,都将其从任务列表中移除 - # del self.device_tasks[device_id] - # else: - # logger.warning(f"No task found for device {device_id}.") - # else: - # logger.warning(f"No task exists for device {device_id} in device_tasks.") + # logger.info(f'stop device {device_id} thread') + # self.device_tasks[device_id].stop_detection_task() + # del self.device_tasks[device_id] + + def stop_device_thread(self, device_id): + """停止指定设备的检测线程,并确保其成功停止""" + if device_id in self.device_tasks: + logger.info(f'stop device {device_id} thread') + # 获取线程 ID 和 future 对象 + thread_id = self.device_tasks[device_id].thread_id + future = self.thread_pool.get_task_future(thread_id=thread_id) + + if future: + if not future.done(): + # 任务正在运行,调用 stop_detection_task 停止任务 + self.device_tasks[device_id].stop_detection_task() + try: + # 设置超时时间等待任务停止(例如10秒) + result = future.result(timeout=10) + logger.info(f"Task {thread_id} stopped successfully.") + except TimeoutError: + logger.error(f"Task {thread_id} did not stop within the timeout.") + except Exception as e: + logger.error(f"Task {thread_id} encountered an error while stopping: {e}") + finally: + # 确保无论任务是否停止,都将其从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.info(f"Task {thread_id} has already stopped.") + # 任务已停止,直接从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.warning(f"No task found for {thread_id} .") + else: + logger.warning(f"No task exists for device {device_id} in device_tasks.") def restart_device_thread(self, device_id): self.stop_device_thread(device_id) - # todo 需要留个关闭时间吗 device = self.device_service.get_device(device_id) self.start_device_thread(device) def on_device_change(self, device_id, change_type): + logger.info(f"on device change, device {device_id} {change_type}") """设备变化回调""" if change_type == NotifyChangeType.DEVICE_CREATE: device = self.device_service.get_device(device_id) @@ -117,7 +134,10 @@ elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: - #todo 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) + self.restart_device_thread(device_id) def on_model_change(self, model_id, change_type): diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py new file mode 100644 index 0000000..d7d8015 --- /dev/null +++ b/algo/algo_runner_manager.py @@ -0,0 +1,7 @@ +from algo.algo_runner import AlgoRunner + +algo_runner = AlgoRunner() + + +def get_algo_runner(): + return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index ff2d513..5e1c75d 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -6,6 +6,7 @@ from algo.model_manager import AlgoModelExec from algo.stream_loader import OpenCVStreamLoad +from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool from common.string_utils import camel_to_snake from db.database import get_db @@ -40,11 +41,12 @@ class DeviceDetectionTask: - def __init__(self, device: Device, model_exec_list: List[AlgoModelExec]): + def __init__(self, device: Device, model_exec_list: List[AlgoModelExec], thread_id): self.device = device self.model_exec_list = model_exec_list self.__stop_event = Event() # 使用 Event 控制线程的运行状态 self.frame_ts = None + self.thread_id = thread_id with next(get_db()) as db: self.device_frame_service = DeviceFrameService(db) @@ -52,9 +54,11 @@ self.thread_pool = GlobalThreadPool() - self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code) + self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, + device_thread_id=thread_id) def stop_detection_task(self): + logger.info(f'stop detection task {self.device.id}, thread_id: {self.thread_id}') self.__stop_event.set() self.stream_loader.stop() # 停止视频流加载的线程 @@ -64,7 +68,7 @@ if self.frame_ts is None: self.frame_ts = datetime.now() return True - if datetime.now() - self.frame_ts > self.device.image_save_interval: + if (datetime.now() - self.frame_ts).total_seconds() > self.device.image_save_interval: self.frame_ts = datetime.now() return True return False @@ -74,6 +78,7 @@ return device_frame = self.device_frame_service.add_frame(self.device.id, frame) frame_id = device_frame.id + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') frame_results = [] for model_exec_id, results in results_map.items(): for r in results: @@ -83,7 +88,7 @@ algo_model_id=model_exec_id, object_class_id=r.object_class_id, object_class_name=r.object_class_name, - confidence=r.conf, + confidence=round(r.confidence, 4), location=r.location, ) frame_results.append(frame_result) @@ -93,6 +98,9 @@ for frame in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 + # print('frame') + if frame is None: + continue results_map = {} for model_exec in self.model_exec_list: diff --git a/algo/model_manager.py b/algo/model_manager.py index 535c666..94a058c 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -66,3 +66,15 @@ if algo_model_exec: algo_model_exec.algo_model_exec = None self.load_model(algo_model_exec) + + def load_new_model(self, model_id): + algo_model = self.model_service.get_model_by_id(model_id) + if algo_model: + algo_model_exec = AlgoModelExec( + algo_model_id=algo_model.id, + algo_model_info=algo_model + ) + self.models[algo_model.id] = algo_model_exec + self.load_model(algo_model_exec) + return algo_model_exec + return None diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 04faeb2..9a1199e 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -3,10 +3,11 @@ import numpy as np from threading import Thread, Event from common.global_logger import logger +from common.global_thread_pool import GlobalThreadPool class OpenCVStreamLoad: - def __init__(self, camera_url, camera_code, + def __init__(self, camera_url, camera_code, device_thread_id = '', retry_interval=1, vid_stride=1): assert camera_url is not None and camera_url != '' @@ -14,16 +15,26 @@ self.camera_code = camera_code self.retry_interval = retry_interval self.vid_stride = vid_stride + self.device_thread_id = device_thread_id - self.cap = self.get_connect() - self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) - self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) - self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + self.init = False + self.frame = None + self.cap = None - _, self.frame = self.cap.read() self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 - self.thread = Thread(target=self.update, daemon=True) - self.thread.start() + self.thread_pool = GlobalThreadPool() + self.thread_pool.submit_task(self.update) + # self.thread = Thread(target=self.update, daemon=True) + # self.thread.start() + + def init_cap(self): + self.cap = self.get_connect() + if self.cap: + self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + _, self.frame = self.cap.read() + self.init = True def create_capture(self): """ @@ -43,7 +54,11 @@ """ cap = None while cap is None or not cap.isOpened(): + if self.__stop_event.is_set(): # 检查是否收到停止信号 + logger.info(f"{self.url} stopping connection attempts...thread {self.device_thread_id} stopped") + return None # 退出循环,返回 None 或者其他你认为合适的值 logger.info(f"{self.url} try to connect...") + print('cap') cap = self.create_capture() if cap is None or not cap.isOpened(): logger.info(f"{self.url} connect failed, retry after {self.retry_interval} second...") @@ -55,13 +70,18 @@ def update(self): vid_n = 0 log_n = 0 - while True: + while not self.__stop_event.is_set(): + print('update') + if not self.init: + self.init_cap() + if self.cap is None: + continue vid_n += 1 if vid_n % self.vid_stride == 0: try: ret, frame = self.cap.read() if not ret: - logger.info("disconnect, try to reconnect...") + logger.info(f"{self.url} disconnect, try to reconnect...") self.cap.release() # 释放当前的捕获对象 self.cap = self.get_connect() # 尝试重新连接 self.frame = np.zeros_like(self.frame) @@ -69,12 +89,11 @@ else: vid_n += 1 self.frame = frame - # cv2.imwrite('cv_test.jpg', frame) if log_n % 1000 == 0: - logger.debug('cap success') + logger.debug(f'{self.url} cap success') log_n = (log_n + 1) % 250 except Exception as e: - logger.error("update fail", e) + logger.error(f"{self.url} update fail", e) if self.cap is not None: self.cap.release() self.cap = self.get_connect() # 尝试重新连接 @@ -86,7 +105,9 @@ return self.frame def stop(self): + logger.info(f'stop stream loader {self.url},device thread {self.device_thread_id} stopped') """ 停止视频流读取线程 """ self.__stop_event.set() - self.thread.join() # 确保线程已完全终止 - self.cap.release() + # self.thread.join() # 确保线程已完全终止 + if self.cap: + self.cap.release() diff --git a/apis/device.py b/apis/device.py index 7eeaa78..3be394f 100644 --- a/apis/device.py +++ b/apis/device.py @@ -3,11 +3,14 @@ from fastapi import APIRouter, Depends, Query from sqlmodel import Session +from algo.algo_runner import AlgoRunner from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response from db.database import get_db from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo from services.device_service import DeviceService +from algo.algo_runner_manager import get_algo_runner + router = APIRouter() @@ -41,15 +44,15 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)): - service = DeviceService(db) +def create_device(device_data: DeviceCreate, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.device_service device = service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)): - service = DeviceService(db) +def update_device(device_data: DeviceUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.device_service device = service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") @@ -57,8 +60,8 @@ @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, db: Session = Depends(get_db)): - service = DeviceService(db) +def delete_device(device_id: int, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.device_service device = service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py index 7108571..9152eca 100644 --- a/apis/device_model_realtion.py +++ b/apis/device_model_realtion.py @@ -8,6 +8,9 @@ from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation from services.device_model_relation_service import DeviceModelRelationService +from algo.algo_runner import AlgoRunner +from algo.algo_runner_manager import get_algo_runner + router = APIRouter() @@ -32,8 +35,8 @@ @router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) def update_by_device(relation_data: List[DeviceModelRelationCreate], device_id: int = Query(...), - db: Session = Depends(get_db)): - service = DeviceModelRelationService(db) + algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.relation_service relations = service.update_relations_by_device(device_id, relation_data) return standard_response(data=relations) diff --git a/apis/model.py b/apis/model.py index 8adcc14..da63a61 100644 --- a/apis/model.py +++ b/apis/model.py @@ -1,14 +1,17 @@ from typing import List, Optional -from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi import APIRouter, Depends, Query from sqlmodel import Session + from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response -from common.biz_exception import BizException from db.database import get_db from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo from services.model_service import ModelService +from algo.algo_runner import AlgoRunner +from algo.algo_runner_manager import get_algo_runner + router = APIRouter() @@ -47,8 +50,8 @@ @router.post("/update", response_model=StandardResponse[AlgoModelInfo]) -def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)): - service = ModelService(db) +def update_model(model_data: AlgoModelUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)): + service = algo_runner.model_service model = service.update_model(model_data) if not model: return standard_error_response(data=model_data, message="Model not found") diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py index f55359b..7f3be9f 100644 --- a/common/global_thread_pool.py +++ b/common/global_thread_pool.py @@ -1,3 +1,4 @@ +import traceback import uuid from concurrent.futures import ThreadPoolExecutor import threading @@ -33,6 +34,7 @@ raise ValueError(f"线程 ID {thread_id} 已存在") future = self.executor.submit(fn, *args, **kwargs) self.task_map[thread_id] = future # 记录线程 ID 和 Future 对象的映射 + future.add_done_callback(lambda f: self._handle_exception(f, thread_id)) return thread_id def check_task_is_running(self, thread_id): @@ -64,6 +66,15 @@ logger.warning(f"No task found with thread ID {thread_id}.") return True # 如果找不到该任务,认为它已经停止(或者不存在) + def get_task_future(self, thread_id): + """获取指定线程 ID 的 future 对象""" + future = self.task_map.get(thread_id) + if future: + return future + else: + logger.warning(f"No task found with thread ID {thread_id}.") + return None + def stop_task(self, thread_id): """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务""" future = self.task_map.get(thread_id) @@ -78,3 +89,16 @@ """关闭线程池""" self.executor.shutdown(wait=wait) GlobalThreadPool._instance = None + + def _handle_exception(self, future, thread_id): + """ + 处理任务完成时的异常 + :param future: 完成的任务 future 对象 + :param camera_id: 对应的摄像头 ID + """ + try: + # 获取任务结果,如果任务有异常,这里会抛出 + result = future.result() + except Exception as e: + logger.error(f"Task for thread {thread_id} raised an exception: {e}") + logger.error(f"Traceback: {traceback.format_exc()}") \ No newline at end of file diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db index 62f6546..765b724 100644 --- a/db/safe-algo-pro.db +++ b/db/safe-algo-pro.db Binary files differ diff --git a/entity/device.py b/entity/device.py index f09b52d..4e30fbd 100644 --- a/entity/device.py +++ b/entity/device.py @@ -17,6 +17,11 @@ class Device(DeviceBase, TimestampMixin, table=True): id: Optional[int] = Field(default=None, primary_key=True) + def __str__(self): + return f"id={self.id} name={self.name} code={self.code} ip={self.ip}" \ + f" input_stream_url={self.input_stream_url} output_stream_url={self.output_stream_url}" \ + f" image_save_interval={self.image_save_interval}" + class DeviceCreate(DeviceBase): pass diff --git a/entity/device_frame.py b/entity/device_frame.py index 5612725..7c9aa89 100644 --- a/entity/device_frame.py +++ b/entity/device_frame.py @@ -10,7 +10,7 @@ time: datetime = Field(default_factory=datetime.now) -class DeviceFrame(DeviceFrameBase): +class DeviceFrame(DeviceFrameBase, table = True): __tablename__ = 'device_frame' id: Optional[int] = Field(default=None, primary_key=True) diff --git a/entity/frame_analysis_result.py b/entity/frame_analysis_result.py index 2185e4b..ca8793e 100644 --- a/entity/frame_analysis_result.py +++ b/entity/frame_analysis_result.py @@ -15,7 +15,7 @@ time: datetime = Field(default_factory=datetime.now) -class FrameAnalysisResult(FrameAnalysisResultBase): +class FrameAnalysisResult(FrameAnalysisResultBase, table = True): __tablename__ = 'frame_analysis_result' id: Optional[int] = Field(default=None, primary_key=True) diff --git a/main.py b/main.py index fdbfb7d..1121638 100644 --- a/main.py +++ b/main.py @@ -2,7 +2,7 @@ from fastapi import FastAPI, HTTPException -from algo.algo_runner import AlgoRunner +from algo.algo_runner_manager import get_algo_runner from apis.router import router import uvicorn import logging @@ -22,7 +22,7 @@ # async def startup_event(): # algo_runner.start() -algo_runner = AlgoRunner() +algo_runner = get_algo_runner() @asynccontextmanager diff --git a/model_handler/base_model_handler.py b/model_handler/base_model_handler.py index 7dd4999..68c04ac 100644 --- a/model_handler/base_model_handler.py +++ b/model_handler/base_model_handler.py @@ -5,7 +5,7 @@ def __init__(self, model: AlgoModelExec): self.model = model - self.model_names = model.algo_model_exec.model_name + self.model_names = model.algo_model_exec.names def pre_process(self, frame): return frame @@ -27,7 +27,7 @@ 'object_class_id': int(box.cls), 'object_class_name': self.model_names[int(box.cls)], 'confidence': float(box.conf), - 'location': box.xyxyn.cpu().squeeze().tolist() + 'location': ", ".join([f"{x:.6f}" for x in box.xyxyn.cpu().squeeze().tolist()]) } ) if annotator is not None: diff --git a/services/device_frame_service.py b/services/device_frame_service.py index b062dd4..6b9c308 100644 --- a/services/device_frame_service.py +++ b/services/device_frame_service.py @@ -16,9 +16,9 @@ def add_frame(self, device_id, frame_data) -> DeviceFrame: def save_frame_file(): # 生成当前年月日作为目录路径 - current_date = datetime.now().strftime('%Y/%m/%d') + current_date = datetime.now().strftime('%Y-%m-%d') # 生成随机 UUID 作为文件名 - file_name = str(uuid.uuid4()) + ".png" + file_name = str(uuid.uuid4()) + ".jpeg" # 创建保存图片的完整路径 save_path = os.path.join('./storage/frames', current_date, file_name) # 创建目录(如果不存在) diff --git a/services/device_service.py b/services/device_service.py index d3a2f10..918caae 100644 --- a/services/device_service.py +++ b/services/device_service.py @@ -76,6 +76,7 @@ return device def update_device(self, device_data: DeviceUpdate): + device_old = self.db.get(Device, device_data.id) device = self.db.get(Device, device_data.id) if not device: return None diff --git a/services/model_service.py b/services/model_service.py index de387c0..8de08bb 100644 --- a/services/model_service.py +++ b/services/model_service.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import List, Sequence, Optional, Tuple +from typing import List, Sequence, Optional, Tuple, Type from sqlalchemy import func from sqlmodel import Session, select @@ -118,3 +118,6 @@ ) results = self.db.exec(statement).all() return results + + def get_model_by_id(self,model_id): + return self.db.get(AlgoModel, model_id)