Newer
Older
safe-algo-pro / model_handler / base_model_handler.py
zhangyingjie on 14 Oct 1 KB 完善模型检测流程
from algo.model_manager import AlgoModelExec


class BaseModelHandler:

    def __init__(self, model: AlgoModelExec):
        self.model = model
        self.model_names = model.algo_model_exec.model_name

    def pre_process(self, frame):
        return frame

    def model_inference(self, frame):
        results_generator = self.model.algo_model_exec.predict(source=frame, imgsz=self.model.input_size,
                                                               save_txt=False,
                                                               save=False,
                                                               verbose=False, stream=True)
        results = list(results_generator)  # 确保生成器转换为列表
        result = results[0]
        return result

    def post_process(self, frame, model_result, annotator):
        results = []
        for box in model_result.boxes:
            results.append(
                {
                    'object_class_id': int(box.cls),
                    'object_class_name': self.model_names[int(box.cls)],
                    'confidence': float(box.conf),
                    'location': box.xyxyn.cpu().squeeze().tolist()
                }
            )
        if annotator is not None:
            for s_box in model_result.boxes:
                annotator.box_label(s_box.xyxy.cpu().squeeze(),
                                    f"{int(s_box.cls)} {float(s_box.conf):.2f}",
                                    color=(255, 0, 0),
                                    rotated=False)
        return results

    def run(self, frame, annotator):
        processed_frame = self.pre_process(frame=frame)
        model_result = self.model_inference(frame=processed_frame)
        result = self.post_process(frame=frame, model_result=model_result, annotator=annotator)
        return frame, result