import asyncio import concurrent.futures import copy import uuid from typing import Dict from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager from common.consts import NotifyChangeType, DEVICE_MODE from common.global_thread_pool import GlobalThreadPool from db.database import get_db from entity.device import Device from services.device_model_relation_service import DeviceModelRelationService from services.device_service import DeviceService from services.model_service import ModelService from common.global_logger import logger class AlgoRunner: def __init__(self, device_service: DeviceService, model_service: ModelService, relation_service: DeviceModelRelationService): self.device_service = device_service self.model_service = model_service self.relation_service = relation_service self.model_manager = ModelManager(model_service) self.thread_pool = GlobalThreadPool() self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程 self.task_futures: Dict[str, concurrent.futures.Future] = {} # 用于存储线程的 future 对象 self.device_model_relations = {} self.main_loop = asyncio.get_running_loop() # 注册设备和模型的变化回调 # self.device_service.register_change_callback(self.on_device_change) # self.model_service.register_change_callback(self.on_model_change) # self.relation_service.register_change_callback(self.on_relation_change) async def start(self): logger.info("Starting AlgoRunner...") """在程序启动时调用,读取设备和模型,启动检测线程""" await self.model_manager.load_models() await self.load_and_start_devices() async def load_and_start_devices(self): """从数据库读取设备列表并启动线程""" devices = await self.device_service.get_device_list() for device in devices: asyncio.create_task(self.start_device_thread(device)) async def start_device_thread(self, device: Device): device = copy.deepcopy(device) """为单个设备启动检测线程""" if not device.input_stream_url: logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return if not device.mode == DEVICE_MODE.ALGO: return thread_id = f'device_{device.id}_algo_{uuid.uuid4()}' logger.info(f'start thread {thread_id}, device info: {device}') # if self.thread_pool.check_task_is_running(thread_id=thread_id): # logger.info(f"设备 {device.code} 已经在运行中") # return # 获取设备绑定的模型列表 async for db in get_db(): relation_service = DeviceModelRelationService(db) relations = await relation_service.get_device_models(device.id) relations = [r for r in relations if r.is_use] self.device_model_relations[device.id] = relations if not relations: logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") return for r in relations: if r.algo_model_id not in self.model_manager.models: await self.model_manager.load_new_model(r.algo_model_id) model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list, thread_id=thread_id, db_session=db, main_loop=self.main_loop) future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) self.device_tasks[device.id] = device_detection_task self.task_futures[thread_id] = future def stop_device_thread(self, device_id): """停止指定设备的检测线程,并确保其成功停止""" if device_id in self.device_tasks: logger.info(f'stop device {device_id} thread') # 获取线程 ID 和 future 对象 thread_id = self.device_tasks[device_id].thread_id # future = self.thread_pool.get_task_future(thread_id=thread_id) future = self.task_futures[thread_id] if future: if not future.done(): # 任务正在运行,调用 stop_detection_task 停止任务 self.device_tasks[device_id].stop_detection_task() try: # 设置超时时间等待任务停止(例如10秒) result = future.result(timeout=30) logger.info(f"Task {thread_id} stopped successfully.") except concurrent.futures.TimeoutError as te: logger.error(f"Task {thread_id} did not stop within the timeout.") except Exception as e: logger.exception(f"Task {thread_id} encountered an error while stopping: {e}") finally: # 确保无论任务是否停止,都将其从任务列表中移除 del self.device_tasks[device_id] else: logger.info(f"Task {thread_id} has already stopped.") # 任务已停止,直接从任务列表中移除 del self.device_tasks[device_id] else: logger.warning(f"No task found for {thread_id} .") else: logger.warning(f"No task exists for device {device_id} in device_tasks.") async def restart_device_thread(self, device_id): self.stop_device_thread(device_id) device = await self.device_service.get_device(device_id) asyncio.create_task(self.start_device_thread(device)) async def on_device_change(self, device_id, change_type): logger.info(f"on device change, device {device_id} {change_type}") """设备变化回调""" if change_type == NotifyChangeType.DEVICE_CREATE: device = await self.device_service.get_device(device_id) await self.start_device_thread(device) elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: if device_id in self.device_tasks: # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 old_device = self.device_tasks[device_id].device new_device = self.device_service.get_device(device_id) await self.restart_device_thread(device_id) async def on_model_change(self, model_id, change_type): """模型变化回调""" if change_type == NotifyChangeType.MODEL_UPDATE: self.model_manager.reload_model(model_id) devices_to_reload = [ device_id for device_id, relation_info in self.device_model_relations.items() if relation_info.model_id == model_id ] for device_id in devices_to_reload: await self.restart_device_thread(device_id) async def on_relation_change(self, device_id, change_type): """设备模型关系变化回调""" # todo 判断模型绑定关系是否真的发生了变化,避免不必要的重启 if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: await self.restart_device_thread(device_id)