Newer
Older
safe-algo-pro / algo / device_detection_task.py
import asyncio
import json
from copy import deepcopy
from dataclasses import dataclass
import importlib
from datetime import datetime, timedelta
from threading import Event
from typing import List, Dict

from algo.model_manager import AlgoModelExec
from algo.stream_loader import OpenCVStreamLoad
from common.device_status_manager import DeviceStatusManager
from common.display_frame_manager import DisplayFrameManager
from common.global_logger import logger
from common.global_thread_pool import GlobalThreadPool
from common.http_utils import send_request
from common.image_plotting import Annotator
from common.string_utils import camel_to_snake, get_class, default_serializer
from db.database import get_db
from entity.device import Device
from entity.frame_analysis_result import FrameAnalysisResultCreate
from services.device_frame_service import DeviceFrameService
from services.frame_analysis_result_service import FrameAnalysisResultService
from services.global_config import GlobalConfig


@dataclass
class DetectionResult:
    object_class_id: int
    object_class_name: str
    confidence: float
    location: str

    @classmethod
    def from_dict(cls, data: Dict) -> 'DetectionResult':
        return DetectionResult(
            object_class_id=data['object_class_id'],
            object_class_name=data['object_class_name'],
            confidence=data['confidence'],
            location=data['location']
        )


class DeviceDetectionTask:
    def __init__(self, device: Device, model_exec_list: List[AlgoModelExec], thread_id, db_session, main_loop):
        self.device = device
        self.model_exec_list = model_exec_list
        self.__stop_event = Event()  # 使用 Event 控制线程的运行状态
        self.frame_ts = None
        self.push_ts = None
        self.frames_detected = 0
        self.fps_ts = None
        self.thread_id = thread_id

        self.device_frame_service = DeviceFrameService(db_session)
        self.frame_analysis_result_service = FrameAnalysisResultService(db_session)

        self.thread_pool = GlobalThreadPool()
        self.device_status_manager = DeviceStatusManager()
        self.display_frame_manager = DisplayFrameManager()
        self.main_loop = main_loop

        self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code,
                                              device_thread_id=thread_id)

    def stop_detection_task(self):
        logger.info(f'stop detection task {self.device.id}, thread_id: {self.thread_id}')
        self.__stop_event.set()
        self.stream_loader.stop()  # 停止视频流加载的线程

    def check_frame_interval(self):
        if self.device.image_save_interval < 0:
            return False
        if self.frame_ts is None:
            self.frame_ts = datetime.now()
            return True
        if (datetime.now() - self.frame_ts).total_seconds() > self.device.image_save_interval:
            self.frame_ts = datetime.now()
            return True
        return False

    def save_frame_results(self, frames, result_maps):
        for idx,frame in enumerate(frames):
            results_map = result_maps[idx]

            if not self.check_frame_interval():
                return
            # device_frame = self.device_frame_service.add_frame(self.device.id, frame)
            # frame_id = device_frame.id

            future = asyncio.run_coroutine_threadsafe(
                self.device_frame_service.add_frame(self.device.id, frame), self.main_loop
            )
            device_frame = future.result()
            frame_id = device_frame.id

            logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}')
            frame_results = []
            for model_exec_id, results in results_map.items():
                for r in results:
                    frame_result = FrameAnalysisResultCreate(
                        device_id=self.device.id,
                        frame_id=frame_id,
                        algo_model_id=model_exec_id,
                        object_class_id=r.object_class_id,
                        object_class_name=r.object_class_name,
                        confidence=round(r.confidence, 4),
                        location=r.location,
                    )
                    frame_results.append(frame_result)
            asyncio.run_coroutine_threadsafe(
                self.frame_analysis_result_service.add_frame_analysis_results(frame_results),
                self.main_loop
            )
            self.thread_pool.submit_task(self.push_frame_results, frame_results)

    def push_frame_results(self, frame_results):
        global_config = GlobalConfig()
        push_config = global_config.get_algo_result_push_config()
        if push_config and push_config.push_url:
            last_ts = self.push_ts
            current_time = datetime.now()

            # 检查是否需要推送数据
            if last_ts is None or (current_time - last_ts).total_seconds() > push_config.push_interval:
                send_request(
                    push_config.push_url,
                    json.dumps([r.dict() for r in frame_results], default=default_serializer)
                )
                self.push_ts = current_time  # 更新推送时间戳

    def log_fps(self, frame_count):
        self.frames_detected += frame_count
        current_time = datetime.now()
        # 每秒输出 FPS
        if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10):
            fps = self.frames_detected / 10.0
            self.frames_detected = 0
            logger.info(f"FPS (detect) for device {self.device.code}: {fps}")
            self.fps_ts = current_time

    def run(self):
        while not self.stream_loader.init:
            if self.__stop_event.is_set():
                break  # 如果触发了停止事件,则退出循环
            self.stream_loader.init_cap()
        for frames in self.stream_loader:
            if self.__stop_event.is_set():
                break  # 如果触发了停止事件,则退出循环
            # print('frame')
            if frames is None or len(frames)<=0:
                continue

            self.device_status_manager.set_status(device_id=self.device.id)

            result_maps = [] # 保存每个frame的result map
            annotators = []
            for frame in frames:
                annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人")
                annotators.append(annotator)

            for model_exec in self.model_exec_list:
                handle_task_name = model_exec.algo_model_info.handle_task
                handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name)
                handler_instance = handler_cls(model_exec)
                frames, results = handler_instance.run(frames, annotators)

                # 遍历检测结果,按帧存储
                for frame_idx, r in enumerate(results):  # 遍历每帧结果
                    if len(result_maps) <= frame_idx:  # 初始化 frame_results_map
                        result_maps.append({})
                    frame_results_map = result_maps[frame_idx]
                    # 为当前模型存储检测结果
                    if model_exec.algo_model_id not in frame_results_map:
                        frame_results_map[model_exec.algo_model_id] = []
                    # 添加检测结果
                    frame_results_map[model_exec.algo_model_id].extend(
                        DetectionResult.from_dict(box) for box in r
                    )

            # 结果处理
            self.thread_pool.submit_task(self.save_frame_results, frames, result_maps)
            for annotator in annotators:
                self.display_frame_manager.add_frame(self.device.id, annotator.result())

            self.log_fps(len(frames))


            # future = asyncio.run_coroutine_threadsafe(
            #     self.display_frame_manager.add_frame(self.device.id, annotator.result()), self.main_loop
            # )