Newer
Older
safe-algo-pro / model_handler / base_model_handler.py
from algo.model_manager import AlgoModelExec
from common.global_logger import logger
from common.image_plotting import colors


class BaseModelHandler:

    def __init__(self, model: AlgoModelExec):
        self.model = model
        self.model_names = model.algo_model_exec.names
        self.model_ids = list(self.model_names.keys())

    def pre_process(self, frame):
        return frame

    def model_inference(self, frames):
        results_generator = self.model.algo_model_exec.predict(source=frames, imgsz=self.model.input_size,
                                                               save_txt=False,
                                                               save=False,
                                                               verbose=True, stream=True)
        result_boxes = []
        for r in results_generator:
            result_boxes.append(r.boxes)

        return result_boxes

    def post_process(self, frames, model_results, annotators):
        results = []
        for idx,frame in enumerate(frames):
            frame_boxes = model_results[idx]
            annotator = annotators[idx]
            frame_result = []
            for box in frame_boxes:
                frame_result.append(
                    {
                        'object_class_id': int(box.cls),
                        'object_class_name': self.model_names[int(box.cls)],
                        'confidence': float(box.conf),
                        'location':  ", ".join([f"{x:.6f}" for x in box.xyxyn.cpu().squeeze().tolist()])
                    }
                )
            results.append(frame_result)
            if annotator is not None:
                for s_box in frame_boxes:
                    annotator.box_label(s_box.xyxy.cpu().squeeze(),
                                        f"{self.model_names[int(s_box.cls)]} {float(s_box.conf):.2f}",
                                        color=colors(int(s_box.cls)),
                                        rotated=False)
        return results

    def run(self, frames, annotators):
        processed_frames = [self.pre_process(frame=frame) for frame in frames]
        result_boxes = self.model_inference(frames=processed_frames)
        results = self.post_process(frames=frames, model_results=result_boxes, annotators=annotators)
        return frames, results