Newer
Older
safe-algo-pro / services / model_service.py
import asyncio
import os
import uuid
import zipfile
from datetime import datetime
from pathlib import Path
from typing import List, Sequence, Optional, Tuple, Type

import aiofiles
from fastapi import UploadFile
from sqlalchemy import func
from sqlmodel import select, delete
from sqlalchemy.ext.asyncio import AsyncSession
from common.biz_exception import BizException
from common.string_utils import snake_to_camel
from entity.device_model_relation import DeviceModelRelation
from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo
from common.global_thread_pool import GlobalThreadPool
from common.consts import NotifyChangeType


class ModelService:
    def __init__(self, db: AsyncSession):
        self.db = db
        self.__model_change_callbacks = []  # 用于存储回调函数
        self.thread_pool = GlobalThreadPool()

    def register_change_callback(self, callback):
        """注册设备变化回调函数"""
        self.__model_change_callbacks.append(callback)

    def notify_change(self, algo_model_id, change_type):
        """当设备发生变化时,调用回调通知变化"""
        for callback in self.__model_change_callbacks:
            self.thread_pool.executor.submit(callback, algo_model_id, change_type)

    async def get_model_list(self,
                             name: Optional[str] = None,
                             remark: Optional[str] = None,
                             ) -> Sequence[AlgoModel]:
        statement = self.model_query(name, remark)
        results = await self.db.execute(statement)
        return results.scalars().all()

    async def get_model_page(self,
                       name: Optional[str] = None,
                       remark: Optional[str] = None,
                       offset: int = 0,
                       limit: int = 10
                       ) -> Tuple[Sequence[AlgoModelInfo], int]:
        statement = self.model_query(name, remark)

        # 查询总记录数
        total_statement = select(func.count()).select_from(statement.subquery())
        total_result = await self.db.execute(total_statement)
        total = total_result.scalar_one()

        # 添加分页限制
        statement = statement.offset(offset).limit(limit)

        # 执行查询并返回结果
        model_list = await self.db.execute(statement)
        rows = model_list.scalars().all()
        model_info_list: List[AlgoModelInfo] = []
        if rows:
            for model in rows:
                model_info_list.append(AlgoModelInfo(
                    **model.dict(),
                    usage_status="使用中" if await self.get_model_usage(model.id) else "未使用"
                ))

        return model_info_list, total  # 返回分页数据和总数

    def model_query(self, name, remark):
        # 构建查询语句
        statement = select(AlgoModel)
        if name:
            statement = statement.where(AlgoModel.name.like(f"%{name}%"))
        if remark:
            statement = statement.where(AlgoModel.remark.like(f"%{remark}%"))
        return statement

    async def process_zip(self, file: UploadFile):
        model_dir = Path('weights/')
        model_handle_dir = Path('model_handler/')
        model_dir.mkdir(parents=True, exist_ok=True)
        model_handle_dir.mkdir(parents=True, exist_ok=True)
        # 支持的模型文件扩展名
        SUPPORTED_MODEL_EXTENSIONS = {".pt", ".onnx", ".engine"}

        # 临时保存上传文件
        temp_path = Path(f"temp_upload_{uuid.uuid4()}.zip")
        async with aiofiles.open(temp_path, "wb") as temp_file:
            content = await file.read()  # 异步读取上传文件内容
            await temp_file.write(content)

        model_file_path = None
        handle_file_path = None

        try:
            # 使用异步方法读取压缩文件的内容
            loop = asyncio.get_event_loop()
            with zipfile.ZipFile(temp_path, 'r') as zip_ref:
                # 获取压缩包文件列表
                file_list = zip_ref.namelist()

                # 检查是否有模型权重文件
                model_file = next((f for f in file_list if Path(f).suffix in SUPPORTED_MODEL_EXTENSIONS), None)
                if not model_file:
                    raise BizException(
                        status_code=400,
                        message=f"Model weight file ({', '.join(SUPPORTED_MODEL_EXTENSIONS)}) is required in the zip."
                    )

                # 异步解压模型文件到模型目录
                model_file_path = model_dir / model_file
                await loop.run_in_executor(None, zip_ref.extract, model_file, model_dir)

                # 检查是否有可选的 Python 脚本
                handle_file = next((f for f in file_list if f.endswith(".py")), None)
                if handle_file:
                    handle_file_path = model_handle_dir / handle_file
                    await loop.run_in_executor(None, zip_ref.extract, handle_file, model_handle_dir)

        except zipfile.BadZipFile:
            raise BizException(status_code=400, message="Invalid zip file.")
        finally:
            # 删除临时文件
            loop = asyncio.get_event_loop()
            await loop.run_in_executor(None, temp_path.unlink)

        return str(model_file_path), str(handle_file_path) if handle_file_path else None

    async def create_model(self, model_data: AlgoModelCreate, file: UploadFile):
        await self.process_model_file(file, model_data)
        model = AlgoModel.model_validate(model_data)
        model.create_time = datetime.now()
        model.update_time = datetime.now()

        self.db.add(model)
        await self.db.commit()
        await self.db.refresh(model)

        return model

    async def process_model_file(self, file, model):
        model_file_path, handle_file_path = await self.process_zip(file)
        model.path = model_file_path
        if handle_file_path:
            model.handle_task = snake_to_camel(os.path.splitext(os.path.basename(handle_file_path))[0])
        else:
            model.handle_task = 'BaseModelHandler'

    async def update_model(self, model_data: AlgoModelUpdate, file: UploadFile):
        model = await self.get_model_by_id(model_data.id)
        if not model:
            return None

        update_data = model_data.dict(exclude_unset=True)
        for key, value in update_data.items():
            setattr(model, key, value)

        model.update_time = datetime.now()
        if file:
            await self.process_model_file(file, model)
        self.db.add(model)
        await self.db.commit()
        await self.db.refresh(model)
        self.notify_change(model.id, NotifyChangeType.MODEL_UPDATE)

        return model

    async def delete_model(self, model_id: int):
        model = await self.get_model_by_id(model_id)
        if not model:
            return None
        # 查询 device_model_relation 中是否存在启用的绑定关系
        statement = (
            select(DeviceModelRelation)
            .where(DeviceModelRelation.algo_model_id == model_id)
            .where(DeviceModelRelation.is_use == 1)
        )
        relation_in_use_exec = await self.db.execute(statement)
        relation_in_use = relation_in_use_exec.scalars().first()

        # 如果存在启用的绑定关系,提示无法删除
        if relation_in_use:
            raise BizException(message=f"模型 {model.name} 正在被设备使用,无法删除")

        statement = delete(AlgoModel).where(AlgoModel.id == model_id)
        await self.db.execute(statement)
        await self.db.commit()

        return model

    async def get_models_in_use(self) -> Sequence[AlgoModel]:
        """获取所有在 device_model_relation 表里有启用绑定关系的模型信息"""
        statement = (
            select(AlgoModel)
            .join(DeviceModelRelation, DeviceModelRelation.algo_model_id == AlgoModel.id)
            .where(DeviceModelRelation.is_use == 1)
            .group_by(AlgoModel.id)
        )
        results = await self.db.execute(statement)
        return results.scalars().all()

    async def get_model_usage(self, algo_model_id) -> bool:
        statement = (
            select(DeviceModelRelation)
            .where(
                DeviceModelRelation.is_use == 1,
                DeviceModelRelation.algo_model_id == algo_model_id,
            )
        )
        result = await self.db.execute(statement)
        rows = result.all()
        return len(rows) > 0

    async def get_model_by_id(self, model_id):
        result = await self.db.execute(select(AlgoModel).where(AlgoModel.id == model_id))
        model = result.scalar_one_or_none()
        return model