Newer
Older
safe-algo-pro / services / device_model_relation_service.py
from datetime import datetime
from typing import List

from sqlmodel import select, delete
from sqlalchemy.ext.asyncio import AsyncSession
from common.consts import NotifyChangeType
from common.global_thread_pool import GlobalThreadPool
from entity.device_model_relation import DeviceModelRelation, DeviceModelRelationInfo, DeviceModelRelationCreate
from entity.model import AlgoModel


class DeviceModelRelationService:
    def __init__(self, db: AsyncSession):
        self.db = db
        self.__relation_change_callbacks = []  # 用于存储回调函数
        self.thread_pool = GlobalThreadPool()

    def register_change_callback(self, callback):
        """注册设备变化回调函数"""
        self.__relation_change_callbacks.append(callback)

    def notify_change(self, device_id, change_type):
        """当设备发生变化时,调用回调通知变化"""
        for callback in self.__relation_change_callbacks:
            self.thread_pool.executor.submit(callback, device_id, change_type)

    async def get_device_models(self, device_id: int) -> List[DeviceModelRelationInfo]:
        statement = (
            select(DeviceModelRelation, AlgoModel)
            .join(AlgoModel, DeviceModelRelation.algo_model_id == AlgoModel.id)
            .where(DeviceModelRelation.device_id == device_id)
        )

        # 执行联表查询
        results = await self.db.execute(statement)
        rows = results.all()

        models_info = [
            DeviceModelRelationInfo(
                id=relation.id,
                device_id=relation.id,
                algo_model_id=relation.algo_model_id,
                is_use=relation.is_use,
                threshold=relation.threshold,
                algo_model_name=model.name,
                algo_model_version=model.version,
                algo_model_path=model.path,
                algo_model_remark=model.remark,
            )
            for relation, model in rows
        ]
        return models_info

    async def add_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]):
        new_relations = [
            DeviceModelRelation(
                algo_model_id=relation.algo_model_id,
                is_use=relation.is_use,
                threshold=relation.threshold,
                device_id=device_id,  # 统一赋值 device_id
                createtime=datetime.now(),
                updatetime=datetime.now(),
            )
            for relation in relations
        ]
        for relation in new_relations:
            self.db.add(relation)
        await self.db.commit()
        for relation in new_relations:
            await self.db.refresh(relation)
        return new_relations

    async def delete_relations_by_device(self, device_id: int):
        statement = delete(DeviceModelRelation).where(DeviceModelRelation.device_id == device_id)
        result = await self.db.execute(statement)
        await self.db.commit()
        return result.rowcount

    async def update_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]):
        await self.delete_relations_by_device(device_id)
        new_relations = await self.add_relations_by_device(device_id, relations)
        self.notify_change(device_id, NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE)
        return new_relations