from ultralytics import YOLO import numpy as np from global_logger import logger class ModelWrapper: def __init__(self, model_path, model_size=640, model_names=None, model_warm_up=5, batch_size=1): self.model_path = model_path self.model_size = model_size self.model_names = model_names self.model_warm_up = model_warm_up self.batch_size = batch_size logger.info(f'start load model {self.model_path}...') self.model = YOLO(model_path) self.__warm_up__() logger.info(f'load model {self.model_path} success!') def __warm_up__(self): if self.model_warm_up > 0: logger.info(f'warming up model {self.model_path}') imgsz = self.model_size if not isinstance(imgsz, list): imgsz = [imgsz, imgsz] dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) dummy_inputs = [] for _ in range(self.batch_size): dummy_inputs.append(dummy_input) for i in range(self.model_warm_up): self.model.predict(source=dummy_inputs, imgsz=imgsz, verbose=False, save=False, save_txt=False) logger.info(f'warm up model {self.model_path} success!') def predict(self, frame): results_generator = self.model.predict(source=frame, imgsz=self.model_size, save_txt=False, save=False, verbose=False, stream=True) results = list(results_generator) return results[0].boxes if len(results) > 0 else [] def batch_predict(self, frames): results_generator = self.model.predict(source=frames, imgsz=self.model_size, save_txt=False, save=False, verbose=False, stream=True) result_boxes = [] for r in results_generator: result_boxes.append(r.boxes) return result_boxes def get_label(self, cls): return self.model_names.get(cls, str(cls))