Newer
Older
safe-algo-pro / model_handler / base_model_handler.py
from algo.model_manager import AlgoModelExec
from common.global_logger import logger
from common.image_plotting import colors


class BaseModelHandler:

    def __init__(self, model: AlgoModelExec):
        self.model = model
        self.model_names = model.algo_model_exec.names

    def pre_process(self, frame):
        return frame

    def model_inference(self, frame):
        results_generator = self.model.algo_model_exec.predict(source=frame, imgsz=self.model.input_size,
                                                               save_txt=False,
                                                               save=False,
                                                               verbose=False, stream=True)
        results = list(results_generator)  # 确保生成器转换为列表
        result = results[0]
        # logger.debug(f"model {self.model.algo_model_info.name} result: {len(result)}")
        return result

    def post_process(self, frame, model_result, annotator):
        results = []
        for box in model_result.boxes:
            results.append(
                {
                    'object_class_id': int(box.cls),
                    'object_class_name': self.model_names[int(box.cls)],
                    'confidence': float(box.conf),
                    'location':  ", ".join([f"{x:.6f}" for x in box.xyxyn.cpu().squeeze().tolist()])
                }
            )
        if annotator is not None:
            for s_box in model_result.boxes:
                annotator.box_label(s_box.xyxy.cpu().squeeze(),
                                    f"{self.model_names[int(s_box.cls)]} {float(s_box.conf):.2f}",
                                    color=colors(int(s_box.cls)),
                                    rotated=False)
        return results

    def run(self, frame, annotator):
        processed_frame = self.pre_process(frame=frame)
        model_result = self.model_inference(frame=processed_frame)
        result = self.post_process(frame=frame, model_result=model_result, annotator=annotator)
        return frame, result