Newer
Older
safe-algo-pro / algo / algo_runner.py
zhangyingjie on 12 Oct 4 KB first commit
from algo.model_manager import ModelManager
from common.consts import NotifyChangeType
from common.global_thread_pool import GlobalThreadPool
from db.database import get_db
from entity.device import Device
from services.device_model_relation_service import DeviceModelRelationService
from services.device_service import DeviceService
from services.model_service import ModelService


class AlgoRunner:
    def __init__(self):
        self.db = get_db()
        self.device_service = DeviceService(self.db)
        self.model_service = ModelService(self.db)
        self.relation_service = DeviceModelRelationService(self.db)
        self.model_manager = ModelManager(self.db)
        self.thread_pool = GlobalThreadPool()
        self.threads = {}  # 用于存储设备对应的线程
        self.device_model_relations = {}

        # 注册设备和模型的变化回调
        self.device_service.register_change_callback(self.on_device_change)
        self.model_service.register_change_callback(self.on_model_change)
        self.relation_service.relation_change_callbacks(self.on_relation_change)

    def start(self):
        """在程序启动时调用,读取设备和模型,启动检测线程"""
        self.model_manager.load_models()
        self.load_and_start_devices()

    def load_and_start_devices(self):
        """从数据库读取设备列表并启动线程"""
        devices = self.device_service.get_device_list()
        devices = [device for device in devices if device.input_stream_url]
        for device in devices:
            self.start_device_thread(device)

    def run_detection(self, device: Device):
        """todo 设备目标检测主逻辑"""
        print(f"设备 {device.device_id} 的检测线程启动")
        video_stream = self.device_service.get_video_stream(device.device_id)

        while True:
            try:
                frame = video_stream.read_frame()
                for model in self.model_manager.models:
                    # 调用目标检测模型
                    result = self.model_service.run_inference(model, frame)
                    # 在此处进行后处理,例如标记视频、生成告警等
                    self.process_result(result, device.device_id)
            except Exception as e:
                print(f"设备 {device.device_id} 处理时出错: {e}")
                break
        video_stream.close()
        print(f"设备 {device.device_id} 的检测线程结束")

    def start_device_thread(self, device: Device):
        """为单个设备启动检测线程"""

        if not device.input_stream_url:
            print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程')
            return

        if device.device_id in self.threads:
            print(f"设备 {device.device_id} 已经在运行中")
            return

        # 获取设备绑定的模型列表
        relations = self.relation_service.get_device_models(device.device_id)
        self.device_model_relations[device.device_id] = relations

        if not relations:
            print(f"设备 {device.code} 未绑定模型,无法启动检测")
            return

        future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}')

    def stop_device_thread(self, device_id):
        """todo 控制标志位 停止指定设备的检测线程"""
        if device_id in self.threads:
            self.threads[device_id].cancel()  # 尝试取消线程
            del self.threads[device_id]
            print(f"设备 {device_id} 的检测线程已停止")

    def restart_device_thread(self, device_id):
        self.stop_device_thread(device_id)
        # todo 需要留个关闭时间吗
        device = self.device_service.get_device(device_id)
        self.start_device_thread(device)

    def on_device_change(self, device_id, change_type):
        """设备变化回调"""
        if change_type == NotifyChangeType.DEVICE_CREATE:
            device = self.device_service.get_device(device_id)
            self.start_device_thread(device)
        elif change_type == NotifyChangeType.DEVICE_DELETE:
            self.stop_device_thread(device_id)
        elif change_type == NotifyChangeType.DEVICE_UPDATE:
            self.restart_device_thread(device_id)

    def on_model_change(self, model_id, change_type):
        """模型变化回调"""
        if change_type == NotifyChangeType.MODEL_UPDATE:
            self.model_manager.reload_model(model_id)

            devices_to_reload = [
                device_id for device_id, relation_info in self.device_model_relations.items()
                if relation_info.model_id == model_id
            ]
            for device_id in devices_to_reload:
                self.restart_device_thread(device_id)

    def on_relation_change(self, device_id, change_type):
        """设备模型关系变化回调"""
        if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE:
            self.restart_device_thread(device_id)