Newer
Older
safe-algo-pro / algo / stream_loader.py
from datetime import datetime, timedelta

import cv2
import time
import numpy as np
from threading import Thread, Event

import queue

from common.global_logger import logger
from common.global_thread_pool import GlobalThreadPool


class OpenCVStreamLoad:
    def __init__(self, camera_url, camera_code, device_thread_id = '',
                 batch_size=4,
                 queue_size=100,
                 retry_interval=1,
                 vid_stride=1):
        assert camera_url is not None and camera_url != ''
        self.url = camera_url
        self.camera_code = camera_code
        self.retry_interval = retry_interval
        self.vid_stride = vid_stride
        self.device_thread_id = device_thread_id

        self.init = False
        self.frame = None
        self.frame_queue = queue.Queue(maxsize=queue_size)
        self.fps_ts = None
        self.cap = None
        self.batch_size = batch_size
        self.frames_read = 0

        self.__stop_event = Event()  # 增加 stop_event 作为停止线程的标志
        self.thread_pool = GlobalThreadPool()
        self.init_cap()

        # self.thread = Thread(target=self.update, daemon=True)
        # self.thread.start()

    def init_cap(self):
        if self.init:
            return
        self.cap = self.get_connect()
        if self.cap:
            self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
            self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
            self.frame_fps = int(self.cap.get(cv2.CAP_PROP_FPS))
            _, self.frame = self.cap.read()
            self.thread_pool.submit_task(self.update)
            self.init = True

    def create_capture(self):
        """
        尝试创建视频流捕获对象。
        """
        try:
            # cap = cv2.VideoCapture(self.url)
            gst_pipeline = (
                f"rtspsrc location={self.url} ! "
                f"rtph264depay ! h264parse ! "
                f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! "
                f"appsink"
            )
            cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER)
            # 可以在这里设置cap的一些属性,如果需要的话
            return cap
        except Exception as e:
            logger.error(e)
            return None

    def get_connect(self):
        """
        尝试重新连接,直到成功。
        """
        cap = None
        attempt = 0
        while cap is None or not cap.isOpened():
            if self.__stop_event.is_set():  # 检查是否收到停止信号
                logger.info(f"{self.url} stopping connection attempts...thread {self.device_thread_id} stopped")
                if cap:
                    cap.release()
                return None  # 退出循环,返回 None 或者其他你认为合适的值

            if cap is not None:
                cap.release()

            # 尝试创建新的连接
            try:
                logger.info(f"{self.url} attempting to connect... (attempt {attempt + 1})")

                cap = cv2.VideoCapture(self.url)
                # cap = cv2.VideoCapture(self.url, cv2.CAP_FFMPEG)

                # gst_pipeline = (
                #     f"rtspsrc location={self.url} ! "
                #     f"rtph264depay ! h264parse ! "
                #     f"nvv4l2decoder ! nvvidconv ! video/x-raw, format=(string)BGRx ! videoconvert ! "
                #     f"appsink"
                # )
                # cap = cv2.VideoCapture(gst_pipeline, cv2.CAP_GSTREAMER)

                if cap.isOpened():
                    logger.info(f"{self.url} connected successfully!")
                    return cap
                else:
                    logger.warning(f"Failed to open stream on attempt {attempt + 1}.")
                    cap.release()  # 确保释放不成功的 cap
            except Exception as e:
                logger.error(f"Error creating VideoCapture for {self.url}: {e}")
            time.sleep(self.retry_interval)  # 延迟重试
            attempt += 1

        return cap

    def log_fps(self):
        current_time = datetime.now()
        # 每秒输出 FPS
        if self.fps_ts is None or current_time - self.fps_ts >= timedelta(seconds=10):
            fps = self.frames_read / 10.0
            self.frames_read = 0
            logger.info(f"FPS (read) for device {self.camera_code}: {fps}")
            self.fps_ts = current_time

    def update(self):
        vid_n = 0
        log_n = 0
        while not self.__stop_event.is_set():
            # if not self.init:
            #     self.init_cap()
            if self.cap is None:
                continue
            if vid_n % self.vid_stride == 0:
                try:
                    ret, frame = self.cap.read()
                    if not ret:
                        logger.info(f"{self.url} disconnect, try to reconnect...")
                        self.cap.release()  # 释放当前的捕获对象
                        self.cap = self.get_connect()  # 尝试重新连接
                        self.frame = None
                        continue  # 跳过当前循环的剩余部分
                    else:
                        vid_n += 1
                        self.frame = frame
                        self.frames_read += 1
                        if not self.frame_queue.full():
                            self.frame_queue.put(frame)
                        if log_n % 1000 == 0:
                            logger.debug(f'{self.url} cap success')
                        log_n = (log_n + 1) % 250
                except Exception as e:
                    logger.error(f"{self.url} update fail", exc_info=e)
                    if self.cap is not None:
                        self.cap.release()
                    self.frame = None
                    self.cap = self.get_connect()  # 尝试重新连接
            self.log_fps()

    def __iter__(self):
        return self

    def __next__(self):
        batch_frames = []

        queue_length = self.frame_queue.qsize()
        if queue_length < self.batch_size:
            return []

        while not self.frame_queue.empty() and len(batch_frames) < self.batch_size:
            frame = self.frame_queue.get()
            batch_frames.append(frame)
        return batch_frames

    def stop(self):
        logger.info(f'stop stream loader {self.url},device thread {self.device_thread_id} stopped')
        """ 停止视频流读取线程 """
        self.__stop_event.set()
        # self.thread.join()  # 确保线程已完全终止
        if self.cap:
            try:
                self.cap.release()
                logger.info(f"{self.url} successfully released video capture.")
            except Exception as e:
                logger.error(f"Failed to release video capture: {e}")