Newer
Older
safe-algo-pro / services / device_frame_service.py
import os
import uuid
from copy import deepcopy
from datetime import datetime
from typing import Sequence, Optional, Tuple

import aiofiles
import numpy as np
from sqlalchemy.ext.asyncio import AsyncSession

from sqlalchemy import func

from common.image_plotting import Annotator, colors
from entity.device import Device
from entity.device_frame import DeviceFrame

import cv2

from sqlmodel import select, delete

from services.frame_analysis_result_service import FrameAnalysisResultService


class DeviceFrameService:

    def __init__(self, db: AsyncSession):
        self.db = db

    async def add_frame(self, device_id, frame_data) -> DeviceFrame:
        async def add_frame(self, device_id, frame_data) -> 'DeviceFrame':
            async def save_frame_file():
                # 生成当前年月日作为目录路径
                current_date = datetime.now().strftime('%Y-%m-%d')
                # 生成随机 UUID 作为文件名
                file_name = f"{uuid.uuid4()}.jpeg"
                # 创建保存图片的完整路径
                save_path = os.path.join('./storage/frames', current_date, file_name)
                # 创建目录(如果不存在)
                os.makedirs(os.path.dirname(save_path), exist_ok=True)

                # 将 frame_data 转换为二进制数据(假设 frame_data 是一个 numpy 数组)
                _, encoded_image = cv2.imencode('.jpeg', frame_data)
                image_data = encoded_image.tobytes()

                # 使用 aiofiles 进行异步写入
                async with aiofiles.open(save_path, 'wb') as f:
                    await f.write(image_data)

                return save_path

            # 异步保存图片文件
            file_path = await save_frame_file()

            # 创建并保存到数据库中
            device_frame = DeviceFrame(device_id=device_id, frame_path=file_path)
            self.db.add(device_frame)
            await self.db.commit()
            await self.db.refresh(device_frame)
            return device_frame

    async def get_frame_page(self,
                             device_name: Optional[str] = None,
                             device_code: Optional[str] = None,
                             frame_start_time: Optional[datetime] = None,
                             frame_end_time: Optional[datetime] = None,
                             offset: int = 0,
                             limit: int = 10
                             ) -> Tuple[Sequence[DeviceFrame], int]:
        statement = (
            select(DeviceFrame, Device)
            .join(Device, DeviceFrame.device_id == Device.id)
        )

        if device_name:
            statement = statement.where(Device.name.like(f"%{device_name}%"))
        if device_code:
            statement = statement.where(Device.code.like(f"%{device_code}%"))
        if frame_start_time:
            statement = statement.where(DeviceFrame.time >= frame_start_time)
        if frame_start_time:
            statement = statement.where(DeviceFrame.time <= frame_end_time)

        # 查询总记录数
        total_statement = select(func.count()).select_from(statement.subquery())
        total_result = await self.db.execute(total_statement)
        total = total_result.scalar_one()

        # 添加分页限制
        statement = statement.offset(offset).limit(limit)

        # 执行查询并返回结果
        results = await self.db.execute(statement)
        rows = results.all()

        frames = [frame for frame, device in rows]
        return frames, total  # 返回分页数据和总数

    async def get_frame(self, frame_id: int):
        result = await self.db.execute(select(DeviceFrame).where(DeviceFrame.id == frame_id))
        frame = result.scalar_one_or_none()
        return frame

    async def get_frame_annotator(self, frame_id: int):
        device_frame = await self.get_frame(frame_id)
        if device_frame:
            # 异步读取图像文件
            async with aiofiles.open(device_frame.frame_path, mode='rb') as f:
                file_content = await f.read()

            # 将读取的字节内容转换为 OpenCV 图像
            np_array = np.frombuffer(file_content, dtype=np.uint8)
            frame_image = cv2.imdecode(np_array, cv2.IMREAD_COLOR)

            frame_analysis_result_service = FrameAnalysisResultService(self.db)
            results = await frame_analysis_result_service.get_results_by_frame(device_frame.id)
            if results:
                annotator = Annotator(deepcopy(frame_image))
                height, width = frame_image.shape[:2]

                for result in results:
                    # 将归一化的坐标恢复成实际的像素坐标
                    xyxyn = [float(coord) for coord in result.location.split(",")]
                    x_min = int(xyxyn[0] * width)
                    y_min = int(xyxyn[1] * height)
                    x_max = int(xyxyn[2] * width)
                    y_max = int(xyxyn[3] * height)

                    # 恢复后的实际坐标
                    box = [x_min, y_min, x_max, y_max]
                    annotator.box_label(box, label=f'{result.object_class_name} {result.confidence:.2f}',
                                        color=colors(result.object_class_id))
                return annotator.result()
            else:
                return frame_image
        return None