Newer
Older
safe-algo-pro / algo / model_manager.py
import os.path
from dataclasses import dataclass
from typing import Optional, Dict

import numpy as np
from ultralytics import YOLO

from entity.model import AlgoModel
from services.model_service import ModelService
from common.global_logger import logger


@dataclass
class AlgoModelExec:
    algo_model_id: int
    algo_model_info: AlgoModel
    algo_model_exec: Optional[YOLO] = None
    input_size: int = 640


class ModelManager:
    def __init__(self, model_service: ModelService, model_warm_up=5, batch_size=4):
        # self.db = db
        self.model_service = model_service
        self.models: Dict[int, AlgoModelExec] = {}
        self.model_warm_up = model_warm_up
        self.batch_size=4

    async def query_model_inuse(self):
        algo_model_list = list(await self.model_service.get_models_in_use())
        for algo_model in algo_model_list:
            self.models[algo_model.id] = AlgoModelExec(
                algo_model_id=algo_model.id,
                algo_model_info=algo_model
            )

    async def load_models(self):
        logger.info('loading models')
        self.models = {}
        await self.query_model_inuse()
        if not self.models:
            logger.info("no model in use")
        for algo_model_id, algo_model_exec in self.models.items():
            self.load_model(algo_model_exec)

    def load_model(self, algo_model_exec: AlgoModelExec):
        model_name = algo_model_exec.algo_model_info.name
        model_path = algo_model_exec.algo_model_info.path
        logger.info(f'loading model {model_name}: {model_path}')

        if not os.path.exists(model_path):
            logger.info(f'model path:{model_path} not exists')
            return

        algo_model_exec.algo_model_exec = YOLO(model_path, task='detect')
        if self.model_warm_up > 0:
            logger.info(f'warming up model {model_name}')
            imgsz = algo_model_exec.input_size
            if not isinstance(imgsz, list):
                imgsz = [imgsz, imgsz]
            dummy_input = np.zeros((imgsz[0], imgsz[1], 3))
            dummy_inputs = []
            for _ in range(self.batch_size):
                dummy_inputs.append(dummy_input)
            for i in range(self.model_warm_up):
                algo_model_exec.algo_model_exec.predict(source=dummy_inputs, imgsz=imgsz, verbose=False,save=False,save_txt=False)
            logger.info(f'warm up model {model_name} success!')
        logger.info(f'load model {model_name} success!')

    def reload_model(self, model_id):
        algo_model_exec = self.models.get(model_id)
        if algo_model_exec:
            algo_model_exec.algo_model_exec = None
            self.load_model(algo_model_exec)

    async def load_new_model(self, model_id):
        algo_model = await self.model_service.get_model_by_id(model_id)
        if algo_model:
            algo_model_exec = AlgoModelExec(
                algo_model_id=algo_model.id,
                algo_model_info=algo_model
            )
            self.models[algo_model.id] = algo_model_exec
            self.load_model(algo_model_exec)
            return algo_model_exec
        return None