from datetime import datetime from typing import List, Sequence, Optional, Tuple from sqlalchemy import func from sqlmodel import Session, select from common.biz_exception import BizException from entity.device_model_relation import DeviceModelRelation from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate from common.global_thread_pool import GlobalThreadPool from common.consts import NotifyChangeType class ModelService: def __init__(self, db: Session): self.db = db self.model_change_callbacks = [] # 用于存储回调函数 self.thread_pool = GlobalThreadPool() def register_change_callback(self, callback): """注册设备变化回调函数""" self.model_change_callbacks.append(callback) def notify_change(self, algo_model_id, change_type): """当设备发生变化时,调用回调通知变化""" for callback in self.model_change_callbacks: self.thread_pool.executor.submit(callback, algo_model_id, change_type) def get_model_list(self, name: Optional[str] = None, remark: Optional[str] = None, ) -> Sequence[AlgoModel]: statement = self.model_query(name, remark) results = self.db.exec(statement) return results.all() def get_model_page(self, name: Optional[str] = None, remark: Optional[str] = None, offset: int = 0, limit: int = 10 ) -> Tuple[Sequence[AlgoModel], int]: statement = self.model_query(name, remark) # 查询总记录数 total_statement = select(func.count()).select_from(statement.subquery()) total = self.db.exec(total_statement).one() # 添加分页限制 statement = statement.offset(offset).limit(limit) # 执行查询并返回结果 results = self.db.exec(statement) return results.all(), total # 返回分页数据和总数 def model_query(self, name, remark): # 构建查询语句 statement = select(AlgoModel) if name: statement = statement.where(AlgoModel.name.like(f"%{name}%")) if remark: statement = statement.where(AlgoModel.remark.like(f"%{remark}%")) return statement def create_model(self, model_data: AlgoModelCreate): model = AlgoModel.model_validate(model_data) model.create_time = datetime.now() model.update_time = datetime.now() self.db.add(model) self.db.commit() self.db.refresh(model) return model def update_model(self, model_data: AlgoModelUpdate): model = self.db.get(AlgoModel, model_data.id) if not model: return None update_data = model_data.dict(exclude_unset=True) for key, value in update_data.items(): setattr(model, key, value) model.update_time = datetime.now() self.db.add(model) self.db.commit() self.db.refresh(model) self.notify_change(model.id, NotifyChangeType.MODEL_UPDATE) return model def delete_model(self, model_id: int): model = self.db.get(AlgoModel, model_id) if not model: return None # 查询 device_model_relation 中是否存在启用的绑定关系 statement = ( select(DeviceModelRelation) .where(DeviceModelRelation.algo_model_id == model_id) .where(DeviceModelRelation.is_use == 1) ) relation_in_use = self.db.exec(statement).first() # 如果存在启用的绑定关系,提示无法删除 if relation_in_use: raise BizException(message=f"模型 {model.name} 正在被设备使用,无法删除") self.db.delete(model) self.db.commit() return model def get_models_in_use(self) -> Sequence[AlgoModel]: """获取所有在 device_model_relation 表里有启用绑定关系的模型信息""" statement = ( select(AlgoModel) .join(DeviceModelRelation, DeviceModelRelation.algo_model_id == AlgoModel.id) .where(DeviceModelRelation.is_use == 1) .group_by(AlgoModel.id) ) results = self.db.exec(statement).all() return results