diff --git a/algo/algo_runner.py b/algo/algo_runner.py index a97aee9..d59d93a 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,4 @@ +import concurrent.futures import copy import uuid from typing import Dict @@ -31,7 +32,7 @@ self.model_service.register_change_callback(self.on_model_change) self.relation_service.register_change_callback(self.on_relation_change) - def start(self): + async def start(self): logger.info("Starting AlgoRunner...") """在程序启动时调用,读取设备和模型,启动检测线程""" self.model_manager.load_models() @@ -102,12 +103,12 @@ self.device_tasks[device_id].stop_detection_task() try: # 设置超时时间等待任务停止(例如10秒) - result = future.result(timeout=10) + result = future.result(timeout=30) logger.info(f"Task {thread_id} stopped successfully.") - except TimeoutError: + except concurrent.futures.TimeoutError as te: logger.error(f"Task {thread_id} did not stop within the timeout.") except Exception as e: - logger.error(f"Task {thread_id} encountered an error while stopping: {e}") + logger.exception(f"Task {thread_id} encountered an error while stopping: {e}") finally: # 确保无论任务是否停止,都将其从任务列表中移除 del self.device_tasks[device_id] diff --git a/algo/algo_runner.py b/algo/algo_runner.py index a97aee9..d59d93a 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,4 @@ +import concurrent.futures import copy import uuid from typing import Dict @@ -31,7 +32,7 @@ self.model_service.register_change_callback(self.on_model_change) self.relation_service.register_change_callback(self.on_relation_change) - def start(self): + async def start(self): logger.info("Starting AlgoRunner...") """在程序启动时调用,读取设备和模型,启动检测线程""" self.model_manager.load_models() @@ -102,12 +103,12 @@ self.device_tasks[device_id].stop_detection_task() try: # 设置超时时间等待任务停止(例如10秒) - result = future.result(timeout=10) + result = future.result(timeout=30) logger.info(f"Task {thread_id} stopped successfully.") - except TimeoutError: + except concurrent.futures.TimeoutError as te: logger.error(f"Task {thread_id} did not stop within the timeout.") except Exception as e: - logger.error(f"Task {thread_id} encountered an error while stopping: {e}") + logger.exception(f"Task {thread_id} encountered an error while stopping: {e}") finally: # 确保无论任务是否停止,都将其从任务列表中移除 del self.device_tasks[device_id] diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 5e1c75d..4ce760c 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -95,6 +95,10 @@ self.frame_analysis_result_service.add_frame_analysis_results(frame_results) def run(self): + while not self.stream_loader.init: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + self.stream_loader.init_cap() for frame in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 diff --git a/algo/algo_runner.py b/algo/algo_runner.py index a97aee9..d59d93a 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,4 @@ +import concurrent.futures import copy import uuid from typing import Dict @@ -31,7 +32,7 @@ self.model_service.register_change_callback(self.on_model_change) self.relation_service.register_change_callback(self.on_relation_change) - def start(self): + async def start(self): logger.info("Starting AlgoRunner...") """在程序启动时调用,读取设备和模型,启动检测线程""" self.model_manager.load_models() @@ -102,12 +103,12 @@ self.device_tasks[device_id].stop_detection_task() try: # 设置超时时间等待任务停止(例如10秒) - result = future.result(timeout=10) + result = future.result(timeout=30) logger.info(f"Task {thread_id} stopped successfully.") - except TimeoutError: + except concurrent.futures.TimeoutError as te: logger.error(f"Task {thread_id} did not stop within the timeout.") except Exception as e: - logger.error(f"Task {thread_id} encountered an error while stopping: {e}") + logger.exception(f"Task {thread_id} encountered an error while stopping: {e}") finally: # 确保无论任务是否停止,都将其从任务列表中移除 del self.device_tasks[device_id] diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 5e1c75d..4ce760c 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -95,6 +95,10 @@ self.frame_analysis_result_service.add_frame_analysis_results(frame_results) def run(self): + while not self.stream_loader.init: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + self.stream_loader.init_cap() for frame in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 9a1199e..d5cdef1 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -23,17 +23,20 @@ self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 self.thread_pool = GlobalThreadPool() - self.thread_pool.submit_task(self.update) + # self.thread = Thread(target=self.update, daemon=True) # self.thread.start() def init_cap(self): + if self.init: + return self.cap = self.get_connect() if self.cap: self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) _, self.frame = self.cap.read() + self.thread_pool.submit_task(self.update) self.init = True def create_capture(self): @@ -58,7 +61,6 @@ logger.info(f"{self.url} stopping connection attempts...thread {self.device_thread_id} stopped") return None # 退出循环,返回 None 或者其他你认为合适的值 logger.info(f"{self.url} try to connect...") - print('cap') cap = self.create_capture() if cap is None or not cap.isOpened(): logger.info(f"{self.url} connect failed, retry after {self.retry_interval} second...") @@ -71,9 +73,8 @@ vid_n = 0 log_n = 0 while not self.__stop_event.is_set(): - print('update') - if not self.init: - self.init_cap() + # if not self.init: + # self.init_cap() if self.cap is None: continue vid_n += 1 diff --git a/algo/algo_runner.py b/algo/algo_runner.py index a97aee9..d59d93a 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,4 @@ +import concurrent.futures import copy import uuid from typing import Dict @@ -31,7 +32,7 @@ self.model_service.register_change_callback(self.on_model_change) self.relation_service.register_change_callback(self.on_relation_change) - def start(self): + async def start(self): logger.info("Starting AlgoRunner...") """在程序启动时调用,读取设备和模型,启动检测线程""" self.model_manager.load_models() @@ -102,12 +103,12 @@ self.device_tasks[device_id].stop_detection_task() try: # 设置超时时间等待任务停止(例如10秒) - result = future.result(timeout=10) + result = future.result(timeout=30) logger.info(f"Task {thread_id} stopped successfully.") - except TimeoutError: + except concurrent.futures.TimeoutError as te: logger.error(f"Task {thread_id} did not stop within the timeout.") except Exception as e: - logger.error(f"Task {thread_id} encountered an error while stopping: {e}") + logger.exception(f"Task {thread_id} encountered an error while stopping: {e}") finally: # 确保无论任务是否停止,都将其从任务列表中移除 del self.device_tasks[device_id] diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 5e1c75d..4ce760c 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -95,6 +95,10 @@ self.frame_analysis_result_service.add_frame_analysis_results(frame_results) def run(self): + while not self.stream_loader.init: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + self.stream_loader.init_cap() for frame in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 9a1199e..d5cdef1 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -23,17 +23,20 @@ self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 self.thread_pool = GlobalThreadPool() - self.thread_pool.submit_task(self.update) + # self.thread = Thread(target=self.update, daemon=True) # self.thread.start() def init_cap(self): + if self.init: + return self.cap = self.get_connect() if self.cap: self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) _, self.frame = self.cap.read() + self.thread_pool.submit_task(self.update) self.init = True def create_capture(self): @@ -58,7 +61,6 @@ logger.info(f"{self.url} stopping connection attempts...thread {self.device_thread_id} stopped") return None # 退出循环,返回 None 或者其他你认为合适的值 logger.info(f"{self.url} try to connect...") - print('cap') cap = self.create_capture() if cap is None or not cap.isOpened(): logger.info(f"{self.url} connect failed, retry after {self.retry_interval} second...") @@ -71,9 +73,8 @@ vid_n = 0 log_n = 0 while not self.__stop_event.is_set(): - print('update') - if not self.init: - self.init_cap() + # if not self.init: + # self.init_cap() if self.cap is None: continue vid_n += 1 diff --git a/apis/base.py b/apis/base.py index a7c4ede..b442a71 100644 --- a/apis/base.py +++ b/apis/base.py @@ -24,3 +24,5 @@ def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False): return StandardResponse(data=data, code=code, message=message, success=success) + + diff --git a/algo/algo_runner.py b/algo/algo_runner.py index a97aee9..d59d93a 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,4 @@ +import concurrent.futures import copy import uuid from typing import Dict @@ -31,7 +32,7 @@ self.model_service.register_change_callback(self.on_model_change) self.relation_service.register_change_callback(self.on_relation_change) - def start(self): + async def start(self): logger.info("Starting AlgoRunner...") """在程序启动时调用,读取设备和模型,启动检测线程""" self.model_manager.load_models() @@ -102,12 +103,12 @@ self.device_tasks[device_id].stop_detection_task() try: # 设置超时时间等待任务停止(例如10秒) - result = future.result(timeout=10) + result = future.result(timeout=30) logger.info(f"Task {thread_id} stopped successfully.") - except TimeoutError: + except concurrent.futures.TimeoutError as te: logger.error(f"Task {thread_id} did not stop within the timeout.") except Exception as e: - logger.error(f"Task {thread_id} encountered an error while stopping: {e}") + logger.exception(f"Task {thread_id} encountered an error while stopping: {e}") finally: # 确保无论任务是否停止,都将其从任务列表中移除 del self.device_tasks[device_id] diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 5e1c75d..4ce760c 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -95,6 +95,10 @@ self.frame_analysis_result_service.add_frame_analysis_results(frame_results) def run(self): + while not self.stream_loader.init: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + self.stream_loader.init_cap() for frame in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 9a1199e..d5cdef1 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -23,17 +23,20 @@ self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 self.thread_pool = GlobalThreadPool() - self.thread_pool.submit_task(self.update) + # self.thread = Thread(target=self.update, daemon=True) # self.thread.start() def init_cap(self): + if self.init: + return self.cap = self.get_connect() if self.cap: self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) _, self.frame = self.cap.read() + self.thread_pool.submit_task(self.update) self.init = True def create_capture(self): @@ -58,7 +61,6 @@ logger.info(f"{self.url} stopping connection attempts...thread {self.device_thread_id} stopped") return None # 退出循环,返回 None 或者其他你认为合适的值 logger.info(f"{self.url} try to connect...") - print('cap') cap = self.create_capture() if cap is None or not cap.isOpened(): logger.info(f"{self.url} connect failed, retry after {self.retry_interval} second...") @@ -71,9 +73,8 @@ vid_n = 0 log_n = 0 while not self.__stop_event.is_set(): - print('update') - if not self.init: - self.init_cap() + # if not self.init: + # self.init_cap() if self.cap is None: continue vid_n += 1 diff --git a/apis/base.py b/apis/base.py index a7c4ede..b442a71 100644 --- a/apis/base.py +++ b/apis/base.py @@ -24,3 +24,5 @@ def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False): return StandardResponse(data=data, code=code, message=message, success=success) + + diff --git a/apis/frame.py b/apis/frame.py new file mode 100644 index 0000000..aa366c8 --- /dev/null +++ b/apis/frame.py @@ -0,0 +1,75 @@ +import io +import os +from io import BytesIO +from typing import List, Optional + +import cv2 +from fastapi import APIRouter, Depends, Query, HTTPException, Request +from sqlmodel import Session +from fastapi.responses import StreamingResponse, FileResponse + +from apis.base import StandardResponse, PageResponse, standard_response, standard_error_response +from common.biz_exception import BizException +from db.database import get_db +from entity.base import parse_datetime +from entity.device_frame import DeviceFrame +from services.device_frame_service import DeviceFrameService + +router = APIRouter() + + +@router.get("/page/", response_model=StandardResponse[PageResponse[DeviceFrame]]) +def get_frame_page( + device_name: Optional[str] = None, + device_code: Optional[str] = None, + frame_start_time: Optional[str] = None, + frame_end_time: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = DeviceFrameService(db) + + # 获取分页后的设备列表和总数 + frames, total = service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), + parse_datetime(frame_end_time), offset, limit) + + return standard_response( + data=PageResponse(total=total, items=frames) + ) + + +# 路由:使用 OpenCV 生成内存图像并返回字节流 +@router.get("/frame_image/") +def get_frame_image(frame_id, db: Session = Depends(get_db)): + try: + service = DeviceFrameService(db) + frame = service.get_frame_annotator(frame_id) + if frame is None: + return standard_error_response(message="Frame does not exist") + + _, img_encoded = cv2.imencode('.jpg', frame) + img_bytes = img_encoded.tobytes() # 转换为字节格式 + + # 定义生成器分块读取内存中的图片数据 + def iterfile(): + chunk_size = 1024 * 64 # 64KB 每块 + file_like = io.BytesIO(img_bytes) # 使用 BytesIO 将字节数据视为文件对象 + while chunk := file_like.read(chunk_size): + yield chunk + + return StreamingResponse(iterfile(), media_type="image/jpeg") + + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error reading image: {str(e)}") + + +@router.get("/frame_test") +def get_frame_test(request: Request): + file_path = "test.jpg" + return FileResponse(file_path) + # def iterfile(): + # with open(file_path, 'rb') as file: + # while chunk := file.read(1024): + # yield chunk + # + # return StreamingResponse(iterfile(), media_type="image/jpeg") diff --git a/algo/algo_runner.py b/algo/algo_runner.py index a97aee9..d59d93a 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,4 @@ +import concurrent.futures import copy import uuid from typing import Dict @@ -31,7 +32,7 @@ self.model_service.register_change_callback(self.on_model_change) self.relation_service.register_change_callback(self.on_relation_change) - def start(self): + async def start(self): logger.info("Starting AlgoRunner...") """在程序启动时调用,读取设备和模型,启动检测线程""" self.model_manager.load_models() @@ -102,12 +103,12 @@ self.device_tasks[device_id].stop_detection_task() try: # 设置超时时间等待任务停止(例如10秒) - result = future.result(timeout=10) + result = future.result(timeout=30) logger.info(f"Task {thread_id} stopped successfully.") - except TimeoutError: + except concurrent.futures.TimeoutError as te: logger.error(f"Task {thread_id} did not stop within the timeout.") except Exception as e: - logger.error(f"Task {thread_id} encountered an error while stopping: {e}") + logger.exception(f"Task {thread_id} encountered an error while stopping: {e}") finally: # 确保无论任务是否停止,都将其从任务列表中移除 del self.device_tasks[device_id] diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 5e1c75d..4ce760c 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -95,6 +95,10 @@ self.frame_analysis_result_service.add_frame_analysis_results(frame_results) def run(self): + while not self.stream_loader.init: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + self.stream_loader.init_cap() for frame in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 9a1199e..d5cdef1 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -23,17 +23,20 @@ self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 self.thread_pool = GlobalThreadPool() - self.thread_pool.submit_task(self.update) + # self.thread = Thread(target=self.update, daemon=True) # self.thread.start() def init_cap(self): + if self.init: + return self.cap = self.get_connect() if self.cap: self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) _, self.frame = self.cap.read() + self.thread_pool.submit_task(self.update) self.init = True def create_capture(self): @@ -58,7 +61,6 @@ logger.info(f"{self.url} stopping connection attempts...thread {self.device_thread_id} stopped") return None # 退出循环,返回 None 或者其他你认为合适的值 logger.info(f"{self.url} try to connect...") - print('cap') cap = self.create_capture() if cap is None or not cap.isOpened(): logger.info(f"{self.url} connect failed, retry after {self.retry_interval} second...") @@ -71,9 +73,8 @@ vid_n = 0 log_n = 0 while not self.__stop_event.is_set(): - print('update') - if not self.init: - self.init_cap() + # if not self.init: + # self.init_cap() if self.cap is None: continue vid_n += 1 diff --git a/apis/base.py b/apis/base.py index a7c4ede..b442a71 100644 --- a/apis/base.py +++ b/apis/base.py @@ -24,3 +24,5 @@ def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False): return StandardResponse(data=data, code=code, message=message, success=success) + + diff --git a/apis/frame.py b/apis/frame.py new file mode 100644 index 0000000..aa366c8 --- /dev/null +++ b/apis/frame.py @@ -0,0 +1,75 @@ +import io +import os +from io import BytesIO +from typing import List, Optional + +import cv2 +from fastapi import APIRouter, Depends, Query, HTTPException, Request +from sqlmodel import Session +from fastapi.responses import StreamingResponse, FileResponse + +from apis.base import StandardResponse, PageResponse, standard_response, standard_error_response +from common.biz_exception import BizException +from db.database import get_db +from entity.base import parse_datetime +from entity.device_frame import DeviceFrame +from services.device_frame_service import DeviceFrameService + +router = APIRouter() + + +@router.get("/page/", response_model=StandardResponse[PageResponse[DeviceFrame]]) +def get_frame_page( + device_name: Optional[str] = None, + device_code: Optional[str] = None, + frame_start_time: Optional[str] = None, + frame_end_time: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = DeviceFrameService(db) + + # 获取分页后的设备列表和总数 + frames, total = service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), + parse_datetime(frame_end_time), offset, limit) + + return standard_response( + data=PageResponse(total=total, items=frames) + ) + + +# 路由:使用 OpenCV 生成内存图像并返回字节流 +@router.get("/frame_image/") +def get_frame_image(frame_id, db: Session = Depends(get_db)): + try: + service = DeviceFrameService(db) + frame = service.get_frame_annotator(frame_id) + if frame is None: + return standard_error_response(message="Frame does not exist") + + _, img_encoded = cv2.imencode('.jpg', frame) + img_bytes = img_encoded.tobytes() # 转换为字节格式 + + # 定义生成器分块读取内存中的图片数据 + def iterfile(): + chunk_size = 1024 * 64 # 64KB 每块 + file_like = io.BytesIO(img_bytes) # 使用 BytesIO 将字节数据视为文件对象 + while chunk := file_like.read(chunk_size): + yield chunk + + return StreamingResponse(iterfile(), media_type="image/jpeg") + + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error reading image: {str(e)}") + + +@router.get("/frame_test") +def get_frame_test(request: Request): + file_path = "test.jpg" + return FileResponse(file_path) + # def iterfile(): + # with open(file_path, 'rb') as file: + # while chunk := file.read(1024): + # yield chunk + # + # return StreamingResponse(iterfile(), media_type="image/jpeg") diff --git a/apis/router.py b/apis/router.py index 22b5ed6..de08805 100644 --- a/apis/router.py +++ b/apis/router.py @@ -3,6 +3,7 @@ from .device import router as devices_router from .model import router as models_router from .device_model_realtion import router as device_model_relation_router +from .frame import router as frame_router # 创建一个全局的 router @@ -12,4 +13,4 @@ router.include_router(devices_router, prefix="/device", tags=["Devices"]) router.include_router(models_router, prefix="/model", tags=["Models"]) router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"]) - +router.include_router(frame_router, prefix="/frame", tags=["DeviceFrame"]) diff --git a/algo/algo_runner.py b/algo/algo_runner.py index a97aee9..d59d93a 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,4 @@ +import concurrent.futures import copy import uuid from typing import Dict @@ -31,7 +32,7 @@ self.model_service.register_change_callback(self.on_model_change) self.relation_service.register_change_callback(self.on_relation_change) - def start(self): + async def start(self): logger.info("Starting AlgoRunner...") """在程序启动时调用,读取设备和模型,启动检测线程""" self.model_manager.load_models() @@ -102,12 +103,12 @@ self.device_tasks[device_id].stop_detection_task() try: # 设置超时时间等待任务停止(例如10秒) - result = future.result(timeout=10) + result = future.result(timeout=30) logger.info(f"Task {thread_id} stopped successfully.") - except TimeoutError: + except concurrent.futures.TimeoutError as te: logger.error(f"Task {thread_id} did not stop within the timeout.") except Exception as e: - logger.error(f"Task {thread_id} encountered an error while stopping: {e}") + logger.exception(f"Task {thread_id} encountered an error while stopping: {e}") finally: # 确保无论任务是否停止,都将其从任务列表中移除 del self.device_tasks[device_id] diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 5e1c75d..4ce760c 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -95,6 +95,10 @@ self.frame_analysis_result_service.add_frame_analysis_results(frame_results) def run(self): + while not self.stream_loader.init: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + self.stream_loader.init_cap() for frame in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 9a1199e..d5cdef1 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -23,17 +23,20 @@ self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 self.thread_pool = GlobalThreadPool() - self.thread_pool.submit_task(self.update) + # self.thread = Thread(target=self.update, daemon=True) # self.thread.start() def init_cap(self): + if self.init: + return self.cap = self.get_connect() if self.cap: self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) _, self.frame = self.cap.read() + self.thread_pool.submit_task(self.update) self.init = True def create_capture(self): @@ -58,7 +61,6 @@ logger.info(f"{self.url} stopping connection attempts...thread {self.device_thread_id} stopped") return None # 退出循环,返回 None 或者其他你认为合适的值 logger.info(f"{self.url} try to connect...") - print('cap') cap = self.create_capture() if cap is None or not cap.isOpened(): logger.info(f"{self.url} connect failed, retry after {self.retry_interval} second...") @@ -71,9 +73,8 @@ vid_n = 0 log_n = 0 while not self.__stop_event.is_set(): - print('update') - if not self.init: - self.init_cap() + # if not self.init: + # self.init_cap() if self.cap is None: continue vid_n += 1 diff --git a/apis/base.py b/apis/base.py index a7c4ede..b442a71 100644 --- a/apis/base.py +++ b/apis/base.py @@ -24,3 +24,5 @@ def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False): return StandardResponse(data=data, code=code, message=message, success=success) + + diff --git a/apis/frame.py b/apis/frame.py new file mode 100644 index 0000000..aa366c8 --- /dev/null +++ b/apis/frame.py @@ -0,0 +1,75 @@ +import io +import os +from io import BytesIO +from typing import List, Optional + +import cv2 +from fastapi import APIRouter, Depends, Query, HTTPException, Request +from sqlmodel import Session +from fastapi.responses import StreamingResponse, FileResponse + +from apis.base import StandardResponse, PageResponse, standard_response, standard_error_response +from common.biz_exception import BizException +from db.database import get_db +from entity.base import parse_datetime +from entity.device_frame import DeviceFrame +from services.device_frame_service import DeviceFrameService + +router = APIRouter() + + +@router.get("/page/", response_model=StandardResponse[PageResponse[DeviceFrame]]) +def get_frame_page( + device_name: Optional[str] = None, + device_code: Optional[str] = None, + frame_start_time: Optional[str] = None, + frame_end_time: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = DeviceFrameService(db) + + # 获取分页后的设备列表和总数 + frames, total = service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), + parse_datetime(frame_end_time), offset, limit) + + return standard_response( + data=PageResponse(total=total, items=frames) + ) + + +# 路由:使用 OpenCV 生成内存图像并返回字节流 +@router.get("/frame_image/") +def get_frame_image(frame_id, db: Session = Depends(get_db)): + try: + service = DeviceFrameService(db) + frame = service.get_frame_annotator(frame_id) + if frame is None: + return standard_error_response(message="Frame does not exist") + + _, img_encoded = cv2.imencode('.jpg', frame) + img_bytes = img_encoded.tobytes() # 转换为字节格式 + + # 定义生成器分块读取内存中的图片数据 + def iterfile(): + chunk_size = 1024 * 64 # 64KB 每块 + file_like = io.BytesIO(img_bytes) # 使用 BytesIO 将字节数据视为文件对象 + while chunk := file_like.read(chunk_size): + yield chunk + + return StreamingResponse(iterfile(), media_type="image/jpeg") + + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error reading image: {str(e)}") + + +@router.get("/frame_test") +def get_frame_test(request: Request): + file_path = "test.jpg" + return FileResponse(file_path) + # def iterfile(): + # with open(file_path, 'rb') as file: + # while chunk := file.read(1024): + # yield chunk + # + # return StreamingResponse(iterfile(), media_type="image/jpeg") diff --git a/apis/router.py b/apis/router.py index 22b5ed6..de08805 100644 --- a/apis/router.py +++ b/apis/router.py @@ -3,6 +3,7 @@ from .device import router as devices_router from .model import router as models_router from .device_model_realtion import router as device_model_relation_router +from .frame import router as frame_router # 创建一个全局的 router @@ -12,4 +13,4 @@ router.include_router(devices_router, prefix="/device", tags=["Devices"]) router.include_router(models_router, prefix="/model", tags=["Models"]) router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"]) - +router.include_router(frame_router, prefix="/frame", tags=["DeviceFrame"]) diff --git a/common/image_plotting.py b/common/image_plotting.py new file mode 100644 index 0000000..818ab95 --- /dev/null +++ b/common/image_plotting.py @@ -0,0 +1,115 @@ +import numpy as np +import cv2 + + +class Colors: + """ + Ultralytics default color palette https://ultralytics.com/. + + This class provides methods to work with the Ultralytics color palette, including converting hex color codes to + RGB values. + + Attributes: + palette (list of tuple): List of RGB color values. + n (int): The number of colors in the palette. + pose_palette (np.ndarray): A specific color palette array with dtype np.uint8. + """ + + def __init__(self): + """Initialize colors as hex = matplotlib.colors.TABLEAU_COLORS.values().""" + hexs = ( + "FF3838", + "FF9D97", + "FF701F", + "FFB21D", + "CFD231", + "48F90A", + "92CC17", + "3DDB86", + "1A9334", + "00D4BB", + "2C99A8", + "00C2FF", + "344593", + "6473FF", + "0018EC", + "8438FF", + "520085", + "CB38FF", + "FF95C8", + "FF37C7", + ) + self.palette = [self.hex2rgb(f"#{c}") for c in hexs] + self.n = len(self.palette) + self.pose_palette = np.array( + [ + [255, 128, 0], + [255, 153, 51], + [255, 178, 102], + [230, 230, 0], + [255, 153, 255], + [153, 204, 255], + [255, 102, 255], + [255, 51, 255], + [102, 178, 255], + [51, 153, 255], + [255, 153, 153], + [255, 102, 102], + [255, 51, 51], + [153, 255, 153], + [102, 255, 102], + [51, 255, 51], + [0, 255, 0], + [0, 0, 255], + [255, 0, 0], + [255, 255, 255], + ], + dtype=np.uint8, + ) + + def __call__(self, i, bgr=False): + """Converts hex color codes to RGB values.""" + c = self.palette[int(i) % self.n] + return (c[2], c[1], c[0]) if bgr else c + + @staticmethod + def hex2rgb(h): + """Converts hex color codes to RGB values (i.e. default PIL order).""" + return tuple(int(h[1 + i: 1 + i + 2], 16) for i in (0, 2, 4)) + + +colors = Colors() + + +class Annotator: + + def __init__(self, im, line_width=None, font_size=None, font="Arial.ttf", pil=False, example="abc"): + self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) + assert im.data.contiguous, "Image not contiguous. Apply np.ascontiguousarray(im) to Annotator input images." + self.im = im if im.flags.writeable else im.copy() + self.tf = max(self.lw - 1, 1) # font thickness + self.sf = self.lw / 3 # font scale + + def box_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), rotated=False): + """Add one xyxy box to image with label.""" + p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3])) + cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA) + if label: + w, h = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0] # text width, height + outside = p1[1] - h >= 3 + p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3 + cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled + cv2.putText( + self.im, + label, + (p1[0], p1[1] - 2 if outside else p1[1] + h + 2), + 0, + self.sf, + txt_color, + thickness=self.tf, + lineType=cv2.LINE_AA, + ) + + def result(self): + """Return annotated image as array.""" + return np.asarray(self.im) \ No newline at end of file diff --git a/algo/algo_runner.py b/algo/algo_runner.py index a97aee9..d59d93a 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,4 @@ +import concurrent.futures import copy import uuid from typing import Dict @@ -31,7 +32,7 @@ self.model_service.register_change_callback(self.on_model_change) self.relation_service.register_change_callback(self.on_relation_change) - def start(self): + async def start(self): logger.info("Starting AlgoRunner...") """在程序启动时调用,读取设备和模型,启动检测线程""" self.model_manager.load_models() @@ -102,12 +103,12 @@ self.device_tasks[device_id].stop_detection_task() try: # 设置超时时间等待任务停止(例如10秒) - result = future.result(timeout=10) + result = future.result(timeout=30) logger.info(f"Task {thread_id} stopped successfully.") - except TimeoutError: + except concurrent.futures.TimeoutError as te: logger.error(f"Task {thread_id} did not stop within the timeout.") except Exception as e: - logger.error(f"Task {thread_id} encountered an error while stopping: {e}") + logger.exception(f"Task {thread_id} encountered an error while stopping: {e}") finally: # 确保无论任务是否停止,都将其从任务列表中移除 del self.device_tasks[device_id] diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 5e1c75d..4ce760c 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -95,6 +95,10 @@ self.frame_analysis_result_service.add_frame_analysis_results(frame_results) def run(self): + while not self.stream_loader.init: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + self.stream_loader.init_cap() for frame in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 9a1199e..d5cdef1 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -23,17 +23,20 @@ self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 self.thread_pool = GlobalThreadPool() - self.thread_pool.submit_task(self.update) + # self.thread = Thread(target=self.update, daemon=True) # self.thread.start() def init_cap(self): + if self.init: + return self.cap = self.get_connect() if self.cap: self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) _, self.frame = self.cap.read() + self.thread_pool.submit_task(self.update) self.init = True def create_capture(self): @@ -58,7 +61,6 @@ logger.info(f"{self.url} stopping connection attempts...thread {self.device_thread_id} stopped") return None # 退出循环,返回 None 或者其他你认为合适的值 logger.info(f"{self.url} try to connect...") - print('cap') cap = self.create_capture() if cap is None or not cap.isOpened(): logger.info(f"{self.url} connect failed, retry after {self.retry_interval} second...") @@ -71,9 +73,8 @@ vid_n = 0 log_n = 0 while not self.__stop_event.is_set(): - print('update') - if not self.init: - self.init_cap() + # if not self.init: + # self.init_cap() if self.cap is None: continue vid_n += 1 diff --git a/apis/base.py b/apis/base.py index a7c4ede..b442a71 100644 --- a/apis/base.py +++ b/apis/base.py @@ -24,3 +24,5 @@ def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False): return StandardResponse(data=data, code=code, message=message, success=success) + + diff --git a/apis/frame.py b/apis/frame.py new file mode 100644 index 0000000..aa366c8 --- /dev/null +++ b/apis/frame.py @@ -0,0 +1,75 @@ +import io +import os +from io import BytesIO +from typing import List, Optional + +import cv2 +from fastapi import APIRouter, Depends, Query, HTTPException, Request +from sqlmodel import Session +from fastapi.responses import StreamingResponse, FileResponse + +from apis.base import StandardResponse, PageResponse, standard_response, standard_error_response +from common.biz_exception import BizException +from db.database import get_db +from entity.base import parse_datetime +from entity.device_frame import DeviceFrame +from services.device_frame_service import DeviceFrameService + +router = APIRouter() + + +@router.get("/page/", response_model=StandardResponse[PageResponse[DeviceFrame]]) +def get_frame_page( + device_name: Optional[str] = None, + device_code: Optional[str] = None, + frame_start_time: Optional[str] = None, + frame_end_time: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = DeviceFrameService(db) + + # 获取分页后的设备列表和总数 + frames, total = service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), + parse_datetime(frame_end_time), offset, limit) + + return standard_response( + data=PageResponse(total=total, items=frames) + ) + + +# 路由:使用 OpenCV 生成内存图像并返回字节流 +@router.get("/frame_image/") +def get_frame_image(frame_id, db: Session = Depends(get_db)): + try: + service = DeviceFrameService(db) + frame = service.get_frame_annotator(frame_id) + if frame is None: + return standard_error_response(message="Frame does not exist") + + _, img_encoded = cv2.imencode('.jpg', frame) + img_bytes = img_encoded.tobytes() # 转换为字节格式 + + # 定义生成器分块读取内存中的图片数据 + def iterfile(): + chunk_size = 1024 * 64 # 64KB 每块 + file_like = io.BytesIO(img_bytes) # 使用 BytesIO 将字节数据视为文件对象 + while chunk := file_like.read(chunk_size): + yield chunk + + return StreamingResponse(iterfile(), media_type="image/jpeg") + + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error reading image: {str(e)}") + + +@router.get("/frame_test") +def get_frame_test(request: Request): + file_path = "test.jpg" + return FileResponse(file_path) + # def iterfile(): + # with open(file_path, 'rb') as file: + # while chunk := file.read(1024): + # yield chunk + # + # return StreamingResponse(iterfile(), media_type="image/jpeg") diff --git a/apis/router.py b/apis/router.py index 22b5ed6..de08805 100644 --- a/apis/router.py +++ b/apis/router.py @@ -3,6 +3,7 @@ from .device import router as devices_router from .model import router as models_router from .device_model_realtion import router as device_model_relation_router +from .frame import router as frame_router # 创建一个全局的 router @@ -12,4 +13,4 @@ router.include_router(devices_router, prefix="/device", tags=["Devices"]) router.include_router(models_router, prefix="/model", tags=["Models"]) router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"]) - +router.include_router(frame_router, prefix="/frame", tags=["DeviceFrame"]) diff --git a/common/image_plotting.py b/common/image_plotting.py new file mode 100644 index 0000000..818ab95 --- /dev/null +++ b/common/image_plotting.py @@ -0,0 +1,115 @@ +import numpy as np +import cv2 + + +class Colors: + """ + Ultralytics default color palette https://ultralytics.com/. + + This class provides methods to work with the Ultralytics color palette, including converting hex color codes to + RGB values. + + Attributes: + palette (list of tuple): List of RGB color values. + n (int): The number of colors in the palette. + pose_palette (np.ndarray): A specific color palette array with dtype np.uint8. + """ + + def __init__(self): + """Initialize colors as hex = matplotlib.colors.TABLEAU_COLORS.values().""" + hexs = ( + "FF3838", + "FF9D97", + "FF701F", + "FFB21D", + "CFD231", + "48F90A", + "92CC17", + "3DDB86", + "1A9334", + "00D4BB", + "2C99A8", + "00C2FF", + "344593", + "6473FF", + "0018EC", + "8438FF", + "520085", + "CB38FF", + "FF95C8", + "FF37C7", + ) + self.palette = [self.hex2rgb(f"#{c}") for c in hexs] + self.n = len(self.palette) + self.pose_palette = np.array( + [ + [255, 128, 0], + [255, 153, 51], + [255, 178, 102], + [230, 230, 0], + [255, 153, 255], + [153, 204, 255], + [255, 102, 255], + [255, 51, 255], + [102, 178, 255], + [51, 153, 255], + [255, 153, 153], + [255, 102, 102], + [255, 51, 51], + [153, 255, 153], + [102, 255, 102], + [51, 255, 51], + [0, 255, 0], + [0, 0, 255], + [255, 0, 0], + [255, 255, 255], + ], + dtype=np.uint8, + ) + + def __call__(self, i, bgr=False): + """Converts hex color codes to RGB values.""" + c = self.palette[int(i) % self.n] + return (c[2], c[1], c[0]) if bgr else c + + @staticmethod + def hex2rgb(h): + """Converts hex color codes to RGB values (i.e. default PIL order).""" + return tuple(int(h[1 + i: 1 + i + 2], 16) for i in (0, 2, 4)) + + +colors = Colors() + + +class Annotator: + + def __init__(self, im, line_width=None, font_size=None, font="Arial.ttf", pil=False, example="abc"): + self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) + assert im.data.contiguous, "Image not contiguous. Apply np.ascontiguousarray(im) to Annotator input images." + self.im = im if im.flags.writeable else im.copy() + self.tf = max(self.lw - 1, 1) # font thickness + self.sf = self.lw / 3 # font scale + + def box_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), rotated=False): + """Add one xyxy box to image with label.""" + p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3])) + cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA) + if label: + w, h = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0] # text width, height + outside = p1[1] - h >= 3 + p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3 + cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled + cv2.putText( + self.im, + label, + (p1[0], p1[1] - 2 if outside else p1[1] + h + 2), + 0, + self.sf, + txt_color, + thickness=self.tf, + lineType=cv2.LINE_AA, + ) + + def result(self): + """Return annotated image as array.""" + return np.asarray(self.im) \ No newline at end of file diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db index 765b724..9973e56 100644 --- a/db/safe-algo-pro.db +++ b/db/safe-algo-pro.db Binary files differ diff --git a/algo/algo_runner.py b/algo/algo_runner.py index a97aee9..d59d93a 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,4 @@ +import concurrent.futures import copy import uuid from typing import Dict @@ -31,7 +32,7 @@ self.model_service.register_change_callback(self.on_model_change) self.relation_service.register_change_callback(self.on_relation_change) - def start(self): + async def start(self): logger.info("Starting AlgoRunner...") """在程序启动时调用,读取设备和模型,启动检测线程""" self.model_manager.load_models() @@ -102,12 +103,12 @@ self.device_tasks[device_id].stop_detection_task() try: # 设置超时时间等待任务停止(例如10秒) - result = future.result(timeout=10) + result = future.result(timeout=30) logger.info(f"Task {thread_id} stopped successfully.") - except TimeoutError: + except concurrent.futures.TimeoutError as te: logger.error(f"Task {thread_id} did not stop within the timeout.") except Exception as e: - logger.error(f"Task {thread_id} encountered an error while stopping: {e}") + logger.exception(f"Task {thread_id} encountered an error while stopping: {e}") finally: # 确保无论任务是否停止,都将其从任务列表中移除 del self.device_tasks[device_id] diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 5e1c75d..4ce760c 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -95,6 +95,10 @@ self.frame_analysis_result_service.add_frame_analysis_results(frame_results) def run(self): + while not self.stream_loader.init: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + self.stream_loader.init_cap() for frame in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 9a1199e..d5cdef1 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -23,17 +23,20 @@ self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 self.thread_pool = GlobalThreadPool() - self.thread_pool.submit_task(self.update) + # self.thread = Thread(target=self.update, daemon=True) # self.thread.start() def init_cap(self): + if self.init: + return self.cap = self.get_connect() if self.cap: self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) _, self.frame = self.cap.read() + self.thread_pool.submit_task(self.update) self.init = True def create_capture(self): @@ -58,7 +61,6 @@ logger.info(f"{self.url} stopping connection attempts...thread {self.device_thread_id} stopped") return None # 退出循环,返回 None 或者其他你认为合适的值 logger.info(f"{self.url} try to connect...") - print('cap') cap = self.create_capture() if cap is None or not cap.isOpened(): logger.info(f"{self.url} connect failed, retry after {self.retry_interval} second...") @@ -71,9 +73,8 @@ vid_n = 0 log_n = 0 while not self.__stop_event.is_set(): - print('update') - if not self.init: - self.init_cap() + # if not self.init: + # self.init_cap() if self.cap is None: continue vid_n += 1 diff --git a/apis/base.py b/apis/base.py index a7c4ede..b442a71 100644 --- a/apis/base.py +++ b/apis/base.py @@ -24,3 +24,5 @@ def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False): return StandardResponse(data=data, code=code, message=message, success=success) + + diff --git a/apis/frame.py b/apis/frame.py new file mode 100644 index 0000000..aa366c8 --- /dev/null +++ b/apis/frame.py @@ -0,0 +1,75 @@ +import io +import os +from io import BytesIO +from typing import List, Optional + +import cv2 +from fastapi import APIRouter, Depends, Query, HTTPException, Request +from sqlmodel import Session +from fastapi.responses import StreamingResponse, FileResponse + +from apis.base import StandardResponse, PageResponse, standard_response, standard_error_response +from common.biz_exception import BizException +from db.database import get_db +from entity.base import parse_datetime +from entity.device_frame import DeviceFrame +from services.device_frame_service import DeviceFrameService + +router = APIRouter() + + +@router.get("/page/", response_model=StandardResponse[PageResponse[DeviceFrame]]) +def get_frame_page( + device_name: Optional[str] = None, + device_code: Optional[str] = None, + frame_start_time: Optional[str] = None, + frame_end_time: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = DeviceFrameService(db) + + # 获取分页后的设备列表和总数 + frames, total = service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), + parse_datetime(frame_end_time), offset, limit) + + return standard_response( + data=PageResponse(total=total, items=frames) + ) + + +# 路由:使用 OpenCV 生成内存图像并返回字节流 +@router.get("/frame_image/") +def get_frame_image(frame_id, db: Session = Depends(get_db)): + try: + service = DeviceFrameService(db) + frame = service.get_frame_annotator(frame_id) + if frame is None: + return standard_error_response(message="Frame does not exist") + + _, img_encoded = cv2.imencode('.jpg', frame) + img_bytes = img_encoded.tobytes() # 转换为字节格式 + + # 定义生成器分块读取内存中的图片数据 + def iterfile(): + chunk_size = 1024 * 64 # 64KB 每块 + file_like = io.BytesIO(img_bytes) # 使用 BytesIO 将字节数据视为文件对象 + while chunk := file_like.read(chunk_size): + yield chunk + + return StreamingResponse(iterfile(), media_type="image/jpeg") + + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error reading image: {str(e)}") + + +@router.get("/frame_test") +def get_frame_test(request: Request): + file_path = "test.jpg" + return FileResponse(file_path) + # def iterfile(): + # with open(file_path, 'rb') as file: + # while chunk := file.read(1024): + # yield chunk + # + # return StreamingResponse(iterfile(), media_type="image/jpeg") diff --git a/apis/router.py b/apis/router.py index 22b5ed6..de08805 100644 --- a/apis/router.py +++ b/apis/router.py @@ -3,6 +3,7 @@ from .device import router as devices_router from .model import router as models_router from .device_model_realtion import router as device_model_relation_router +from .frame import router as frame_router # 创建一个全局的 router @@ -12,4 +13,4 @@ router.include_router(devices_router, prefix="/device", tags=["Devices"]) router.include_router(models_router, prefix="/model", tags=["Models"]) router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"]) - +router.include_router(frame_router, prefix="/frame", tags=["DeviceFrame"]) diff --git a/common/image_plotting.py b/common/image_plotting.py new file mode 100644 index 0000000..818ab95 --- /dev/null +++ b/common/image_plotting.py @@ -0,0 +1,115 @@ +import numpy as np +import cv2 + + +class Colors: + """ + Ultralytics default color palette https://ultralytics.com/. + + This class provides methods to work with the Ultralytics color palette, including converting hex color codes to + RGB values. + + Attributes: + palette (list of tuple): List of RGB color values. + n (int): The number of colors in the palette. + pose_palette (np.ndarray): A specific color palette array with dtype np.uint8. + """ + + def __init__(self): + """Initialize colors as hex = matplotlib.colors.TABLEAU_COLORS.values().""" + hexs = ( + "FF3838", + "FF9D97", + "FF701F", + "FFB21D", + "CFD231", + "48F90A", + "92CC17", + "3DDB86", + "1A9334", + "00D4BB", + "2C99A8", + "00C2FF", + "344593", + "6473FF", + "0018EC", + "8438FF", + "520085", + "CB38FF", + "FF95C8", + "FF37C7", + ) + self.palette = [self.hex2rgb(f"#{c}") for c in hexs] + self.n = len(self.palette) + self.pose_palette = np.array( + [ + [255, 128, 0], + [255, 153, 51], + [255, 178, 102], + [230, 230, 0], + [255, 153, 255], + [153, 204, 255], + [255, 102, 255], + [255, 51, 255], + [102, 178, 255], + [51, 153, 255], + [255, 153, 153], + [255, 102, 102], + [255, 51, 51], + [153, 255, 153], + [102, 255, 102], + [51, 255, 51], + [0, 255, 0], + [0, 0, 255], + [255, 0, 0], + [255, 255, 255], + ], + dtype=np.uint8, + ) + + def __call__(self, i, bgr=False): + """Converts hex color codes to RGB values.""" + c = self.palette[int(i) % self.n] + return (c[2], c[1], c[0]) if bgr else c + + @staticmethod + def hex2rgb(h): + """Converts hex color codes to RGB values (i.e. default PIL order).""" + return tuple(int(h[1 + i: 1 + i + 2], 16) for i in (0, 2, 4)) + + +colors = Colors() + + +class Annotator: + + def __init__(self, im, line_width=None, font_size=None, font="Arial.ttf", pil=False, example="abc"): + self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) + assert im.data.contiguous, "Image not contiguous. Apply np.ascontiguousarray(im) to Annotator input images." + self.im = im if im.flags.writeable else im.copy() + self.tf = max(self.lw - 1, 1) # font thickness + self.sf = self.lw / 3 # font scale + + def box_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), rotated=False): + """Add one xyxy box to image with label.""" + p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3])) + cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA) + if label: + w, h = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0] # text width, height + outside = p1[1] - h >= 3 + p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3 + cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled + cv2.putText( + self.im, + label, + (p1[0], p1[1] - 2 if outside else p1[1] + h + 2), + 0, + self.sf, + txt_color, + thickness=self.tf, + lineType=cv2.LINE_AA, + ) + + def result(self): + """Return annotated image as array.""" + return np.asarray(self.im) \ No newline at end of file diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db index 765b724..9973e56 100644 --- a/db/safe-algo-pro.db +++ b/db/safe-algo-pro.db Binary files differ diff --git a/entity/base.py b/entity/base.py index 6aa862c..7e40119 100644 --- a/entity/base.py +++ b/entity/base.py @@ -2,6 +2,8 @@ from typing import Optional from datetime import datetime +from common.biz_exception import BizException + class TimestampMixin(SQLModel): create_time: Optional[datetime] = Field(default_factory=datetime.now) @@ -11,3 +13,15 @@ json_encoders = { datetime: lambda v: v.strftime('%Y-%m-%d %H:%M:%S') } + + +# 自定义函数将字符串转换为 datetime +def parse_datetime(date_str: str): + try: + if date_str: + # 尝试解析前端传来的字符串格式时间 + return datetime.strptime(date_str, "%Y-%m-%d %H:%M:%S") + else: + return None + except ValueError: + raise BizException(message=f"Incorrect datetime format, should be YYYY-MM-DD HH:MM:SS") diff --git a/algo/algo_runner.py b/algo/algo_runner.py index a97aee9..d59d93a 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,4 @@ +import concurrent.futures import copy import uuid from typing import Dict @@ -31,7 +32,7 @@ self.model_service.register_change_callback(self.on_model_change) self.relation_service.register_change_callback(self.on_relation_change) - def start(self): + async def start(self): logger.info("Starting AlgoRunner...") """在程序启动时调用,读取设备和模型,启动检测线程""" self.model_manager.load_models() @@ -102,12 +103,12 @@ self.device_tasks[device_id].stop_detection_task() try: # 设置超时时间等待任务停止(例如10秒) - result = future.result(timeout=10) + result = future.result(timeout=30) logger.info(f"Task {thread_id} stopped successfully.") - except TimeoutError: + except concurrent.futures.TimeoutError as te: logger.error(f"Task {thread_id} did not stop within the timeout.") except Exception as e: - logger.error(f"Task {thread_id} encountered an error while stopping: {e}") + logger.exception(f"Task {thread_id} encountered an error while stopping: {e}") finally: # 确保无论任务是否停止,都将其从任务列表中移除 del self.device_tasks[device_id] diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 5e1c75d..4ce760c 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -95,6 +95,10 @@ self.frame_analysis_result_service.add_frame_analysis_results(frame_results) def run(self): + while not self.stream_loader.init: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + self.stream_loader.init_cap() for frame in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 9a1199e..d5cdef1 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -23,17 +23,20 @@ self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 self.thread_pool = GlobalThreadPool() - self.thread_pool.submit_task(self.update) + # self.thread = Thread(target=self.update, daemon=True) # self.thread.start() def init_cap(self): + if self.init: + return self.cap = self.get_connect() if self.cap: self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) _, self.frame = self.cap.read() + self.thread_pool.submit_task(self.update) self.init = True def create_capture(self): @@ -58,7 +61,6 @@ logger.info(f"{self.url} stopping connection attempts...thread {self.device_thread_id} stopped") return None # 退出循环,返回 None 或者其他你认为合适的值 logger.info(f"{self.url} try to connect...") - print('cap') cap = self.create_capture() if cap is None or not cap.isOpened(): logger.info(f"{self.url} connect failed, retry after {self.retry_interval} second...") @@ -71,9 +73,8 @@ vid_n = 0 log_n = 0 while not self.__stop_event.is_set(): - print('update') - if not self.init: - self.init_cap() + # if not self.init: + # self.init_cap() if self.cap is None: continue vid_n += 1 diff --git a/apis/base.py b/apis/base.py index a7c4ede..b442a71 100644 --- a/apis/base.py +++ b/apis/base.py @@ -24,3 +24,5 @@ def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False): return StandardResponse(data=data, code=code, message=message, success=success) + + diff --git a/apis/frame.py b/apis/frame.py new file mode 100644 index 0000000..aa366c8 --- /dev/null +++ b/apis/frame.py @@ -0,0 +1,75 @@ +import io +import os +from io import BytesIO +from typing import List, Optional + +import cv2 +from fastapi import APIRouter, Depends, Query, HTTPException, Request +from sqlmodel import Session +from fastapi.responses import StreamingResponse, FileResponse + +from apis.base import StandardResponse, PageResponse, standard_response, standard_error_response +from common.biz_exception import BizException +from db.database import get_db +from entity.base import parse_datetime +from entity.device_frame import DeviceFrame +from services.device_frame_service import DeviceFrameService + +router = APIRouter() + + +@router.get("/page/", response_model=StandardResponse[PageResponse[DeviceFrame]]) +def get_frame_page( + device_name: Optional[str] = None, + device_code: Optional[str] = None, + frame_start_time: Optional[str] = None, + frame_end_time: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = DeviceFrameService(db) + + # 获取分页后的设备列表和总数 + frames, total = service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), + parse_datetime(frame_end_time), offset, limit) + + return standard_response( + data=PageResponse(total=total, items=frames) + ) + + +# 路由:使用 OpenCV 生成内存图像并返回字节流 +@router.get("/frame_image/") +def get_frame_image(frame_id, db: Session = Depends(get_db)): + try: + service = DeviceFrameService(db) + frame = service.get_frame_annotator(frame_id) + if frame is None: + return standard_error_response(message="Frame does not exist") + + _, img_encoded = cv2.imencode('.jpg', frame) + img_bytes = img_encoded.tobytes() # 转换为字节格式 + + # 定义生成器分块读取内存中的图片数据 + def iterfile(): + chunk_size = 1024 * 64 # 64KB 每块 + file_like = io.BytesIO(img_bytes) # 使用 BytesIO 将字节数据视为文件对象 + while chunk := file_like.read(chunk_size): + yield chunk + + return StreamingResponse(iterfile(), media_type="image/jpeg") + + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error reading image: {str(e)}") + + +@router.get("/frame_test") +def get_frame_test(request: Request): + file_path = "test.jpg" + return FileResponse(file_path) + # def iterfile(): + # with open(file_path, 'rb') as file: + # while chunk := file.read(1024): + # yield chunk + # + # return StreamingResponse(iterfile(), media_type="image/jpeg") diff --git a/apis/router.py b/apis/router.py index 22b5ed6..de08805 100644 --- a/apis/router.py +++ b/apis/router.py @@ -3,6 +3,7 @@ from .device import router as devices_router from .model import router as models_router from .device_model_realtion import router as device_model_relation_router +from .frame import router as frame_router # 创建一个全局的 router @@ -12,4 +13,4 @@ router.include_router(devices_router, prefix="/device", tags=["Devices"]) router.include_router(models_router, prefix="/model", tags=["Models"]) router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"]) - +router.include_router(frame_router, prefix="/frame", tags=["DeviceFrame"]) diff --git a/common/image_plotting.py b/common/image_plotting.py new file mode 100644 index 0000000..818ab95 --- /dev/null +++ b/common/image_plotting.py @@ -0,0 +1,115 @@ +import numpy as np +import cv2 + + +class Colors: + """ + Ultralytics default color palette https://ultralytics.com/. + + This class provides methods to work with the Ultralytics color palette, including converting hex color codes to + RGB values. + + Attributes: + palette (list of tuple): List of RGB color values. + n (int): The number of colors in the palette. + pose_palette (np.ndarray): A specific color palette array with dtype np.uint8. + """ + + def __init__(self): + """Initialize colors as hex = matplotlib.colors.TABLEAU_COLORS.values().""" + hexs = ( + "FF3838", + "FF9D97", + "FF701F", + "FFB21D", + "CFD231", + "48F90A", + "92CC17", + "3DDB86", + "1A9334", + "00D4BB", + "2C99A8", + "00C2FF", + "344593", + "6473FF", + "0018EC", + "8438FF", + "520085", + "CB38FF", + "FF95C8", + "FF37C7", + ) + self.palette = [self.hex2rgb(f"#{c}") for c in hexs] + self.n = len(self.palette) + self.pose_palette = np.array( + [ + [255, 128, 0], + [255, 153, 51], + [255, 178, 102], + [230, 230, 0], + [255, 153, 255], + [153, 204, 255], + [255, 102, 255], + [255, 51, 255], + [102, 178, 255], + [51, 153, 255], + [255, 153, 153], + [255, 102, 102], + [255, 51, 51], + [153, 255, 153], + [102, 255, 102], + [51, 255, 51], + [0, 255, 0], + [0, 0, 255], + [255, 0, 0], + [255, 255, 255], + ], + dtype=np.uint8, + ) + + def __call__(self, i, bgr=False): + """Converts hex color codes to RGB values.""" + c = self.palette[int(i) % self.n] + return (c[2], c[1], c[0]) if bgr else c + + @staticmethod + def hex2rgb(h): + """Converts hex color codes to RGB values (i.e. default PIL order).""" + return tuple(int(h[1 + i: 1 + i + 2], 16) for i in (0, 2, 4)) + + +colors = Colors() + + +class Annotator: + + def __init__(self, im, line_width=None, font_size=None, font="Arial.ttf", pil=False, example="abc"): + self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) + assert im.data.contiguous, "Image not contiguous. Apply np.ascontiguousarray(im) to Annotator input images." + self.im = im if im.flags.writeable else im.copy() + self.tf = max(self.lw - 1, 1) # font thickness + self.sf = self.lw / 3 # font scale + + def box_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), rotated=False): + """Add one xyxy box to image with label.""" + p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3])) + cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA) + if label: + w, h = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0] # text width, height + outside = p1[1] - h >= 3 + p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3 + cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled + cv2.putText( + self.im, + label, + (p1[0], p1[1] - 2 if outside else p1[1] + h + 2), + 0, + self.sf, + txt_color, + thickness=self.tf, + lineType=cv2.LINE_AA, + ) + + def result(self): + """Return annotated image as array.""" + return np.asarray(self.im) \ No newline at end of file diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db index 765b724..9973e56 100644 --- a/db/safe-algo-pro.db +++ b/db/safe-algo-pro.db Binary files differ diff --git a/entity/base.py b/entity/base.py index 6aa862c..7e40119 100644 --- a/entity/base.py +++ b/entity/base.py @@ -2,6 +2,8 @@ from typing import Optional from datetime import datetime +from common.biz_exception import BizException + class TimestampMixin(SQLModel): create_time: Optional[datetime] = Field(default_factory=datetime.now) @@ -11,3 +13,15 @@ json_encoders = { datetime: lambda v: v.strftime('%Y-%m-%d %H:%M:%S') } + + +# 自定义函数将字符串转换为 datetime +def parse_datetime(date_str: str): + try: + if date_str: + # 尝试解析前端传来的字符串格式时间 + return datetime.strptime(date_str, "%Y-%m-%d %H:%M:%S") + else: + return None + except ValueError: + raise BizException(message=f"Incorrect datetime format, should be YYYY-MM-DD HH:MM:SS") diff --git a/main.py b/main.py index 1121638..61cbe9e 100644 --- a/main.py +++ b/main.py @@ -1,3 +1,4 @@ +import asyncio from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException @@ -28,7 +29,8 @@ @asynccontextmanager async def lifespan(app: FastAPI): # 应用启动时的初始化 - algo_runner.start() + await algo_runner.start() + # app.add_event_handler("startup", lambda: asyncio.create_task(algo_runner.start())) # 允许请求处理 yield # 应用关闭时的清理逻辑 diff --git a/algo/algo_runner.py b/algo/algo_runner.py index a97aee9..d59d93a 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,4 @@ +import concurrent.futures import copy import uuid from typing import Dict @@ -31,7 +32,7 @@ self.model_service.register_change_callback(self.on_model_change) self.relation_service.register_change_callback(self.on_relation_change) - def start(self): + async def start(self): logger.info("Starting AlgoRunner...") """在程序启动时调用,读取设备和模型,启动检测线程""" self.model_manager.load_models() @@ -102,12 +103,12 @@ self.device_tasks[device_id].stop_detection_task() try: # 设置超时时间等待任务停止(例如10秒) - result = future.result(timeout=10) + result = future.result(timeout=30) logger.info(f"Task {thread_id} stopped successfully.") - except TimeoutError: + except concurrent.futures.TimeoutError as te: logger.error(f"Task {thread_id} did not stop within the timeout.") except Exception as e: - logger.error(f"Task {thread_id} encountered an error while stopping: {e}") + logger.exception(f"Task {thread_id} encountered an error while stopping: {e}") finally: # 确保无论任务是否停止,都将其从任务列表中移除 del self.device_tasks[device_id] diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 5e1c75d..4ce760c 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -95,6 +95,10 @@ self.frame_analysis_result_service.add_frame_analysis_results(frame_results) def run(self): + while not self.stream_loader.init: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + self.stream_loader.init_cap() for frame in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 9a1199e..d5cdef1 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -23,17 +23,20 @@ self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 self.thread_pool = GlobalThreadPool() - self.thread_pool.submit_task(self.update) + # self.thread = Thread(target=self.update, daemon=True) # self.thread.start() def init_cap(self): + if self.init: + return self.cap = self.get_connect() if self.cap: self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) _, self.frame = self.cap.read() + self.thread_pool.submit_task(self.update) self.init = True def create_capture(self): @@ -58,7 +61,6 @@ logger.info(f"{self.url} stopping connection attempts...thread {self.device_thread_id} stopped") return None # 退出循环,返回 None 或者其他你认为合适的值 logger.info(f"{self.url} try to connect...") - print('cap') cap = self.create_capture() if cap is None or not cap.isOpened(): logger.info(f"{self.url} connect failed, retry after {self.retry_interval} second...") @@ -71,9 +73,8 @@ vid_n = 0 log_n = 0 while not self.__stop_event.is_set(): - print('update') - if not self.init: - self.init_cap() + # if not self.init: + # self.init_cap() if self.cap is None: continue vid_n += 1 diff --git a/apis/base.py b/apis/base.py index a7c4ede..b442a71 100644 --- a/apis/base.py +++ b/apis/base.py @@ -24,3 +24,5 @@ def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False): return StandardResponse(data=data, code=code, message=message, success=success) + + diff --git a/apis/frame.py b/apis/frame.py new file mode 100644 index 0000000..aa366c8 --- /dev/null +++ b/apis/frame.py @@ -0,0 +1,75 @@ +import io +import os +from io import BytesIO +from typing import List, Optional + +import cv2 +from fastapi import APIRouter, Depends, Query, HTTPException, Request +from sqlmodel import Session +from fastapi.responses import StreamingResponse, FileResponse + +from apis.base import StandardResponse, PageResponse, standard_response, standard_error_response +from common.biz_exception import BizException +from db.database import get_db +from entity.base import parse_datetime +from entity.device_frame import DeviceFrame +from services.device_frame_service import DeviceFrameService + +router = APIRouter() + + +@router.get("/page/", response_model=StandardResponse[PageResponse[DeviceFrame]]) +def get_frame_page( + device_name: Optional[str] = None, + device_code: Optional[str] = None, + frame_start_time: Optional[str] = None, + frame_end_time: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = DeviceFrameService(db) + + # 获取分页后的设备列表和总数 + frames, total = service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), + parse_datetime(frame_end_time), offset, limit) + + return standard_response( + data=PageResponse(total=total, items=frames) + ) + + +# 路由:使用 OpenCV 生成内存图像并返回字节流 +@router.get("/frame_image/") +def get_frame_image(frame_id, db: Session = Depends(get_db)): + try: + service = DeviceFrameService(db) + frame = service.get_frame_annotator(frame_id) + if frame is None: + return standard_error_response(message="Frame does not exist") + + _, img_encoded = cv2.imencode('.jpg', frame) + img_bytes = img_encoded.tobytes() # 转换为字节格式 + + # 定义生成器分块读取内存中的图片数据 + def iterfile(): + chunk_size = 1024 * 64 # 64KB 每块 + file_like = io.BytesIO(img_bytes) # 使用 BytesIO 将字节数据视为文件对象 + while chunk := file_like.read(chunk_size): + yield chunk + + return StreamingResponse(iterfile(), media_type="image/jpeg") + + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error reading image: {str(e)}") + + +@router.get("/frame_test") +def get_frame_test(request: Request): + file_path = "test.jpg" + return FileResponse(file_path) + # def iterfile(): + # with open(file_path, 'rb') as file: + # while chunk := file.read(1024): + # yield chunk + # + # return StreamingResponse(iterfile(), media_type="image/jpeg") diff --git a/apis/router.py b/apis/router.py index 22b5ed6..de08805 100644 --- a/apis/router.py +++ b/apis/router.py @@ -3,6 +3,7 @@ from .device import router as devices_router from .model import router as models_router from .device_model_realtion import router as device_model_relation_router +from .frame import router as frame_router # 创建一个全局的 router @@ -12,4 +13,4 @@ router.include_router(devices_router, prefix="/device", tags=["Devices"]) router.include_router(models_router, prefix="/model", tags=["Models"]) router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"]) - +router.include_router(frame_router, prefix="/frame", tags=["DeviceFrame"]) diff --git a/common/image_plotting.py b/common/image_plotting.py new file mode 100644 index 0000000..818ab95 --- /dev/null +++ b/common/image_plotting.py @@ -0,0 +1,115 @@ +import numpy as np +import cv2 + + +class Colors: + """ + Ultralytics default color palette https://ultralytics.com/. + + This class provides methods to work with the Ultralytics color palette, including converting hex color codes to + RGB values. + + Attributes: + palette (list of tuple): List of RGB color values. + n (int): The number of colors in the palette. + pose_palette (np.ndarray): A specific color palette array with dtype np.uint8. + """ + + def __init__(self): + """Initialize colors as hex = matplotlib.colors.TABLEAU_COLORS.values().""" + hexs = ( + "FF3838", + "FF9D97", + "FF701F", + "FFB21D", + "CFD231", + "48F90A", + "92CC17", + "3DDB86", + "1A9334", + "00D4BB", + "2C99A8", + "00C2FF", + "344593", + "6473FF", + "0018EC", + "8438FF", + "520085", + "CB38FF", + "FF95C8", + "FF37C7", + ) + self.palette = [self.hex2rgb(f"#{c}") for c in hexs] + self.n = len(self.palette) + self.pose_palette = np.array( + [ + [255, 128, 0], + [255, 153, 51], + [255, 178, 102], + [230, 230, 0], + [255, 153, 255], + [153, 204, 255], + [255, 102, 255], + [255, 51, 255], + [102, 178, 255], + [51, 153, 255], + [255, 153, 153], + [255, 102, 102], + [255, 51, 51], + [153, 255, 153], + [102, 255, 102], + [51, 255, 51], + [0, 255, 0], + [0, 0, 255], + [255, 0, 0], + [255, 255, 255], + ], + dtype=np.uint8, + ) + + def __call__(self, i, bgr=False): + """Converts hex color codes to RGB values.""" + c = self.palette[int(i) % self.n] + return (c[2], c[1], c[0]) if bgr else c + + @staticmethod + def hex2rgb(h): + """Converts hex color codes to RGB values (i.e. default PIL order).""" + return tuple(int(h[1 + i: 1 + i + 2], 16) for i in (0, 2, 4)) + + +colors = Colors() + + +class Annotator: + + def __init__(self, im, line_width=None, font_size=None, font="Arial.ttf", pil=False, example="abc"): + self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) + assert im.data.contiguous, "Image not contiguous. Apply np.ascontiguousarray(im) to Annotator input images." + self.im = im if im.flags.writeable else im.copy() + self.tf = max(self.lw - 1, 1) # font thickness + self.sf = self.lw / 3 # font scale + + def box_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), rotated=False): + """Add one xyxy box to image with label.""" + p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3])) + cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA) + if label: + w, h = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0] # text width, height + outside = p1[1] - h >= 3 + p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3 + cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled + cv2.putText( + self.im, + label, + (p1[0], p1[1] - 2 if outside else p1[1] + h + 2), + 0, + self.sf, + txt_color, + thickness=self.tf, + lineType=cv2.LINE_AA, + ) + + def result(self): + """Return annotated image as array.""" + return np.asarray(self.im) \ No newline at end of file diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db index 765b724..9973e56 100644 --- a/db/safe-algo-pro.db +++ b/db/safe-algo-pro.db Binary files differ diff --git a/entity/base.py b/entity/base.py index 6aa862c..7e40119 100644 --- a/entity/base.py +++ b/entity/base.py @@ -2,6 +2,8 @@ from typing import Optional from datetime import datetime +from common.biz_exception import BizException + class TimestampMixin(SQLModel): create_time: Optional[datetime] = Field(default_factory=datetime.now) @@ -11,3 +13,15 @@ json_encoders = { datetime: lambda v: v.strftime('%Y-%m-%d %H:%M:%S') } + + +# 自定义函数将字符串转换为 datetime +def parse_datetime(date_str: str): + try: + if date_str: + # 尝试解析前端传来的字符串格式时间 + return datetime.strptime(date_str, "%Y-%m-%d %H:%M:%S") + else: + return None + except ValueError: + raise BizException(message=f"Incorrect datetime format, should be YYYY-MM-DD HH:MM:SS") diff --git a/main.py b/main.py index 1121638..61cbe9e 100644 --- a/main.py +++ b/main.py @@ -1,3 +1,4 @@ +import asyncio from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException @@ -28,7 +29,8 @@ @asynccontextmanager async def lifespan(app: FastAPI): # 应用启动时的初始化 - algo_runner.start() + await algo_runner.start() + # app.add_event_handler("startup", lambda: asyncio.create_task(algo_runner.start())) # 允许请求处理 yield # 应用关闭时的清理逻辑 diff --git a/services/device_frame_service.py b/services/device_frame_service.py index 6b9c308..5dead7d 100644 --- a/services/device_frame_service.py +++ b/services/device_frame_service.py @@ -1,12 +1,22 @@ import os import uuid +from copy import deepcopy from datetime import datetime - +from typing import Sequence, Optional, Tuple from sqlmodel import Session + +from sqlalchemy import func + +from common.image_plotting import Annotator, colors +from entity.device import Device from entity.device_frame import DeviceFrame import cv2 +from sqlmodel import Session, select, delete + +from services.frame_analysis_result_service import FrameAnalysisResultService + class DeviceFrameService: @@ -34,3 +44,68 @@ self.db.commit() self.db.refresh(device_frame) return device_frame + + def get_frame_page(self, + device_name: Optional[str] = None, + device_code: Optional[str] = None, + frame_start_time: Optional[datetime] = None, + frame_end_time: Optional[datetime] = None, + offset: int = 0, + limit: int = 10 + ) -> Tuple[Sequence[DeviceFrame], int]: + statement = ( + select(DeviceFrame, Device) + .join(Device, DeviceFrame.device_id == Device.id) + ) + + if device_name: + statement = statement.where(Device.name.like(f"%{device_name}%")) + if device_code: + statement = statement.where(Device.code.like(f"%{device_code}%")) + if frame_start_time: + statement = statement.where(DeviceFrame.time >= frame_start_time) + if frame_start_time: + statement = statement.where(DeviceFrame.time <= frame_end_time) + + # 查询总记录数 + total_statement = select(func.count()).select_from(statement.subquery()) + total = self.db.exec(total_statement).one() + + # 添加分页限制 + statement = statement.offset(offset).limit(limit) + + # 执行查询并返回结果 + results = self.db.exec(statement) + results = results.all() + frames = [frame for frame, device in results] + return frames, total # 返回分页数据和总数 + + def get_frame(self, frame_id: int): + return self.db.get(DeviceFrame, frame_id) + + def get_frame_annotator(self, frame_id: int): + device_frame = self.get_frame(frame_id) + if device_frame: + frame_image = cv2.imread(device_frame.frame_path) + frame_analysis_result_service = FrameAnalysisResultService(self.db) + results = frame_analysis_result_service.get_results_by_frame(device_frame.id) + if results: + annotator = Annotator(deepcopy(frame_image)) + height, width = frame_image.shape[:2] + + for result in results: + # 将归一化的坐标恢复成实际的像素坐标 + xyxyn = [float(coord) for coord in result.location.split(",")] + x_min = int(xyxyn[0] * width) + y_min = int(xyxyn[1] * height) + x_max = int(xyxyn[2] * width) + y_max = int(xyxyn[3] * height) + + # 恢复后的实际坐标 + box = [x_min, y_min, x_max, y_max] + annotator.box_label(box, label=f'{result.object_class_name} {result.confidence:.2f}', + color=colors(result.object_class_id)) + return annotator.result() + else: + return frame_image + return None diff --git a/algo/algo_runner.py b/algo/algo_runner.py index a97aee9..d59d93a 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -1,3 +1,4 @@ +import concurrent.futures import copy import uuid from typing import Dict @@ -31,7 +32,7 @@ self.model_service.register_change_callback(self.on_model_change) self.relation_service.register_change_callback(self.on_relation_change) - def start(self): + async def start(self): logger.info("Starting AlgoRunner...") """在程序启动时调用,读取设备和模型,启动检测线程""" self.model_manager.load_models() @@ -102,12 +103,12 @@ self.device_tasks[device_id].stop_detection_task() try: # 设置超时时间等待任务停止(例如10秒) - result = future.result(timeout=10) + result = future.result(timeout=30) logger.info(f"Task {thread_id} stopped successfully.") - except TimeoutError: + except concurrent.futures.TimeoutError as te: logger.error(f"Task {thread_id} did not stop within the timeout.") except Exception as e: - logger.error(f"Task {thread_id} encountered an error while stopping: {e}") + logger.exception(f"Task {thread_id} encountered an error while stopping: {e}") finally: # 确保无论任务是否停止,都将其从任务列表中移除 del self.device_tasks[device_id] diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index 5e1c75d..4ce760c 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -95,6 +95,10 @@ self.frame_analysis_result_service.add_frame_analysis_results(frame_results) def run(self): + while not self.stream_loader.init: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + self.stream_loader.init_cap() for frame in self.stream_loader: if self.__stop_event.is_set(): break # 如果触发了停止事件,则退出循环 diff --git a/algo/stream_loader.py b/algo/stream_loader.py index 9a1199e..d5cdef1 100644 --- a/algo/stream_loader.py +++ b/algo/stream_loader.py @@ -23,17 +23,20 @@ self.__stop_event = Event() # 增加 stop_event 作为停止线程的标志 self.thread_pool = GlobalThreadPool() - self.thread_pool.submit_task(self.update) + # self.thread = Thread(target=self.update, daemon=True) # self.thread.start() def init_cap(self): + if self.init: + return self.cap = self.get_connect() if self.cap: self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS)) _, self.frame = self.cap.read() + self.thread_pool.submit_task(self.update) self.init = True def create_capture(self): @@ -58,7 +61,6 @@ logger.info(f"{self.url} stopping connection attempts...thread {self.device_thread_id} stopped") return None # 退出循环,返回 None 或者其他你认为合适的值 logger.info(f"{self.url} try to connect...") - print('cap') cap = self.create_capture() if cap is None or not cap.isOpened(): logger.info(f"{self.url} connect failed, retry after {self.retry_interval} second...") @@ -71,9 +73,8 @@ vid_n = 0 log_n = 0 while not self.__stop_event.is_set(): - print('update') - if not self.init: - self.init_cap() + # if not self.init: + # self.init_cap() if self.cap is None: continue vid_n += 1 diff --git a/apis/base.py b/apis/base.py index a7c4ede..b442a71 100644 --- a/apis/base.py +++ b/apis/base.py @@ -24,3 +24,5 @@ def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False): return StandardResponse(data=data, code=code, message=message, success=success) + + diff --git a/apis/frame.py b/apis/frame.py new file mode 100644 index 0000000..aa366c8 --- /dev/null +++ b/apis/frame.py @@ -0,0 +1,75 @@ +import io +import os +from io import BytesIO +from typing import List, Optional + +import cv2 +from fastapi import APIRouter, Depends, Query, HTTPException, Request +from sqlmodel import Session +from fastapi.responses import StreamingResponse, FileResponse + +from apis.base import StandardResponse, PageResponse, standard_response, standard_error_response +from common.biz_exception import BizException +from db.database import get_db +from entity.base import parse_datetime +from entity.device_frame import DeviceFrame +from services.device_frame_service import DeviceFrameService + +router = APIRouter() + + +@router.get("/page/", response_model=StandardResponse[PageResponse[DeviceFrame]]) +def get_frame_page( + device_name: Optional[str] = None, + device_code: Optional[str] = None, + frame_start_time: Optional[str] = None, + frame_end_time: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = DeviceFrameService(db) + + # 获取分页后的设备列表和总数 + frames, total = service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), + parse_datetime(frame_end_time), offset, limit) + + return standard_response( + data=PageResponse(total=total, items=frames) + ) + + +# 路由:使用 OpenCV 生成内存图像并返回字节流 +@router.get("/frame_image/") +def get_frame_image(frame_id, db: Session = Depends(get_db)): + try: + service = DeviceFrameService(db) + frame = service.get_frame_annotator(frame_id) + if frame is None: + return standard_error_response(message="Frame does not exist") + + _, img_encoded = cv2.imencode('.jpg', frame) + img_bytes = img_encoded.tobytes() # 转换为字节格式 + + # 定义生成器分块读取内存中的图片数据 + def iterfile(): + chunk_size = 1024 * 64 # 64KB 每块 + file_like = io.BytesIO(img_bytes) # 使用 BytesIO 将字节数据视为文件对象 + while chunk := file_like.read(chunk_size): + yield chunk + + return StreamingResponse(iterfile(), media_type="image/jpeg") + + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error reading image: {str(e)}") + + +@router.get("/frame_test") +def get_frame_test(request: Request): + file_path = "test.jpg" + return FileResponse(file_path) + # def iterfile(): + # with open(file_path, 'rb') as file: + # while chunk := file.read(1024): + # yield chunk + # + # return StreamingResponse(iterfile(), media_type="image/jpeg") diff --git a/apis/router.py b/apis/router.py index 22b5ed6..de08805 100644 --- a/apis/router.py +++ b/apis/router.py @@ -3,6 +3,7 @@ from .device import router as devices_router from .model import router as models_router from .device_model_realtion import router as device_model_relation_router +from .frame import router as frame_router # 创建一个全局的 router @@ -12,4 +13,4 @@ router.include_router(devices_router, prefix="/device", tags=["Devices"]) router.include_router(models_router, prefix="/model", tags=["Models"]) router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"]) - +router.include_router(frame_router, prefix="/frame", tags=["DeviceFrame"]) diff --git a/common/image_plotting.py b/common/image_plotting.py new file mode 100644 index 0000000..818ab95 --- /dev/null +++ b/common/image_plotting.py @@ -0,0 +1,115 @@ +import numpy as np +import cv2 + + +class Colors: + """ + Ultralytics default color palette https://ultralytics.com/. + + This class provides methods to work with the Ultralytics color palette, including converting hex color codes to + RGB values. + + Attributes: + palette (list of tuple): List of RGB color values. + n (int): The number of colors in the palette. + pose_palette (np.ndarray): A specific color palette array with dtype np.uint8. + """ + + def __init__(self): + """Initialize colors as hex = matplotlib.colors.TABLEAU_COLORS.values().""" + hexs = ( + "FF3838", + "FF9D97", + "FF701F", + "FFB21D", + "CFD231", + "48F90A", + "92CC17", + "3DDB86", + "1A9334", + "00D4BB", + "2C99A8", + "00C2FF", + "344593", + "6473FF", + "0018EC", + "8438FF", + "520085", + "CB38FF", + "FF95C8", + "FF37C7", + ) + self.palette = [self.hex2rgb(f"#{c}") for c in hexs] + self.n = len(self.palette) + self.pose_palette = np.array( + [ + [255, 128, 0], + [255, 153, 51], + [255, 178, 102], + [230, 230, 0], + [255, 153, 255], + [153, 204, 255], + [255, 102, 255], + [255, 51, 255], + [102, 178, 255], + [51, 153, 255], + [255, 153, 153], + [255, 102, 102], + [255, 51, 51], + [153, 255, 153], + [102, 255, 102], + [51, 255, 51], + [0, 255, 0], + [0, 0, 255], + [255, 0, 0], + [255, 255, 255], + ], + dtype=np.uint8, + ) + + def __call__(self, i, bgr=False): + """Converts hex color codes to RGB values.""" + c = self.palette[int(i) % self.n] + return (c[2], c[1], c[0]) if bgr else c + + @staticmethod + def hex2rgb(h): + """Converts hex color codes to RGB values (i.e. default PIL order).""" + return tuple(int(h[1 + i: 1 + i + 2], 16) for i in (0, 2, 4)) + + +colors = Colors() + + +class Annotator: + + def __init__(self, im, line_width=None, font_size=None, font="Arial.ttf", pil=False, example="abc"): + self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) + assert im.data.contiguous, "Image not contiguous. Apply np.ascontiguousarray(im) to Annotator input images." + self.im = im if im.flags.writeable else im.copy() + self.tf = max(self.lw - 1, 1) # font thickness + self.sf = self.lw / 3 # font scale + + def box_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), rotated=False): + """Add one xyxy box to image with label.""" + p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3])) + cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA) + if label: + w, h = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0] # text width, height + outside = p1[1] - h >= 3 + p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3 + cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled + cv2.putText( + self.im, + label, + (p1[0], p1[1] - 2 if outside else p1[1] + h + 2), + 0, + self.sf, + txt_color, + thickness=self.tf, + lineType=cv2.LINE_AA, + ) + + def result(self): + """Return annotated image as array.""" + return np.asarray(self.im) \ No newline at end of file diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db index 765b724..9973e56 100644 --- a/db/safe-algo-pro.db +++ b/db/safe-algo-pro.db Binary files differ diff --git a/entity/base.py b/entity/base.py index 6aa862c..7e40119 100644 --- a/entity/base.py +++ b/entity/base.py @@ -2,6 +2,8 @@ from typing import Optional from datetime import datetime +from common.biz_exception import BizException + class TimestampMixin(SQLModel): create_time: Optional[datetime] = Field(default_factory=datetime.now) @@ -11,3 +13,15 @@ json_encoders = { datetime: lambda v: v.strftime('%Y-%m-%d %H:%M:%S') } + + +# 自定义函数将字符串转换为 datetime +def parse_datetime(date_str: str): + try: + if date_str: + # 尝试解析前端传来的字符串格式时间 + return datetime.strptime(date_str, "%Y-%m-%d %H:%M:%S") + else: + return None + except ValueError: + raise BizException(message=f"Incorrect datetime format, should be YYYY-MM-DD HH:MM:SS") diff --git a/main.py b/main.py index 1121638..61cbe9e 100644 --- a/main.py +++ b/main.py @@ -1,3 +1,4 @@ +import asyncio from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException @@ -28,7 +29,8 @@ @asynccontextmanager async def lifespan(app: FastAPI): # 应用启动时的初始化 - algo_runner.start() + await algo_runner.start() + # app.add_event_handler("startup", lambda: asyncio.create_task(algo_runner.start())) # 允许请求处理 yield # 应用关闭时的清理逻辑 diff --git a/services/device_frame_service.py b/services/device_frame_service.py index 6b9c308..5dead7d 100644 --- a/services/device_frame_service.py +++ b/services/device_frame_service.py @@ -1,12 +1,22 @@ import os import uuid +from copy import deepcopy from datetime import datetime - +from typing import Sequence, Optional, Tuple from sqlmodel import Session + +from sqlalchemy import func + +from common.image_plotting import Annotator, colors +from entity.device import Device from entity.device_frame import DeviceFrame import cv2 +from sqlmodel import Session, select, delete + +from services.frame_analysis_result_service import FrameAnalysisResultService + class DeviceFrameService: @@ -34,3 +44,68 @@ self.db.commit() self.db.refresh(device_frame) return device_frame + + def get_frame_page(self, + device_name: Optional[str] = None, + device_code: Optional[str] = None, + frame_start_time: Optional[datetime] = None, + frame_end_time: Optional[datetime] = None, + offset: int = 0, + limit: int = 10 + ) -> Tuple[Sequence[DeviceFrame], int]: + statement = ( + select(DeviceFrame, Device) + .join(Device, DeviceFrame.device_id == Device.id) + ) + + if device_name: + statement = statement.where(Device.name.like(f"%{device_name}%")) + if device_code: + statement = statement.where(Device.code.like(f"%{device_code}%")) + if frame_start_time: + statement = statement.where(DeviceFrame.time >= frame_start_time) + if frame_start_time: + statement = statement.where(DeviceFrame.time <= frame_end_time) + + # 查询总记录数 + total_statement = select(func.count()).select_from(statement.subquery()) + total = self.db.exec(total_statement).one() + + # 添加分页限制 + statement = statement.offset(offset).limit(limit) + + # 执行查询并返回结果 + results = self.db.exec(statement) + results = results.all() + frames = [frame for frame, device in results] + return frames, total # 返回分页数据和总数 + + def get_frame(self, frame_id: int): + return self.db.get(DeviceFrame, frame_id) + + def get_frame_annotator(self, frame_id: int): + device_frame = self.get_frame(frame_id) + if device_frame: + frame_image = cv2.imread(device_frame.frame_path) + frame_analysis_result_service = FrameAnalysisResultService(self.db) + results = frame_analysis_result_service.get_results_by_frame(device_frame.id) + if results: + annotator = Annotator(deepcopy(frame_image)) + height, width = frame_image.shape[:2] + + for result in results: + # 将归一化的坐标恢复成实际的像素坐标 + xyxyn = [float(coord) for coord in result.location.split(",")] + x_min = int(xyxyn[0] * width) + y_min = int(xyxyn[1] * height) + x_max = int(xyxyn[2] * width) + y_max = int(xyxyn[3] * height) + + # 恢复后的实际坐标 + box = [x_min, y_min, x_max, y_max] + annotator.box_label(box, label=f'{result.object_class_name} {result.confidence:.2f}', + color=colors(result.object_class_id)) + return annotator.result() + else: + return frame_image + return None diff --git a/services/frame_analysis_result_service.py b/services/frame_analysis_result_service.py index 3884fa3..dbd272d 100644 --- a/services/frame_analysis_result_service.py +++ b/services/frame_analysis_result_service.py @@ -1,6 +1,6 @@ from typing import List -from sqlmodel import Session +from sqlmodel import Session, select from entity.frame_analysis_result import FrameAnalysisResultCreate, FrameAnalysisResult @@ -16,3 +16,8 @@ for result in new_results: self.db.refresh(result) return new_results + + def get_results_by_frame(self, frame_id): + statement = select(FrameAnalysisResult).where(FrameAnalysisResult.frame_id == frame_id) + results = self.db.exec(statement) + return results.all()