Newer
Older
safe-algo-pro / services / push_config_service.py
import asyncio
from datetime import datetime

from sqlmodel import select
from sqlalchemy.ext.asyncio import AsyncSession

from entity.push_config import PushConfigCreate, PushConfig


class PushConfigService:
    _instance = None

    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance

    def __init__(self, db: AsyncSession):
        if not hasattr(self, 'initialized'):
            self.db = db
            self.__push_change_callbacks = []  # 用于存储回调函数
            self.initialized = True

    def register_change_callback(self, callback):
        """注册设备变化回调函数"""
        self.__push_change_callbacks.append(callback)

    # def notify_change(self, push_config):
    #     for callback in self.__push_change_callbacks:
    #         callback(push_config)

    def notify_change(self, config):
        """通知所有回调函数"""
        for callback in self.__push_change_callbacks:
            if asyncio.iscoroutinefunction(callback):
                # 如果是异步函数,使用 asyncio.create_task() 调度
                asyncio.create_task(callback(config))
            else:
                # 如果是同步函数,直接调用
                callback(config)

    async def set_push_config(self, push_config_create: PushConfigCreate):
        push_config = await self.get_push_config(push_config_create.push_type)
        if push_config:
            update_data = push_config_create.dict(exclude_unset=True)
            for key, value in update_data.items():
                setattr(push_config, key, value)
            push_config.update_time = datetime.now()
        else:
            push_config = PushConfig.model_validate(push_config_create)
            push_config.create_time = datetime.now()
            push_config.update_time = datetime.now()

        self.db.add(push_config)
        await self.db.commit()
        await self.db.refresh(push_config)

        self.notify_change(push_config)
        return push_config

    async def get_push_config(self, push_type):
        statement = select(PushConfig).where(PushConfig.push_type == push_type)
        results = await self.db.execute(statement)
        return results.scalars().first()

    async def get_push_config_list(self):
        statement = select(PushConfig)
        results = await self.db.execute(statement)
        return results.scalars().all()