Newer
Older
safe-algo-pro / algo / scene_runner.py
import asyncio
import concurrent
import copy
import uuid
from typing import Dict

from common.consts import DEVICE_MODE, NotifyChangeType
from common.global_logger import logger
from common.global_thread_pool import GlobalThreadPool
from common.string_utils import get_class, camel_to_snake
from db.database import get_db
from entity.device import Device
from scene_handler.base_scene_handler import BaseSceneHandler
from services.device_scene_relation_service import DeviceSceneRelationService
from services.device_service import DeviceService
from services.scene_service import SceneService
from tcp.tcp_manager import TcpManager


class SceneRunner:
    def __init__(self,
                 device_service: DeviceService,
                 scene_service: SceneService,
                 relation_service: DeviceSceneRelationService,
                 tcp_manager: TcpManager,
                 main_loop):
        self.device_service = device_service
        self.scene_service = scene_service
        self.relation_service = relation_service
        self.tcp_manager = tcp_manager
        self.main_loop = main_loop

        self.thread_pool = GlobalThreadPool()
        self.device_tasks: Dict[int, BaseSceneHandler] = {}
        self.device_scene_relations = {}

        # 注册设备和模型的变化回调
        # self.device_service.register_change_callback(self.on_device_change)
        # self.scene_service.register_change_callback(self.on_scene_change)
        # self.relation_service.register_change_callback(self.on_relation_change)

    async def start(self):
        await self.load_and_start_devices()

    async def load_and_start_devices(self):
        """从数据库读取设备列表并启动线程"""
        devices = await self.device_service.get_device_list()
        devices = [device for device in devices if device.mode == DEVICE_MODE.SCENE]
        for device in devices:
            asyncio.create_task(self.start_device_thread(device))

    async def start_device_thread(self, device: Device):
        device = copy.deepcopy(device)
        """为单个设备启动检测线程"""

        if not device.input_stream_url:
            logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程')
            return

        if not device.mode == DEVICE_MODE.SCENE:
            return

        thread_id = f'device_{device.id}_scene_{uuid.uuid4()}'
        async for db in get_db():
            relation_service = DeviceSceneRelationService(db)
            scene = await relation_service.get_device_scene(device.id)
            if scene:
                self.device_scene_relations[device.id] = scene
                try:
                    handle_task_name = scene.scene_handle_task
                    handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name)
                    handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop,scene.range_points)
                    future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id)
                    self.device_tasks[device.id] = handler_instance
                    logger.info(f'start thread {thread_id}, device info: {device}')
                except Exception as e:
                    logger.exception(f"设备 {device.code} 场景 {scene.scene_name} 启动异常 : {e}")

    def stop_device_thread(self, device_id):
        """停止指定设备的检测线程,并确保其成功停止"""
        if device_id in self.device_tasks:
            logger.info(f'stop device {device_id} thread')
            # 获取线程 ID 和 future 对象
            thread_id = self.device_tasks[device_id].thread_id
            future = self.thread_pool.get_task_future(thread_id=thread_id)

            if future:
                if not future.done():
                    # 任务正在运行,调用 stop_detection_task 停止任务
                    self.device_tasks[device_id].stop_task()
                    try:
                        # 设置超时时间等待任务停止(例如10秒)
                        result = future.result(timeout=30)
                        logger.info(f"Task {thread_id} stopped successfully.")
                    except concurrent.futures.TimeoutError as te:
                        logger.error(f"Task {thread_id} did not stop within the timeout.")
                    except Exception as e:
                        logger.exception(f"Task {thread_id} encountered an error while stopping: {e}")
                    finally:
                        # 确保无论任务是否停止,都将其从任务列表中移除
                        del self.device_tasks[device_id]
                else:
                    logger.info(f"Task {thread_id} has already stopped.")
                    # 任务已停止,直接从任务列表中移除
                    del self.device_tasks[device_id]
            else:
                logger.warning(f"No task found for {thread_id} .")
        else:
            logger.warning(f"No task exists for device {device_id} in device_tasks.")

    async def restart_device_thread(self, device_id):
        try:
            self.stop_device_thread(device_id)
            device = await self.device_service.get_device(device_id)
            asyncio.create_task(self.start_device_thread(device))
        except Exception as e:
            logger.error(e)

    async def on_device_change(self, device_id, change_type):
        logger.info(f"on device change, device {device_id} {change_type}")
        """设备变化回调"""
        if change_type == NotifyChangeType.DEVICE_CREATE:
            device = await self.device_service.get_device(device_id)
            await self.start_device_thread(device)
        elif change_type == NotifyChangeType.DEVICE_DELETE:
            self.stop_device_thread(device_id)
        elif change_type == NotifyChangeType.DEVICE_UPDATE:
            # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
            if device_id in self.device_tasks:
                old_device = self.device_tasks[device_id].device
                new_device = self.device_service.get_device(device_id)

            await self.restart_device_thread(device_id)

    async def on_scene_change(self, scene_id, change_type):
        """模型变化回调"""
        if change_type == NotifyChangeType.SCENE_UPDATE:

            devices_to_reload = [
                device_id for device_id, relation_info in self.device_scene_relations.items()
                if relation_info.scene_id == scene_id
            ]
            for device_id in devices_to_reload:
                await self.restart_device_thread(device_id)

    async def on_relation_change(self, device_id, change_type):
        """设备模型关系变化回调"""
        # 判断模型绑定关系是否真的发生了变化,避免不必要的重启
        if change_type == NotifyChangeType.DEVICE_SCENE_RELATION_UPDATE:
            await self.restart_device_thread(device_id)