Newer
Older
safe-algo-pro / algo / algo_runner.py
zhangyingjie on 14 Oct 6 KB 完善模型检测流程
from algo.device_detection_task import DeviceDetectionTask
from algo.model_manager import ModelManager
from common.consts import NotifyChangeType
from common.global_thread_pool import GlobalThreadPool
from db.database import get_db
from entity.device import Device
from services.device_model_relation_service import DeviceModelRelationService
from services.device_service import DeviceService
from services.model_service import ModelService
from common.global_logger import logger

class AlgoRunner:
    def __init__(self):
        with next(get_db()) as db:
            self.device_service = DeviceService(db)
            self.model_service = ModelService(db)
            self.relation_service = DeviceModelRelationService(db)
            self.model_manager = ModelManager(db)

        self.thread_pool = GlobalThreadPool()
        self.device_tasks = {}  # 用于存储设备对应的线程
        self.device_model_relations = {}

        # 注册设备和模型的变化回调
        self.device_service.register_change_callback(self.on_device_change)
        self.model_service.register_change_callback(self.on_model_change)
        self.relation_service.register_change_callback(self.on_relation_change)

    def start(self):
        logger.info("Starting AlgoRunner...")
        """在程序启动时调用,读取设备和模型,启动检测线程"""
        self.model_manager.load_models()
        self.load_and_start_devices()

    def load_and_start_devices(self):
        """从数据库读取设备列表并启动线程"""
        devices = self.device_service.get_device_list()
        for device in devices:
            self.start_device_thread(device)

    @staticmethod
    def get_device_thread_id(device_id):
        return f'device_{device_id}'

    def start_device_thread(self, device: Device):
        """为单个设备启动检测线程"""

        if not device.input_stream_url:
            logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程')
            return

        thread_id = self.get_device_thread_id(device.id)

        if self.thread_pool.check_task_is_running(thread_id=thread_id):
            logger.info(f"设备 {device.code} 已经在运行中")
            return

        # 获取设备绑定的模型列表
        relations = self.relation_service.get_device_models(device.id)
        relations = [r for r in relations if r.is_use]
        self.device_model_relations[device.id] = relations

        if not relations:
            logger.info(f"设备 {device.code} 未绑定模型,无法启动检测")
            return

        model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations
                           if self.model_manager.models[r.algo_model_id].algo_model_exec is not None]
        device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list)
        future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id)
        self.device_tasks[device.id] = device_detection_task

    def stop_device_thread(self, device_id):
        """停止指定设备的检测线程"""
        if device_id in self.device_tasks:
            self.device_tasks[device_id].stop_detection_task()
            del self.device_tasks[device_id]

    # def stop_device_thread(self, device_id):
    #     """停止指定设备的检测线程,并确保其成功停止"""
    #     if device_id in self.device_tasks:
    #         # 获取线程 ID 和 future 对象
    #         thread_id = self.get_device_thread_id(device_id)
    #         future = self.thread_pool.check_task_stopped(thread_id=thread_id)
    #
    #         if future:
    #             # 调用 stop_detection_task 停止任务
    #             self.device_tasks[device_id].stop_detection_task()
    #
    #             try:
    #                 # 设置超时时间等待任务停止(例如10秒)
    #                 result = future.result(timeout=10)
    #                 logger.info(f"Task for device {device_id} stopped successfully.")
    #             except TimeoutError:
    #                 logger.error(f"Task for device {device_id} did not stop within the timeout.")
    #             except Exception as e:
    #                 logger.error(f"Task for device {device_id} encountered an error while stopping: {e}")
    #             finally:
    #                 # 确保无论任务是否停止,都将其从任务列表中移除
    #                 del self.device_tasks[device_id]
    #         else:
    #             logger.warning(f"No task found for device {device_id}.")
    #     else:
    #         logger.warning(f"No task exists for device {device_id} in device_tasks.")

    def restart_device_thread(self, device_id):
        self.stop_device_thread(device_id)
        # todo 需要留个关闭时间吗
        device = self.device_service.get_device(device_id)
        self.start_device_thread(device)

    def on_device_change(self, device_id, change_type):
        """设备变化回调"""
        if change_type == NotifyChangeType.DEVICE_CREATE:
            device = self.device_service.get_device(device_id)
            self.start_device_thread(device)
        elif change_type == NotifyChangeType.DEVICE_DELETE:
            self.stop_device_thread(device_id)
        elif change_type == NotifyChangeType.DEVICE_UPDATE:
            #todo  如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
            self.restart_device_thread(device_id)

    def on_model_change(self, model_id, change_type):
        """模型变化回调"""
        if change_type == NotifyChangeType.MODEL_UPDATE:
            self.model_manager.reload_model(model_id)

            devices_to_reload = [
                device_id for device_id, relation_info in self.device_model_relations.items()
                if relation_info.model_id == model_id
            ]
            for device_id in devices_to_reload:
                self.restart_device_thread(device_id)

    def on_relation_change(self, device_id, change_type):
        """设备模型关系变化回调"""
        # todo 判断模型绑定关系是否真的发生了变化,避免不必要的重启
        if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE:
            self.restart_device_thread(device_id)