Newer
Older
safe-algo-pro / services / model_service.py
zhangyingjie on 17 Oct 4 KB 修改线程控制问题
from datetime import datetime
from typing import List, Sequence, Optional, Tuple, Type

from sqlalchemy import func
from sqlmodel import Session, select

from common.biz_exception import BizException
from entity.device_model_relation import DeviceModelRelation
from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate
from common.global_thread_pool import GlobalThreadPool
from common.consts import NotifyChangeType


class ModelService:
    def __init__(self, db: Session):
        self.db = db
        self.__model_change_callbacks = []  # 用于存储回调函数
        self.thread_pool = GlobalThreadPool()

    def register_change_callback(self, callback):
        """注册设备变化回调函数"""
        self.__model_change_callbacks.append(callback)

    def notify_change(self, algo_model_id, change_type):
        """当设备发生变化时,调用回调通知变化"""
        for callback in self.__model_change_callbacks:
            self.thread_pool.executor.submit(callback, algo_model_id, change_type)

    def get_model_list(self,
                       name: Optional[str] = None,
                       remark: Optional[str] = None,
                       ) -> Sequence[AlgoModel]:
        statement = self.model_query(name, remark)
        results = self.db.exec(statement)
        return results.all()

    def get_model_page(self,
                       name: Optional[str] = None,
                       remark: Optional[str] = None,
                       offset: int = 0,
                       limit: int = 10
                       ) -> Tuple[Sequence[AlgoModel], int]:
        statement = self.model_query(name, remark)

        # 查询总记录数
        total_statement = select(func.count()).select_from(statement.subquery())
        total = self.db.exec(total_statement).one()

        # 添加分页限制
        statement = statement.offset(offset).limit(limit)

        # 执行查询并返回结果
        results = self.db.exec(statement)
        return results.all(), total  # 返回分页数据和总数

    def model_query(self, name, remark):
        # 构建查询语句
        statement = select(AlgoModel)
        if name:
            statement = statement.where(AlgoModel.name.like(f"%{name}%"))
        if remark:
            statement = statement.where(AlgoModel.remark.like(f"%{remark}%"))
        return statement

    def create_model(self, model_data: AlgoModelCreate):
        model = AlgoModel.model_validate(model_data)
        model.create_time = datetime.now()
        model.update_time = datetime.now()
        self.db.add(model)
        self.db.commit()
        self.db.refresh(model)
        return model

    def update_model(self, model_data: AlgoModelUpdate):
        model = self.db.get(AlgoModel, model_data.id)
        if not model:
            return None

        update_data = model_data.dict(exclude_unset=True)
        for key, value in update_data.items():
            setattr(model, key, value)

        model.update_time = datetime.now()
        self.db.add(model)
        self.db.commit()
        self.db.refresh(model)
        self.notify_change(model.id, NotifyChangeType.MODEL_UPDATE)

        return model

    def delete_model(self, model_id: int):
        model = self.db.get(AlgoModel, model_id)
        if not model:
            return None
        # 查询 device_model_relation 中是否存在启用的绑定关系
        statement = (
            select(DeviceModelRelation)
            .where(DeviceModelRelation.algo_model_id == model_id)
            .where(DeviceModelRelation.is_use == 1)
        )
        relation_in_use = self.db.exec(statement).first()

        # 如果存在启用的绑定关系,提示无法删除
        if relation_in_use:
            raise BizException(message=f"模型 {model.name} 正在被设备使用,无法删除")

        self.db.delete(model)
        self.db.commit()
        return model

    def get_models_in_use(self) -> Sequence[AlgoModel]:
        """获取所有在 device_model_relation 表里有启用绑定关系的模型信息"""
        statement = (
            select(AlgoModel)
            .join(DeviceModelRelation, DeviceModelRelation.algo_model_id == AlgoModel.id)
            .where(DeviceModelRelation.is_use == 1)
            .group_by(AlgoModel.id)
        )
        results = self.db.exec(statement).all()
        return results

    def get_model_by_id(self,model_id):
        return self.db.get(AlgoModel, model_id)