Newer
Older
safe-algo-pro / scene_handler / limit_space_scene_handler.py
import asyncio
import base64
import traceback
from asyncio import Event
from copy import deepcopy, copy
from datetime import datetime

import cv2

from algo.model_manager import AlgoModelExec
from algo.stream_loader import OpenCVStreamLoad
from common.device_status_manager import DeviceStatusManager
from common.global_logger import logger
from common.global_thread_pool import GlobalThreadPool
from common.http_utils import send_request
from common.image_plotting import Annotator
from entity.device import Device

from ultralytics import YOLO

from scene_handler.base_scene_handler import BaseSceneHandler
from services.global_config import GlobalConfig
from tcp.tcp_manager import TcpManager

COLOR_RED = (0, 0, 255)
COLOR_GREEN = (255, 0, 0)

ALARM_DICT = {
    'hat_and_mask': {
        'alarmType': '11',
        'alarmContent': '未佩戴呼吸防护设备',
        'alarmSoundMessage': b'\xaa\x01\x00\x93\x12\x00\xA6',
        'label': '未佩戴呼吸防护设备'
    },
    'no_jiandu': {
        'alarmType': '12',
        'alarmContent': '没有监护人员',
        'alarmSoundMessage': b'\xaa\x01\x00\x93\x13\x00\xA7',
        'label': '没有监护人员'
    },
    'break': {
        'alarmType': '3',
        'alarmContent': '非法闯入',
        'alarmSoundMessage': b'\xaa\x01\x00\x93\x00\x00\x94',
        'label': '非法闯入'
    },
    'smoke': {
        'alarmType': '6',
        'alarmContent': '吸烟',
        'alarmSoundMessage': b'\xaa\x01\x00\x93\x03\x00\x97',
        'label': '吸烟'
    },
    'no_blower': {
        'alarmType': '13',
        'alarmContent': '没有检测到通风设备',
        'alarmSoundMessage': b'\xaa\x01\x00\x93\x1A\x00\xAE',
        'label': '没有检测到通风设备'
    },
    'no_extinguisher': {
        'alarmType': '14',
        'alarmContent': '没有检测到灭火器',
        'alarmSoundMessage': b'\xaa\x01\x00\x93\x1B\x00\xAF',
        'label': '没有检测到灭火器'
    }
}


def intersection_area(bbox1, bbox2):
    # 计算两个坐标框的重叠面积
    x1, y1, x2, y2 = bbox1
    x3, y3, x4, y4 = bbox2

    xi1 = max(x1, x3)
    yi1 = max(y1, y3)
    xi2 = min(x2, x4)
    yi2 = min(y2, y4)

    width = max(0, xi2 - xi1)
    height = max(0, yi2 - yi1)

    return width * height


def bbox_area(bbox):
    # 计算坐标框的面积
    x1, y1, x2, y2 = bbox
    return (x2 - x1) * (y2 - y1)


def get_person_head(person_bbox, heads):
    best_head = None
    max_overlap = 0

    for head in heads:
        head_bbox = head.xyxy.cpu().squeeze()
        overlap_area = intersection_area(person_bbox, head_bbox)
        head_area = bbox_area(head_bbox)
        overlap_ratio = overlap_area / head_area

        if overlap_ratio >= 0.8 and (best_head is None or overlap_area > max_overlap or (
                overlap_area == max_overlap and float(head.conf) > float(best_head.conf))):
            best_head = head
            max_overlap = overlap_area
    return best_head


def is_overlapping(bbox1, bbox2):
    # 检查两个坐标框是否重叠
    x1, y1, x2, y2 = bbox1
    x3, y3, x4, y4 = bbox2

    return not (x2 < x3 or x4 < x1 or y2 < y3 or y4 < y1)


def image_to_base64(numpy_image, format='jpg'):
    # 将NumPy数组转换为图片格式的字节流
    # format指定图片格式,'png'或'jpeg'
    success, encoded_image = cv2.imencode(f'.{format}', numpy_image)
    if not success:
        raise ValueError("Could not encode image")

    # 将字节流转换为Base64编码
    base64_encoded_image = base64.b64encode(encoded_image)

    # 将bytes类型转换为UTF-8字符串
    base64_message = base64_encoded_image.decode('utf-8')

    return base64_message


def handle_alarm_info(type, frame, frame_alarm, person_box, conf=None, person_id=None):
    if type not in frame_alarm or frame_alarm[type] is None:
        annotator = Annotator(deepcopy(frame), None, 18, "Arial.ttf", False, example="人")
        frame_alarm[type] = {'count': 0, 'annotator': annotator}
    frame_alarm[type]['count'] = frame_alarm[type]['count'] + 1
    alarm_annotator = frame_alarm[type]['annotator']
    if person_box is not None:
        alarm_annotator.box_label(person_box,
                                  ALARM_DICT[type]['label'] + (f'{conf:.2f}' if conf is not None else '') + (
                                      f' id={person_id}' if person_id is not None else ''),
                                  color=COLOR_RED,
                                  rotated=False)
    return alarm_annotator


