diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 8724c0d..6a86830 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -1,5 +1,6 @@ import asyncio import json +from copy import deepcopy from dataclasses import dataclass import importlib from datetime import datetime @@ -9,9 +10,11 @@ from algo.model_manager import AlgoModelExec from algo.stream_loader import OpenCVStreamLoad from common.device_status_manager import DeviceStatusManager +from common.display_frame_manager import DisplayFrameManager from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool from common.http_utils import send_request +from common.image_plotting import Annotator from common.string_utils import camel_to_snake, get_class, default_serializer from db.database import get_db from entity.device import Device @@ -52,6 +55,7 @@ self.thread_pool = GlobalThreadPool() self.device_status_manager = DeviceStatusManager() + self.display_frame_manager = DisplayFrameManager() self.main_loop = main_loop self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, @@ -134,12 +138,17 @@ self.device_status_manager.set_status(device_id=self.device.id) results_map = {} + annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") for model_exec in self.model_exec_list: handle_task_name = model_exec.algo_model_info.handle_task handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) handler_instance = handler_cls(model_exec) - frame, results = handler_instance.run(frame, None) + frame, results = handler_instance.run(frame, annotator) results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] # 结果处理 self.thread_pool.submit_task(self.save_frame_results, frame, results_map) + self.display_frame_manager.add_frame(self.device.id, annotator.result()) + # future = asyncio.run_coroutine_threadsafe( + # self.display_frame_manager.add_frame(self.device.id, annotator.result()), self.main_loop + # ) diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 8724c0d..6a86830 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -1,5 +1,6 @@ import asyncio import json +from copy import deepcopy from dataclasses import dataclass import importlib from datetime import datetime @@ -9,9 +10,11 @@ from algo.model_manager import AlgoModelExec from algo.stream_loader import OpenCVStreamLoad from common.device_status_manager import DeviceStatusManager +from common.display_frame_manager import DisplayFrameManager from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool from common.http_utils import send_request +from common.image_plotting import Annotator from common.string_utils import camel_to_snake, get_class, default_serializer from db.database import get_db from entity.device import Device @@ -52,6 +55,7 @@ self.thread_pool = GlobalThreadPool() self.device_status_manager = DeviceStatusManager() + self.display_frame_manager = DisplayFrameManager() self.main_loop = main_loop self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, @@ -134,12 +138,17 @@ self.device_status_manager.set_status(device_id=self.device.id) results_map = {} + annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") for model_exec in self.model_exec_list: handle_task_name = model_exec.algo_model_info.handle_task handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) handler_instance = handler_cls(model_exec) - frame, results = handler_instance.run(frame, None) + frame, results = handler_instance.run(frame, annotator) results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] # 结果处理 self.thread_pool.submit_task(self.save_frame_results, frame, results_map) + self.display_frame_manager.add_frame(self.device.id, annotator.result()) + # future = asyncio.run_coroutine_threadsafe( + # self.display_frame_manager.add_frame(self.device.id, annotator.result()), self.main_loop + # ) diff --git a/apis/control.py b/apis/control.py index 4c108c5..9c6d6fb 100644 --- a/apis/control.py +++ b/apis/control.py @@ -98,3 +98,11 @@ except Exception as e: traceback.print_exc() return standard_error_response(code=500, message=f"Failed to restart container: {e}") + +@router.get("/sync_test") +def sync_test(): + return standard_response() + +@router.get("/async_test") +async def async_test(): + return standard_response() diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 8724c0d..6a86830 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -1,5 +1,6 @@ import asyncio import json +from copy import deepcopy from dataclasses import dataclass import importlib from datetime import datetime @@ -9,9 +10,11 @@ from algo.model_manager import AlgoModelExec from algo.stream_loader import OpenCVStreamLoad from common.device_status_manager import DeviceStatusManager +from common.display_frame_manager import DisplayFrameManager from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool from common.http_utils import send_request +from common.image_plotting import Annotator from common.string_utils import camel_to_snake, get_class, default_serializer from db.database import get_db from entity.device import Device @@ -52,6 +55,7 @@ self.thread_pool = GlobalThreadPool() self.device_status_manager = DeviceStatusManager() + self.display_frame_manager = DisplayFrameManager() self.main_loop = main_loop self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, @@ -134,12 +138,17 @@ self.device_status_manager.set_status(device_id=self.device.id) results_map = {} + annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") for model_exec in self.model_exec_list: handle_task_name = model_exec.algo_model_info.handle_task handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) handler_instance = handler_cls(model_exec) - frame, results = handler_instance.run(frame, None) + frame, results = handler_instance.run(frame, annotator) results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] # 结果处理 self.thread_pool.submit_task(self.save_frame_results, frame, results_map) + self.display_frame_manager.add_frame(self.device.id, annotator.result()) + # future = asyncio.run_coroutine_threadsafe( + # self.display_frame_manager.add_frame(self.device.id, annotator.result()), self.main_loop + # ) diff --git a/apis/control.py b/apis/control.py index 4c108c5..9c6d6fb 100644 --- a/apis/control.py +++ b/apis/control.py @@ -98,3 +98,11 @@ except Exception as e: traceback.print_exc() return standard_error_response(code=500, message=f"Failed to restart container: {e}") + +@router.get("/sync_test") +def sync_test(): + return standard_response() + +@router.get("/async_test") +async def async_test(): + return standard_response() diff --git a/apis/display.py b/apis/display.py new file mode 100644 index 0000000..64c998a --- /dev/null +++ b/apis/display.py @@ -0,0 +1,73 @@ +import asyncio +import threading + +import cv2 +from fastapi import APIRouter, Request, WebSocket, WebSocketDisconnect +from starlette.responses import StreamingResponse + +from common.display_frame_manager import DisplayFrameManager +from common.global_logger import logger + +router = APIRouter() +display_frame_manager = DisplayFrameManager() + + +async def generate_video_stream(device_id, request: Request): + while True: + if await request.is_disconnected(): + print("客户端已断开连接,停止视频流") + break + + if display_frame_manager.has_device(device_id): + frame = display_frame_manager.get_latest_frame(device_id) + # 将帧编码为 JPEG 格式 + _, buffer = cv2.imencode('.jpg', frame) + frame = buffer.tobytes() + + yield (b'--frame\r\n' + b'Content-Type: image/jpeg\r\n\r\n' + frame + b'\r\n') + + await asyncio.sleep(0.03) # 控制帧的推送频率,减小CPU占用 + + +@router.get("/video") +async def video_stream(device_id: int, request: Request): + return StreamingResponse(generate_video_stream(device_id, request), + media_type="multipart/x-mixed-replace; boundary=frame") + + +async def send_video_stream(device_id, websocket: WebSocket): + await websocket.accept() # 接受 WebSocket 连接 + try: + while True: + # 检查并获取设备的最新帧 + frame = None + if display_frame_manager.has_device(device_id): + frame = display_frame_manager.get_latest_frame(device_id) + + # 如果有可用帧,则进行处理并发送 + if frame is not None: + # 将帧编码为 JPEG + _, buffer = cv2.imencode('.jpg', frame) + # 将图像数据编码为 Base64 字符串(或者可以直接发送字节流) + await websocket.send_bytes(buffer.tobytes()) + await asyncio.sleep(0) + else: + # 如果没有帧,可以选择发送空帧或跳过 + await asyncio.sleep(0.03) # 控制帧率,避免过度占用资源 + except WebSocketDisconnect: + logger.info("WebSocket 连接已断开") + except Exception as e: + logger.error(f"WebSocket 连接出错: {e}") + + +def run_in_thread(device_id, websocket: WebSocket): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(send_video_stream(device_id, websocket)) + loop.close() + + +@router.websocket("/ws/video/{device_id}") +async def video_stream(websocket: WebSocket, device_id: int): + await send_video_stream(device_id,websocket) diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 8724c0d..6a86830 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -1,5 +1,6 @@ import asyncio import json +from copy import deepcopy from dataclasses import dataclass import importlib from datetime import datetime @@ -9,9 +10,11 @@ from algo.model_manager import AlgoModelExec from algo.stream_loader import OpenCVStreamLoad from common.device_status_manager import DeviceStatusManager +from common.display_frame_manager import DisplayFrameManager from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool from common.http_utils import send_request +from common.image_plotting import Annotator from common.string_utils import camel_to_snake, get_class, default_serializer from db.database import get_db from entity.device import Device @@ -52,6 +55,7 @@ self.thread_pool = GlobalThreadPool() self.device_status_manager = DeviceStatusManager() + self.display_frame_manager = DisplayFrameManager() self.main_loop = main_loop self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, @@ -134,12 +138,17 @@ self.device_status_manager.set_status(device_id=self.device.id) results_map = {} + annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") for model_exec in self.model_exec_list: handle_task_name = model_exec.algo_model_info.handle_task handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) handler_instance = handler_cls(model_exec) - frame, results = handler_instance.run(frame, None) + frame, results = handler_instance.run(frame, annotator) results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] # 结果处理 self.thread_pool.submit_task(self.save_frame_results, frame, results_map) + self.display_frame_manager.add_frame(self.device.id, annotator.result()) + # future = asyncio.run_coroutine_threadsafe( + # self.display_frame_manager.add_frame(self.device.id, annotator.result()), self.main_loop + # ) diff --git a/apis/control.py b/apis/control.py index 4c108c5..9c6d6fb 100644 --- a/apis/control.py +++ b/apis/control.py @@ -98,3 +98,11 @@ except Exception as e: traceback.print_exc() return standard_error_response(code=500, message=f"Failed to restart container: {e}") + +@router.get("/sync_test") +def sync_test(): + return standard_response() + +@router.get("/async_test") +async def async_test(): + return standard_response() diff --git a/apis/display.py b/apis/display.py new file mode 100644 index 0000000..64c998a --- /dev/null +++ b/apis/display.py @@ -0,0 +1,73 @@ +import asyncio +import threading + +import cv2 +from fastapi import APIRouter, Request, WebSocket, WebSocketDisconnect +from starlette.responses import StreamingResponse + +from common.display_frame_manager import DisplayFrameManager +from common.global_logger import logger + +router = APIRouter() +display_frame_manager = DisplayFrameManager() + + +async def generate_video_stream(device_id, request: Request): + while True: + if await request.is_disconnected(): + print("客户端已断开连接,停止视频流") + break + + if display_frame_manager.has_device(device_id): + frame = display_frame_manager.get_latest_frame(device_id) + # 将帧编码为 JPEG 格式 + _, buffer = cv2.imencode('.jpg', frame) + frame = buffer.tobytes() + + yield (b'--frame\r\n' + b'Content-Type: image/jpeg\r\n\r\n' + frame + b'\r\n') + + await asyncio.sleep(0.03) # 控制帧的推送频率,减小CPU占用 + + +@router.get("/video") +async def video_stream(device_id: int, request: Request): + return StreamingResponse(generate_video_stream(device_id, request), + media_type="multipart/x-mixed-replace; boundary=frame") + + +async def send_video_stream(device_id, websocket: WebSocket): + await websocket.accept() # 接受 WebSocket 连接 + try: + while True: + # 检查并获取设备的最新帧 + frame = None + if display_frame_manager.has_device(device_id): + frame = display_frame_manager.get_latest_frame(device_id) + + # 如果有可用帧,则进行处理并发送 + if frame is not None: + # 将帧编码为 JPEG + _, buffer = cv2.imencode('.jpg', frame) + # 将图像数据编码为 Base64 字符串(或者可以直接发送字节流) + await websocket.send_bytes(buffer.tobytes()) + await asyncio.sleep(0) + else: + # 如果没有帧,可以选择发送空帧或跳过 + await asyncio.sleep(0.03) # 控制帧率,避免过度占用资源 + except WebSocketDisconnect: + logger.info("WebSocket 连接已断开") + except Exception as e: + logger.error(f"WebSocket 连接出错: {e}") + + +def run_in_thread(device_id, websocket: WebSocket): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(send_video_stream(device_id, websocket)) + loop.close() + + +@router.websocket("/ws/video/{device_id}") +async def video_stream(websocket: WebSocket, device_id: int): + await send_video_stream(device_id,websocket) diff --git a/apis/router.py b/apis/router.py index 18e1311..4f575f6 100644 --- a/apis/router.py +++ b/apis/router.py @@ -9,7 +9,7 @@ from .data_gas import router as gas_router from .control import router as control_router from .push_config import router as push_config_router - +from .display import router as display_router # 创建一个全局的 router router = APIRouter() @@ -18,10 +18,10 @@ router.include_router(devices_router, prefix="/device", tags=["Devices"]) router.include_router(models_router, prefix="/model", tags=["Models"]) router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"]) -router.include_router(scene_router,prefix="/scene", tags=["Scene"]) +router.include_router(scene_router, prefix="/scene", tags=["Scene"]) router.include_router(device_scene_relation_router, prefix="/device_scene_relation", tags=["DeviceSceneRelations"]) router.include_router(frame_router, prefix="/frame", tags=["DeviceFrame"]) router.include_router(gas_router, prefix="/gas", tags=["DataGas"]) -router.include_router(control_router,prefix="/control", tags=["Control"]) -router.include_router(push_config_router,prefix="/push", tags=["PushConfig"]) - +router.include_router(control_router, prefix="/control", tags=["Control"]) +router.include_router(push_config_router, prefix="/push", tags=["PushConfig"]) +router.include_router(display_router, prefix="/display", tags=["Display"]) diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 8724c0d..6a86830 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -1,5 +1,6 @@ import asyncio import json +from copy import deepcopy from dataclasses import dataclass import importlib from datetime import datetime @@ -9,9 +10,11 @@ from algo.model_manager import AlgoModelExec from algo.stream_loader import OpenCVStreamLoad from common.device_status_manager import DeviceStatusManager +from common.display_frame_manager import DisplayFrameManager from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool from common.http_utils import send_request +from common.image_plotting import Annotator from common.string_utils import camel_to_snake, get_class, default_serializer from db.database import get_db from entity.device import Device @@ -52,6 +55,7 @@ self.thread_pool = GlobalThreadPool() self.device_status_manager = DeviceStatusManager() + self.display_frame_manager = DisplayFrameManager() self.main_loop = main_loop self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, @@ -134,12 +138,17 @@ self.device_status_manager.set_status(device_id=self.device.id) results_map = {} + annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") for model_exec in self.model_exec_list: handle_task_name = model_exec.algo_model_info.handle_task handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) handler_instance = handler_cls(model_exec) - frame, results = handler_instance.run(frame, None) + frame, results = handler_instance.run(frame, annotator) results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] # 结果处理 self.thread_pool.submit_task(self.save_frame_results, frame, results_map) + self.display_frame_manager.add_frame(self.device.id, annotator.result()) + # future = asyncio.run_coroutine_threadsafe( + # self.display_frame_manager.add_frame(self.device.id, annotator.result()), self.main_loop + # ) diff --git a/apis/control.py b/apis/control.py index 4c108c5..9c6d6fb 100644 --- a/apis/control.py +++ b/apis/control.py @@ -98,3 +98,11 @@ except Exception as e: traceback.print_exc() return standard_error_response(code=500, message=f"Failed to restart container: {e}") + +@router.get("/sync_test") +def sync_test(): + return standard_response() + +@router.get("/async_test") +async def async_test(): + return standard_response() diff --git a/apis/display.py b/apis/display.py new file mode 100644 index 0000000..64c998a --- /dev/null +++ b/apis/display.py @@ -0,0 +1,73 @@ +import asyncio +import threading + +import cv2 +from fastapi import APIRouter, Request, WebSocket, WebSocketDisconnect +from starlette.responses import StreamingResponse + +from common.display_frame_manager import DisplayFrameManager +from common.global_logger import logger + +router = APIRouter() +display_frame_manager = DisplayFrameManager() + + +async def generate_video_stream(device_id, request: Request): + while True: + if await request.is_disconnected(): + print("客户端已断开连接,停止视频流") + break + + if display_frame_manager.has_device(device_id): + frame = display_frame_manager.get_latest_frame(device_id) + # 将帧编码为 JPEG 格式 + _, buffer = cv2.imencode('.jpg', frame) + frame = buffer.tobytes() + + yield (b'--frame\r\n' + b'Content-Type: image/jpeg\r\n\r\n' + frame + b'\r\n') + + await asyncio.sleep(0.03) # 控制帧的推送频率,减小CPU占用 + + +@router.get("/video") +async def video_stream(device_id: int, request: Request): + return StreamingResponse(generate_video_stream(device_id, request), + media_type="multipart/x-mixed-replace; boundary=frame") + + +async def send_video_stream(device_id, websocket: WebSocket): + await websocket.accept() # 接受 WebSocket 连接 + try: + while True: + # 检查并获取设备的最新帧 + frame = None + if display_frame_manager.has_device(device_id): + frame = display_frame_manager.get_latest_frame(device_id) + + # 如果有可用帧,则进行处理并发送 + if frame is not None: + # 将帧编码为 JPEG + _, buffer = cv2.imencode('.jpg', frame) + # 将图像数据编码为 Base64 字符串(或者可以直接发送字节流) + await websocket.send_bytes(buffer.tobytes()) + await asyncio.sleep(0) + else: + # 如果没有帧,可以选择发送空帧或跳过 + await asyncio.sleep(0.03) # 控制帧率,避免过度占用资源 + except WebSocketDisconnect: + logger.info("WebSocket 连接已断开") + except Exception as e: + logger.error(f"WebSocket 连接出错: {e}") + + +def run_in_thread(device_id, websocket: WebSocket): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(send_video_stream(device_id, websocket)) + loop.close() + + +@router.websocket("/ws/video/{device_id}") +async def video_stream(websocket: WebSocket, device_id: int): + await send_video_stream(device_id,websocket) diff --git a/apis/router.py b/apis/router.py index 18e1311..4f575f6 100644 --- a/apis/router.py +++ b/apis/router.py @@ -9,7 +9,7 @@ from .data_gas import router as gas_router from .control import router as control_router from .push_config import router as push_config_router - +from .display import router as display_router # 创建一个全局的 router router = APIRouter() @@ -18,10 +18,10 @@ router.include_router(devices_router, prefix="/device", tags=["Devices"]) router.include_router(models_router, prefix="/model", tags=["Models"]) router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"]) -router.include_router(scene_router,prefix="/scene", tags=["Scene"]) +router.include_router(scene_router, prefix="/scene", tags=["Scene"]) router.include_router(device_scene_relation_router, prefix="/device_scene_relation", tags=["DeviceSceneRelations"]) router.include_router(frame_router, prefix="/frame", tags=["DeviceFrame"]) router.include_router(gas_router, prefix="/gas", tags=["DataGas"]) -router.include_router(control_router,prefix="/control", tags=["Control"]) -router.include_router(push_config_router,prefix="/push", tags=["PushConfig"]) - +router.include_router(control_router, prefix="/control", tags=["Control"]) +router.include_router(push_config_router, prefix="/push", tags=["PushConfig"]) +router.include_router(display_router, prefix="/display", tags=["Display"]) diff --git a/app_instance.py b/app_instance.py index 4e24675..82a750d 100644 --- a/app_instance.py +++ b/app_instance.py @@ -49,7 +49,7 @@ tcp_manager = TcpManager(device_service=device_service) app.state.tcp_manager = tcp_manager - # await tcp_manager.start() + await tcp_manager.start() algo_runner = AlgoRunner( device_service=device_service, @@ -57,7 +57,7 @@ relation_service=model_relation_service, ) app.state.algo_runner = algo_runner - # await algo_runner.start() + await algo_runner.start() scene_runner = SceneRunner( device_service=device_service, diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 8724c0d..6a86830 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -1,5 +1,6 @@ import asyncio import json +from copy import deepcopy from dataclasses import dataclass import importlib from datetime import datetime @@ -9,9 +10,11 @@ from algo.model_manager import AlgoModelExec from algo.stream_loader import OpenCVStreamLoad from common.device_status_manager import DeviceStatusManager +from common.display_frame_manager import DisplayFrameManager from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool from common.http_utils import send_request +from common.image_plotting import Annotator from common.string_utils import camel_to_snake, get_class, default_serializer from db.database import get_db from entity.device import Device @@ -52,6 +55,7 @@ self.thread_pool = GlobalThreadPool() self.device_status_manager = DeviceStatusManager() + self.display_frame_manager = DisplayFrameManager() self.main_loop = main_loop self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, @@ -134,12 +138,17 @@ self.device_status_manager.set_status(device_id=self.device.id) results_map = {} + annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") for model_exec in self.model_exec_list: handle_task_name = model_exec.algo_model_info.handle_task handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) handler_instance = handler_cls(model_exec) - frame, results = handler_instance.run(frame, None) + frame, results = handler_instance.run(frame, annotator) results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] # 结果处理 self.thread_pool.submit_task(self.save_frame_results, frame, results_map) + self.display_frame_manager.add_frame(self.device.id, annotator.result()) + # future = asyncio.run_coroutine_threadsafe( + # self.display_frame_manager.add_frame(self.device.id, annotator.result()), self.main_loop + # ) diff --git a/apis/control.py b/apis/control.py index 4c108c5..9c6d6fb 100644 --- a/apis/control.py +++ b/apis/control.py @@ -98,3 +98,11 @@ except Exception as e: traceback.print_exc() return standard_error_response(code=500, message=f"Failed to restart container: {e}") + +@router.get("/sync_test") +def sync_test(): + return standard_response() + +@router.get("/async_test") +async def async_test(): + return standard_response() diff --git a/apis/display.py b/apis/display.py new file mode 100644 index 0000000..64c998a --- /dev/null +++ b/apis/display.py @@ -0,0 +1,73 @@ +import asyncio +import threading + +import cv2 +from fastapi import APIRouter, Request, WebSocket, WebSocketDisconnect +from starlette.responses import StreamingResponse + +from common.display_frame_manager import DisplayFrameManager +from common.global_logger import logger + +router = APIRouter() +display_frame_manager = DisplayFrameManager() + + +async def generate_video_stream(device_id, request: Request): + while True: + if await request.is_disconnected(): + print("客户端已断开连接,停止视频流") + break + + if display_frame_manager.has_device(device_id): + frame = display_frame_manager.get_latest_frame(device_id) + # 将帧编码为 JPEG 格式 + _, buffer = cv2.imencode('.jpg', frame) + frame = buffer.tobytes() + + yield (b'--frame\r\n' + b'Content-Type: image/jpeg\r\n\r\n' + frame + b'\r\n') + + await asyncio.sleep(0.03) # 控制帧的推送频率,减小CPU占用 + + +@router.get("/video") +async def video_stream(device_id: int, request: Request): + return StreamingResponse(generate_video_stream(device_id, request), + media_type="multipart/x-mixed-replace; boundary=frame") + + +async def send_video_stream(device_id, websocket: WebSocket): + await websocket.accept() # 接受 WebSocket 连接 + try: + while True: + # 检查并获取设备的最新帧 + frame = None + if display_frame_manager.has_device(device_id): + frame = display_frame_manager.get_latest_frame(device_id) + + # 如果有可用帧,则进行处理并发送 + if frame is not None: + # 将帧编码为 JPEG + _, buffer = cv2.imencode('.jpg', frame) + # 将图像数据编码为 Base64 字符串(或者可以直接发送字节流) + await websocket.send_bytes(buffer.tobytes()) + await asyncio.sleep(0) + else: + # 如果没有帧,可以选择发送空帧或跳过 + await asyncio.sleep(0.03) # 控制帧率,避免过度占用资源 + except WebSocketDisconnect: + logger.info("WebSocket 连接已断开") + except Exception as e: + logger.error(f"WebSocket 连接出错: {e}") + + +def run_in_thread(device_id, websocket: WebSocket): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(send_video_stream(device_id, websocket)) + loop.close() + + +@router.websocket("/ws/video/{device_id}") +async def video_stream(websocket: WebSocket, device_id: int): + await send_video_stream(device_id,websocket) diff --git a/apis/router.py b/apis/router.py index 18e1311..4f575f6 100644 --- a/apis/router.py +++ b/apis/router.py @@ -9,7 +9,7 @@ from .data_gas import router as gas_router from .control import router as control_router from .push_config import router as push_config_router - +from .display import router as display_router # 创建一个全局的 router router = APIRouter() @@ -18,10 +18,10 @@ router.include_router(devices_router, prefix="/device", tags=["Devices"]) router.include_router(models_router, prefix="/model", tags=["Models"]) router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"]) -router.include_router(scene_router,prefix="/scene", tags=["Scene"]) +router.include_router(scene_router, prefix="/scene", tags=["Scene"]) router.include_router(device_scene_relation_router, prefix="/device_scene_relation", tags=["DeviceSceneRelations"]) router.include_router(frame_router, prefix="/frame", tags=["DeviceFrame"]) router.include_router(gas_router, prefix="/gas", tags=["DataGas"]) -router.include_router(control_router,prefix="/control", tags=["Control"]) -router.include_router(push_config_router,prefix="/push", tags=["PushConfig"]) - +router.include_router(control_router, prefix="/control", tags=["Control"]) +router.include_router(push_config_router, prefix="/push", tags=["PushConfig"]) +router.include_router(display_router, prefix="/display", tags=["Display"]) diff --git a/app_instance.py b/app_instance.py index 4e24675..82a750d 100644 --- a/app_instance.py +++ b/app_instance.py @@ -49,7 +49,7 @@ tcp_manager = TcpManager(device_service=device_service) app.state.tcp_manager = tcp_manager - # await tcp_manager.start() + await tcp_manager.start() algo_runner = AlgoRunner( device_service=device_service, @@ -57,7 +57,7 @@ relation_service=model_relation_service, ) app.state.algo_runner = algo_runner - # await algo_runner.start() + await algo_runner.start() scene_runner = SceneRunner( device_service=device_service, diff --git a/common/display_frame_manager.py b/common/display_frame_manager.py new file mode 100644 index 0000000..d7f8354 --- /dev/null +++ b/common/display_frame_manager.py @@ -0,0 +1,46 @@ +import threading +from collections import deque + + +class DisplayFrameManager: + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + """实现单例模式,确保只有一个实例""" + if cls._instance is None: + with cls._lock: + if cls._instance is None: # 防止多个线程同时创建 + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self, maxlen=10): + if not hasattr(self, '_initialized'): # 防止重复初始化 + self.device_queues = {} + self.maxlen = maxlen + self._initialized = True + + def add_frame(self, device_id, frame): + """向设备队列中添加帧,如果设备不存在,则自动创建队列""" + with self._lock: + if device_id not in self.device_queues: + self.device_queues[device_id] = deque(maxlen=self.maxlen) + if len(self.device_queues[device_id]) >= self.maxlen: + self.device_queues[device_id].pop() # 移除最旧的帧 + self.device_queues[device_id].appendleft(frame) # 添加新帧 + + def get_latest_frame(self, device_id): + """获取设备队列中的最新帧""" + with self._lock: + if device_id in self.device_queues and self.device_queues[device_id]: + return self.device_queues[device_id][0] + return None # 如果设备不存在或没有帧,返回 None + + def remove_device(self, device_id): + """移除设备及其队列""" + with self._lock: + if device_id in self.device_queues: + del self.device_queues[device_id] + + def has_device(self,device_id): + return device_id in self.device_queues and self.device_queues[device_id] diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 8724c0d..6a86830 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -1,5 +1,6 @@ import asyncio import json +from copy import deepcopy from dataclasses import dataclass import importlib from datetime import datetime @@ -9,9 +10,11 @@ from algo.model_manager import AlgoModelExec from algo.stream_loader import OpenCVStreamLoad from common.device_status_manager import DeviceStatusManager +from common.display_frame_manager import DisplayFrameManager from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool from common.http_utils import send_request +from common.image_plotting import Annotator from common.string_utils import camel_to_snake, get_class, default_serializer from db.database import get_db from entity.device import Device @@ -52,6 +55,7 @@ self.thread_pool = GlobalThreadPool() self.device_status_manager = DeviceStatusManager() + self.display_frame_manager = DisplayFrameManager() self.main_loop = main_loop self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, @@ -134,12 +138,17 @@ self.device_status_manager.set_status(device_id=self.device.id) results_map = {} + annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") for model_exec in self.model_exec_list: handle_task_name = model_exec.algo_model_info.handle_task handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) handler_instance = handler_cls(model_exec) - frame, results = handler_instance.run(frame, None) + frame, results = handler_instance.run(frame, annotator) results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] # 结果处理 self.thread_pool.submit_task(self.save_frame_results, frame, results_map) + self.display_frame_manager.add_frame(self.device.id, annotator.result()) + # future = asyncio.run_coroutine_threadsafe( + # self.display_frame_manager.add_frame(self.device.id, annotator.result()), self.main_loop + # ) diff --git a/apis/control.py b/apis/control.py index 4c108c5..9c6d6fb 100644 --- a/apis/control.py +++ b/apis/control.py @@ -98,3 +98,11 @@ except Exception as e: traceback.print_exc() return standard_error_response(code=500, message=f"Failed to restart container: {e}") + +@router.get("/sync_test") +def sync_test(): + return standard_response() + +@router.get("/async_test") +async def async_test(): + return standard_response() diff --git a/apis/display.py b/apis/display.py new file mode 100644 index 0000000..64c998a --- /dev/null +++ b/apis/display.py @@ -0,0 +1,73 @@ +import asyncio +import threading + +import cv2 +from fastapi import APIRouter, Request, WebSocket, WebSocketDisconnect +from starlette.responses import StreamingResponse + +from common.display_frame_manager import DisplayFrameManager +from common.global_logger import logger + +router = APIRouter() +display_frame_manager = DisplayFrameManager() + + +async def generate_video_stream(device_id, request: Request): + while True: + if await request.is_disconnected(): + print("客户端已断开连接,停止视频流") + break + + if display_frame_manager.has_device(device_id): + frame = display_frame_manager.get_latest_frame(device_id) + # 将帧编码为 JPEG 格式 + _, buffer = cv2.imencode('.jpg', frame) + frame = buffer.tobytes() + + yield (b'--frame\r\n' + b'Content-Type: image/jpeg\r\n\r\n' + frame + b'\r\n') + + await asyncio.sleep(0.03) # 控制帧的推送频率,减小CPU占用 + + +@router.get("/video") +async def video_stream(device_id: int, request: Request): + return StreamingResponse(generate_video_stream(device_id, request), + media_type="multipart/x-mixed-replace; boundary=frame") + + +async def send_video_stream(device_id, websocket: WebSocket): + await websocket.accept() # 接受 WebSocket 连接 + try: + while True: + # 检查并获取设备的最新帧 + frame = None + if display_frame_manager.has_device(device_id): + frame = display_frame_manager.get_latest_frame(device_id) + + # 如果有可用帧,则进行处理并发送 + if frame is not None: + # 将帧编码为 JPEG + _, buffer = cv2.imencode('.jpg', frame) + # 将图像数据编码为 Base64 字符串(或者可以直接发送字节流) + await websocket.send_bytes(buffer.tobytes()) + await asyncio.sleep(0) + else: + # 如果没有帧,可以选择发送空帧或跳过 + await asyncio.sleep(0.03) # 控制帧率,避免过度占用资源 + except WebSocketDisconnect: + logger.info("WebSocket 连接已断开") + except Exception as e: + logger.error(f"WebSocket 连接出错: {e}") + + +def run_in_thread(device_id, websocket: WebSocket): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(send_video_stream(device_id, websocket)) + loop.close() + + +@router.websocket("/ws/video/{device_id}") +async def video_stream(websocket: WebSocket, device_id: int): + await send_video_stream(device_id,websocket) diff --git a/apis/router.py b/apis/router.py index 18e1311..4f575f6 100644 --- a/apis/router.py +++ b/apis/router.py @@ -9,7 +9,7 @@ from .data_gas import router as gas_router from .control import router as control_router from .push_config import router as push_config_router - +from .display import router as display_router # 创建一个全局的 router router = APIRouter() @@ -18,10 +18,10 @@ router.include_router(devices_router, prefix="/device", tags=["Devices"]) router.include_router(models_router, prefix="/model", tags=["Models"]) router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"]) -router.include_router(scene_router,prefix="/scene", tags=["Scene"]) +router.include_router(scene_router, prefix="/scene", tags=["Scene"]) router.include_router(device_scene_relation_router, prefix="/device_scene_relation", tags=["DeviceSceneRelations"]) router.include_router(frame_router, prefix="/frame", tags=["DeviceFrame"]) router.include_router(gas_router, prefix="/gas", tags=["DataGas"]) -router.include_router(control_router,prefix="/control", tags=["Control"]) -router.include_router(push_config_router,prefix="/push", tags=["PushConfig"]) - +router.include_router(control_router, prefix="/control", tags=["Control"]) +router.include_router(push_config_router, prefix="/push", tags=["PushConfig"]) +router.include_router(display_router, prefix="/display", tags=["Display"]) diff --git a/app_instance.py b/app_instance.py index 4e24675..82a750d 100644 --- a/app_instance.py +++ b/app_instance.py @@ -49,7 +49,7 @@ tcp_manager = TcpManager(device_service=device_service) app.state.tcp_manager = tcp_manager - # await tcp_manager.start() + await tcp_manager.start() algo_runner = AlgoRunner( device_service=device_service, @@ -57,7 +57,7 @@ relation_service=model_relation_service, ) app.state.algo_runner = algo_runner - # await algo_runner.start() + await algo_runner.start() scene_runner = SceneRunner( device_service=device_service, diff --git a/common/display_frame_manager.py b/common/display_frame_manager.py new file mode 100644 index 0000000..d7f8354 --- /dev/null +++ b/common/display_frame_manager.py @@ -0,0 +1,46 @@ +import threading +from collections import deque + + +class DisplayFrameManager: + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + """实现单例模式,确保只有一个实例""" + if cls._instance is None: + with cls._lock: + if cls._instance is None: # 防止多个线程同时创建 + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self, maxlen=10): + if not hasattr(self, '_initialized'): # 防止重复初始化 + self.device_queues = {} + self.maxlen = maxlen + self._initialized = True + + def add_frame(self, device_id, frame): + """向设备队列中添加帧,如果设备不存在,则自动创建队列""" + with self._lock: + if device_id not in self.device_queues: + self.device_queues[device_id] = deque(maxlen=self.maxlen) + if len(self.device_queues[device_id]) >= self.maxlen: + self.device_queues[device_id].pop() # 移除最旧的帧 + self.device_queues[device_id].appendleft(frame) # 添加新帧 + + def get_latest_frame(self, device_id): + """获取设备队列中的最新帧""" + with self._lock: + if device_id in self.device_queues and self.device_queues[device_id]: + return self.device_queues[device_id][0] + return None # 如果设备不存在或没有帧,返回 None + + def remove_device(self, device_id): + """移除设备及其队列""" + with self._lock: + if device_id in self.device_queues: + del self.device_queues[device_id] + + def has_device(self,device_id): + return device_id in self.device_queues and self.device_queues[device_id] diff --git a/common/global_logger.py b/common/global_logger.py index c975ea8..47e41b7 100644 --- a/common/global_logger.py +++ b/common/global_logger.py @@ -13,7 +13,6 @@ logger = logging.getLogger("casic_safe_logger") logger.setLevel(logging.DEBUG) # 设置日志级别 -logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING) # 创建一个TimedRotatingFileHandler handler = TimedRotatingFileHandler( diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 8724c0d..6a86830 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -1,5 +1,6 @@ import asyncio import json +from copy import deepcopy from dataclasses import dataclass import importlib from datetime import datetime @@ -9,9 +10,11 @@ from algo.model_manager import AlgoModelExec from algo.stream_loader import OpenCVStreamLoad from common.device_status_manager import DeviceStatusManager +from common.display_frame_manager import DisplayFrameManager from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool from common.http_utils import send_request +from common.image_plotting import Annotator from common.string_utils import camel_to_snake, get_class, default_serializer from db.database import get_db from entity.device import Device @@ -52,6 +55,7 @@ self.thread_pool = GlobalThreadPool() self.device_status_manager = DeviceStatusManager() + self.display_frame_manager = DisplayFrameManager() self.main_loop = main_loop self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, @@ -134,12 +138,17 @@ self.device_status_manager.set_status(device_id=self.device.id) results_map = {} + annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") for model_exec in self.model_exec_list: handle_task_name = model_exec.algo_model_info.handle_task handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) handler_instance = handler_cls(model_exec) - frame, results = handler_instance.run(frame, None) + frame, results = handler_instance.run(frame, annotator) results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] # 结果处理 self.thread_pool.submit_task(self.save_frame_results, frame, results_map) + self.display_frame_manager.add_frame(self.device.id, annotator.result()) + # future = asyncio.run_coroutine_threadsafe( + # self.display_frame_manager.add_frame(self.device.id, annotator.result()), self.main_loop + # ) diff --git a/apis/control.py b/apis/control.py index 4c108c5..9c6d6fb 100644 --- a/apis/control.py +++ b/apis/control.py @@ -98,3 +98,11 @@ except Exception as e: traceback.print_exc() return standard_error_response(code=500, message=f"Failed to restart container: {e}") + +@router.get("/sync_test") +def sync_test(): + return standard_response() + +@router.get("/async_test") +async def async_test(): + return standard_response() diff --git a/apis/display.py b/apis/display.py new file mode 100644 index 0000000..64c998a --- /dev/null +++ b/apis/display.py @@ -0,0 +1,73 @@ +import asyncio +import threading + +import cv2 +from fastapi import APIRouter, Request, WebSocket, WebSocketDisconnect +from starlette.responses import StreamingResponse + +from common.display_frame_manager import DisplayFrameManager +from common.global_logger import logger + +router = APIRouter() +display_frame_manager = DisplayFrameManager() + + +async def generate_video_stream(device_id, request: Request): + while True: + if await request.is_disconnected(): + print("客户端已断开连接,停止视频流") + break + + if display_frame_manager.has_device(device_id): + frame = display_frame_manager.get_latest_frame(device_id) + # 将帧编码为 JPEG 格式 + _, buffer = cv2.imencode('.jpg', frame) + frame = buffer.tobytes() + + yield (b'--frame\r\n' + b'Content-Type: image/jpeg\r\n\r\n' + frame + b'\r\n') + + await asyncio.sleep(0.03) # 控制帧的推送频率,减小CPU占用 + + +@router.get("/video") +async def video_stream(device_id: int, request: Request): + return StreamingResponse(generate_video_stream(device_id, request), + media_type="multipart/x-mixed-replace; boundary=frame") + + +async def send_video_stream(device_id, websocket: WebSocket): + await websocket.accept() # 接受 WebSocket 连接 + try: + while True: + # 检查并获取设备的最新帧 + frame = None + if display_frame_manager.has_device(device_id): + frame = display_frame_manager.get_latest_frame(device_id) + + # 如果有可用帧,则进行处理并发送 + if frame is not None: + # 将帧编码为 JPEG + _, buffer = cv2.imencode('.jpg', frame) + # 将图像数据编码为 Base64 字符串(或者可以直接发送字节流) + await websocket.send_bytes(buffer.tobytes()) + await asyncio.sleep(0) + else: + # 如果没有帧,可以选择发送空帧或跳过 + await asyncio.sleep(0.03) # 控制帧率,避免过度占用资源 + except WebSocketDisconnect: + logger.info("WebSocket 连接已断开") + except Exception as e: + logger.error(f"WebSocket 连接出错: {e}") + + +def run_in_thread(device_id, websocket: WebSocket): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(send_video_stream(device_id, websocket)) + loop.close() + + +@router.websocket("/ws/video/{device_id}") +async def video_stream(websocket: WebSocket, device_id: int): + await send_video_stream(device_id,websocket) diff --git a/apis/router.py b/apis/router.py index 18e1311..4f575f6 100644 --- a/apis/router.py +++ b/apis/router.py @@ -9,7 +9,7 @@ from .data_gas import router as gas_router from .control import router as control_router from .push_config import router as push_config_router - +from .display import router as display_router # 创建一个全局的 router router = APIRouter() @@ -18,10 +18,10 @@ router.include_router(devices_router, prefix="/device", tags=["Devices"]) router.include_router(models_router, prefix="/model", tags=["Models"]) router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"]) -router.include_router(scene_router,prefix="/scene", tags=["Scene"]) +router.include_router(scene_router, prefix="/scene", tags=["Scene"]) router.include_router(device_scene_relation_router, prefix="/device_scene_relation", tags=["DeviceSceneRelations"]) router.include_router(frame_router, prefix="/frame", tags=["DeviceFrame"]) router.include_router(gas_router, prefix="/gas", tags=["DataGas"]) -router.include_router(control_router,prefix="/control", tags=["Control"]) -router.include_router(push_config_router,prefix="/push", tags=["PushConfig"]) - +router.include_router(control_router, prefix="/control", tags=["Control"]) +router.include_router(push_config_router, prefix="/push", tags=["PushConfig"]) +router.include_router(display_router, prefix="/display", tags=["Display"]) diff --git a/app_instance.py b/app_instance.py index 4e24675..82a750d 100644 --- a/app_instance.py +++ b/app_instance.py @@ -49,7 +49,7 @@ tcp_manager = TcpManager(device_service=device_service) app.state.tcp_manager = tcp_manager - # await tcp_manager.start() + await tcp_manager.start() algo_runner = AlgoRunner( device_service=device_service, @@ -57,7 +57,7 @@ relation_service=model_relation_service, ) app.state.algo_runner = algo_runner - # await algo_runner.start() + await algo_runner.start() scene_runner = SceneRunner( device_service=device_service, diff --git a/common/display_frame_manager.py b/common/display_frame_manager.py new file mode 100644 index 0000000..d7f8354 --- /dev/null +++ b/common/display_frame_manager.py @@ -0,0 +1,46 @@ +import threading +from collections import deque + + +class DisplayFrameManager: + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + """实现单例模式,确保只有一个实例""" + if cls._instance is None: + with cls._lock: + if cls._instance is None: # 防止多个线程同时创建 + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self, maxlen=10): + if not hasattr(self, '_initialized'): # 防止重复初始化 + self.device_queues = {} + self.maxlen = maxlen + self._initialized = True + + def add_frame(self, device_id, frame): + """向设备队列中添加帧,如果设备不存在,则自动创建队列""" + with self._lock: + if device_id not in self.device_queues: + self.device_queues[device_id] = deque(maxlen=self.maxlen) + if len(self.device_queues[device_id]) >= self.maxlen: + self.device_queues[device_id].pop() # 移除最旧的帧 + self.device_queues[device_id].appendleft(frame) # 添加新帧 + + def get_latest_frame(self, device_id): + """获取设备队列中的最新帧""" + with self._lock: + if device_id in self.device_queues and self.device_queues[device_id]: + return self.device_queues[device_id][0] + return None # 如果设备不存在或没有帧,返回 None + + def remove_device(self, device_id): + """移除设备及其队列""" + with self._lock: + if device_id in self.device_queues: + del self.device_queues[device_id] + + def has_device(self,device_id): + return device_id in self.device_queues and self.device_queues[device_id] diff --git a/common/global_logger.py b/common/global_logger.py index c975ea8..47e41b7 100644 --- a/common/global_logger.py +++ b/common/global_logger.py @@ -13,7 +13,6 @@ logger = logging.getLogger("casic_safe_logger") logger.setLevel(logging.DEBUG) # 设置日志级别 -logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING) # 创建一个TimedRotatingFileHandler handler = TimedRotatingFileHandler( diff --git a/common/image_plotting.py b/common/image_plotting.py index 818ab95..5a2d942 100644 --- a/common/image_plotting.py +++ b/common/image_plotting.py @@ -1,6 +1,8 @@ +from collections.abc import Sequence + import numpy as np import cv2 - +from PIL import Image, ImageDraw, ImageFont class Colors: """ @@ -85,30 +87,38 @@ def __init__(self, im, line_width=None, font_size=None, font="Arial.ttf", pil=False, example="abc"): self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) - assert im.data.contiguous, "Image not contiguous. Apply np.ascontiguousarray(im) to Annotator input images." - self.im = im if im.flags.writeable else im.copy() + + self.im = Image.fromarray(im) + self.draw = ImageDraw.Draw(self.im) + + font = 'static/font/Arial.Unicode.ttf' + size = font_size or max(round(sum(self.im.size) / 2 * 0.035), 12) + self.font = ImageFont.truetype(str(font), size) + if not hasattr(self.font, 'getsize'): + self.font.getsize = lambda x: self.font.getbbox(x)[2:4] + self.tf = max(self.lw - 1, 1) # font thickness self.sf = self.lw / 3 # font scale def box_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), rotated=False): - """Add one xyxy box to image with label.""" - p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3])) - cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA) + if not isinstance(box, Sequence): + box = box.tolist() + if rotated: + p1 = box[0] + # NOTE: PIL-version polygon needs tuple type. + self.draw.polygon([tuple(b) for b in box], width=self.lw, outline=color) + else: + p1 = (box[0], box[1]) + self.draw.rectangle(box, width=self.lw, outline=color) # box if label: - w, h = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0] # text width, height - outside = p1[1] - h >= 3 - p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3 - cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled - cv2.putText( - self.im, - label, - (p1[0], p1[1] - 2 if outside else p1[1] + h + 2), - 0, - self.sf, - txt_color, - thickness=self.tf, - lineType=cv2.LINE_AA, + w, h = self.font.getsize(label) # text width, height + outside = p1[1] - h >= 0 # label fits outside box + self.draw.rectangle( + (p1[0], p1[1] - h if outside else p1[1], p1[0] + w + 1, p1[1] + 1 if outside else p1[1] + h + 1), + fill=color, ) + # self.draw.text((box[0], box[1]), label, fill=txt_color, font=self.font, anchor='ls') # for PIL>8.0 + self.draw.text((p1[0], p1[1] - h if outside else p1[1]), label, fill=txt_color, font=self.font) def result(self): """Return annotated image as array.""" diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 8724c0d..6a86830 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -1,5 +1,6 @@ import asyncio import json +from copy import deepcopy from dataclasses import dataclass import importlib from datetime import datetime @@ -9,9 +10,11 @@ from algo.model_manager import AlgoModelExec from algo.stream_loader import OpenCVStreamLoad from common.device_status_manager import DeviceStatusManager +from common.display_frame_manager import DisplayFrameManager from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool from common.http_utils import send_request +from common.image_plotting import Annotator from common.string_utils import camel_to_snake, get_class, default_serializer from db.database import get_db from entity.device import Device @@ -52,6 +55,7 @@ self.thread_pool = GlobalThreadPool() self.device_status_manager = DeviceStatusManager() + self.display_frame_manager = DisplayFrameManager() self.main_loop = main_loop self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, @@ -134,12 +138,17 @@ self.device_status_manager.set_status(device_id=self.device.id) results_map = {} + annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") for model_exec in self.model_exec_list: handle_task_name = model_exec.algo_model_info.handle_task handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) handler_instance = handler_cls(model_exec) - frame, results = handler_instance.run(frame, None) + frame, results = handler_instance.run(frame, annotator) results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] # 结果处理 self.thread_pool.submit_task(self.save_frame_results, frame, results_map) + self.display_frame_manager.add_frame(self.device.id, annotator.result()) + # future = asyncio.run_coroutine_threadsafe( + # self.display_frame_manager.add_frame(self.device.id, annotator.result()), self.main_loop + # ) diff --git a/apis/control.py b/apis/control.py index 4c108c5..9c6d6fb 100644 --- a/apis/control.py +++ b/apis/control.py @@ -98,3 +98,11 @@ except Exception as e: traceback.print_exc() return standard_error_response(code=500, message=f"Failed to restart container: {e}") + +@router.get("/sync_test") +def sync_test(): + return standard_response() + +@router.get("/async_test") +async def async_test(): + return standard_response() diff --git a/apis/display.py b/apis/display.py new file mode 100644 index 0000000..64c998a --- /dev/null +++ b/apis/display.py @@ -0,0 +1,73 @@ +import asyncio +import threading + +import cv2 +from fastapi import APIRouter, Request, WebSocket, WebSocketDisconnect +from starlette.responses import StreamingResponse + +from common.display_frame_manager import DisplayFrameManager +from common.global_logger import logger + +router = APIRouter() +display_frame_manager = DisplayFrameManager() + + +async def generate_video_stream(device_id, request: Request): + while True: + if await request.is_disconnected(): + print("客户端已断开连接,停止视频流") + break + + if display_frame_manager.has_device(device_id): + frame = display_frame_manager.get_latest_frame(device_id) + # 将帧编码为 JPEG 格式 + _, buffer = cv2.imencode('.jpg', frame) + frame = buffer.tobytes() + + yield (b'--frame\r\n' + b'Content-Type: image/jpeg\r\n\r\n' + frame + b'\r\n') + + await asyncio.sleep(0.03) # 控制帧的推送频率,减小CPU占用 + + +@router.get("/video") +async def video_stream(device_id: int, request: Request): + return StreamingResponse(generate_video_stream(device_id, request), + media_type="multipart/x-mixed-replace; boundary=frame") + + +async def send_video_stream(device_id, websocket: WebSocket): + await websocket.accept() # 接受 WebSocket 连接 + try: + while True: + # 检查并获取设备的最新帧 + frame = None + if display_frame_manager.has_device(device_id): + frame = display_frame_manager.get_latest_frame(device_id) + + # 如果有可用帧,则进行处理并发送 + if frame is not None: + # 将帧编码为 JPEG + _, buffer = cv2.imencode('.jpg', frame) + # 将图像数据编码为 Base64 字符串(或者可以直接发送字节流) + await websocket.send_bytes(buffer.tobytes()) + await asyncio.sleep(0) + else: + # 如果没有帧,可以选择发送空帧或跳过 + await asyncio.sleep(0.03) # 控制帧率,避免过度占用资源 + except WebSocketDisconnect: + logger.info("WebSocket 连接已断开") + except Exception as e: + logger.error(f"WebSocket 连接出错: {e}") + + +def run_in_thread(device_id, websocket: WebSocket): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(send_video_stream(device_id, websocket)) + loop.close() + + +@router.websocket("/ws/video/{device_id}") +async def video_stream(websocket: WebSocket, device_id: int): + await send_video_stream(device_id,websocket) diff --git a/apis/router.py b/apis/router.py index 18e1311..4f575f6 100644 --- a/apis/router.py +++ b/apis/router.py @@ -9,7 +9,7 @@ from .data_gas import router as gas_router from .control import router as control_router from .push_config import router as push_config_router - +from .display import router as display_router # 创建一个全局的 router router = APIRouter() @@ -18,10 +18,10 @@ router.include_router(devices_router, prefix="/device", tags=["Devices"]) router.include_router(models_router, prefix="/model", tags=["Models"]) router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"]) -router.include_router(scene_router,prefix="/scene", tags=["Scene"]) +router.include_router(scene_router, prefix="/scene", tags=["Scene"]) router.include_router(device_scene_relation_router, prefix="/device_scene_relation", tags=["DeviceSceneRelations"]) router.include_router(frame_router, prefix="/frame", tags=["DeviceFrame"]) router.include_router(gas_router, prefix="/gas", tags=["DataGas"]) -router.include_router(control_router,prefix="/control", tags=["Control"]) -router.include_router(push_config_router,prefix="/push", tags=["PushConfig"]) - +router.include_router(control_router, prefix="/control", tags=["Control"]) +router.include_router(push_config_router, prefix="/push", tags=["PushConfig"]) +router.include_router(display_router, prefix="/display", tags=["Display"]) diff --git a/app_instance.py b/app_instance.py index 4e24675..82a750d 100644 --- a/app_instance.py +++ b/app_instance.py @@ -49,7 +49,7 @@ tcp_manager = TcpManager(device_service=device_service) app.state.tcp_manager = tcp_manager - # await tcp_manager.start() + await tcp_manager.start() algo_runner = AlgoRunner( device_service=device_service, @@ -57,7 +57,7 @@ relation_service=model_relation_service, ) app.state.algo_runner = algo_runner - # await algo_runner.start() + await algo_runner.start() scene_runner = SceneRunner( device_service=device_service, diff --git a/common/display_frame_manager.py b/common/display_frame_manager.py new file mode 100644 index 0000000..d7f8354 --- /dev/null +++ b/common/display_frame_manager.py @@ -0,0 +1,46 @@ +import threading +from collections import deque + + +class DisplayFrameManager: + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + """实现单例模式,确保只有一个实例""" + if cls._instance is None: + with cls._lock: + if cls._instance is None: # 防止多个线程同时创建 + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self, maxlen=10): + if not hasattr(self, '_initialized'): # 防止重复初始化 + self.device_queues = {} + self.maxlen = maxlen + self._initialized = True + + def add_frame(self, device_id, frame): + """向设备队列中添加帧,如果设备不存在,则自动创建队列""" + with self._lock: + if device_id not in self.device_queues: + self.device_queues[device_id] = deque(maxlen=self.maxlen) + if len(self.device_queues[device_id]) >= self.maxlen: + self.device_queues[device_id].pop() # 移除最旧的帧 + self.device_queues[device_id].appendleft(frame) # 添加新帧 + + def get_latest_frame(self, device_id): + """获取设备队列中的最新帧""" + with self._lock: + if device_id in self.device_queues and self.device_queues[device_id]: + return self.device_queues[device_id][0] + return None # 如果设备不存在或没有帧,返回 None + + def remove_device(self, device_id): + """移除设备及其队列""" + with self._lock: + if device_id in self.device_queues: + del self.device_queues[device_id] + + def has_device(self,device_id): + return device_id in self.device_queues and self.device_queues[device_id] diff --git a/common/global_logger.py b/common/global_logger.py index c975ea8..47e41b7 100644 --- a/common/global_logger.py +++ b/common/global_logger.py @@ -13,7 +13,6 @@ logger = logging.getLogger("casic_safe_logger") logger.setLevel(logging.DEBUG) # 设置日志级别 -logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING) # 创建一个TimedRotatingFileHandler handler = TimedRotatingFileHandler( diff --git a/common/image_plotting.py b/common/image_plotting.py index 818ab95..5a2d942 100644 --- a/common/image_plotting.py +++ b/common/image_plotting.py @@ -1,6 +1,8 @@ +from collections.abc import Sequence + import numpy as np import cv2 - +from PIL import Image, ImageDraw, ImageFont class Colors: """ @@ -85,30 +87,38 @@ def __init__(self, im, line_width=None, font_size=None, font="Arial.ttf", pil=False, example="abc"): self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) - assert im.data.contiguous, "Image not contiguous. Apply np.ascontiguousarray(im) to Annotator input images." - self.im = im if im.flags.writeable else im.copy() + + self.im = Image.fromarray(im) + self.draw = ImageDraw.Draw(self.im) + + font = 'static/font/Arial.Unicode.ttf' + size = font_size or max(round(sum(self.im.size) / 2 * 0.035), 12) + self.font = ImageFont.truetype(str(font), size) + if not hasattr(self.font, 'getsize'): + self.font.getsize = lambda x: self.font.getbbox(x)[2:4] + self.tf = max(self.lw - 1, 1) # font thickness self.sf = self.lw / 3 # font scale def box_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), rotated=False): - """Add one xyxy box to image with label.""" - p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3])) - cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA) + if not isinstance(box, Sequence): + box = box.tolist() + if rotated: + p1 = box[0] + # NOTE: PIL-version polygon needs tuple type. + self.draw.polygon([tuple(b) for b in box], width=self.lw, outline=color) + else: + p1 = (box[0], box[1]) + self.draw.rectangle(box, width=self.lw, outline=color) # box if label: - w, h = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0] # text width, height - outside = p1[1] - h >= 3 - p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3 - cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled - cv2.putText( - self.im, - label, - (p1[0], p1[1] - 2 if outside else p1[1] + h + 2), - 0, - self.sf, - txt_color, - thickness=self.tf, - lineType=cv2.LINE_AA, + w, h = self.font.getsize(label) # text width, height + outside = p1[1] - h >= 0 # label fits outside box + self.draw.rectangle( + (p1[0], p1[1] - h if outside else p1[1], p1[0] + w + 1, p1[1] + 1 if outside else p1[1] + h + 1), + fill=color, ) + # self.draw.text((box[0], box[1]), label, fill=txt_color, font=self.font, anchor='ls') # for PIL>8.0 + self.draw.text((p1[0], p1[1] - h if outside else p1[1]), label, fill=txt_color, font=self.font) def result(self): """Return annotated image as array.""" diff --git a/main.py b/main.py index e830e86..8da2462 100644 --- a/main.py +++ b/main.py @@ -42,7 +42,7 @@ # 重定向 uvicorn 的日志 uvicorn_logger = logging.getLogger("uvicorn") uvicorn_logger.handlers = logger.handlers - uvicorn_logger.setLevel(logging.DEBUG) + uvicorn_logger.setLevel(logging.INFO) uvicorn.run(app, host="0.0.0.0", port=9000, log_config=None) diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 8724c0d..6a86830 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -1,5 +1,6 @@ import asyncio import json +from copy import deepcopy from dataclasses import dataclass import importlib from datetime import datetime @@ -9,9 +10,11 @@ from algo.model_manager import AlgoModelExec from algo.stream_loader import OpenCVStreamLoad from common.device_status_manager import DeviceStatusManager +from common.display_frame_manager import DisplayFrameManager from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool from common.http_utils import send_request +from common.image_plotting import Annotator from common.string_utils import camel_to_snake, get_class, default_serializer from db.database import get_db from entity.device import Device @@ -52,6 +55,7 @@ self.thread_pool = GlobalThreadPool() self.device_status_manager = DeviceStatusManager() + self.display_frame_manager = DisplayFrameManager() self.main_loop = main_loop self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, @@ -134,12 +138,17 @@ self.device_status_manager.set_status(device_id=self.device.id) results_map = {} + annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") for model_exec in self.model_exec_list: handle_task_name = model_exec.algo_model_info.handle_task handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) handler_instance = handler_cls(model_exec) - frame, results = handler_instance.run(frame, None) + frame, results = handler_instance.run(frame, annotator) results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] # 结果处理 self.thread_pool.submit_task(self.save_frame_results, frame, results_map) + self.display_frame_manager.add_frame(self.device.id, annotator.result()) + # future = asyncio.run_coroutine_threadsafe( + # self.display_frame_manager.add_frame(self.device.id, annotator.result()), self.main_loop + # ) diff --git a/apis/control.py b/apis/control.py index 4c108c5..9c6d6fb 100644 --- a/apis/control.py +++ b/apis/control.py @@ -98,3 +98,11 @@ except Exception as e: traceback.print_exc() return standard_error_response(code=500, message=f"Failed to restart container: {e}") + +@router.get("/sync_test") +def sync_test(): + return standard_response() + +@router.get("/async_test") +async def async_test(): + return standard_response() diff --git a/apis/display.py b/apis/display.py new file mode 100644 index 0000000..64c998a --- /dev/null +++ b/apis/display.py @@ -0,0 +1,73 @@ +import asyncio +import threading + +import cv2 +from fastapi import APIRouter, Request, WebSocket, WebSocketDisconnect +from starlette.responses import StreamingResponse + +from common.display_frame_manager import DisplayFrameManager +from common.global_logger import logger + +router = APIRouter() +display_frame_manager = DisplayFrameManager() + + +async def generate_video_stream(device_id, request: Request): + while True: + if await request.is_disconnected(): + print("客户端已断开连接,停止视频流") + break + + if display_frame_manager.has_device(device_id): + frame = display_frame_manager.get_latest_frame(device_id) + # 将帧编码为 JPEG 格式 + _, buffer = cv2.imencode('.jpg', frame) + frame = buffer.tobytes() + + yield (b'--frame\r\n' + b'Content-Type: image/jpeg\r\n\r\n' + frame + b'\r\n') + + await asyncio.sleep(0.03) # 控制帧的推送频率,减小CPU占用 + + +@router.get("/video") +async def video_stream(device_id: int, request: Request): + return StreamingResponse(generate_video_stream(device_id, request), + media_type="multipart/x-mixed-replace; boundary=frame") + + +async def send_video_stream(device_id, websocket: WebSocket): + await websocket.accept() # 接受 WebSocket 连接 + try: + while True: + # 检查并获取设备的最新帧 + frame = None + if display_frame_manager.has_device(device_id): + frame = display_frame_manager.get_latest_frame(device_id) + + # 如果有可用帧,则进行处理并发送 + if frame is not None: + # 将帧编码为 JPEG + _, buffer = cv2.imencode('.jpg', frame) + # 将图像数据编码为 Base64 字符串(或者可以直接发送字节流) + await websocket.send_bytes(buffer.tobytes()) + await asyncio.sleep(0) + else: + # 如果没有帧,可以选择发送空帧或跳过 + await asyncio.sleep(0.03) # 控制帧率,避免过度占用资源 + except WebSocketDisconnect: + logger.info("WebSocket 连接已断开") + except Exception as e: + logger.error(f"WebSocket 连接出错: {e}") + + +def run_in_thread(device_id, websocket: WebSocket): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(send_video_stream(device_id, websocket)) + loop.close() + + +@router.websocket("/ws/video/{device_id}") +async def video_stream(websocket: WebSocket, device_id: int): + await send_video_stream(device_id,websocket) diff --git a/apis/router.py b/apis/router.py index 18e1311..4f575f6 100644 --- a/apis/router.py +++ b/apis/router.py @@ -9,7 +9,7 @@ from .data_gas import router as gas_router from .control import router as control_router from .push_config import router as push_config_router - +from .display import router as display_router # 创建一个全局的 router router = APIRouter() @@ -18,10 +18,10 @@ router.include_router(devices_router, prefix="/device", tags=["Devices"]) router.include_router(models_router, prefix="/model", tags=["Models"]) router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"]) -router.include_router(scene_router,prefix="/scene", tags=["Scene"]) +router.include_router(scene_router, prefix="/scene", tags=["Scene"]) router.include_router(device_scene_relation_router, prefix="/device_scene_relation", tags=["DeviceSceneRelations"]) router.include_router(frame_router, prefix="/frame", tags=["DeviceFrame"]) router.include_router(gas_router, prefix="/gas", tags=["DataGas"]) -router.include_router(control_router,prefix="/control", tags=["Control"]) -router.include_router(push_config_router,prefix="/push", tags=["PushConfig"]) - +router.include_router(control_router, prefix="/control", tags=["Control"]) +router.include_router(push_config_router, prefix="/push", tags=["PushConfig"]) +router.include_router(display_router, prefix="/display", tags=["Display"]) diff --git a/app_instance.py b/app_instance.py index 4e24675..82a750d 100644 --- a/app_instance.py +++ b/app_instance.py @@ -49,7 +49,7 @@ tcp_manager = TcpManager(device_service=device_service) app.state.tcp_manager = tcp_manager - # await tcp_manager.start() + await tcp_manager.start() algo_runner = AlgoRunner( device_service=device_service, @@ -57,7 +57,7 @@ relation_service=model_relation_service, ) app.state.algo_runner = algo_runner - # await algo_runner.start() + await algo_runner.start() scene_runner = SceneRunner( device_service=device_service, diff --git a/common/display_frame_manager.py b/common/display_frame_manager.py new file mode 100644 index 0000000..d7f8354 --- /dev/null +++ b/common/display_frame_manager.py @@ -0,0 +1,46 @@ +import threading +from collections import deque + + +class DisplayFrameManager: + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + """实现单例模式,确保只有一个实例""" + if cls._instance is None: + with cls._lock: + if cls._instance is None: # 防止多个线程同时创建 + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self, maxlen=10): + if not hasattr(self, '_initialized'): # 防止重复初始化 + self.device_queues = {} + self.maxlen = maxlen + self._initialized = True + + def add_frame(self, device_id, frame): + """向设备队列中添加帧,如果设备不存在,则自动创建队列""" + with self._lock: + if device_id not in self.device_queues: + self.device_queues[device_id] = deque(maxlen=self.maxlen) + if len(self.device_queues[device_id]) >= self.maxlen: + self.device_queues[device_id].pop() # 移除最旧的帧 + self.device_queues[device_id].appendleft(frame) # 添加新帧 + + def get_latest_frame(self, device_id): + """获取设备队列中的最新帧""" + with self._lock: + if device_id in self.device_queues and self.device_queues[device_id]: + return self.device_queues[device_id][0] + return None # 如果设备不存在或没有帧,返回 None + + def remove_device(self, device_id): + """移除设备及其队列""" + with self._lock: + if device_id in self.device_queues: + del self.device_queues[device_id] + + def has_device(self,device_id): + return device_id in self.device_queues and self.device_queues[device_id] diff --git a/common/global_logger.py b/common/global_logger.py index c975ea8..47e41b7 100644 --- a/common/global_logger.py +++ b/common/global_logger.py @@ -13,7 +13,6 @@ logger = logging.getLogger("casic_safe_logger") logger.setLevel(logging.DEBUG) # 设置日志级别 -logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING) # 创建一个TimedRotatingFileHandler handler = TimedRotatingFileHandler( diff --git a/common/image_plotting.py b/common/image_plotting.py index 818ab95..5a2d942 100644 --- a/common/image_plotting.py +++ b/common/image_plotting.py @@ -1,6 +1,8 @@ +from collections.abc import Sequence + import numpy as np import cv2 - +from PIL import Image, ImageDraw, ImageFont class Colors: """ @@ -85,30 +87,38 @@ def __init__(self, im, line_width=None, font_size=None, font="Arial.ttf", pil=False, example="abc"): self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) - assert im.data.contiguous, "Image not contiguous. Apply np.ascontiguousarray(im) to Annotator input images." - self.im = im if im.flags.writeable else im.copy() + + self.im = Image.fromarray(im) + self.draw = ImageDraw.Draw(self.im) + + font = 'static/font/Arial.Unicode.ttf' + size = font_size or max(round(sum(self.im.size) / 2 * 0.035), 12) + self.font = ImageFont.truetype(str(font), size) + if not hasattr(self.font, 'getsize'): + self.font.getsize = lambda x: self.font.getbbox(x)[2:4] + self.tf = max(self.lw - 1, 1) # font thickness self.sf = self.lw / 3 # font scale def box_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), rotated=False): - """Add one xyxy box to image with label.""" - p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3])) - cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA) + if not isinstance(box, Sequence): + box = box.tolist() + if rotated: + p1 = box[0] + # NOTE: PIL-version polygon needs tuple type. + self.draw.polygon([tuple(b) for b in box], width=self.lw, outline=color) + else: + p1 = (box[0], box[1]) + self.draw.rectangle(box, width=self.lw, outline=color) # box if label: - w, h = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0] # text width, height - outside = p1[1] - h >= 3 - p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3 - cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled - cv2.putText( - self.im, - label, - (p1[0], p1[1] - 2 if outside else p1[1] + h + 2), - 0, - self.sf, - txt_color, - thickness=self.tf, - lineType=cv2.LINE_AA, + w, h = self.font.getsize(label) # text width, height + outside = p1[1] - h >= 0 # label fits outside box + self.draw.rectangle( + (p1[0], p1[1] - h if outside else p1[1], p1[0] + w + 1, p1[1] + 1 if outside else p1[1] + h + 1), + fill=color, ) + # self.draw.text((box[0], box[1]), label, fill=txt_color, font=self.font, anchor='ls') # for PIL>8.0 + self.draw.text((p1[0], p1[1] - h if outside else p1[1]), label, fill=txt_color, font=self.font) def result(self): """Return annotated image as array.""" diff --git a/main.py b/main.py index e830e86..8da2462 100644 --- a/main.py +++ b/main.py @@ -42,7 +42,7 @@ # 重定向 uvicorn 的日志 uvicorn_logger = logging.getLogger("uvicorn") uvicorn_logger.handlers = logger.handlers - uvicorn_logger.setLevel(logging.DEBUG) + uvicorn_logger.setLevel(logging.INFO) uvicorn.run(app, host="0.0.0.0", port=9000, log_config=None) diff --git a/model_handler/base_model_handler.py b/model_handler/base_model_handler.py index 8797f0d..46e0d53 100644 --- a/model_handler/base_model_handler.py +++ b/model_handler/base_model_handler.py @@ -1,5 +1,6 @@ from algo.model_manager import AlgoModelExec from common.global_logger import logger +from common.image_plotting import colors class BaseModelHandler: @@ -35,8 +36,8 @@ if annotator is not None: for s_box in model_result.boxes: annotator.box_label(s_box.xyxy.cpu().squeeze(), - f"{int(s_box.cls)} {float(s_box.conf):.2f}", - color=(255, 0, 0), + f"{self.model_names[int(s_box.cls)]} {float(s_box.conf):.2f}", + color=colors(int(s_box.cls)), rotated=False) return results diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 8724c0d..6a86830 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -1,5 +1,6 @@ import asyncio import json +from copy import deepcopy from dataclasses import dataclass import importlib from datetime import datetime @@ -9,9 +10,11 @@ from algo.model_manager import AlgoModelExec from algo.stream_loader import OpenCVStreamLoad from common.device_status_manager import DeviceStatusManager +from common.display_frame_manager import DisplayFrameManager from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool from common.http_utils import send_request +from common.image_plotting import Annotator from common.string_utils import camel_to_snake, get_class, default_serializer from db.database import get_db from entity.device import Device @@ -52,6 +55,7 @@ self.thread_pool = GlobalThreadPool() self.device_status_manager = DeviceStatusManager() + self.display_frame_manager = DisplayFrameManager() self.main_loop = main_loop self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, @@ -134,12 +138,17 @@ self.device_status_manager.set_status(device_id=self.device.id) results_map = {} + annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") for model_exec in self.model_exec_list: handle_task_name = model_exec.algo_model_info.handle_task handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) handler_instance = handler_cls(model_exec) - frame, results = handler_instance.run(frame, None) + frame, results = handler_instance.run(frame, annotator) results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] # 结果处理 self.thread_pool.submit_task(self.save_frame_results, frame, results_map) + self.display_frame_manager.add_frame(self.device.id, annotator.result()) + # future = asyncio.run_coroutine_threadsafe( + # self.display_frame_manager.add_frame(self.device.id, annotator.result()), self.main_loop + # ) diff --git a/apis/control.py b/apis/control.py index 4c108c5..9c6d6fb 100644 --- a/apis/control.py +++ b/apis/control.py @@ -98,3 +98,11 @@ except Exception as e: traceback.print_exc() return standard_error_response(code=500, message=f"Failed to restart container: {e}") + +@router.get("/sync_test") +def sync_test(): + return standard_response() + +@router.get("/async_test") +async def async_test(): + return standard_response() diff --git a/apis/display.py b/apis/display.py new file mode 100644 index 0000000..64c998a --- /dev/null +++ b/apis/display.py @@ -0,0 +1,73 @@ +import asyncio +import threading + +import cv2 +from fastapi import APIRouter, Request, WebSocket, WebSocketDisconnect +from starlette.responses import StreamingResponse + +from common.display_frame_manager import DisplayFrameManager +from common.global_logger import logger + +router = APIRouter() +display_frame_manager = DisplayFrameManager() + + +async def generate_video_stream(device_id, request: Request): + while True: + if await request.is_disconnected(): + print("客户端已断开连接,停止视频流") + break + + if display_frame_manager.has_device(device_id): + frame = display_frame_manager.get_latest_frame(device_id) + # 将帧编码为 JPEG 格式 + _, buffer = cv2.imencode('.jpg', frame) + frame = buffer.tobytes() + + yield (b'--frame\r\n' + b'Content-Type: image/jpeg\r\n\r\n' + frame + b'\r\n') + + await asyncio.sleep(0.03) # 控制帧的推送频率,减小CPU占用 + + +@router.get("/video") +async def video_stream(device_id: int, request: Request): + return StreamingResponse(generate_video_stream(device_id, request), + media_type="multipart/x-mixed-replace; boundary=frame") + + +async def send_video_stream(device_id, websocket: WebSocket): + await websocket.accept() # 接受 WebSocket 连接 + try: + while True: + # 检查并获取设备的最新帧 + frame = None + if display_frame_manager.has_device(device_id): + frame = display_frame_manager.get_latest_frame(device_id) + + # 如果有可用帧,则进行处理并发送 + if frame is not None: + # 将帧编码为 JPEG + _, buffer = cv2.imencode('.jpg', frame) + # 将图像数据编码为 Base64 字符串(或者可以直接发送字节流) + await websocket.send_bytes(buffer.tobytes()) + await asyncio.sleep(0) + else: + # 如果没有帧,可以选择发送空帧或跳过 + await asyncio.sleep(0.03) # 控制帧率,避免过度占用资源 + except WebSocketDisconnect: + logger.info("WebSocket 连接已断开") + except Exception as e: + logger.error(f"WebSocket 连接出错: {e}") + + +def run_in_thread(device_id, websocket: WebSocket): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(send_video_stream(device_id, websocket)) + loop.close() + + +@router.websocket("/ws/video/{device_id}") +async def video_stream(websocket: WebSocket, device_id: int): + await send_video_stream(device_id,websocket) diff --git a/apis/router.py b/apis/router.py index 18e1311..4f575f6 100644 --- a/apis/router.py +++ b/apis/router.py @@ -9,7 +9,7 @@ from .data_gas import router as gas_router from .control import router as control_router from .push_config import router as push_config_router - +from .display import router as display_router # 创建一个全局的 router router = APIRouter() @@ -18,10 +18,10 @@ router.include_router(devices_router, prefix="/device", tags=["Devices"]) router.include_router(models_router, prefix="/model", tags=["Models"]) router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"]) -router.include_router(scene_router,prefix="/scene", tags=["Scene"]) +router.include_router(scene_router, prefix="/scene", tags=["Scene"]) router.include_router(device_scene_relation_router, prefix="/device_scene_relation", tags=["DeviceSceneRelations"]) router.include_router(frame_router, prefix="/frame", tags=["DeviceFrame"]) router.include_router(gas_router, prefix="/gas", tags=["DataGas"]) -router.include_router(control_router,prefix="/control", tags=["Control"]) -router.include_router(push_config_router,prefix="/push", tags=["PushConfig"]) - +router.include_router(control_router, prefix="/control", tags=["Control"]) +router.include_router(push_config_router, prefix="/push", tags=["PushConfig"]) +router.include_router(display_router, prefix="/display", tags=["Display"]) diff --git a/app_instance.py b/app_instance.py index 4e24675..82a750d 100644 --- a/app_instance.py +++ b/app_instance.py @@ -49,7 +49,7 @@ tcp_manager = TcpManager(device_service=device_service) app.state.tcp_manager = tcp_manager - # await tcp_manager.start() + await tcp_manager.start() algo_runner = AlgoRunner( device_service=device_service, @@ -57,7 +57,7 @@ relation_service=model_relation_service, ) app.state.algo_runner = algo_runner - # await algo_runner.start() + await algo_runner.start() scene_runner = SceneRunner( device_service=device_service, diff --git a/common/display_frame_manager.py b/common/display_frame_manager.py new file mode 100644 index 0000000..d7f8354 --- /dev/null +++ b/common/display_frame_manager.py @@ -0,0 +1,46 @@ +import threading +from collections import deque + + +class DisplayFrameManager: + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + """实现单例模式,确保只有一个实例""" + if cls._instance is None: + with cls._lock: + if cls._instance is None: # 防止多个线程同时创建 + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self, maxlen=10): + if not hasattr(self, '_initialized'): # 防止重复初始化 + self.device_queues = {} + self.maxlen = maxlen + self._initialized = True + + def add_frame(self, device_id, frame): + """向设备队列中添加帧,如果设备不存在,则自动创建队列""" + with self._lock: + if device_id not in self.device_queues: + self.device_queues[device_id] = deque(maxlen=self.maxlen) + if len(self.device_queues[device_id]) >= self.maxlen: + self.device_queues[device_id].pop() # 移除最旧的帧 + self.device_queues[device_id].appendleft(frame) # 添加新帧 + + def get_latest_frame(self, device_id): + """获取设备队列中的最新帧""" + with self._lock: + if device_id in self.device_queues and self.device_queues[device_id]: + return self.device_queues[device_id][0] + return None # 如果设备不存在或没有帧,返回 None + + def remove_device(self, device_id): + """移除设备及其队列""" + with self._lock: + if device_id in self.device_queues: + del self.device_queues[device_id] + + def has_device(self,device_id): + return device_id in self.device_queues and self.device_queues[device_id] diff --git a/common/global_logger.py b/common/global_logger.py index c975ea8..47e41b7 100644 --- a/common/global_logger.py +++ b/common/global_logger.py @@ -13,7 +13,6 @@ logger = logging.getLogger("casic_safe_logger") logger.setLevel(logging.DEBUG) # 设置日志级别 -logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING) # 创建一个TimedRotatingFileHandler handler = TimedRotatingFileHandler( diff --git a/common/image_plotting.py b/common/image_plotting.py index 818ab95..5a2d942 100644 --- a/common/image_plotting.py +++ b/common/image_plotting.py @@ -1,6 +1,8 @@ +from collections.abc import Sequence + import numpy as np import cv2 - +from PIL import Image, ImageDraw, ImageFont class Colors: """ @@ -85,30 +87,38 @@ def __init__(self, im, line_width=None, font_size=None, font="Arial.ttf", pil=False, example="abc"): self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) - assert im.data.contiguous, "Image not contiguous. Apply np.ascontiguousarray(im) to Annotator input images." - self.im = im if im.flags.writeable else im.copy() + + self.im = Image.fromarray(im) + self.draw = ImageDraw.Draw(self.im) + + font = 'static/font/Arial.Unicode.ttf' + size = font_size or max(round(sum(self.im.size) / 2 * 0.035), 12) + self.font = ImageFont.truetype(str(font), size) + if not hasattr(self.font, 'getsize'): + self.font.getsize = lambda x: self.font.getbbox(x)[2:4] + self.tf = max(self.lw - 1, 1) # font thickness self.sf = self.lw / 3 # font scale def box_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), rotated=False): - """Add one xyxy box to image with label.""" - p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3])) - cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA) + if not isinstance(box, Sequence): + box = box.tolist() + if rotated: + p1 = box[0] + # NOTE: PIL-version polygon needs tuple type. + self.draw.polygon([tuple(b) for b in box], width=self.lw, outline=color) + else: + p1 = (box[0], box[1]) + self.draw.rectangle(box, width=self.lw, outline=color) # box if label: - w, h = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0] # text width, height - outside = p1[1] - h >= 3 - p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3 - cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled - cv2.putText( - self.im, - label, - (p1[0], p1[1] - 2 if outside else p1[1] + h + 2), - 0, - self.sf, - txt_color, - thickness=self.tf, - lineType=cv2.LINE_AA, + w, h = self.font.getsize(label) # text width, height + outside = p1[1] - h >= 0 # label fits outside box + self.draw.rectangle( + (p1[0], p1[1] - h if outside else p1[1], p1[0] + w + 1, p1[1] + 1 if outside else p1[1] + h + 1), + fill=color, ) + # self.draw.text((box[0], box[1]), label, fill=txt_color, font=self.font, anchor='ls') # for PIL>8.0 + self.draw.text((p1[0], p1[1] - h if outside else p1[1]), label, fill=txt_color, font=self.font) def result(self): """Return annotated image as array.""" diff --git a/main.py b/main.py index e830e86..8da2462 100644 --- a/main.py +++ b/main.py @@ -42,7 +42,7 @@ # 重定向 uvicorn 的日志 uvicorn_logger = logging.getLogger("uvicorn") uvicorn_logger.handlers = logger.handlers - uvicorn_logger.setLevel(logging.DEBUG) + uvicorn_logger.setLevel(logging.INFO) uvicorn.run(app, host="0.0.0.0", port=9000, log_config=None) diff --git a/model_handler/base_model_handler.py b/model_handler/base_model_handler.py index 8797f0d..46e0d53 100644 --- a/model_handler/base_model_handler.py +++ b/model_handler/base_model_handler.py @@ -1,5 +1,6 @@ from algo.model_manager import AlgoModelExec from common.global_logger import logger +from common.image_plotting import colors class BaseModelHandler: @@ -35,8 +36,8 @@ if annotator is not None: for s_box in model_result.boxes: annotator.box_label(s_box.xyxy.cpu().squeeze(), - f"{int(s_box.cls)} {float(s_box.conf):.2f}", - color=(255, 0, 0), + f"{self.model_names[int(s_box.cls)]} {float(s_box.conf):.2f}", + color=colors(int(s_box.cls)), rotated=False) return results diff --git a/model_handler/coco_engine_model_handler.py b/model_handler/coco_engine_model_handler.py new file mode 100644 index 0000000..e6023d0 --- /dev/null +++ b/model_handler/coco_engine_model_handler.py @@ -0,0 +1,90 @@ +from algo.model_manager import AlgoModelExec +from model_handler.base_model_handler import BaseModelHandler + + +class CocoEngineModelHandler(BaseModelHandler): + + def __init__(self, model: AlgoModelExec): + super().__init__(model) + self.model_names = { + 0: 'person', + 1: 'bicycle', + 2: 'car', + 3: 'motorcycle', + 4: 'airplane', + 5: 'bus', + 6: 'train', + 7: 'truck', + 8: 'boat', + 9: 'traffic light', + 10: 'fire hydrant', + 11: 'stop sign', + 12: 'parking meter', + 13: 'bench', + 14: 'bird', + 15: 'cat', + 16: 'dog', + 17: 'horse', + 18: 'sheep', + 19: 'cow', + 20: 'elephant', + 21: 'bear', + 22: 'zebra', + 23: 'giraffe', + 24: 'backpack', + 25: 'umbrella', + 26: 'handbag', + 27: 'tie', + 28: 'suitcase', + 29: 'frisbee', + 30: 'skis', + 31: 'snowboard', + 32: 'sports ball', + 33: 'kite', + 34: 'baseball bat', + 35: 'baseball glove', + 36: 'skateboard', + 37: 'surfboard', + 38: 'tennis racket', + 39: 'bottle', + 40: 'wine glass', + 41: 'cup', + 42: 'fork', + 43: 'knife', + 44: 'spoon', + 45: 'bowl', + 46: 'banana', + 47: 'apple', + 48: 'sandwich', + 49: 'orange', + 50: 'broccoli', + 51: 'carrot', + 52: 'hot dog', + 53: 'pizza', + 54: 'donut', + 55: 'cake', + 56: 'chair', + 57: 'couch', + 58: 'potted plant', + 59: 'bed', + 60: 'dining table', + 61: 'toilet', + 62: 'tv', + 63: 'laptop', + 64: 'mouse', + 65: 'remote', + 66: 'keyboard', + 67: 'cell phone', + 68: 'microwave', + 69: 'oven', + 70: 'toaster', + 71: 'sink', + 72: 'refrigerator', + 73: 'book', + 74: 'clock', + 75: 'vase', + 76: 'scissors', + 77: 'teddy bear', + 78: 'hair_drier', + 79: 'toothbrush', + } diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 8724c0d..6a86830 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -1,5 +1,6 @@ import asyncio import json +from copy import deepcopy from dataclasses import dataclass import importlib from datetime import datetime @@ -9,9 +10,11 @@ from algo.model_manager import AlgoModelExec from algo.stream_loader import OpenCVStreamLoad from common.device_status_manager import DeviceStatusManager +from common.display_frame_manager import DisplayFrameManager from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool from common.http_utils import send_request +from common.image_plotting import Annotator from common.string_utils import camel_to_snake, get_class, default_serializer from db.database import get_db from entity.device import Device @@ -52,6 +55,7 @@ self.thread_pool = GlobalThreadPool() self.device_status_manager = DeviceStatusManager() + self.display_frame_manager = DisplayFrameManager() self.main_loop = main_loop self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, @@ -134,12 +138,17 @@ self.device_status_manager.set_status(device_id=self.device.id) results_map = {} + annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") for model_exec in self.model_exec_list: handle_task_name = model_exec.algo_model_info.handle_task handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) handler_instance = handler_cls(model_exec) - frame, results = handler_instance.run(frame, None) + frame, results = handler_instance.run(frame, annotator) results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] # 结果处理 self.thread_pool.submit_task(self.save_frame_results, frame, results_map) + self.display_frame_manager.add_frame(self.device.id, annotator.result()) + # future = asyncio.run_coroutine_threadsafe( + # self.display_frame_manager.add_frame(self.device.id, annotator.result()), self.main_loop + # ) diff --git a/apis/control.py b/apis/control.py index 4c108c5..9c6d6fb 100644 --- a/apis/control.py +++ b/apis/control.py @@ -98,3 +98,11 @@ except Exception as e: traceback.print_exc() return standard_error_response(code=500, message=f"Failed to restart container: {e}") + +@router.get("/sync_test") +def sync_test(): + return standard_response() + +@router.get("/async_test") +async def async_test(): + return standard_response() diff --git a/apis/display.py b/apis/display.py new file mode 100644 index 0000000..64c998a --- /dev/null +++ b/apis/display.py @@ -0,0 +1,73 @@ +import asyncio +import threading + +import cv2 +from fastapi import APIRouter, Request, WebSocket, WebSocketDisconnect +from starlette.responses import StreamingResponse + +from common.display_frame_manager import DisplayFrameManager +from common.global_logger import logger + +router = APIRouter() +display_frame_manager = DisplayFrameManager() + + +async def generate_video_stream(device_id, request: Request): + while True: + if await request.is_disconnected(): + print("客户端已断开连接,停止视频流") + break + + if display_frame_manager.has_device(device_id): + frame = display_frame_manager.get_latest_frame(device_id) + # 将帧编码为 JPEG 格式 + _, buffer = cv2.imencode('.jpg', frame) + frame = buffer.tobytes() + + yield (b'--frame\r\n' + b'Content-Type: image/jpeg\r\n\r\n' + frame + b'\r\n') + + await asyncio.sleep(0.03) # 控制帧的推送频率,减小CPU占用 + + +@router.get("/video") +async def video_stream(device_id: int, request: Request): + return StreamingResponse(generate_video_stream(device_id, request), + media_type="multipart/x-mixed-replace; boundary=frame") + + +async def send_video_stream(device_id, websocket: WebSocket): + await websocket.accept() # 接受 WebSocket 连接 + try: + while True: + # 检查并获取设备的最新帧 + frame = None + if display_frame_manager.has_device(device_id): + frame = display_frame_manager.get_latest_frame(device_id) + + # 如果有可用帧,则进行处理并发送 + if frame is not None: + # 将帧编码为 JPEG + _, buffer = cv2.imencode('.jpg', frame) + # 将图像数据编码为 Base64 字符串(或者可以直接发送字节流) + await websocket.send_bytes(buffer.tobytes()) + await asyncio.sleep(0) + else: + # 如果没有帧,可以选择发送空帧或跳过 + await asyncio.sleep(0.03) # 控制帧率,避免过度占用资源 + except WebSocketDisconnect: + logger.info("WebSocket 连接已断开") + except Exception as e: + logger.error(f"WebSocket 连接出错: {e}") + + +def run_in_thread(device_id, websocket: WebSocket): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(send_video_stream(device_id, websocket)) + loop.close() + + +@router.websocket("/ws/video/{device_id}") +async def video_stream(websocket: WebSocket, device_id: int): + await send_video_stream(device_id,websocket) diff --git a/apis/router.py b/apis/router.py index 18e1311..4f575f6 100644 --- a/apis/router.py +++ b/apis/router.py @@ -9,7 +9,7 @@ from .data_gas import router as gas_router from .control import router as control_router from .push_config import router as push_config_router - +from .display import router as display_router # 创建一个全局的 router router = APIRouter() @@ -18,10 +18,10 @@ router.include_router(devices_router, prefix="/device", tags=["Devices"]) router.include_router(models_router, prefix="/model", tags=["Models"]) router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"]) -router.include_router(scene_router,prefix="/scene", tags=["Scene"]) +router.include_router(scene_router, prefix="/scene", tags=["Scene"]) router.include_router(device_scene_relation_router, prefix="/device_scene_relation", tags=["DeviceSceneRelations"]) router.include_router(frame_router, prefix="/frame", tags=["DeviceFrame"]) router.include_router(gas_router, prefix="/gas", tags=["DataGas"]) -router.include_router(control_router,prefix="/control", tags=["Control"]) -router.include_router(push_config_router,prefix="/push", tags=["PushConfig"]) - +router.include_router(control_router, prefix="/control", tags=["Control"]) +router.include_router(push_config_router, prefix="/push", tags=["PushConfig"]) +router.include_router(display_router, prefix="/display", tags=["Display"]) diff --git a/app_instance.py b/app_instance.py index 4e24675..82a750d 100644 --- a/app_instance.py +++ b/app_instance.py @@ -49,7 +49,7 @@ tcp_manager = TcpManager(device_service=device_service) app.state.tcp_manager = tcp_manager - # await tcp_manager.start() + await tcp_manager.start() algo_runner = AlgoRunner( device_service=device_service, @@ -57,7 +57,7 @@ relation_service=model_relation_service, ) app.state.algo_runner = algo_runner - # await algo_runner.start() + await algo_runner.start() scene_runner = SceneRunner( device_service=device_service, diff --git a/common/display_frame_manager.py b/common/display_frame_manager.py new file mode 100644 index 0000000..d7f8354 --- /dev/null +++ b/common/display_frame_manager.py @@ -0,0 +1,46 @@ +import threading +from collections import deque + + +class DisplayFrameManager: + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + """实现单例模式,确保只有一个实例""" + if cls._instance is None: + with cls._lock: + if cls._instance is None: # 防止多个线程同时创建 + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self, maxlen=10): + if not hasattr(self, '_initialized'): # 防止重复初始化 + self.device_queues = {} + self.maxlen = maxlen + self._initialized = True + + def add_frame(self, device_id, frame): + """向设备队列中添加帧,如果设备不存在,则自动创建队列""" + with self._lock: + if device_id not in self.device_queues: + self.device_queues[device_id] = deque(maxlen=self.maxlen) + if len(self.device_queues[device_id]) >= self.maxlen: + self.device_queues[device_id].pop() # 移除最旧的帧 + self.device_queues[device_id].appendleft(frame) # 添加新帧 + + def get_latest_frame(self, device_id): + """获取设备队列中的最新帧""" + with self._lock: + if device_id in self.device_queues and self.device_queues[device_id]: + return self.device_queues[device_id][0] + return None # 如果设备不存在或没有帧,返回 None + + def remove_device(self, device_id): + """移除设备及其队列""" + with self._lock: + if device_id in self.device_queues: + del self.device_queues[device_id] + + def has_device(self,device_id): + return device_id in self.device_queues and self.device_queues[device_id] diff --git a/common/global_logger.py b/common/global_logger.py index c975ea8..47e41b7 100644 --- a/common/global_logger.py +++ b/common/global_logger.py @@ -13,7 +13,6 @@ logger = logging.getLogger("casic_safe_logger") logger.setLevel(logging.DEBUG) # 设置日志级别 -logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING) # 创建一个TimedRotatingFileHandler handler = TimedRotatingFileHandler( diff --git a/common/image_plotting.py b/common/image_plotting.py index 818ab95..5a2d942 100644 --- a/common/image_plotting.py +++ b/common/image_plotting.py @@ -1,6 +1,8 @@ +from collections.abc import Sequence + import numpy as np import cv2 - +from PIL import Image, ImageDraw, ImageFont class Colors: """ @@ -85,30 +87,38 @@ def __init__(self, im, line_width=None, font_size=None, font="Arial.ttf", pil=False, example="abc"): self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) - assert im.data.contiguous, "Image not contiguous. Apply np.ascontiguousarray(im) to Annotator input images." - self.im = im if im.flags.writeable else im.copy() + + self.im = Image.fromarray(im) + self.draw = ImageDraw.Draw(self.im) + + font = 'static/font/Arial.Unicode.ttf' + size = font_size or max(round(sum(self.im.size) / 2 * 0.035), 12) + self.font = ImageFont.truetype(str(font), size) + if not hasattr(self.font, 'getsize'): + self.font.getsize = lambda x: self.font.getbbox(x)[2:4] + self.tf = max(self.lw - 1, 1) # font thickness self.sf = self.lw / 3 # font scale def box_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), rotated=False): - """Add one xyxy box to image with label.""" - p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3])) - cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA) + if not isinstance(box, Sequence): + box = box.tolist() + if rotated: + p1 = box[0] + # NOTE: PIL-version polygon needs tuple type. + self.draw.polygon([tuple(b) for b in box], width=self.lw, outline=color) + else: + p1 = (box[0], box[1]) + self.draw.rectangle(box, width=self.lw, outline=color) # box if label: - w, h = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0] # text width, height - outside = p1[1] - h >= 3 - p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3 - cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled - cv2.putText( - self.im, - label, - (p1[0], p1[1] - 2 if outside else p1[1] + h + 2), - 0, - self.sf, - txt_color, - thickness=self.tf, - lineType=cv2.LINE_AA, + w, h = self.font.getsize(label) # text width, height + outside = p1[1] - h >= 0 # label fits outside box + self.draw.rectangle( + (p1[0], p1[1] - h if outside else p1[1], p1[0] + w + 1, p1[1] + 1 if outside else p1[1] + h + 1), + fill=color, ) + # self.draw.text((box[0], box[1]), label, fill=txt_color, font=self.font, anchor='ls') # for PIL>8.0 + self.draw.text((p1[0], p1[1] - h if outside else p1[1]), label, fill=txt_color, font=self.font) def result(self): """Return annotated image as array.""" diff --git a/main.py b/main.py index e830e86..8da2462 100644 --- a/main.py +++ b/main.py @@ -42,7 +42,7 @@ # 重定向 uvicorn 的日志 uvicorn_logger = logging.getLogger("uvicorn") uvicorn_logger.handlers = logger.handlers - uvicorn_logger.setLevel(logging.DEBUG) + uvicorn_logger.setLevel(logging.INFO) uvicorn.run(app, host="0.0.0.0", port=9000, log_config=None) diff --git a/model_handler/base_model_handler.py b/model_handler/base_model_handler.py index 8797f0d..46e0d53 100644 --- a/model_handler/base_model_handler.py +++ b/model_handler/base_model_handler.py @@ -1,5 +1,6 @@ from algo.model_manager import AlgoModelExec from common.global_logger import logger +from common.image_plotting import colors class BaseModelHandler: @@ -35,8 +36,8 @@ if annotator is not None: for s_box in model_result.boxes: annotator.box_label(s_box.xyxy.cpu().squeeze(), - f"{int(s_box.cls)} {float(s_box.conf):.2f}", - color=(255, 0, 0), + f"{self.model_names[int(s_box.cls)]} {float(s_box.conf):.2f}", + color=colors(int(s_box.cls)), rotated=False) return results diff --git a/model_handler/coco_engine_model_handler.py b/model_handler/coco_engine_model_handler.py new file mode 100644 index 0000000..e6023d0 --- /dev/null +++ b/model_handler/coco_engine_model_handler.py @@ -0,0 +1,90 @@ +from algo.model_manager import AlgoModelExec +from model_handler.base_model_handler import BaseModelHandler + + +class CocoEngineModelHandler(BaseModelHandler): + + def __init__(self, model: AlgoModelExec): + super().__init__(model) + self.model_names = { + 0: 'person', + 1: 'bicycle', + 2: 'car', + 3: 'motorcycle', + 4: 'airplane', + 5: 'bus', + 6: 'train', + 7: 'truck', + 8: 'boat', + 9: 'traffic light', + 10: 'fire hydrant', + 11: 'stop sign', + 12: 'parking meter', + 13: 'bench', + 14: 'bird', + 15: 'cat', + 16: 'dog', + 17: 'horse', + 18: 'sheep', + 19: 'cow', + 20: 'elephant', + 21: 'bear', + 22: 'zebra', + 23: 'giraffe', + 24: 'backpack', + 25: 'umbrella', + 26: 'handbag', + 27: 'tie', + 28: 'suitcase', + 29: 'frisbee', + 30: 'skis', + 31: 'snowboard', + 32: 'sports ball', + 33: 'kite', + 34: 'baseball bat', + 35: 'baseball glove', + 36: 'skateboard', + 37: 'surfboard', + 38: 'tennis racket', + 39: 'bottle', + 40: 'wine glass', + 41: 'cup', + 42: 'fork', + 43: 'knife', + 44: 'spoon', + 45: 'bowl', + 46: 'banana', + 47: 'apple', + 48: 'sandwich', + 49: 'orange', + 50: 'broccoli', + 51: 'carrot', + 52: 'hot dog', + 53: 'pizza', + 54: 'donut', + 55: 'cake', + 56: 'chair', + 57: 'couch', + 58: 'potted plant', + 59: 'bed', + 60: 'dining table', + 61: 'toilet', + 62: 'tv', + 63: 'laptop', + 64: 'mouse', + 65: 'remote', + 66: 'keyboard', + 67: 'cell phone', + 68: 'microwave', + 69: 'oven', + 70: 'toaster', + 71: 'sink', + 72: 'refrigerator', + 73: 'book', + 74: 'clock', + 75: 'vase', + 76: 'scissors', + 77: 'teddy bear', + 78: 'hair_drier', + 79: 'toothbrush', + } diff --git a/model_handler/labor_engine_model_handler.py b/model_handler/labor_engine_model_handler.py new file mode 100644 index 0000000..d0fefba --- /dev/null +++ b/model_handler/labor_engine_model_handler.py @@ -0,0 +1,70 @@ +from algo.model_manager import AlgoModelExec +from model_handler.base_model_handler import BaseModelHandler + + +class LaborEngineModelHandler(BaseModelHandler): + + def __init__(self, model: AlgoModelExec): + super().__init__(model) + self.model_names = { + 0: '三脚架', + 1: '三通', + 2: '专用软管', + 3: '人', + 4: '作业信息公示牌', + 5: '切断阀', + 6: '危险告知牌', + 7: '压力测试仪', + 8: '压力表', + 9: '反光衣', + 10: '可燃气体报警控制器', + 11: '呼吸面罩', + 12: '喉箍', + 13: '四合一', + 14: '圆头水枪', + 15: '头', + 16: '安全告知牌', + 17: '安全带', + 18: '安全帽', + 19: '安全标识', + 20: '安全标识牌', + 21: '安全绳', + 22: '对讲机', + 23: '尖头水枪', + 24: '工服', + 25: '开关', + 26: '报警装置', + 27: '接头', + 28: '施工路牌', + 29: '气体检测仪', + 30: '水带', + 31: '水带_矩形', + 32: '流量计', + 33: '消火栓箱', + 34: '灭火器', + 35: '灶台', + 36: '灶眼', + 37: '照明设备', + 38: '熄火保护', + 39: '燃气管道', + 40: '燃气计量器具', + 41: '电线暴露', + 42: '电路图', + 43: '警戒线', + 44: '调压器', + 45: '调长器', + 46: '贴纸', + 47: '跨电线', + 48: '路锥', + 49: '过滤器', + 50: '配电箱内部', + 51: '配电箱外部', + 52: '长柄阀门', + 53: '闪光灯亮', + 54: '闪光灯灭', + 55: '阀门', + 56: '非专用软管', + 57: '风管', + 58: '鼓风机', + + } diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 8724c0d..6a86830 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -1,5 +1,6 @@ import asyncio import json +from copy import deepcopy from dataclasses import dataclass import importlib from datetime import datetime @@ -9,9 +10,11 @@ from algo.model_manager import AlgoModelExec from algo.stream_loader import OpenCVStreamLoad from common.device_status_manager import DeviceStatusManager +from common.display_frame_manager import DisplayFrameManager from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool from common.http_utils import send_request +from common.image_plotting import Annotator from common.string_utils import camel_to_snake, get_class, default_serializer from db.database import get_db from entity.device import Device @@ -52,6 +55,7 @@ self.thread_pool = GlobalThreadPool() self.device_status_manager = DeviceStatusManager() + self.display_frame_manager = DisplayFrameManager() self.main_loop = main_loop self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, @@ -134,12 +138,17 @@ self.device_status_manager.set_status(device_id=self.device.id) results_map = {} + annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") for model_exec in self.model_exec_list: handle_task_name = model_exec.algo_model_info.handle_task handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) handler_instance = handler_cls(model_exec) - frame, results = handler_instance.run(frame, None) + frame, results = handler_instance.run(frame, annotator) results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] # 结果处理 self.thread_pool.submit_task(self.save_frame_results, frame, results_map) + self.display_frame_manager.add_frame(self.device.id, annotator.result()) + # future = asyncio.run_coroutine_threadsafe( + # self.display_frame_manager.add_frame(self.device.id, annotator.result()), self.main_loop + # ) diff --git a/apis/control.py b/apis/control.py index 4c108c5..9c6d6fb 100644 --- a/apis/control.py +++ b/apis/control.py @@ -98,3 +98,11 @@ except Exception as e: traceback.print_exc() return standard_error_response(code=500, message=f"Failed to restart container: {e}") + +@router.get("/sync_test") +def sync_test(): + return standard_response() + +@router.get("/async_test") +async def async_test(): + return standard_response() diff --git a/apis/display.py b/apis/display.py new file mode 100644 index 0000000..64c998a --- /dev/null +++ b/apis/display.py @@ -0,0 +1,73 @@ +import asyncio +import threading + +import cv2 +from fastapi import APIRouter, Request, WebSocket, WebSocketDisconnect +from starlette.responses import StreamingResponse + +from common.display_frame_manager import DisplayFrameManager +from common.global_logger import logger + +router = APIRouter() +display_frame_manager = DisplayFrameManager() + + +async def generate_video_stream(device_id, request: Request): + while True: + if await request.is_disconnected(): + print("客户端已断开连接,停止视频流") + break + + if display_frame_manager.has_device(device_id): + frame = display_frame_manager.get_latest_frame(device_id) + # 将帧编码为 JPEG 格式 + _, buffer = cv2.imencode('.jpg', frame) + frame = buffer.tobytes() + + yield (b'--frame\r\n' + b'Content-Type: image/jpeg\r\n\r\n' + frame + b'\r\n') + + await asyncio.sleep(0.03) # 控制帧的推送频率,减小CPU占用 + + +@router.get("/video") +async def video_stream(device_id: int, request: Request): + return StreamingResponse(generate_video_stream(device_id, request), + media_type="multipart/x-mixed-replace; boundary=frame") + + +async def send_video_stream(device_id, websocket: WebSocket): + await websocket.accept() # 接受 WebSocket 连接 + try: + while True: + # 检查并获取设备的最新帧 + frame = None + if display_frame_manager.has_device(device_id): + frame = display_frame_manager.get_latest_frame(device_id) + + # 如果有可用帧,则进行处理并发送 + if frame is not None: + # 将帧编码为 JPEG + _, buffer = cv2.imencode('.jpg', frame) + # 将图像数据编码为 Base64 字符串(或者可以直接发送字节流) + await websocket.send_bytes(buffer.tobytes()) + await asyncio.sleep(0) + else: + # 如果没有帧,可以选择发送空帧或跳过 + await asyncio.sleep(0.03) # 控制帧率,避免过度占用资源 + except WebSocketDisconnect: + logger.info("WebSocket 连接已断开") + except Exception as e: + logger.error(f"WebSocket 连接出错: {e}") + + +def run_in_thread(device_id, websocket: WebSocket): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(send_video_stream(device_id, websocket)) + loop.close() + + +@router.websocket("/ws/video/{device_id}") +async def video_stream(websocket: WebSocket, device_id: int): + await send_video_stream(device_id,websocket) diff --git a/apis/router.py b/apis/router.py index 18e1311..4f575f6 100644 --- a/apis/router.py +++ b/apis/router.py @@ -9,7 +9,7 @@ from .data_gas import router as gas_router from .control import router as control_router from .push_config import router as push_config_router - +from .display import router as display_router # 创建一个全局的 router router = APIRouter() @@ -18,10 +18,10 @@ router.include_router(devices_router, prefix="/device", tags=["Devices"]) router.include_router(models_router, prefix="/model", tags=["Models"]) router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"]) -router.include_router(scene_router,prefix="/scene", tags=["Scene"]) +router.include_router(scene_router, prefix="/scene", tags=["Scene"]) router.include_router(device_scene_relation_router, prefix="/device_scene_relation", tags=["DeviceSceneRelations"]) router.include_router(frame_router, prefix="/frame", tags=["DeviceFrame"]) router.include_router(gas_router, prefix="/gas", tags=["DataGas"]) -router.include_router(control_router,prefix="/control", tags=["Control"]) -router.include_router(push_config_router,prefix="/push", tags=["PushConfig"]) - +router.include_router(control_router, prefix="/control", tags=["Control"]) +router.include_router(push_config_router, prefix="/push", tags=["PushConfig"]) +router.include_router(display_router, prefix="/display", tags=["Display"]) diff --git a/app_instance.py b/app_instance.py index 4e24675..82a750d 100644 --- a/app_instance.py +++ b/app_instance.py @@ -49,7 +49,7 @@ tcp_manager = TcpManager(device_service=device_service) app.state.tcp_manager = tcp_manager - # await tcp_manager.start() + await tcp_manager.start() algo_runner = AlgoRunner( device_service=device_service, @@ -57,7 +57,7 @@ relation_service=model_relation_service, ) app.state.algo_runner = algo_runner - # await algo_runner.start() + await algo_runner.start() scene_runner = SceneRunner( device_service=device_service, diff --git a/common/display_frame_manager.py b/common/display_frame_manager.py new file mode 100644 index 0000000..d7f8354 --- /dev/null +++ b/common/display_frame_manager.py @@ -0,0 +1,46 @@ +import threading +from collections import deque + + +class DisplayFrameManager: + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + """实现单例模式,确保只有一个实例""" + if cls._instance is None: + with cls._lock: + if cls._instance is None: # 防止多个线程同时创建 + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self, maxlen=10): + if not hasattr(self, '_initialized'): # 防止重复初始化 + self.device_queues = {} + self.maxlen = maxlen + self._initialized = True + + def add_frame(self, device_id, frame): + """向设备队列中添加帧,如果设备不存在,则自动创建队列""" + with self._lock: + if device_id not in self.device_queues: + self.device_queues[device_id] = deque(maxlen=self.maxlen) + if len(self.device_queues[device_id]) >= self.maxlen: + self.device_queues[device_id].pop() # 移除最旧的帧 + self.device_queues[device_id].appendleft(frame) # 添加新帧 + + def get_latest_frame(self, device_id): + """获取设备队列中的最新帧""" + with self._lock: + if device_id in self.device_queues and self.device_queues[device_id]: + return self.device_queues[device_id][0] + return None # 如果设备不存在或没有帧,返回 None + + def remove_device(self, device_id): + """移除设备及其队列""" + with self._lock: + if device_id in self.device_queues: + del self.device_queues[device_id] + + def has_device(self,device_id): + return device_id in self.device_queues and self.device_queues[device_id] diff --git a/common/global_logger.py b/common/global_logger.py index c975ea8..47e41b7 100644 --- a/common/global_logger.py +++ b/common/global_logger.py @@ -13,7 +13,6 @@ logger = logging.getLogger("casic_safe_logger") logger.setLevel(logging.DEBUG) # 设置日志级别 -logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING) # 创建一个TimedRotatingFileHandler handler = TimedRotatingFileHandler( diff --git a/common/image_plotting.py b/common/image_plotting.py index 818ab95..5a2d942 100644 --- a/common/image_plotting.py +++ b/common/image_plotting.py @@ -1,6 +1,8 @@ +from collections.abc import Sequence + import numpy as np import cv2 - +from PIL import Image, ImageDraw, ImageFont class Colors: """ @@ -85,30 +87,38 @@ def __init__(self, im, line_width=None, font_size=None, font="Arial.ttf", pil=False, example="abc"): self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) - assert im.data.contiguous, "Image not contiguous. Apply np.ascontiguousarray(im) to Annotator input images." - self.im = im if im.flags.writeable else im.copy() + + self.im = Image.fromarray(im) + self.draw = ImageDraw.Draw(self.im) + + font = 'static/font/Arial.Unicode.ttf' + size = font_size or max(round(sum(self.im.size) / 2 * 0.035), 12) + self.font = ImageFont.truetype(str(font), size) + if not hasattr(self.font, 'getsize'): + self.font.getsize = lambda x: self.font.getbbox(x)[2:4] + self.tf = max(self.lw - 1, 1) # font thickness self.sf = self.lw / 3 # font scale def box_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), rotated=False): - """Add one xyxy box to image with label.""" - p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3])) - cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA) + if not isinstance(box, Sequence): + box = box.tolist() + if rotated: + p1 = box[0] + # NOTE: PIL-version polygon needs tuple type. + self.draw.polygon([tuple(b) for b in box], width=self.lw, outline=color) + else: + p1 = (box[0], box[1]) + self.draw.rectangle(box, width=self.lw, outline=color) # box if label: - w, h = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0] # text width, height - outside = p1[1] - h >= 3 - p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3 - cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled - cv2.putText( - self.im, - label, - (p1[0], p1[1] - 2 if outside else p1[1] + h + 2), - 0, - self.sf, - txt_color, - thickness=self.tf, - lineType=cv2.LINE_AA, + w, h = self.font.getsize(label) # text width, height + outside = p1[1] - h >= 0 # label fits outside box + self.draw.rectangle( + (p1[0], p1[1] - h if outside else p1[1], p1[0] + w + 1, p1[1] + 1 if outside else p1[1] + h + 1), + fill=color, ) + # self.draw.text((box[0], box[1]), label, fill=txt_color, font=self.font, anchor='ls') # for PIL>8.0 + self.draw.text((p1[0], p1[1] - h if outside else p1[1]), label, fill=txt_color, font=self.font) def result(self): """Return annotated image as array.""" diff --git a/main.py b/main.py index e830e86..8da2462 100644 --- a/main.py +++ b/main.py @@ -42,7 +42,7 @@ # 重定向 uvicorn 的日志 uvicorn_logger = logging.getLogger("uvicorn") uvicorn_logger.handlers = logger.handlers - uvicorn_logger.setLevel(logging.DEBUG) + uvicorn_logger.setLevel(logging.INFO) uvicorn.run(app, host="0.0.0.0", port=9000, log_config=None) diff --git a/model_handler/base_model_handler.py b/model_handler/base_model_handler.py index 8797f0d..46e0d53 100644 --- a/model_handler/base_model_handler.py +++ b/model_handler/base_model_handler.py @@ -1,5 +1,6 @@ from algo.model_manager import AlgoModelExec from common.global_logger import logger +from common.image_plotting import colors class BaseModelHandler: @@ -35,8 +36,8 @@ if annotator is not None: for s_box in model_result.boxes: annotator.box_label(s_box.xyxy.cpu().squeeze(), - f"{int(s_box.cls)} {float(s_box.conf):.2f}", - color=(255, 0, 0), + f"{self.model_names[int(s_box.cls)]} {float(s_box.conf):.2f}", + color=colors(int(s_box.cls)), rotated=False) return results diff --git a/model_handler/coco_engine_model_handler.py b/model_handler/coco_engine_model_handler.py new file mode 100644 index 0000000..e6023d0 --- /dev/null +++ b/model_handler/coco_engine_model_handler.py @@ -0,0 +1,90 @@ +from algo.model_manager import AlgoModelExec +from model_handler.base_model_handler import BaseModelHandler + + +class CocoEngineModelHandler(BaseModelHandler): + + def __init__(self, model: AlgoModelExec): + super().__init__(model) + self.model_names = { + 0: 'person', + 1: 'bicycle', + 2: 'car', + 3: 'motorcycle', + 4: 'airplane', + 5: 'bus', + 6: 'train', + 7: 'truck', + 8: 'boat', + 9: 'traffic light', + 10: 'fire hydrant', + 11: 'stop sign', + 12: 'parking meter', + 13: 'bench', + 14: 'bird', + 15: 'cat', + 16: 'dog', + 17: 'horse', + 18: 'sheep', + 19: 'cow', + 20: 'elephant', + 21: 'bear', + 22: 'zebra', + 23: 'giraffe', + 24: 'backpack', + 25: 'umbrella', + 26: 'handbag', + 27: 'tie', + 28: 'suitcase', + 29: 'frisbee', + 30: 'skis', + 31: 'snowboard', + 32: 'sports ball', + 33: 'kite', + 34: 'baseball bat', + 35: 'baseball glove', + 36: 'skateboard', + 37: 'surfboard', + 38: 'tennis racket', + 39: 'bottle', + 40: 'wine glass', + 41: 'cup', + 42: 'fork', + 43: 'knife', + 44: 'spoon', + 45: 'bowl', + 46: 'banana', + 47: 'apple', + 48: 'sandwich', + 49: 'orange', + 50: 'broccoli', + 51: 'carrot', + 52: 'hot dog', + 53: 'pizza', + 54: 'donut', + 55: 'cake', + 56: 'chair', + 57: 'couch', + 58: 'potted plant', + 59: 'bed', + 60: 'dining table', + 61: 'toilet', + 62: 'tv', + 63: 'laptop', + 64: 'mouse', + 65: 'remote', + 66: 'keyboard', + 67: 'cell phone', + 68: 'microwave', + 69: 'oven', + 70: 'toaster', + 71: 'sink', + 72: 'refrigerator', + 73: 'book', + 74: 'clock', + 75: 'vase', + 76: 'scissors', + 77: 'teddy bear', + 78: 'hair_drier', + 79: 'toothbrush', + } diff --git a/model_handler/labor_engine_model_handler.py b/model_handler/labor_engine_model_handler.py new file mode 100644 index 0000000..d0fefba --- /dev/null +++ b/model_handler/labor_engine_model_handler.py @@ -0,0 +1,70 @@ +from algo.model_manager import AlgoModelExec +from model_handler.base_model_handler import BaseModelHandler + + +class LaborEngineModelHandler(BaseModelHandler): + + def __init__(self, model: AlgoModelExec): + super().__init__(model) + self.model_names = { + 0: '三脚架', + 1: '三通', + 2: '专用软管', + 3: '人', + 4: '作业信息公示牌', + 5: '切断阀', + 6: '危险告知牌', + 7: '压力测试仪', + 8: '压力表', + 9: '反光衣', + 10: '可燃气体报警控制器', + 11: '呼吸面罩', + 12: '喉箍', + 13: '四合一', + 14: '圆头水枪', + 15: '头', + 16: '安全告知牌', + 17: '安全带', + 18: '安全帽', + 19: '安全标识', + 20: '安全标识牌', + 21: '安全绳', + 22: '对讲机', + 23: '尖头水枪', + 24: '工服', + 25: '开关', + 26: '报警装置', + 27: '接头', + 28: '施工路牌', + 29: '气体检测仪', + 30: '水带', + 31: '水带_矩形', + 32: '流量计', + 33: '消火栓箱', + 34: '灭火器', + 35: '灶台', + 36: '灶眼', + 37: '照明设备', + 38: '熄火保护', + 39: '燃气管道', + 40: '燃气计量器具', + 41: '电线暴露', + 42: '电路图', + 43: '警戒线', + 44: '调压器', + 45: '调长器', + 46: '贴纸', + 47: '跨电线', + 48: '路锥', + 49: '过滤器', + 50: '配电箱内部', + 51: '配电箱外部', + 52: '长柄阀门', + 53: '闪光灯亮', + 54: '闪光灯灭', + 55: '阀门', + 56: '非专用软管', + 57: '风管', + 58: '鼓风机', + + } diff --git a/static/font/Arial.Unicode.ttf b/static/font/Arial.Unicode.ttf new file mode 100644 index 0000000..1537c5b --- /dev/null +++ b/static/font/Arial.Unicode.ttf Binary files differ diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 8724c0d..6a86830 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -1,5 +1,6 @@ import asyncio import json +from copy import deepcopy from dataclasses import dataclass import importlib from datetime import datetime @@ -9,9 +10,11 @@ from algo.model_manager import AlgoModelExec from algo.stream_loader import OpenCVStreamLoad from common.device_status_manager import DeviceStatusManager +from common.display_frame_manager import DisplayFrameManager from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool from common.http_utils import send_request +from common.image_plotting import Annotator from common.string_utils import camel_to_snake, get_class, default_serializer from db.database import get_db from entity.device import Device @@ -52,6 +55,7 @@ self.thread_pool = GlobalThreadPool() self.device_status_manager = DeviceStatusManager() + self.display_frame_manager = DisplayFrameManager() self.main_loop = main_loop self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, @@ -134,12 +138,17 @@ self.device_status_manager.set_status(device_id=self.device.id) results_map = {} + annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") for model_exec in self.model_exec_list: handle_task_name = model_exec.algo_model_info.handle_task handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) handler_instance = handler_cls(model_exec) - frame, results = handler_instance.run(frame, None) + frame, results = handler_instance.run(frame, annotator) results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] # 结果处理 self.thread_pool.submit_task(self.save_frame_results, frame, results_map) + self.display_frame_manager.add_frame(self.device.id, annotator.result()) + # future = asyncio.run_coroutine_threadsafe( + # self.display_frame_manager.add_frame(self.device.id, annotator.result()), self.main_loop + # ) diff --git a/apis/control.py b/apis/control.py index 4c108c5..9c6d6fb 100644 --- a/apis/control.py +++ b/apis/control.py @@ -98,3 +98,11 @@ except Exception as e: traceback.print_exc() return standard_error_response(code=500, message=f"Failed to restart container: {e}") + +@router.get("/sync_test") +def sync_test(): + return standard_response() + +@router.get("/async_test") +async def async_test(): + return standard_response() diff --git a/apis/display.py b/apis/display.py new file mode 100644 index 0000000..64c998a --- /dev/null +++ b/apis/display.py @@ -0,0 +1,73 @@ +import asyncio +import threading + +import cv2 +from fastapi import APIRouter, Request, WebSocket, WebSocketDisconnect +from starlette.responses import StreamingResponse + +from common.display_frame_manager import DisplayFrameManager +from common.global_logger import logger + +router = APIRouter() +display_frame_manager = DisplayFrameManager() + + +async def generate_video_stream(device_id, request: Request): + while True: + if await request.is_disconnected(): + print("客户端已断开连接,停止视频流") + break + + if display_frame_manager.has_device(device_id): + frame = display_frame_manager.get_latest_frame(device_id) + # 将帧编码为 JPEG 格式 + _, buffer = cv2.imencode('.jpg', frame) + frame = buffer.tobytes() + + yield (b'--frame\r\n' + b'Content-Type: image/jpeg\r\n\r\n' + frame + b'\r\n') + + await asyncio.sleep(0.03) # 控制帧的推送频率,减小CPU占用 + + +@router.get("/video") +async def video_stream(device_id: int, request: Request): + return StreamingResponse(generate_video_stream(device_id, request), + media_type="multipart/x-mixed-replace; boundary=frame") + + +async def send_video_stream(device_id, websocket: WebSocket): + await websocket.accept() # 接受 WebSocket 连接 + try: + while True: + # 检查并获取设备的最新帧 + frame = None + if display_frame_manager.has_device(device_id): + frame = display_frame_manager.get_latest_frame(device_id) + + # 如果有可用帧,则进行处理并发送 + if frame is not None: + # 将帧编码为 JPEG + _, buffer = cv2.imencode('.jpg', frame) + # 将图像数据编码为 Base64 字符串(或者可以直接发送字节流) + await websocket.send_bytes(buffer.tobytes()) + await asyncio.sleep(0) + else: + # 如果没有帧,可以选择发送空帧或跳过 + await asyncio.sleep(0.03) # 控制帧率,避免过度占用资源 + except WebSocketDisconnect: + logger.info("WebSocket 连接已断开") + except Exception as e: + logger.error(f"WebSocket 连接出错: {e}") + + +def run_in_thread(device_id, websocket: WebSocket): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(send_video_stream(device_id, websocket)) + loop.close() + + +@router.websocket("/ws/video/{device_id}") +async def video_stream(websocket: WebSocket, device_id: int): + await send_video_stream(device_id,websocket) diff --git a/apis/router.py b/apis/router.py index 18e1311..4f575f6 100644 --- a/apis/router.py +++ b/apis/router.py @@ -9,7 +9,7 @@ from .data_gas import router as gas_router from .control import router as control_router from .push_config import router as push_config_router - +from .display import router as display_router # 创建一个全局的 router router = APIRouter() @@ -18,10 +18,10 @@ router.include_router(devices_router, prefix="/device", tags=["Devices"]) router.include_router(models_router, prefix="/model", tags=["Models"]) router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"]) -router.include_router(scene_router,prefix="/scene", tags=["Scene"]) +router.include_router(scene_router, prefix="/scene", tags=["Scene"]) router.include_router(device_scene_relation_router, prefix="/device_scene_relation", tags=["DeviceSceneRelations"]) router.include_router(frame_router, prefix="/frame", tags=["DeviceFrame"]) router.include_router(gas_router, prefix="/gas", tags=["DataGas"]) -router.include_router(control_router,prefix="/control", tags=["Control"]) -router.include_router(push_config_router,prefix="/push", tags=["PushConfig"]) - +router.include_router(control_router, prefix="/control", tags=["Control"]) +router.include_router(push_config_router, prefix="/push", tags=["PushConfig"]) +router.include_router(display_router, prefix="/display", tags=["Display"]) diff --git a/app_instance.py b/app_instance.py index 4e24675..82a750d 100644 --- a/app_instance.py +++ b/app_instance.py @@ -49,7 +49,7 @@ tcp_manager = TcpManager(device_service=device_service) app.state.tcp_manager = tcp_manager - # await tcp_manager.start() + await tcp_manager.start() algo_runner = AlgoRunner( device_service=device_service, @@ -57,7 +57,7 @@ relation_service=model_relation_service, ) app.state.algo_runner = algo_runner - # await algo_runner.start() + await algo_runner.start() scene_runner = SceneRunner( device_service=device_service, diff --git a/common/display_frame_manager.py b/common/display_frame_manager.py new file mode 100644 index 0000000..d7f8354 --- /dev/null +++ b/common/display_frame_manager.py @@ -0,0 +1,46 @@ +import threading +from collections import deque + + +class DisplayFrameManager: + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + """实现单例模式,确保只有一个实例""" + if cls._instance is None: + with cls._lock: + if cls._instance is None: # 防止多个线程同时创建 + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self, maxlen=10): + if not hasattr(self, '_initialized'): # 防止重复初始化 + self.device_queues = {} + self.maxlen = maxlen + self._initialized = True + + def add_frame(self, device_id, frame): + """向设备队列中添加帧,如果设备不存在,则自动创建队列""" + with self._lock: + if device_id not in self.device_queues: + self.device_queues[device_id] = deque(maxlen=self.maxlen) + if len(self.device_queues[device_id]) >= self.maxlen: + self.device_queues[device_id].pop() # 移除最旧的帧 + self.device_queues[device_id].appendleft(frame) # 添加新帧 + + def get_latest_frame(self, device_id): + """获取设备队列中的最新帧""" + with self._lock: + if device_id in self.device_queues and self.device_queues[device_id]: + return self.device_queues[device_id][0] + return None # 如果设备不存在或没有帧,返回 None + + def remove_device(self, device_id): + """移除设备及其队列""" + with self._lock: + if device_id in self.device_queues: + del self.device_queues[device_id] + + def has_device(self,device_id): + return device_id in self.device_queues and self.device_queues[device_id] diff --git a/common/global_logger.py b/common/global_logger.py index c975ea8..47e41b7 100644 --- a/common/global_logger.py +++ b/common/global_logger.py @@ -13,7 +13,6 @@ logger = logging.getLogger("casic_safe_logger") logger.setLevel(logging.DEBUG) # 设置日志级别 -logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING) # 创建一个TimedRotatingFileHandler handler = TimedRotatingFileHandler( diff --git a/common/image_plotting.py b/common/image_plotting.py index 818ab95..5a2d942 100644 --- a/common/image_plotting.py +++ b/common/image_plotting.py @@ -1,6 +1,8 @@ +from collections.abc import Sequence + import numpy as np import cv2 - +from PIL import Image, ImageDraw, ImageFont class Colors: """ @@ -85,30 +87,38 @@ def __init__(self, im, line_width=None, font_size=None, font="Arial.ttf", pil=False, example="abc"): self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) - assert im.data.contiguous, "Image not contiguous. Apply np.ascontiguousarray(im) to Annotator input images." - self.im = im if im.flags.writeable else im.copy() + + self.im = Image.fromarray(im) + self.draw = ImageDraw.Draw(self.im) + + font = 'static/font/Arial.Unicode.ttf' + size = font_size or max(round(sum(self.im.size) / 2 * 0.035), 12) + self.font = ImageFont.truetype(str(font), size) + if not hasattr(self.font, 'getsize'): + self.font.getsize = lambda x: self.font.getbbox(x)[2:4] + self.tf = max(self.lw - 1, 1) # font thickness self.sf = self.lw / 3 # font scale def box_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), rotated=False): - """Add one xyxy box to image with label.""" - p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3])) - cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA) + if not isinstance(box, Sequence): + box = box.tolist() + if rotated: + p1 = box[0] + # NOTE: PIL-version polygon needs tuple type. + self.draw.polygon([tuple(b) for b in box], width=self.lw, outline=color) + else: + p1 = (box[0], box[1]) + self.draw.rectangle(box, width=self.lw, outline=color) # box if label: - w, h = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0] # text width, height - outside = p1[1] - h >= 3 - p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3 - cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled - cv2.putText( - self.im, - label, - (p1[0], p1[1] - 2 if outside else p1[1] + h + 2), - 0, - self.sf, - txt_color, - thickness=self.tf, - lineType=cv2.LINE_AA, + w, h = self.font.getsize(label) # text width, height + outside = p1[1] - h >= 0 # label fits outside box + self.draw.rectangle( + (p1[0], p1[1] - h if outside else p1[1], p1[0] + w + 1, p1[1] + 1 if outside else p1[1] + h + 1), + fill=color, ) + # self.draw.text((box[0], box[1]), label, fill=txt_color, font=self.font, anchor='ls') # for PIL>8.0 + self.draw.text((p1[0], p1[1] - h if outside else p1[1]), label, fill=txt_color, font=self.font) def result(self): """Return annotated image as array.""" diff --git a/main.py b/main.py index e830e86..8da2462 100644 --- a/main.py +++ b/main.py @@ -42,7 +42,7 @@ # 重定向 uvicorn 的日志 uvicorn_logger = logging.getLogger("uvicorn") uvicorn_logger.handlers = logger.handlers - uvicorn_logger.setLevel(logging.DEBUG) + uvicorn_logger.setLevel(logging.INFO) uvicorn.run(app, host="0.0.0.0", port=9000, log_config=None) diff --git a/model_handler/base_model_handler.py b/model_handler/base_model_handler.py index 8797f0d..46e0d53 100644 --- a/model_handler/base_model_handler.py +++ b/model_handler/base_model_handler.py @@ -1,5 +1,6 @@ from algo.model_manager import AlgoModelExec from common.global_logger import logger +from common.image_plotting import colors class BaseModelHandler: @@ -35,8 +36,8 @@ if annotator is not None: for s_box in model_result.boxes: annotator.box_label(s_box.xyxy.cpu().squeeze(), - f"{int(s_box.cls)} {float(s_box.conf):.2f}", - color=(255, 0, 0), + f"{self.model_names[int(s_box.cls)]} {float(s_box.conf):.2f}", + color=colors(int(s_box.cls)), rotated=False) return results diff --git a/model_handler/coco_engine_model_handler.py b/model_handler/coco_engine_model_handler.py new file mode 100644 index 0000000..e6023d0 --- /dev/null +++ b/model_handler/coco_engine_model_handler.py @@ -0,0 +1,90 @@ +from algo.model_manager import AlgoModelExec +from model_handler.base_model_handler import BaseModelHandler + + +class CocoEngineModelHandler(BaseModelHandler): + + def __init__(self, model: AlgoModelExec): + super().__init__(model) + self.model_names = { + 0: 'person', + 1: 'bicycle', + 2: 'car', + 3: 'motorcycle', + 4: 'airplane', + 5: 'bus', + 6: 'train', + 7: 'truck', + 8: 'boat', + 9: 'traffic light', + 10: 'fire hydrant', + 11: 'stop sign', + 12: 'parking meter', + 13: 'bench', + 14: 'bird', + 15: 'cat', + 16: 'dog', + 17: 'horse', + 18: 'sheep', + 19: 'cow', + 20: 'elephant', + 21: 'bear', + 22: 'zebra', + 23: 'giraffe', + 24: 'backpack', + 25: 'umbrella', + 26: 'handbag', + 27: 'tie', + 28: 'suitcase', + 29: 'frisbee', + 30: 'skis', + 31: 'snowboard', + 32: 'sports ball', + 33: 'kite', + 34: 'baseball bat', + 35: 'baseball glove', + 36: 'skateboard', + 37: 'surfboard', + 38: 'tennis racket', + 39: 'bottle', + 40: 'wine glass', + 41: 'cup', + 42: 'fork', + 43: 'knife', + 44: 'spoon', + 45: 'bowl', + 46: 'banana', + 47: 'apple', + 48: 'sandwich', + 49: 'orange', + 50: 'broccoli', + 51: 'carrot', + 52: 'hot dog', + 53: 'pizza', + 54: 'donut', + 55: 'cake', + 56: 'chair', + 57: 'couch', + 58: 'potted plant', + 59: 'bed', + 60: 'dining table', + 61: 'toilet', + 62: 'tv', + 63: 'laptop', + 64: 'mouse', + 65: 'remote', + 66: 'keyboard', + 67: 'cell phone', + 68: 'microwave', + 69: 'oven', + 70: 'toaster', + 71: 'sink', + 72: 'refrigerator', + 73: 'book', + 74: 'clock', + 75: 'vase', + 76: 'scissors', + 77: 'teddy bear', + 78: 'hair_drier', + 79: 'toothbrush', + } diff --git a/model_handler/labor_engine_model_handler.py b/model_handler/labor_engine_model_handler.py new file mode 100644 index 0000000..d0fefba --- /dev/null +++ b/model_handler/labor_engine_model_handler.py @@ -0,0 +1,70 @@ +from algo.model_manager import AlgoModelExec +from model_handler.base_model_handler import BaseModelHandler + + +class LaborEngineModelHandler(BaseModelHandler): + + def __init__(self, model: AlgoModelExec): + super().__init__(model) + self.model_names = { + 0: '三脚架', + 1: '三通', + 2: '专用软管', + 3: '人', + 4: '作业信息公示牌', + 5: '切断阀', + 6: '危险告知牌', + 7: '压力测试仪', + 8: '压力表', + 9: '反光衣', + 10: '可燃气体报警控制器', + 11: '呼吸面罩', + 12: '喉箍', + 13: '四合一', + 14: '圆头水枪', + 15: '头', + 16: '安全告知牌', + 17: '安全带', + 18: '安全帽', + 19: '安全标识', + 20: '安全标识牌', + 21: '安全绳', + 22: '对讲机', + 23: '尖头水枪', + 24: '工服', + 25: '开关', + 26: '报警装置', + 27: '接头', + 28: '施工路牌', + 29: '气体检测仪', + 30: '水带', + 31: '水带_矩形', + 32: '流量计', + 33: '消火栓箱', + 34: '灭火器', + 35: '灶台', + 36: '灶眼', + 37: '照明设备', + 38: '熄火保护', + 39: '燃气管道', + 40: '燃气计量器具', + 41: '电线暴露', + 42: '电路图', + 43: '警戒线', + 44: '调压器', + 45: '调长器', + 46: '贴纸', + 47: '跨电线', + 48: '路锥', + 49: '过滤器', + 50: '配电箱内部', + 51: '配电箱外部', + 52: '长柄阀门', + 53: '闪光灯亮', + 54: '闪光灯灭', + 55: '阀门', + 56: '非专用软管', + 57: '风管', + 58: '鼓风机', + + } diff --git a/static/font/Arial.Unicode.ttf b/static/font/Arial.Unicode.ttf new file mode 100644 index 0000000..1537c5b --- /dev/null +++ b/static/font/Arial.Unicode.ttf Binary files differ diff --git a/static/font/Arial.ttf b/static/font/Arial.ttf new file mode 100644 index 0000000..ab68fb1 --- /dev/null +++ b/static/font/Arial.ttf Binary files differ diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 8724c0d..6a86830 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -1,5 +1,6 @@ import asyncio import json +from copy import deepcopy from dataclasses import dataclass import importlib from datetime import datetime @@ -9,9 +10,11 @@ from algo.model_manager import AlgoModelExec from algo.stream_loader import OpenCVStreamLoad from common.device_status_manager import DeviceStatusManager +from common.display_frame_manager import DisplayFrameManager from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool from common.http_utils import send_request +from common.image_plotting import Annotator from common.string_utils import camel_to_snake, get_class, default_serializer from db.database import get_db from entity.device import Device @@ -52,6 +55,7 @@ self.thread_pool = GlobalThreadPool() self.device_status_manager = DeviceStatusManager() + self.display_frame_manager = DisplayFrameManager() self.main_loop = main_loop self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, @@ -134,12 +138,17 @@ self.device_status_manager.set_status(device_id=self.device.id) results_map = {} + annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人") for model_exec in self.model_exec_list: handle_task_name = model_exec.algo_model_info.handle_task handler_cls = get_class(f'model_handler.{camel_to_snake(handle_task_name)}', handle_task_name) handler_instance = handler_cls(model_exec) - frame, results = handler_instance.run(frame, None) + frame, results = handler_instance.run(frame, annotator) results_map[model_exec.algo_model_id] = [DetectionResult.from_dict(r) for r in results] # 结果处理 self.thread_pool.submit_task(self.save_frame_results, frame, results_map) + self.display_frame_manager.add_frame(self.device.id, annotator.result()) + # future = asyncio.run_coroutine_threadsafe( + # self.display_frame_manager.add_frame(self.device.id, annotator.result()), self.main_loop + # ) diff --git a/apis/control.py b/apis/control.py index 4c108c5..9c6d6fb 100644 --- a/apis/control.py +++ b/apis/control.py @@ -98,3 +98,11 @@ except Exception as e: traceback.print_exc() return standard_error_response(code=500, message=f"Failed to restart container: {e}") + +@router.get("/sync_test") +def sync_test(): + return standard_response() + +@router.get("/async_test") +async def async_test(): + return standard_response() diff --git a/apis/display.py b/apis/display.py new file mode 100644 index 0000000..64c998a --- /dev/null +++ b/apis/display.py @@ -0,0 +1,73 @@ +import asyncio +import threading + +import cv2 +from fastapi import APIRouter, Request, WebSocket, WebSocketDisconnect +from starlette.responses import StreamingResponse + +from common.display_frame_manager import DisplayFrameManager +from common.global_logger import logger + +router = APIRouter() +display_frame_manager = DisplayFrameManager() + + +async def generate_video_stream(device_id, request: Request): + while True: + if await request.is_disconnected(): + print("客户端已断开连接,停止视频流") + break + + if display_frame_manager.has_device(device_id): + frame = display_frame_manager.get_latest_frame(device_id) + # 将帧编码为 JPEG 格式 + _, buffer = cv2.imencode('.jpg', frame) + frame = buffer.tobytes() + + yield (b'--frame\r\n' + b'Content-Type: image/jpeg\r\n\r\n' + frame + b'\r\n') + + await asyncio.sleep(0.03) # 控制帧的推送频率,减小CPU占用 + + +@router.get("/video") +async def video_stream(device_id: int, request: Request): + return StreamingResponse(generate_video_stream(device_id, request), + media_type="multipart/x-mixed-replace; boundary=frame") + + +async def send_video_stream(device_id, websocket: WebSocket): + await websocket.accept() # 接受 WebSocket 连接 + try: + while True: + # 检查并获取设备的最新帧 + frame = None + if display_frame_manager.has_device(device_id): + frame = display_frame_manager.get_latest_frame(device_id) + + # 如果有可用帧,则进行处理并发送 + if frame is not None: + # 将帧编码为 JPEG + _, buffer = cv2.imencode('.jpg', frame) + # 将图像数据编码为 Base64 字符串(或者可以直接发送字节流) + await websocket.send_bytes(buffer.tobytes()) + await asyncio.sleep(0) + else: + # 如果没有帧,可以选择发送空帧或跳过 + await asyncio.sleep(0.03) # 控制帧率,避免过度占用资源 + except WebSocketDisconnect: + logger.info("WebSocket 连接已断开") + except Exception as e: + logger.error(f"WebSocket 连接出错: {e}") + + +def run_in_thread(device_id, websocket: WebSocket): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(send_video_stream(device_id, websocket)) + loop.close() + + +@router.websocket("/ws/video/{device_id}") +async def video_stream(websocket: WebSocket, device_id: int): + await send_video_stream(device_id,websocket) diff --git a/apis/router.py b/apis/router.py index 18e1311..4f575f6 100644 --- a/apis/router.py +++ b/apis/router.py @@ -9,7 +9,7 @@ from .data_gas import router as gas_router from .control import router as control_router from .push_config import router as push_config_router - +from .display import router as display_router # 创建一个全局的 router router = APIRouter() @@ -18,10 +18,10 @@ router.include_router(devices_router, prefix="/device", tags=["Devices"]) router.include_router(models_router, prefix="/model", tags=["Models"]) router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"]) -router.include_router(scene_router,prefix="/scene", tags=["Scene"]) +router.include_router(scene_router, prefix="/scene", tags=["Scene"]) router.include_router(device_scene_relation_router, prefix="/device_scene_relation", tags=["DeviceSceneRelations"]) router.include_router(frame_router, prefix="/frame", tags=["DeviceFrame"]) router.include_router(gas_router, prefix="/gas", tags=["DataGas"]) -router.include_router(control_router,prefix="/control", tags=["Control"]) -router.include_router(push_config_router,prefix="/push", tags=["PushConfig"]) - +router.include_router(control_router, prefix="/control", tags=["Control"]) +router.include_router(push_config_router, prefix="/push", tags=["PushConfig"]) +router.include_router(display_router, prefix="/display", tags=["Display"]) diff --git a/app_instance.py b/app_instance.py index 4e24675..82a750d 100644 --- a/app_instance.py +++ b/app_instance.py @@ -49,7 +49,7 @@ tcp_manager = TcpManager(device_service=device_service) app.state.tcp_manager = tcp_manager - # await tcp_manager.start() + await tcp_manager.start() algo_runner = AlgoRunner( device_service=device_service, @@ -57,7 +57,7 @@ relation_service=model_relation_service, ) app.state.algo_runner = algo_runner - # await algo_runner.start() + await algo_runner.start() scene_runner = SceneRunner( device_service=device_service, diff --git a/common/display_frame_manager.py b/common/display_frame_manager.py new file mode 100644 index 0000000..d7f8354 --- /dev/null +++ b/common/display_frame_manager.py @@ -0,0 +1,46 @@ +import threading +from collections import deque + + +class DisplayFrameManager: + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + """实现单例模式,确保只有一个实例""" + if cls._instance is None: + with cls._lock: + if cls._instance is None: # 防止多个线程同时创建 + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self, maxlen=10): + if not hasattr(self, '_initialized'): # 防止重复初始化 + self.device_queues = {} + self.maxlen = maxlen + self._initialized = True + + def add_frame(self, device_id, frame): + """向设备队列中添加帧,如果设备不存在,则自动创建队列""" + with self._lock: + if device_id not in self.device_queues: + self.device_queues[device_id] = deque(maxlen=self.maxlen) + if len(self.device_queues[device_id]) >= self.maxlen: + self.device_queues[device_id].pop() # 移除最旧的帧 + self.device_queues[device_id].appendleft(frame) # 添加新帧 + + def get_latest_frame(self, device_id): + """获取设备队列中的最新帧""" + with self._lock: + if device_id in self.device_queues and self.device_queues[device_id]: + return self.device_queues[device_id][0] + return None # 如果设备不存在或没有帧,返回 None + + def remove_device(self, device_id): + """移除设备及其队列""" + with self._lock: + if device_id in self.device_queues: + del self.device_queues[device_id] + + def has_device(self,device_id): + return device_id in self.device_queues and self.device_queues[device_id] diff --git a/common/global_logger.py b/common/global_logger.py index c975ea8..47e41b7 100644 --- a/common/global_logger.py +++ b/common/global_logger.py @@ -13,7 +13,6 @@ logger = logging.getLogger("casic_safe_logger") logger.setLevel(logging.DEBUG) # 设置日志级别 -logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING) # 创建一个TimedRotatingFileHandler handler = TimedRotatingFileHandler( diff --git a/common/image_plotting.py b/common/image_plotting.py index 818ab95..5a2d942 100644 --- a/common/image_plotting.py +++ b/common/image_plotting.py @@ -1,6 +1,8 @@ +from collections.abc import Sequence + import numpy as np import cv2 - +from PIL import Image, ImageDraw, ImageFont class Colors: """ @@ -85,30 +87,38 @@ def __init__(self, im, line_width=None, font_size=None, font="Arial.ttf", pil=False, example="abc"): self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) - assert im.data.contiguous, "Image not contiguous. Apply np.ascontiguousarray(im) to Annotator input images." - self.im = im if im.flags.writeable else im.copy() + + self.im = Image.fromarray(im) + self.draw = ImageDraw.Draw(self.im) + + font = 'static/font/Arial.Unicode.ttf' + size = font_size or max(round(sum(self.im.size) / 2 * 0.035), 12) + self.font = ImageFont.truetype(str(font), size) + if not hasattr(self.font, 'getsize'): + self.font.getsize = lambda x: self.font.getbbox(x)[2:4] + self.tf = max(self.lw - 1, 1) # font thickness self.sf = self.lw / 3 # font scale def box_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), rotated=False): - """Add one xyxy box to image with label.""" - p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3])) - cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA) + if not isinstance(box, Sequence): + box = box.tolist() + if rotated: + p1 = box[0] + # NOTE: PIL-version polygon needs tuple type. + self.draw.polygon([tuple(b) for b in box], width=self.lw, outline=color) + else: + p1 = (box[0], box[1]) + self.draw.rectangle(box, width=self.lw, outline=color) # box if label: - w, h = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0] # text width, height - outside = p1[1] - h >= 3 - p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3 - cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled - cv2.putText( - self.im, - label, - (p1[0], p1[1] - 2 if outside else p1[1] + h + 2), - 0, - self.sf, - txt_color, - thickness=self.tf, - lineType=cv2.LINE_AA, + w, h = self.font.getsize(label) # text width, height + outside = p1[1] - h >= 0 # label fits outside box + self.draw.rectangle( + (p1[0], p1[1] - h if outside else p1[1], p1[0] + w + 1, p1[1] + 1 if outside else p1[1] + h + 1), + fill=color, ) + # self.draw.text((box[0], box[1]), label, fill=txt_color, font=self.font, anchor='ls') # for PIL>8.0 + self.draw.text((p1[0], p1[1] - h if outside else p1[1]), label, fill=txt_color, font=self.font) def result(self): """Return annotated image as array.""" diff --git a/main.py b/main.py index e830e86..8da2462 100644 --- a/main.py +++ b/main.py @@ -42,7 +42,7 @@ # 重定向 uvicorn 的日志 uvicorn_logger = logging.getLogger("uvicorn") uvicorn_logger.handlers = logger.handlers - uvicorn_logger.setLevel(logging.DEBUG) + uvicorn_logger.setLevel(logging.INFO) uvicorn.run(app, host="0.0.0.0", port=9000, log_config=None) diff --git a/model_handler/base_model_handler.py b/model_handler/base_model_handler.py index 8797f0d..46e0d53 100644 --- a/model_handler/base_model_handler.py +++ b/model_handler/base_model_handler.py @@ -1,5 +1,6 @@ from algo.model_manager import AlgoModelExec from common.global_logger import logger +from common.image_plotting import colors class BaseModelHandler: @@ -35,8 +36,8 @@ if annotator is not None: for s_box in model_result.boxes: annotator.box_label(s_box.xyxy.cpu().squeeze(), - f"{int(s_box.cls)} {float(s_box.conf):.2f}", - color=(255, 0, 0), + f"{self.model_names[int(s_box.cls)]} {float(s_box.conf):.2f}", + color=colors(int(s_box.cls)), rotated=False) return results diff --git a/model_handler/coco_engine_model_handler.py b/model_handler/coco_engine_model_handler.py new file mode 100644 index 0000000..e6023d0 --- /dev/null +++ b/model_handler/coco_engine_model_handler.py @@ -0,0 +1,90 @@ +from algo.model_manager import AlgoModelExec +from model_handler.base_model_handler import BaseModelHandler + + +class CocoEngineModelHandler(BaseModelHandler): + + def __init__(self, model: AlgoModelExec): + super().__init__(model) + self.model_names = { + 0: 'person', + 1: 'bicycle', + 2: 'car', + 3: 'motorcycle', + 4: 'airplane', + 5: 'bus', + 6: 'train', + 7: 'truck', + 8: 'boat', + 9: 'traffic light', + 10: 'fire hydrant', + 11: 'stop sign', + 12: 'parking meter', + 13: 'bench', + 14: 'bird', + 15: 'cat', + 16: 'dog', + 17: 'horse', + 18: 'sheep', + 19: 'cow', + 20: 'elephant', + 21: 'bear', + 22: 'zebra', + 23: 'giraffe', + 24: 'backpack', + 25: 'umbrella', + 26: 'handbag', + 27: 'tie', + 28: 'suitcase', + 29: 'frisbee', + 30: 'skis', + 31: 'snowboard', + 32: 'sports ball', + 33: 'kite', + 34: 'baseball bat', + 35: 'baseball glove', + 36: 'skateboard', + 37: 'surfboard', + 38: 'tennis racket', + 39: 'bottle', + 40: 'wine glass', + 41: 'cup', + 42: 'fork', + 43: 'knife', + 44: 'spoon', + 45: 'bowl', + 46: 'banana', + 47: 'apple', + 48: 'sandwich', + 49: 'orange', + 50: 'broccoli', + 51: 'carrot', + 52: 'hot dog', + 53: 'pizza', + 54: 'donut', + 55: 'cake', + 56: 'chair', + 57: 'couch', + 58: 'potted plant', + 59: 'bed', + 60: 'dining table', + 61: 'toilet', + 62: 'tv', + 63: 'laptop', + 64: 'mouse', + 65: 'remote', + 66: 'keyboard', + 67: 'cell phone', + 68: 'microwave', + 69: 'oven', + 70: 'toaster', + 71: 'sink', + 72: 'refrigerator', + 73: 'book', + 74: 'clock', + 75: 'vase', + 76: 'scissors', + 77: 'teddy bear', + 78: 'hair_drier', + 79: 'toothbrush', + } diff --git a/model_handler/labor_engine_model_handler.py b/model_handler/labor_engine_model_handler.py new file mode 100644 index 0000000..d0fefba --- /dev/null +++ b/model_handler/labor_engine_model_handler.py @@ -0,0 +1,70 @@ +from algo.model_manager import AlgoModelExec +from model_handler.base_model_handler import BaseModelHandler + + +class LaborEngineModelHandler(BaseModelHandler): + + def __init__(self, model: AlgoModelExec): + super().__init__(model) + self.model_names = { + 0: '三脚架', + 1: '三通', + 2: '专用软管', + 3: '人', + 4: '作业信息公示牌', + 5: '切断阀', + 6: '危险告知牌', + 7: '压力测试仪', + 8: '压力表', + 9: '反光衣', + 10: '可燃气体报警控制器', + 11: '呼吸面罩', + 12: '喉箍', + 13: '四合一', + 14: '圆头水枪', + 15: '头', + 16: '安全告知牌', + 17: '安全带', + 18: '安全帽', + 19: '安全标识', + 20: '安全标识牌', + 21: '安全绳', + 22: '对讲机', + 23: '尖头水枪', + 24: '工服', + 25: '开关', + 26: '报警装置', + 27: '接头', + 28: '施工路牌', + 29: '气体检测仪', + 30: '水带', + 31: '水带_矩形', + 32: '流量计', + 33: '消火栓箱', + 34: '灭火器', + 35: '灶台', + 36: '灶眼', + 37: '照明设备', + 38: '熄火保护', + 39: '燃气管道', + 40: '燃气计量器具', + 41: '电线暴露', + 42: '电路图', + 43: '警戒线', + 44: '调压器', + 45: '调长器', + 46: '贴纸', + 47: '跨电线', + 48: '路锥', + 49: '过滤器', + 50: '配电箱内部', + 51: '配电箱外部', + 52: '长柄阀门', + 53: '闪光灯亮', + 54: '闪光灯灭', + 55: '阀门', + 56: '非专用软管', + 57: '风管', + 58: '鼓风机', + + } diff --git a/static/font/Arial.Unicode.ttf b/static/font/Arial.Unicode.ttf new file mode 100644 index 0000000..1537c5b --- /dev/null +++ b/static/font/Arial.Unicode.ttf Binary files differ diff --git a/static/font/Arial.ttf b/static/font/Arial.ttf new file mode 100644 index 0000000..ab68fb1 --- /dev/null +++ b/static/font/Arial.ttf Binary files differ diff --git a/tcp/tcp_client_connector.py b/tcp/tcp_client_connector.py index 399cd66..29d2966 100644 --- a/tcp/tcp_client_connector.py +++ b/tcp/tcp_client_connector.py @@ -67,7 +67,9 @@ self.reconnect_interval = reconnect_interval # 重连间隔 self.timeout = timeout # 连接/发送超时时间 self.is_connected = False # 连接状态标志 - self.message_queue = deque() + self.is_reconnecting = False + self.message_queue = asyncio.Queue() #deque() + self.gas_task = None self.read_lock = asyncio.Lock() # 添加锁 self.push_ts_dict = {} @@ -82,7 +84,9 @@ ) self.is_connected = True logger.info(f"已连接到 {self.ip}:{self.port}") - asyncio.create_task(self.process_message_queue()) # Start processing message queue + + if self.gas_task is None: + self.gas_task = asyncio.create_task(self.process_message_queue()) # Start processing message queue # 一旦连接成功,开始发送查询指令 await self.start_gas_query() @@ -93,10 +97,15 @@ async def reconnect(self): """处理断线重连""" + if self.is_reconnecting: + logger.info("Reconnection is already in progress...") + return + self.is_reconnecting = True await self.disconnect() # 先断开现有连接 logger.info(f"Reconnecting to {self.ip}:{self.port} after {self.reconnect_interval} seconds") - await asyncio.sleep(self.reconnect_interval) # 等待n秒后重连 + # await asyncio.sleep(self.reconnect_interval) # 等待n秒后重连 await self.connect() + self.is_reconnecting = False async def disconnect(self): """断开设备连接,清理资源""" @@ -172,24 +181,35 @@ async def _send_message_with_retry(self, message: bytes, have_response): """Send a message with retries on failure""" - while self.is_connected: + retry_attempts = 3 # Maximum retry attempts + for _ in range(retry_attempts): + if not self.is_connected: + await self.reconnect() + if not self.is_connected: + logger.error("Reconnection failed") + continue # Skip this attempt if reconnection fails + try: - if self.writer is None: - raise ConnectionResetError("No active connection") + if self.writer is None or self.writer.is_closing(): + raise ConnectionResetError("No active connection or writer is closing") self.writer.write(message) await self.writer.drain() logger.info(f"Sent message to {self.ip}:{self.port}: {message}") if have_response: - async with self.read_lock: # 使用锁确保只有一个协程读取 + async with self.read_lock: # Ensure only one coroutine reads data = await asyncio.wait_for(self.reader.read(1024), timeout=self.timeout) await self.parse_response(data) return # Exit loop on success - except (asyncio.TimeoutError, ConnectionResetError, asyncio.IncompleteReadError) as e: - logger.error(f"Failed to send message: {e}, retrying...") + + except (asyncio.TimeoutError, ConnectionResetError, asyncio.IncompleteReadError, RuntimeError, + BrokenPipeError, OSError, EOFError, ConnectionAbortedError, ConnectionRefusedError) as e: + logger.exception("Failed to send message") + self.is_connected = False # Mark connection as disconnected await self.reconnect() - break + + logger.error("Max retry attempts reached, message sending failed") # async def send_message(self, message: bytes, have_response=True): # """发送自定义消息的接口,供其他类调用"""