from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager from common.consts import NotifyChangeType from common.global_thread_pool import GlobalThreadPool from db.database import get_db from entity.device import Device from services.device_model_relation_service import DeviceModelRelationService from services.device_service import DeviceService from services.model_service import ModelService from common.global_logger import logger class AlgoRunner: def __init__(self): with next(get_db()) as db: self.device_service = DeviceService(db) self.model_service = ModelService(db) self.relation_service = DeviceModelRelationService(db) self.model_manager = ModelManager(db) self.thread_pool = GlobalThreadPool() self.device_tasks = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 self.device_service.register_change_callback(self.on_device_change) self.model_service.register_change_callback(self.on_model_change) self.relation_service.register_change_callback(self.on_relation_change) def start(self): logger.info("Starting AlgoRunner...") """在程序启动时调用,读取设备和模型,启动检测线程""" self.model_manager.load_models() self.load_and_start_devices() def load_and_start_devices(self): """从数据库读取设备列表并启动线程""" devices = self.device_service.get_device_list() for device in devices: self.start_device_thread(device) @staticmethod def get_device_thread_id(device_id): return f'device_{device_id}' def start_device_thread(self, device: Device): """为单个设备启动检测线程""" if not device.input_stream_url: logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return thread_id = self.get_device_thread_id(device.id) if self.thread_pool.check_task_is_running(thread_id=thread_id): logger.info(f"设备 {device.code} 已经在运行中") return # 获取设备绑定的模型列表 relations = self.relation_service.get_device_models(device.id) relations = [r for r in relations if r.is_use] self.device_model_relations[device.id] = relations if not relations: logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") return model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list) future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) self.device_tasks[device.id] = device_detection_task def stop_device_thread(self, device_id): """停止指定设备的检测线程""" if device_id in self.device_tasks: self.device_tasks[device_id].stop_detection_task() del self.device_tasks[device_id] # def stop_device_thread(self, device_id): # """停止指定设备的检测线程,并确保其成功停止""" # if device_id in self.device_tasks: # # 获取线程 ID 和 future 对象 # thread_id = self.get_device_thread_id(device_id) # future = self.thread_pool.check_task_stopped(thread_id=thread_id) # # if future: # # 调用 stop_detection_task 停止任务 # self.device_tasks[device_id].stop_detection_task() # # try: # # 设置超时时间等待任务停止(例如10秒) # result = future.result(timeout=10) # logger.info(f"Task for device {device_id} stopped successfully.") # except TimeoutError: # logger.error(f"Task for device {device_id} did not stop within the timeout.") # except Exception as e: # logger.error(f"Task for device {device_id} encountered an error while stopping: {e}") # finally: # # 确保无论任务是否停止,都将其从任务列表中移除 # del self.device_tasks[device_id] # else: # logger.warning(f"No task found for device {device_id}.") # else: # logger.warning(f"No task exists for device {device_id} in device_tasks.") def restart_device_thread(self, device_id): self.stop_device_thread(device_id) # todo 需要留个关闭时间吗 device = self.device_service.get_device(device_id) self.start_device_thread(device) def on_device_change(self, device_id, change_type): """设备变化回调""" if change_type == NotifyChangeType.DEVICE_CREATE: device = self.device_service.get_device(device_id) self.start_device_thread(device) elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: #todo 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 self.restart_device_thread(device_id) def on_model_change(self, model_id, change_type): """模型变化回调""" if change_type == NotifyChangeType.MODEL_UPDATE: self.model_manager.reload_model(model_id) devices_to_reload = [ device_id for device_id, relation_info in self.device_model_relations.items() if relation_info.model_id == model_id ] for device_id in devices_to_reload: self.restart_device_thread(device_id) def on_relation_change(self, device_id, change_type): """设备模型关系变化回调""" # todo 判断模型绑定关系是否真的发生了变化,避免不必要的重启 if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: self.restart_device_thread(device_id)