Newer
Older
safe-algo-pro / algo / algo_runner.py
zhangyingjie on 4 Mar 7 KB 部署版本
import asyncio
import concurrent.futures
import copy
import uuid
from typing import Dict

from algo.device_detection_task import DeviceDetectionTask
from algo.model_manager import ModelManager
from common.consts import NotifyChangeType, DEVICE_MODE
from common.global_thread_pool import GlobalThreadPool
from db.database import get_db
from entity.device import Device
from services.device_model_relation_service import DeviceModelRelationService
from services.device_service import DeviceService
from services.model_service import ModelService
from common.global_logger import logger


class AlgoRunner:
    def __init__(self,
                 device_service: DeviceService,
                 model_service: ModelService,
                 relation_service: DeviceModelRelationService):

        self.device_service = device_service
        self.model_service = model_service
        self.relation_service = relation_service
        self.model_manager = ModelManager(model_service)

        self.thread_pool = GlobalThreadPool()
        self.device_tasks: Dict[int, DeviceDetectionTask] = {}  # 用于存储设备对应的线程
        self.task_futures: Dict[str, concurrent.futures.Future] = {}  # 用于存储线程的 future 对象
        self.device_model_relations = {}

        self.main_loop = asyncio.get_running_loop()

        # 注册设备和模型的变化回调
        # self.device_service.register_change_callback(self.on_device_change)
        # self.model_service.register_change_callback(self.on_model_change)
        # self.relation_service.register_change_callback(self.on_relation_change)

    async def start(self):
        logger.info("Starting AlgoRunner...")
        """在程序启动时调用,读取设备和模型,启动检测线程"""
        await self.model_manager.load_models()
        await self.load_and_start_devices()

    async def load_and_start_devices(self):
        """从数据库读取设备列表并启动线程"""
        devices = await self.device_service.get_device_list()
        for device in devices:
            asyncio.create_task(self.start_device_thread(device))

    async def start_device_thread(self, device: Device):
        device = copy.deepcopy(device)
        """为单个设备启动检测线程"""

        if not device.input_stream_url:
            logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程')
            return

        if not device.mode == DEVICE_MODE.ALGO:
            return

        thread_id = f'device_{device.id}_algo_{uuid.uuid4()}'
        logger.info(f'start thread {thread_id}, device info: {device}')

        # if self.thread_pool.check_task_is_running(thread_id=thread_id):
        #     logger.info(f"设备 {device.code} 已经在运行中")
        #     return

        # 获取设备绑定的模型列表
        async for db in get_db():
            relation_service = DeviceModelRelationService(db)
            relations = await relation_service.get_device_models(device.id)
            relations = [r for r in relations if r.is_use]
            self.device_model_relations[device.id] = relations

            if not relations:
                logger.info(f"设备 {device.code} 未绑定模型,无法启动检测")
                return

            for r in relations:
                if r.algo_model_id not in self.model_manager.models:
                    await self.model_manager.load_new_model(r.algo_model_id)

            model_exec_list = [self.model_manager.models[r.algo_model_id] for r in relations
                               if self.model_manager.models[r.algo_model_id].algo_model_exec is not None]
            device_detection_task = DeviceDetectionTask(device=device, model_exec_list=model_exec_list,
                                                        thread_id=thread_id, db_session=db, main_loop=self.main_loop)
            future = self.thread_pool.submit_task(device_detection_task.run, thread_id=thread_id)
            self.device_tasks[device.id] = device_detection_task
            self.task_futures[thread_id] = future

    def stop_device_thread(self, device_id):
        """停止指定设备的检测线程,并确保其成功停止"""
        if device_id in self.device_tasks:
            logger.info(f'stop device {device_id} thread')
            # 获取线程 ID 和 future 对象
            thread_id = self.device_tasks[device_id].thread_id
            # future = self.thread_pool.get_task_future(thread_id=thread_id)
            future = self.task_futures[thread_id]

            if future:
                if not future.done():
                    # 任务正在运行,调用 stop_detection_task 停止任务
                    self.device_tasks[device_id].stop_detection_task()
                    try:
                        # 设置超时时间等待任务停止(例如10秒)
                        result = future.result(timeout=30)
                        logger.info(f"Task {thread_id} stopped successfully.")
                    except concurrent.futures.TimeoutError as te:
                        logger.error(f"Task {thread_id} did not stop within the timeout.")
                    except Exception as e:
                        logger.exception(f"Task {thread_id} encountered an error while stopping: {e}")
                    finally:
                        # 确保无论任务是否停止,都将其从任务列表中移除
                        del self.device_tasks[device_id]
                else:
                    logger.info(f"Task {thread_id} has already stopped.")
                    # 任务已停止,直接从任务列表中移除
                    del self.device_tasks[device_id]
            else:
                logger.warning(f"No task found for {thread_id} .")
        else:
            logger.warning(f"No task exists for device {device_id} in device_tasks.")

    async def restart_device_thread(self, device_id):
        self.stop_device_thread(device_id)
        device = await self.device_service.get_device(device_id)
        asyncio.create_task(self.start_device_thread(device))

    async def on_device_change(self, device_id, change_type):
        logger.info(f"on device change, device {device_id} {change_type}")
        """设备变化回调"""
        if change_type == NotifyChangeType.DEVICE_CREATE:
            device = await self.device_service.get_device(device_id)
            await self.start_device_thread(device)
        elif change_type == NotifyChangeType.DEVICE_DELETE:
            self.stop_device_thread(device_id)
        elif change_type == NotifyChangeType.DEVICE_UPDATE:
            if device_id in self.device_tasks:
                # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
                old_device = self.device_tasks[device_id].device
                new_device = self.device_service.get_device(device_id)

            await self.restart_device_thread(device_id)

    async def on_model_change(self, model_id, change_type):
        """模型变化回调"""
        if change_type == NotifyChangeType.MODEL_UPDATE:
            self.model_manager.reload_model(model_id)

            devices_to_reload = [
                device_id for device_id, relation_info in self.device_model_relations.items()
                if relation_info.model_id == model_id
            ]
            for device_id in devices_to_reload:
                await self.restart_device_thread(device_id)

    async def on_relation_change(self, device_id, change_type):
        """设备模型关系变化回调"""
        # todo 判断模型绑定关系是否真的发生了变化,避免不必要的重启
        if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE:
            await self.restart_device_thread(device_id)