import asyncio import json from copy import deepcopy from dataclasses import dataclass import importlib from datetime import datetime, timedelta from threading import Event from typing import List, Dict from algo.model_manager import AlgoModelExec from algo.stream_loader import OpenCVStreamLoad from common.device_status_manager import DeviceStatusManager from common.display_frame_manager import DisplayFrameManager from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool from common.http_utils import send_request from common.image_plotting import Annotator from common.string_utils import camel_to_snake, get_class, default_serializer from db.database import get_db from entity.device import Device from entity.frame_analysis_result import FrameAnalysisResultCreate from services.device_frame_service import DeviceFrameService from services.frame_analysis_result_service import FrameAnalysisResultService from services.global_config import GlobalConfig @dataclass class DetectionResult: object_class_id: int object_class_name: str confidence: float location: str @classmethod def from_dict(cls, data: Dict) -> 'DetectionResult': return DetectionResult( object_class_id=data['object_class_id'], object_class_name=data['object_class_name'], confidence=data['confidence'], location=data['location'] ) class DeviceDetectionTask: def __init__(self, device: Device, model_exec_list: List[AlgoModelExec], thread_id, db_session, main_loop): self.device = device self.model_exec_list = model_exec_list self.__stop_event = Event() # 使用 Event 控制线程的运行状态 self.frame_ts = None self.push_ts = None self.frames_detected = 0 self.fps_ts = None self.thread_id = thread_id self.device_frame_service = DeviceFrameService(db_session) self.frame_analysis_result_service = FrameAnalysisResultService(db_session) self.thread_pool = GlobalThreadPool() self.device_status_manager = DeviceStatusManager() self.display_frame_manager = DisplayFrameManager() self.main_loop = main_loop self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, device_thread_id=thread_id) def stop_detection_task(self): logger.info(f'stop detection task {self.device.id}, thread_id: {self.thread_id}') self.__stop_event.set() self.stream_loader.stop() # 停止视频流加载的线程 def check_frame_interval(self): if self.device.image_save_interval < 0: return False if self.frame_ts is None: self.frame_ts = datetime.now() return True if (datetime.now() - self.frame_ts).total_seconds() > self.device.image_save_interval: self.frame_ts = datetime.now() return True return False def save_frame_results(self, frames, result_maps): for idx,frame in enumerate(frames): results_map = result_maps[idx] if not self.check_frame_interval(): return # device_frame = self.device_frame_service.add_frame(self.device.id, frame) # frame_id = device_frame.id future = asyncio.run_coroutine_threadsafe( self.device_frame_service.add_frame(self.device.id, frame), self.main_loop ) device_frame = future.result() frame_id = device_frame.id logger.info(f'save frame for device {self.device.id}, frame_id: {frame_id}') frame_results = [] for model_exec_id, results in results_map.items(): for r in results: frame_result = FrameAnalysisResultCreate( device_id=self.device.id, frame_id=frame_id, algo_model_id=model_exec_id, object_class_id=r.object_class_id, object_class_name=r.object_class_name, confidence=round(r.confidence, 4), location=r.location, ) frame_results.append(frame_result) asyncio.run_coroutine_threadsafe( self.frame_analysis_result_service.add_frame_analysis_results(frame_results), self.main_loop ) self.thread_pool.submit_task(self.push_frame_results, frame_results) def push_frame_results(self, frame_results): global_config = GlobalConfig() push_config = global_config.get_algo_result_push_config() if push_config and push_config.push_url: last_ts = self.push_ts current_time = datetime.now() # 检查是否需要推送数据 if last_ts is None or (current_time - last_ts).total_seconds() > push_config.push_interval: send_request( push_config.push_url, json.dumps([r.dict() for r in frame_results], default=default_serializer) ) self.push_ts = current_time # 更新推送时间戳 def log_fps(self, frame_count): self.frames_detected += frame_count current_time = datetime.now() # 每秒输出 FPS if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10): fps = self.frames_detected / 10.0 self.frames_detected = 0 logger.info(f"FPS (detect) for device {self.device.code}: {fps}") self.fps_ts = current_time def run(self): while not self.stream_loader.init: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 self.stream_loader.init_cap() for frames in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 # print('frame') if frames is None or len(frames)<=0: continue self.device_status_manager.set_status(device_id=self.device.id) result_maps = [] # 保存每个frame的result map annotators = [] for frame in frames: annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") annotators.append(annotator) for model_exec in self.model_exec_list: handle_task_name = model_exec.algo_model_info.handle_task handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) handler_instance = handler_cls(model_exec) frames, results = handler_instance.run(frames, annotators) # 遍历检测结果,按帧存储 for frame_idx, r in enumerate(results): # 遍历每帧结果 if len(result_maps) <= frame_idx: # 初始化 frame_results_map result_maps.append({}) frame_results_map = result_maps[frame_idx] # 为当前模型存储检测结果 if model_exec.algo_model_id not in frame_results_map: frame_results_map[model_exec.algo_model_id] = [] # 添加检测结果 frame_results_map[model_exec.algo_model_id].extend( DetectionResult.from_dict(box) for box in r ) # 结果处理 self.thread_pool.submit_task(self.save_frame_results, frames, result_maps) for annotator in annotators: self.display_frame_manager.add_frame(self.device.id, annotator.result()) self.log_fps(len(frames)) # future = asyncio.run_coroutine_threadsafe( # self.display_frame_manager.add_frame(self.device.id, annotator.result()), self.main_loop # )