diff --git a/algo/algo_runner.py b/algo/algo_runner.py index 40b8d1e..3541a1c 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,4 @@ +import asyncio import concurrent.futures import copy import uuid @@ -14,9 +15,10 @@ from services.model_service import ModelService from common.global_logger import logger + class AlgoRunner: def __init__(self, - device_service : DeviceService, + device_service: DeviceService, model_service: ModelService, relation_service: DeviceModelRelationService): @@ -27,8 +29,11 @@ self.thread_pool = GlobalThreadPool() self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程 + self.task_futures: Dict[str, concurrent.futures.Future] = {} # 用于存储线程的 future 对象 self.device_model_relations = {} + self.main_loop = asyncio.get_running_loop() + # 注册设备和模型的变化回调 # self.device_service.register_change_callback(self.on_device_change) # self.model_service.register_change_callback(self.on_model_change) @@ -37,16 +42,16 @@ async def start(self): logger.info("Starting AlgoRunner...") """在程序启动时调用,读取设备和模型,启动检测线程""" - self.model_manager.load_models() - self.load_and_start_devices() + await self.model_manager.load_models() + await self.load_and_start_devices() - def load_and_start_devices(self): + async def load_and_start_devices(self): """从数据库读取设备列表并启动线程""" - devices = self.device_service.get_device_list() + devices = await self.device_service.get_device_list() for device in devices: - self.start_device_thread(device) + asyncio.create_task(self.start_device_thread(device)) - def start_device_thread(self, device: Device): + async def start_device_thread(self, device: Device): device = copy.deepcopy(device) """为单个设备启动检测线程""" @@ -65,30 +70,27 @@ # return # 获取设备绑定的模型列表 - relations = self.relation_service.get_device_models(device.id) - relations = [r for r in relations if r.is_use] - self.device_model_relations[device.id] = relations + async for db in get_db(): + relation_service = DeviceModelRelationService(db) + relations = await relation_service.get_device_models(device.id) + relations = [r for r in relations if r.is_use] + self.device_model_relations[device.id] = relations - if not relations: - logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") - return + if not relations: + logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") + return - for r in relations: - if r.algo_model_id not in self.model_manager.models: - self.model_manager.load_new_model(r.algo_model_id) + for r in relations: + if r.algo_model_id not in self.model_manager.models: + await self.model_manager.load_new_model(r.algo_model_id) - model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations - if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] - device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list, thread_id=thread_id) - future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) - self.device_tasks[device.id] = device_detection_task - - # def stop_device_thread(self, device_id): - # """停止指定设备的检测线程""" - # if device_id in self.device_tasks: - # logger.info(f'stop device {device_id} thread') - # self.device_tasks[device_id].stop_detection_task() - # del self.device_tasks[device_id] + model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations + if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] + device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list, + thread_id=thread_id, db_session=db, main_loop=self.main_loop) + future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) + self.device_tasks[device.id] = device_detection_task + self.task_futures[thread_id] = future def stop_device_thread(self, device_id): """停止指定设备的检测线程,并确保其成功停止""" @@ -96,7 +98,8 @@ logger.info(f'stop device {device_id} thread') # 获取线程 ID 和 future 对象 thread_id = self.device_tasks[device_id].thread_id - future = self.thread_pool.get_task_future(thread_id=thread_id) + # future = self.thread_pool.get_task_future(thread_id=thread_id) + future = self.task_futures[thread_id] if future: if not future.done(): @@ -122,17 +125,17 @@ else: logger.warning(f"No task exists for device {device_id} in device_tasks.") - def restart_device_thread(self, device_id): + async def restart_device_thread(self, device_id): self.stop_device_thread(device_id) - device = self.device_service.get_device(device_id) - self.start_device_thread(device) + device = await self.device_service.get_device(device_id) + asyncio.create_task(self.start_device_thread(device)) - def on_device_change(self, device_id, change_type): + async def on_device_change(self, device_id, change_type): logger.info(f"on device change, device {device_id} {change_type}") """设备变化回调""" if change_type == NotifyChangeType.DEVICE_CREATE: - device = self.device_service.get_device(device_id) - self.start_device_thread(device) + device = await self.device_service.get_device(device_id) + await self.start_device_thread(device) elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: @@ -141,9 +144,9 @@ old_device = self.device_tasks[device_id].device new_device = self.device_service.get_device(device_id) - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) - def on_model_change(self, model_id, change_type): + async def on_model_change(self, model_id, change_type): """模型变化回调""" if change_type == NotifyChangeType.MODEL_UPDATE: self.model_manager.reload_model(model_id) @@ -153,10 +156,10 @@ if relation_info.model_id == model_id ] for device_id in devices_to_reload: - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) - def on_relation_change(self, device_id, change_type): + async def on_relation_change(self, device_id, change_type): """设备模型关系变化回调""" # todo 判断模型绑定关系是否真的发生了变化,避免不必要的重启 if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) diff --git a/algo/algo_runner.py b/algo/algo_runner.py index 40b8d1e..3541a1c 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,4 @@ +import asyncio import concurrent.futures import copy import uuid @@ -14,9 +15,10 @@ from services.model_service import ModelService from common.global_logger import logger + class AlgoRunner: def __init__(self, - device_service : DeviceService, + device_service: DeviceService, model_service: ModelService, relation_service: DeviceModelRelationService): @@ -27,8 +29,11 @@ self.thread_pool = GlobalThreadPool() self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程 + self.task_futures: Dict[str, concurrent.futures.Future] = {} # 用于存储线程的 future 对象 self.device_model_relations = {} + self.main_loop = asyncio.get_running_loop() + # 注册设备和模型的变化回调 # self.device_service.register_change_callback(self.on_device_change) # self.model_service.register_change_callback(self.on_model_change) @@ -37,16 +42,16 @@ async def start(self): logger.info("Starting AlgoRunner...") """在程序启动时调用,读取设备和模型,启动检测线程""" - self.model_manager.load_models() - self.load_and_start_devices() + await self.model_manager.load_models() + await self.load_and_start_devices() - def load_and_start_devices(self): + async def load_and_start_devices(self): """从数据库读取设备列表并启动线程""" - devices = self.device_service.get_device_list() + devices = await self.device_service.get_device_list() for device in devices: - self.start_device_thread(device) + asyncio.create_task(self.start_device_thread(device)) - def start_device_thread(self, device: Device): + async def start_device_thread(self, device: Device): device = copy.deepcopy(device) """为单个设备启动检测线程""" @@ -65,30 +70,27 @@ # return # 获取设备绑定的模型列表 - relations = self.relation_service.get_device_models(device.id) - relations = [r for r in relations if r.is_use] - self.device_model_relations[device.id] = relations + async for db in get_db(): + relation_service = DeviceModelRelationService(db) + relations = await relation_service.get_device_models(device.id) + relations = [r for r in relations if r.is_use] + self.device_model_relations[device.id] = relations - if not relations: - logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") - return + if not relations: + logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") + return - for r in relations: - if r.algo_model_id not in self.model_manager.models: - self.model_manager.load_new_model(r.algo_model_id) + for r in relations: + if r.algo_model_id not in self.model_manager.models: + await self.model_manager.load_new_model(r.algo_model_id) - model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations - if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] - device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list, thread_id=thread_id) - future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) - self.device_tasks[device.id] = device_detection_task - - # def stop_device_thread(self, device_id): - # """停止指定设备的检测线程""" - # if device_id in self.device_tasks: - # logger.info(f'stop device {device_id} thread') - # self.device_tasks[device_id].stop_detection_task() - # del self.device_tasks[device_id] + model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations + if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] + device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list, + thread_id=thread_id, db_session=db, main_loop=self.main_loop) + future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) + self.device_tasks[device.id] = device_detection_task + self.task_futures[thread_id] = future def stop_device_thread(self, device_id): """停止指定设备的检测线程,并确保其成功停止""" @@ -96,7 +98,8 @@ logger.info(f'stop device {device_id} thread') # 获取线程 ID 和 future 对象 thread_id = self.device_tasks[device_id].thread_id - future = self.thread_pool.get_task_future(thread_id=thread_id) + # future = self.thread_pool.get_task_future(thread_id=thread_id) + future = self.task_futures[thread_id] if future: if not future.done(): @@ -122,17 +125,17 @@ else: logger.warning(f"No task exists for device {device_id} in device_tasks.") - def restart_device_thread(self, device_id): + async def restart_device_thread(self, device_id): self.stop_device_thread(device_id) - device = self.device_service.get_device(device_id) - self.start_device_thread(device) + device = await self.device_service.get_device(device_id) + asyncio.create_task(self.start_device_thread(device)) - def on_device_change(self, device_id, change_type): + async def on_device_change(self, device_id, change_type): logger.info(f"on device change, device {device_id} {change_type}") """设备变化回调""" if change_type == NotifyChangeType.DEVICE_CREATE: - device = self.device_service.get_device(device_id) - self.start_device_thread(device) + device = await self.device_service.get_device(device_id) + await self.start_device_thread(device) elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: @@ -141,9 +144,9 @@ old_device = self.device_tasks[device_id].device new_device = self.device_service.get_device(device_id) - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) - def on_model_change(self, model_id, change_type): + async def on_model_change(self, model_id, change_type): """模型变化回调""" if change_type == NotifyChangeType.MODEL_UPDATE: self.model_manager.reload_model(model_id) @@ -153,10 +156,10 @@ if relation_info.model_id == model_id ] for device_id in devices_to_reload: - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) - def on_relation_change(self, device_id, change_type): + async def on_relation_change(self, device_id, change_type): """设备模型关系变化回调""" # todo 判断模型绑定关系是否真的发生了变化,避免不必要的重启 if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index cdb97bf..8724c0d 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -39,7 +39,7 @@ class DeviceDetectionTask: - def __init__(self, device: Device, model_exec_list: List[AlgoModelExec], thread_id): + def __init__(self, device: Device, model_exec_list: List[AlgoModelExec], thread_id, db_session, main_loop): self.device = device self.model_exec_list = model_exec_list self.__stop_event = Event() # 使用 Event 控制线程的运行状态 @@ -47,12 +47,12 @@ self.push_ts = None self.thread_id = thread_id - with next(get_db()) as db: - self.device_frame_service = DeviceFrameService(db) - self.frame_analysis_result_service = FrameAnalysisResultService(db) + self.device_frame_service = DeviceFrameService(db_session) + self.frame_analysis_result_service = FrameAnalysisResultService(db_session) self.thread_pool = GlobalThreadPool() self.device_status_manager = DeviceStatusManager() + self.main_loop = main_loop self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, device_thread_id=thread_id) @@ -76,8 +76,15 @@ def save_frame_results(self, frame, results_map): if not self.check_frame_interval(): return - device_frame = self.device_frame_service.add_frame(self.device.id, frame) + # device_frame = self.device_frame_service.add_frame(self.device.id, frame) + # frame_id = device_frame.id + + future = asyncio.run_coroutine_threadsafe( + self.device_frame_service.add_frame(self.device.id, frame), self.main_loop + ) + device_frame = future.result() frame_id = device_frame.id + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') frame_results = [] for model_exec_id, results in results_map.items(): @@ -92,7 +99,10 @@ location=r.location, ) frame_results.append(frame_result) - self.frame_analysis_result_service.add_frame_analysis_results(frame_results) + asyncio.run_coroutine_threadsafe( + self.frame_analysis_result_service.add_frame_analysis_results(frame_results), + self.main_loop + ) self.thread_pool.submit_task(self.push_frame_results, frame_results) def push_frame_results(self, frame_results): diff --git a/algo/algo_runner.py b/algo/algo_runner.py index 40b8d1e..3541a1c 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,4 @@ +import asyncio import concurrent.futures import copy import uuid @@ -14,9 +15,10 @@ from services.model_service import ModelService from common.global_logger import logger + class AlgoRunner: def __init__(self, - device_service : DeviceService, + device_service: DeviceService, model_service: ModelService, relation_service: DeviceModelRelationService): @@ -27,8 +29,11 @@ self.thread_pool = GlobalThreadPool() self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程 + self.task_futures: Dict[str, concurrent.futures.Future] = {} # 用于存储线程的 future 对象 self.device_model_relations = {} + self.main_loop = asyncio.get_running_loop() + # 注册设备和模型的变化回调 # self.device_service.register_change_callback(self.on_device_change) # self.model_service.register_change_callback(self.on_model_change) @@ -37,16 +42,16 @@ async def start(self): logger.info("Starting AlgoRunner...") """在程序启动时调用,读取设备和模型,启动检测线程""" - self.model_manager.load_models() - self.load_and_start_devices() + await self.model_manager.load_models() + await self.load_and_start_devices() - def load_and_start_devices(self): + async def load_and_start_devices(self): """从数据库读取设备列表并启动线程""" - devices = self.device_service.get_device_list() + devices = await self.device_service.get_device_list() for device in devices: - self.start_device_thread(device) + asyncio.create_task(self.start_device_thread(device)) - def start_device_thread(self, device: Device): + async def start_device_thread(self, device: Device): device = copy.deepcopy(device) """为单个设备启动检测线程""" @@ -65,30 +70,27 @@ # return # 获取设备绑定的模型列表 - relations = self.relation_service.get_device_models(device.id) - relations = [r for r in relations if r.is_use] - self.device_model_relations[device.id] = relations + async for db in get_db(): + relation_service = DeviceModelRelationService(db) + relations = await relation_service.get_device_models(device.id) + relations = [r for r in relations if r.is_use] + self.device_model_relations[device.id] = relations - if not relations: - logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") - return + if not relations: + logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") + return - for r in relations: - if r.algo_model_id not in self.model_manager.models: - self.model_manager.load_new_model(r.algo_model_id) + for r in relations: + if r.algo_model_id not in self.model_manager.models: + await self.model_manager.load_new_model(r.algo_model_id) - model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations - if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] - device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list, thread_id=thread_id) - future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) - self.device_tasks[device.id] = device_detection_task - - # def stop_device_thread(self, device_id): - # """停止指定设备的检测线程""" - # if device_id in self.device_tasks: - # logger.info(f'stop device {device_id} thread') - # self.device_tasks[device_id].stop_detection_task() - # del self.device_tasks[device_id] + model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations + if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] + device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list, + thread_id=thread_id, db_session=db, main_loop=self.main_loop) + future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) + self.device_tasks[device.id] = device_detection_task + self.task_futures[thread_id] = future def stop_device_thread(self, device_id): """停止指定设备的检测线程,并确保其成功停止""" @@ -96,7 +98,8 @@ logger.info(f'stop device {device_id} thread') # 获取线程 ID 和 future 对象 thread_id = self.device_tasks[device_id].thread_id - future = self.thread_pool.get_task_future(thread_id=thread_id) + # future = self.thread_pool.get_task_future(thread_id=thread_id) + future = self.task_futures[thread_id] if future: if not future.done(): @@ -122,17 +125,17 @@ else: logger.warning(f"No task exists for device {device_id} in device_tasks.") - def restart_device_thread(self, device_id): + async def restart_device_thread(self, device_id): self.stop_device_thread(device_id) - device = self.device_service.get_device(device_id) - self.start_device_thread(device) + device = await self.device_service.get_device(device_id) + asyncio.create_task(self.start_device_thread(device)) - def on_device_change(self, device_id, change_type): + async def on_device_change(self, device_id, change_type): logger.info(f"on device change, device {device_id} {change_type}") """设备变化回调""" if change_type == NotifyChangeType.DEVICE_CREATE: - device = self.device_service.get_device(device_id) - self.start_device_thread(device) + device = await self.device_service.get_device(device_id) + await self.start_device_thread(device) elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: @@ -141,9 +144,9 @@ old_device = self.device_tasks[device_id].device new_device = self.device_service.get_device(device_id) - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) - def on_model_change(self, model_id, change_type): + async def on_model_change(self, model_id, change_type): """模型变化回调""" if change_type == NotifyChangeType.MODEL_UPDATE: self.model_manager.reload_model(model_id) @@ -153,10 +156,10 @@ if relation_info.model_id == model_id ] for device_id in devices_to_reload: - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) - def on_relation_change(self, device_id, change_type): + async def on_relation_change(self, device_id, change_type): """设备模型关系变化回调""" # todo 判断模型绑定关系是否真的发生了变化,避免不必要的重启 if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index cdb97bf..8724c0d 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -39,7 +39,7 @@ class DeviceDetectionTask: - def __init__(self, device: Device, model_exec_list: List[AlgoModelExec], thread_id): + def __init__(self, device: Device, model_exec_list: List[AlgoModelExec], thread_id, db_session, main_loop): self.device = device self.model_exec_list = model_exec_list self.__stop_event = Event() # 使用 Event 控制线程的运行状态 @@ -47,12 +47,12 @@ self.push_ts = None self.thread_id = thread_id - with next(get_db()) as db: - self.device_frame_service = DeviceFrameService(db) - self.frame_analysis_result_service = FrameAnalysisResultService(db) + self.device_frame_service = DeviceFrameService(db_session) + self.frame_analysis_result_service = FrameAnalysisResultService(db_session) self.thread_pool = GlobalThreadPool() self.device_status_manager = DeviceStatusManager() + self.main_loop = main_loop self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, device_thread_id=thread_id) @@ -76,8 +76,15 @@ def save_frame_results(self, frame, results_map): if not self.check_frame_interval(): return - device_frame = self.device_frame_service.add_frame(self.device.id, frame) + # device_frame = self.device_frame_service.add_frame(self.device.id, frame) + # frame_id = device_frame.id + + future = asyncio.run_coroutine_threadsafe( + self.device_frame_service.add_frame(self.device.id, frame), self.main_loop + ) + device_frame = future.result() frame_id = device_frame.id + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') frame_results = [] for model_exec_id, results in results_map.items(): @@ -92,7 +99,10 @@ location=r.location, ) frame_results.append(frame_result) - self.frame_analysis_result_service.add_frame_analysis_results(frame_results) + asyncio.run_coroutine_threadsafe( + self.frame_analysis_result_service.add_frame_analysis_results(frame_results), + self.main_loop + ) self.thread_pool.submit_task(self.push_frame_results, frame_results) def push_frame_results(self, frame_results): diff --git a/algo/model_manager.py b/algo/model_manager.py index e3bc4c9..1c30867 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -25,18 +25,18 @@ self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up - def query_model_inuse(self): - algo_model_list = list(self.model_service.get_models_in_use()) + async def query_model_inuse(self): + algo_model_list = list(await self.model_service.get_models_in_use()) for algo_model in algo_model_list: self.models[algo_model.id] = AlgoModelExec( algo_model_id=algo_model.id, algo_model_info=algo_model ) - def load_models(self): + async def load_models(self): logger.info('loading models') self.models = {} - self.query_model_inuse() + await self.query_model_inuse() if not self.models: logger.info("no model in use") for algo_model_id, algo_model_exec in self.models.items(): @@ -69,8 +69,8 @@ algo_model_exec.algo_model_exec = None self.load_model(algo_model_exec) - def load_new_model(self, model_id): - algo_model = self.model_service.get_model_by_id(model_id) + async def load_new_model(self, model_id): + algo_model = await self.model_service.get_model_by_id(model_id) if algo_model: algo_model_exec = AlgoModelExec( algo_model_id=algo_model.id, diff --git a/algo/algo_runner.py b/algo/algo_runner.py index 40b8d1e..3541a1c 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,4 @@ +import asyncio import concurrent.futures import copy import uuid @@ -14,9 +15,10 @@ from services.model_service import ModelService from common.global_logger import logger + class AlgoRunner: def __init__(self, - device_service : DeviceService, + device_service: DeviceService, model_service: ModelService, relation_service: DeviceModelRelationService): @@ -27,8 +29,11 @@ self.thread_pool = GlobalThreadPool() self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程 + self.task_futures: Dict[str, concurrent.futures.Future] = {} # 用于存储线程的 future 对象 self.device_model_relations = {} + self.main_loop = asyncio.get_running_loop() + # 注册设备和模型的变化回调 # self.device_service.register_change_callback(self.on_device_change) # self.model_service.register_change_callback(self.on_model_change) @@ -37,16 +42,16 @@ async def start(self): logger.info("Starting AlgoRunner...") """在程序启动时调用,读取设备和模型,启动检测线程""" - self.model_manager.load_models() - self.load_and_start_devices() + await self.model_manager.load_models() + await self.load_and_start_devices() - def load_and_start_devices(self): + async def load_and_start_devices(self): """从数据库读取设备列表并启动线程""" - devices = self.device_service.get_device_list() + devices = await self.device_service.get_device_list() for device in devices: - self.start_device_thread(device) + asyncio.create_task(self.start_device_thread(device)) - def start_device_thread(self, device: Device): + async def start_device_thread(self, device: Device): device = copy.deepcopy(device) """为单个设备启动检测线程""" @@ -65,30 +70,27 @@ # return # 获取设备绑定的模型列表 - relations = self.relation_service.get_device_models(device.id) - relations = [r for r in relations if r.is_use] - self.device_model_relations[device.id] = relations + async for db in get_db(): + relation_service = DeviceModelRelationService(db) + relations = await relation_service.get_device_models(device.id) + relations = [r for r in relations if r.is_use] + self.device_model_relations[device.id] = relations - if not relations: - logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") - return + if not relations: + logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") + return - for r in relations: - if r.algo_model_id not in self.model_manager.models: - self.model_manager.load_new_model(r.algo_model_id) + for r in relations: + if r.algo_model_id not in self.model_manager.models: + await self.model_manager.load_new_model(r.algo_model_id) - model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations - if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] - device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list, thread_id=thread_id) - future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) - self.device_tasks[device.id] = device_detection_task - - # def stop_device_thread(self, device_id): - # """停止指定设备的检测线程""" - # if device_id in self.device_tasks: - # logger.info(f'stop device {device_id} thread') - # self.device_tasks[device_id].stop_detection_task() - # del self.device_tasks[device_id] + model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations + if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] + device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list, + thread_id=thread_id, db_session=db, main_loop=self.main_loop) + future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) + self.device_tasks[device.id] = device_detection_task + self.task_futures[thread_id] = future def stop_device_thread(self, device_id): """停止指定设备的检测线程,并确保其成功停止""" @@ -96,7 +98,8 @@ logger.info(f'stop device {device_id} thread') # 获取线程 ID 和 future 对象 thread_id = self.device_tasks[device_id].thread_id - future = self.thread_pool.get_task_future(thread_id=thread_id) + # future = self.thread_pool.get_task_future(thread_id=thread_id) + future = self.task_futures[thread_id] if future: if not future.done(): @@ -122,17 +125,17 @@ else: logger.warning(f"No task exists for device {device_id} in device_tasks.") - def restart_device_thread(self, device_id): + async def restart_device_thread(self, device_id): self.stop_device_thread(device_id) - device = self.device_service.get_device(device_id) - self.start_device_thread(device) + device = await self.device_service.get_device(device_id) + asyncio.create_task(self.start_device_thread(device)) - def on_device_change(self, device_id, change_type): + async def on_device_change(self, device_id, change_type): logger.info(f"on device change, device {device_id} {change_type}") """设备变化回调""" if change_type == NotifyChangeType.DEVICE_CREATE: - device = self.device_service.get_device(device_id) - self.start_device_thread(device) + device = await self.device_service.get_device(device_id) + await self.start_device_thread(device) elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: @@ -141,9 +144,9 @@ old_device = self.device_tasks[device_id].device new_device = self.device_service.get_device(device_id) - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) - def on_model_change(self, model_id, change_type): + async def on_model_change(self, model_id, change_type): """模型变化回调""" if change_type == NotifyChangeType.MODEL_UPDATE: self.model_manager.reload_model(model_id) @@ -153,10 +156,10 @@ if relation_info.model_id == model_id ] for device_id in devices_to_reload: - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) - def on_relation_change(self, device_id, change_type): + async def on_relation_change(self, device_id, change_type): """设备模型关系变化回调""" # todo 判断模型绑定关系是否真的发生了变化,避免不必要的重启 if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index cdb97bf..8724c0d 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -39,7 +39,7 @@ class DeviceDetectionTask: - def __init__(self, device: Device, model_exec_list: List[AlgoModelExec], thread_id): + def __init__(self, device: Device, model_exec_list: List[AlgoModelExec], thread_id, db_session, main_loop): self.device = device self.model_exec_list = model_exec_list self.__stop_event = Event() # 使用 Event 控制线程的运行状态 @@ -47,12 +47,12 @@ self.push_ts = None self.thread_id = thread_id - with next(get_db()) as db: - self.device_frame_service = DeviceFrameService(db) - self.frame_analysis_result_service = FrameAnalysisResultService(db) + self.device_frame_service = DeviceFrameService(db_session) + self.frame_analysis_result_service = FrameAnalysisResultService(db_session) self.thread_pool = GlobalThreadPool() self.device_status_manager = DeviceStatusManager() + self.main_loop = main_loop self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, device_thread_id=thread_id) @@ -76,8 +76,15 @@ def save_frame_results(self, frame, results_map): if not self.check_frame_interval(): return - device_frame = self.device_frame_service.add_frame(self.device.id, frame) + # device_frame = self.device_frame_service.add_frame(self.device.id, frame) + # frame_id = device_frame.id + + future = asyncio.run_coroutine_threadsafe( + self.device_frame_service.add_frame(self.device.id, frame), self.main_loop + ) + device_frame = future.result() frame_id = device_frame.id + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') frame_results = [] for model_exec_id, results in results_map.items(): @@ -92,7 +99,10 @@ location=r.location, ) frame_results.append(frame_result) - self.frame_analysis_result_service.add_frame_analysis_results(frame_results) + asyncio.run_coroutine_threadsafe( + self.frame_analysis_result_service.add_frame_analysis_results(frame_results), + self.main_loop + ) self.thread_pool.submit_task(self.push_frame_results, frame_results) def push_frame_results(self, frame_results): diff --git a/algo/model_manager.py b/algo/model_manager.py index e3bc4c9..1c30867 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -25,18 +25,18 @@ self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up - def query_model_inuse(self): - algo_model_list = list(self.model_service.get_models_in_use()) + async def query_model_inuse(self): + algo_model_list = list(await self.model_service.get_models_in_use()) for algo_model in algo_model_list: self.models[algo_model.id] = AlgoModelExec( algo_model_id=algo_model.id, algo_model_info=algo_model ) - def load_models(self): + async def load_models(self): logger.info('loading models') self.models = {} - self.query_model_inuse() + await self.query_model_inuse() if not self.models: logger.info("no model in use") for algo_model_id, algo_model_exec in self.models.items(): @@ -69,8 +69,8 @@ algo_model_exec.algo_model_exec = None self.load_model(algo_model_exec) - def load_new_model(self, model_id): - algo_model = self.model_service.get_model_by_id(model_id) + async def load_new_model(self, model_id): + algo_model = await self.model_service.get_model_by_id(model_id) if algo_model: algo_model_exec = AlgoModelExec( algo_model_id=algo_model.id, diff --git a/algo/scene_runner.py b/algo/scene_runner.py index 2fb2b2b..e5c067e 100644 --- a/algo/scene_runner.py +++ b/algo/scene_runner.py @@ -1,3 +1,4 @@ +import asyncio import concurrent import copy import uuid @@ -7,6 +8,7 @@ from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool from common.string_utils import get_class, camel_to_snake +from db.database import get_db from entity.device import Device from scene_handler.base_scene_handler import BaseSceneHandler from services.device_scene_relation_service import DeviceSceneRelationService @@ -38,16 +40,16 @@ # self.relation_service.register_change_callback(self.on_relation_change) async def start(self): - self.load_and_start_devices() + await self.load_and_start_devices() - def load_and_start_devices(self): + async def load_and_start_devices(self): """从数据库读取设备列表并启动线程""" - devices = self.device_service.get_device_list() + devices = await self.device_service.get_device_list() devices = [device for device in devices if device.mode == DEVICE_MODE.SCENE] for device in devices: - self.start_device_thread(device) + asyncio.create_task(self.start_device_thread(device)) - def start_device_thread(self, device: Device): + async def start_device_thread(self, device: Device): device = copy.deepcopy(device) """为单个设备启动检测线程""" @@ -59,18 +61,20 @@ return thread_id = f'device_{device.id}_scene_{uuid.uuid4()}' - scene = self.relation_service.get_device_scene(device.id) - if scene: - self.device_scene_relations[device.id] = scene - try: - handle_task_name = scene.scene_handle_task - handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) - handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) - future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) - self.device_tasks[device.id] = handler_instance - logger.info(f'start thread {thread_id}, device info: {device}') - except Exception as e: - logger.exception(f"设备 {device.code} 场景 {scene.scene_name} 启动异常 : {e}") + async for db in get_db(): + relation_service = DeviceSceneRelationService(db) + scene = await relation_service.get_device_scene(device.id) + if scene: + self.device_scene_relations[device.id] = scene + try: + handle_task_name = scene.scene_handle_task + handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) + handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) + future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) + self.device_tasks[device.id] = handler_instance + logger.info(f'start thread {thread_id}, device info: {device}') + except Exception as e: + logger.exception(f"设备 {device.code} 场景 {scene.scene_name} 启动异常 : {e}") def stop_device_thread(self, device_id): """停止指定设备的检测线程,并确保其成功停止""" @@ -104,20 +108,20 @@ else: logger.warning(f"No task exists for device {device_id} in device_tasks.") - def restart_device_thread(self, device_id): + async def restart_device_thread(self, device_id): try: self.stop_device_thread(device_id) - device = self.device_service.get_device(device_id) - self.start_device_thread(device) + device = await self.device_service.get_device(device_id) + asyncio.create_task(self.start_device_thread(device)) except Exception as e: logger.error(e) - def on_device_change(self, device_id, change_type): + async def on_device_change(self, device_id, change_type): logger.info(f"on device change, device {device_id} {change_type}") """设备变化回调""" if change_type == NotifyChangeType.DEVICE_CREATE: - device = self.device_service.get_device(device_id) - self.start_device_thread(device) + device = await self.device_service.get_device(device_id) + await self.start_device_thread(device) elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: @@ -126,9 +130,9 @@ old_device = self.device_tasks[device_id].device new_device = self.device_service.get_device(device_id) - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) - def on_scene_change(self, scene_id, change_type): + async def on_scene_change(self, scene_id, change_type): """模型变化回调""" if change_type == NotifyChangeType.SCENE_UPDATE: @@ -137,10 +141,10 @@ if relation_info.scene_id == scene_id ] for device_id in devices_to_reload: - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) - def on_relation_change(self, device_id, change_type): + async def on_relation_change(self, device_id, change_type): """设备模型关系变化回调""" # 判断模型绑定关系是否真的发生了变化,避免不必要的重启 if change_type == NotifyChangeType.DEVICE_SCENE_RELATION_UPDATE: - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) diff --git a/algo/algo_runner.py b/algo/algo_runner.py index 40b8d1e..3541a1c 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,4 @@ +import asyncio import concurrent.futures import copy import uuid @@ -14,9 +15,10 @@ from services.model_service import ModelService from common.global_logger import logger + class AlgoRunner: def __init__(self, - device_service : DeviceService, + device_service: DeviceService, model_service: ModelService, relation_service: DeviceModelRelationService): @@ -27,8 +29,11 @@ self.thread_pool = GlobalThreadPool() self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程 + self.task_futures: Dict[str, concurrent.futures.Future] = {} # 用于存储线程的 future 对象 self.device_model_relations = {} + self.main_loop = asyncio.get_running_loop() + # 注册设备和模型的变化回调 # self.device_service.register_change_callback(self.on_device_change) # self.model_service.register_change_callback(self.on_model_change) @@ -37,16 +42,16 @@ async def start(self): logger.info("Starting AlgoRunner...") """在程序启动时调用,读取设备和模型,启动检测线程""" - self.model_manager.load_models() - self.load_and_start_devices() + await self.model_manager.load_models() + await self.load_and_start_devices() - def load_and_start_devices(self): + async def load_and_start_devices(self): """从数据库读取设备列表并启动线程""" - devices = self.device_service.get_device_list() + devices = await self.device_service.get_device_list() for device in devices: - self.start_device_thread(device) + asyncio.create_task(self.start_device_thread(device)) - def start_device_thread(self, device: Device): + async def start_device_thread(self, device: Device): device = copy.deepcopy(device) """为单个设备启动检测线程""" @@ -65,30 +70,27 @@ # return # 获取设备绑定的模型列表 - relations = self.relation_service.get_device_models(device.id) - relations = [r for r in relations if r.is_use] - self.device_model_relations[device.id] = relations + async for db in get_db(): + relation_service = DeviceModelRelationService(db) + relations = await relation_service.get_device_models(device.id) + relations = [r for r in relations if r.is_use] + self.device_model_relations[device.id] = relations - if not relations: - logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") - return + if not relations: + logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") + return - for r in relations: - if r.algo_model_id not in self.model_manager.models: - self.model_manager.load_new_model(r.algo_model_id) + for r in relations: + if r.algo_model_id not in self.model_manager.models: + await self.model_manager.load_new_model(r.algo_model_id) - model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations - if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] - device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list, thread_id=thread_id) - future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) - self.device_tasks[device.id] = device_detection_task - - # def stop_device_thread(self, device_id): - # """停止指定设备的检测线程""" - # if device_id in self.device_tasks: - # logger.info(f'stop device {device_id} thread') - # self.device_tasks[device_id].stop_detection_task() - # del self.device_tasks[device_id] + model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations + if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] + device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list, + thread_id=thread_id, db_session=db, main_loop=self.main_loop) + future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) + self.device_tasks[device.id] = device_detection_task + self.task_futures[thread_id] = future def stop_device_thread(self, device_id): """停止指定设备的检测线程,并确保其成功停止""" @@ -96,7 +98,8 @@ logger.info(f'stop device {device_id} thread') # 获取线程 ID 和 future 对象 thread_id = self.device_tasks[device_id].thread_id - future = self.thread_pool.get_task_future(thread_id=thread_id) + # future = self.thread_pool.get_task_future(thread_id=thread_id) + future = self.task_futures[thread_id] if future: if not future.done(): @@ -122,17 +125,17 @@ else: logger.warning(f"No task exists for device {device_id} in device_tasks.") - def restart_device_thread(self, device_id): + async def restart_device_thread(self, device_id): self.stop_device_thread(device_id) - device = self.device_service.get_device(device_id) - self.start_device_thread(device) + device = await self.device_service.get_device(device_id) + asyncio.create_task(self.start_device_thread(device)) - def on_device_change(self, device_id, change_type): + async def on_device_change(self, device_id, change_type): logger.info(f"on device change, device {device_id} {change_type}") """设备变化回调""" if change_type == NotifyChangeType.DEVICE_CREATE: - device = self.device_service.get_device(device_id) - self.start_device_thread(device) + device = await self.device_service.get_device(device_id) + await self.start_device_thread(device) elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: @@ -141,9 +144,9 @@ old_device = self.device_tasks[device_id].device new_device = self.device_service.get_device(device_id) - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) - def on_model_change(self, model_id, change_type): + async def on_model_change(self, model_id, change_type): """模型变化回调""" if change_type == NotifyChangeType.MODEL_UPDATE: self.model_manager.reload_model(model_id) @@ -153,10 +156,10 @@ if relation_info.model_id == model_id ] for device_id in devices_to_reload: - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) - def on_relation_change(self, device_id, change_type): + async def on_relation_change(self, device_id, change_type): """设备模型关系变化回调""" # todo 判断模型绑定关系是否真的发生了变化,避免不必要的重启 if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index cdb97bf..8724c0d 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -39,7 +39,7 @@ class DeviceDetectionTask: - def __init__(self, device: Device, model_exec_list: List[AlgoModelExec], thread_id): + def __init__(self, device: Device, model_exec_list: List[AlgoModelExec], thread_id, db_session, main_loop): self.device = device self.model_exec_list = model_exec_list self.__stop_event = Event() # 使用 Event 控制线程的运行状态 @@ -47,12 +47,12 @@ self.push_ts = None self.thread_id = thread_id - with next(get_db()) as db: - self.device_frame_service = DeviceFrameService(db) - self.frame_analysis_result_service = FrameAnalysisResultService(db) + self.device_frame_service = DeviceFrameService(db_session) + self.frame_analysis_result_service = FrameAnalysisResultService(db_session) self.thread_pool = GlobalThreadPool() self.device_status_manager = DeviceStatusManager() + self.main_loop = main_loop self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, device_thread_id=thread_id) @@ -76,8 +76,15 @@ def save_frame_results(self, frame, results_map): if not self.check_frame_interval(): return - device_frame = self.device_frame_service.add_frame(self.device.id, frame) + # device_frame = self.device_frame_service.add_frame(self.device.id, frame) + # frame_id = device_frame.id + + future = asyncio.run_coroutine_threadsafe( + self.device_frame_service.add_frame(self.device.id, frame), self.main_loop + ) + device_frame = future.result() frame_id = device_frame.id + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') frame_results = [] for model_exec_id, results in results_map.items(): @@ -92,7 +99,10 @@ location=r.location, ) frame_results.append(frame_result) - self.frame_analysis_result_service.add_frame_analysis_results(frame_results) + asyncio.run_coroutine_threadsafe( + self.frame_analysis_result_service.add_frame_analysis_results(frame_results), + self.main_loop + ) self.thread_pool.submit_task(self.push_frame_results, frame_results) def push_frame_results(self, frame_results): diff --git a/algo/model_manager.py b/algo/model_manager.py index e3bc4c9..1c30867 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -25,18 +25,18 @@ self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up - def query_model_inuse(self): - algo_model_list = list(self.model_service.get_models_in_use()) + async def query_model_inuse(self): + algo_model_list = list(await self.model_service.get_models_in_use()) for algo_model in algo_model_list: self.models[algo_model.id] = AlgoModelExec( algo_model_id=algo_model.id, algo_model_info=algo_model ) - def load_models(self): + async def load_models(self): logger.info('loading models') self.models = {} - self.query_model_inuse() + await self.query_model_inuse() if not self.models: logger.info("no model in use") for algo_model_id, algo_model_exec in self.models.items(): @@ -69,8 +69,8 @@ algo_model_exec.algo_model_exec = None self.load_model(algo_model_exec) - def load_new_model(self, model_id): - algo_model = self.model_service.get_model_by_id(model_id) + async def load_new_model(self, model_id): + algo_model = await self.model_service.get_model_by_id(model_id) if algo_model: algo_model_exec = AlgoModelExec( algo_model_id=algo_model.id, diff --git a/algo/scene_runner.py b/algo/scene_runner.py index 2fb2b2b..e5c067e 100644 --- a/algo/scene_runner.py +++ b/algo/scene_runner.py @@ -1,3 +1,4 @@ +import asyncio import concurrent import copy import uuid @@ -7,6 +8,7 @@ from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool from common.string_utils import get_class, camel_to_snake +from db.database import get_db from entity.device import Device from scene_handler.base_scene_handler import BaseSceneHandler from services.device_scene_relation_service import DeviceSceneRelationService @@ -38,16 +40,16 @@ # self.relation_service.register_change_callback(self.on_relation_change) async def start(self): - self.load_and_start_devices() + await self.load_and_start_devices() - def load_and_start_devices(self): + async def load_and_start_devices(self): """从数据库读取设备列表并启动线程""" - devices = self.device_service.get_device_list() + devices = await self.device_service.get_device_list() devices = [device for device in devices if device.mode == DEVICE_MODE.SCENE] for device in devices: - self.start_device_thread(device) + asyncio.create_task(self.start_device_thread(device)) - def start_device_thread(self, device: Device): + async def start_device_thread(self, device: Device): device = copy.deepcopy(device) """为单个设备启动检测线程""" @@ -59,18 +61,20 @@ return thread_id = f'device_{device.id}_scene_{uuid.uuid4()}' - scene = self.relation_service.get_device_scene(device.id) - if scene: - self.device_scene_relations[device.id] = scene - try: - handle_task_name = scene.scene_handle_task - handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) - handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) - future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) - self.device_tasks[device.id] = handler_instance - logger.info(f'start thread {thread_id}, device info: {device}') - except Exception as e: - logger.exception(f"设备 {device.code} 场景 {scene.scene_name} 启动异常 : {e}") + async for db in get_db(): + relation_service = DeviceSceneRelationService(db) + scene = await relation_service.get_device_scene(device.id) + if scene: + self.device_scene_relations[device.id] = scene + try: + handle_task_name = scene.scene_handle_task + handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) + handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) + future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) + self.device_tasks[device.id] = handler_instance + logger.info(f'start thread {thread_id}, device info: {device}') + except Exception as e: + logger.exception(f"设备 {device.code} 场景 {scene.scene_name} 启动异常 : {e}") def stop_device_thread(self, device_id): """停止指定设备的检测线程,并确保其成功停止""" @@ -104,20 +108,20 @@ else: logger.warning(f"No task exists for device {device_id} in device_tasks.") - def restart_device_thread(self, device_id): + async def restart_device_thread(self, device_id): try: self.stop_device_thread(device_id) - device = self.device_service.get_device(device_id) - self.start_device_thread(device) + device = await self.device_service.get_device(device_id) + asyncio.create_task(self.start_device_thread(device)) except Exception as e: logger.error(e) - def on_device_change(self, device_id, change_type): + async def on_device_change(self, device_id, change_type): logger.info(f"on device change, device {device_id} {change_type}") """设备变化回调""" if change_type == NotifyChangeType.DEVICE_CREATE: - device = self.device_service.get_device(device_id) - self.start_device_thread(device) + device = await self.device_service.get_device(device_id) + await self.start_device_thread(device) elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: @@ -126,9 +130,9 @@ old_device = self.device_tasks[device_id].device new_device = self.device_service.get_device(device_id) - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) - def on_scene_change(self, scene_id, change_type): + async def on_scene_change(self, scene_id, change_type): """模型变化回调""" if change_type == NotifyChangeType.SCENE_UPDATE: @@ -137,10 +141,10 @@ if relation_info.scene_id == scene_id ] for device_id in devices_to_reload: - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) - def on_relation_change(self, device_id, change_type): + async def on_relation_change(self, device_id, change_type): """设备模型关系变化回调""" # 判断模型绑定关系是否真的发生了变化,避免不必要的重启 if change_type == NotifyChangeType.DEVICE_SCENE_RELATION_UPDATE: - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) diff --git a/app_instance.py b/app_instance.py index f9d1ad9..33f5e80 100644 --- a/app_instance.py +++ b/app_instance.py @@ -56,7 +56,7 @@ relation_service=model_relation_service, ) app.state.algo_runner = algo_runner - # await algo_runner.start() + await algo_runner.start() scene_runner = SceneRunner( device_service=device_service, @@ -66,9 +66,7 @@ main_loop=main_loop ) app.state.scene_runner = scene_runner - # await scene_runner.start() - - + await scene_runner.start() yield # 允许请求处理 diff --git a/algo/algo_runner.py b/algo/algo_runner.py index 40b8d1e..3541a1c 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,4 @@ +import asyncio import concurrent.futures import copy import uuid @@ -14,9 +15,10 @@ from services.model_service import ModelService from common.global_logger import logger + class AlgoRunner: def __init__(self, - device_service : DeviceService, + device_service: DeviceService, model_service: ModelService, relation_service: DeviceModelRelationService): @@ -27,8 +29,11 @@ self.thread_pool = GlobalThreadPool() self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程 + self.task_futures: Dict[str, concurrent.futures.Future] = {} # 用于存储线程的 future 对象 self.device_model_relations = {} + self.main_loop = asyncio.get_running_loop() + # 注册设备和模型的变化回调 # self.device_service.register_change_callback(self.on_device_change) # self.model_service.register_change_callback(self.on_model_change) @@ -37,16 +42,16 @@ async def start(self): logger.info("Starting AlgoRunner...") """在程序启动时调用,读取设备和模型,启动检测线程""" - self.model_manager.load_models() - self.load_and_start_devices() + await self.model_manager.load_models() + await self.load_and_start_devices() - def load_and_start_devices(self): + async def load_and_start_devices(self): """从数据库读取设备列表并启动线程""" - devices = self.device_service.get_device_list() + devices = await self.device_service.get_device_list() for device in devices: - self.start_device_thread(device) + asyncio.create_task(self.start_device_thread(device)) - def start_device_thread(self, device: Device): + async def start_device_thread(self, device: Device): device = copy.deepcopy(device) """为单个设备启动检测线程""" @@ -65,30 +70,27 @@ # return # 获取设备绑定的模型列表 - relations = self.relation_service.get_device_models(device.id) - relations = [r for r in relations if r.is_use] - self.device_model_relations[device.id] = relations + async for db in get_db(): + relation_service = DeviceModelRelationService(db) + relations = await relation_service.get_device_models(device.id) + relations = [r for r in relations if r.is_use] + self.device_model_relations[device.id] = relations - if not relations: - logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") - return + if not relations: + logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") + return - for r in relations: - if r.algo_model_id not in self.model_manager.models: - self.model_manager.load_new_model(r.algo_model_id) + for r in relations: + if r.algo_model_id not in self.model_manager.models: + await self.model_manager.load_new_model(r.algo_model_id) - model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations - if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] - device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list, thread_id=thread_id) - future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) - self.device_tasks[device.id] = device_detection_task - - # def stop_device_thread(self, device_id): - # """停止指定设备的检测线程""" - # if device_id in self.device_tasks: - # logger.info(f'stop device {device_id} thread') - # self.device_tasks[device_id].stop_detection_task() - # del self.device_tasks[device_id] + model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations + if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] + device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list, + thread_id=thread_id, db_session=db, main_loop=self.main_loop) + future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) + self.device_tasks[device.id] = device_detection_task + self.task_futures[thread_id] = future def stop_device_thread(self, device_id): """停止指定设备的检测线程,并确保其成功停止""" @@ -96,7 +98,8 @@ logger.info(f'stop device {device_id} thread') # 获取线程 ID 和 future 对象 thread_id = self.device_tasks[device_id].thread_id - future = self.thread_pool.get_task_future(thread_id=thread_id) + # future = self.thread_pool.get_task_future(thread_id=thread_id) + future = self.task_futures[thread_id] if future: if not future.done(): @@ -122,17 +125,17 @@ else: logger.warning(f"No task exists for device {device_id} in device_tasks.") - def restart_device_thread(self, device_id): + async def restart_device_thread(self, device_id): self.stop_device_thread(device_id) - device = self.device_service.get_device(device_id) - self.start_device_thread(device) + device = await self.device_service.get_device(device_id) + asyncio.create_task(self.start_device_thread(device)) - def on_device_change(self, device_id, change_type): + async def on_device_change(self, device_id, change_type): logger.info(f"on device change, device {device_id} {change_type}") """设备变化回调""" if change_type == NotifyChangeType.DEVICE_CREATE: - device = self.device_service.get_device(device_id) - self.start_device_thread(device) + device = await self.device_service.get_device(device_id) + await self.start_device_thread(device) elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: @@ -141,9 +144,9 @@ old_device = self.device_tasks[device_id].device new_device = self.device_service.get_device(device_id) - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) - def on_model_change(self, model_id, change_type): + async def on_model_change(self, model_id, change_type): """模型变化回调""" if change_type == NotifyChangeType.MODEL_UPDATE: self.model_manager.reload_model(model_id) @@ -153,10 +156,10 @@ if relation_info.model_id == model_id ] for device_id in devices_to_reload: - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) - def on_relation_change(self, device_id, change_type): + async def on_relation_change(self, device_id, change_type): """设备模型关系变化回调""" # todo 判断模型绑定关系是否真的发生了变化,避免不必要的重启 if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index cdb97bf..8724c0d 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -39,7 +39,7 @@ class DeviceDetectionTask: - def __init__(self, device: Device, model_exec_list: List[AlgoModelExec], thread_id): + def __init__(self, device: Device, model_exec_list: List[AlgoModelExec], thread_id, db_session, main_loop): self.device = device self.model_exec_list = model_exec_list self.__stop_event = Event() # 使用 Event 控制线程的运行状态 @@ -47,12 +47,12 @@ self.push_ts = None self.thread_id = thread_id - with next(get_db()) as db: - self.device_frame_service = DeviceFrameService(db) - self.frame_analysis_result_service = FrameAnalysisResultService(db) + self.device_frame_service = DeviceFrameService(db_session) + self.frame_analysis_result_service = FrameAnalysisResultService(db_session) self.thread_pool = GlobalThreadPool() self.device_status_manager = DeviceStatusManager() + self.main_loop = main_loop self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, device_thread_id=thread_id) @@ -76,8 +76,15 @@ def save_frame_results(self, frame, results_map): if not self.check_frame_interval(): return - device_frame = self.device_frame_service.add_frame(self.device.id, frame) + # device_frame = self.device_frame_service.add_frame(self.device.id, frame) + # frame_id = device_frame.id + + future = asyncio.run_coroutine_threadsafe( + self.device_frame_service.add_frame(self.device.id, frame), self.main_loop + ) + device_frame = future.result() frame_id = device_frame.id + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') frame_results = [] for model_exec_id, results in results_map.items(): @@ -92,7 +99,10 @@ location=r.location, ) frame_results.append(frame_result) - self.frame_analysis_result_service.add_frame_analysis_results(frame_results) + asyncio.run_coroutine_threadsafe( + self.frame_analysis_result_service.add_frame_analysis_results(frame_results), + self.main_loop + ) self.thread_pool.submit_task(self.push_frame_results, frame_results) def push_frame_results(self, frame_results): diff --git a/algo/model_manager.py b/algo/model_manager.py index e3bc4c9..1c30867 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -25,18 +25,18 @@ self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up - def query_model_inuse(self): - algo_model_list = list(self.model_service.get_models_in_use()) + async def query_model_inuse(self): + algo_model_list = list(await self.model_service.get_models_in_use()) for algo_model in algo_model_list: self.models[algo_model.id] = AlgoModelExec( algo_model_id=algo_model.id, algo_model_info=algo_model ) - def load_models(self): + async def load_models(self): logger.info('loading models') self.models = {} - self.query_model_inuse() + await self.query_model_inuse() if not self.models: logger.info("no model in use") for algo_model_id, algo_model_exec in self.models.items(): @@ -69,8 +69,8 @@ algo_model_exec.algo_model_exec = None self.load_model(algo_model_exec) - def load_new_model(self, model_id): - algo_model = self.model_service.get_model_by_id(model_id) + async def load_new_model(self, model_id): + algo_model = await self.model_service.get_model_by_id(model_id) if algo_model: algo_model_exec = AlgoModelExec( algo_model_id=algo_model.id, diff --git a/algo/scene_runner.py b/algo/scene_runner.py index 2fb2b2b..e5c067e 100644 --- a/algo/scene_runner.py +++ b/algo/scene_runner.py @@ -1,3 +1,4 @@ +import asyncio import concurrent import copy import uuid @@ -7,6 +8,7 @@ from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool from common.string_utils import get_class, camel_to_snake +from db.database import get_db from entity.device import Device from scene_handler.base_scene_handler import BaseSceneHandler from services.device_scene_relation_service import DeviceSceneRelationService @@ -38,16 +40,16 @@ # self.relation_service.register_change_callback(self.on_relation_change) async def start(self): - self.load_and_start_devices() + await self.load_and_start_devices() - def load_and_start_devices(self): + async def load_and_start_devices(self): """从数据库读取设备列表并启动线程""" - devices = self.device_service.get_device_list() + devices = await self.device_service.get_device_list() devices = [device for device in devices if device.mode == DEVICE_MODE.SCENE] for device in devices: - self.start_device_thread(device) + asyncio.create_task(self.start_device_thread(device)) - def start_device_thread(self, device: Device): + async def start_device_thread(self, device: Device): device = copy.deepcopy(device) """为单个设备启动检测线程""" @@ -59,18 +61,20 @@ return thread_id = f'device_{device.id}_scene_{uuid.uuid4()}' - scene = self.relation_service.get_device_scene(device.id) - if scene: - self.device_scene_relations[device.id] = scene - try: - handle_task_name = scene.scene_handle_task - handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) - handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) - future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) - self.device_tasks[device.id] = handler_instance - logger.info(f'start thread {thread_id}, device info: {device}') - except Exception as e: - logger.exception(f"设备 {device.code} 场景 {scene.scene_name} 启动异常 : {e}") + async for db in get_db(): + relation_service = DeviceSceneRelationService(db) + scene = await relation_service.get_device_scene(device.id) + if scene: + self.device_scene_relations[device.id] = scene + try: + handle_task_name = scene.scene_handle_task + handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) + handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) + future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) + self.device_tasks[device.id] = handler_instance + logger.info(f'start thread {thread_id}, device info: {device}') + except Exception as e: + logger.exception(f"设备 {device.code} 场景 {scene.scene_name} 启动异常 : {e}") def stop_device_thread(self, device_id): """停止指定设备的检测线程,并确保其成功停止""" @@ -104,20 +108,20 @@ else: logger.warning(f"No task exists for device {device_id} in device_tasks.") - def restart_device_thread(self, device_id): + async def restart_device_thread(self, device_id): try: self.stop_device_thread(device_id) - device = self.device_service.get_device(device_id) - self.start_device_thread(device) + device = await self.device_service.get_device(device_id) + asyncio.create_task(self.start_device_thread(device)) except Exception as e: logger.error(e) - def on_device_change(self, device_id, change_type): + async def on_device_change(self, device_id, change_type): logger.info(f"on device change, device {device_id} {change_type}") """设备变化回调""" if change_type == NotifyChangeType.DEVICE_CREATE: - device = self.device_service.get_device(device_id) - self.start_device_thread(device) + device = await self.device_service.get_device(device_id) + await self.start_device_thread(device) elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: @@ -126,9 +130,9 @@ old_device = self.device_tasks[device_id].device new_device = self.device_service.get_device(device_id) - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) - def on_scene_change(self, scene_id, change_type): + async def on_scene_change(self, scene_id, change_type): """模型变化回调""" if change_type == NotifyChangeType.SCENE_UPDATE: @@ -137,10 +141,10 @@ if relation_info.scene_id == scene_id ] for device_id in devices_to_reload: - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) - def on_relation_change(self, device_id, change_type): + async def on_relation_change(self, device_id, change_type): """设备模型关系变化回调""" # 判断模型绑定关系是否真的发生了变化,避免不必要的重启 if change_type == NotifyChangeType.DEVICE_SCENE_RELATION_UPDATE: - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) diff --git a/app_instance.py b/app_instance.py index f9d1ad9..33f5e80 100644 --- a/app_instance.py +++ b/app_instance.py @@ -56,7 +56,7 @@ relation_service=model_relation_service, ) app.state.algo_runner = algo_runner - # await algo_runner.start() + await algo_runner.start() scene_runner = SceneRunner( device_service=device_service, @@ -66,9 +66,7 @@ main_loop=main_loop ) app.state.scene_runner = scene_runner - # await scene_runner.start() - - + await scene_runner.start() yield # 允许请求处理 diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py index 7f3be9f..ef9a356 100644 --- a/common/global_thread_pool.py +++ b/common/global_thread_pool.py @@ -1,3 +1,4 @@ +import asyncio import traceback import uuid from concurrent.futures import ThreadPoolExecutor @@ -11,20 +12,37 @@ return str(uuid.uuid4()) +def wrapper(func, *args, **kwargs): + return func(*args, **kwargs) + + class GlobalThreadPool: _instance = None _lock = threading.Lock() - def __new__(cls, *args, **kwargs): - with cls._lock: - if cls._instance is None: - # 第一次创建实例时调用父类的 __new__ 来创建实例 - cls._instance = super(GlobalThreadPool, cls).__new__(cls) - # 在此进行一次性的初始化,比如线程池的创建 - max_workers = kwargs.get('max_workers', 10) - cls._instance.executor = ThreadPoolExecutor(max_workers=max_workers) - cls._instance.task_map = {} # 初始化任务映射 - return cls._instance + def __new__(cls, max_workers=5): + if not cls._instance: + with cls._lock: + if not cls._instance: + cls._instance = super(GlobalThreadPool, cls).__new__(cls) + cls._instance._initialize(max_workers) + return cls._instance + + def _initialize(self, max_workers): + self.executor = ThreadPoolExecutor(max_workers=max_workers) + self.loop = asyncio.get_running_loop() # 获取当前的事件循环 + self.task_map = {} + + # def __new__(cls, *args, **kwargs): + # with cls._lock: + # if cls._instance is None: + # # 第一次创建实例时调用父类的 __new__ 来创建实例 + # cls._instance = super(GlobalThreadPool, cls).__new__(cls) + # # 在此进行一次性的初始化,比如线程池的创建 + # max_workers = kwargs.get('max_workers', 10) + # cls._instance.executor = ThreadPoolExecutor(max_workers=max_workers) + # cls._instance.task_map = {} # 初始化任务映射 + # return cls._instance def submit_task(self, fn, *args, thread_id=None, **kwargs): """提交任务到线程池,并记录线程 ID""" @@ -32,7 +50,9 @@ thread_id = generate_thread_id() if self.check_task_is_running(thread_id): raise ValueError(f"线程 ID {thread_id} 已存在") - future = self.executor.submit(fn, *args, **kwargs) + # future = self.executor.submit(fn, *args, **kwargs) + future = self.loop.run_in_executor(None, wrapper, fn, *args, **kwargs) + self.task_map[thread_id] = future # 记录线程 ID 和 Future 对象的映射 future.add_done_callback(lambda f: self._handle_exception(f, thread_id)) return thread_id @@ -101,4 +121,4 @@ result = future.result() except Exception as e: logger.error(f"Task for thread {thread_id} raised an exception: {e}") - logger.error(f"Traceback: {traceback.format_exc()}") \ No newline at end of file + logger.error(f"Traceback: {traceback.format_exc()}") diff --git a/algo/algo_runner.py b/algo/algo_runner.py index 40b8d1e..3541a1c 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,4 @@ +import asyncio import concurrent.futures import copy import uuid @@ -14,9 +15,10 @@ from services.model_service import ModelService from common.global_logger import logger + class AlgoRunner: def __init__(self, - device_service : DeviceService, + device_service: DeviceService, model_service: ModelService, relation_service: DeviceModelRelationService): @@ -27,8 +29,11 @@ self.thread_pool = GlobalThreadPool() self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程 + self.task_futures: Dict[str, concurrent.futures.Future] = {} # 用于存储线程的 future 对象 self.device_model_relations = {} + self.main_loop = asyncio.get_running_loop() + # 注册设备和模型的变化回调 # self.device_service.register_change_callback(self.on_device_change) # self.model_service.register_change_callback(self.on_model_change) @@ -37,16 +42,16 @@ async def start(self): logger.info("Starting AlgoRunner...") """在程序启动时调用,读取设备和模型,启动检测线程""" - self.model_manager.load_models() - self.load_and_start_devices() + await self.model_manager.load_models() + await self.load_and_start_devices() - def load_and_start_devices(self): + async def load_and_start_devices(self): """从数据库读取设备列表并启动线程""" - devices = self.device_service.get_device_list() + devices = await self.device_service.get_device_list() for device in devices: - self.start_device_thread(device) + asyncio.create_task(self.start_device_thread(device)) - def start_device_thread(self, device: Device): + async def start_device_thread(self, device: Device): device = copy.deepcopy(device) """为单个设备启动检测线程""" @@ -65,30 +70,27 @@ # return # 获取设备绑定的模型列表 - relations = self.relation_service.get_device_models(device.id) - relations = [r for r in relations if r.is_use] - self.device_model_relations[device.id] = relations + async for db in get_db(): + relation_service = DeviceModelRelationService(db) + relations = await relation_service.get_device_models(device.id) + relations = [r for r in relations if r.is_use] + self.device_model_relations[device.id] = relations - if not relations: - logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") - return + if not relations: + logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") + return - for r in relations: - if r.algo_model_id not in self.model_manager.models: - self.model_manager.load_new_model(r.algo_model_id) + for r in relations: + if r.algo_model_id not in self.model_manager.models: + await self.model_manager.load_new_model(r.algo_model_id) - model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations - if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] - device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list, thread_id=thread_id) - future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) - self.device_tasks[device.id] = device_detection_task - - # def stop_device_thread(self, device_id): - # """停止指定设备的检测线程""" - # if device_id in self.device_tasks: - # logger.info(f'stop device {device_id} thread') - # self.device_tasks[device_id].stop_detection_task() - # del self.device_tasks[device_id] + model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations + if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] + device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list, + thread_id=thread_id, db_session=db, main_loop=self.main_loop) + future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) + self.device_tasks[device.id] = device_detection_task + self.task_futures[thread_id] = future def stop_device_thread(self, device_id): """停止指定设备的检测线程,并确保其成功停止""" @@ -96,7 +98,8 @@ logger.info(f'stop device {device_id} thread') # 获取线程 ID 和 future 对象 thread_id = self.device_tasks[device_id].thread_id - future = self.thread_pool.get_task_future(thread_id=thread_id) + # future = self.thread_pool.get_task_future(thread_id=thread_id) + future = self.task_futures[thread_id] if future: if not future.done(): @@ -122,17 +125,17 @@ else: logger.warning(f"No task exists for device {device_id} in device_tasks.") - def restart_device_thread(self, device_id): + async def restart_device_thread(self, device_id): self.stop_device_thread(device_id) - device = self.device_service.get_device(device_id) - self.start_device_thread(device) + device = await self.device_service.get_device(device_id) + asyncio.create_task(self.start_device_thread(device)) - def on_device_change(self, device_id, change_type): + async def on_device_change(self, device_id, change_type): logger.info(f"on device change, device {device_id} {change_type}") """设备变化回调""" if change_type == NotifyChangeType.DEVICE_CREATE: - device = self.device_service.get_device(device_id) - self.start_device_thread(device) + device = await self.device_service.get_device(device_id) + await self.start_device_thread(device) elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: @@ -141,9 +144,9 @@ old_device = self.device_tasks[device_id].device new_device = self.device_service.get_device(device_id) - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) - def on_model_change(self, model_id, change_type): + async def on_model_change(self, model_id, change_type): """模型变化回调""" if change_type == NotifyChangeType.MODEL_UPDATE: self.model_manager.reload_model(model_id) @@ -153,10 +156,10 @@ if relation_info.model_id == model_id ] for device_id in devices_to_reload: - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) - def on_relation_change(self, device_id, change_type): + async def on_relation_change(self, device_id, change_type): """设备模型关系变化回调""" # todo 判断模型绑定关系是否真的发生了变化,避免不必要的重启 if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index cdb97bf..8724c0d 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -39,7 +39,7 @@ class DeviceDetectionTask: - def __init__(self, device: Device, model_exec_list: List[AlgoModelExec], thread_id): + def __init__(self, device: Device, model_exec_list: List[AlgoModelExec], thread_id, db_session, main_loop): self.device = device self.model_exec_list = model_exec_list self.__stop_event = Event() # 使用 Event 控制线程的运行状态 @@ -47,12 +47,12 @@ self.push_ts = None self.thread_id = thread_id - with next(get_db()) as db: - self.device_frame_service = DeviceFrameService(db) - self.frame_analysis_result_service = FrameAnalysisResultService(db) + self.device_frame_service = DeviceFrameService(db_session) + self.frame_analysis_result_service = FrameAnalysisResultService(db_session) self.thread_pool = GlobalThreadPool() self.device_status_manager = DeviceStatusManager() + self.main_loop = main_loop self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, device_thread_id=thread_id) @@ -76,8 +76,15 @@ def save_frame_results(self, frame, results_map): if not self.check_frame_interval(): return - device_frame = self.device_frame_service.add_frame(self.device.id, frame) + # device_frame = self.device_frame_service.add_frame(self.device.id, frame) + # frame_id = device_frame.id + + future = asyncio.run_coroutine_threadsafe( + self.device_frame_service.add_frame(self.device.id, frame), self.main_loop + ) + device_frame = future.result() frame_id = device_frame.id + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') frame_results = [] for model_exec_id, results in results_map.items(): @@ -92,7 +99,10 @@ location=r.location, ) frame_results.append(frame_result) - self.frame_analysis_result_service.add_frame_analysis_results(frame_results) + asyncio.run_coroutine_threadsafe( + self.frame_analysis_result_service.add_frame_analysis_results(frame_results), + self.main_loop + ) self.thread_pool.submit_task(self.push_frame_results, frame_results) def push_frame_results(self, frame_results): diff --git a/algo/model_manager.py b/algo/model_manager.py index e3bc4c9..1c30867 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -25,18 +25,18 @@ self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up - def query_model_inuse(self): - algo_model_list = list(self.model_service.get_models_in_use()) + async def query_model_inuse(self): + algo_model_list = list(await self.model_service.get_models_in_use()) for algo_model in algo_model_list: self.models[algo_model.id] = AlgoModelExec( algo_model_id=algo_model.id, algo_model_info=algo_model ) - def load_models(self): + async def load_models(self): logger.info('loading models') self.models = {} - self.query_model_inuse() + await self.query_model_inuse() if not self.models: logger.info("no model in use") for algo_model_id, algo_model_exec in self.models.items(): @@ -69,8 +69,8 @@ algo_model_exec.algo_model_exec = None self.load_model(algo_model_exec) - def load_new_model(self, model_id): - algo_model = self.model_service.get_model_by_id(model_id) + async def load_new_model(self, model_id): + algo_model = await self.model_service.get_model_by_id(model_id) if algo_model: algo_model_exec = AlgoModelExec( algo_model_id=algo_model.id, diff --git a/algo/scene_runner.py b/algo/scene_runner.py index 2fb2b2b..e5c067e 100644 --- a/algo/scene_runner.py +++ b/algo/scene_runner.py @@ -1,3 +1,4 @@ +import asyncio import concurrent import copy import uuid @@ -7,6 +8,7 @@ from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool from common.string_utils import get_class, camel_to_snake +from db.database import get_db from entity.device import Device from scene_handler.base_scene_handler import BaseSceneHandler from services.device_scene_relation_service import DeviceSceneRelationService @@ -38,16 +40,16 @@ # self.relation_service.register_change_callback(self.on_relation_change) async def start(self): - self.load_and_start_devices() + await self.load_and_start_devices() - def load_and_start_devices(self): + async def load_and_start_devices(self): """从数据库读取设备列表并启动线程""" - devices = self.device_service.get_device_list() + devices = await self.device_service.get_device_list() devices = [device for device in devices if device.mode == DEVICE_MODE.SCENE] for device in devices: - self.start_device_thread(device) + asyncio.create_task(self.start_device_thread(device)) - def start_device_thread(self, device: Device): + async def start_device_thread(self, device: Device): device = copy.deepcopy(device) """为单个设备启动检测线程""" @@ -59,18 +61,20 @@ return thread_id = f'device_{device.id}_scene_{uuid.uuid4()}' - scene = self.relation_service.get_device_scene(device.id) - if scene: - self.device_scene_relations[device.id] = scene - try: - handle_task_name = scene.scene_handle_task - handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) - handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) - future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) - self.device_tasks[device.id] = handler_instance - logger.info(f'start thread {thread_id}, device info: {device}') - except Exception as e: - logger.exception(f"设备 {device.code} 场景 {scene.scene_name} 启动异常 : {e}") + async for db in get_db(): + relation_service = DeviceSceneRelationService(db) + scene = await relation_service.get_device_scene(device.id) + if scene: + self.device_scene_relations[device.id] = scene + try: + handle_task_name = scene.scene_handle_task + handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) + handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) + future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) + self.device_tasks[device.id] = handler_instance + logger.info(f'start thread {thread_id}, device info: {device}') + except Exception as e: + logger.exception(f"设备 {device.code} 场景 {scene.scene_name} 启动异常 : {e}") def stop_device_thread(self, device_id): """停止指定设备的检测线程,并确保其成功停止""" @@ -104,20 +108,20 @@ else: logger.warning(f"No task exists for device {device_id} in device_tasks.") - def restart_device_thread(self, device_id): + async def restart_device_thread(self, device_id): try: self.stop_device_thread(device_id) - device = self.device_service.get_device(device_id) - self.start_device_thread(device) + device = await self.device_service.get_device(device_id) + asyncio.create_task(self.start_device_thread(device)) except Exception as e: logger.error(e) - def on_device_change(self, device_id, change_type): + async def on_device_change(self, device_id, change_type): logger.info(f"on device change, device {device_id} {change_type}") """设备变化回调""" if change_type == NotifyChangeType.DEVICE_CREATE: - device = self.device_service.get_device(device_id) - self.start_device_thread(device) + device = await self.device_service.get_device(device_id) + await self.start_device_thread(device) elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: @@ -126,9 +130,9 @@ old_device = self.device_tasks[device_id].device new_device = self.device_service.get_device(device_id) - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) - def on_scene_change(self, scene_id, change_type): + async def on_scene_change(self, scene_id, change_type): """模型变化回调""" if change_type == NotifyChangeType.SCENE_UPDATE: @@ -137,10 +141,10 @@ if relation_info.scene_id == scene_id ] for device_id in devices_to_reload: - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) - def on_relation_change(self, device_id, change_type): + async def on_relation_change(self, device_id, change_type): """设备模型关系变化回调""" # 判断模型绑定关系是否真的发生了变化,避免不必要的重启 if change_type == NotifyChangeType.DEVICE_SCENE_RELATION_UPDATE: - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) diff --git a/app_instance.py b/app_instance.py index f9d1ad9..33f5e80 100644 --- a/app_instance.py +++ b/app_instance.py @@ -56,7 +56,7 @@ relation_service=model_relation_service, ) app.state.algo_runner = algo_runner - # await algo_runner.start() + await algo_runner.start() scene_runner = SceneRunner( device_service=device_service, @@ -66,9 +66,7 @@ main_loop=main_loop ) app.state.scene_runner = scene_runner - # await scene_runner.start() - - + await scene_runner.start() yield # 允许请求处理 diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py index 7f3be9f..ef9a356 100644 --- a/common/global_thread_pool.py +++ b/common/global_thread_pool.py @@ -1,3 +1,4 @@ +import asyncio import traceback import uuid from concurrent.futures import ThreadPoolExecutor @@ -11,20 +12,37 @@ return str(uuid.uuid4()) +def wrapper(func, *args, **kwargs): + return func(*args, **kwargs) + + class GlobalThreadPool: _instance = None _lock = threading.Lock() - def __new__(cls, *args, **kwargs): - with cls._lock: - if cls._instance is None: - # 第一次创建实例时调用父类的 __new__ 来创建实例 - cls._instance = super(GlobalThreadPool, cls).__new__(cls) - # 在此进行一次性的初始化,比如线程池的创建 - max_workers = kwargs.get('max_workers', 10) - cls._instance.executor = ThreadPoolExecutor(max_workers=max_workers) - cls._instance.task_map = {} # 初始化任务映射 - return cls._instance + def __new__(cls, max_workers=5): + if not cls._instance: + with cls._lock: + if not cls._instance: + cls._instance = super(GlobalThreadPool, cls).__new__(cls) + cls._instance._initialize(max_workers) + return cls._instance + + def _initialize(self, max_workers): + self.executor = ThreadPoolExecutor(max_workers=max_workers) + self.loop = asyncio.get_running_loop() # 获取当前的事件循环 + self.task_map = {} + + # def __new__(cls, *args, **kwargs): + # with cls._lock: + # if cls._instance is None: + # # 第一次创建实例时调用父类的 __new__ 来创建实例 + # cls._instance = super(GlobalThreadPool, cls).__new__(cls) + # # 在此进行一次性的初始化,比如线程池的创建 + # max_workers = kwargs.get('max_workers', 10) + # cls._instance.executor = ThreadPoolExecutor(max_workers=max_workers) + # cls._instance.task_map = {} # 初始化任务映射 + # return cls._instance def submit_task(self, fn, *args, thread_id=None, **kwargs): """提交任务到线程池,并记录线程 ID""" @@ -32,7 +50,9 @@ thread_id = generate_thread_id() if self.check_task_is_running(thread_id): raise ValueError(f"线程 ID {thread_id} 已存在") - future = self.executor.submit(fn, *args, **kwargs) + # future = self.executor.submit(fn, *args, **kwargs) + future = self.loop.run_in_executor(None, wrapper, fn, *args, **kwargs) + self.task_map[thread_id] = future # 记录线程 ID 和 Future 对象的映射 future.add_done_callback(lambda f: self._handle_exception(f, thread_id)) return thread_id @@ -101,4 +121,4 @@ result = future.result() except Exception as e: logger.error(f"Task for thread {thread_id} raised an exception: {e}") - logger.error(f"Traceback: {traceback.format_exc()}") \ No newline at end of file + logger.error(f"Traceback: {traceback.format_exc()}") diff --git a/model_handler/base_model_handler.py b/model_handler/base_model_handler.py index 8e0ca4b..8797f0d 100644 --- a/model_handler/base_model_handler.py +++ b/model_handler/base_model_handler.py @@ -18,7 +18,7 @@ verbose=False, stream=True) results = list(results_generator) # 确保生成器转换为列表 result = results[0] - logger.debug(f"model {self.model.algo_model_info.name} result: {len(result)}") + # logger.debug(f"model {self.model.algo_model_info.name} result: {len(result)}") return result def post_process(self, frame, model_result, annotator): diff --git a/algo/algo_runner.py b/algo/algo_runner.py index 40b8d1e..3541a1c 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,4 @@ +import asyncio import concurrent.futures import copy import uuid @@ -14,9 +15,10 @@ from services.model_service import ModelService from common.global_logger import logger + class AlgoRunner: def __init__(self, - device_service : DeviceService, + device_service: DeviceService, model_service: ModelService, relation_service: DeviceModelRelationService): @@ -27,8 +29,11 @@ self.thread_pool = GlobalThreadPool() self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程 + self.task_futures: Dict[str, concurrent.futures.Future] = {} # 用于存储线程的 future 对象 self.device_model_relations = {} + self.main_loop = asyncio.get_running_loop() + # 注册设备和模型的变化回调 # self.device_service.register_change_callback(self.on_device_change) # self.model_service.register_change_callback(self.on_model_change) @@ -37,16 +42,16 @@ async def start(self): logger.info("Starting AlgoRunner...") """在程序启动时调用,读取设备和模型,启动检测线程""" - self.model_manager.load_models() - self.load_and_start_devices() + await self.model_manager.load_models() + await self.load_and_start_devices() - def load_and_start_devices(self): + async def load_and_start_devices(self): """从数据库读取设备列表并启动线程""" - devices = self.device_service.get_device_list() + devices = await self.device_service.get_device_list() for device in devices: - self.start_device_thread(device) + asyncio.create_task(self.start_device_thread(device)) - def start_device_thread(self, device: Device): + async def start_device_thread(self, device: Device): device = copy.deepcopy(device) """为单个设备启动检测线程""" @@ -65,30 +70,27 @@ # return # 获取设备绑定的模型列表 - relations = self.relation_service.get_device_models(device.id) - relations = [r for r in relations if r.is_use] - self.device_model_relations[device.id] = relations + async for db in get_db(): + relation_service = DeviceModelRelationService(db) + relations = await relation_service.get_device_models(device.id) + relations = [r for r in relations if r.is_use] + self.device_model_relations[device.id] = relations - if not relations: - logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") - return + if not relations: + logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") + return - for r in relations: - if r.algo_model_id not in self.model_manager.models: - self.model_manager.load_new_model(r.algo_model_id) + for r in relations: + if r.algo_model_id not in self.model_manager.models: + await self.model_manager.load_new_model(r.algo_model_id) - model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations - if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] - device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list, thread_id=thread_id) - future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) - self.device_tasks[device.id] = device_detection_task - - # def stop_device_thread(self, device_id): - # """停止指定设备的检测线程""" - # if device_id in self.device_tasks: - # logger.info(f'stop device {device_id} thread') - # self.device_tasks[device_id].stop_detection_task() - # del self.device_tasks[device_id] + model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations + if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] + device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list, + thread_id=thread_id, db_session=db, main_loop=self.main_loop) + future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) + self.device_tasks[device.id] = device_detection_task + self.task_futures[thread_id] = future def stop_device_thread(self, device_id): """停止指定设备的检测线程,并确保其成功停止""" @@ -96,7 +98,8 @@ logger.info(f'stop device {device_id} thread') # 获取线程 ID 和 future 对象 thread_id = self.device_tasks[device_id].thread_id - future = self.thread_pool.get_task_future(thread_id=thread_id) + # future = self.thread_pool.get_task_future(thread_id=thread_id) + future = self.task_futures[thread_id] if future: if not future.done(): @@ -122,17 +125,17 @@ else: logger.warning(f"No task exists for device {device_id} in device_tasks.") - def restart_device_thread(self, device_id): + async def restart_device_thread(self, device_id): self.stop_device_thread(device_id) - device = self.device_service.get_device(device_id) - self.start_device_thread(device) + device = await self.device_service.get_device(device_id) + asyncio.create_task(self.start_device_thread(device)) - def on_device_change(self, device_id, change_type): + async def on_device_change(self, device_id, change_type): logger.info(f"on device change, device {device_id} {change_type}") """设备变化回调""" if change_type == NotifyChangeType.DEVICE_CREATE: - device = self.device_service.get_device(device_id) - self.start_device_thread(device) + device = await self.device_service.get_device(device_id) + await self.start_device_thread(device) elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: @@ -141,9 +144,9 @@ old_device = self.device_tasks[device_id].device new_device = self.device_service.get_device(device_id) - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) - def on_model_change(self, model_id, change_type): + async def on_model_change(self, model_id, change_type): """模型变化回调""" if change_type == NotifyChangeType.MODEL_UPDATE: self.model_manager.reload_model(model_id) @@ -153,10 +156,10 @@ if relation_info.model_id == model_id ] for device_id in devices_to_reload: - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) - def on_relation_change(self, device_id, change_type): + async def on_relation_change(self, device_id, change_type): """设备模型关系变化回调""" # todo 判断模型绑定关系是否真的发生了变化,避免不必要的重启 if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index cdb97bf..8724c0d 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -39,7 +39,7 @@ class DeviceDetectionTask: - def __init__(self, device: Device, model_exec_list: List[AlgoModelExec], thread_id): + def __init__(self, device: Device, model_exec_list: List[AlgoModelExec], thread_id, db_session, main_loop): self.device = device self.model_exec_list = model_exec_list self.__stop_event = Event() # 使用 Event 控制线程的运行状态 @@ -47,12 +47,12 @@ self.push_ts = None self.thread_id = thread_id - with next(get_db()) as db: - self.device_frame_service = DeviceFrameService(db) - self.frame_analysis_result_service = FrameAnalysisResultService(db) + self.device_frame_service = DeviceFrameService(db_session) + self.frame_analysis_result_service = FrameAnalysisResultService(db_session) self.thread_pool = GlobalThreadPool() self.device_status_manager = DeviceStatusManager() + self.main_loop = main_loop self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, device_thread_id=thread_id) @@ -76,8 +76,15 @@ def save_frame_results(self, frame, results_map): if not self.check_frame_interval(): return - device_frame = self.device_frame_service.add_frame(self.device.id, frame) + # device_frame = self.device_frame_service.add_frame(self.device.id, frame) + # frame_id = device_frame.id + + future = asyncio.run_coroutine_threadsafe( + self.device_frame_service.add_frame(self.device.id, frame), self.main_loop + ) + device_frame = future.result() frame_id = device_frame.id + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') frame_results = [] for model_exec_id, results in results_map.items(): @@ -92,7 +99,10 @@ location=r.location, ) frame_results.append(frame_result) - self.frame_analysis_result_service.add_frame_analysis_results(frame_results) + asyncio.run_coroutine_threadsafe( + self.frame_analysis_result_service.add_frame_analysis_results(frame_results), + self.main_loop + ) self.thread_pool.submit_task(self.push_frame_results, frame_results) def push_frame_results(self, frame_results): diff --git a/algo/model_manager.py b/algo/model_manager.py index e3bc4c9..1c30867 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -25,18 +25,18 @@ self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up - def query_model_inuse(self): - algo_model_list = list(self.model_service.get_models_in_use()) + async def query_model_inuse(self): + algo_model_list = list(await self.model_service.get_models_in_use()) for algo_model in algo_model_list: self.models[algo_model.id] = AlgoModelExec( algo_model_id=algo_model.id, algo_model_info=algo_model ) - def load_models(self): + async def load_models(self): logger.info('loading models') self.models = {} - self.query_model_inuse() + await self.query_model_inuse() if not self.models: logger.info("no model in use") for algo_model_id, algo_model_exec in self.models.items(): @@ -69,8 +69,8 @@ algo_model_exec.algo_model_exec = None self.load_model(algo_model_exec) - def load_new_model(self, model_id): - algo_model = self.model_service.get_model_by_id(model_id) + async def load_new_model(self, model_id): + algo_model = await self.model_service.get_model_by_id(model_id) if algo_model: algo_model_exec = AlgoModelExec( algo_model_id=algo_model.id, diff --git a/algo/scene_runner.py b/algo/scene_runner.py index 2fb2b2b..e5c067e 100644 --- a/algo/scene_runner.py +++ b/algo/scene_runner.py @@ -1,3 +1,4 @@ +import asyncio import concurrent import copy import uuid @@ -7,6 +8,7 @@ from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool from common.string_utils import get_class, camel_to_snake +from db.database import get_db from entity.device import Device from scene_handler.base_scene_handler import BaseSceneHandler from services.device_scene_relation_service import DeviceSceneRelationService @@ -38,16 +40,16 @@ # self.relation_service.register_change_callback(self.on_relation_change) async def start(self): - self.load_and_start_devices() + await self.load_and_start_devices() - def load_and_start_devices(self): + async def load_and_start_devices(self): """从数据库读取设备列表并启动线程""" - devices = self.device_service.get_device_list() + devices = await self.device_service.get_device_list() devices = [device for device in devices if device.mode == DEVICE_MODE.SCENE] for device in devices: - self.start_device_thread(device) + asyncio.create_task(self.start_device_thread(device)) - def start_device_thread(self, device: Device): + async def start_device_thread(self, device: Device): device = copy.deepcopy(device) """为单个设备启动检测线程""" @@ -59,18 +61,20 @@ return thread_id = f'device_{device.id}_scene_{uuid.uuid4()}' - scene = self.relation_service.get_device_scene(device.id) - if scene: - self.device_scene_relations[device.id] = scene - try: - handle_task_name = scene.scene_handle_task - handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) - handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) - future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) - self.device_tasks[device.id] = handler_instance - logger.info(f'start thread {thread_id}, device info: {device}') - except Exception as e: - logger.exception(f"设备 {device.code} 场景 {scene.scene_name} 启动异常 : {e}") + async for db in get_db(): + relation_service = DeviceSceneRelationService(db) + scene = await relation_service.get_device_scene(device.id) + if scene: + self.device_scene_relations[device.id] = scene + try: + handle_task_name = scene.scene_handle_task + handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) + handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) + future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) + self.device_tasks[device.id] = handler_instance + logger.info(f'start thread {thread_id}, device info: {device}') + except Exception as e: + logger.exception(f"设备 {device.code} 场景 {scene.scene_name} 启动异常 : {e}") def stop_device_thread(self, device_id): """停止指定设备的检测线程,并确保其成功停止""" @@ -104,20 +108,20 @@ else: logger.warning(f"No task exists for device {device_id} in device_tasks.") - def restart_device_thread(self, device_id): + async def restart_device_thread(self, device_id): try: self.stop_device_thread(device_id) - device = self.device_service.get_device(device_id) - self.start_device_thread(device) + device = await self.device_service.get_device(device_id) + asyncio.create_task(self.start_device_thread(device)) except Exception as e: logger.error(e) - def on_device_change(self, device_id, change_type): + async def on_device_change(self, device_id, change_type): logger.info(f"on device change, device {device_id} {change_type}") """设备变化回调""" if change_type == NotifyChangeType.DEVICE_CREATE: - device = self.device_service.get_device(device_id) - self.start_device_thread(device) + device = await self.device_service.get_device(device_id) + await self.start_device_thread(device) elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: @@ -126,9 +130,9 @@ old_device = self.device_tasks[device_id].device new_device = self.device_service.get_device(device_id) - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) - def on_scene_change(self, scene_id, change_type): + async def on_scene_change(self, scene_id, change_type): """模型变化回调""" if change_type == NotifyChangeType.SCENE_UPDATE: @@ -137,10 +141,10 @@ if relation_info.scene_id == scene_id ] for device_id in devices_to_reload: - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) - def on_relation_change(self, device_id, change_type): + async def on_relation_change(self, device_id, change_type): """设备模型关系变化回调""" # 判断模型绑定关系是否真的发生了变化,避免不必要的重启 if change_type == NotifyChangeType.DEVICE_SCENE_RELATION_UPDATE: - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) diff --git a/app_instance.py b/app_instance.py index f9d1ad9..33f5e80 100644 --- a/app_instance.py +++ b/app_instance.py @@ -56,7 +56,7 @@ relation_service=model_relation_service, ) app.state.algo_runner = algo_runner - # await algo_runner.start() + await algo_runner.start() scene_runner = SceneRunner( device_service=device_service, @@ -66,9 +66,7 @@ main_loop=main_loop ) app.state.scene_runner = scene_runner - # await scene_runner.start() - - + await scene_runner.start() yield # 允许请求处理 diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py index 7f3be9f..ef9a356 100644 --- a/common/global_thread_pool.py +++ b/common/global_thread_pool.py @@ -1,3 +1,4 @@ +import asyncio import traceback import uuid from concurrent.futures import ThreadPoolExecutor @@ -11,20 +12,37 @@ return str(uuid.uuid4()) +def wrapper(func, *args, **kwargs): + return func(*args, **kwargs) + + class GlobalThreadPool: _instance = None _lock = threading.Lock() - def __new__(cls, *args, **kwargs): - with cls._lock: - if cls._instance is None: - # 第一次创建实例时调用父类的 __new__ 来创建实例 - cls._instance = super(GlobalThreadPool, cls).__new__(cls) - # 在此进行一次性的初始化,比如线程池的创建 - max_workers = kwargs.get('max_workers', 10) - cls._instance.executor = ThreadPoolExecutor(max_workers=max_workers) - cls._instance.task_map = {} # 初始化任务映射 - return cls._instance + def __new__(cls, max_workers=5): + if not cls._instance: + with cls._lock: + if not cls._instance: + cls._instance = super(GlobalThreadPool, cls).__new__(cls) + cls._instance._initialize(max_workers) + return cls._instance + + def _initialize(self, max_workers): + self.executor = ThreadPoolExecutor(max_workers=max_workers) + self.loop = asyncio.get_running_loop() # 获取当前的事件循环 + self.task_map = {} + + # def __new__(cls, *args, **kwargs): + # with cls._lock: + # if cls._instance is None: + # # 第一次创建实例时调用父类的 __new__ 来创建实例 + # cls._instance = super(GlobalThreadPool, cls).__new__(cls) + # # 在此进行一次性的初始化,比如线程池的创建 + # max_workers = kwargs.get('max_workers', 10) + # cls._instance.executor = ThreadPoolExecutor(max_workers=max_workers) + # cls._instance.task_map = {} # 初始化任务映射 + # return cls._instance def submit_task(self, fn, *args, thread_id=None, **kwargs): """提交任务到线程池,并记录线程 ID""" @@ -32,7 +50,9 @@ thread_id = generate_thread_id() if self.check_task_is_running(thread_id): raise ValueError(f"线程 ID {thread_id} 已存在") - future = self.executor.submit(fn, *args, **kwargs) + # future = self.executor.submit(fn, *args, **kwargs) + future = self.loop.run_in_executor(None, wrapper, fn, *args, **kwargs) + self.task_map[thread_id] = future # 记录线程 ID 和 Future 对象的映射 future.add_done_callback(lambda f: self._handle_exception(f, thread_id)) return thread_id @@ -101,4 +121,4 @@ result = future.result() except Exception as e: logger.error(f"Task for thread {thread_id} raised an exception: {e}") - logger.error(f"Traceback: {traceback.format_exc()}") \ No newline at end of file + logger.error(f"Traceback: {traceback.format_exc()}") diff --git a/model_handler/base_model_handler.py b/model_handler/base_model_handler.py index 8e0ca4b..8797f0d 100644 --- a/model_handler/base_model_handler.py +++ b/model_handler/base_model_handler.py @@ -18,7 +18,7 @@ verbose=False, stream=True) results = list(results_generator) # 确保生成器转换为列表 result = results[0] - logger.debug(f"model {self.model.algo_model_info.name} result: {len(result)}") + # logger.debug(f"model {self.model.algo_model_info.name} result: {len(result)}") return result def post_process(self, frame, model_result, annotator): diff --git a/services/device_frame_service.py b/services/device_frame_service.py index 1a468ae..929e39a 100644 --- a/services/device_frame_service.py +++ b/services/device_frame_service.py @@ -27,36 +27,36 @@ self.db = db async def add_frame(self, device_id, frame_data) -> DeviceFrame: - async def add_frame(self, device_id, frame_data) -> 'DeviceFrame': - async def save_frame_file(): - # 生成当前年月日作为目录路径 - current_date = datetime.now().strftime('%Y-%m-%d') - # 生成随机 UUID 作为文件名 - file_name = f"{uuid.uuid4()}.jpeg" - # 创建保存图片的完整路径 - save_path = os.path.join('./storage/frames', current_date, file_name) - # 创建目录(如果不存在) - os.makedirs(os.path.dirname(save_path), exist_ok=True) - # 将 frame_data 转换为二进制数据(假设 frame_data 是一个 numpy 数组) - _, encoded_image = cv2.imencode('.jpeg', frame_data) - image_data = encoded_image.tobytes() + async def save_frame_file(): + # 生成当前年月日作为目录路径 + current_date = datetime.now().strftime('%Y-%m-%d') + # 生成随机 UUID 作为文件名 + file_name = f"{uuid.uuid4()}.jpeg" + # 创建保存图片的完整路径 + save_path = os.path.join('./storage/frames', current_date, file_name) + # 创建目录(如果不存在) + os.makedirs(os.path.dirname(save_path), exist_ok=True) - # 使用 aiofiles 进行异步写入 - async with aiofiles.open(save_path, 'wb') as f: - await f.write(image_data) + # 将 frame_data 转换为二进制数据(假设 frame_data 是一个 numpy 数组) + _, encoded_image = cv2.imencode('.jpeg', frame_data) + image_data = encoded_image.tobytes() - return save_path + # 使用 aiofiles 进行异步写入 + async with aiofiles.open(save_path, 'wb') as f: + await f.write(image_data) - # 异步保存图片文件 - file_path = await save_frame_file() + return save_path - # 创建并保存到数据库中 - device_frame = DeviceFrame(device_id=device_id, frame_path=file_path) - self.db.add(device_frame) - await self.db.commit() - await self.db.refresh(device_frame) - return device_frame + # 异步保存图片文件 + file_path = await save_frame_file() + + # 创建并保存到数据库中 + device_frame = DeviceFrame(device_id=device_id, frame_path=file_path) + self.db.add(device_frame) + await self.db.commit() + await self.db.refresh(device_frame) + return device_frame async def get_frame_page(self, device_name: Optional[str] = None, @@ -79,6 +79,7 @@ statement = statement.where(DeviceFrame.time >= frame_start_time) if frame_start_time: statement = statement.where(DeviceFrame.time <= frame_end_time) + statement = statement.order_by(DeviceFrame.time.desc()) # 查询总记录数 total_statement = select(func.count()).select_from(statement.subquery()) diff --git a/algo/algo_runner.py b/algo/algo_runner.py index 40b8d1e..3541a1c 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,4 @@ +import asyncio import concurrent.futures import copy import uuid @@ -14,9 +15,10 @@ from services.model_service import ModelService from common.global_logger import logger + class AlgoRunner: def __init__(self, - device_service : DeviceService, + device_service: DeviceService, model_service: ModelService, relation_service: DeviceModelRelationService): @@ -27,8 +29,11 @@ self.thread_pool = GlobalThreadPool() self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程 + self.task_futures: Dict[str, concurrent.futures.Future] = {} # 用于存储线程的 future 对象 self.device_model_relations = {} + self.main_loop = asyncio.get_running_loop() + # 注册设备和模型的变化回调 # self.device_service.register_change_callback(self.on_device_change) # self.model_service.register_change_callback(self.on_model_change) @@ -37,16 +42,16 @@ async def start(self): logger.info("Starting AlgoRunner...") """在程序启动时调用,读取设备和模型,启动检测线程""" - self.model_manager.load_models() - self.load_and_start_devices() + await self.model_manager.load_models() + await self.load_and_start_devices() - def load_and_start_devices(self): + async def load_and_start_devices(self): """从数据库读取设备列表并启动线程""" - devices = self.device_service.get_device_list() + devices = await self.device_service.get_device_list() for device in devices: - self.start_device_thread(device) + asyncio.create_task(self.start_device_thread(device)) - def start_device_thread(self, device: Device): + async def start_device_thread(self, device: Device): device = copy.deepcopy(device) """为单个设备启动检测线程""" @@ -65,30 +70,27 @@ # return # 获取设备绑定的模型列表 - relations = self.relation_service.get_device_models(device.id) - relations = [r for r in relations if r.is_use] - self.device_model_relations[device.id] = relations + async for db in get_db(): + relation_service = DeviceModelRelationService(db) + relations = await relation_service.get_device_models(device.id) + relations = [r for r in relations if r.is_use] + self.device_model_relations[device.id] = relations - if not relations: - logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") - return + if not relations: + logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") + return - for r in relations: - if r.algo_model_id not in self.model_manager.models: - self.model_manager.load_new_model(r.algo_model_id) + for r in relations: + if r.algo_model_id not in self.model_manager.models: + await self.model_manager.load_new_model(r.algo_model_id) - model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations - if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] - device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list, thread_id=thread_id) - future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) - self.device_tasks[device.id] = device_detection_task - - # def stop_device_thread(self, device_id): - # """停止指定设备的检测线程""" - # if device_id in self.device_tasks: - # logger.info(f'stop device {device_id} thread') - # self.device_tasks[device_id].stop_detection_task() - # del self.device_tasks[device_id] + model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations + if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] + device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list, + thread_id=thread_id, db_session=db, main_loop=self.main_loop) + future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) + self.device_tasks[device.id] = device_detection_task + self.task_futures[thread_id] = future def stop_device_thread(self, device_id): """停止指定设备的检测线程,并确保其成功停止""" @@ -96,7 +98,8 @@ logger.info(f'stop device {device_id} thread') # 获取线程 ID 和 future 对象 thread_id = self.device_tasks[device_id].thread_id - future = self.thread_pool.get_task_future(thread_id=thread_id) + # future = self.thread_pool.get_task_future(thread_id=thread_id) + future = self.task_futures[thread_id] if future: if not future.done(): @@ -122,17 +125,17 @@ else: logger.warning(f"No task exists for device {device_id} in device_tasks.") - def restart_device_thread(self, device_id): + async def restart_device_thread(self, device_id): self.stop_device_thread(device_id) - device = self.device_service.get_device(device_id) - self.start_device_thread(device) + device = await self.device_service.get_device(device_id) + asyncio.create_task(self.start_device_thread(device)) - def on_device_change(self, device_id, change_type): + async def on_device_change(self, device_id, change_type): logger.info(f"on device change, device {device_id} {change_type}") """设备变化回调""" if change_type == NotifyChangeType.DEVICE_CREATE: - device = self.device_service.get_device(device_id) - self.start_device_thread(device) + device = await self.device_service.get_device(device_id) + await self.start_device_thread(device) elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: @@ -141,9 +144,9 @@ old_device = self.device_tasks[device_id].device new_device = self.device_service.get_device(device_id) - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) - def on_model_change(self, model_id, change_type): + async def on_model_change(self, model_id, change_type): """模型变化回调""" if change_type == NotifyChangeType.MODEL_UPDATE: self.model_manager.reload_model(model_id) @@ -153,10 +156,10 @@ if relation_info.model_id == model_id ] for device_id in devices_to_reload: - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) - def on_relation_change(self, device_id, change_type): + async def on_relation_change(self, device_id, change_type): """设备模型关系变化回调""" # todo 判断模型绑定关系是否真的发生了变化,避免不必要的重启 if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index cdb97bf..8724c0d 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -39,7 +39,7 @@ class DeviceDetectionTask: - def __init__(self, device: Device, model_exec_list: List[AlgoModelExec], thread_id): + def __init__(self, device: Device, model_exec_list: List[AlgoModelExec], thread_id, db_session, main_loop): self.device = device self.model_exec_list = model_exec_list self.__stop_event = Event() # 使用 Event 控制线程的运行状态 @@ -47,12 +47,12 @@ self.push_ts = None self.thread_id = thread_id - with next(get_db()) as db: - self.device_frame_service = DeviceFrameService(db) - self.frame_analysis_result_service = FrameAnalysisResultService(db) + self.device_frame_service = DeviceFrameService(db_session) + self.frame_analysis_result_service = FrameAnalysisResultService(db_session) self.thread_pool = GlobalThreadPool() self.device_status_manager = DeviceStatusManager() + self.main_loop = main_loop self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, device_thread_id=thread_id) @@ -76,8 +76,15 @@ def save_frame_results(self, frame, results_map): if not self.check_frame_interval(): return - device_frame = self.device_frame_service.add_frame(self.device.id, frame) + # device_frame = self.device_frame_service.add_frame(self.device.id, frame) + # frame_id = device_frame.id + + future = asyncio.run_coroutine_threadsafe( + self.device_frame_service.add_frame(self.device.id, frame), self.main_loop + ) + device_frame = future.result() frame_id = device_frame.id + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') frame_results = [] for model_exec_id, results in results_map.items(): @@ -92,7 +99,10 @@ location=r.location, ) frame_results.append(frame_result) - self.frame_analysis_result_service.add_frame_analysis_results(frame_results) + asyncio.run_coroutine_threadsafe( + self.frame_analysis_result_service.add_frame_analysis_results(frame_results), + self.main_loop + ) self.thread_pool.submit_task(self.push_frame_results, frame_results) def push_frame_results(self, frame_results): diff --git a/algo/model_manager.py b/algo/model_manager.py index e3bc4c9..1c30867 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -25,18 +25,18 @@ self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up - def query_model_inuse(self): - algo_model_list = list(self.model_service.get_models_in_use()) + async def query_model_inuse(self): + algo_model_list = list(await self.model_service.get_models_in_use()) for algo_model in algo_model_list: self.models[algo_model.id] = AlgoModelExec( algo_model_id=algo_model.id, algo_model_info=algo_model ) - def load_models(self): + async def load_models(self): logger.info('loading models') self.models = {} - self.query_model_inuse() + await self.query_model_inuse() if not self.models: logger.info("no model in use") for algo_model_id, algo_model_exec in self.models.items(): @@ -69,8 +69,8 @@ algo_model_exec.algo_model_exec = None self.load_model(algo_model_exec) - def load_new_model(self, model_id): - algo_model = self.model_service.get_model_by_id(model_id) + async def load_new_model(self, model_id): + algo_model = await self.model_service.get_model_by_id(model_id) if algo_model: algo_model_exec = AlgoModelExec( algo_model_id=algo_model.id, diff --git a/algo/scene_runner.py b/algo/scene_runner.py index 2fb2b2b..e5c067e 100644 --- a/algo/scene_runner.py +++ b/algo/scene_runner.py @@ -1,3 +1,4 @@ +import asyncio import concurrent import copy import uuid @@ -7,6 +8,7 @@ from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool from common.string_utils import get_class, camel_to_snake +from db.database import get_db from entity.device import Device from scene_handler.base_scene_handler import BaseSceneHandler from services.device_scene_relation_service import DeviceSceneRelationService @@ -38,16 +40,16 @@ # self.relation_service.register_change_callback(self.on_relation_change) async def start(self): - self.load_and_start_devices() + await self.load_and_start_devices() - def load_and_start_devices(self): + async def load_and_start_devices(self): """从数据库读取设备列表并启动线程""" - devices = self.device_service.get_device_list() + devices = await self.device_service.get_device_list() devices = [device for device in devices if device.mode == DEVICE_MODE.SCENE] for device in devices: - self.start_device_thread(device) + asyncio.create_task(self.start_device_thread(device)) - def start_device_thread(self, device: Device): + async def start_device_thread(self, device: Device): device = copy.deepcopy(device) """为单个设备启动检测线程""" @@ -59,18 +61,20 @@ return thread_id = f'device_{device.id}_scene_{uuid.uuid4()}' - scene = self.relation_service.get_device_scene(device.id) - if scene: - self.device_scene_relations[device.id] = scene - try: - handle_task_name = scene.scene_handle_task - handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) - handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) - future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) - self.device_tasks[device.id] = handler_instance - logger.info(f'start thread {thread_id}, device info: {device}') - except Exception as e: - logger.exception(f"设备 {device.code} 场景 {scene.scene_name} 启动异常 : {e}") + async for db in get_db(): + relation_service = DeviceSceneRelationService(db) + scene = await relation_service.get_device_scene(device.id) + if scene: + self.device_scene_relations[device.id] = scene + try: + handle_task_name = scene.scene_handle_task + handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) + handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) + future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) + self.device_tasks[device.id] = handler_instance + logger.info(f'start thread {thread_id}, device info: {device}') + except Exception as e: + logger.exception(f"设备 {device.code} 场景 {scene.scene_name} 启动异常 : {e}") def stop_device_thread(self, device_id): """停止指定设备的检测线程,并确保其成功停止""" @@ -104,20 +108,20 @@ else: logger.warning(f"No task exists for device {device_id} in device_tasks.") - def restart_device_thread(self, device_id): + async def restart_device_thread(self, device_id): try: self.stop_device_thread(device_id) - device = self.device_service.get_device(device_id) - self.start_device_thread(device) + device = await self.device_service.get_device(device_id) + asyncio.create_task(self.start_device_thread(device)) except Exception as e: logger.error(e) - def on_device_change(self, device_id, change_type): + async def on_device_change(self, device_id, change_type): logger.info(f"on device change, device {device_id} {change_type}") """设备变化回调""" if change_type == NotifyChangeType.DEVICE_CREATE: - device = self.device_service.get_device(device_id) - self.start_device_thread(device) + device = await self.device_service.get_device(device_id) + await self.start_device_thread(device) elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: @@ -126,9 +130,9 @@ old_device = self.device_tasks[device_id].device new_device = self.device_service.get_device(device_id) - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) - def on_scene_change(self, scene_id, change_type): + async def on_scene_change(self, scene_id, change_type): """模型变化回调""" if change_type == NotifyChangeType.SCENE_UPDATE: @@ -137,10 +141,10 @@ if relation_info.scene_id == scene_id ] for device_id in devices_to_reload: - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) - def on_relation_change(self, device_id, change_type): + async def on_relation_change(self, device_id, change_type): """设备模型关系变化回调""" # 判断模型绑定关系是否真的发生了变化,避免不必要的重启 if change_type == NotifyChangeType.DEVICE_SCENE_RELATION_UPDATE: - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) diff --git a/app_instance.py b/app_instance.py index f9d1ad9..33f5e80 100644 --- a/app_instance.py +++ b/app_instance.py @@ -56,7 +56,7 @@ relation_service=model_relation_service, ) app.state.algo_runner = algo_runner - # await algo_runner.start() + await algo_runner.start() scene_runner = SceneRunner( device_service=device_service, @@ -66,9 +66,7 @@ main_loop=main_loop ) app.state.scene_runner = scene_runner - # await scene_runner.start() - - + await scene_runner.start() yield # 允许请求处理 diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py index 7f3be9f..ef9a356 100644 --- a/common/global_thread_pool.py +++ b/common/global_thread_pool.py @@ -1,3 +1,4 @@ +import asyncio import traceback import uuid from concurrent.futures import ThreadPoolExecutor @@ -11,20 +12,37 @@ return str(uuid.uuid4()) +def wrapper(func, *args, **kwargs): + return func(*args, **kwargs) + + class GlobalThreadPool: _instance = None _lock = threading.Lock() - def __new__(cls, *args, **kwargs): - with cls._lock: - if cls._instance is None: - # 第一次创建实例时调用父类的 __new__ 来创建实例 - cls._instance = super(GlobalThreadPool, cls).__new__(cls) - # 在此进行一次性的初始化,比如线程池的创建 - max_workers = kwargs.get('max_workers', 10) - cls._instance.executor = ThreadPoolExecutor(max_workers=max_workers) - cls._instance.task_map = {} # 初始化任务映射 - return cls._instance + def __new__(cls, max_workers=5): + if not cls._instance: + with cls._lock: + if not cls._instance: + cls._instance = super(GlobalThreadPool, cls).__new__(cls) + cls._instance._initialize(max_workers) + return cls._instance + + def _initialize(self, max_workers): + self.executor = ThreadPoolExecutor(max_workers=max_workers) + self.loop = asyncio.get_running_loop() # 获取当前的事件循环 + self.task_map = {} + + # def __new__(cls, *args, **kwargs): + # with cls._lock: + # if cls._instance is None: + # # 第一次创建实例时调用父类的 __new__ 来创建实例 + # cls._instance = super(GlobalThreadPool, cls).__new__(cls) + # # 在此进行一次性的初始化,比如线程池的创建 + # max_workers = kwargs.get('max_workers', 10) + # cls._instance.executor = ThreadPoolExecutor(max_workers=max_workers) + # cls._instance.task_map = {} # 初始化任务映射 + # return cls._instance def submit_task(self, fn, *args, thread_id=None, **kwargs): """提交任务到线程池,并记录线程 ID""" @@ -32,7 +50,9 @@ thread_id = generate_thread_id() if self.check_task_is_running(thread_id): raise ValueError(f"线程 ID {thread_id} 已存在") - future = self.executor.submit(fn, *args, **kwargs) + # future = self.executor.submit(fn, *args, **kwargs) + future = self.loop.run_in_executor(None, wrapper, fn, *args, **kwargs) + self.task_map[thread_id] = future # 记录线程 ID 和 Future 对象的映射 future.add_done_callback(lambda f: self._handle_exception(f, thread_id)) return thread_id @@ -101,4 +121,4 @@ result = future.result() except Exception as e: logger.error(f"Task for thread {thread_id} raised an exception: {e}") - logger.error(f"Traceback: {traceback.format_exc()}") \ No newline at end of file + logger.error(f"Traceback: {traceback.format_exc()}") diff --git a/model_handler/base_model_handler.py b/model_handler/base_model_handler.py index 8e0ca4b..8797f0d 100644 --- a/model_handler/base_model_handler.py +++ b/model_handler/base_model_handler.py @@ -18,7 +18,7 @@ verbose=False, stream=True) results = list(results_generator) # 确保生成器转换为列表 result = results[0] - logger.debug(f"model {self.model.algo_model_info.name} result: {len(result)}") + # logger.debug(f"model {self.model.algo_model_info.name} result: {len(result)}") return result def post_process(self, frame, model_result, annotator): diff --git a/services/device_frame_service.py b/services/device_frame_service.py index 1a468ae..929e39a 100644 --- a/services/device_frame_service.py +++ b/services/device_frame_service.py @@ -27,36 +27,36 @@ self.db = db async def add_frame(self, device_id, frame_data) -> DeviceFrame: - async def add_frame(self, device_id, frame_data) -> 'DeviceFrame': - async def save_frame_file(): - # 生成当前年月日作为目录路径 - current_date = datetime.now().strftime('%Y-%m-%d') - # 生成随机 UUID 作为文件名 - file_name = f"{uuid.uuid4()}.jpeg" - # 创建保存图片的完整路径 - save_path = os.path.join('./storage/frames', current_date, file_name) - # 创建目录(如果不存在) - os.makedirs(os.path.dirname(save_path), exist_ok=True) - # 将 frame_data 转换为二进制数据(假设 frame_data 是一个 numpy 数组) - _, encoded_image = cv2.imencode('.jpeg', frame_data) - image_data = encoded_image.tobytes() + async def save_frame_file(): + # 生成当前年月日作为目录路径 + current_date = datetime.now().strftime('%Y-%m-%d') + # 生成随机 UUID 作为文件名 + file_name = f"{uuid.uuid4()}.jpeg" + # 创建保存图片的完整路径 + save_path = os.path.join('./storage/frames', current_date, file_name) + # 创建目录(如果不存在) + os.makedirs(os.path.dirname(save_path), exist_ok=True) - # 使用 aiofiles 进行异步写入 - async with aiofiles.open(save_path, 'wb') as f: - await f.write(image_data) + # 将 frame_data 转换为二进制数据(假设 frame_data 是一个 numpy 数组) + _, encoded_image = cv2.imencode('.jpeg', frame_data) + image_data = encoded_image.tobytes() - return save_path + # 使用 aiofiles 进行异步写入 + async with aiofiles.open(save_path, 'wb') as f: + await f.write(image_data) - # 异步保存图片文件 - file_path = await save_frame_file() + return save_path - # 创建并保存到数据库中 - device_frame = DeviceFrame(device_id=device_id, frame_path=file_path) - self.db.add(device_frame) - await self.db.commit() - await self.db.refresh(device_frame) - return device_frame + # 异步保存图片文件 + file_path = await save_frame_file() + + # 创建并保存到数据库中 + device_frame = DeviceFrame(device_id=device_id, frame_path=file_path) + self.db.add(device_frame) + await self.db.commit() + await self.db.refresh(device_frame) + return device_frame async def get_frame_page(self, device_name: Optional[str] = None, @@ -79,6 +79,7 @@ statement = statement.where(DeviceFrame.time >= frame_start_time) if frame_start_time: statement = statement.where(DeviceFrame.time <= frame_end_time) + statement = statement.order_by(DeviceFrame.time.desc()) # 查询总记录数 total_statement = select(func.count()).select_from(statement.subquery()) diff --git a/services/scene_service.py b/services/scene_service.py index dc4c67b..3fbdf0a 100644 --- a/services/scene_service.py +++ b/services/scene_service.py @@ -204,7 +204,7 @@ DeviceSceneRelation.scene_id == scene_id, ) ) - result = await self.db.exec(statement) + result = await self.db.execute(statement) rows = result.all() return len(rows) > 0 diff --git a/algo/algo_runner.py b/algo/algo_runner.py index 40b8d1e..3541a1c 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,4 @@ +import asyncio import concurrent.futures import copy import uuid @@ -14,9 +15,10 @@ from services.model_service import ModelService from common.global_logger import logger + class AlgoRunner: def __init__(self, - device_service : DeviceService, + device_service: DeviceService, model_service: ModelService, relation_service: DeviceModelRelationService): @@ -27,8 +29,11 @@ self.thread_pool = GlobalThreadPool() self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程 + self.task_futures: Dict[str, concurrent.futures.Future] = {} # 用于存储线程的 future 对象 self.device_model_relations = {} + self.main_loop = asyncio.get_running_loop() + # 注册设备和模型的变化回调 # self.device_service.register_change_callback(self.on_device_change) # self.model_service.register_change_callback(self.on_model_change) @@ -37,16 +42,16 @@ async def start(self): logger.info("Starting AlgoRunner...") """在程序启动时调用,读取设备和模型,启动检测线程""" - self.model_manager.load_models() - self.load_and_start_devices() + await self.model_manager.load_models() + await self.load_and_start_devices() - def load_and_start_devices(self): + async def load_and_start_devices(self): """从数据库读取设备列表并启动线程""" - devices = self.device_service.get_device_list() + devices = await self.device_service.get_device_list() for device in devices: - self.start_device_thread(device) + asyncio.create_task(self.start_device_thread(device)) - def start_device_thread(self, device: Device): + async def start_device_thread(self, device: Device): device = copy.deepcopy(device) """为单个设备启动检测线程""" @@ -65,30 +70,27 @@ # return # 获取设备绑定的模型列表 - relations = self.relation_service.get_device_models(device.id) - relations = [r for r in relations if r.is_use] - self.device_model_relations[device.id] = relations + async for db in get_db(): + relation_service = DeviceModelRelationService(db) + relations = await relation_service.get_device_models(device.id) + relations = [r for r in relations if r.is_use] + self.device_model_relations[device.id] = relations - if not relations: - logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") - return + if not relations: + logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") + return - for r in relations: - if r.algo_model_id not in self.model_manager.models: - self.model_manager.load_new_model(r.algo_model_id) + for r in relations: + if r.algo_model_id not in self.model_manager.models: + await self.model_manager.load_new_model(r.algo_model_id) - model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations - if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] - device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list, thread_id=thread_id) - future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) - self.device_tasks[device.id] = device_detection_task - - # def stop_device_thread(self, device_id): - # """停止指定设备的检测线程""" - # if device_id in self.device_tasks: - # logger.info(f'stop device {device_id} thread') - # self.device_tasks[device_id].stop_detection_task() - # del self.device_tasks[device_id] + model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations + if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] + device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list, + thread_id=thread_id, db_session=db, main_loop=self.main_loop) + future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) + self.device_tasks[device.id] = device_detection_task + self.task_futures[thread_id] = future def stop_device_thread(self, device_id): """停止指定设备的检测线程,并确保其成功停止""" @@ -96,7 +98,8 @@ logger.info(f'stop device {device_id} thread') # 获取线程 ID 和 future 对象 thread_id = self.device_tasks[device_id].thread_id - future = self.thread_pool.get_task_future(thread_id=thread_id) + # future = self.thread_pool.get_task_future(thread_id=thread_id) + future = self.task_futures[thread_id] if future: if not future.done(): @@ -122,17 +125,17 @@ else: logger.warning(f"No task exists for device {device_id} in device_tasks.") - def restart_device_thread(self, device_id): + async def restart_device_thread(self, device_id): self.stop_device_thread(device_id) - device = self.device_service.get_device(device_id) - self.start_device_thread(device) + device = await self.device_service.get_device(device_id) + asyncio.create_task(self.start_device_thread(device)) - def on_device_change(self, device_id, change_type): + async def on_device_change(self, device_id, change_type): logger.info(f"on device change, device {device_id} {change_type}") """设备变化回调""" if change_type == NotifyChangeType.DEVICE_CREATE: - device = self.device_service.get_device(device_id) - self.start_device_thread(device) + device = await self.device_service.get_device(device_id) + await self.start_device_thread(device) elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: @@ -141,9 +144,9 @@ old_device = self.device_tasks[device_id].device new_device = self.device_service.get_device(device_id) - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) - def on_model_change(self, model_id, change_type): + async def on_model_change(self, model_id, change_type): """模型变化回调""" if change_type == NotifyChangeType.MODEL_UPDATE: self.model_manager.reload_model(model_id) @@ -153,10 +156,10 @@ if relation_info.model_id == model_id ] for device_id in devices_to_reload: - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) - def on_relation_change(self, device_id, change_type): + async def on_relation_change(self, device_id, change_type): """设备模型关系变化回调""" # todo 判断模型绑定关系是否真的发生了变化,避免不必要的重启 if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index cdb97bf..8724c0d 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -39,7 +39,7 @@ class DeviceDetectionTask: - def __init__(self, device: Device, model_exec_list: List[AlgoModelExec], thread_id): + def __init__(self, device: Device, model_exec_list: List[AlgoModelExec], thread_id, db_session, main_loop): self.device = device self.model_exec_list = model_exec_list self.__stop_event = Event() # 使用 Event 控制线程的运行状态 @@ -47,12 +47,12 @@ self.push_ts = None self.thread_id = thread_id - with next(get_db()) as db: - self.device_frame_service = DeviceFrameService(db) - self.frame_analysis_result_service = FrameAnalysisResultService(db) + self.device_frame_service = DeviceFrameService(db_session) + self.frame_analysis_result_service = FrameAnalysisResultService(db_session) self.thread_pool = GlobalThreadPool() self.device_status_manager = DeviceStatusManager() + self.main_loop = main_loop self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, device_thread_id=thread_id) @@ -76,8 +76,15 @@ def save_frame_results(self, frame, results_map): if not self.check_frame_interval(): return - device_frame = self.device_frame_service.add_frame(self.device.id, frame) + # device_frame = self.device_frame_service.add_frame(self.device.id, frame) + # frame_id = device_frame.id + + future = asyncio.run_coroutine_threadsafe( + self.device_frame_service.add_frame(self.device.id, frame), self.main_loop + ) + device_frame = future.result() frame_id = device_frame.id + logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') frame_results = [] for model_exec_id, results in results_map.items(): @@ -92,7 +99,10 @@ location=r.location, ) frame_results.append(frame_result) - self.frame_analysis_result_service.add_frame_analysis_results(frame_results) + asyncio.run_coroutine_threadsafe( + self.frame_analysis_result_service.add_frame_analysis_results(frame_results), + self.main_loop + ) self.thread_pool.submit_task(self.push_frame_results, frame_results) def push_frame_results(self, frame_results): diff --git a/algo/model_manager.py b/algo/model_manager.py index e3bc4c9..1c30867 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -25,18 +25,18 @@ self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up - def query_model_inuse(self): - algo_model_list = list(self.model_service.get_models_in_use()) + async def query_model_inuse(self): + algo_model_list = list(await self.model_service.get_models_in_use()) for algo_model in algo_model_list: self.models[algo_model.id] = AlgoModelExec( algo_model_id=algo_model.id, algo_model_info=algo_model ) - def load_models(self): + async def load_models(self): logger.info('loading models') self.models = {} - self.query_model_inuse() + await self.query_model_inuse() if not self.models: logger.info("no model in use") for algo_model_id, algo_model_exec in self.models.items(): @@ -69,8 +69,8 @@ algo_model_exec.algo_model_exec = None self.load_model(algo_model_exec) - def load_new_model(self, model_id): - algo_model = self.model_service.get_model_by_id(model_id) + async def load_new_model(self, model_id): + algo_model = await self.model_service.get_model_by_id(model_id) if algo_model: algo_model_exec = AlgoModelExec( algo_model_id=algo_model.id, diff --git a/algo/scene_runner.py b/algo/scene_runner.py index 2fb2b2b..e5c067e 100644 --- a/algo/scene_runner.py +++ b/algo/scene_runner.py @@ -1,3 +1,4 @@ +import asyncio import concurrent import copy import uuid @@ -7,6 +8,7 @@ from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool from common.string_utils import get_class, camel_to_snake +from db.database import get_db from entity.device import Device from scene_handler.base_scene_handler import BaseSceneHandler from services.device_scene_relation_service import DeviceSceneRelationService @@ -38,16 +40,16 @@ # self.relation_service.register_change_callback(self.on_relation_change) async def start(self): - self.load_and_start_devices() + await self.load_and_start_devices() - def load_and_start_devices(self): + async def load_and_start_devices(self): """从数据库读取设备列表并启动线程""" - devices = self.device_service.get_device_list() + devices = await self.device_service.get_device_list() devices = [device for device in devices if device.mode == DEVICE_MODE.SCENE] for device in devices: - self.start_device_thread(device) + asyncio.create_task(self.start_device_thread(device)) - def start_device_thread(self, device: Device): + async def start_device_thread(self, device: Device): device = copy.deepcopy(device) """为单个设备启动检测线程""" @@ -59,18 +61,20 @@ return thread_id = f'device_{device.id}_scene_{uuid.uuid4()}' - scene = self.relation_service.get_device_scene(device.id) - if scene: - self.device_scene_relations[device.id] = scene - try: - handle_task_name = scene.scene_handle_task - handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) - handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) - future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) - self.device_tasks[device.id] = handler_instance - logger.info(f'start thread {thread_id}, device info: {device}') - except Exception as e: - logger.exception(f"设备 {device.code} 场景 {scene.scene_name} 启动异常 : {e}") + async for db in get_db(): + relation_service = DeviceSceneRelationService(db) + scene = await relation_service.get_device_scene(device.id) + if scene: + self.device_scene_relations[device.id] = scene + try: + handle_task_name = scene.scene_handle_task + handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) + handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) + future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) + self.device_tasks[device.id] = handler_instance + logger.info(f'start thread {thread_id}, device info: {device}') + except Exception as e: + logger.exception(f"设备 {device.code} 场景 {scene.scene_name} 启动异常 : {e}") def stop_device_thread(self, device_id): """停止指定设备的检测线程,并确保其成功停止""" @@ -104,20 +108,20 @@ else: logger.warning(f"No task exists for device {device_id} in device_tasks.") - def restart_device_thread(self, device_id): + async def restart_device_thread(self, device_id): try: self.stop_device_thread(device_id) - device = self.device_service.get_device(device_id) - self.start_device_thread(device) + device = await self.device_service.get_device(device_id) + asyncio.create_task(self.start_device_thread(device)) except Exception as e: logger.error(e) - def on_device_change(self, device_id, change_type): + async def on_device_change(self, device_id, change_type): logger.info(f"on device change, device {device_id} {change_type}") """设备变化回调""" if change_type == NotifyChangeType.DEVICE_CREATE: - device = self.device_service.get_device(device_id) - self.start_device_thread(device) + device = await self.device_service.get_device(device_id) + await self.start_device_thread(device) elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: @@ -126,9 +130,9 @@ old_device = self.device_tasks[device_id].device new_device = self.device_service.get_device(device_id) - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) - def on_scene_change(self, scene_id, change_type): + async def on_scene_change(self, scene_id, change_type): """模型变化回调""" if change_type == NotifyChangeType.SCENE_UPDATE: @@ -137,10 +141,10 @@ if relation_info.scene_id == scene_id ] for device_id in devices_to_reload: - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) - def on_relation_change(self, device_id, change_type): + async def on_relation_change(self, device_id, change_type): """设备模型关系变化回调""" # 判断模型绑定关系是否真的发生了变化,避免不必要的重启 if change_type == NotifyChangeType.DEVICE_SCENE_RELATION_UPDATE: - self.restart_device_thread(device_id) + await self.restart_device_thread(device_id) diff --git a/app_instance.py b/app_instance.py index f9d1ad9..33f5e80 100644 --- a/app_instance.py +++ b/app_instance.py @@ -56,7 +56,7 @@ relation_service=model_relation_service, ) app.state.algo_runner = algo_runner - # await algo_runner.start() + await algo_runner.start() scene_runner = SceneRunner( device_service=device_service, @@ -66,9 +66,7 @@ main_loop=main_loop ) app.state.scene_runner = scene_runner - # await scene_runner.start() - - + await scene_runner.start() yield # 允许请求处理 diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py index 7f3be9f..ef9a356 100644 --- a/common/global_thread_pool.py +++ b/common/global_thread_pool.py @@ -1,3 +1,4 @@ +import asyncio import traceback import uuid from concurrent.futures import ThreadPoolExecutor @@ -11,20 +12,37 @@ return str(uuid.uuid4()) +def wrapper(func, *args, **kwargs): + return func(*args, **kwargs) + + class GlobalThreadPool: _instance = None _lock = threading.Lock() - def __new__(cls, *args, **kwargs): - with cls._lock: - if cls._instance is None: - # 第一次创建实例时调用父类的 __new__ 来创建实例 - cls._instance = super(GlobalThreadPool, cls).__new__(cls) - # 在此进行一次性的初始化,比如线程池的创建 - max_workers = kwargs.get('max_workers', 10) - cls._instance.executor = ThreadPoolExecutor(max_workers=max_workers) - cls._instance.task_map = {} # 初始化任务映射 - return cls._instance + def __new__(cls, max_workers=5): + if not cls._instance: + with cls._lock: + if not cls._instance: + cls._instance = super(GlobalThreadPool, cls).__new__(cls) + cls._instance._initialize(max_workers) + return cls._instance + + def _initialize(self, max_workers): + self.executor = ThreadPoolExecutor(max_workers=max_workers) + self.loop = asyncio.get_running_loop() # 获取当前的事件循环 + self.task_map = {} + + # def __new__(cls, *args, **kwargs): + # with cls._lock: + # if cls._instance is None: + # # 第一次创建实例时调用父类的 __new__ 来创建实例 + # cls._instance = super(GlobalThreadPool, cls).__new__(cls) + # # 在此进行一次性的初始化,比如线程池的创建 + # max_workers = kwargs.get('max_workers', 10) + # cls._instance.executor = ThreadPoolExecutor(max_workers=max_workers) + # cls._instance.task_map = {} # 初始化任务映射 + # return cls._instance def submit_task(self, fn, *args, thread_id=None, **kwargs): """提交任务到线程池,并记录线程 ID""" @@ -32,7 +50,9 @@ thread_id = generate_thread_id() if self.check_task_is_running(thread_id): raise ValueError(f"线程 ID {thread_id} 已存在") - future = self.executor.submit(fn, *args, **kwargs) + # future = self.executor.submit(fn, *args, **kwargs) + future = self.loop.run_in_executor(None, wrapper, fn, *args, **kwargs) + self.task_map[thread_id] = future # 记录线程 ID 和 Future 对象的映射 future.add_done_callback(lambda f: self._handle_exception(f, thread_id)) return thread_id @@ -101,4 +121,4 @@ result = future.result() except Exception as e: logger.error(f"Task for thread {thread_id} raised an exception: {e}") - logger.error(f"Traceback: {traceback.format_exc()}") \ No newline at end of file + logger.error(f"Traceback: {traceback.format_exc()}") diff --git a/model_handler/base_model_handler.py b/model_handler/base_model_handler.py index 8e0ca4b..8797f0d 100644 --- a/model_handler/base_model_handler.py +++ b/model_handler/base_model_handler.py @@ -18,7 +18,7 @@ verbose=False, stream=True) results = list(results_generator) # 确保生成器转换为列表 result = results[0] - logger.debug(f"model {self.model.algo_model_info.name} result: {len(result)}") + # logger.debug(f"model {self.model.algo_model_info.name} result: {len(result)}") return result def post_process(self, frame, model_result, annotator): diff --git a/services/device_frame_service.py b/services/device_frame_service.py index 1a468ae..929e39a 100644 --- a/services/device_frame_service.py +++ b/services/device_frame_service.py @@ -27,36 +27,36 @@ self.db = db async def add_frame(self, device_id, frame_data) -> DeviceFrame: - async def add_frame(self, device_id, frame_data) -> 'DeviceFrame': - async def save_frame_file(): - # 生成当前年月日作为目录路径 - current_date = datetime.now().strftime('%Y-%m-%d') - # 生成随机 UUID 作为文件名 - file_name = f"{uuid.uuid4()}.jpeg" - # 创建保存图片的完整路径 - save_path = os.path.join('./storage/frames', current_date, file_name) - # 创建目录(如果不存在) - os.makedirs(os.path.dirname(save_path), exist_ok=True) - # 将 frame_data 转换为二进制数据(假设 frame_data 是一个 numpy 数组) - _, encoded_image = cv2.imencode('.jpeg', frame_data) - image_data = encoded_image.tobytes() + async def save_frame_file(): + # 生成当前年月日作为目录路径 + current_date = datetime.now().strftime('%Y-%m-%d') + # 生成随机 UUID 作为文件名 + file_name = f"{uuid.uuid4()}.jpeg" + # 创建保存图片的完整路径 + save_path = os.path.join('./storage/frames', current_date, file_name) + # 创建目录(如果不存在) + os.makedirs(os.path.dirname(save_path), exist_ok=True) - # 使用 aiofiles 进行异步写入 - async with aiofiles.open(save_path, 'wb') as f: - await f.write(image_data) + # 将 frame_data 转换为二进制数据(假设 frame_data 是一个 numpy 数组) + _, encoded_image = cv2.imencode('.jpeg', frame_data) + image_data = encoded_image.tobytes() - return save_path + # 使用 aiofiles 进行异步写入 + async with aiofiles.open(save_path, 'wb') as f: + await f.write(image_data) - # 异步保存图片文件 - file_path = await save_frame_file() + return save_path - # 创建并保存到数据库中 - device_frame = DeviceFrame(device_id=device_id, frame_path=file_path) - self.db.add(device_frame) - await self.db.commit() - await self.db.refresh(device_frame) - return device_frame + # 异步保存图片文件 + file_path = await save_frame_file() + + # 创建并保存到数据库中 + device_frame = DeviceFrame(device_id=device_id, frame_path=file_path) + self.db.add(device_frame) + await self.db.commit() + await self.db.refresh(device_frame) + return device_frame async def get_frame_page(self, device_name: Optional[str] = None, @@ -79,6 +79,7 @@ statement = statement.where(DeviceFrame.time >= frame_start_time) if frame_start_time: statement = statement.where(DeviceFrame.time <= frame_end_time) + statement = statement.order_by(DeviceFrame.time.desc()) # 查询总记录数 total_statement = select(func.count()).select_from(statement.subquery()) diff --git a/services/scene_service.py b/services/scene_service.py index dc4c67b..3fbdf0a 100644 --- a/services/scene_service.py +++ b/services/scene_service.py @@ -204,7 +204,7 @@ DeviceSceneRelation.scene_id == scene_id, ) ) - result = await self.db.exec(statement) + result = await self.db.execute(statement) rows = result.all() return len(rows) > 0 diff --git a/tcp/tcp_manager.py b/tcp/tcp_manager.py index 094aa6a..f7ed12e 100644 --- a/tcp/tcp_manager.py +++ b/tcp/tcp_manager.py @@ -80,10 +80,10 @@ -if __name__ == '__main__': - async for db in get_db(): - global_config = GlobalConfig() - await global_config.init_config() - device_service = DeviceService(db) - tcp_manager = TcpManager(device_service) - asyncio.run(tcp_manager.start()) +# if __name__ == '__main__': + # async for db in get_db(): + # global_config = GlobalConfig() + # await global_config.init_config() + # device_service = DeviceService(db) + # tcp_manager = TcpManager(device_service) + # asyncio.run(tcp_manager.start())