Newer
Older
safe-algo-pro / algo / device_detection_task.py
zhangyingjie on 17 Oct 4 KB 修改线程控制问题
from dataclasses import dataclass
import importlib
from datetime import datetime
from threading import Event
from typing import List, Dict

from algo.model_manager import AlgoModelExec
from algo.stream_loader import OpenCVStreamLoad
from common.global_logger import logger
from common.global_thread_pool import GlobalThreadPool
from common.string_utils import camel_to_snake
from db.database import get_db
from entity.device import Device
from entity.frame_analysis_result import FrameAnalysisResultCreate
from services.device_frame_service import DeviceFrameService
from services.frame_analysis_result_service import FrameAnalysisResultService


def get_class(module_name, class_name):
    # 动态导入模块
    module = importlib.import_module(module_name)
    # 使用 getattr 从模块中获取类
    return getattr(module, class_name)


@dataclass
class DetectionResult:
    object_class_id: int
    object_class_name: str
    confidence: float
    location: str

    @classmethod
    def from_dict(cls, data: Dict) -> 'DetectionResult':
        return DetectionResult(
            object_class_id=data['object_class_id'],
            object_class_name=data['object_class_name'],
            confidence=data['confidence'],
            location=data['location']
        )


class DeviceDetectionTask:
    def __init__(self, device: Device, model_exec_list: List[AlgoModelExec], thread_id):
        self.device = device
        self.model_exec_list = model_exec_list
        self.__stop_event = Event()  # 使用 Event 控制线程的运行状态
        self.frame_ts = None
        self.thread_id = thread_id

        with next(get_db()) as db:
            self.device_frame_service = DeviceFrameService(db)
            self.frame_analysis_result_service = FrameAnalysisResultService(db)

        self.thread_pool = GlobalThreadPool()

        self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code,
                                              device_thread_id=thread_id)

    def stop_detection_task(self):
        logger.info(f'stop detection task {self.device.id}, thread_id: {self.thread_id}')
        self.__stop_event.set()
        self.stream_loader.stop()  # 停止视频流加载的线程

    def check_frame_interval(self):
        if self.device.image_save_interval < 0:
            return False
        if self.frame_ts is None:
            self.frame_ts = datetime.now()
            return True
        if (datetime.now() - self.frame_ts).total_seconds() > self.device.image_save_interval:
            self.frame_ts = datetime.now()
            return True
        return False

    def save_frame_results(self, frame, results_map):
        if not self.check_frame_interval():
            return
        device_frame = self.device_frame_service.add_frame(self.device.id, frame)
        frame_id = device_frame.id
        logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}')
        frame_results = []
        for model_exec_id, results in results_map.items():
            for r in results:
                frame_result = FrameAnalysisResultCreate(
                    device_id=self.device.id,
                    frame_id=frame_id,
                    algo_model_id=model_exec_id,
                    object_class_id=r.object_class_id,
                    object_class_name=r.object_class_name,
                    confidence=round(r.confidence, 4),
                    location=r.location,
                )
                frame_results.append(frame_result)
        self.frame_analysis_result_service.add_frame_analysis_results(frame_results)

    def run(self):
        for frame in self.stream_loader:
            if self.__stop_event.is_set():
                break  # 如果触发了停止事件,则退出循环
            # print('frame')
            if frame is None:
                continue

            results_map = {}
            for model_exec in self.model_exec_list:
                handle_task_name = model_exec.algo_model_info.handle_task
                handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name)
                handler_instance = handler_cls(model_exec)
                frame, results = handler_instance.run(frame, None)
                results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results]

            # 结果处理
            self.thread_pool.submit_task(self.save_frame_results, frame, results_map)