Newer
Older
safe-algo-pro / services / scene_service.py
import os
import uuid
import zipfile
from datetime import datetime
from pathlib import Path
from typing import Optional, Sequence, Tuple, List

from fastapi import UploadFile
from sqlalchemy import func
from sqlmodel import Session, select

from common.biz_exception import BizException
from common.consts import NotifyChangeType
from common.global_thread_pool import GlobalThreadPool
from common.string_utils import snake_to_camel
from entity.device_scene_relation import DeviceSceneRelation
from entity.scene import Scene, SceneInfo, SceneCreate, SceneUpdate


class SceneService:
    def __init__(self, db: Session):
        self.db = db
        self.__scene_change_callbacks = []  # 用于存储回调函数
        self.thread_pool = GlobalThreadPool()

    def register_change_callback(self, callback):
        """注册设备变化回调函数"""
        self.__scene_change_callbacks.append(callback)

    def notify_change(self, scene_id, change_type):
        """当设备发生变化时,调用回调通知变化"""
        for callback in self.__scene_change_callbacks:
            self.thread_pool.executor.submit(callback, scene_id, change_type)

    def get_scene_list(self,
                       name: Optional[str] = None,
                       remark: Optional[str] = None,
                       ) -> Sequence[Scene]:
        statement = self.scene_query(name, remark)
        results = self.db.exec(statement)
        return results.all()

    def get_scene_page(self,
                       name: Optional[str] = None,
                       remark: Optional[str] = None,
                       offset: int = 0,
                       limit: int = 10
                       ) -> Tuple[Sequence[SceneInfo], int]:
        statement = self.scene_query(name, remark)

        # 查询总记录数
        total_statement = select(func.count()).select_from(statement.subquery())
        total = self.db.exec(total_statement).one()

        # 添加分页限制
        statement = statement.offset(offset).limit(limit)

        # 执行查询并返回结果
        scene_list = self.db.exec(statement)
        scene_info_list: List[SceneInfo] = []
        if scene_list:
            for scene in scene_list:
                scene_info_list.append(SceneInfo(
                    **scene.dict(),
                    usage_status="使用中" if self.get_scene_usage(scene.id) else "未使用"
                ))

        return scene_info_list, total  # 返回分页数据和总数

    def scene_query(self, name, remark):
        # 构建查询语句
        statement = select(Scene)
        if name:
            statement = statement.where(Scene.name.like(f"%{name}%"))
        if remark:
            statement = statement.where(Scene.remark.like(f"%{remark}%"))
        return statement

    def process_zip(self, file: UploadFile):
        model_dir = Path('weights/')
        scene_handle_dir = Path('scene_handler/')
        model_dir.mkdir(parents=True, exist_ok=True)
        scene_handle_dir.mkdir(parents=True, exist_ok=True)
        # 支持的模型文件扩展名
        SUPPORTED_MODEL_EXTENSIONS = {".pt", ".onnx", ".engine"}

        # 临时保存上传文件
        temp_path = Path(f"temp_upload_{uuid.uuid4()}.zip")
        with open(temp_path, "wb") as temp_file:
            temp_file.write(file.file.read())

        model_file_paths = []
        handle_file_path = None

        try:
            with zipfile.ZipFile(temp_path, 'r') as zip_ref:
                # 获取压缩包文件列表
                file_list = zip_ref.namelist()

                model_files = [f for f in file_list if Path(f).suffix in SUPPORTED_MODEL_EXTENSIONS]

                # 解压所有模型文件到模型目录
                for model_file in model_files:
                    zip_ref.extract(model_file, model_dir)
                    model_file_paths.append(model_dir / model_file)

                # 检查是否有可选的 Python 脚本
                handle_file = next((f for f in file_list if f.endswith(".py")), None)
                if handle_file:
                    zip_ref.extract(handle_file, scene_handle_dir)
                    handle_file_path = scene_handle_dir / handle_file
                else:
                    raise BizException(
                        status_code=400,
                        message=f"handle file (.py) is required in the zip."
                    )

        except zipfile.BadZipFile:
            raise BizException(status_code=400, message="Invalid zip file.")
        finally:
            # 删除临时文件
            temp_path.unlink()

        return [str(path) for path in model_file_paths], str(handle_file_path)

    def process_scene_file(self, file, scene):
        model_file_paths, handle_file_path = self.process_zip(file)
        scene.handle_task = snake_to_camel(os.path.splitext(os.path.basename(handle_file_path))[0])

    def create_scene(self, scene_data: SceneCreate, file: UploadFile):
        self.process_scene_file(file, scene_data)
        scene = Scene.model_validate(scene_data)
        scene.create_time = datetime.now()
        scene.update_time = datetime.now()

        self.db.add(scene)
        self.db.commit()
        self.db.refresh(scene)

        return scene

    def update_scene(self, scene_data: SceneUpdate, file: UploadFile):
        scene = self.db.get(Scene, scene_data.id)
        if not scene:
            return None

        update_data = scene_data.dict(exclude_unset=True)
        for key, value in update_data.items():
            setattr(scene, key, value)

        scene.update_time = datetime.now()
        if file:
            self.process_scene_file(file, scene)
        self.db.add(scene)
        self.db.commit()
        self.db.refresh(scene)
        self.notify_change(scene.id, NotifyChangeType.SCENE_UPDATE)

        return scene

    def delete_scene(self, scene_id: int):
        scene = self.db.get(Scene, scene_id)
        if not scene:
            return None
        # 查询 device_scene_relation 中是否存在启用的绑定关系
        statement = (
            select(DeviceSceneRelation)
            .where(DeviceSceneRelation.scene_id == scene_id)
        )
        relation_in_use = self.db.exec(statement).first()

        # 如果存在启用的绑定关系,提示无法删除
        if relation_in_use:
            raise BizException(message=f"场景 {scene.name} 正在被设备使用,无法删除")

        self.db.delete(scene)
        self.db.commit()
        return scene

    def get_scenes_in_use(self) -> Sequence[Scene]:
        """获取所有在 device_scene_relation 表里有绑定关系的模型信息"""
        statement = (
            select(Scene)
            .join(DeviceSceneRelation, DeviceSceneRelation.scene_id == Scene.id)
            .group_by(Scene.id)
        )
        results = self.db.exec(statement).all()
        return results

    def get_scene_usage(self, scene_id) -> bool:
        statement = (
            select(DeviceSceneRelation)
            .where(
                DeviceSceneRelation.scene_id == scene_id,
            )
        )
        result = self.db.exec(statement).all()
        return len(result) > 0

    def get_scene_by_id(self, scene_id):
        return self.db.get(Scene, scene_id)