import copy import uuid from typing import Dict from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager from common.consts import NotifyChangeType from common.global_thread_pool import GlobalThreadPool from db.database import get_db from entity.device import Device from services.device_model_relation_service import DeviceModelRelationService from services.device_service import DeviceService from services.model_service import ModelService from common.global_logger import logger class AlgoRunner: def __init__(self): with next(get_db()) as db: self.device_service = DeviceService(db) self.model_service = ModelService(db) self.relation_service = DeviceModelRelationService(db) self.model_manager = ModelManager(db) self.thread_pool = GlobalThreadPool() self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 self.device_service.register_change_callback(self.on_device_change) self.model_service.register_change_callback(self.on_model_change) self.relation_service.register_change_callback(self.on_relation_change) def start(self): logger.info("Starting AlgoRunner...") """在程序启动时调用,读取设备和模型,启动检测线程""" self.model_manager.load_models() self.load_and_start_devices() def load_and_start_devices(self): """从数据库读取设备列表并启动线程""" devices = self.device_service.get_device_list() for device in devices: self.start_device_thread(device) # @staticmethod # def get_device_thread_id(device_id): # return f'device_{device_id}' def start_device_thread(self, device: Device): device = copy.deepcopy(device) """为单个设备启动检测线程""" if not device.input_stream_url: logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return thread_id = f'device_{device.id}_{uuid.uuid4()}' logger.info(f'start thread {thread_id}, device info: {device}') # if self.thread_pool.check_task_is_running(thread_id=thread_id): # logger.info(f"设备 {device.code} 已经在运行中") # return # 获取设备绑定的模型列表 relations = self.relation_service.get_device_models(device.id) relations = [r for r in relations if r.is_use] self.device_model_relations[device.id] = relations if not relations: logger.info(f"设备 {device.code} 未绑定模型,无法启动检测") return for r in relations: if r.algo_model_id not in self.model_manager.models: self.model_manager.load_new_model(r.algo_model_id) model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations if self.model_manager.models[r.algo_model_id].algo_model_exec is not None] device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list, thread_id=thread_id) future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id) self.device_tasks[device.id] = device_detection_task # def stop_device_thread(self, device_id): # """停止指定设备的检测线程""" # if device_id in self.device_tasks: # logger.info(f'stop device {device_id} thread') # self.device_tasks[device_id].stop_detection_task() # del self.device_tasks[device_id] def stop_device_thread(self, device_id): """停止指定设备的检测线程,并确保其成功停止""" if device_id in self.device_tasks: logger.info(f'stop device {device_id} thread') # 获取线程 ID 和 future 对象 thread_id = self.device_tasks[device_id].thread_id future = self.thread_pool.get_task_future(thread_id=thread_id) if future: if not future.done(): # 任务正在运行,调用 stop_detection_task 停止任务 self.device_tasks[device_id].stop_detection_task() try: # 设置超时时间等待任务停止(例如10秒) result = future.result(timeout=10) logger.info(f"Task {thread_id} stopped successfully.") except TimeoutError: logger.error(f"Task {thread_id} did not stop within the timeout.") except Exception as e: logger.error(f"Task {thread_id} encountered an error while stopping: {e}") finally: # 确保无论任务是否停止,都将其从任务列表中移除 del self.device_tasks[device_id] else: logger.info(f"Task {thread_id} has already stopped.") # 任务已停止,直接从任务列表中移除 del self.device_tasks[device_id] else: logger.warning(f"No task found for {thread_id} .") else: logger.warning(f"No task exists for device {device_id} in device_tasks.") def restart_device_thread(self, device_id): self.stop_device_thread(device_id) device = self.device_service.get_device(device_id) self.start_device_thread(device) def on_device_change(self, device_id, change_type): logger.info(f"on device change, device {device_id} {change_type}") """设备变化回调""" if change_type == NotifyChangeType.DEVICE_CREATE: device = self.device_service.get_device(device_id) self.start_device_thread(device) elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 old_device = self.device_tasks[device_id].device new_device = self.device_service.get_device(device_id) self.restart_device_thread(device_id) def on_model_change(self, model_id, change_type): """模型变化回调""" if change_type == NotifyChangeType.MODEL_UPDATE: self.model_manager.reload_model(model_id) devices_to_reload = [ device_id for device_id, relation_info in self.device_model_relations.items() if relation_info.model_id == model_id ] for device_id in devices_to_reload: self.restart_device_thread(device_id) def on_relation_change(self, device_id, change_type): """设备模型关系变化回调""" # todo 判断模型绑定关系是否真的发生了变化,避免不必要的重启 if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: self.restart_device_thread(device_id)