Newer
Older
safe-algo-pro / algo / device_detection_task.py
zhangyingjie on 4 Mar 8 KB 部署版本
import asyncio
import json
from copy import deepcopy
from dataclasses import dataclass
import importlib
from datetime import datetime, timedelta
from threading import Event
from typing import List, Dict

from algo.model_manager import AlgoModelExec
from algo.stream_loader import OpenCVStreamLoad
from common.device_status_manager import DeviceStatusManager
from common.display_frame_manager import DisplayFrameManager
from common.global_logger import logger
from common.global_thread_pool import GlobalThreadPool
from common.http_utils import send_request
from common.image_plotting import Annotator
from common.string_utils import camel_to_snake, get_class, default_serializer
from db.database import get_db
from entity.device import Device
from entity.frame_analysis_result import FrameAnalysisResultCreate
from services.device_frame_service import DeviceFrameService
from services.frame_analysis_result_service import FrameAnalysisResultService
from services.global_config import GlobalConfig


@dataclass
class DetectionResult:
    object_class_id: int
    object_class_name: str
    confidence: float
    location: str

    @classmethod
    def from_dict(cls, data: Dict) -> 'DetectionResult':
        return DetectionResult(
            object_class_id=data['object_class_id'],
            object_class_name=data['object_class_name'],
            confidence=data['confidence'],
            location=data['location']
        )


class DeviceDetectionTask:
    def __init__(self, device: Device, model_exec_list: List[AlgoModelExec], thread_id, db_session, main_loop):
        self.device = device
        self.model_exec_list = model_exec_list
        self.__stop_event = Event()  # 使用 Event 控制线程的运行状态
        self.frame_ts = None
        self.push_ts = None
        self.frames_detected = 0
        self.fps_ts = None
        self.thread_id = thread_id

        self.device_frame_service = DeviceFrameService(db_session)
        self.frame_analysis_result_service = FrameAnalysisResultService(db_session)

        self.thread_pool = GlobalThreadPool()
        self.device_status_manager = DeviceStatusManager()
        self.display_frame_manager = DisplayFrameManager()
        self.main_loop = main_loop

        self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code,
                                              device_thread_id=thread_id)

    def stop_detection_task(self):
        logger.info(f'stop detection task {self.device.id}, thread_id: {self.thread_id}')
        self.__stop_event.set()
        self.stream_loader.stop()  # 停止视频流加载的线程

    def check_frame_interval(self):
        if self.device.image_save_interval < 0:
            return False
        if self.frame_ts is None:
            self.frame_ts = datetime.now()
            return True
        if (datetime.now() - self.frame_ts).total_seconds() > self.device.image_save_interval:
            self.frame_ts = datetime.now()
            return True
        return False

    def save_frame_results(self, frames, result_maps):
        for idx,frame in enumerate(frames):
            results_map = result_maps[idx]

            if not self.check_frame_interval():
                return
            # device_frame = self.device_frame_service.add_frame(self.device.id, frame)
            # frame_id = device_frame.id

            future = asyncio.run_coroutine_threadsafe(
                self.device_frame_service.add_frame(self.device.id, frame), self.main_loop
            )
            device_frame = future.result()
            frame_id = device_frame.id

            logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}')
            frame_results = []
            for model_exec_id, results in results_map.items():
                for r in results:
                    frame_result = FrameAnalysisResultCreate(
                        device_id=self.device.id,
                        frame_id=frame_id,
                        algo_model_id=model_exec_id,
                        object_class_id=r.object_class_id,
                        object_class_name=r.object_class_name,
                        confidence=round(r.confidence, 4),
                        location=r.location,
                    )
                    frame_results.append(frame_result)
            asyncio.run_coroutine_threadsafe(
                self.frame_analysis_result_service.add_frame_analysis_results(frame_results),
                self.main_loop
            )
            self.thread_pool.submit_task(self.push_frame_results, frame_results)

    def push_frame_results(self, frame_results):
        global_config = GlobalConfig()
        push_config = global_config.get_algo_result_push_config()
        if push_config and push_config.push_url:
            last_ts = self.push_ts
            current_time = datetime.now()

            # 检查是否需要推送数据
            if last_ts is None or (current_time - last_ts).total_seconds() > push_config.push_interval:
                send_request(
                    push_config.push_url,
                    json.dumps([r.dict() for r in frame_results], default=default_serializer)
                )
                self.push_ts = current_time  # 更新推送时间戳

    def log_fps(self, frame_count):
        self.frames_detected += frame_count
        current_time = datetime.now()
        # 每秒输出 FPS
        if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10):
            fps = self.frames_detected / 10.0
            self.frames_detected = 0
            logger.info(f"FPS (detect) for device {self.device.code}: {fps}")
            self.fps_ts = current_time

    def run(self):
        while not self.stream_loader.init:
            if self.__stop_event.is_set():
                break  # 如果触发了停止事件,则退出循环
            self.stream_loader.init_cap()
        for frames in self.stream_loader:
            if self.__stop_event.is_set():
                break  # 如果触发了停止事件,则退出循环
            # print('frame')
            if frames is None or len(frames)<=0:
                continue

            self.device_status_manager.set_status(device_id=self.device.id)

            result_maps = [] # 保存每个frame的result map
            annotators = []
            for frame in frames:
                annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人")
                annotators.append(annotator)

            for model_exec in self.model_exec_list:
                handle_task_name = model_exec.algo_model_info.handle_task
                handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name)
                handler_instance = handler_cls(model_exec)
                frames, results = handler_instance.run(frames, annotators)

                # 遍历检测结果,按帧存储
                for frame_idx, r in enumerate(results):  # 遍历每帧结果
                    if len(result_maps) <= frame_idx:  # 初始化 frame_results_map
                        result_maps.append({})
                    frame_results_map = result_maps[frame_idx]
                    # 为当前模型存储检测结果
                    if model_exec.algo_model_id not in frame_results_map:
                        frame_results_map[model_exec.algo_model_id] = []
                    # 添加检测结果
                    frame_results_map[model_exec.algo_model_id].extend(
                        DetectionResult.from_dict(box) for box in r
                    )

            # 结果处理
            self.thread_pool.submit_task(self.save_frame_results, frames, result_maps)
            for annotator in annotators:
                self.display_frame_manager.add_frame(self.device.id, annotator.result())

            self.log_fps(len(frames))


            # future = asyncio.run_coroutine_threadsafe(
            #     self.display_frame_manager.add_frame(self.device.id, annotator.result()), self.main_loop
            # )