Newer
Older
go-algo-server / model_wrapper.py
zhangyingjie on 6 Mar 2 KB 初版提交(本地测试版)
from ultralytics import YOLO
import numpy as np

from global_logger import logger


class ModelWrapper:
    def __init__(self, model_path, model_size=640, model_names=None, model_warm_up=5, batch_size=1):
        self.model_path = model_path
        self.model_size = model_size
        self.model_names = model_names
        self.model_warm_up = model_warm_up
        self.batch_size = batch_size

        logger.info(f'start load model {self.model_path}...')
        self.model = YOLO(model_path)
        self.__warm_up__()
        logger.info(f'load model {self.model_path} success!')

    def __warm_up__(self):
        if self.model_warm_up > 0:
            logger.info(f'warming up model {self.model_path}')
            imgsz = self.model_size
            if not isinstance(imgsz, list):
                imgsz = [imgsz, imgsz]
            dummy_input = np.zeros((imgsz[0], imgsz[1], 3))
            dummy_inputs = []
            for _ in range(self.batch_size):
                dummy_inputs.append(dummy_input)
            for i in range(self.model_warm_up):
                self.model.predict(source=dummy_inputs, imgsz=imgsz, verbose=False, save=False, save_txt=False)
            logger.info(f'warm up model {self.model_path} success!')

    def predict(self, frame):
        results_generator = self.model.predict(source=frame, imgsz=self.model_size,
                                               save_txt=False,
                                               save=False,
                                               verbose=False, stream=True)
        results = list(results_generator)
        return results[0].boxes if len(results) > 0 else []

    def batch_predict(self, frames):
        results_generator = self.model.predict(source=frames, imgsz=self.model_size,
                                               save_txt=False,
                                               save=False,
                                               verbose=False, stream=True)
        result_boxes = []
        for r in results_generator:
            result_boxes.append(r.boxes)

        return result_boxes

    def get_label(self, cls):
        return self.model_names.get(cls, str(cls))