class LimitSpaceSceneHandler(BaseSceneHandler):

    def __init__(self, device: Device, thread_id: str, tcp_manager: TcpManager, main_loop, range_points):
        super().__init__(device=device, thread_id=thread_id, tcp_manager=tcp_manager, main_loop=main_loop)
        # self.device = device
        # self.thread_id = thread_id
        self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code,
                                              device_thread_id=thread_id)
        self.device_status_manager = DeviceStatusManager()
        self.thread_pool = GlobalThreadPool()

        self.model = YOLO('weights/labor-v8-20241114.pt')
        self.model_classes = {
            0: '三脚架',
            3: '人',
            4: '作业信息公示牌',
            6: '危险告知牌',
            9: '反光衣',
            11: '呼吸面罩',
            13: '四合一',
            15: '头',
            16: '安全告知牌',
            18: '安全帽',
            20: '安全标识牌',
            24: '工服',
            34: '灭火器',
            43: '警戒线',
            58: '鼓风机',
        }

        self.alarm_interval_dict = {}
        self.alarm_interval = device.alarm_interval

        self.socket_interval_dict = {}
        self.socket_interval = device.alarm_interval
        self.socket_retry = 3

        self.__stop_event = Event(loop=main_loop)  # 使用 Event 控制线程的运行状态

    def stop_task(self, **kwargs):
        logger.info(f'stop detection task {self.device.id}, thread_id: {self.thread_id}')
        self.__stop_event.set()
        self.stream_loader.stop()  # 停止视频流加载的线程

    def send_tcp_message(self, message: bytes, have_response=False):
        asyncio.run_coroutine_threadsafe(
            self.tcp_manager.send_message_to_device(device_id=self.device.id,
                                                    message=message,
                                                    have_response=have_response),
            self.main_loop)

    def send_alarm_message(self, type):
        if self.tcp_manager:
            if self.socket_interval_dict.get(type) is None \
                    or (datetime.now() - self.socket_interval_dict.get(type)).total_seconds() > int(self.socket_interval):
                logger.debug("send alarm message %s %s", ALARM_DICT[type]['alarmContent'],
                             ALARM_DICT[type]['alarmSoundMessage'])
                self.send_tcp_message(ALARM_DICT[type]['alarmSoundMessage'], have_response=True)
                self.socket_interval_dict[type] = datetime.now()

    def send_alarm_record(self, type, frame_alarm):
        if self.alarm_interval < 0:
            return

        global_config = GlobalConfig()
        push_config = global_config.get_alarm_push_config()
        if push_config and push_config.push_url:
            if self.alarm_interval_dict.get(type) is None \
                    or (datetime.now() - self.alarm_interval_dict.get(type)).total_seconds() > int(self.alarm_interval):
                logger.debug("send alarm record")

                annotator_result = frame_alarm[type]['annotator'].result()
                alarm_image = deepcopy(annotator_result)

                data = {}
                data["device_id"] = self.device.id
                data["alarm_type"] = ALARM_DICT[type]['alarmType']
                data["alarm_content"] = ALARM_DICT[type]['alarmContent']
                data["alarm_value"] = frame_alarm[type]['count']
                data["alarm_image"] = image_to_base64(alarm_image)
                data["alarm_time"] = datetime.now().strftime('%Y-%m-%d %H:%M:%S')

                url = push_config.push_url
                send_request(url, data)

                data_copy = copy(data)
                # 从拷贝的字典中移除"alarm_image"键
                data_copy.pop("alarm_image", None)
                logger.debug(f"send to {url}: {data_copy}")

                self.alarm_interval_dict[type] = datetime.now()

    def model_predict(self, frame):
        results_gen = self.model(frame, save_txt=False, save=False, verbose=False, conf=0.5,
                                 classes=list(self.model_classes.keys()),
                                 imgsz=640,
                                 stream=True)
        results = list(results_gen)  # 确保生成器转换为列表
        result = results[0]
        result_boxes = [box for box in result.boxes]
        pred_ids = [int(box.cls) for box in result_boxes]
        pred_names = [self.model_classes[int(box.cls)] for box in result_boxes]
        return result_boxes, pred_ids, pred_names

    def process_alarm(self, frame, result_boxes, pred_ids, pred_names, frame_alarm):
        persons = [box for box in result_boxes if int(box.cls) == 3]
        helmets = [box for box in result_boxes if int(box.cls) == 18]
        heads = [box for box in result_boxes if int(box.cls) == 15]

        has_jianduyuan = False
        has_others = False

        for person in persons:
            person_bbox = person.xyxy.cpu().squeeze()

            # 检查这个人是否佩戴了安全帽
            has_helmet = True
            person_head = get_person_head(person_bbox, heads)
            if person_head is not None:
                has_helmet = any(
                    is_overlapping(person_head.xyxy.cpu().squeeze(), helmet.xyxy.cpu().squeeze()) for helmet in
                    helmets)

            if not has_helmet:
                has_others = True
                alarm_annotator = handle_alarm_info('break', frame, frame_alarm, person_bbox,
                                                         float(person.conf))
            else:
                has_jianduyuan = True

        if not has_jianduyuan:
            alarm_annotator = handle_alarm_info('no_jiandu', frame, frame_alarm, None)
            self.send_alarm_message('no_jiandu')
        if has_others:
            self.send_alarm_message('break')

    def process_labor(self, frame, result_boxes, pred_ids, pred_names):
        pass

    def run(self):
        while not self.stream_loader.init:
            if self.__stop_event.is_set():
                break  # 如果触发了停止事件,则退出循环
            self.stream_loader.init_cap()
        for frames in self.stream_loader:
            try:
                if self.__stop_event.is_set():
                    break  # 如果触发了停止事件,则退出循环
                # print('frame')
                if frames is None:
                    continue

                self.device_status_manager.set_status(device_id=self.device.id)
                # result_boxes, pred_ids, pred_names = self.model_predict(frames)

                frame_alarm = {}
                # self.process_alarm(frame, result_boxes, pred_ids, pred_names, frame_alarm)
                # self.process_labor(frame, result_boxes, pred_ids, pred_names)

                if len(frame_alarm.keys()) > 0:
                    for key in frame_alarm.keys():
                        if frame_alarm[key]['count'] > 0:
                            self.thread_pool.submit_task(self.send_alarm_record, key, frame_alarm, )

            except Exception as ex:
                traceback.print_exc()
                logger.error(ex)