Newer
Older
safe-algo-pro / common / global_thread_pool.py
import asyncio
import traceback
import uuid
from concurrent.futures import ThreadPoolExecutor
import threading

from common.global_logger import logger


def generate_thread_id():
    """生成唯一的线程 ID"""
    return str(uuid.uuid4())


def wrapper(func, *args, **kwargs):
    return func(*args, **kwargs)


class GlobalThreadPool:
    _instance = None
    _lock = threading.Lock()

    def __new__(cls, max_workers=5):
        if not cls._instance:
            with cls._lock:
                if not cls._instance:
                    cls._instance = super(GlobalThreadPool, cls).__new__(cls)
                    cls._instance._initialize(max_workers)
        return cls._instance

    def _initialize(self, max_workers):
        self.executor = ThreadPoolExecutor(max_workers=max_workers)
        self.loop = asyncio.get_running_loop()  # 获取当前的事件循环
        self.task_map = {}

    # def __new__(cls, *args, **kwargs):
    #     with cls._lock:
    #         if cls._instance is None:
    #             # 第一次创建实例时调用父类的 __new__ 来创建实例
    #             cls._instance = super(GlobalThreadPool, cls).__new__(cls)
    #             # 在此进行一次性的初始化,比如线程池的创建
    #             max_workers = kwargs.get('max_workers', 10)
    #             cls._instance.executor = ThreadPoolExecutor(max_workers=max_workers)
    #             cls._instance.task_map = {}  # 初始化任务映射
    #         return cls._instance

    def submit_task(self, fn, *args, thread_id=None, **kwargs):
        """提交任务到线程池,并记录线程 ID"""
        if thread_id is None:
            thread_id = generate_thread_id()
        if self.check_task_is_running(thread_id):
            raise ValueError(f"线程 ID {thread_id} 已存在")
        # future = self.executor.submit(fn, *args, **kwargs)
        future = self.loop.run_in_executor(None, wrapper, fn, *args, **kwargs)

        self.task_map[thread_id] = future  # 记录线程 ID 和 Future 对象的映射
        future.add_done_callback(lambda f: self._handle_exception(f, thread_id))
        return thread_id

    def check_task_is_running(self, thread_id):
        future = self.task_map.get(thread_id)
        if future:
            if future.running():
                return True
            else:
                del self.task_map[thread_id]
                return False
        else:
            return False

    def check_task_stopped(self, thread_id):
        """判断任务是否已停止"""
        future = self.task_map.get(thread_id)
        if future:
            if future.done():
                try:
                    # 确保任务是正常完成的,而不是因为异常停止
                    future.result()  # 如果任务抛出异常,这里会捕获
                    logger.info(f"Task {thread_id} has stopped successfully.")
                except Exception as e:
                    logger.error(f"Task {thread_id} encountered an error: {e}")
                return True  # 无论成功还是失败,任务已停止
            else:
                return False  # 任务仍在运行
        else:
            logger.warning(f"No task found with thread ID {thread_id}.")
            return True  # 如果找不到该任务,认为它已经停止(或者不存在)

    def get_task_future(self, thread_id):
        """获取指定线程 ID 的 future 对象"""
        future = self.task_map.get(thread_id)
        if future:
            return future
        else:
            logger.warning(f"No task found with thread ID {thread_id}.")
            return None

    def stop_task(self, thread_id):
        """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务"""
        future = self.task_map.get(thread_id)
        if future:
            future.cancel()  # 尝试取消任务
            logger.info(f"任务 {thread_id} 已取消")
            del self.task_map[thread_id]  # 从任务映射中删除
        else:
            logger.info(f"未找到线程 ID {thread_id}")

    def shutdown(self, wait=True):
        """关闭线程池"""
        self.executor.shutdown(wait=wait)
        GlobalThreadPool._instance = None

    def _handle_exception(self, future, thread_id):
        """
        处理任务完成时的异常
        :param future: 完成的任务 future 对象
        :param camera_id: 对应的摄像头 ID
        """
        try:
            # 获取任务结果,如果任务有异常,这里会抛出
            result = future.result()
        except Exception as e:
            logger.error(f"Task for thread {thread_id} raised an exception: {e}")
            logger.error(f"Traceback: {traceback.format_exc()}")