Newer
Older
safe-algo-pro / services / device_frame_service.py
zhangyingjie on 18 Oct 4 KB 增加识别结果查询接口
import os
import uuid
from copy import deepcopy
from datetime import datetime
from typing import Sequence, Optional, Tuple
from sqlmodel import Session

from sqlalchemy import func

from common.image_plotting import Annotator, colors
from entity.device import Device
from entity.device_frame import DeviceFrame

import cv2

from sqlmodel import Session, select, delete

from services.frame_analysis_result_service import FrameAnalysisResultService


class DeviceFrameService:

    def __init__(self, db: Session):
        self.db = db

    def add_frame(self, device_id, frame_data) -> DeviceFrame:
        def save_frame_file():
            # 生成当前年月日作为目录路径
            current_date = datetime.now().strftime('%Y-%m-%d')
            # 生成随机 UUID 作为文件名
            file_name = str(uuid.uuid4()) + ".jpeg"
            # 创建保存图片的完整路径
            save_path = os.path.join('./storage/frames', current_date, file_name)
            # 创建目录(如果不存在)
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
            # 保存图片
            cv2.imwrite(save_path, frame_data)
            return save_path

        # 保存图片文件
        file_path = save_frame_file()
        device_frame = DeviceFrame(device_id=device_id, frame_path=file_path)
        self.db.add(device_frame)
        self.db.commit()
        self.db.refresh(device_frame)
        return device_frame

    def get_frame_page(self,
                       device_name: Optional[str] = None,
                       device_code: Optional[str] = None,
                       frame_start_time: Optional[datetime] = None,
                       frame_end_time: Optional[datetime] = None,
                       offset: int = 0,
                       limit: int = 10
                       ) -> Tuple[Sequence[DeviceFrame], int]:
        statement = (
            select(DeviceFrame, Device)
            .join(Device, DeviceFrame.device_id == Device.id)
        )

        if device_name:
            statement = statement.where(Device.name.like(f"%{device_name}%"))
        if device_code:
            statement = statement.where(Device.code.like(f"%{device_code}%"))
        if frame_start_time:
            statement = statement.where(DeviceFrame.time >= frame_start_time)
        if frame_start_time:
            statement = statement.where(DeviceFrame.time <= frame_end_time)

        # 查询总记录数
        total_statement = select(func.count()).select_from(statement.subquery())
        total = self.db.exec(total_statement).one()

        # 添加分页限制
        statement = statement.offset(offset).limit(limit)

        # 执行查询并返回结果
        results = self.db.exec(statement)
        results = results.all()
        frames = [frame for frame, device in results]
        return frames, total  # 返回分页数据和总数

    def get_frame(self, frame_id: int):
        return self.db.get(DeviceFrame, frame_id)

    def get_frame_annotator(self, frame_id: int):
        device_frame = self.get_frame(frame_id)
        if device_frame:
            frame_image = cv2.imread(device_frame.frame_path)
            frame_analysis_result_service = FrameAnalysisResultService(self.db)
            results = frame_analysis_result_service.get_results_by_frame(device_frame.id)
            if results:
                annotator = Annotator(deepcopy(frame_image))
                height, width = frame_image.shape[:2]

                for result in results:
                    # 将归一化的坐标恢复成实际的像素坐标
                    xyxyn = [float(coord) for coord in result.location.split(",")]
                    x_min = int(xyxyn[0] * width)
                    y_min = int(xyxyn[1] * height)
                    x_max = int(xyxyn[2] * width)
                    y_max = int(xyxyn[3] * height)

                    # 恢复后的实际坐标
                    box = [x_min, y_min, x_max, y_max]
                    annotator.box_label(box, label=f'{result.object_class_name} {result.confidence:.2f}',
                                        color=colors(result.object_class_id))
                return annotator.result()
            else:
                return frame_image
        return None