diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml index d0876a7..33f5c97 100644 --- a/.idea/safe-algo-pro.iml +++ b/.idea/safe-algo-pro.iml @@ -5,4 +5,7 @@ + + \ No newline at end of file diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml index d0876a7..33f5c97 100644 --- a/.idea/safe-algo-pro.iml +++ b/.idea/safe-algo-pro.iml @@ -5,4 +5,7 @@ + + \ No newline at end of file diff --git a/algo/algo_runner.py b/algo/algo_runner.py index d59d93a..40b8d1e 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -5,7 +5,7 @@ from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager -from common.consts import NotifyChangeType +from common.consts import NotifyChangeType, DEVICE_MODE from common.global_thread_pool import GlobalThreadPool from db.database import get_db from entity.device import Device @@ -14,23 +14,25 @@ from services.model_service import ModelService from common.global_logger import logger - class AlgoRunner: - def __init__(self): - with next(get_db()) as db: - self.device_service = DeviceService(db) - self.model_service = ModelService(db) - self.relation_service = DeviceModelRelationService(db) - self.model_manager = ModelManager(db) + def __init__(self, + device_service : DeviceService, + model_service: ModelService, + relation_service: DeviceModelRelationService): + + self.device_service = device_service + self.model_service = model_service + self.relation_service = relation_service + self.model_manager = ModelManager(model_service) self.thread_pool = GlobalThreadPool() self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 - self.device_service.register_change_callback(self.on_device_change) - self.model_service.register_change_callback(self.on_model_change) - self.relation_service.register_change_callback(self.on_relation_change) + # self.device_service.register_change_callback(self.on_device_change) + # self.model_service.register_change_callback(self.on_model_change) + # self.relation_service.register_change_callback(self.on_relation_change) async def start(self): logger.info("Starting AlgoRunner...") @@ -44,10 +46,6 @@ for device in devices: self.start_device_thread(device) - # @staticmethod - # def get_device_thread_id(device_id): - # return f'device_{device_id}' - def start_device_thread(self, device: Device): device = copy.deepcopy(device) """为单个设备启动检测线程""" @@ -56,7 +54,10 @@ logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return - thread_id = f'device_{device.id}_{uuid.uuid4()}' + if not device.mode == DEVICE_MODE.ALGO: + return + + thread_id = f'device_{device.id}_algo_{uuid.uuid4()}' logger.info(f'start thread {thread_id}, device info: {device}') # if self.thread_pool.check_task_is_running(thread_id=thread_id): @@ -135,9 +136,10 @@ elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: - # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 - old_device = self.device_tasks[device_id].device - new_device = self.device_service.get_device(device_id) + if device_id in self.device_tasks: + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) self.restart_device_thread(device_id) diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml index d0876a7..33f5c97 100644 --- a/.idea/safe-algo-pro.iml +++ b/.idea/safe-algo-pro.iml @@ -5,4 +5,7 @@ + + \ No newline at end of file diff --git a/algo/algo_runner.py b/algo/algo_runner.py index d59d93a..40b8d1e 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -5,7 +5,7 @@ from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager -from common.consts import NotifyChangeType +from common.consts import NotifyChangeType, DEVICE_MODE from common.global_thread_pool import GlobalThreadPool from db.database import get_db from entity.device import Device @@ -14,23 +14,25 @@ from services.model_service import ModelService from common.global_logger import logger - class AlgoRunner: - def __init__(self): - with next(get_db()) as db: - self.device_service = DeviceService(db) - self.model_service = ModelService(db) - self.relation_service = DeviceModelRelationService(db) - self.model_manager = ModelManager(db) + def __init__(self, + device_service : DeviceService, + model_service: ModelService, + relation_service: DeviceModelRelationService): + + self.device_service = device_service + self.model_service = model_service + self.relation_service = relation_service + self.model_manager = ModelManager(model_service) self.thread_pool = GlobalThreadPool() self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 - self.device_service.register_change_callback(self.on_device_change) - self.model_service.register_change_callback(self.on_model_change) - self.relation_service.register_change_callback(self.on_relation_change) + # self.device_service.register_change_callback(self.on_device_change) + # self.model_service.register_change_callback(self.on_model_change) + # self.relation_service.register_change_callback(self.on_relation_change) async def start(self): logger.info("Starting AlgoRunner...") @@ -44,10 +46,6 @@ for device in devices: self.start_device_thread(device) - # @staticmethod - # def get_device_thread_id(device_id): - # return f'device_{device_id}' - def start_device_thread(self, device: Device): device = copy.deepcopy(device) """为单个设备启动检测线程""" @@ -56,7 +54,10 @@ logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return - thread_id = f'device_{device.id}_{uuid.uuid4()}' + if not device.mode == DEVICE_MODE.ALGO: + return + + thread_id = f'device_{device.id}_algo_{uuid.uuid4()}' logger.info(f'start thread {thread_id}, device info: {device}') # if self.thread_pool.check_task_is_running(thread_id=thread_id): @@ -135,9 +136,10 @@ elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: - # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 - old_device = self.device_tasks[device_id].device - new_device = self.device_service.get_device(device_id) + if device_id in self.device_tasks: + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) self.restart_device_thread(device_id) diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py index d7d8015..9109d75 100644 --- a/algo/algo_runner_manager.py +++ b/algo/algo_runner_manager.py @@ -1,7 +1,8 @@ from algo.algo_runner import AlgoRunner -algo_runner = AlgoRunner() +# algo_runner = AlgoRunner() def get_algo_runner(): - return algo_runner + pass + # return algo_runner diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml index d0876a7..33f5c97 100644 --- a/.idea/safe-algo-pro.iml +++ b/.idea/safe-algo-pro.iml @@ -5,4 +5,7 @@ + + \ No newline at end of file diff --git a/algo/algo_runner.py b/algo/algo_runner.py index d59d93a..40b8d1e 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -5,7 +5,7 @@ from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager -from common.consts import NotifyChangeType +from common.consts import NotifyChangeType, DEVICE_MODE from common.global_thread_pool import GlobalThreadPool from db.database import get_db from entity.device import Device @@ -14,23 +14,25 @@ from services.model_service import ModelService from common.global_logger import logger - class AlgoRunner: - def __init__(self): - with next(get_db()) as db: - self.device_service = DeviceService(db) - self.model_service = ModelService(db) - self.relation_service = DeviceModelRelationService(db) - self.model_manager = ModelManager(db) + def __init__(self, + device_service : DeviceService, + model_service: ModelService, + relation_service: DeviceModelRelationService): + + self.device_service = device_service + self.model_service = model_service + self.relation_service = relation_service + self.model_manager = ModelManager(model_service) self.thread_pool = GlobalThreadPool() self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 - self.device_service.register_change_callback(self.on_device_change) - self.model_service.register_change_callback(self.on_model_change) - self.relation_service.register_change_callback(self.on_relation_change) + # self.device_service.register_change_callback(self.on_device_change) + # self.model_service.register_change_callback(self.on_model_change) + # self.relation_service.register_change_callback(self.on_relation_change) async def start(self): logger.info("Starting AlgoRunner...") @@ -44,10 +46,6 @@ for device in devices: self.start_device_thread(device) - # @staticmethod - # def get_device_thread_id(device_id): - # return f'device_{device_id}' - def start_device_thread(self, device: Device): device = copy.deepcopy(device) """为单个设备启动检测线程""" @@ -56,7 +54,10 @@ logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return - thread_id = f'device_{device.id}_{uuid.uuid4()}' + if not device.mode == DEVICE_MODE.ALGO: + return + + thread_id = f'device_{device.id}_algo_{uuid.uuid4()}' logger.info(f'start thread {thread_id}, device info: {device}') # if self.thread_pool.check_task_is_running(thread_id=thread_id): @@ -135,9 +136,10 @@ elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: - # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 - old_device = self.device_tasks[device_id].device - new_device = self.device_service.get_device(device_id) + if device_id in self.device_tasks: + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) self.restart_device_thread(device_id) diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py index d7d8015..9109d75 100644 --- a/algo/algo_runner_manager.py +++ b/algo/algo_runner_manager.py @@ -1,7 +1,8 @@ from algo.algo_runner import AlgoRunner -algo_runner = AlgoRunner() +# algo_runner = AlgoRunner() def get_algo_runner(): - return algo_runner + pass + # return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index b5a2d9e..e222ca7 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -9,7 +9,7 @@ from common.device_status_manager import DeviceStatusManager from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool -from common.string_utils import camel_to_snake +from common.string_utils import camel_to_snake, get_class from db.database import get_db from entity.device import Device from entity.frame_analysis_result import FrameAnalysisResultCreate @@ -17,11 +17,7 @@ from services.frame_analysis_result_service import FrameAnalysisResultService -def get_class(module_name, class_name): - # 动态导入模块 - module = importlib.import_module(module_name) - # 使用 getattr 从模块中获取类 - return getattr(module, class_name) + @dataclass diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml index d0876a7..33f5c97 100644 --- a/.idea/safe-algo-pro.iml +++ b/.idea/safe-algo-pro.iml @@ -5,4 +5,7 @@ + + \ No newline at end of file diff --git a/algo/algo_runner.py b/algo/algo_runner.py index d59d93a..40b8d1e 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -5,7 +5,7 @@ from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager -from common.consts import NotifyChangeType +from common.consts import NotifyChangeType, DEVICE_MODE from common.global_thread_pool import GlobalThreadPool from db.database import get_db from entity.device import Device @@ -14,23 +14,25 @@ from services.model_service import ModelService from common.global_logger import logger - class AlgoRunner: - def __init__(self): - with next(get_db()) as db: - self.device_service = DeviceService(db) - self.model_service = ModelService(db) - self.relation_service = DeviceModelRelationService(db) - self.model_manager = ModelManager(db) + def __init__(self, + device_service : DeviceService, + model_service: ModelService, + relation_service: DeviceModelRelationService): + + self.device_service = device_service + self.model_service = model_service + self.relation_service = relation_service + self.model_manager = ModelManager(model_service) self.thread_pool = GlobalThreadPool() self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 - self.device_service.register_change_callback(self.on_device_change) - self.model_service.register_change_callback(self.on_model_change) - self.relation_service.register_change_callback(self.on_relation_change) + # self.device_service.register_change_callback(self.on_device_change) + # self.model_service.register_change_callback(self.on_model_change) + # self.relation_service.register_change_callback(self.on_relation_change) async def start(self): logger.info("Starting AlgoRunner...") @@ -44,10 +46,6 @@ for device in devices: self.start_device_thread(device) - # @staticmethod - # def get_device_thread_id(device_id): - # return f'device_{device_id}' - def start_device_thread(self, device: Device): device = copy.deepcopy(device) """为单个设备启动检测线程""" @@ -56,7 +54,10 @@ logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return - thread_id = f'device_{device.id}_{uuid.uuid4()}' + if not device.mode == DEVICE_MODE.ALGO: + return + + thread_id = f'device_{device.id}_algo_{uuid.uuid4()}' logger.info(f'start thread {thread_id}, device info: {device}') # if self.thread_pool.check_task_is_running(thread_id=thread_id): @@ -135,9 +136,10 @@ elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: - # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 - old_device = self.device_tasks[device_id].device - new_device = self.device_service.get_device(device_id) + if device_id in self.device_tasks: + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) self.restart_device_thread(device_id) diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py index d7d8015..9109d75 100644 --- a/algo/algo_runner_manager.py +++ b/algo/algo_runner_manager.py @@ -1,7 +1,8 @@ from algo.algo_runner import AlgoRunner -algo_runner = AlgoRunner() +# algo_runner = AlgoRunner() def get_algo_runner(): - return algo_runner + pass + # return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index b5a2d9e..e222ca7 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -9,7 +9,7 @@ from common.device_status_manager import DeviceStatusManager from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool -from common.string_utils import camel_to_snake +from common.string_utils import camel_to_snake, get_class from db.database import get_db from entity.device import Device from entity.frame_analysis_result import FrameAnalysisResultCreate @@ -17,11 +17,7 @@ from services.frame_analysis_result_service import FrameAnalysisResultService -def get_class(module_name, class_name): - # 动态导入模块 - module = importlib.import_module(module_name) - # 使用 getattr 从模块中获取类 - return getattr(module, class_name) + @dataclass diff --git a/algo/model_manager.py b/algo/model_manager.py index 94a058c..e3bc4c9 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -19,9 +19,9 @@ class ModelManager: - def __init__(self, db, model_warm_up=5): - self.db = db - self.model_service = ModelService(self.db) + def __init__(self, model_service: ModelService, model_warm_up=5): + # self.db = db + self.model_service = model_service self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up @@ -37,6 +37,8 @@ logger.info('loading models') self.models = {} self.query_model_inuse() + if not self.models: + logger.info("no model in use") for algo_model_id, algo_model_exec in self.models.items(): self.load_model(algo_model_exec) diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml index d0876a7..33f5c97 100644 --- a/.idea/safe-algo-pro.iml +++ b/.idea/safe-algo-pro.iml @@ -5,4 +5,7 @@ + + \ No newline at end of file diff --git a/algo/algo_runner.py b/algo/algo_runner.py index d59d93a..40b8d1e 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -5,7 +5,7 @@ from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager -from common.consts import NotifyChangeType +from common.consts import NotifyChangeType, DEVICE_MODE from common.global_thread_pool import GlobalThreadPool from db.database import get_db from entity.device import Device @@ -14,23 +14,25 @@ from services.model_service import ModelService from common.global_logger import logger - class AlgoRunner: - def __init__(self): - with next(get_db()) as db: - self.device_service = DeviceService(db) - self.model_service = ModelService(db) - self.relation_service = DeviceModelRelationService(db) - self.model_manager = ModelManager(db) + def __init__(self, + device_service : DeviceService, + model_service: ModelService, + relation_service: DeviceModelRelationService): + + self.device_service = device_service + self.model_service = model_service + self.relation_service = relation_service + self.model_manager = ModelManager(model_service) self.thread_pool = GlobalThreadPool() self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 - self.device_service.register_change_callback(self.on_device_change) - self.model_service.register_change_callback(self.on_model_change) - self.relation_service.register_change_callback(self.on_relation_change) + # self.device_service.register_change_callback(self.on_device_change) + # self.model_service.register_change_callback(self.on_model_change) + # self.relation_service.register_change_callback(self.on_relation_change) async def start(self): logger.info("Starting AlgoRunner...") @@ -44,10 +46,6 @@ for device in devices: self.start_device_thread(device) - # @staticmethod - # def get_device_thread_id(device_id): - # return f'device_{device_id}' - def start_device_thread(self, device: Device): device = copy.deepcopy(device) """为单个设备启动检测线程""" @@ -56,7 +54,10 @@ logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return - thread_id = f'device_{device.id}_{uuid.uuid4()}' + if not device.mode == DEVICE_MODE.ALGO: + return + + thread_id = f'device_{device.id}_algo_{uuid.uuid4()}' logger.info(f'start thread {thread_id}, device info: {device}') # if self.thread_pool.check_task_is_running(thread_id=thread_id): @@ -135,9 +136,10 @@ elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: - # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 - old_device = self.device_tasks[device_id].device - new_device = self.device_service.get_device(device_id) + if device_id in self.device_tasks: + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) self.restart_device_thread(device_id) diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py index d7d8015..9109d75 100644 --- a/algo/algo_runner_manager.py +++ b/algo/algo_runner_manager.py @@ -1,7 +1,8 @@ from algo.algo_runner import AlgoRunner -algo_runner = AlgoRunner() +# algo_runner = AlgoRunner() def get_algo_runner(): - return algo_runner + pass + # return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index b5a2d9e..e222ca7 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -9,7 +9,7 @@ from common.device_status_manager import DeviceStatusManager from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool -from common.string_utils import camel_to_snake +from common.string_utils import camel_to_snake, get_class from db.database import get_db from entity.device import Device from entity.frame_analysis_result import FrameAnalysisResultCreate @@ -17,11 +17,7 @@ from services.frame_analysis_result_service import FrameAnalysisResultService -def get_class(module_name, class_name): - # 动态导入模块 - module = importlib.import_module(module_name) - # 使用 getattr 从模块中获取类 - return getattr(module, class_name) + @dataclass diff --git a/algo/model_manager.py b/algo/model_manager.py index 94a058c..e3bc4c9 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -19,9 +19,9 @@ class ModelManager: - def __init__(self, db, model_warm_up=5): - self.db = db - self.model_service = ModelService(self.db) + def __init__(self, model_service: ModelService, model_warm_up=5): + # self.db = db + self.model_service = model_service self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up @@ -37,6 +37,8 @@ logger.info('loading models') self.models = {} self.query_model_inuse() + if not self.models: + logger.info("no model in use") for algo_model_id, algo_model_exec in self.models.items(): self.load_model(algo_model_exec) diff --git a/algo/scene_runner.py b/algo/scene_runner.py new file mode 100644 index 0000000..2fb2b2b --- /dev/null +++ b/algo/scene_runner.py @@ -0,0 +1,146 @@ +import concurrent +import copy +import uuid +from typing import Dict + +from common.consts import DEVICE_MODE, NotifyChangeType +from common.global_logger import logger +from common.global_thread_pool import GlobalThreadPool +from common.string_utils import get_class, camel_to_snake +from entity.device import Device +from scene_handler.base_scene_handler import BaseSceneHandler +from services.device_scene_relation_service import DeviceSceneRelationService +from services.device_service import DeviceService +from services.scene_service import SceneService +from tcp.tcp_manager import TcpManager + + +class SceneRunner: + def __init__(self, + device_service: DeviceService, + scene_service: SceneService, + relation_service: DeviceSceneRelationService, + tcp_manager: TcpManager, + main_loop): + self.device_service = device_service + self.scene_service = scene_service + self.relation_service = relation_service + self.tcp_manager = tcp_manager + self.main_loop = main_loop + + self.thread_pool = GlobalThreadPool() + self.device_tasks: Dict[int, BaseSceneHandler] = {} + self.device_scene_relations = {} + + # 注册设备和模型的变化回调 + # self.device_service.register_change_callback(self.on_device_change) + # self.scene_service.register_change_callback(self.on_scene_change) + # self.relation_service.register_change_callback(self.on_relation_change) + + async def start(self): + self.load_and_start_devices() + + def load_and_start_devices(self): + """从数据库读取设备列表并启动线程""" + devices = self.device_service.get_device_list() + devices = [device for device in devices if device.mode == DEVICE_MODE.SCENE] + for device in devices: + self.start_device_thread(device) + + def start_device_thread(self, device: Device): + device = copy.deepcopy(device) + """为单个设备启动检测线程""" + + if not device.input_stream_url: + logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') + return + + if not device.mode == DEVICE_MODE.SCENE: + return + + thread_id = f'device_{device.id}_scene_{uuid.uuid4()}' + scene = self.relation_service.get_device_scene(device.id) + if scene: + self.device_scene_relations[device.id] = scene + try: + handle_task_name = scene.scene_handle_task + handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) + handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) + future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) + self.device_tasks[device.id] = handler_instance + logger.info(f'start thread {thread_id}, device info: {device}') + except Exception as e: + logger.exception(f"设备 {device.code} 场景 {scene.scene_name} 启动异常 : {e}") + + def stop_device_thread(self, device_id): + """停止指定设备的检测线程,并确保其成功停止""" + if device_id in self.device_tasks: + logger.info(f'stop device {device_id} thread') + # 获取线程 ID 和 future 对象 + thread_id = self.device_tasks[device_id].thread_id + future = self.thread_pool.get_task_future(thread_id=thread_id) + + if future: + if not future.done(): + # 任务正在运行,调用 stop_detection_task 停止任务 + self.device_tasks[device_id].stop_task() + try: + # 设置超时时间等待任务停止(例如10秒) + result = future.result(timeout=30) + logger.info(f"Task {thread_id} stopped successfully.") + except concurrent.futures.TimeoutError as te: + logger.error(f"Task {thread_id} did not stop within the timeout.") + except Exception as e: + logger.exception(f"Task {thread_id} encountered an error while stopping: {e}") + finally: + # 确保无论任务是否停止,都将其从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.info(f"Task {thread_id} has already stopped.") + # 任务已停止,直接从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.warning(f"No task found for {thread_id} .") + else: + logger.warning(f"No task exists for device {device_id} in device_tasks.") + + def restart_device_thread(self, device_id): + try: + self.stop_device_thread(device_id) + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + except Exception as e: + logger.error(e) + + def on_device_change(self, device_id, change_type): + logger.info(f"on device change, device {device_id} {change_type}") + """设备变化回调""" + if change_type == NotifyChangeType.DEVICE_CREATE: + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + elif change_type == NotifyChangeType.DEVICE_DELETE: + self.stop_device_thread(device_id) + elif change_type == NotifyChangeType.DEVICE_UPDATE: + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + if device_id in self.device_tasks: + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) + + self.restart_device_thread(device_id) + + def on_scene_change(self, scene_id, change_type): + """模型变化回调""" + if change_type == NotifyChangeType.SCENE_UPDATE: + + devices_to_reload = [ + device_id for device_id, relation_info in self.device_scene_relations.items() + if relation_info.scene_id == scene_id + ] + for device_id in devices_to_reload: + self.restart_device_thread(device_id) + + def on_relation_change(self, device_id, change_type): + """设备模型关系变化回调""" + # 判断模型绑定关系是否真的发生了变化,避免不必要的重启 + if change_type == NotifyChangeType.DEVICE_SCENE_RELATION_UPDATE: + self.restart_device_thread(device_id) diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml index d0876a7..33f5c97 100644 --- a/.idea/safe-algo-pro.iml +++ b/.idea/safe-algo-pro.iml @@ -5,4 +5,7 @@ + + \ No newline at end of file diff --git a/algo/algo_runner.py b/algo/algo_runner.py index d59d93a..40b8d1e 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -5,7 +5,7 @@ from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager -from common.consts import NotifyChangeType +from common.consts import NotifyChangeType, DEVICE_MODE from common.global_thread_pool import GlobalThreadPool from db.database import get_db from entity.device import Device @@ -14,23 +14,25 @@ from services.model_service import ModelService from common.global_logger import logger - class AlgoRunner: - def __init__(self): - with next(get_db()) as db: - self.device_service = DeviceService(db) - self.model_service = ModelService(db) - self.relation_service = DeviceModelRelationService(db) - self.model_manager = ModelManager(db) + def __init__(self, + device_service : DeviceService, + model_service: ModelService, + relation_service: DeviceModelRelationService): + + self.device_service = device_service + self.model_service = model_service + self.relation_service = relation_service + self.model_manager = ModelManager(model_service) self.thread_pool = GlobalThreadPool() self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 - self.device_service.register_change_callback(self.on_device_change) - self.model_service.register_change_callback(self.on_model_change) - self.relation_service.register_change_callback(self.on_relation_change) + # self.device_service.register_change_callback(self.on_device_change) + # self.model_service.register_change_callback(self.on_model_change) + # self.relation_service.register_change_callback(self.on_relation_change) async def start(self): logger.info("Starting AlgoRunner...") @@ -44,10 +46,6 @@ for device in devices: self.start_device_thread(device) - # @staticmethod - # def get_device_thread_id(device_id): - # return f'device_{device_id}' - def start_device_thread(self, device: Device): device = copy.deepcopy(device) """为单个设备启动检测线程""" @@ -56,7 +54,10 @@ logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return - thread_id = f'device_{device.id}_{uuid.uuid4()}' + if not device.mode == DEVICE_MODE.ALGO: + return + + thread_id = f'device_{device.id}_algo_{uuid.uuid4()}' logger.info(f'start thread {thread_id}, device info: {device}') # if self.thread_pool.check_task_is_running(thread_id=thread_id): @@ -135,9 +136,10 @@ elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: - # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 - old_device = self.device_tasks[device_id].device - new_device = self.device_service.get_device(device_id) + if device_id in self.device_tasks: + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) self.restart_device_thread(device_id) diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py index d7d8015..9109d75 100644 --- a/algo/algo_runner_manager.py +++ b/algo/algo_runner_manager.py @@ -1,7 +1,8 @@ from algo.algo_runner import AlgoRunner -algo_runner = AlgoRunner() +# algo_runner = AlgoRunner() def get_algo_runner(): - return algo_runner + pass + # return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index b5a2d9e..e222ca7 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -9,7 +9,7 @@ from common.device_status_manager import DeviceStatusManager from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool -from common.string_utils import camel_to_snake +from common.string_utils import camel_to_snake, get_class from db.database import get_db from entity.device import Device from entity.frame_analysis_result import FrameAnalysisResultCreate @@ -17,11 +17,7 @@ from services.frame_analysis_result_service import FrameAnalysisResultService -def get_class(module_name, class_name): - # 动态导入模块 - module = importlib.import_module(module_name) - # 使用 getattr 从模块中获取类 - return getattr(module, class_name) + @dataclass diff --git a/algo/model_manager.py b/algo/model_manager.py index 94a058c..e3bc4c9 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -19,9 +19,9 @@ class ModelManager: - def __init__(self, db, model_warm_up=5): - self.db = db - self.model_service = ModelService(self.db) + def __init__(self, model_service: ModelService, model_warm_up=5): + # self.db = db + self.model_service = model_service self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up @@ -37,6 +37,8 @@ logger.info('loading models') self.models = {} self.query_model_inuse() + if not self.models: + logger.info("no model in use") for algo_model_id, algo_model_exec in self.models.items(): self.load_model(algo_model_exec) diff --git a/algo/scene_runner.py b/algo/scene_runner.py new file mode 100644 index 0000000..2fb2b2b --- /dev/null +++ b/algo/scene_runner.py @@ -0,0 +1,146 @@ +import concurrent +import copy +import uuid +from typing import Dict + +from common.consts import DEVICE_MODE, NotifyChangeType +from common.global_logger import logger +from common.global_thread_pool import GlobalThreadPool +from common.string_utils import get_class, camel_to_snake +from entity.device import Device +from scene_handler.base_scene_handler import BaseSceneHandler +from services.device_scene_relation_service import DeviceSceneRelationService +from services.device_service import DeviceService +from services.scene_service import SceneService +from tcp.tcp_manager import TcpManager + + +class SceneRunner: + def __init__(self, + device_service: DeviceService, + scene_service: SceneService, + relation_service: DeviceSceneRelationService, + tcp_manager: TcpManager, + main_loop): + self.device_service = device_service + self.scene_service = scene_service + self.relation_service = relation_service + self.tcp_manager = tcp_manager + self.main_loop = main_loop + + self.thread_pool = GlobalThreadPool() + self.device_tasks: Dict[int, BaseSceneHandler] = {} + self.device_scene_relations = {} + + # 注册设备和模型的变化回调 + # self.device_service.register_change_callback(self.on_device_change) + # self.scene_service.register_change_callback(self.on_scene_change) + # self.relation_service.register_change_callback(self.on_relation_change) + + async def start(self): + self.load_and_start_devices() + + def load_and_start_devices(self): + """从数据库读取设备列表并启动线程""" + devices = self.device_service.get_device_list() + devices = [device for device in devices if device.mode == DEVICE_MODE.SCENE] + for device in devices: + self.start_device_thread(device) + + def start_device_thread(self, device: Device): + device = copy.deepcopy(device) + """为单个设备启动检测线程""" + + if not device.input_stream_url: + logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') + return + + if not device.mode == DEVICE_MODE.SCENE: + return + + thread_id = f'device_{device.id}_scene_{uuid.uuid4()}' + scene = self.relation_service.get_device_scene(device.id) + if scene: + self.device_scene_relations[device.id] = scene + try: + handle_task_name = scene.scene_handle_task + handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) + handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) + future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) + self.device_tasks[device.id] = handler_instance + logger.info(f'start thread {thread_id}, device info: {device}') + except Exception as e: + logger.exception(f"设备 {device.code} 场景 {scene.scene_name} 启动异常 : {e}") + + def stop_device_thread(self, device_id): + """停止指定设备的检测线程,并确保其成功停止""" + if device_id in self.device_tasks: + logger.info(f'stop device {device_id} thread') + # 获取线程 ID 和 future 对象 + thread_id = self.device_tasks[device_id].thread_id + future = self.thread_pool.get_task_future(thread_id=thread_id) + + if future: + if not future.done(): + # 任务正在运行,调用 stop_detection_task 停止任务 + self.device_tasks[device_id].stop_task() + try: + # 设置超时时间等待任务停止(例如10秒) + result = future.result(timeout=30) + logger.info(f"Task {thread_id} stopped successfully.") + except concurrent.futures.TimeoutError as te: + logger.error(f"Task {thread_id} did not stop within the timeout.") + except Exception as e: + logger.exception(f"Task {thread_id} encountered an error while stopping: {e}") + finally: + # 确保无论任务是否停止,都将其从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.info(f"Task {thread_id} has already stopped.") + # 任务已停止,直接从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.warning(f"No task found for {thread_id} .") + else: + logger.warning(f"No task exists for device {device_id} in device_tasks.") + + def restart_device_thread(self, device_id): + try: + self.stop_device_thread(device_id) + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + except Exception as e: + logger.error(e) + + def on_device_change(self, device_id, change_type): + logger.info(f"on device change, device {device_id} {change_type}") + """设备变化回调""" + if change_type == NotifyChangeType.DEVICE_CREATE: + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + elif change_type == NotifyChangeType.DEVICE_DELETE: + self.stop_device_thread(device_id) + elif change_type == NotifyChangeType.DEVICE_UPDATE: + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + if device_id in self.device_tasks: + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) + + self.restart_device_thread(device_id) + + def on_scene_change(self, scene_id, change_type): + """模型变化回调""" + if change_type == NotifyChangeType.SCENE_UPDATE: + + devices_to_reload = [ + device_id for device_id, relation_info in self.device_scene_relations.items() + if relation_info.scene_id == scene_id + ] + for device_id in devices_to_reload: + self.restart_device_thread(device_id) + + def on_relation_change(self, device_id, change_type): + """设备模型关系变化回调""" + # 判断模型绑定关系是否真的发生了变化,避免不必要的重启 + if change_type == NotifyChangeType.DEVICE_SCENE_RELATION_UPDATE: + self.restart_device_thread(device_id) diff --git a/apis/device.py b/apis/device.py index f786851..85fb10c 100644 --- a/apis/device.py +++ b/apis/device.py @@ -3,15 +3,18 @@ from fastapi import APIRouter, Depends, Query from sqlmodel import Session -from algo.algo_runner import AlgoRunner from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo from services.device_service import DeviceService -from algo.algo_runner_manager import get_algo_runner +from app_instance import get_app router = APIRouter() +app = get_app() + +def get_service(): + return app.state.device_service @router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) @@ -45,15 +48,13 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.device_service +def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): device = service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.device_service +def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): device = service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") @@ -61,8 +62,7 @@ @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.device_service +def delete_device(device_id: int, service: DeviceService = Depends(get_service)): device = service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml index d0876a7..33f5c97 100644 --- a/.idea/safe-algo-pro.iml +++ b/.idea/safe-algo-pro.iml @@ -5,4 +5,7 @@ + + \ No newline at end of file diff --git a/algo/algo_runner.py b/algo/algo_runner.py index d59d93a..40b8d1e 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -5,7 +5,7 @@ from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager -from common.consts import NotifyChangeType +from common.consts import NotifyChangeType, DEVICE_MODE from common.global_thread_pool import GlobalThreadPool from db.database import get_db from entity.device import Device @@ -14,23 +14,25 @@ from services.model_service import ModelService from common.global_logger import logger - class AlgoRunner: - def __init__(self): - with next(get_db()) as db: - self.device_service = DeviceService(db) - self.model_service = ModelService(db) - self.relation_service = DeviceModelRelationService(db) - self.model_manager = ModelManager(db) + def __init__(self, + device_service : DeviceService, + model_service: ModelService, + relation_service: DeviceModelRelationService): + + self.device_service = device_service + self.model_service = model_service + self.relation_service = relation_service + self.model_manager = ModelManager(model_service) self.thread_pool = GlobalThreadPool() self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 - self.device_service.register_change_callback(self.on_device_change) - self.model_service.register_change_callback(self.on_model_change) - self.relation_service.register_change_callback(self.on_relation_change) + # self.device_service.register_change_callback(self.on_device_change) + # self.model_service.register_change_callback(self.on_model_change) + # self.relation_service.register_change_callback(self.on_relation_change) async def start(self): logger.info("Starting AlgoRunner...") @@ -44,10 +46,6 @@ for device in devices: self.start_device_thread(device) - # @staticmethod - # def get_device_thread_id(device_id): - # return f'device_{device_id}' - def start_device_thread(self, device: Device): device = copy.deepcopy(device) """为单个设备启动检测线程""" @@ -56,7 +54,10 @@ logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return - thread_id = f'device_{device.id}_{uuid.uuid4()}' + if not device.mode == DEVICE_MODE.ALGO: + return + + thread_id = f'device_{device.id}_algo_{uuid.uuid4()}' logger.info(f'start thread {thread_id}, device info: {device}') # if self.thread_pool.check_task_is_running(thread_id=thread_id): @@ -135,9 +136,10 @@ elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: - # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 - old_device = self.device_tasks[device_id].device - new_device = self.device_service.get_device(device_id) + if device_id in self.device_tasks: + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) self.restart_device_thread(device_id) diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py index d7d8015..9109d75 100644 --- a/algo/algo_runner_manager.py +++ b/algo/algo_runner_manager.py @@ -1,7 +1,8 @@ from algo.algo_runner import AlgoRunner -algo_runner = AlgoRunner() +# algo_runner = AlgoRunner() def get_algo_runner(): - return algo_runner + pass + # return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index b5a2d9e..e222ca7 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -9,7 +9,7 @@ from common.device_status_manager import DeviceStatusManager from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool -from common.string_utils import camel_to_snake +from common.string_utils import camel_to_snake, get_class from db.database import get_db from entity.device import Device from entity.frame_analysis_result import FrameAnalysisResultCreate @@ -17,11 +17,7 @@ from services.frame_analysis_result_service import FrameAnalysisResultService -def get_class(module_name, class_name): - # 动态导入模块 - module = importlib.import_module(module_name) - # 使用 getattr 从模块中获取类 - return getattr(module, class_name) + @dataclass diff --git a/algo/model_manager.py b/algo/model_manager.py index 94a058c..e3bc4c9 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -19,9 +19,9 @@ class ModelManager: - def __init__(self, db, model_warm_up=5): - self.db = db - self.model_service = ModelService(self.db) + def __init__(self, model_service: ModelService, model_warm_up=5): + # self.db = db + self.model_service = model_service self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up @@ -37,6 +37,8 @@ logger.info('loading models') self.models = {} self.query_model_inuse() + if not self.models: + logger.info("no model in use") for algo_model_id, algo_model_exec in self.models.items(): self.load_model(algo_model_exec) diff --git a/algo/scene_runner.py b/algo/scene_runner.py new file mode 100644 index 0000000..2fb2b2b --- /dev/null +++ b/algo/scene_runner.py @@ -0,0 +1,146 @@ +import concurrent +import copy +import uuid +from typing import Dict + +from common.consts import DEVICE_MODE, NotifyChangeType +from common.global_logger import logger +from common.global_thread_pool import GlobalThreadPool +from common.string_utils import get_class, camel_to_snake +from entity.device import Device +from scene_handler.base_scene_handler import BaseSceneHandler +from services.device_scene_relation_service import DeviceSceneRelationService +from services.device_service import DeviceService +from services.scene_service import SceneService +from tcp.tcp_manager import TcpManager + + +class SceneRunner: + def __init__(self, + device_service: DeviceService, + scene_service: SceneService, + relation_service: DeviceSceneRelationService, + tcp_manager: TcpManager, + main_loop): + self.device_service = device_service + self.scene_service = scene_service + self.relation_service = relation_service + self.tcp_manager = tcp_manager + self.main_loop = main_loop + + self.thread_pool = GlobalThreadPool() + self.device_tasks: Dict[int, BaseSceneHandler] = {} + self.device_scene_relations = {} + + # 注册设备和模型的变化回调 + # self.device_service.register_change_callback(self.on_device_change) + # self.scene_service.register_change_callback(self.on_scene_change) + # self.relation_service.register_change_callback(self.on_relation_change) + + async def start(self): + self.load_and_start_devices() + + def load_and_start_devices(self): + """从数据库读取设备列表并启动线程""" + devices = self.device_service.get_device_list() + devices = [device for device in devices if device.mode == DEVICE_MODE.SCENE] + for device in devices: + self.start_device_thread(device) + + def start_device_thread(self, device: Device): + device = copy.deepcopy(device) + """为单个设备启动检测线程""" + + if not device.input_stream_url: + logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') + return + + if not device.mode == DEVICE_MODE.SCENE: + return + + thread_id = f'device_{device.id}_scene_{uuid.uuid4()}' + scene = self.relation_service.get_device_scene(device.id) + if scene: + self.device_scene_relations[device.id] = scene + try: + handle_task_name = scene.scene_handle_task + handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) + handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) + future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) + self.device_tasks[device.id] = handler_instance + logger.info(f'start thread {thread_id}, device info: {device}') + except Exception as e: + logger.exception(f"设备 {device.code} 场景 {scene.scene_name} 启动异常 : {e}") + + def stop_device_thread(self, device_id): + """停止指定设备的检测线程,并确保其成功停止""" + if device_id in self.device_tasks: + logger.info(f'stop device {device_id} thread') + # 获取线程 ID 和 future 对象 + thread_id = self.device_tasks[device_id].thread_id + future = self.thread_pool.get_task_future(thread_id=thread_id) + + if future: + if not future.done(): + # 任务正在运行,调用 stop_detection_task 停止任务 + self.device_tasks[device_id].stop_task() + try: + # 设置超时时间等待任务停止(例如10秒) + result = future.result(timeout=30) + logger.info(f"Task {thread_id} stopped successfully.") + except concurrent.futures.TimeoutError as te: + logger.error(f"Task {thread_id} did not stop within the timeout.") + except Exception as e: + logger.exception(f"Task {thread_id} encountered an error while stopping: {e}") + finally: + # 确保无论任务是否停止,都将其从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.info(f"Task {thread_id} has already stopped.") + # 任务已停止,直接从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.warning(f"No task found for {thread_id} .") + else: + logger.warning(f"No task exists for device {device_id} in device_tasks.") + + def restart_device_thread(self, device_id): + try: + self.stop_device_thread(device_id) + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + except Exception as e: + logger.error(e) + + def on_device_change(self, device_id, change_type): + logger.info(f"on device change, device {device_id} {change_type}") + """设备变化回调""" + if change_type == NotifyChangeType.DEVICE_CREATE: + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + elif change_type == NotifyChangeType.DEVICE_DELETE: + self.stop_device_thread(device_id) + elif change_type == NotifyChangeType.DEVICE_UPDATE: + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + if device_id in self.device_tasks: + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) + + self.restart_device_thread(device_id) + + def on_scene_change(self, scene_id, change_type): + """模型变化回调""" + if change_type == NotifyChangeType.SCENE_UPDATE: + + devices_to_reload = [ + device_id for device_id, relation_info in self.device_scene_relations.items() + if relation_info.scene_id == scene_id + ] + for device_id in devices_to_reload: + self.restart_device_thread(device_id) + + def on_relation_change(self, device_id, change_type): + """设备模型关系变化回调""" + # 判断模型绑定关系是否真的发生了变化,避免不必要的重启 + if change_type == NotifyChangeType.DEVICE_SCENE_RELATION_UPDATE: + self.restart_device_thread(device_id) diff --git a/apis/device.py b/apis/device.py index f786851..85fb10c 100644 --- a/apis/device.py +++ b/apis/device.py @@ -3,15 +3,18 @@ from fastapi import APIRouter, Depends, Query from sqlmodel import Session -from algo.algo_runner import AlgoRunner from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo from services.device_service import DeviceService -from algo.algo_runner_manager import get_algo_runner +from app_instance import get_app router = APIRouter() +app = get_app() + +def get_service(): + return app.state.device_service @router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) @@ -45,15 +48,13 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.device_service +def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): device = service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.device_service +def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): device = service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") @@ -61,8 +62,7 @@ @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.device_service +def delete_device(device_id: int, service: DeviceService = Depends(get_service)): device = service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py index 9152eca..443176b 100644 --- a/apis/device_model_realtion.py +++ b/apis/device_model_realtion.py @@ -8,10 +8,13 @@ from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation from services.device_model_relation_service import DeviceModelRelationService -from algo.algo_runner import AlgoRunner -from algo.algo_runner_manager import get_algo_runner +from app_instance import get_app router = APIRouter() +app = get_app() + +def get_service(): + return app.state.device_model_relation_service @router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) @@ -35,12 +38,10 @@ @router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) def update_by_device(relation_data: List[DeviceModelRelationCreate], device_id: int = Query(...), - algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.relation_service + service: DeviceModelRelationService = Depends(get_service)): relations = service.update_relations_by_device(device_id, relation_data) return standard_response(data=relations) - # @router.delete("/delete_by_device", response_model=StandardResponse[int]) # def delete_device(device_id: int, db: Session = Depends(get_db)): # service = DeviceModelRelationService(db) diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml index d0876a7..33f5c97 100644 --- a/.idea/safe-algo-pro.iml +++ b/.idea/safe-algo-pro.iml @@ -5,4 +5,7 @@ + + \ No newline at end of file diff --git a/algo/algo_runner.py b/algo/algo_runner.py index d59d93a..40b8d1e 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -5,7 +5,7 @@ from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager -from common.consts import NotifyChangeType +from common.consts import NotifyChangeType, DEVICE_MODE from common.global_thread_pool import GlobalThreadPool from db.database import get_db from entity.device import Device @@ -14,23 +14,25 @@ from services.model_service import ModelService from common.global_logger import logger - class AlgoRunner: - def __init__(self): - with next(get_db()) as db: - self.device_service = DeviceService(db) - self.model_service = ModelService(db) - self.relation_service = DeviceModelRelationService(db) - self.model_manager = ModelManager(db) + def __init__(self, + device_service : DeviceService, + model_service: ModelService, + relation_service: DeviceModelRelationService): + + self.device_service = device_service + self.model_service = model_service + self.relation_service = relation_service + self.model_manager = ModelManager(model_service) self.thread_pool = GlobalThreadPool() self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 - self.device_service.register_change_callback(self.on_device_change) - self.model_service.register_change_callback(self.on_model_change) - self.relation_service.register_change_callback(self.on_relation_change) + # self.device_service.register_change_callback(self.on_device_change) + # self.model_service.register_change_callback(self.on_model_change) + # self.relation_service.register_change_callback(self.on_relation_change) async def start(self): logger.info("Starting AlgoRunner...") @@ -44,10 +46,6 @@ for device in devices: self.start_device_thread(device) - # @staticmethod - # def get_device_thread_id(device_id): - # return f'device_{device_id}' - def start_device_thread(self, device: Device): device = copy.deepcopy(device) """为单个设备启动检测线程""" @@ -56,7 +54,10 @@ logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return - thread_id = f'device_{device.id}_{uuid.uuid4()}' + if not device.mode == DEVICE_MODE.ALGO: + return + + thread_id = f'device_{device.id}_algo_{uuid.uuid4()}' logger.info(f'start thread {thread_id}, device info: {device}') # if self.thread_pool.check_task_is_running(thread_id=thread_id): @@ -135,9 +136,10 @@ elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: - # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 - old_device = self.device_tasks[device_id].device - new_device = self.device_service.get_device(device_id) + if device_id in self.device_tasks: + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) self.restart_device_thread(device_id) diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py index d7d8015..9109d75 100644 --- a/algo/algo_runner_manager.py +++ b/algo/algo_runner_manager.py @@ -1,7 +1,8 @@ from algo.algo_runner import AlgoRunner -algo_runner = AlgoRunner() +# algo_runner = AlgoRunner() def get_algo_runner(): - return algo_runner + pass + # return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index b5a2d9e..e222ca7 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -9,7 +9,7 @@ from common.device_status_manager import DeviceStatusManager from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool -from common.string_utils import camel_to_snake +from common.string_utils import camel_to_snake, get_class from db.database import get_db from entity.device import Device from entity.frame_analysis_result import FrameAnalysisResultCreate @@ -17,11 +17,7 @@ from services.frame_analysis_result_service import FrameAnalysisResultService -def get_class(module_name, class_name): - # 动态导入模块 - module = importlib.import_module(module_name) - # 使用 getattr 从模块中获取类 - return getattr(module, class_name) + @dataclass diff --git a/algo/model_manager.py b/algo/model_manager.py index 94a058c..e3bc4c9 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -19,9 +19,9 @@ class ModelManager: - def __init__(self, db, model_warm_up=5): - self.db = db - self.model_service = ModelService(self.db) + def __init__(self, model_service: ModelService, model_warm_up=5): + # self.db = db + self.model_service = model_service self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up @@ -37,6 +37,8 @@ logger.info('loading models') self.models = {} self.query_model_inuse() + if not self.models: + logger.info("no model in use") for algo_model_id, algo_model_exec in self.models.items(): self.load_model(algo_model_exec) diff --git a/algo/scene_runner.py b/algo/scene_runner.py new file mode 100644 index 0000000..2fb2b2b --- /dev/null +++ b/algo/scene_runner.py @@ -0,0 +1,146 @@ +import concurrent +import copy +import uuid +from typing import Dict + +from common.consts import DEVICE_MODE, NotifyChangeType +from common.global_logger import logger +from common.global_thread_pool import GlobalThreadPool +from common.string_utils import get_class, camel_to_snake +from entity.device import Device +from scene_handler.base_scene_handler import BaseSceneHandler +from services.device_scene_relation_service import DeviceSceneRelationService +from services.device_service import DeviceService +from services.scene_service import SceneService +from tcp.tcp_manager import TcpManager + + +class SceneRunner: + def __init__(self, + device_service: DeviceService, + scene_service: SceneService, + relation_service: DeviceSceneRelationService, + tcp_manager: TcpManager, + main_loop): + self.device_service = device_service + self.scene_service = scene_service + self.relation_service = relation_service + self.tcp_manager = tcp_manager + self.main_loop = main_loop + + self.thread_pool = GlobalThreadPool() + self.device_tasks: Dict[int, BaseSceneHandler] = {} + self.device_scene_relations = {} + + # 注册设备和模型的变化回调 + # self.device_service.register_change_callback(self.on_device_change) + # self.scene_service.register_change_callback(self.on_scene_change) + # self.relation_service.register_change_callback(self.on_relation_change) + + async def start(self): + self.load_and_start_devices() + + def load_and_start_devices(self): + """从数据库读取设备列表并启动线程""" + devices = self.device_service.get_device_list() + devices = [device for device in devices if device.mode == DEVICE_MODE.SCENE] + for device in devices: + self.start_device_thread(device) + + def start_device_thread(self, device: Device): + device = copy.deepcopy(device) + """为单个设备启动检测线程""" + + if not device.input_stream_url: + logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') + return + + if not device.mode == DEVICE_MODE.SCENE: + return + + thread_id = f'device_{device.id}_scene_{uuid.uuid4()}' + scene = self.relation_service.get_device_scene(device.id) + if scene: + self.device_scene_relations[device.id] = scene + try: + handle_task_name = scene.scene_handle_task + handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) + handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) + future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) + self.device_tasks[device.id] = handler_instance + logger.info(f'start thread {thread_id}, device info: {device}') + except Exception as e: + logger.exception(f"设备 {device.code} 场景 {scene.scene_name} 启动异常 : {e}") + + def stop_device_thread(self, device_id): + """停止指定设备的检测线程,并确保其成功停止""" + if device_id in self.device_tasks: + logger.info(f'stop device {device_id} thread') + # 获取线程 ID 和 future 对象 + thread_id = self.device_tasks[device_id].thread_id + future = self.thread_pool.get_task_future(thread_id=thread_id) + + if future: + if not future.done(): + # 任务正在运行,调用 stop_detection_task 停止任务 + self.device_tasks[device_id].stop_task() + try: + # 设置超时时间等待任务停止(例如10秒) + result = future.result(timeout=30) + logger.info(f"Task {thread_id} stopped successfully.") + except concurrent.futures.TimeoutError as te: + logger.error(f"Task {thread_id} did not stop within the timeout.") + except Exception as e: + logger.exception(f"Task {thread_id} encountered an error while stopping: {e}") + finally: + # 确保无论任务是否停止,都将其从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.info(f"Task {thread_id} has already stopped.") + # 任务已停止,直接从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.warning(f"No task found for {thread_id} .") + else: + logger.warning(f"No task exists for device {device_id} in device_tasks.") + + def restart_device_thread(self, device_id): + try: + self.stop_device_thread(device_id) + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + except Exception as e: + logger.error(e) + + def on_device_change(self, device_id, change_type): + logger.info(f"on device change, device {device_id} {change_type}") + """设备变化回调""" + if change_type == NotifyChangeType.DEVICE_CREATE: + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + elif change_type == NotifyChangeType.DEVICE_DELETE: + self.stop_device_thread(device_id) + elif change_type == NotifyChangeType.DEVICE_UPDATE: + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + if device_id in self.device_tasks: + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) + + self.restart_device_thread(device_id) + + def on_scene_change(self, scene_id, change_type): + """模型变化回调""" + if change_type == NotifyChangeType.SCENE_UPDATE: + + devices_to_reload = [ + device_id for device_id, relation_info in self.device_scene_relations.items() + if relation_info.scene_id == scene_id + ] + for device_id in devices_to_reload: + self.restart_device_thread(device_id) + + def on_relation_change(self, device_id, change_type): + """设备模型关系变化回调""" + # 判断模型绑定关系是否真的发生了变化,避免不必要的重启 + if change_type == NotifyChangeType.DEVICE_SCENE_RELATION_UPDATE: + self.restart_device_thread(device_id) diff --git a/apis/device.py b/apis/device.py index f786851..85fb10c 100644 --- a/apis/device.py +++ b/apis/device.py @@ -3,15 +3,18 @@ from fastapi import APIRouter, Depends, Query from sqlmodel import Session -from algo.algo_runner import AlgoRunner from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo from services.device_service import DeviceService -from algo.algo_runner_manager import get_algo_runner +from app_instance import get_app router = APIRouter() +app = get_app() + +def get_service(): + return app.state.device_service @router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) @@ -45,15 +48,13 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.device_service +def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): device = service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.device_service +def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): device = service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") @@ -61,8 +62,7 @@ @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.device_service +def delete_device(device_id: int, service: DeviceService = Depends(get_service)): device = service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py index 9152eca..443176b 100644 --- a/apis/device_model_realtion.py +++ b/apis/device_model_realtion.py @@ -8,10 +8,13 @@ from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation from services.device_model_relation_service import DeviceModelRelationService -from algo.algo_runner import AlgoRunner -from algo.algo_runner_manager import get_algo_runner +from app_instance import get_app router = APIRouter() +app = get_app() + +def get_service(): + return app.state.device_model_relation_service @router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) @@ -35,12 +38,10 @@ @router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) def update_by_device(relation_data: List[DeviceModelRelationCreate], device_id: int = Query(...), - algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.relation_service + service: DeviceModelRelationService = Depends(get_service)): relations = service.update_relations_by_device(device_id, relation_data) return standard_response(data=relations) - # @router.delete("/delete_by_device", response_model=StandardResponse[int]) # def delete_device(device_id: int, db: Session = Depends(get_db)): # service = DeviceModelRelationService(db) diff --git a/apis/model.py b/apis/model.py index 264e10a..78c064f 100644 --- a/apis/model.py +++ b/apis/model.py @@ -4,15 +4,17 @@ from sqlmodel import Session from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param +from app_instance import get_app from db.database import get_db from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo from services.model_service import ModelService -from algo.algo_runner import AlgoRunner -from algo.algo_runner_manager import get_algo_runner router = APIRouter() +app = get_app() +def get_service(): + return app.state.model_service @router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) def get_model_list( @@ -60,8 +62,7 @@ @router.post("/update", response_model=StandardResponse[AlgoModelInfo]) def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), file: UploadFile = File(None, description="模型文件"), - algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.model_service + service: ModelService = Depends(get_service)): model_data = AlgoModelUpdate.parse_raw(json_data) model = service.update_model(model_data, file) if not model: diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml index d0876a7..33f5c97 100644 --- a/.idea/safe-algo-pro.iml +++ b/.idea/safe-algo-pro.iml @@ -5,4 +5,7 @@ + + \ No newline at end of file diff --git a/algo/algo_runner.py b/algo/algo_runner.py index d59d93a..40b8d1e 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -5,7 +5,7 @@ from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager -from common.consts import NotifyChangeType +from common.consts import NotifyChangeType, DEVICE_MODE from common.global_thread_pool import GlobalThreadPool from db.database import get_db from entity.device import Device @@ -14,23 +14,25 @@ from services.model_service import ModelService from common.global_logger import logger - class AlgoRunner: - def __init__(self): - with next(get_db()) as db: - self.device_service = DeviceService(db) - self.model_service = ModelService(db) - self.relation_service = DeviceModelRelationService(db) - self.model_manager = ModelManager(db) + def __init__(self, + device_service : DeviceService, + model_service: ModelService, + relation_service: DeviceModelRelationService): + + self.device_service = device_service + self.model_service = model_service + self.relation_service = relation_service + self.model_manager = ModelManager(model_service) self.thread_pool = GlobalThreadPool() self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 - self.device_service.register_change_callback(self.on_device_change) - self.model_service.register_change_callback(self.on_model_change) - self.relation_service.register_change_callback(self.on_relation_change) + # self.device_service.register_change_callback(self.on_device_change) + # self.model_service.register_change_callback(self.on_model_change) + # self.relation_service.register_change_callback(self.on_relation_change) async def start(self): logger.info("Starting AlgoRunner...") @@ -44,10 +46,6 @@ for device in devices: self.start_device_thread(device) - # @staticmethod - # def get_device_thread_id(device_id): - # return f'device_{device_id}' - def start_device_thread(self, device: Device): device = copy.deepcopy(device) """为单个设备启动检测线程""" @@ -56,7 +54,10 @@ logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return - thread_id = f'device_{device.id}_{uuid.uuid4()}' + if not device.mode == DEVICE_MODE.ALGO: + return + + thread_id = f'device_{device.id}_algo_{uuid.uuid4()}' logger.info(f'start thread {thread_id}, device info: {device}') # if self.thread_pool.check_task_is_running(thread_id=thread_id): @@ -135,9 +136,10 @@ elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: - # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 - old_device = self.device_tasks[device_id].device - new_device = self.device_service.get_device(device_id) + if device_id in self.device_tasks: + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) self.restart_device_thread(device_id) diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py index d7d8015..9109d75 100644 --- a/algo/algo_runner_manager.py +++ b/algo/algo_runner_manager.py @@ -1,7 +1,8 @@ from algo.algo_runner import AlgoRunner -algo_runner = AlgoRunner() +# algo_runner = AlgoRunner() def get_algo_runner(): - return algo_runner + pass + # return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index b5a2d9e..e222ca7 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -9,7 +9,7 @@ from common.device_status_manager import DeviceStatusManager from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool -from common.string_utils import camel_to_snake +from common.string_utils import camel_to_snake, get_class from db.database import get_db from entity.device import Device from entity.frame_analysis_result import FrameAnalysisResultCreate @@ -17,11 +17,7 @@ from services.frame_analysis_result_service import FrameAnalysisResultService -def get_class(module_name, class_name): - # 动态导入模块 - module = importlib.import_module(module_name) - # 使用 getattr 从模块中获取类 - return getattr(module, class_name) + @dataclass diff --git a/algo/model_manager.py b/algo/model_manager.py index 94a058c..e3bc4c9 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -19,9 +19,9 @@ class ModelManager: - def __init__(self, db, model_warm_up=5): - self.db = db - self.model_service = ModelService(self.db) + def __init__(self, model_service: ModelService, model_warm_up=5): + # self.db = db + self.model_service = model_service self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up @@ -37,6 +37,8 @@ logger.info('loading models') self.models = {} self.query_model_inuse() + if not self.models: + logger.info("no model in use") for algo_model_id, algo_model_exec in self.models.items(): self.load_model(algo_model_exec) diff --git a/algo/scene_runner.py b/algo/scene_runner.py new file mode 100644 index 0000000..2fb2b2b --- /dev/null +++ b/algo/scene_runner.py @@ -0,0 +1,146 @@ +import concurrent +import copy +import uuid +from typing import Dict + +from common.consts import DEVICE_MODE, NotifyChangeType +from common.global_logger import logger +from common.global_thread_pool import GlobalThreadPool +from common.string_utils import get_class, camel_to_snake +from entity.device import Device +from scene_handler.base_scene_handler import BaseSceneHandler +from services.device_scene_relation_service import DeviceSceneRelationService +from services.device_service import DeviceService +from services.scene_service import SceneService +from tcp.tcp_manager import TcpManager + + +class SceneRunner: + def __init__(self, + device_service: DeviceService, + scene_service: SceneService, + relation_service: DeviceSceneRelationService, + tcp_manager: TcpManager, + main_loop): + self.device_service = device_service + self.scene_service = scene_service + self.relation_service = relation_service + self.tcp_manager = tcp_manager + self.main_loop = main_loop + + self.thread_pool = GlobalThreadPool() + self.device_tasks: Dict[int, BaseSceneHandler] = {} + self.device_scene_relations = {} + + # 注册设备和模型的变化回调 + # self.device_service.register_change_callback(self.on_device_change) + # self.scene_service.register_change_callback(self.on_scene_change) + # self.relation_service.register_change_callback(self.on_relation_change) + + async def start(self): + self.load_and_start_devices() + + def load_and_start_devices(self): + """从数据库读取设备列表并启动线程""" + devices = self.device_service.get_device_list() + devices = [device for device in devices if device.mode == DEVICE_MODE.SCENE] + for device in devices: + self.start_device_thread(device) + + def start_device_thread(self, device: Device): + device = copy.deepcopy(device) + """为单个设备启动检测线程""" + + if not device.input_stream_url: + logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') + return + + if not device.mode == DEVICE_MODE.SCENE: + return + + thread_id = f'device_{device.id}_scene_{uuid.uuid4()}' + scene = self.relation_service.get_device_scene(device.id) + if scene: + self.device_scene_relations[device.id] = scene + try: + handle_task_name = scene.scene_handle_task + handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) + handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) + future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) + self.device_tasks[device.id] = handler_instance + logger.info(f'start thread {thread_id}, device info: {device}') + except Exception as e: + logger.exception(f"设备 {device.code} 场景 {scene.scene_name} 启动异常 : {e}") + + def stop_device_thread(self, device_id): + """停止指定设备的检测线程,并确保其成功停止""" + if device_id in self.device_tasks: + logger.info(f'stop device {device_id} thread') + # 获取线程 ID 和 future 对象 + thread_id = self.device_tasks[device_id].thread_id + future = self.thread_pool.get_task_future(thread_id=thread_id) + + if future: + if not future.done(): + # 任务正在运行,调用 stop_detection_task 停止任务 + self.device_tasks[device_id].stop_task() + try: + # 设置超时时间等待任务停止(例如10秒) + result = future.result(timeout=30) + logger.info(f"Task {thread_id} stopped successfully.") + except concurrent.futures.TimeoutError as te: + logger.error(f"Task {thread_id} did not stop within the timeout.") + except Exception as e: + logger.exception(f"Task {thread_id} encountered an error while stopping: {e}") + finally: + # 确保无论任务是否停止,都将其从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.info(f"Task {thread_id} has already stopped.") + # 任务已停止,直接从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.warning(f"No task found for {thread_id} .") + else: + logger.warning(f"No task exists for device {device_id} in device_tasks.") + + def restart_device_thread(self, device_id): + try: + self.stop_device_thread(device_id) + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + except Exception as e: + logger.error(e) + + def on_device_change(self, device_id, change_type): + logger.info(f"on device change, device {device_id} {change_type}") + """设备变化回调""" + if change_type == NotifyChangeType.DEVICE_CREATE: + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + elif change_type == NotifyChangeType.DEVICE_DELETE: + self.stop_device_thread(device_id) + elif change_type == NotifyChangeType.DEVICE_UPDATE: + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + if device_id in self.device_tasks: + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) + + self.restart_device_thread(device_id) + + def on_scene_change(self, scene_id, change_type): + """模型变化回调""" + if change_type == NotifyChangeType.SCENE_UPDATE: + + devices_to_reload = [ + device_id for device_id, relation_info in self.device_scene_relations.items() + if relation_info.scene_id == scene_id + ] + for device_id in devices_to_reload: + self.restart_device_thread(device_id) + + def on_relation_change(self, device_id, change_type): + """设备模型关系变化回调""" + # 判断模型绑定关系是否真的发生了变化,避免不必要的重启 + if change_type == NotifyChangeType.DEVICE_SCENE_RELATION_UPDATE: + self.restart_device_thread(device_id) diff --git a/apis/device.py b/apis/device.py index f786851..85fb10c 100644 --- a/apis/device.py +++ b/apis/device.py @@ -3,15 +3,18 @@ from fastapi import APIRouter, Depends, Query from sqlmodel import Session -from algo.algo_runner import AlgoRunner from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo from services.device_service import DeviceService -from algo.algo_runner_manager import get_algo_runner +from app_instance import get_app router = APIRouter() +app = get_app() + +def get_service(): + return app.state.device_service @router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) @@ -45,15 +48,13 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.device_service +def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): device = service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.device_service +def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): device = service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") @@ -61,8 +62,7 @@ @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.device_service +def delete_device(device_id: int, service: DeviceService = Depends(get_service)): device = service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py index 9152eca..443176b 100644 --- a/apis/device_model_realtion.py +++ b/apis/device_model_realtion.py @@ -8,10 +8,13 @@ from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation from services.device_model_relation_service import DeviceModelRelationService -from algo.algo_runner import AlgoRunner -from algo.algo_runner_manager import get_algo_runner +from app_instance import get_app router = APIRouter() +app = get_app() + +def get_service(): + return app.state.device_model_relation_service @router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) @@ -35,12 +38,10 @@ @router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) def update_by_device(relation_data: List[DeviceModelRelationCreate], device_id: int = Query(...), - algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.relation_service + service: DeviceModelRelationService = Depends(get_service)): relations = service.update_relations_by_device(device_id, relation_data) return standard_response(data=relations) - # @router.delete("/delete_by_device", response_model=StandardResponse[int]) # def delete_device(device_id: int, db: Session = Depends(get_db)): # service = DeviceModelRelationService(db) diff --git a/apis/model.py b/apis/model.py index 264e10a..78c064f 100644 --- a/apis/model.py +++ b/apis/model.py @@ -4,15 +4,17 @@ from sqlmodel import Session from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param +from app_instance import get_app from db.database import get_db from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo from services.model_service import ModelService -from algo.algo_runner import AlgoRunner -from algo.algo_runner_manager import get_algo_runner router = APIRouter() +app = get_app() +def get_service(): + return app.state.model_service @router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) def get_model_list( @@ -60,8 +62,7 @@ @router.post("/update", response_model=StandardResponse[AlgoModelInfo]) def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), file: UploadFile = File(None, description="模型文件"), - algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.model_service + service: ModelService = Depends(get_service)): model_data = AlgoModelUpdate.parse_raw(json_data) model = service.update_model(model_data, file) if not model: diff --git a/apis/scene.py b/apis/scene.py index b0a4d22..f84b54b 100644 --- a/apis/scene.py +++ b/apis/scene.py @@ -9,8 +9,6 @@ from entity.scene import SceneInfo from services.scene_service import SceneService -from algo.algo_runner import AlgoRunner -from algo.algo_runner_manager import get_algo_runner router = APIRouter() diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml index d0876a7..33f5c97 100644 --- a/.idea/safe-algo-pro.iml +++ b/.idea/safe-algo-pro.iml @@ -5,4 +5,7 @@ + + \ No newline at end of file diff --git a/algo/algo_runner.py b/algo/algo_runner.py index d59d93a..40b8d1e 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -5,7 +5,7 @@ from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager -from common.consts import NotifyChangeType +from common.consts import NotifyChangeType, DEVICE_MODE from common.global_thread_pool import GlobalThreadPool from db.database import get_db from entity.device import Device @@ -14,23 +14,25 @@ from services.model_service import ModelService from common.global_logger import logger - class AlgoRunner: - def __init__(self): - with next(get_db()) as db: - self.device_service = DeviceService(db) - self.model_service = ModelService(db) - self.relation_service = DeviceModelRelationService(db) - self.model_manager = ModelManager(db) + def __init__(self, + device_service : DeviceService, + model_service: ModelService, + relation_service: DeviceModelRelationService): + + self.device_service = device_service + self.model_service = model_service + self.relation_service = relation_service + self.model_manager = ModelManager(model_service) self.thread_pool = GlobalThreadPool() self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 - self.device_service.register_change_callback(self.on_device_change) - self.model_service.register_change_callback(self.on_model_change) - self.relation_service.register_change_callback(self.on_relation_change) + # self.device_service.register_change_callback(self.on_device_change) + # self.model_service.register_change_callback(self.on_model_change) + # self.relation_service.register_change_callback(self.on_relation_change) async def start(self): logger.info("Starting AlgoRunner...") @@ -44,10 +46,6 @@ for device in devices: self.start_device_thread(device) - # @staticmethod - # def get_device_thread_id(device_id): - # return f'device_{device_id}' - def start_device_thread(self, device: Device): device = copy.deepcopy(device) """为单个设备启动检测线程""" @@ -56,7 +54,10 @@ logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return - thread_id = f'device_{device.id}_{uuid.uuid4()}' + if not device.mode == DEVICE_MODE.ALGO: + return + + thread_id = f'device_{device.id}_algo_{uuid.uuid4()}' logger.info(f'start thread {thread_id}, device info: {device}') # if self.thread_pool.check_task_is_running(thread_id=thread_id): @@ -135,9 +136,10 @@ elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: - # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 - old_device = self.device_tasks[device_id].device - new_device = self.device_service.get_device(device_id) + if device_id in self.device_tasks: + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) self.restart_device_thread(device_id) diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py index d7d8015..9109d75 100644 --- a/algo/algo_runner_manager.py +++ b/algo/algo_runner_manager.py @@ -1,7 +1,8 @@ from algo.algo_runner import AlgoRunner -algo_runner = AlgoRunner() +# algo_runner = AlgoRunner() def get_algo_runner(): - return algo_runner + pass + # return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index b5a2d9e..e222ca7 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -9,7 +9,7 @@ from common.device_status_manager import DeviceStatusManager from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool -from common.string_utils import camel_to_snake +from common.string_utils import camel_to_snake, get_class from db.database import get_db from entity.device import Device from entity.frame_analysis_result import FrameAnalysisResultCreate @@ -17,11 +17,7 @@ from services.frame_analysis_result_service import FrameAnalysisResultService -def get_class(module_name, class_name): - # 动态导入模块 - module = importlib.import_module(module_name) - # 使用 getattr 从模块中获取类 - return getattr(module, class_name) + @dataclass diff --git a/algo/model_manager.py b/algo/model_manager.py index 94a058c..e3bc4c9 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -19,9 +19,9 @@ class ModelManager: - def __init__(self, db, model_warm_up=5): - self.db = db - self.model_service = ModelService(self.db) + def __init__(self, model_service: ModelService, model_warm_up=5): + # self.db = db + self.model_service = model_service self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up @@ -37,6 +37,8 @@ logger.info('loading models') self.models = {} self.query_model_inuse() + if not self.models: + logger.info("no model in use") for algo_model_id, algo_model_exec in self.models.items(): self.load_model(algo_model_exec) diff --git a/algo/scene_runner.py b/algo/scene_runner.py new file mode 100644 index 0000000..2fb2b2b --- /dev/null +++ b/algo/scene_runner.py @@ -0,0 +1,146 @@ +import concurrent +import copy +import uuid +from typing import Dict + +from common.consts import DEVICE_MODE, NotifyChangeType +from common.global_logger import logger +from common.global_thread_pool import GlobalThreadPool +from common.string_utils import get_class, camel_to_snake +from entity.device import Device +from scene_handler.base_scene_handler import BaseSceneHandler +from services.device_scene_relation_service import DeviceSceneRelationService +from services.device_service import DeviceService +from services.scene_service import SceneService +from tcp.tcp_manager import TcpManager + + +class SceneRunner: + def __init__(self, + device_service: DeviceService, + scene_service: SceneService, + relation_service: DeviceSceneRelationService, + tcp_manager: TcpManager, + main_loop): + self.device_service = device_service + self.scene_service = scene_service + self.relation_service = relation_service + self.tcp_manager = tcp_manager + self.main_loop = main_loop + + self.thread_pool = GlobalThreadPool() + self.device_tasks: Dict[int, BaseSceneHandler] = {} + self.device_scene_relations = {} + + # 注册设备和模型的变化回调 + # self.device_service.register_change_callback(self.on_device_change) + # self.scene_service.register_change_callback(self.on_scene_change) + # self.relation_service.register_change_callback(self.on_relation_change) + + async def start(self): + self.load_and_start_devices() + + def load_and_start_devices(self): + """从数据库读取设备列表并启动线程""" + devices = self.device_service.get_device_list() + devices = [device for device in devices if device.mode == DEVICE_MODE.SCENE] + for device in devices: + self.start_device_thread(device) + + def start_device_thread(self, device: Device): + device = copy.deepcopy(device) + """为单个设备启动检测线程""" + + if not device.input_stream_url: + logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') + return + + if not device.mode == DEVICE_MODE.SCENE: + return + + thread_id = f'device_{device.id}_scene_{uuid.uuid4()}' + scene = self.relation_service.get_device_scene(device.id) + if scene: + self.device_scene_relations[device.id] = scene + try: + handle_task_name = scene.scene_handle_task + handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) + handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) + future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) + self.device_tasks[device.id] = handler_instance + logger.info(f'start thread {thread_id}, device info: {device}') + except Exception as e: + logger.exception(f"设备 {device.code} 场景 {scene.scene_name} 启动异常 : {e}") + + def stop_device_thread(self, device_id): + """停止指定设备的检测线程,并确保其成功停止""" + if device_id in self.device_tasks: + logger.info(f'stop device {device_id} thread') + # 获取线程 ID 和 future 对象 + thread_id = self.device_tasks[device_id].thread_id + future = self.thread_pool.get_task_future(thread_id=thread_id) + + if future: + if not future.done(): + # 任务正在运行,调用 stop_detection_task 停止任务 + self.device_tasks[device_id].stop_task() + try: + # 设置超时时间等待任务停止(例如10秒) + result = future.result(timeout=30) + logger.info(f"Task {thread_id} stopped successfully.") + except concurrent.futures.TimeoutError as te: + logger.error(f"Task {thread_id} did not stop within the timeout.") + except Exception as e: + logger.exception(f"Task {thread_id} encountered an error while stopping: {e}") + finally: + # 确保无论任务是否停止,都将其从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.info(f"Task {thread_id} has already stopped.") + # 任务已停止,直接从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.warning(f"No task found for {thread_id} .") + else: + logger.warning(f"No task exists for device {device_id} in device_tasks.") + + def restart_device_thread(self, device_id): + try: + self.stop_device_thread(device_id) + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + except Exception as e: + logger.error(e) + + def on_device_change(self, device_id, change_type): + logger.info(f"on device change, device {device_id} {change_type}") + """设备变化回调""" + if change_type == NotifyChangeType.DEVICE_CREATE: + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + elif change_type == NotifyChangeType.DEVICE_DELETE: + self.stop_device_thread(device_id) + elif change_type == NotifyChangeType.DEVICE_UPDATE: + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + if device_id in self.device_tasks: + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) + + self.restart_device_thread(device_id) + + def on_scene_change(self, scene_id, change_type): + """模型变化回调""" + if change_type == NotifyChangeType.SCENE_UPDATE: + + devices_to_reload = [ + device_id for device_id, relation_info in self.device_scene_relations.items() + if relation_info.scene_id == scene_id + ] + for device_id in devices_to_reload: + self.restart_device_thread(device_id) + + def on_relation_change(self, device_id, change_type): + """设备模型关系变化回调""" + # 判断模型绑定关系是否真的发生了变化,避免不必要的重启 + if change_type == NotifyChangeType.DEVICE_SCENE_RELATION_UPDATE: + self.restart_device_thread(device_id) diff --git a/apis/device.py b/apis/device.py index f786851..85fb10c 100644 --- a/apis/device.py +++ b/apis/device.py @@ -3,15 +3,18 @@ from fastapi import APIRouter, Depends, Query from sqlmodel import Session -from algo.algo_runner import AlgoRunner from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo from services.device_service import DeviceService -from algo.algo_runner_manager import get_algo_runner +from app_instance import get_app router = APIRouter() +app = get_app() + +def get_service(): + return app.state.device_service @router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) @@ -45,15 +48,13 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.device_service +def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): device = service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.device_service +def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): device = service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") @@ -61,8 +62,7 @@ @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.device_service +def delete_device(device_id: int, service: DeviceService = Depends(get_service)): device = service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py index 9152eca..443176b 100644 --- a/apis/device_model_realtion.py +++ b/apis/device_model_realtion.py @@ -8,10 +8,13 @@ from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation from services.device_model_relation_service import DeviceModelRelationService -from algo.algo_runner import AlgoRunner -from algo.algo_runner_manager import get_algo_runner +from app_instance import get_app router = APIRouter() +app = get_app() + +def get_service(): + return app.state.device_model_relation_service @router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) @@ -35,12 +38,10 @@ @router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) def update_by_device(relation_data: List[DeviceModelRelationCreate], device_id: int = Query(...), - algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.relation_service + service: DeviceModelRelationService = Depends(get_service)): relations = service.update_relations_by_device(device_id, relation_data) return standard_response(data=relations) - # @router.delete("/delete_by_device", response_model=StandardResponse[int]) # def delete_device(device_id: int, db: Session = Depends(get_db)): # service = DeviceModelRelationService(db) diff --git a/apis/model.py b/apis/model.py index 264e10a..78c064f 100644 --- a/apis/model.py +++ b/apis/model.py @@ -4,15 +4,17 @@ from sqlmodel import Session from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param +from app_instance import get_app from db.database import get_db from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo from services.model_service import ModelService -from algo.algo_runner import AlgoRunner -from algo.algo_runner_manager import get_algo_runner router = APIRouter() +app = get_app() +def get_service(): + return app.state.model_service @router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) def get_model_list( @@ -60,8 +62,7 @@ @router.post("/update", response_model=StandardResponse[AlgoModelInfo]) def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), file: UploadFile = File(None, description="模型文件"), - algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.model_service + service: ModelService = Depends(get_service)): model_data = AlgoModelUpdate.parse_raw(json_data) model = service.update_model(model_data, file) if not model: diff --git a/apis/scene.py b/apis/scene.py index b0a4d22..f84b54b 100644 --- a/apis/scene.py +++ b/apis/scene.py @@ -9,8 +9,6 @@ from entity.scene import SceneInfo from services.scene_service import SceneService -from algo.algo_runner import AlgoRunner -from algo.algo_runner_manager import get_algo_runner router = APIRouter() diff --git a/app_instance.py b/app_instance.py new file mode 100644 index 0000000..b72926d --- /dev/null +++ b/app_instance.py @@ -0,0 +1,77 @@ +import asyncio +from contextlib import asynccontextmanager + +from fastapi import FastAPI, HTTPException + +from algo.algo_runner import AlgoRunner +from algo.scene_runner import SceneRunner + +from common.biz_exception import BizExceptionHandlers +from common.global_logger import logger +from db.database import get_db +from services.device_model_relation_service import DeviceModelRelationService +from services.device_scene_relation_service import DeviceSceneRelationService +from services.device_service import DeviceService +from services.model_service import ModelService +from services.scene_service import SceneService +from tcp.tcp_manager import TcpManager + +_app = None # 创建一个私有变量来存储 app 实例 + + +def create_app() -> FastAPI: + global _app + if _app is None: + _app = FastAPI() + + @asynccontextmanager + async def lifespan(app: FastAPI): + main_loop = asyncio.get_running_loop() + + with next(get_db()) as db: + device_service = DeviceService(db) + model_service = ModelService(db) + model_relation_service = DeviceModelRelationService(db) + scene_service = SceneService(db) + scene_relation_service = DeviceSceneRelationService(db) + + app.state.device_service = device_service + app.state.model_service = model_service + app.state.model_relation_service = model_relation_service + app.state.scene_service = scene_service + app.state.scene_relation_service = scene_relation_service + + tcp_manager = TcpManager(device_service=device_service) + app.state.tcp_manager = tcp_manager + await tcp_manager.start() + + algo_runner = AlgoRunner( + device_service=device_service, + model_service=model_service, + relation_service=model_relation_service, + ) + app.state.algo_runner = algo_runner + await algo_runner.start() + + scene_runner = SceneRunner( + device_service=device_service, + scene_service=scene_service, + relation_service=scene_relation_service, + tcp_manager=tcp_manager, + main_loop=main_loop + ) + app.state.scene_runner = scene_runner + await scene_runner.start() + + yield # 允许请求处理 + + logger.info("Shutting down application...") + + _app.router.lifespan_context = lifespan + _app.add_exception_handler(HTTPException, BizExceptionHandlers.biz_exception_handler) + + return _app + + +def get_app() -> FastAPI: + return _app or create_app() diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml index d0876a7..33f5c97 100644 --- a/.idea/safe-algo-pro.iml +++ b/.idea/safe-algo-pro.iml @@ -5,4 +5,7 @@ + + \ No newline at end of file diff --git a/algo/algo_runner.py b/algo/algo_runner.py index d59d93a..40b8d1e 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -5,7 +5,7 @@ from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager -from common.consts import NotifyChangeType +from common.consts import NotifyChangeType, DEVICE_MODE from common.global_thread_pool import GlobalThreadPool from db.database import get_db from entity.device import Device @@ -14,23 +14,25 @@ from services.model_service import ModelService from common.global_logger import logger - class AlgoRunner: - def __init__(self): - with next(get_db()) as db: - self.device_service = DeviceService(db) - self.model_service = ModelService(db) - self.relation_service = DeviceModelRelationService(db) - self.model_manager = ModelManager(db) + def __init__(self, + device_service : DeviceService, + model_service: ModelService, + relation_service: DeviceModelRelationService): + + self.device_service = device_service + self.model_service = model_service + self.relation_service = relation_service + self.model_manager = ModelManager(model_service) self.thread_pool = GlobalThreadPool() self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 - self.device_service.register_change_callback(self.on_device_change) - self.model_service.register_change_callback(self.on_model_change) - self.relation_service.register_change_callback(self.on_relation_change) + # self.device_service.register_change_callback(self.on_device_change) + # self.model_service.register_change_callback(self.on_model_change) + # self.relation_service.register_change_callback(self.on_relation_change) async def start(self): logger.info("Starting AlgoRunner...") @@ -44,10 +46,6 @@ for device in devices: self.start_device_thread(device) - # @staticmethod - # def get_device_thread_id(device_id): - # return f'device_{device_id}' - def start_device_thread(self, device: Device): device = copy.deepcopy(device) """为单个设备启动检测线程""" @@ -56,7 +54,10 @@ logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return - thread_id = f'device_{device.id}_{uuid.uuid4()}' + if not device.mode == DEVICE_MODE.ALGO: + return + + thread_id = f'device_{device.id}_algo_{uuid.uuid4()}' logger.info(f'start thread {thread_id}, device info: {device}') # if self.thread_pool.check_task_is_running(thread_id=thread_id): @@ -135,9 +136,10 @@ elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: - # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 - old_device = self.device_tasks[device_id].device - new_device = self.device_service.get_device(device_id) + if device_id in self.device_tasks: + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) self.restart_device_thread(device_id) diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py index d7d8015..9109d75 100644 --- a/algo/algo_runner_manager.py +++ b/algo/algo_runner_manager.py @@ -1,7 +1,8 @@ from algo.algo_runner import AlgoRunner -algo_runner = AlgoRunner() +# algo_runner = AlgoRunner() def get_algo_runner(): - return algo_runner + pass + # return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index b5a2d9e..e222ca7 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -9,7 +9,7 @@ from common.device_status_manager import DeviceStatusManager from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool -from common.string_utils import camel_to_snake +from common.string_utils import camel_to_snake, get_class from db.database import get_db from entity.device import Device from entity.frame_analysis_result import FrameAnalysisResultCreate @@ -17,11 +17,7 @@ from services.frame_analysis_result_service import FrameAnalysisResultService -def get_class(module_name, class_name): - # 动态导入模块 - module = importlib.import_module(module_name) - # 使用 getattr 从模块中获取类 - return getattr(module, class_name) + @dataclass diff --git a/algo/model_manager.py b/algo/model_manager.py index 94a058c..e3bc4c9 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -19,9 +19,9 @@ class ModelManager: - def __init__(self, db, model_warm_up=5): - self.db = db - self.model_service = ModelService(self.db) + def __init__(self, model_service: ModelService, model_warm_up=5): + # self.db = db + self.model_service = model_service self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up @@ -37,6 +37,8 @@ logger.info('loading models') self.models = {} self.query_model_inuse() + if not self.models: + logger.info("no model in use") for algo_model_id, algo_model_exec in self.models.items(): self.load_model(algo_model_exec) diff --git a/algo/scene_runner.py b/algo/scene_runner.py new file mode 100644 index 0000000..2fb2b2b --- /dev/null +++ b/algo/scene_runner.py @@ -0,0 +1,146 @@ +import concurrent +import copy +import uuid +from typing import Dict + +from common.consts import DEVICE_MODE, NotifyChangeType +from common.global_logger import logger +from common.global_thread_pool import GlobalThreadPool +from common.string_utils import get_class, camel_to_snake +from entity.device import Device +from scene_handler.base_scene_handler import BaseSceneHandler +from services.device_scene_relation_service import DeviceSceneRelationService +from services.device_service import DeviceService +from services.scene_service import SceneService +from tcp.tcp_manager import TcpManager + + +class SceneRunner: + def __init__(self, + device_service: DeviceService, + scene_service: SceneService, + relation_service: DeviceSceneRelationService, + tcp_manager: TcpManager, + main_loop): + self.device_service = device_service + self.scene_service = scene_service + self.relation_service = relation_service + self.tcp_manager = tcp_manager + self.main_loop = main_loop + + self.thread_pool = GlobalThreadPool() + self.device_tasks: Dict[int, BaseSceneHandler] = {} + self.device_scene_relations = {} + + # 注册设备和模型的变化回调 + # self.device_service.register_change_callback(self.on_device_change) + # self.scene_service.register_change_callback(self.on_scene_change) + # self.relation_service.register_change_callback(self.on_relation_change) + + async def start(self): + self.load_and_start_devices() + + def load_and_start_devices(self): + """从数据库读取设备列表并启动线程""" + devices = self.device_service.get_device_list() + devices = [device for device in devices if device.mode == DEVICE_MODE.SCENE] + for device in devices: + self.start_device_thread(device) + + def start_device_thread(self, device: Device): + device = copy.deepcopy(device) + """为单个设备启动检测线程""" + + if not device.input_stream_url: + logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') + return + + if not device.mode == DEVICE_MODE.SCENE: + return + + thread_id = f'device_{device.id}_scene_{uuid.uuid4()}' + scene = self.relation_service.get_device_scene(device.id) + if scene: + self.device_scene_relations[device.id] = scene + try: + handle_task_name = scene.scene_handle_task + handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) + handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) + future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) + self.device_tasks[device.id] = handler_instance + logger.info(f'start thread {thread_id}, device info: {device}') + except Exception as e: + logger.exception(f"设备 {device.code} 场景 {scene.scene_name} 启动异常 : {e}") + + def stop_device_thread(self, device_id): + """停止指定设备的检测线程,并确保其成功停止""" + if device_id in self.device_tasks: + logger.info(f'stop device {device_id} thread') + # 获取线程 ID 和 future 对象 + thread_id = self.device_tasks[device_id].thread_id + future = self.thread_pool.get_task_future(thread_id=thread_id) + + if future: + if not future.done(): + # 任务正在运行,调用 stop_detection_task 停止任务 + self.device_tasks[device_id].stop_task() + try: + # 设置超时时间等待任务停止(例如10秒) + result = future.result(timeout=30) + logger.info(f"Task {thread_id} stopped successfully.") + except concurrent.futures.TimeoutError as te: + logger.error(f"Task {thread_id} did not stop within the timeout.") + except Exception as e: + logger.exception(f"Task {thread_id} encountered an error while stopping: {e}") + finally: + # 确保无论任务是否停止,都将其从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.info(f"Task {thread_id} has already stopped.") + # 任务已停止,直接从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.warning(f"No task found for {thread_id} .") + else: + logger.warning(f"No task exists for device {device_id} in device_tasks.") + + def restart_device_thread(self, device_id): + try: + self.stop_device_thread(device_id) + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + except Exception as e: + logger.error(e) + + def on_device_change(self, device_id, change_type): + logger.info(f"on device change, device {device_id} {change_type}") + """设备变化回调""" + if change_type == NotifyChangeType.DEVICE_CREATE: + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + elif change_type == NotifyChangeType.DEVICE_DELETE: + self.stop_device_thread(device_id) + elif change_type == NotifyChangeType.DEVICE_UPDATE: + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + if device_id in self.device_tasks: + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) + + self.restart_device_thread(device_id) + + def on_scene_change(self, scene_id, change_type): + """模型变化回调""" + if change_type == NotifyChangeType.SCENE_UPDATE: + + devices_to_reload = [ + device_id for device_id, relation_info in self.device_scene_relations.items() + if relation_info.scene_id == scene_id + ] + for device_id in devices_to_reload: + self.restart_device_thread(device_id) + + def on_relation_change(self, device_id, change_type): + """设备模型关系变化回调""" + # 判断模型绑定关系是否真的发生了变化,避免不必要的重启 + if change_type == NotifyChangeType.DEVICE_SCENE_RELATION_UPDATE: + self.restart_device_thread(device_id) diff --git a/apis/device.py b/apis/device.py index f786851..85fb10c 100644 --- a/apis/device.py +++ b/apis/device.py @@ -3,15 +3,18 @@ from fastapi import APIRouter, Depends, Query from sqlmodel import Session -from algo.algo_runner import AlgoRunner from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo from services.device_service import DeviceService -from algo.algo_runner_manager import get_algo_runner +from app_instance import get_app router = APIRouter() +app = get_app() + +def get_service(): + return app.state.device_service @router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) @@ -45,15 +48,13 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.device_service +def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): device = service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.device_service +def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): device = service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") @@ -61,8 +62,7 @@ @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.device_service +def delete_device(device_id: int, service: DeviceService = Depends(get_service)): device = service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py index 9152eca..443176b 100644 --- a/apis/device_model_realtion.py +++ b/apis/device_model_realtion.py @@ -8,10 +8,13 @@ from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation from services.device_model_relation_service import DeviceModelRelationService -from algo.algo_runner import AlgoRunner -from algo.algo_runner_manager import get_algo_runner +from app_instance import get_app router = APIRouter() +app = get_app() + +def get_service(): + return app.state.device_model_relation_service @router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) @@ -35,12 +38,10 @@ @router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) def update_by_device(relation_data: List[DeviceModelRelationCreate], device_id: int = Query(...), - algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.relation_service + service: DeviceModelRelationService = Depends(get_service)): relations = service.update_relations_by_device(device_id, relation_data) return standard_response(data=relations) - # @router.delete("/delete_by_device", response_model=StandardResponse[int]) # def delete_device(device_id: int, db: Session = Depends(get_db)): # service = DeviceModelRelationService(db) diff --git a/apis/model.py b/apis/model.py index 264e10a..78c064f 100644 --- a/apis/model.py +++ b/apis/model.py @@ -4,15 +4,17 @@ from sqlmodel import Session from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param +from app_instance import get_app from db.database import get_db from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo from services.model_service import ModelService -from algo.algo_runner import AlgoRunner -from algo.algo_runner_manager import get_algo_runner router = APIRouter() +app = get_app() +def get_service(): + return app.state.model_service @router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) def get_model_list( @@ -60,8 +62,7 @@ @router.post("/update", response_model=StandardResponse[AlgoModelInfo]) def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), file: UploadFile = File(None, description="模型文件"), - algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.model_service + service: ModelService = Depends(get_service)): model_data = AlgoModelUpdate.parse_raw(json_data) model = service.update_model(model_data, file) if not model: diff --git a/apis/scene.py b/apis/scene.py index b0a4d22..f84b54b 100644 --- a/apis/scene.py +++ b/apis/scene.py @@ -9,8 +9,6 @@ from entity.scene import SceneInfo from services.scene_service import SceneService -from algo.algo_runner import AlgoRunner -from algo.algo_runner_manager import get_algo_runner router = APIRouter() diff --git a/app_instance.py b/app_instance.py new file mode 100644 index 0000000..b72926d --- /dev/null +++ b/app_instance.py @@ -0,0 +1,77 @@ +import asyncio +from contextlib import asynccontextmanager + +from fastapi import FastAPI, HTTPException + +from algo.algo_runner import AlgoRunner +from algo.scene_runner import SceneRunner + +from common.biz_exception import BizExceptionHandlers +from common.global_logger import logger +from db.database import get_db +from services.device_model_relation_service import DeviceModelRelationService +from services.device_scene_relation_service import DeviceSceneRelationService +from services.device_service import DeviceService +from services.model_service import ModelService +from services.scene_service import SceneService +from tcp.tcp_manager import TcpManager + +_app = None # 创建一个私有变量来存储 app 实例 + + +def create_app() -> FastAPI: + global _app + if _app is None: + _app = FastAPI() + + @asynccontextmanager + async def lifespan(app: FastAPI): + main_loop = asyncio.get_running_loop() + + with next(get_db()) as db: + device_service = DeviceService(db) + model_service = ModelService(db) + model_relation_service = DeviceModelRelationService(db) + scene_service = SceneService(db) + scene_relation_service = DeviceSceneRelationService(db) + + app.state.device_service = device_service + app.state.model_service = model_service + app.state.model_relation_service = model_relation_service + app.state.scene_service = scene_service + app.state.scene_relation_service = scene_relation_service + + tcp_manager = TcpManager(device_service=device_service) + app.state.tcp_manager = tcp_manager + await tcp_manager.start() + + algo_runner = AlgoRunner( + device_service=device_service, + model_service=model_service, + relation_service=model_relation_service, + ) + app.state.algo_runner = algo_runner + await algo_runner.start() + + scene_runner = SceneRunner( + device_service=device_service, + scene_service=scene_service, + relation_service=scene_relation_service, + tcp_manager=tcp_manager, + main_loop=main_loop + ) + app.state.scene_runner = scene_runner + await scene_runner.start() + + yield # 允许请求处理 + + logger.info("Shutting down application...") + + _app.router.lifespan_context = lifespan + _app.add_exception_handler(HTTPException, BizExceptionHandlers.biz_exception_handler) + + return _app + + +def get_app() -> FastAPI: + return _app or create_app() diff --git a/common/string_utils.py b/common/string_utils.py index 42d987b..7ff7f03 100644 --- a/common/string_utils.py +++ b/common/string_utils.py @@ -1,5 +1,12 @@ +import importlib import re +def get_class(module_name, class_name): + # 动态导入模块 + module = importlib.import_module(module_name) + # 使用 getattr 从模块中获取类 + return getattr(module, class_name) + def camel_to_snake(name): # 将大写字母前加上下划线,并将整个字符串转换为小写 diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml index d0876a7..33f5c97 100644 --- a/.idea/safe-algo-pro.iml +++ b/.idea/safe-algo-pro.iml @@ -5,4 +5,7 @@ + + \ No newline at end of file diff --git a/algo/algo_runner.py b/algo/algo_runner.py index d59d93a..40b8d1e 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -5,7 +5,7 @@ from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager -from common.consts import NotifyChangeType +from common.consts import NotifyChangeType, DEVICE_MODE from common.global_thread_pool import GlobalThreadPool from db.database import get_db from entity.device import Device @@ -14,23 +14,25 @@ from services.model_service import ModelService from common.global_logger import logger - class AlgoRunner: - def __init__(self): - with next(get_db()) as db: - self.device_service = DeviceService(db) - self.model_service = ModelService(db) - self.relation_service = DeviceModelRelationService(db) - self.model_manager = ModelManager(db) + def __init__(self, + device_service : DeviceService, + model_service: ModelService, + relation_service: DeviceModelRelationService): + + self.device_service = device_service + self.model_service = model_service + self.relation_service = relation_service + self.model_manager = ModelManager(model_service) self.thread_pool = GlobalThreadPool() self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 - self.device_service.register_change_callback(self.on_device_change) - self.model_service.register_change_callback(self.on_model_change) - self.relation_service.register_change_callback(self.on_relation_change) + # self.device_service.register_change_callback(self.on_device_change) + # self.model_service.register_change_callback(self.on_model_change) + # self.relation_service.register_change_callback(self.on_relation_change) async def start(self): logger.info("Starting AlgoRunner...") @@ -44,10 +46,6 @@ for device in devices: self.start_device_thread(device) - # @staticmethod - # def get_device_thread_id(device_id): - # return f'device_{device_id}' - def start_device_thread(self, device: Device): device = copy.deepcopy(device) """为单个设备启动检测线程""" @@ -56,7 +54,10 @@ logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return - thread_id = f'device_{device.id}_{uuid.uuid4()}' + if not device.mode == DEVICE_MODE.ALGO: + return + + thread_id = f'device_{device.id}_algo_{uuid.uuid4()}' logger.info(f'start thread {thread_id}, device info: {device}') # if self.thread_pool.check_task_is_running(thread_id=thread_id): @@ -135,9 +136,10 @@ elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: - # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 - old_device = self.device_tasks[device_id].device - new_device = self.device_service.get_device(device_id) + if device_id in self.device_tasks: + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) self.restart_device_thread(device_id) diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py index d7d8015..9109d75 100644 --- a/algo/algo_runner_manager.py +++ b/algo/algo_runner_manager.py @@ -1,7 +1,8 @@ from algo.algo_runner import AlgoRunner -algo_runner = AlgoRunner() +# algo_runner = AlgoRunner() def get_algo_runner(): - return algo_runner + pass + # return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index b5a2d9e..e222ca7 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -9,7 +9,7 @@ from common.device_status_manager import DeviceStatusManager from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool -from common.string_utils import camel_to_snake +from common.string_utils import camel_to_snake, get_class from db.database import get_db from entity.device import Device from entity.frame_analysis_result import FrameAnalysisResultCreate @@ -17,11 +17,7 @@ from services.frame_analysis_result_service import FrameAnalysisResultService -def get_class(module_name, class_name): - # 动态导入模块 - module = importlib.import_module(module_name) - # 使用 getattr 从模块中获取类 - return getattr(module, class_name) + @dataclass diff --git a/algo/model_manager.py b/algo/model_manager.py index 94a058c..e3bc4c9 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -19,9 +19,9 @@ class ModelManager: - def __init__(self, db, model_warm_up=5): - self.db = db - self.model_service = ModelService(self.db) + def __init__(self, model_service: ModelService, model_warm_up=5): + # self.db = db + self.model_service = model_service self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up @@ -37,6 +37,8 @@ logger.info('loading models') self.models = {} self.query_model_inuse() + if not self.models: + logger.info("no model in use") for algo_model_id, algo_model_exec in self.models.items(): self.load_model(algo_model_exec) diff --git a/algo/scene_runner.py b/algo/scene_runner.py new file mode 100644 index 0000000..2fb2b2b --- /dev/null +++ b/algo/scene_runner.py @@ -0,0 +1,146 @@ +import concurrent +import copy +import uuid +from typing import Dict + +from common.consts import DEVICE_MODE, NotifyChangeType +from common.global_logger import logger +from common.global_thread_pool import GlobalThreadPool +from common.string_utils import get_class, camel_to_snake +from entity.device import Device +from scene_handler.base_scene_handler import BaseSceneHandler +from services.device_scene_relation_service import DeviceSceneRelationService +from services.device_service import DeviceService +from services.scene_service import SceneService +from tcp.tcp_manager import TcpManager + + +class SceneRunner: + def __init__(self, + device_service: DeviceService, + scene_service: SceneService, + relation_service: DeviceSceneRelationService, + tcp_manager: TcpManager, + main_loop): + self.device_service = device_service + self.scene_service = scene_service + self.relation_service = relation_service + self.tcp_manager = tcp_manager + self.main_loop = main_loop + + self.thread_pool = GlobalThreadPool() + self.device_tasks: Dict[int, BaseSceneHandler] = {} + self.device_scene_relations = {} + + # 注册设备和模型的变化回调 + # self.device_service.register_change_callback(self.on_device_change) + # self.scene_service.register_change_callback(self.on_scene_change) + # self.relation_service.register_change_callback(self.on_relation_change) + + async def start(self): + self.load_and_start_devices() + + def load_and_start_devices(self): + """从数据库读取设备列表并启动线程""" + devices = self.device_service.get_device_list() + devices = [device for device in devices if device.mode == DEVICE_MODE.SCENE] + for device in devices: + self.start_device_thread(device) + + def start_device_thread(self, device: Device): + device = copy.deepcopy(device) + """为单个设备启动检测线程""" + + if not device.input_stream_url: + logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') + return + + if not device.mode == DEVICE_MODE.SCENE: + return + + thread_id = f'device_{device.id}_scene_{uuid.uuid4()}' + scene = self.relation_service.get_device_scene(device.id) + if scene: + self.device_scene_relations[device.id] = scene + try: + handle_task_name = scene.scene_handle_task + handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) + handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) + future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) + self.device_tasks[device.id] = handler_instance + logger.info(f'start thread {thread_id}, device info: {device}') + except Exception as e: + logger.exception(f"设备 {device.code} 场景 {scene.scene_name} 启动异常 : {e}") + + def stop_device_thread(self, device_id): + """停止指定设备的检测线程,并确保其成功停止""" + if device_id in self.device_tasks: + logger.info(f'stop device {device_id} thread') + # 获取线程 ID 和 future 对象 + thread_id = self.device_tasks[device_id].thread_id + future = self.thread_pool.get_task_future(thread_id=thread_id) + + if future: + if not future.done(): + # 任务正在运行,调用 stop_detection_task 停止任务 + self.device_tasks[device_id].stop_task() + try: + # 设置超时时间等待任务停止(例如10秒) + result = future.result(timeout=30) + logger.info(f"Task {thread_id} stopped successfully.") + except concurrent.futures.TimeoutError as te: + logger.error(f"Task {thread_id} did not stop within the timeout.") + except Exception as e: + logger.exception(f"Task {thread_id} encountered an error while stopping: {e}") + finally: + # 确保无论任务是否停止,都将其从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.info(f"Task {thread_id} has already stopped.") + # 任务已停止,直接从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.warning(f"No task found for {thread_id} .") + else: + logger.warning(f"No task exists for device {device_id} in device_tasks.") + + def restart_device_thread(self, device_id): + try: + self.stop_device_thread(device_id) + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + except Exception as e: + logger.error(e) + + def on_device_change(self, device_id, change_type): + logger.info(f"on device change, device {device_id} {change_type}") + """设备变化回调""" + if change_type == NotifyChangeType.DEVICE_CREATE: + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + elif change_type == NotifyChangeType.DEVICE_DELETE: + self.stop_device_thread(device_id) + elif change_type == NotifyChangeType.DEVICE_UPDATE: + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + if device_id in self.device_tasks: + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) + + self.restart_device_thread(device_id) + + def on_scene_change(self, scene_id, change_type): + """模型变化回调""" + if change_type == NotifyChangeType.SCENE_UPDATE: + + devices_to_reload = [ + device_id for device_id, relation_info in self.device_scene_relations.items() + if relation_info.scene_id == scene_id + ] + for device_id in devices_to_reload: + self.restart_device_thread(device_id) + + def on_relation_change(self, device_id, change_type): + """设备模型关系变化回调""" + # 判断模型绑定关系是否真的发生了变化,避免不必要的重启 + if change_type == NotifyChangeType.DEVICE_SCENE_RELATION_UPDATE: + self.restart_device_thread(device_id) diff --git a/apis/device.py b/apis/device.py index f786851..85fb10c 100644 --- a/apis/device.py +++ b/apis/device.py @@ -3,15 +3,18 @@ from fastapi import APIRouter, Depends, Query from sqlmodel import Session -from algo.algo_runner import AlgoRunner from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo from services.device_service import DeviceService -from algo.algo_runner_manager import get_algo_runner +from app_instance import get_app router = APIRouter() +app = get_app() + +def get_service(): + return app.state.device_service @router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) @@ -45,15 +48,13 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.device_service +def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): device = service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.device_service +def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): device = service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") @@ -61,8 +62,7 @@ @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.device_service +def delete_device(device_id: int, service: DeviceService = Depends(get_service)): device = service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py index 9152eca..443176b 100644 --- a/apis/device_model_realtion.py +++ b/apis/device_model_realtion.py @@ -8,10 +8,13 @@ from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation from services.device_model_relation_service import DeviceModelRelationService -from algo.algo_runner import AlgoRunner -from algo.algo_runner_manager import get_algo_runner +from app_instance import get_app router = APIRouter() +app = get_app() + +def get_service(): + return app.state.device_model_relation_service @router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) @@ -35,12 +38,10 @@ @router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) def update_by_device(relation_data: List[DeviceModelRelationCreate], device_id: int = Query(...), - algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.relation_service + service: DeviceModelRelationService = Depends(get_service)): relations = service.update_relations_by_device(device_id, relation_data) return standard_response(data=relations) - # @router.delete("/delete_by_device", response_model=StandardResponse[int]) # def delete_device(device_id: int, db: Session = Depends(get_db)): # service = DeviceModelRelationService(db) diff --git a/apis/model.py b/apis/model.py index 264e10a..78c064f 100644 --- a/apis/model.py +++ b/apis/model.py @@ -4,15 +4,17 @@ from sqlmodel import Session from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param +from app_instance import get_app from db.database import get_db from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo from services.model_service import ModelService -from algo.algo_runner import AlgoRunner -from algo.algo_runner_manager import get_algo_runner router = APIRouter() +app = get_app() +def get_service(): + return app.state.model_service @router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) def get_model_list( @@ -60,8 +62,7 @@ @router.post("/update", response_model=StandardResponse[AlgoModelInfo]) def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), file: UploadFile = File(None, description="模型文件"), - algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.model_service + service: ModelService = Depends(get_service)): model_data = AlgoModelUpdate.parse_raw(json_data) model = service.update_model(model_data, file) if not model: diff --git a/apis/scene.py b/apis/scene.py index b0a4d22..f84b54b 100644 --- a/apis/scene.py +++ b/apis/scene.py @@ -9,8 +9,6 @@ from entity.scene import SceneInfo from services.scene_service import SceneService -from algo.algo_runner import AlgoRunner -from algo.algo_runner_manager import get_algo_runner router = APIRouter() diff --git a/app_instance.py b/app_instance.py new file mode 100644 index 0000000..b72926d --- /dev/null +++ b/app_instance.py @@ -0,0 +1,77 @@ +import asyncio +from contextlib import asynccontextmanager + +from fastapi import FastAPI, HTTPException + +from algo.algo_runner import AlgoRunner +from algo.scene_runner import SceneRunner + +from common.biz_exception import BizExceptionHandlers +from common.global_logger import logger +from db.database import get_db +from services.device_model_relation_service import DeviceModelRelationService +from services.device_scene_relation_service import DeviceSceneRelationService +from services.device_service import DeviceService +from services.model_service import ModelService +from services.scene_service import SceneService +from tcp.tcp_manager import TcpManager + +_app = None # 创建一个私有变量来存储 app 实例 + + +def create_app() -> FastAPI: + global _app + if _app is None: + _app = FastAPI() + + @asynccontextmanager + async def lifespan(app: FastAPI): + main_loop = asyncio.get_running_loop() + + with next(get_db()) as db: + device_service = DeviceService(db) + model_service = ModelService(db) + model_relation_service = DeviceModelRelationService(db) + scene_service = SceneService(db) + scene_relation_service = DeviceSceneRelationService(db) + + app.state.device_service = device_service + app.state.model_service = model_service + app.state.model_relation_service = model_relation_service + app.state.scene_service = scene_service + app.state.scene_relation_service = scene_relation_service + + tcp_manager = TcpManager(device_service=device_service) + app.state.tcp_manager = tcp_manager + await tcp_manager.start() + + algo_runner = AlgoRunner( + device_service=device_service, + model_service=model_service, + relation_service=model_relation_service, + ) + app.state.algo_runner = algo_runner + await algo_runner.start() + + scene_runner = SceneRunner( + device_service=device_service, + scene_service=scene_service, + relation_service=scene_relation_service, + tcp_manager=tcp_manager, + main_loop=main_loop + ) + app.state.scene_runner = scene_runner + await scene_runner.start() + + yield # 允许请求处理 + + logger.info("Shutting down application...") + + _app.router.lifespan_context = lifespan + _app.add_exception_handler(HTTPException, BizExceptionHandlers.biz_exception_handler) + + return _app + + +def get_app() -> FastAPI: + return _app or create_app() diff --git a/common/string_utils.py b/common/string_utils.py index 42d987b..7ff7f03 100644 --- a/common/string_utils.py +++ b/common/string_utils.py @@ -1,5 +1,12 @@ +import importlib import re +def get_class(module_name, class_name): + # 动态导入模块 + module = importlib.import_module(module_name) + # 使用 getattr 从模块中获取类 + return getattr(module, class_name) + def camel_to_snake(name): # 将大写字母前加上下划线,并将整个字符串转换为小写 diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db index b1d8fab..d1acdd4 100644 --- a/db/safe-algo-pro.db +++ b/db/safe-algo-pro.db Binary files differ diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml index d0876a7..33f5c97 100644 --- a/.idea/safe-algo-pro.iml +++ b/.idea/safe-algo-pro.iml @@ -5,4 +5,7 @@ + + \ No newline at end of file diff --git a/algo/algo_runner.py b/algo/algo_runner.py index d59d93a..40b8d1e 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -5,7 +5,7 @@ from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager -from common.consts import NotifyChangeType +from common.consts import NotifyChangeType, DEVICE_MODE from common.global_thread_pool import GlobalThreadPool from db.database import get_db from entity.device import Device @@ -14,23 +14,25 @@ from services.model_service import ModelService from common.global_logger import logger - class AlgoRunner: - def __init__(self): - with next(get_db()) as db: - self.device_service = DeviceService(db) - self.model_service = ModelService(db) - self.relation_service = DeviceModelRelationService(db) - self.model_manager = ModelManager(db) + def __init__(self, + device_service : DeviceService, + model_service: ModelService, + relation_service: DeviceModelRelationService): + + self.device_service = device_service + self.model_service = model_service + self.relation_service = relation_service + self.model_manager = ModelManager(model_service) self.thread_pool = GlobalThreadPool() self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 - self.device_service.register_change_callback(self.on_device_change) - self.model_service.register_change_callback(self.on_model_change) - self.relation_service.register_change_callback(self.on_relation_change) + # self.device_service.register_change_callback(self.on_device_change) + # self.model_service.register_change_callback(self.on_model_change) + # self.relation_service.register_change_callback(self.on_relation_change) async def start(self): logger.info("Starting AlgoRunner...") @@ -44,10 +46,6 @@ for device in devices: self.start_device_thread(device) - # @staticmethod - # def get_device_thread_id(device_id): - # return f'device_{device_id}' - def start_device_thread(self, device: Device): device = copy.deepcopy(device) """为单个设备启动检测线程""" @@ -56,7 +54,10 @@ logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return - thread_id = f'device_{device.id}_{uuid.uuid4()}' + if not device.mode == DEVICE_MODE.ALGO: + return + + thread_id = f'device_{device.id}_algo_{uuid.uuid4()}' logger.info(f'start thread {thread_id}, device info: {device}') # if self.thread_pool.check_task_is_running(thread_id=thread_id): @@ -135,9 +136,10 @@ elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: - # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 - old_device = self.device_tasks[device_id].device - new_device = self.device_service.get_device(device_id) + if device_id in self.device_tasks: + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) self.restart_device_thread(device_id) diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py index d7d8015..9109d75 100644 --- a/algo/algo_runner_manager.py +++ b/algo/algo_runner_manager.py @@ -1,7 +1,8 @@ from algo.algo_runner import AlgoRunner -algo_runner = AlgoRunner() +# algo_runner = AlgoRunner() def get_algo_runner(): - return algo_runner + pass + # return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index b5a2d9e..e222ca7 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -9,7 +9,7 @@ from common.device_status_manager import DeviceStatusManager from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool -from common.string_utils import camel_to_snake +from common.string_utils import camel_to_snake, get_class from db.database import get_db from entity.device import Device from entity.frame_analysis_result import FrameAnalysisResultCreate @@ -17,11 +17,7 @@ from services.frame_analysis_result_service import FrameAnalysisResultService -def get_class(module_name, class_name): - # 动态导入模块 - module = importlib.import_module(module_name) - # 使用 getattr 从模块中获取类 - return getattr(module, class_name) + @dataclass diff --git a/algo/model_manager.py b/algo/model_manager.py index 94a058c..e3bc4c9 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -19,9 +19,9 @@ class ModelManager: - def __init__(self, db, model_warm_up=5): - self.db = db - self.model_service = ModelService(self.db) + def __init__(self, model_service: ModelService, model_warm_up=5): + # self.db = db + self.model_service = model_service self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up @@ -37,6 +37,8 @@ logger.info('loading models') self.models = {} self.query_model_inuse() + if not self.models: + logger.info("no model in use") for algo_model_id, algo_model_exec in self.models.items(): self.load_model(algo_model_exec) diff --git a/algo/scene_runner.py b/algo/scene_runner.py new file mode 100644 index 0000000..2fb2b2b --- /dev/null +++ b/algo/scene_runner.py @@ -0,0 +1,146 @@ +import concurrent +import copy +import uuid +from typing import Dict + +from common.consts import DEVICE_MODE, NotifyChangeType +from common.global_logger import logger +from common.global_thread_pool import GlobalThreadPool +from common.string_utils import get_class, camel_to_snake +from entity.device import Device +from scene_handler.base_scene_handler import BaseSceneHandler +from services.device_scene_relation_service import DeviceSceneRelationService +from services.device_service import DeviceService +from services.scene_service import SceneService +from tcp.tcp_manager import TcpManager + + +class SceneRunner: + def __init__(self, + device_service: DeviceService, + scene_service: SceneService, + relation_service: DeviceSceneRelationService, + tcp_manager: TcpManager, + main_loop): + self.device_service = device_service + self.scene_service = scene_service + self.relation_service = relation_service + self.tcp_manager = tcp_manager + self.main_loop = main_loop + + self.thread_pool = GlobalThreadPool() + self.device_tasks: Dict[int, BaseSceneHandler] = {} + self.device_scene_relations = {} + + # 注册设备和模型的变化回调 + # self.device_service.register_change_callback(self.on_device_change) + # self.scene_service.register_change_callback(self.on_scene_change) + # self.relation_service.register_change_callback(self.on_relation_change) + + async def start(self): + self.load_and_start_devices() + + def load_and_start_devices(self): + """从数据库读取设备列表并启动线程""" + devices = self.device_service.get_device_list() + devices = [device for device in devices if device.mode == DEVICE_MODE.SCENE] + for device in devices: + self.start_device_thread(device) + + def start_device_thread(self, device: Device): + device = copy.deepcopy(device) + """为单个设备启动检测线程""" + + if not device.input_stream_url: + logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') + return + + if not device.mode == DEVICE_MODE.SCENE: + return + + thread_id = f'device_{device.id}_scene_{uuid.uuid4()}' + scene = self.relation_service.get_device_scene(device.id) + if scene: + self.device_scene_relations[device.id] = scene + try: + handle_task_name = scene.scene_handle_task + handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) + handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) + future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) + self.device_tasks[device.id] = handler_instance + logger.info(f'start thread {thread_id}, device info: {device}') + except Exception as e: + logger.exception(f"设备 {device.code} 场景 {scene.scene_name} 启动异常 : {e}") + + def stop_device_thread(self, device_id): + """停止指定设备的检测线程,并确保其成功停止""" + if device_id in self.device_tasks: + logger.info(f'stop device {device_id} thread') + # 获取线程 ID 和 future 对象 + thread_id = self.device_tasks[device_id].thread_id + future = self.thread_pool.get_task_future(thread_id=thread_id) + + if future: + if not future.done(): + # 任务正在运行,调用 stop_detection_task 停止任务 + self.device_tasks[device_id].stop_task() + try: + # 设置超时时间等待任务停止(例如10秒) + result = future.result(timeout=30) + logger.info(f"Task {thread_id} stopped successfully.") + except concurrent.futures.TimeoutError as te: + logger.error(f"Task {thread_id} did not stop within the timeout.") + except Exception as e: + logger.exception(f"Task {thread_id} encountered an error while stopping: {e}") + finally: + # 确保无论任务是否停止,都将其从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.info(f"Task {thread_id} has already stopped.") + # 任务已停止,直接从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.warning(f"No task found for {thread_id} .") + else: + logger.warning(f"No task exists for device {device_id} in device_tasks.") + + def restart_device_thread(self, device_id): + try: + self.stop_device_thread(device_id) + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + except Exception as e: + logger.error(e) + + def on_device_change(self, device_id, change_type): + logger.info(f"on device change, device {device_id} {change_type}") + """设备变化回调""" + if change_type == NotifyChangeType.DEVICE_CREATE: + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + elif change_type == NotifyChangeType.DEVICE_DELETE: + self.stop_device_thread(device_id) + elif change_type == NotifyChangeType.DEVICE_UPDATE: + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + if device_id in self.device_tasks: + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) + + self.restart_device_thread(device_id) + + def on_scene_change(self, scene_id, change_type): + """模型变化回调""" + if change_type == NotifyChangeType.SCENE_UPDATE: + + devices_to_reload = [ + device_id for device_id, relation_info in self.device_scene_relations.items() + if relation_info.scene_id == scene_id + ] + for device_id in devices_to_reload: + self.restart_device_thread(device_id) + + def on_relation_change(self, device_id, change_type): + """设备模型关系变化回调""" + # 判断模型绑定关系是否真的发生了变化,避免不必要的重启 + if change_type == NotifyChangeType.DEVICE_SCENE_RELATION_UPDATE: + self.restart_device_thread(device_id) diff --git a/apis/device.py b/apis/device.py index f786851..85fb10c 100644 --- a/apis/device.py +++ b/apis/device.py @@ -3,15 +3,18 @@ from fastapi import APIRouter, Depends, Query from sqlmodel import Session -from algo.algo_runner import AlgoRunner from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo from services.device_service import DeviceService -from algo.algo_runner_manager import get_algo_runner +from app_instance import get_app router = APIRouter() +app = get_app() + +def get_service(): + return app.state.device_service @router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) @@ -45,15 +48,13 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.device_service +def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): device = service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.device_service +def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): device = service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") @@ -61,8 +62,7 @@ @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.device_service +def delete_device(device_id: int, service: DeviceService = Depends(get_service)): device = service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py index 9152eca..443176b 100644 --- a/apis/device_model_realtion.py +++ b/apis/device_model_realtion.py @@ -8,10 +8,13 @@ from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation from services.device_model_relation_service import DeviceModelRelationService -from algo.algo_runner import AlgoRunner -from algo.algo_runner_manager import get_algo_runner +from app_instance import get_app router = APIRouter() +app = get_app() + +def get_service(): + return app.state.device_model_relation_service @router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) @@ -35,12 +38,10 @@ @router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) def update_by_device(relation_data: List[DeviceModelRelationCreate], device_id: int = Query(...), - algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.relation_service + service: DeviceModelRelationService = Depends(get_service)): relations = service.update_relations_by_device(device_id, relation_data) return standard_response(data=relations) - # @router.delete("/delete_by_device", response_model=StandardResponse[int]) # def delete_device(device_id: int, db: Session = Depends(get_db)): # service = DeviceModelRelationService(db) diff --git a/apis/model.py b/apis/model.py index 264e10a..78c064f 100644 --- a/apis/model.py +++ b/apis/model.py @@ -4,15 +4,17 @@ from sqlmodel import Session from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param +from app_instance import get_app from db.database import get_db from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo from services.model_service import ModelService -from algo.algo_runner import AlgoRunner -from algo.algo_runner_manager import get_algo_runner router = APIRouter() +app = get_app() +def get_service(): + return app.state.model_service @router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) def get_model_list( @@ -60,8 +62,7 @@ @router.post("/update", response_model=StandardResponse[AlgoModelInfo]) def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), file: UploadFile = File(None, description="模型文件"), - algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.model_service + service: ModelService = Depends(get_service)): model_data = AlgoModelUpdate.parse_raw(json_data) model = service.update_model(model_data, file) if not model: diff --git a/apis/scene.py b/apis/scene.py index b0a4d22..f84b54b 100644 --- a/apis/scene.py +++ b/apis/scene.py @@ -9,8 +9,6 @@ from entity.scene import SceneInfo from services.scene_service import SceneService -from algo.algo_runner import AlgoRunner -from algo.algo_runner_manager import get_algo_runner router = APIRouter() diff --git a/app_instance.py b/app_instance.py new file mode 100644 index 0000000..b72926d --- /dev/null +++ b/app_instance.py @@ -0,0 +1,77 @@ +import asyncio +from contextlib import asynccontextmanager + +from fastapi import FastAPI, HTTPException + +from algo.algo_runner import AlgoRunner +from algo.scene_runner import SceneRunner + +from common.biz_exception import BizExceptionHandlers +from common.global_logger import logger +from db.database import get_db +from services.device_model_relation_service import DeviceModelRelationService +from services.device_scene_relation_service import DeviceSceneRelationService +from services.device_service import DeviceService +from services.model_service import ModelService +from services.scene_service import SceneService +from tcp.tcp_manager import TcpManager + +_app = None # 创建一个私有变量来存储 app 实例 + + +def create_app() -> FastAPI: + global _app + if _app is None: + _app = FastAPI() + + @asynccontextmanager + async def lifespan(app: FastAPI): + main_loop = asyncio.get_running_loop() + + with next(get_db()) as db: + device_service = DeviceService(db) + model_service = ModelService(db) + model_relation_service = DeviceModelRelationService(db) + scene_service = SceneService(db) + scene_relation_service = DeviceSceneRelationService(db) + + app.state.device_service = device_service + app.state.model_service = model_service + app.state.model_relation_service = model_relation_service + app.state.scene_service = scene_service + app.state.scene_relation_service = scene_relation_service + + tcp_manager = TcpManager(device_service=device_service) + app.state.tcp_manager = tcp_manager + await tcp_manager.start() + + algo_runner = AlgoRunner( + device_service=device_service, + model_service=model_service, + relation_service=model_relation_service, + ) + app.state.algo_runner = algo_runner + await algo_runner.start() + + scene_runner = SceneRunner( + device_service=device_service, + scene_service=scene_service, + relation_service=scene_relation_service, + tcp_manager=tcp_manager, + main_loop=main_loop + ) + app.state.scene_runner = scene_runner + await scene_runner.start() + + yield # 允许请求处理 + + logger.info("Shutting down application...") + + _app.router.lifespan_context = lifespan + _app.add_exception_handler(HTTPException, BizExceptionHandlers.biz_exception_handler) + + return _app + + +def get_app() -> FastAPI: + return _app or create_app() diff --git a/common/string_utils.py b/common/string_utils.py index 42d987b..7ff7f03 100644 --- a/common/string_utils.py +++ b/common/string_utils.py @@ -1,5 +1,12 @@ +import importlib import re +def get_class(module_name, class_name): + # 动态导入模块 + module = importlib.import_module(module_name) + # 使用 getattr 从模块中获取类 + return getattr(module, class_name) + def camel_to_snake(name): # 将大写字母前加上下划线,并将整个字符串转换为小写 diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db index b1d8fab..d1acdd4 100644 --- a/db/safe-algo-pro.db +++ b/db/safe-algo-pro.db Binary files differ diff --git a/main.py b/main.py index 61cbe9e..8e5e464 100644 --- a/main.py +++ b/main.py @@ -1,58 +1,21 @@ -import asyncio -from contextlib import asynccontextmanager -from fastapi import FastAPI, HTTPException - -from algo.algo_runner_manager import get_algo_runner -from apis.router import router import uvicorn import logging -from common.biz_exception import BizExceptionHandlers +from app_instance import get_app from common.global_logger import logger +app = get_app() -app = FastAPI() - -# # 初始化 AlgoRunner -# algo_runner = AlgoRunner() -# -# -# # 使用 FastAPI 的 startup 事件来启动 AlgoRunner -# @app.on_event("startup") -# async def startup_event(): -# algo_runner.start() - -algo_runner = get_algo_runner() - - -@asynccontextmanager -async def lifespan(app: FastAPI): - # 应用启动时的初始化 - await algo_runner.start() - # app.add_event_handler("startup", lambda: asyncio.create_task(algo_runner.start())) - # 允许请求处理 - yield - # 应用关闭时的清理逻辑 - logger.info("Shutting down application...") - - -# 包含所有模块的路由 +# 延迟导入 router 并注册路由 +from apis.router import router app.include_router(router, prefix="/api") -app.add_exception_handler(HTTPException, BizExceptionHandlers.biz_exception_handler) -app.router.lifespan_context = lifespan - if __name__ == "__main__": - # 重定向 uvicorn 的日志 uvicorn_logger = logging.getLogger("uvicorn") uvicorn_logger.handlers = logger.handlers uvicorn_logger.setLevel(logging.DEBUG) - # 重定向 uvicorn 的 access 日志 - # access_logger = logging.getLogger("uvicorn.access") - # access_logger.handlers = logger.handlers - # access_logger.setLevel(logging.DEBUG) uvicorn.run(app, host="0.0.0.0", port=9000, log_config=None) diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml index d0876a7..33f5c97 100644 --- a/.idea/safe-algo-pro.iml +++ b/.idea/safe-algo-pro.iml @@ -5,4 +5,7 @@ + + \ No newline at end of file diff --git a/algo/algo_runner.py b/algo/algo_runner.py index d59d93a..40b8d1e 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -5,7 +5,7 @@ from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager -from common.consts import NotifyChangeType +from common.consts import NotifyChangeType, DEVICE_MODE from common.global_thread_pool import GlobalThreadPool from db.database import get_db from entity.device import Device @@ -14,23 +14,25 @@ from services.model_service import ModelService from common.global_logger import logger - class AlgoRunner: - def __init__(self): - with next(get_db()) as db: - self.device_service = DeviceService(db) - self.model_service = ModelService(db) - self.relation_service = DeviceModelRelationService(db) - self.model_manager = ModelManager(db) + def __init__(self, + device_service : DeviceService, + model_service: ModelService, + relation_service: DeviceModelRelationService): + + self.device_service = device_service + self.model_service = model_service + self.relation_service = relation_service + self.model_manager = ModelManager(model_service) self.thread_pool = GlobalThreadPool() self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 - self.device_service.register_change_callback(self.on_device_change) - self.model_service.register_change_callback(self.on_model_change) - self.relation_service.register_change_callback(self.on_relation_change) + # self.device_service.register_change_callback(self.on_device_change) + # self.model_service.register_change_callback(self.on_model_change) + # self.relation_service.register_change_callback(self.on_relation_change) async def start(self): logger.info("Starting AlgoRunner...") @@ -44,10 +46,6 @@ for device in devices: self.start_device_thread(device) - # @staticmethod - # def get_device_thread_id(device_id): - # return f'device_{device_id}' - def start_device_thread(self, device: Device): device = copy.deepcopy(device) """为单个设备启动检测线程""" @@ -56,7 +54,10 @@ logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return - thread_id = f'device_{device.id}_{uuid.uuid4()}' + if not device.mode == DEVICE_MODE.ALGO: + return + + thread_id = f'device_{device.id}_algo_{uuid.uuid4()}' logger.info(f'start thread {thread_id}, device info: {device}') # if self.thread_pool.check_task_is_running(thread_id=thread_id): @@ -135,9 +136,10 @@ elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: - # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 - old_device = self.device_tasks[device_id].device - new_device = self.device_service.get_device(device_id) + if device_id in self.device_tasks: + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) self.restart_device_thread(device_id) diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py index d7d8015..9109d75 100644 --- a/algo/algo_runner_manager.py +++ b/algo/algo_runner_manager.py @@ -1,7 +1,8 @@ from algo.algo_runner import AlgoRunner -algo_runner = AlgoRunner() +# algo_runner = AlgoRunner() def get_algo_runner(): - return algo_runner + pass + # return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index b5a2d9e..e222ca7 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -9,7 +9,7 @@ from common.device_status_manager import DeviceStatusManager from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool -from common.string_utils import camel_to_snake +from common.string_utils import camel_to_snake, get_class from db.database import get_db from entity.device import Device from entity.frame_analysis_result import FrameAnalysisResultCreate @@ -17,11 +17,7 @@ from services.frame_analysis_result_service import FrameAnalysisResultService -def get_class(module_name, class_name): - # 动态导入模块 - module = importlib.import_module(module_name) - # 使用 getattr 从模块中获取类 - return getattr(module, class_name) + @dataclass diff --git a/algo/model_manager.py b/algo/model_manager.py index 94a058c..e3bc4c9 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -19,9 +19,9 @@ class ModelManager: - def __init__(self, db, model_warm_up=5): - self.db = db - self.model_service = ModelService(self.db) + def __init__(self, model_service: ModelService, model_warm_up=5): + # self.db = db + self.model_service = model_service self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up @@ -37,6 +37,8 @@ logger.info('loading models') self.models = {} self.query_model_inuse() + if not self.models: + logger.info("no model in use") for algo_model_id, algo_model_exec in self.models.items(): self.load_model(algo_model_exec) diff --git a/algo/scene_runner.py b/algo/scene_runner.py new file mode 100644 index 0000000..2fb2b2b --- /dev/null +++ b/algo/scene_runner.py @@ -0,0 +1,146 @@ +import concurrent +import copy +import uuid +from typing import Dict + +from common.consts import DEVICE_MODE, NotifyChangeType +from common.global_logger import logger +from common.global_thread_pool import GlobalThreadPool +from common.string_utils import get_class, camel_to_snake +from entity.device import Device +from scene_handler.base_scene_handler import BaseSceneHandler +from services.device_scene_relation_service import DeviceSceneRelationService +from services.device_service import DeviceService +from services.scene_service import SceneService +from tcp.tcp_manager import TcpManager + + +class SceneRunner: + def __init__(self, + device_service: DeviceService, + scene_service: SceneService, + relation_service: DeviceSceneRelationService, + tcp_manager: TcpManager, + main_loop): + self.device_service = device_service + self.scene_service = scene_service + self.relation_service = relation_service + self.tcp_manager = tcp_manager + self.main_loop = main_loop + + self.thread_pool = GlobalThreadPool() + self.device_tasks: Dict[int, BaseSceneHandler] = {} + self.device_scene_relations = {} + + # 注册设备和模型的变化回调 + # self.device_service.register_change_callback(self.on_device_change) + # self.scene_service.register_change_callback(self.on_scene_change) + # self.relation_service.register_change_callback(self.on_relation_change) + + async def start(self): + self.load_and_start_devices() + + def load_and_start_devices(self): + """从数据库读取设备列表并启动线程""" + devices = self.device_service.get_device_list() + devices = [device for device in devices if device.mode == DEVICE_MODE.SCENE] + for device in devices: + self.start_device_thread(device) + + def start_device_thread(self, device: Device): + device = copy.deepcopy(device) + """为单个设备启动检测线程""" + + if not device.input_stream_url: + logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') + return + + if not device.mode == DEVICE_MODE.SCENE: + return + + thread_id = f'device_{device.id}_scene_{uuid.uuid4()}' + scene = self.relation_service.get_device_scene(device.id) + if scene: + self.device_scene_relations[device.id] = scene + try: + handle_task_name = scene.scene_handle_task + handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) + handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) + future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) + self.device_tasks[device.id] = handler_instance + logger.info(f'start thread {thread_id}, device info: {device}') + except Exception as e: + logger.exception(f"设备 {device.code} 场景 {scene.scene_name} 启动异常 : {e}") + + def stop_device_thread(self, device_id): + """停止指定设备的检测线程,并确保其成功停止""" + if device_id in self.device_tasks: + logger.info(f'stop device {device_id} thread') + # 获取线程 ID 和 future 对象 + thread_id = self.device_tasks[device_id].thread_id + future = self.thread_pool.get_task_future(thread_id=thread_id) + + if future: + if not future.done(): + # 任务正在运行,调用 stop_detection_task 停止任务 + self.device_tasks[device_id].stop_task() + try: + # 设置超时时间等待任务停止(例如10秒) + result = future.result(timeout=30) + logger.info(f"Task {thread_id} stopped successfully.") + except concurrent.futures.TimeoutError as te: + logger.error(f"Task {thread_id} did not stop within the timeout.") + except Exception as e: + logger.exception(f"Task {thread_id} encountered an error while stopping: {e}") + finally: + # 确保无论任务是否停止,都将其从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.info(f"Task {thread_id} has already stopped.") + # 任务已停止,直接从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.warning(f"No task found for {thread_id} .") + else: + logger.warning(f"No task exists for device {device_id} in device_tasks.") + + def restart_device_thread(self, device_id): + try: + self.stop_device_thread(device_id) + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + except Exception as e: + logger.error(e) + + def on_device_change(self, device_id, change_type): + logger.info(f"on device change, device {device_id} {change_type}") + """设备变化回调""" + if change_type == NotifyChangeType.DEVICE_CREATE: + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + elif change_type == NotifyChangeType.DEVICE_DELETE: + self.stop_device_thread(device_id) + elif change_type == NotifyChangeType.DEVICE_UPDATE: + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + if device_id in self.device_tasks: + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) + + self.restart_device_thread(device_id) + + def on_scene_change(self, scene_id, change_type): + """模型变化回调""" + if change_type == NotifyChangeType.SCENE_UPDATE: + + devices_to_reload = [ + device_id for device_id, relation_info in self.device_scene_relations.items() + if relation_info.scene_id == scene_id + ] + for device_id in devices_to_reload: + self.restart_device_thread(device_id) + + def on_relation_change(self, device_id, change_type): + """设备模型关系变化回调""" + # 判断模型绑定关系是否真的发生了变化,避免不必要的重启 + if change_type == NotifyChangeType.DEVICE_SCENE_RELATION_UPDATE: + self.restart_device_thread(device_id) diff --git a/apis/device.py b/apis/device.py index f786851..85fb10c 100644 --- a/apis/device.py +++ b/apis/device.py @@ -3,15 +3,18 @@ from fastapi import APIRouter, Depends, Query from sqlmodel import Session -from algo.algo_runner import AlgoRunner from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo from services.device_service import DeviceService -from algo.algo_runner_manager import get_algo_runner +from app_instance import get_app router = APIRouter() +app = get_app() + +def get_service(): + return app.state.device_service @router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) @@ -45,15 +48,13 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.device_service +def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): device = service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.device_service +def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): device = service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") @@ -61,8 +62,7 @@ @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.device_service +def delete_device(device_id: int, service: DeviceService = Depends(get_service)): device = service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py index 9152eca..443176b 100644 --- a/apis/device_model_realtion.py +++ b/apis/device_model_realtion.py @@ -8,10 +8,13 @@ from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation from services.device_model_relation_service import DeviceModelRelationService -from algo.algo_runner import AlgoRunner -from algo.algo_runner_manager import get_algo_runner +from app_instance import get_app router = APIRouter() +app = get_app() + +def get_service(): + return app.state.device_model_relation_service @router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) @@ -35,12 +38,10 @@ @router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) def update_by_device(relation_data: List[DeviceModelRelationCreate], device_id: int = Query(...), - algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.relation_service + service: DeviceModelRelationService = Depends(get_service)): relations = service.update_relations_by_device(device_id, relation_data) return standard_response(data=relations) - # @router.delete("/delete_by_device", response_model=StandardResponse[int]) # def delete_device(device_id: int, db: Session = Depends(get_db)): # service = DeviceModelRelationService(db) diff --git a/apis/model.py b/apis/model.py index 264e10a..78c064f 100644 --- a/apis/model.py +++ b/apis/model.py @@ -4,15 +4,17 @@ from sqlmodel import Session from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param +from app_instance import get_app from db.database import get_db from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo from services.model_service import ModelService -from algo.algo_runner import AlgoRunner -from algo.algo_runner_manager import get_algo_runner router = APIRouter() +app = get_app() +def get_service(): + return app.state.model_service @router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) def get_model_list( @@ -60,8 +62,7 @@ @router.post("/update", response_model=StandardResponse[AlgoModelInfo]) def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), file: UploadFile = File(None, description="模型文件"), - algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.model_service + service: ModelService = Depends(get_service)): model_data = AlgoModelUpdate.parse_raw(json_data) model = service.update_model(model_data, file) if not model: diff --git a/apis/scene.py b/apis/scene.py index b0a4d22..f84b54b 100644 --- a/apis/scene.py +++ b/apis/scene.py @@ -9,8 +9,6 @@ from entity.scene import SceneInfo from services.scene_service import SceneService -from algo.algo_runner import AlgoRunner -from algo.algo_runner_manager import get_algo_runner router = APIRouter() diff --git a/app_instance.py b/app_instance.py new file mode 100644 index 0000000..b72926d --- /dev/null +++ b/app_instance.py @@ -0,0 +1,77 @@ +import asyncio +from contextlib import asynccontextmanager + +from fastapi import FastAPI, HTTPException + +from algo.algo_runner import AlgoRunner +from algo.scene_runner import SceneRunner + +from common.biz_exception import BizExceptionHandlers +from common.global_logger import logger +from db.database import get_db +from services.device_model_relation_service import DeviceModelRelationService +from services.device_scene_relation_service import DeviceSceneRelationService +from services.device_service import DeviceService +from services.model_service import ModelService +from services.scene_service import SceneService +from tcp.tcp_manager import TcpManager + +_app = None # 创建一个私有变量来存储 app 实例 + + +def create_app() -> FastAPI: + global _app + if _app is None: + _app = FastAPI() + + @asynccontextmanager + async def lifespan(app: FastAPI): + main_loop = asyncio.get_running_loop() + + with next(get_db()) as db: + device_service = DeviceService(db) + model_service = ModelService(db) + model_relation_service = DeviceModelRelationService(db) + scene_service = SceneService(db) + scene_relation_service = DeviceSceneRelationService(db) + + app.state.device_service = device_service + app.state.model_service = model_service + app.state.model_relation_service = model_relation_service + app.state.scene_service = scene_service + app.state.scene_relation_service = scene_relation_service + + tcp_manager = TcpManager(device_service=device_service) + app.state.tcp_manager = tcp_manager + await tcp_manager.start() + + algo_runner = AlgoRunner( + device_service=device_service, + model_service=model_service, + relation_service=model_relation_service, + ) + app.state.algo_runner = algo_runner + await algo_runner.start() + + scene_runner = SceneRunner( + device_service=device_service, + scene_service=scene_service, + relation_service=scene_relation_service, + tcp_manager=tcp_manager, + main_loop=main_loop + ) + app.state.scene_runner = scene_runner + await scene_runner.start() + + yield # 允许请求处理 + + logger.info("Shutting down application...") + + _app.router.lifespan_context = lifespan + _app.add_exception_handler(HTTPException, BizExceptionHandlers.biz_exception_handler) + + return _app + + +def get_app() -> FastAPI: + return _app or create_app() diff --git a/common/string_utils.py b/common/string_utils.py index 42d987b..7ff7f03 100644 --- a/common/string_utils.py +++ b/common/string_utils.py @@ -1,5 +1,12 @@ +import importlib import re +def get_class(module_name, class_name): + # 动态导入模块 + module = importlib.import_module(module_name) + # 使用 getattr 从模块中获取类 + return getattr(module, class_name) + def camel_to_snake(name): # 将大写字母前加上下划线,并将整个字符串转换为小写 diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db index b1d8fab..d1acdd4 100644 --- a/db/safe-algo-pro.db +++ b/db/safe-algo-pro.db Binary files differ diff --git a/main.py b/main.py index 61cbe9e..8e5e464 100644 --- a/main.py +++ b/main.py @@ -1,58 +1,21 @@ -import asyncio -from contextlib import asynccontextmanager -from fastapi import FastAPI, HTTPException - -from algo.algo_runner_manager import get_algo_runner -from apis.router import router import uvicorn import logging -from common.biz_exception import BizExceptionHandlers +from app_instance import get_app from common.global_logger import logger +app = get_app() -app = FastAPI() - -# # 初始化 AlgoRunner -# algo_runner = AlgoRunner() -# -# -# # 使用 FastAPI 的 startup 事件来启动 AlgoRunner -# @app.on_event("startup") -# async def startup_event(): -# algo_runner.start() - -algo_runner = get_algo_runner() - - -@asynccontextmanager -async def lifespan(app: FastAPI): - # 应用启动时的初始化 - await algo_runner.start() - # app.add_event_handler("startup", lambda: asyncio.create_task(algo_runner.start())) - # 允许请求处理 - yield - # 应用关闭时的清理逻辑 - logger.info("Shutting down application...") - - -# 包含所有模块的路由 +# 延迟导入 router 并注册路由 +from apis.router import router app.include_router(router, prefix="/api") -app.add_exception_handler(HTTPException, BizExceptionHandlers.biz_exception_handler) -app.router.lifespan_context = lifespan - if __name__ == "__main__": - # 重定向 uvicorn 的日志 uvicorn_logger = logging.getLogger("uvicorn") uvicorn_logger.handlers = logger.handlers uvicorn_logger.setLevel(logging.DEBUG) - # 重定向 uvicorn 的 access 日志 - # access_logger = logging.getLogger("uvicorn.access") - # access_logger.handlers = logger.handlers - # access_logger.setLevel(logging.DEBUG) uvicorn.run(app, host="0.0.0.0", port=9000, log_config=None) diff --git a/requirements.txt b/requirements.txt index 9a7d88c..af16eda 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,12 @@ sqlmodel openpyxl python-multipart -docker \ No newline at end of file +docker +numpy +ultralytics +opencv-python +pydantic +pandas +starlette +uvicorn +sqlalchemy \ No newline at end of file diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml index d0876a7..33f5c97 100644 --- a/.idea/safe-algo-pro.iml +++ b/.idea/safe-algo-pro.iml @@ -5,4 +5,7 @@ + + \ No newline at end of file diff --git a/algo/algo_runner.py b/algo/algo_runner.py index d59d93a..40b8d1e 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -5,7 +5,7 @@ from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager -from common.consts import NotifyChangeType +from common.consts import NotifyChangeType, DEVICE_MODE from common.global_thread_pool import GlobalThreadPool from db.database import get_db from entity.device import Device @@ -14,23 +14,25 @@ from services.model_service import ModelService from common.global_logger import logger - class AlgoRunner: - def __init__(self): - with next(get_db()) as db: - self.device_service = DeviceService(db) - self.model_service = ModelService(db) - self.relation_service = DeviceModelRelationService(db) - self.model_manager = ModelManager(db) + def __init__(self, + device_service : DeviceService, + model_service: ModelService, + relation_service: DeviceModelRelationService): + + self.device_service = device_service + self.model_service = model_service + self.relation_service = relation_service + self.model_manager = ModelManager(model_service) self.thread_pool = GlobalThreadPool() self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 - self.device_service.register_change_callback(self.on_device_change) - self.model_service.register_change_callback(self.on_model_change) - self.relation_service.register_change_callback(self.on_relation_change) + # self.device_service.register_change_callback(self.on_device_change) + # self.model_service.register_change_callback(self.on_model_change) + # self.relation_service.register_change_callback(self.on_relation_change) async def start(self): logger.info("Starting AlgoRunner...") @@ -44,10 +46,6 @@ for device in devices: self.start_device_thread(device) - # @staticmethod - # def get_device_thread_id(device_id): - # return f'device_{device_id}' - def start_device_thread(self, device: Device): device = copy.deepcopy(device) """为单个设备启动检测线程""" @@ -56,7 +54,10 @@ logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return - thread_id = f'device_{device.id}_{uuid.uuid4()}' + if not device.mode == DEVICE_MODE.ALGO: + return + + thread_id = f'device_{device.id}_algo_{uuid.uuid4()}' logger.info(f'start thread {thread_id}, device info: {device}') # if self.thread_pool.check_task_is_running(thread_id=thread_id): @@ -135,9 +136,10 @@ elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: - # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 - old_device = self.device_tasks[device_id].device - new_device = self.device_service.get_device(device_id) + if device_id in self.device_tasks: + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) self.restart_device_thread(device_id) diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py index d7d8015..9109d75 100644 --- a/algo/algo_runner_manager.py +++ b/algo/algo_runner_manager.py @@ -1,7 +1,8 @@ from algo.algo_runner import AlgoRunner -algo_runner = AlgoRunner() +# algo_runner = AlgoRunner() def get_algo_runner(): - return algo_runner + pass + # return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index b5a2d9e..e222ca7 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -9,7 +9,7 @@ from common.device_status_manager import DeviceStatusManager from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool -from common.string_utils import camel_to_snake +from common.string_utils import camel_to_snake, get_class from db.database import get_db from entity.device import Device from entity.frame_analysis_result import FrameAnalysisResultCreate @@ -17,11 +17,7 @@ from services.frame_analysis_result_service import FrameAnalysisResultService -def get_class(module_name, class_name): - # 动态导入模块 - module = importlib.import_module(module_name) - # 使用 getattr 从模块中获取类 - return getattr(module, class_name) + @dataclass diff --git a/algo/model_manager.py b/algo/model_manager.py index 94a058c..e3bc4c9 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -19,9 +19,9 @@ class ModelManager: - def __init__(self, db, model_warm_up=5): - self.db = db - self.model_service = ModelService(self.db) + def __init__(self, model_service: ModelService, model_warm_up=5): + # self.db = db + self.model_service = model_service self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up @@ -37,6 +37,8 @@ logger.info('loading models') self.models = {} self.query_model_inuse() + if not self.models: + logger.info("no model in use") for algo_model_id, algo_model_exec in self.models.items(): self.load_model(algo_model_exec) diff --git a/algo/scene_runner.py b/algo/scene_runner.py new file mode 100644 index 0000000..2fb2b2b --- /dev/null +++ b/algo/scene_runner.py @@ -0,0 +1,146 @@ +import concurrent +import copy +import uuid +from typing import Dict + +from common.consts import DEVICE_MODE, NotifyChangeType +from common.global_logger import logger +from common.global_thread_pool import GlobalThreadPool +from common.string_utils import get_class, camel_to_snake +from entity.device import Device +from scene_handler.base_scene_handler import BaseSceneHandler +from services.device_scene_relation_service import DeviceSceneRelationService +from services.device_service import DeviceService +from services.scene_service import SceneService +from tcp.tcp_manager import TcpManager + + +class SceneRunner: + def __init__(self, + device_service: DeviceService, + scene_service: SceneService, + relation_service: DeviceSceneRelationService, + tcp_manager: TcpManager, + main_loop): + self.device_service = device_service + self.scene_service = scene_service + self.relation_service = relation_service + self.tcp_manager = tcp_manager + self.main_loop = main_loop + + self.thread_pool = GlobalThreadPool() + self.device_tasks: Dict[int, BaseSceneHandler] = {} + self.device_scene_relations = {} + + # 注册设备和模型的变化回调 + # self.device_service.register_change_callback(self.on_device_change) + # self.scene_service.register_change_callback(self.on_scene_change) + # self.relation_service.register_change_callback(self.on_relation_change) + + async def start(self): + self.load_and_start_devices() + + def load_and_start_devices(self): + """从数据库读取设备列表并启动线程""" + devices = self.device_service.get_device_list() + devices = [device for device in devices if device.mode == DEVICE_MODE.SCENE] + for device in devices: + self.start_device_thread(device) + + def start_device_thread(self, device: Device): + device = copy.deepcopy(device) + """为单个设备启动检测线程""" + + if not device.input_stream_url: + logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') + return + + if not device.mode == DEVICE_MODE.SCENE: + return + + thread_id = f'device_{device.id}_scene_{uuid.uuid4()}' + scene = self.relation_service.get_device_scene(device.id) + if scene: + self.device_scene_relations[device.id] = scene + try: + handle_task_name = scene.scene_handle_task + handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) + handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) + future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) + self.device_tasks[device.id] = handler_instance + logger.info(f'start thread {thread_id}, device info: {device}') + except Exception as e: + logger.exception(f"设备 {device.code} 场景 {scene.scene_name} 启动异常 : {e}") + + def stop_device_thread(self, device_id): + """停止指定设备的检测线程,并确保其成功停止""" + if device_id in self.device_tasks: + logger.info(f'stop device {device_id} thread') + # 获取线程 ID 和 future 对象 + thread_id = self.device_tasks[device_id].thread_id + future = self.thread_pool.get_task_future(thread_id=thread_id) + + if future: + if not future.done(): + # 任务正在运行,调用 stop_detection_task 停止任务 + self.device_tasks[device_id].stop_task() + try: + # 设置超时时间等待任务停止(例如10秒) + result = future.result(timeout=30) + logger.info(f"Task {thread_id} stopped successfully.") + except concurrent.futures.TimeoutError as te: + logger.error(f"Task {thread_id} did not stop within the timeout.") + except Exception as e: + logger.exception(f"Task {thread_id} encountered an error while stopping: {e}") + finally: + # 确保无论任务是否停止,都将其从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.info(f"Task {thread_id} has already stopped.") + # 任务已停止,直接从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.warning(f"No task found for {thread_id} .") + else: + logger.warning(f"No task exists for device {device_id} in device_tasks.") + + def restart_device_thread(self, device_id): + try: + self.stop_device_thread(device_id) + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + except Exception as e: + logger.error(e) + + def on_device_change(self, device_id, change_type): + logger.info(f"on device change, device {device_id} {change_type}") + """设备变化回调""" + if change_type == NotifyChangeType.DEVICE_CREATE: + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + elif change_type == NotifyChangeType.DEVICE_DELETE: + self.stop_device_thread(device_id) + elif change_type == NotifyChangeType.DEVICE_UPDATE: + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + if device_id in self.device_tasks: + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) + + self.restart_device_thread(device_id) + + def on_scene_change(self, scene_id, change_type): + """模型变化回调""" + if change_type == NotifyChangeType.SCENE_UPDATE: + + devices_to_reload = [ + device_id for device_id, relation_info in self.device_scene_relations.items() + if relation_info.scene_id == scene_id + ] + for device_id in devices_to_reload: + self.restart_device_thread(device_id) + + def on_relation_change(self, device_id, change_type): + """设备模型关系变化回调""" + # 判断模型绑定关系是否真的发生了变化,避免不必要的重启 + if change_type == NotifyChangeType.DEVICE_SCENE_RELATION_UPDATE: + self.restart_device_thread(device_id) diff --git a/apis/device.py b/apis/device.py index f786851..85fb10c 100644 --- a/apis/device.py +++ b/apis/device.py @@ -3,15 +3,18 @@ from fastapi import APIRouter, Depends, Query from sqlmodel import Session -from algo.algo_runner import AlgoRunner from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo from services.device_service import DeviceService -from algo.algo_runner_manager import get_algo_runner +from app_instance import get_app router = APIRouter() +app = get_app() + +def get_service(): + return app.state.device_service @router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) @@ -45,15 +48,13 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.device_service +def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): device = service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.device_service +def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): device = service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") @@ -61,8 +62,7 @@ @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.device_service +def delete_device(device_id: int, service: DeviceService = Depends(get_service)): device = service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py index 9152eca..443176b 100644 --- a/apis/device_model_realtion.py +++ b/apis/device_model_realtion.py @@ -8,10 +8,13 @@ from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation from services.device_model_relation_service import DeviceModelRelationService -from algo.algo_runner import AlgoRunner -from algo.algo_runner_manager import get_algo_runner +from app_instance import get_app router = APIRouter() +app = get_app() + +def get_service(): + return app.state.device_model_relation_service @router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) @@ -35,12 +38,10 @@ @router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) def update_by_device(relation_data: List[DeviceModelRelationCreate], device_id: int = Query(...), - algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.relation_service + service: DeviceModelRelationService = Depends(get_service)): relations = service.update_relations_by_device(device_id, relation_data) return standard_response(data=relations) - # @router.delete("/delete_by_device", response_model=StandardResponse[int]) # def delete_device(device_id: int, db: Session = Depends(get_db)): # service = DeviceModelRelationService(db) diff --git a/apis/model.py b/apis/model.py index 264e10a..78c064f 100644 --- a/apis/model.py +++ b/apis/model.py @@ -4,15 +4,17 @@ from sqlmodel import Session from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param +from app_instance import get_app from db.database import get_db from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo from services.model_service import ModelService -from algo.algo_runner import AlgoRunner -from algo.algo_runner_manager import get_algo_runner router = APIRouter() +app = get_app() +def get_service(): + return app.state.model_service @router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) def get_model_list( @@ -60,8 +62,7 @@ @router.post("/update", response_model=StandardResponse[AlgoModelInfo]) def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), file: UploadFile = File(None, description="模型文件"), - algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.model_service + service: ModelService = Depends(get_service)): model_data = AlgoModelUpdate.parse_raw(json_data) model = service.update_model(model_data, file) if not model: diff --git a/apis/scene.py b/apis/scene.py index b0a4d22..f84b54b 100644 --- a/apis/scene.py +++ b/apis/scene.py @@ -9,8 +9,6 @@ from entity.scene import SceneInfo from services.scene_service import SceneService -from algo.algo_runner import AlgoRunner -from algo.algo_runner_manager import get_algo_runner router = APIRouter() diff --git a/app_instance.py b/app_instance.py new file mode 100644 index 0000000..b72926d --- /dev/null +++ b/app_instance.py @@ -0,0 +1,77 @@ +import asyncio +from contextlib import asynccontextmanager + +from fastapi import FastAPI, HTTPException + +from algo.algo_runner import AlgoRunner +from algo.scene_runner import SceneRunner + +from common.biz_exception import BizExceptionHandlers +from common.global_logger import logger +from db.database import get_db +from services.device_model_relation_service import DeviceModelRelationService +from services.device_scene_relation_service import DeviceSceneRelationService +from services.device_service import DeviceService +from services.model_service import ModelService +from services.scene_service import SceneService +from tcp.tcp_manager import TcpManager + +_app = None # 创建一个私有变量来存储 app 实例 + + +def create_app() -> FastAPI: + global _app + if _app is None: + _app = FastAPI() + + @asynccontextmanager + async def lifespan(app: FastAPI): + main_loop = asyncio.get_running_loop() + + with next(get_db()) as db: + device_service = DeviceService(db) + model_service = ModelService(db) + model_relation_service = DeviceModelRelationService(db) + scene_service = SceneService(db) + scene_relation_service = DeviceSceneRelationService(db) + + app.state.device_service = device_service + app.state.model_service = model_service + app.state.model_relation_service = model_relation_service + app.state.scene_service = scene_service + app.state.scene_relation_service = scene_relation_service + + tcp_manager = TcpManager(device_service=device_service) + app.state.tcp_manager = tcp_manager + await tcp_manager.start() + + algo_runner = AlgoRunner( + device_service=device_service, + model_service=model_service, + relation_service=model_relation_service, + ) + app.state.algo_runner = algo_runner + await algo_runner.start() + + scene_runner = SceneRunner( + device_service=device_service, + scene_service=scene_service, + relation_service=scene_relation_service, + tcp_manager=tcp_manager, + main_loop=main_loop + ) + app.state.scene_runner = scene_runner + await scene_runner.start() + + yield # 允许请求处理 + + logger.info("Shutting down application...") + + _app.router.lifespan_context = lifespan + _app.add_exception_handler(HTTPException, BizExceptionHandlers.biz_exception_handler) + + return _app + + +def get_app() -> FastAPI: + return _app or create_app() diff --git a/common/string_utils.py b/common/string_utils.py index 42d987b..7ff7f03 100644 --- a/common/string_utils.py +++ b/common/string_utils.py @@ -1,5 +1,12 @@ +import importlib import re +def get_class(module_name, class_name): + # 动态导入模块 + module = importlib.import_module(module_name) + # 使用 getattr 从模块中获取类 + return getattr(module, class_name) + def camel_to_snake(name): # 将大写字母前加上下划线,并将整个字符串转换为小写 diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db index b1d8fab..d1acdd4 100644 --- a/db/safe-algo-pro.db +++ b/db/safe-algo-pro.db Binary files differ diff --git a/main.py b/main.py index 61cbe9e..8e5e464 100644 --- a/main.py +++ b/main.py @@ -1,58 +1,21 @@ -import asyncio -from contextlib import asynccontextmanager -from fastapi import FastAPI, HTTPException - -from algo.algo_runner_manager import get_algo_runner -from apis.router import router import uvicorn import logging -from common.biz_exception import BizExceptionHandlers +from app_instance import get_app from common.global_logger import logger +app = get_app() -app = FastAPI() - -# # 初始化 AlgoRunner -# algo_runner = AlgoRunner() -# -# -# # 使用 FastAPI 的 startup 事件来启动 AlgoRunner -# @app.on_event("startup") -# async def startup_event(): -# algo_runner.start() - -algo_runner = get_algo_runner() - - -@asynccontextmanager -async def lifespan(app: FastAPI): - # 应用启动时的初始化 - await algo_runner.start() - # app.add_event_handler("startup", lambda: asyncio.create_task(algo_runner.start())) - # 允许请求处理 - yield - # 应用关闭时的清理逻辑 - logger.info("Shutting down application...") - - -# 包含所有模块的路由 +# 延迟导入 router 并注册路由 +from apis.router import router app.include_router(router, prefix="/api") -app.add_exception_handler(HTTPException, BizExceptionHandlers.biz_exception_handler) -app.router.lifespan_context = lifespan - if __name__ == "__main__": - # 重定向 uvicorn 的日志 uvicorn_logger = logging.getLogger("uvicorn") uvicorn_logger.handlers = logger.handlers uvicorn_logger.setLevel(logging.DEBUG) - # 重定向 uvicorn 的 access 日志 - # access_logger = logging.getLogger("uvicorn.access") - # access_logger.handlers = logger.handlers - # access_logger.setLevel(logging.DEBUG) uvicorn.run(app, host="0.0.0.0", port=9000, log_config=None) diff --git a/requirements.txt b/requirements.txt index 9a7d88c..af16eda 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,12 @@ sqlmodel openpyxl python-multipart -docker \ No newline at end of file +docker +numpy +ultralytics +opencv-python +pydantic +pandas +starlette +uvicorn +sqlalchemy \ No newline at end of file diff --git a/scene_handler/base_scene_handler.py b/scene_handler/base_scene_handler.py new file mode 100644 index 0000000..5451a16 --- /dev/null +++ b/scene_handler/base_scene_handler.py @@ -0,0 +1,29 @@ +from asyncio import Event + +from algo.model_manager import AlgoModelExec +from algo.stream_loader import OpenCVStreamLoad +from common.device_status_manager import DeviceStatusManager +from common.global_logger import logger +from entity.device import Device + +from ultralytics import YOLO + +from tcp.tcp_manager import TcpManager + + +class BaseSceneHandler: + + def __init__(self, device: Device, + thread_id: str, + tcp_manager: TcpManager, + main_loop ): + self.device = device + self.thread_id = thread_id + self.tcp_manager =tcp_manager + self.main_loop = main_loop + + def stop_task(self): + pass + + def run(self): + pass \ No newline at end of file diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml index d0876a7..33f5c97 100644 --- a/.idea/safe-algo-pro.iml +++ b/.idea/safe-algo-pro.iml @@ -5,4 +5,7 @@ + + \ No newline at end of file diff --git a/algo/algo_runner.py b/algo/algo_runner.py index d59d93a..40b8d1e 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -5,7 +5,7 @@ from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager -from common.consts import NotifyChangeType +from common.consts import NotifyChangeType, DEVICE_MODE from common.global_thread_pool import GlobalThreadPool from db.database import get_db from entity.device import Device @@ -14,23 +14,25 @@ from services.model_service import ModelService from common.global_logger import logger - class AlgoRunner: - def __init__(self): - with next(get_db()) as db: - self.device_service = DeviceService(db) - self.model_service = ModelService(db) - self.relation_service = DeviceModelRelationService(db) - self.model_manager = ModelManager(db) + def __init__(self, + device_service : DeviceService, + model_service: ModelService, + relation_service: DeviceModelRelationService): + + self.device_service = device_service + self.model_service = model_service + self.relation_service = relation_service + self.model_manager = ModelManager(model_service) self.thread_pool = GlobalThreadPool() self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 - self.device_service.register_change_callback(self.on_device_change) - self.model_service.register_change_callback(self.on_model_change) - self.relation_service.register_change_callback(self.on_relation_change) + # self.device_service.register_change_callback(self.on_device_change) + # self.model_service.register_change_callback(self.on_model_change) + # self.relation_service.register_change_callback(self.on_relation_change) async def start(self): logger.info("Starting AlgoRunner...") @@ -44,10 +46,6 @@ for device in devices: self.start_device_thread(device) - # @staticmethod - # def get_device_thread_id(device_id): - # return f'device_{device_id}' - def start_device_thread(self, device: Device): device = copy.deepcopy(device) """为单个设备启动检测线程""" @@ -56,7 +54,10 @@ logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return - thread_id = f'device_{device.id}_{uuid.uuid4()}' + if not device.mode == DEVICE_MODE.ALGO: + return + + thread_id = f'device_{device.id}_algo_{uuid.uuid4()}' logger.info(f'start thread {thread_id}, device info: {device}') # if self.thread_pool.check_task_is_running(thread_id=thread_id): @@ -135,9 +136,10 @@ elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: - # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 - old_device = self.device_tasks[device_id].device - new_device = self.device_service.get_device(device_id) + if device_id in self.device_tasks: + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) self.restart_device_thread(device_id) diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py index d7d8015..9109d75 100644 --- a/algo/algo_runner_manager.py +++ b/algo/algo_runner_manager.py @@ -1,7 +1,8 @@ from algo.algo_runner import AlgoRunner -algo_runner = AlgoRunner() +# algo_runner = AlgoRunner() def get_algo_runner(): - return algo_runner + pass + # return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index b5a2d9e..e222ca7 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -9,7 +9,7 @@ from common.device_status_manager import DeviceStatusManager from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool -from common.string_utils import camel_to_snake +from common.string_utils import camel_to_snake, get_class from db.database import get_db from entity.device import Device from entity.frame_analysis_result import FrameAnalysisResultCreate @@ -17,11 +17,7 @@ from services.frame_analysis_result_service import FrameAnalysisResultService -def get_class(module_name, class_name): - # 动态导入模块 - module = importlib.import_module(module_name) - # 使用 getattr 从模块中获取类 - return getattr(module, class_name) + @dataclass diff --git a/algo/model_manager.py b/algo/model_manager.py index 94a058c..e3bc4c9 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -19,9 +19,9 @@ class ModelManager: - def __init__(self, db, model_warm_up=5): - self.db = db - self.model_service = ModelService(self.db) + def __init__(self, model_service: ModelService, model_warm_up=5): + # self.db = db + self.model_service = model_service self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up @@ -37,6 +37,8 @@ logger.info('loading models') self.models = {} self.query_model_inuse() + if not self.models: + logger.info("no model in use") for algo_model_id, algo_model_exec in self.models.items(): self.load_model(algo_model_exec) diff --git a/algo/scene_runner.py b/algo/scene_runner.py new file mode 100644 index 0000000..2fb2b2b --- /dev/null +++ b/algo/scene_runner.py @@ -0,0 +1,146 @@ +import concurrent +import copy +import uuid +from typing import Dict + +from common.consts import DEVICE_MODE, NotifyChangeType +from common.global_logger import logger +from common.global_thread_pool import GlobalThreadPool +from common.string_utils import get_class, camel_to_snake +from entity.device import Device +from scene_handler.base_scene_handler import BaseSceneHandler +from services.device_scene_relation_service import DeviceSceneRelationService +from services.device_service import DeviceService +from services.scene_service import SceneService +from tcp.tcp_manager import TcpManager + + +class SceneRunner: + def __init__(self, + device_service: DeviceService, + scene_service: SceneService, + relation_service: DeviceSceneRelationService, + tcp_manager: TcpManager, + main_loop): + self.device_service = device_service + self.scene_service = scene_service + self.relation_service = relation_service + self.tcp_manager = tcp_manager + self.main_loop = main_loop + + self.thread_pool = GlobalThreadPool() + self.device_tasks: Dict[int, BaseSceneHandler] = {} + self.device_scene_relations = {} + + # 注册设备和模型的变化回调 + # self.device_service.register_change_callback(self.on_device_change) + # self.scene_service.register_change_callback(self.on_scene_change) + # self.relation_service.register_change_callback(self.on_relation_change) + + async def start(self): + self.load_and_start_devices() + + def load_and_start_devices(self): + """从数据库读取设备列表并启动线程""" + devices = self.device_service.get_device_list() + devices = [device for device in devices if device.mode == DEVICE_MODE.SCENE] + for device in devices: + self.start_device_thread(device) + + def start_device_thread(self, device: Device): + device = copy.deepcopy(device) + """为单个设备启动检测线程""" + + if not device.input_stream_url: + logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') + return + + if not device.mode == DEVICE_MODE.SCENE: + return + + thread_id = f'device_{device.id}_scene_{uuid.uuid4()}' + scene = self.relation_service.get_device_scene(device.id) + if scene: + self.device_scene_relations[device.id] = scene + try: + handle_task_name = scene.scene_handle_task + handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) + handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) + future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) + self.device_tasks[device.id] = handler_instance + logger.info(f'start thread {thread_id}, device info: {device}') + except Exception as e: + logger.exception(f"设备 {device.code} 场景 {scene.scene_name} 启动异常 : {e}") + + def stop_device_thread(self, device_id): + """停止指定设备的检测线程,并确保其成功停止""" + if device_id in self.device_tasks: + logger.info(f'stop device {device_id} thread') + # 获取线程 ID 和 future 对象 + thread_id = self.device_tasks[device_id].thread_id + future = self.thread_pool.get_task_future(thread_id=thread_id) + + if future: + if not future.done(): + # 任务正在运行,调用 stop_detection_task 停止任务 + self.device_tasks[device_id].stop_task() + try: + # 设置超时时间等待任务停止(例如10秒) + result = future.result(timeout=30) + logger.info(f"Task {thread_id} stopped successfully.") + except concurrent.futures.TimeoutError as te: + logger.error(f"Task {thread_id} did not stop within the timeout.") + except Exception as e: + logger.exception(f"Task {thread_id} encountered an error while stopping: {e}") + finally: + # 确保无论任务是否停止,都将其从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.info(f"Task {thread_id} has already stopped.") + # 任务已停止,直接从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.warning(f"No task found for {thread_id} .") + else: + logger.warning(f"No task exists for device {device_id} in device_tasks.") + + def restart_device_thread(self, device_id): + try: + self.stop_device_thread(device_id) + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + except Exception as e: + logger.error(e) + + def on_device_change(self, device_id, change_type): + logger.info(f"on device change, device {device_id} {change_type}") + """设备变化回调""" + if change_type == NotifyChangeType.DEVICE_CREATE: + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + elif change_type == NotifyChangeType.DEVICE_DELETE: + self.stop_device_thread(device_id) + elif change_type == NotifyChangeType.DEVICE_UPDATE: + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + if device_id in self.device_tasks: + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) + + self.restart_device_thread(device_id) + + def on_scene_change(self, scene_id, change_type): + """模型变化回调""" + if change_type == NotifyChangeType.SCENE_UPDATE: + + devices_to_reload = [ + device_id for device_id, relation_info in self.device_scene_relations.items() + if relation_info.scene_id == scene_id + ] + for device_id in devices_to_reload: + self.restart_device_thread(device_id) + + def on_relation_change(self, device_id, change_type): + """设备模型关系变化回调""" + # 判断模型绑定关系是否真的发生了变化,避免不必要的重启 + if change_type == NotifyChangeType.DEVICE_SCENE_RELATION_UPDATE: + self.restart_device_thread(device_id) diff --git a/apis/device.py b/apis/device.py index f786851..85fb10c 100644 --- a/apis/device.py +++ b/apis/device.py @@ -3,15 +3,18 @@ from fastapi import APIRouter, Depends, Query from sqlmodel import Session -from algo.algo_runner import AlgoRunner from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo from services.device_service import DeviceService -from algo.algo_runner_manager import get_algo_runner +from app_instance import get_app router = APIRouter() +app = get_app() + +def get_service(): + return app.state.device_service @router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) @@ -45,15 +48,13 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.device_service +def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): device = service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.device_service +def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): device = service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") @@ -61,8 +62,7 @@ @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.device_service +def delete_device(device_id: int, service: DeviceService = Depends(get_service)): device = service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py index 9152eca..443176b 100644 --- a/apis/device_model_realtion.py +++ b/apis/device_model_realtion.py @@ -8,10 +8,13 @@ from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation from services.device_model_relation_service import DeviceModelRelationService -from algo.algo_runner import AlgoRunner -from algo.algo_runner_manager import get_algo_runner +from app_instance import get_app router = APIRouter() +app = get_app() + +def get_service(): + return app.state.device_model_relation_service @router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) @@ -35,12 +38,10 @@ @router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) def update_by_device(relation_data: List[DeviceModelRelationCreate], device_id: int = Query(...), - algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.relation_service + service: DeviceModelRelationService = Depends(get_service)): relations = service.update_relations_by_device(device_id, relation_data) return standard_response(data=relations) - # @router.delete("/delete_by_device", response_model=StandardResponse[int]) # def delete_device(device_id: int, db: Session = Depends(get_db)): # service = DeviceModelRelationService(db) diff --git a/apis/model.py b/apis/model.py index 264e10a..78c064f 100644 --- a/apis/model.py +++ b/apis/model.py @@ -4,15 +4,17 @@ from sqlmodel import Session from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param +from app_instance import get_app from db.database import get_db from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo from services.model_service import ModelService -from algo.algo_runner import AlgoRunner -from algo.algo_runner_manager import get_algo_runner router = APIRouter() +app = get_app() +def get_service(): + return app.state.model_service @router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) def get_model_list( @@ -60,8 +62,7 @@ @router.post("/update", response_model=StandardResponse[AlgoModelInfo]) def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), file: UploadFile = File(None, description="模型文件"), - algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.model_service + service: ModelService = Depends(get_service)): model_data = AlgoModelUpdate.parse_raw(json_data) model = service.update_model(model_data, file) if not model: diff --git a/apis/scene.py b/apis/scene.py index b0a4d22..f84b54b 100644 --- a/apis/scene.py +++ b/apis/scene.py @@ -9,8 +9,6 @@ from entity.scene import SceneInfo from services.scene_service import SceneService -from algo.algo_runner import AlgoRunner -from algo.algo_runner_manager import get_algo_runner router = APIRouter() diff --git a/app_instance.py b/app_instance.py new file mode 100644 index 0000000..b72926d --- /dev/null +++ b/app_instance.py @@ -0,0 +1,77 @@ +import asyncio +from contextlib import asynccontextmanager + +from fastapi import FastAPI, HTTPException + +from algo.algo_runner import AlgoRunner +from algo.scene_runner import SceneRunner + +from common.biz_exception import BizExceptionHandlers +from common.global_logger import logger +from db.database import get_db +from services.device_model_relation_service import DeviceModelRelationService +from services.device_scene_relation_service import DeviceSceneRelationService +from services.device_service import DeviceService +from services.model_service import ModelService +from services.scene_service import SceneService +from tcp.tcp_manager import TcpManager + +_app = None # 创建一个私有变量来存储 app 实例 + + +def create_app() -> FastAPI: + global _app + if _app is None: + _app = FastAPI() + + @asynccontextmanager + async def lifespan(app: FastAPI): + main_loop = asyncio.get_running_loop() + + with next(get_db()) as db: + device_service = DeviceService(db) + model_service = ModelService(db) + model_relation_service = DeviceModelRelationService(db) + scene_service = SceneService(db) + scene_relation_service = DeviceSceneRelationService(db) + + app.state.device_service = device_service + app.state.model_service = model_service + app.state.model_relation_service = model_relation_service + app.state.scene_service = scene_service + app.state.scene_relation_service = scene_relation_service + + tcp_manager = TcpManager(device_service=device_service) + app.state.tcp_manager = tcp_manager + await tcp_manager.start() + + algo_runner = AlgoRunner( + device_service=device_service, + model_service=model_service, + relation_service=model_relation_service, + ) + app.state.algo_runner = algo_runner + await algo_runner.start() + + scene_runner = SceneRunner( + device_service=device_service, + scene_service=scene_service, + relation_service=scene_relation_service, + tcp_manager=tcp_manager, + main_loop=main_loop + ) + app.state.scene_runner = scene_runner + await scene_runner.start() + + yield # 允许请求处理 + + logger.info("Shutting down application...") + + _app.router.lifespan_context = lifespan + _app.add_exception_handler(HTTPException, BizExceptionHandlers.biz_exception_handler) + + return _app + + +def get_app() -> FastAPI: + return _app or create_app() diff --git a/common/string_utils.py b/common/string_utils.py index 42d987b..7ff7f03 100644 --- a/common/string_utils.py +++ b/common/string_utils.py @@ -1,5 +1,12 @@ +import importlib import re +def get_class(module_name, class_name): + # 动态导入模块 + module = importlib.import_module(module_name) + # 使用 getattr 从模块中获取类 + return getattr(module, class_name) + def camel_to_snake(name): # 将大写字母前加上下划线,并将整个字符串转换为小写 diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db index b1d8fab..d1acdd4 100644 --- a/db/safe-algo-pro.db +++ b/db/safe-algo-pro.db Binary files differ diff --git a/main.py b/main.py index 61cbe9e..8e5e464 100644 --- a/main.py +++ b/main.py @@ -1,58 +1,21 @@ -import asyncio -from contextlib import asynccontextmanager -from fastapi import FastAPI, HTTPException - -from algo.algo_runner_manager import get_algo_runner -from apis.router import router import uvicorn import logging -from common.biz_exception import BizExceptionHandlers +from app_instance import get_app from common.global_logger import logger +app = get_app() -app = FastAPI() - -# # 初始化 AlgoRunner -# algo_runner = AlgoRunner() -# -# -# # 使用 FastAPI 的 startup 事件来启动 AlgoRunner -# @app.on_event("startup") -# async def startup_event(): -# algo_runner.start() - -algo_runner = get_algo_runner() - - -@asynccontextmanager -async def lifespan(app: FastAPI): - # 应用启动时的初始化 - await algo_runner.start() - # app.add_event_handler("startup", lambda: asyncio.create_task(algo_runner.start())) - # 允许请求处理 - yield - # 应用关闭时的清理逻辑 - logger.info("Shutting down application...") - - -# 包含所有模块的路由 +# 延迟导入 router 并注册路由 +from apis.router import router app.include_router(router, prefix="/api") -app.add_exception_handler(HTTPException, BizExceptionHandlers.biz_exception_handler) -app.router.lifespan_context = lifespan - if __name__ == "__main__": - # 重定向 uvicorn 的日志 uvicorn_logger = logging.getLogger("uvicorn") uvicorn_logger.handlers = logger.handlers uvicorn_logger.setLevel(logging.DEBUG) - # 重定向 uvicorn 的 access 日志 - # access_logger = logging.getLogger("uvicorn.access") - # access_logger.handlers = logger.handlers - # access_logger.setLevel(logging.DEBUG) uvicorn.run(app, host="0.0.0.0", port=9000, log_config=None) diff --git a/requirements.txt b/requirements.txt index 9a7d88c..af16eda 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,12 @@ sqlmodel openpyxl python-multipart -docker \ No newline at end of file +docker +numpy +ultralytics +opencv-python +pydantic +pandas +starlette +uvicorn +sqlalchemy \ No newline at end of file diff --git a/scene_handler/base_scene_handler.py b/scene_handler/base_scene_handler.py new file mode 100644 index 0000000..5451a16 --- /dev/null +++ b/scene_handler/base_scene_handler.py @@ -0,0 +1,29 @@ +from asyncio import Event + +from algo.model_manager import AlgoModelExec +from algo.stream_loader import OpenCVStreamLoad +from common.device_status_manager import DeviceStatusManager +from common.global_logger import logger +from entity.device import Device + +from ultralytics import YOLO + +from tcp.tcp_manager import TcpManager + + +class BaseSceneHandler: + + def __init__(self, device: Device, + thread_id: str, + tcp_manager: TcpManager, + main_loop ): + self.device = device + self.thread_id = thread_id + self.tcp_manager =tcp_manager + self.main_loop = main_loop + + def stop_task(self): + pass + + def run(self): + pass \ No newline at end of file diff --git a/scene_handler/limit_space_scene_handler.py b/scene_handler/limit_space_scene_handler.py new file mode 100644 index 0000000..68df7df --- /dev/null +++ b/scene_handler/limit_space_scene_handler.py @@ -0,0 +1,65 @@ +import asyncio +from asyncio import Event + +from algo.model_manager import AlgoModelExec +from algo.stream_loader import OpenCVStreamLoad +from common.device_status_manager import DeviceStatusManager +from common.global_logger import logger +from entity.device import Device + +from ultralytics import YOLO + +from scene_handler.base_scene_handler import BaseSceneHandler +from tcp.tcp_manager import TcpManager + + +class LimitSpaceSceneHandler(BaseSceneHandler): + + def __init__(self, device: Device, thread_id: str, tcp_manager: TcpManager, main_loop): + super().__init__(device=device, thread_id=thread_id, tcp_manager=tcp_manager, main_loop=main_loop) + # self.device = device + # self.thread_id = thread_id + self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, + device_thread_id=thread_id) + self.device_status_manager = DeviceStatusManager() + + self.person_model = YOLO('weights/yolov8s.pt') + + self.__stop_event = Event(loop=main_loop) # 使用 Event 控制线程的运行状态 + + def stop_task(self, **kwargs): + logger.info(f'stop detection task {self.device.id}, thread_id: {self.thread_id}') + self.__stop_event.set() + self.stream_loader.stop() # 停止视频流加载的线程 + + def send_tcp_message(self, message: bytes, have_response=False): + asyncio.run_coroutine_threadsafe( + self.tcp_manager.send_message_to_device(device_id=self.device.id, + message=message, + have_response=have_response), + self.main_loop) + + def run(self): + while not self.stream_loader.init: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + self.stream_loader.init_cap() + for frame in self.stream_loader: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + # print('frame') + if frame is None: + continue + + self.device_status_manager.set_status(device_id=self.device.id) + results = self.person_model.predict(source=frame, imgsz=640, + save_txt=False, + save=False, + verbose=False, stream=True) + result = (list(results)) + if len(result[0]) > 0: + asyncio.run_coroutine_threadsafe( + self.tcp_manager.send_message_to_device(device_id=self.device.id, + message=b'\xaa\x01\x00\x93\x07\x00\x9B', + have_response=False), + self.main_loop) diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml index d0876a7..33f5c97 100644 --- a/.idea/safe-algo-pro.iml +++ b/.idea/safe-algo-pro.iml @@ -5,4 +5,7 @@ + + \ No newline at end of file diff --git a/algo/algo_runner.py b/algo/algo_runner.py index d59d93a..40b8d1e 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -5,7 +5,7 @@ from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager -from common.consts import NotifyChangeType +from common.consts import NotifyChangeType, DEVICE_MODE from common.global_thread_pool import GlobalThreadPool from db.database import get_db from entity.device import Device @@ -14,23 +14,25 @@ from services.model_service import ModelService from common.global_logger import logger - class AlgoRunner: - def __init__(self): - with next(get_db()) as db: - self.device_service = DeviceService(db) - self.model_service = ModelService(db) - self.relation_service = DeviceModelRelationService(db) - self.model_manager = ModelManager(db) + def __init__(self, + device_service : DeviceService, + model_service: ModelService, + relation_service: DeviceModelRelationService): + + self.device_service = device_service + self.model_service = model_service + self.relation_service = relation_service + self.model_manager = ModelManager(model_service) self.thread_pool = GlobalThreadPool() self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 - self.device_service.register_change_callback(self.on_device_change) - self.model_service.register_change_callback(self.on_model_change) - self.relation_service.register_change_callback(self.on_relation_change) + # self.device_service.register_change_callback(self.on_device_change) + # self.model_service.register_change_callback(self.on_model_change) + # self.relation_service.register_change_callback(self.on_relation_change) async def start(self): logger.info("Starting AlgoRunner...") @@ -44,10 +46,6 @@ for device in devices: self.start_device_thread(device) - # @staticmethod - # def get_device_thread_id(device_id): - # return f'device_{device_id}' - def start_device_thread(self, device: Device): device = copy.deepcopy(device) """为单个设备启动检测线程""" @@ -56,7 +54,10 @@ logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return - thread_id = f'device_{device.id}_{uuid.uuid4()}' + if not device.mode == DEVICE_MODE.ALGO: + return + + thread_id = f'device_{device.id}_algo_{uuid.uuid4()}' logger.info(f'start thread {thread_id}, device info: {device}') # if self.thread_pool.check_task_is_running(thread_id=thread_id): @@ -135,9 +136,10 @@ elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: - # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 - old_device = self.device_tasks[device_id].device - new_device = self.device_service.get_device(device_id) + if device_id in self.device_tasks: + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) self.restart_device_thread(device_id) diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py index d7d8015..9109d75 100644 --- a/algo/algo_runner_manager.py +++ b/algo/algo_runner_manager.py @@ -1,7 +1,8 @@ from algo.algo_runner import AlgoRunner -algo_runner = AlgoRunner() +# algo_runner = AlgoRunner() def get_algo_runner(): - return algo_runner + pass + # return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index b5a2d9e..e222ca7 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -9,7 +9,7 @@ from common.device_status_manager import DeviceStatusManager from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool -from common.string_utils import camel_to_snake +from common.string_utils import camel_to_snake, get_class from db.database import get_db from entity.device import Device from entity.frame_analysis_result import FrameAnalysisResultCreate @@ -17,11 +17,7 @@ from services.frame_analysis_result_service import FrameAnalysisResultService -def get_class(module_name, class_name): - # 动态导入模块 - module = importlib.import_module(module_name) - # 使用 getattr 从模块中获取类 - return getattr(module, class_name) + @dataclass diff --git a/algo/model_manager.py b/algo/model_manager.py index 94a058c..e3bc4c9 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -19,9 +19,9 @@ class ModelManager: - def __init__(self, db, model_warm_up=5): - self.db = db - self.model_service = ModelService(self.db) + def __init__(self, model_service: ModelService, model_warm_up=5): + # self.db = db + self.model_service = model_service self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up @@ -37,6 +37,8 @@ logger.info('loading models') self.models = {} self.query_model_inuse() + if not self.models: + logger.info("no model in use") for algo_model_id, algo_model_exec in self.models.items(): self.load_model(algo_model_exec) diff --git a/algo/scene_runner.py b/algo/scene_runner.py new file mode 100644 index 0000000..2fb2b2b --- /dev/null +++ b/algo/scene_runner.py @@ -0,0 +1,146 @@ +import concurrent +import copy +import uuid +from typing import Dict + +from common.consts import DEVICE_MODE, NotifyChangeType +from common.global_logger import logger +from common.global_thread_pool import GlobalThreadPool +from common.string_utils import get_class, camel_to_snake +from entity.device import Device +from scene_handler.base_scene_handler import BaseSceneHandler +from services.device_scene_relation_service import DeviceSceneRelationService +from services.device_service import DeviceService +from services.scene_service import SceneService +from tcp.tcp_manager import TcpManager + + +class SceneRunner: + def __init__(self, + device_service: DeviceService, + scene_service: SceneService, + relation_service: DeviceSceneRelationService, + tcp_manager: TcpManager, + main_loop): + self.device_service = device_service + self.scene_service = scene_service + self.relation_service = relation_service + self.tcp_manager = tcp_manager + self.main_loop = main_loop + + self.thread_pool = GlobalThreadPool() + self.device_tasks: Dict[int, BaseSceneHandler] = {} + self.device_scene_relations = {} + + # 注册设备和模型的变化回调 + # self.device_service.register_change_callback(self.on_device_change) + # self.scene_service.register_change_callback(self.on_scene_change) + # self.relation_service.register_change_callback(self.on_relation_change) + + async def start(self): + self.load_and_start_devices() + + def load_and_start_devices(self): + """从数据库读取设备列表并启动线程""" + devices = self.device_service.get_device_list() + devices = [device for device in devices if device.mode == DEVICE_MODE.SCENE] + for device in devices: + self.start_device_thread(device) + + def start_device_thread(self, device: Device): + device = copy.deepcopy(device) + """为单个设备启动检测线程""" + + if not device.input_stream_url: + logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') + return + + if not device.mode == DEVICE_MODE.SCENE: + return + + thread_id = f'device_{device.id}_scene_{uuid.uuid4()}' + scene = self.relation_service.get_device_scene(device.id) + if scene: + self.device_scene_relations[device.id] = scene + try: + handle_task_name = scene.scene_handle_task + handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) + handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) + future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) + self.device_tasks[device.id] = handler_instance + logger.info(f'start thread {thread_id}, device info: {device}') + except Exception as e: + logger.exception(f"设备 {device.code} 场景 {scene.scene_name} 启动异常 : {e}") + + def stop_device_thread(self, device_id): + """停止指定设备的检测线程,并确保其成功停止""" + if device_id in self.device_tasks: + logger.info(f'stop device {device_id} thread') + # 获取线程 ID 和 future 对象 + thread_id = self.device_tasks[device_id].thread_id + future = self.thread_pool.get_task_future(thread_id=thread_id) + + if future: + if not future.done(): + # 任务正在运行,调用 stop_detection_task 停止任务 + self.device_tasks[device_id].stop_task() + try: + # 设置超时时间等待任务停止(例如10秒) + result = future.result(timeout=30) + logger.info(f"Task {thread_id} stopped successfully.") + except concurrent.futures.TimeoutError as te: + logger.error(f"Task {thread_id} did not stop within the timeout.") + except Exception as e: + logger.exception(f"Task {thread_id} encountered an error while stopping: {e}") + finally: + # 确保无论任务是否停止,都将其从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.info(f"Task {thread_id} has already stopped.") + # 任务已停止,直接从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.warning(f"No task found for {thread_id} .") + else: + logger.warning(f"No task exists for device {device_id} in device_tasks.") + + def restart_device_thread(self, device_id): + try: + self.stop_device_thread(device_id) + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + except Exception as e: + logger.error(e) + + def on_device_change(self, device_id, change_type): + logger.info(f"on device change, device {device_id} {change_type}") + """设备变化回调""" + if change_type == NotifyChangeType.DEVICE_CREATE: + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + elif change_type == NotifyChangeType.DEVICE_DELETE: + self.stop_device_thread(device_id) + elif change_type == NotifyChangeType.DEVICE_UPDATE: + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + if device_id in self.device_tasks: + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) + + self.restart_device_thread(device_id) + + def on_scene_change(self, scene_id, change_type): + """模型变化回调""" + if change_type == NotifyChangeType.SCENE_UPDATE: + + devices_to_reload = [ + device_id for device_id, relation_info in self.device_scene_relations.items() + if relation_info.scene_id == scene_id + ] + for device_id in devices_to_reload: + self.restart_device_thread(device_id) + + def on_relation_change(self, device_id, change_type): + """设备模型关系变化回调""" + # 判断模型绑定关系是否真的发生了变化,避免不必要的重启 + if change_type == NotifyChangeType.DEVICE_SCENE_RELATION_UPDATE: + self.restart_device_thread(device_id) diff --git a/apis/device.py b/apis/device.py index f786851..85fb10c 100644 --- a/apis/device.py +++ b/apis/device.py @@ -3,15 +3,18 @@ from fastapi import APIRouter, Depends, Query from sqlmodel import Session -from algo.algo_runner import AlgoRunner from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo from services.device_service import DeviceService -from algo.algo_runner_manager import get_algo_runner +from app_instance import get_app router = APIRouter() +app = get_app() + +def get_service(): + return app.state.device_service @router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) @@ -45,15 +48,13 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.device_service +def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): device = service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.device_service +def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): device = service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") @@ -61,8 +62,7 @@ @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.device_service +def delete_device(device_id: int, service: DeviceService = Depends(get_service)): device = service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py index 9152eca..443176b 100644 --- a/apis/device_model_realtion.py +++ b/apis/device_model_realtion.py @@ -8,10 +8,13 @@ from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation from services.device_model_relation_service import DeviceModelRelationService -from algo.algo_runner import AlgoRunner -from algo.algo_runner_manager import get_algo_runner +from app_instance import get_app router = APIRouter() +app = get_app() + +def get_service(): + return app.state.device_model_relation_service @router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) @@ -35,12 +38,10 @@ @router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) def update_by_device(relation_data: List[DeviceModelRelationCreate], device_id: int = Query(...), - algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.relation_service + service: DeviceModelRelationService = Depends(get_service)): relations = service.update_relations_by_device(device_id, relation_data) return standard_response(data=relations) - # @router.delete("/delete_by_device", response_model=StandardResponse[int]) # def delete_device(device_id: int, db: Session = Depends(get_db)): # service = DeviceModelRelationService(db) diff --git a/apis/model.py b/apis/model.py index 264e10a..78c064f 100644 --- a/apis/model.py +++ b/apis/model.py @@ -4,15 +4,17 @@ from sqlmodel import Session from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param +from app_instance import get_app from db.database import get_db from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo from services.model_service import ModelService -from algo.algo_runner import AlgoRunner -from algo.algo_runner_manager import get_algo_runner router = APIRouter() +app = get_app() +def get_service(): + return app.state.model_service @router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) def get_model_list( @@ -60,8 +62,7 @@ @router.post("/update", response_model=StandardResponse[AlgoModelInfo]) def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), file: UploadFile = File(None, description="模型文件"), - algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.model_service + service: ModelService = Depends(get_service)): model_data = AlgoModelUpdate.parse_raw(json_data) model = service.update_model(model_data, file) if not model: diff --git a/apis/scene.py b/apis/scene.py index b0a4d22..f84b54b 100644 --- a/apis/scene.py +++ b/apis/scene.py @@ -9,8 +9,6 @@ from entity.scene import SceneInfo from services.scene_service import SceneService -from algo.algo_runner import AlgoRunner -from algo.algo_runner_manager import get_algo_runner router = APIRouter() diff --git a/app_instance.py b/app_instance.py new file mode 100644 index 0000000..b72926d --- /dev/null +++ b/app_instance.py @@ -0,0 +1,77 @@ +import asyncio +from contextlib import asynccontextmanager + +from fastapi import FastAPI, HTTPException + +from algo.algo_runner import AlgoRunner +from algo.scene_runner import SceneRunner + +from common.biz_exception import BizExceptionHandlers +from common.global_logger import logger +from db.database import get_db +from services.device_model_relation_service import DeviceModelRelationService +from services.device_scene_relation_service import DeviceSceneRelationService +from services.device_service import DeviceService +from services.model_service import ModelService +from services.scene_service import SceneService +from tcp.tcp_manager import TcpManager + +_app = None # 创建一个私有变量来存储 app 实例 + + +def create_app() -> FastAPI: + global _app + if _app is None: + _app = FastAPI() + + @asynccontextmanager + async def lifespan(app: FastAPI): + main_loop = asyncio.get_running_loop() + + with next(get_db()) as db: + device_service = DeviceService(db) + model_service = ModelService(db) + model_relation_service = DeviceModelRelationService(db) + scene_service = SceneService(db) + scene_relation_service = DeviceSceneRelationService(db) + + app.state.device_service = device_service + app.state.model_service = model_service + app.state.model_relation_service = model_relation_service + app.state.scene_service = scene_service + app.state.scene_relation_service = scene_relation_service + + tcp_manager = TcpManager(device_service=device_service) + app.state.tcp_manager = tcp_manager + await tcp_manager.start() + + algo_runner = AlgoRunner( + device_service=device_service, + model_service=model_service, + relation_service=model_relation_service, + ) + app.state.algo_runner = algo_runner + await algo_runner.start() + + scene_runner = SceneRunner( + device_service=device_service, + scene_service=scene_service, + relation_service=scene_relation_service, + tcp_manager=tcp_manager, + main_loop=main_loop + ) + app.state.scene_runner = scene_runner + await scene_runner.start() + + yield # 允许请求处理 + + logger.info("Shutting down application...") + + _app.router.lifespan_context = lifespan + _app.add_exception_handler(HTTPException, BizExceptionHandlers.biz_exception_handler) + + return _app + + +def get_app() -> FastAPI: + return _app or create_app() diff --git a/common/string_utils.py b/common/string_utils.py index 42d987b..7ff7f03 100644 --- a/common/string_utils.py +++ b/common/string_utils.py @@ -1,5 +1,12 @@ +import importlib import re +def get_class(module_name, class_name): + # 动态导入模块 + module = importlib.import_module(module_name) + # 使用 getattr 从模块中获取类 + return getattr(module, class_name) + def camel_to_snake(name): # 将大写字母前加上下划线,并将整个字符串转换为小写 diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db index b1d8fab..d1acdd4 100644 --- a/db/safe-algo-pro.db +++ b/db/safe-algo-pro.db Binary files differ diff --git a/main.py b/main.py index 61cbe9e..8e5e464 100644 --- a/main.py +++ b/main.py @@ -1,58 +1,21 @@ -import asyncio -from contextlib import asynccontextmanager -from fastapi import FastAPI, HTTPException - -from algo.algo_runner_manager import get_algo_runner -from apis.router import router import uvicorn import logging -from common.biz_exception import BizExceptionHandlers +from app_instance import get_app from common.global_logger import logger +app = get_app() -app = FastAPI() - -# # 初始化 AlgoRunner -# algo_runner = AlgoRunner() -# -# -# # 使用 FastAPI 的 startup 事件来启动 AlgoRunner -# @app.on_event("startup") -# async def startup_event(): -# algo_runner.start() - -algo_runner = get_algo_runner() - - -@asynccontextmanager -async def lifespan(app: FastAPI): - # 应用启动时的初始化 - await algo_runner.start() - # app.add_event_handler("startup", lambda: asyncio.create_task(algo_runner.start())) - # 允许请求处理 - yield - # 应用关闭时的清理逻辑 - logger.info("Shutting down application...") - - -# 包含所有模块的路由 +# 延迟导入 router 并注册路由 +from apis.router import router app.include_router(router, prefix="/api") -app.add_exception_handler(HTTPException, BizExceptionHandlers.biz_exception_handler) -app.router.lifespan_context = lifespan - if __name__ == "__main__": - # 重定向 uvicorn 的日志 uvicorn_logger = logging.getLogger("uvicorn") uvicorn_logger.handlers = logger.handlers uvicorn_logger.setLevel(logging.DEBUG) - # 重定向 uvicorn 的 access 日志 - # access_logger = logging.getLogger("uvicorn.access") - # access_logger.handlers = logger.handlers - # access_logger.setLevel(logging.DEBUG) uvicorn.run(app, host="0.0.0.0", port=9000, log_config=None) diff --git a/requirements.txt b/requirements.txt index 9a7d88c..af16eda 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,12 @@ sqlmodel openpyxl python-multipart -docker \ No newline at end of file +docker +numpy +ultralytics +opencv-python +pydantic +pandas +starlette +uvicorn +sqlalchemy \ No newline at end of file diff --git a/scene_handler/base_scene_handler.py b/scene_handler/base_scene_handler.py new file mode 100644 index 0000000..5451a16 --- /dev/null +++ b/scene_handler/base_scene_handler.py @@ -0,0 +1,29 @@ +from asyncio import Event + +from algo.model_manager import AlgoModelExec +from algo.stream_loader import OpenCVStreamLoad +from common.device_status_manager import DeviceStatusManager +from common.global_logger import logger +from entity.device import Device + +from ultralytics import YOLO + +from tcp.tcp_manager import TcpManager + + +class BaseSceneHandler: + + def __init__(self, device: Device, + thread_id: str, + tcp_manager: TcpManager, + main_loop ): + self.device = device + self.thread_id = thread_id + self.tcp_manager =tcp_manager + self.main_loop = main_loop + + def stop_task(self): + pass + + def run(self): + pass \ No newline at end of file diff --git a/scene_handler/limit_space_scene_handler.py b/scene_handler/limit_space_scene_handler.py new file mode 100644 index 0000000..68df7df --- /dev/null +++ b/scene_handler/limit_space_scene_handler.py @@ -0,0 +1,65 @@ +import asyncio +from asyncio import Event + +from algo.model_manager import AlgoModelExec +from algo.stream_loader import OpenCVStreamLoad +from common.device_status_manager import DeviceStatusManager +from common.global_logger import logger +from entity.device import Device + +from ultralytics import YOLO + +from scene_handler.base_scene_handler import BaseSceneHandler +from tcp.tcp_manager import TcpManager + + +class LimitSpaceSceneHandler(BaseSceneHandler): + + def __init__(self, device: Device, thread_id: str, tcp_manager: TcpManager, main_loop): + super().__init__(device=device, thread_id=thread_id, tcp_manager=tcp_manager, main_loop=main_loop) + # self.device = device + # self.thread_id = thread_id + self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, + device_thread_id=thread_id) + self.device_status_manager = DeviceStatusManager() + + self.person_model = YOLO('weights/yolov8s.pt') + + self.__stop_event = Event(loop=main_loop) # 使用 Event 控制线程的运行状态 + + def stop_task(self, **kwargs): + logger.info(f'stop detection task {self.device.id}, thread_id: {self.thread_id}') + self.__stop_event.set() + self.stream_loader.stop() # 停止视频流加载的线程 + + def send_tcp_message(self, message: bytes, have_response=False): + asyncio.run_coroutine_threadsafe( + self.tcp_manager.send_message_to_device(device_id=self.device.id, + message=message, + have_response=have_response), + self.main_loop) + + def run(self): + while not self.stream_loader.init: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + self.stream_loader.init_cap() + for frame in self.stream_loader: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + # print('frame') + if frame is None: + continue + + self.device_status_manager.set_status(device_id=self.device.id) + results = self.person_model.predict(source=frame, imgsz=640, + save_txt=False, + save=False, + verbose=False, stream=True) + result = (list(results)) + if len(result[0]) > 0: + asyncio.run_coroutine_threadsafe( + self.tcp_manager.send_message_to_device(device_id=self.device.id, + message=b'\xaa\x01\x00\x93\x07\x00\x9B', + have_response=False), + self.main_loop) diff --git a/tcp/tcp_client_connector.py b/tcp/tcp_client_connector.py index bda6a95..a62758e 100644 --- a/tcp/tcp_client_connector.py +++ b/tcp/tcp_client_connector.py @@ -139,7 +139,6 @@ # 可以根据需求选择是否接收响应 if have_response: data = await asyncio.wait_for(self.reader.read(1024), timeout=self.timeout) - print(format_bytes(data)) # if not data: # raise ConnectionResetError("Connection lost or no data received") self.parse_response(data) @@ -159,6 +158,6 @@ # asyncio.run(client.connect()) # Run the asynchronous connect method # 示例数据 - data = b'\x07 \x00\x01\x00\x01\xaa\x01\x00"0\r`' + data = b'\x07\x00\x01\x00\x01\xaa\x01\x00"0\r`' result = parse_gas_data(data) print(result) diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml index d0876a7..33f5c97 100644 --- a/.idea/safe-algo-pro.iml +++ b/.idea/safe-algo-pro.iml @@ -5,4 +5,7 @@ + + \ No newline at end of file diff --git a/algo/algo_runner.py b/algo/algo_runner.py index d59d93a..40b8d1e 100644 --- a/algo/algo_runner.py +++ b/algo/algo_runner.py @@ -5,7 +5,7 @@ from algo.device_detection_task import DeviceDetectionTask from algo.model_manager import ModelManager -from common.consts import NotifyChangeType +from common.consts import NotifyChangeType, DEVICE_MODE from common.global_thread_pool import GlobalThreadPool from db.database import get_db from entity.device import Device @@ -14,23 +14,25 @@ from services.model_service import ModelService from common.global_logger import logger - class AlgoRunner: - def __init__(self): - with next(get_db()) as db: - self.device_service = DeviceService(db) - self.model_service = ModelService(db) - self.relation_service = DeviceModelRelationService(db) - self.model_manager = ModelManager(db) + def __init__(self, + device_service : DeviceService, + model_service: ModelService, + relation_service: DeviceModelRelationService): + + self.device_service = device_service + self.model_service = model_service + self.relation_service = relation_service + self.model_manager = ModelManager(model_service) self.thread_pool = GlobalThreadPool() self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程 self.device_model_relations = {} # 注册设备和模型的变化回调 - self.device_service.register_change_callback(self.on_device_change) - self.model_service.register_change_callback(self.on_model_change) - self.relation_service.register_change_callback(self.on_relation_change) + # self.device_service.register_change_callback(self.on_device_change) + # self.model_service.register_change_callback(self.on_model_change) + # self.relation_service.register_change_callback(self.on_relation_change) async def start(self): logger.info("Starting AlgoRunner...") @@ -44,10 +46,6 @@ for device in devices: self.start_device_thread(device) - # @staticmethod - # def get_device_thread_id(device_id): - # return f'device_{device_id}' - def start_device_thread(self, device: Device): device = copy.deepcopy(device) """为单个设备启动检测线程""" @@ -56,7 +54,10 @@ logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') return - thread_id = f'device_{device.id}_{uuid.uuid4()}' + if not device.mode == DEVICE_MODE.ALGO: + return + + thread_id = f'device_{device.id}_algo_{uuid.uuid4()}' logger.info(f'start thread {thread_id}, device info: {device}') # if self.thread_pool.check_task_is_running(thread_id=thread_id): @@ -135,9 +136,10 @@ elif change_type == NotifyChangeType.DEVICE_DELETE: self.stop_device_thread(device_id) elif change_type == NotifyChangeType.DEVICE_UPDATE: - # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 - old_device = self.device_tasks[device_id].device - new_device = self.device_service.get_device(device_id) + if device_id in self.device_tasks: + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) self.restart_device_thread(device_id) diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py index d7d8015..9109d75 100644 --- a/algo/algo_runner_manager.py +++ b/algo/algo_runner_manager.py @@ -1,7 +1,8 @@ from algo.algo_runner import AlgoRunner -algo_runner = AlgoRunner() +# algo_runner = AlgoRunner() def get_algo_runner(): - return algo_runner + pass + # return algo_runner diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index b5a2d9e..e222ca7 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -9,7 +9,7 @@ from common.device_status_manager import DeviceStatusManager from common.global_logger import logger from common.global_thread_pool import GlobalThreadPool -from common.string_utils import camel_to_snake +from common.string_utils import camel_to_snake, get_class from db.database import get_db from entity.device import Device from entity.frame_analysis_result import FrameAnalysisResultCreate @@ -17,11 +17,7 @@ from services.frame_analysis_result_service import FrameAnalysisResultService -def get_class(module_name, class_name): - # 动态导入模块 - module = importlib.import_module(module_name) - # 使用 getattr 从模块中获取类 - return getattr(module, class_name) + @dataclass diff --git a/algo/model_manager.py b/algo/model_manager.py index 94a058c..e3bc4c9 100644 --- a/algo/model_manager.py +++ b/algo/model_manager.py @@ -19,9 +19,9 @@ class ModelManager: - def __init__(self, db, model_warm_up=5): - self.db = db - self.model_service = ModelService(self.db) + def __init__(self, model_service: ModelService, model_warm_up=5): + # self.db = db + self.model_service = model_service self.models: Dict[int, AlgoModelExec] = {} self.model_warm_up = model_warm_up @@ -37,6 +37,8 @@ logger.info('loading models') self.models = {} self.query_model_inuse() + if not self.models: + logger.info("no model in use") for algo_model_id, algo_model_exec in self.models.items(): self.load_model(algo_model_exec) diff --git a/algo/scene_runner.py b/algo/scene_runner.py new file mode 100644 index 0000000..2fb2b2b --- /dev/null +++ b/algo/scene_runner.py @@ -0,0 +1,146 @@ +import concurrent +import copy +import uuid +from typing import Dict + +from common.consts import DEVICE_MODE, NotifyChangeType +from common.global_logger import logger +from common.global_thread_pool import GlobalThreadPool +from common.string_utils import get_class, camel_to_snake +from entity.device import Device +from scene_handler.base_scene_handler import BaseSceneHandler +from services.device_scene_relation_service import DeviceSceneRelationService +from services.device_service import DeviceService +from services.scene_service import SceneService +from tcp.tcp_manager import TcpManager + + +class SceneRunner: + def __init__(self, + device_service: DeviceService, + scene_service: SceneService, + relation_service: DeviceSceneRelationService, + tcp_manager: TcpManager, + main_loop): + self.device_service = device_service + self.scene_service = scene_service + self.relation_service = relation_service + self.tcp_manager = tcp_manager + self.main_loop = main_loop + + self.thread_pool = GlobalThreadPool() + self.device_tasks: Dict[int, BaseSceneHandler] = {} + self.device_scene_relations = {} + + # 注册设备和模型的变化回调 + # self.device_service.register_change_callback(self.on_device_change) + # self.scene_service.register_change_callback(self.on_scene_change) + # self.relation_service.register_change_callback(self.on_relation_change) + + async def start(self): + self.load_and_start_devices() + + def load_and_start_devices(self): + """从数据库读取设备列表并启动线程""" + devices = self.device_service.get_device_list() + devices = [device for device in devices if device.mode == DEVICE_MODE.SCENE] + for device in devices: + self.start_device_thread(device) + + def start_device_thread(self, device: Device): + device = copy.deepcopy(device) + """为单个设备启动检测线程""" + + if not device.input_stream_url: + logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程') + return + + if not device.mode == DEVICE_MODE.SCENE: + return + + thread_id = f'device_{device.id}_scene_{uuid.uuid4()}' + scene = self.relation_service.get_device_scene(device.id) + if scene: + self.device_scene_relations[device.id] = scene + try: + handle_task_name = scene.scene_handle_task + handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name) + handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop) + future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id) + self.device_tasks[device.id] = handler_instance + logger.info(f'start thread {thread_id}, device info: {device}') + except Exception as e: + logger.exception(f"设备 {device.code} 场景 {scene.scene_name} 启动异常 : {e}") + + def stop_device_thread(self, device_id): + """停止指定设备的检测线程,并确保其成功停止""" + if device_id in self.device_tasks: + logger.info(f'stop device {device_id} thread') + # 获取线程 ID 和 future 对象 + thread_id = self.device_tasks[device_id].thread_id + future = self.thread_pool.get_task_future(thread_id=thread_id) + + if future: + if not future.done(): + # 任务正在运行,调用 stop_detection_task 停止任务 + self.device_tasks[device_id].stop_task() + try: + # 设置超时时间等待任务停止(例如10秒) + result = future.result(timeout=30) + logger.info(f"Task {thread_id} stopped successfully.") + except concurrent.futures.TimeoutError as te: + logger.error(f"Task {thread_id} did not stop within the timeout.") + except Exception as e: + logger.exception(f"Task {thread_id} encountered an error while stopping: {e}") + finally: + # 确保无论任务是否停止,都将其从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.info(f"Task {thread_id} has already stopped.") + # 任务已停止,直接从任务列表中移除 + del self.device_tasks[device_id] + else: + logger.warning(f"No task found for {thread_id} .") + else: + logger.warning(f"No task exists for device {device_id} in device_tasks.") + + def restart_device_thread(self, device_id): + try: + self.stop_device_thread(device_id) + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + except Exception as e: + logger.error(e) + + def on_device_change(self, device_id, change_type): + logger.info(f"on device change, device {device_id} {change_type}") + """设备变化回调""" + if change_type == NotifyChangeType.DEVICE_CREATE: + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + elif change_type == NotifyChangeType.DEVICE_DELETE: + self.stop_device_thread(device_id) + elif change_type == NotifyChangeType.DEVICE_UPDATE: + # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启 + if device_id in self.device_tasks: + old_device = self.device_tasks[device_id].device + new_device = self.device_service.get_device(device_id) + + self.restart_device_thread(device_id) + + def on_scene_change(self, scene_id, change_type): + """模型变化回调""" + if change_type == NotifyChangeType.SCENE_UPDATE: + + devices_to_reload = [ + device_id for device_id, relation_info in self.device_scene_relations.items() + if relation_info.scene_id == scene_id + ] + for device_id in devices_to_reload: + self.restart_device_thread(device_id) + + def on_relation_change(self, device_id, change_type): + """设备模型关系变化回调""" + # 判断模型绑定关系是否真的发生了变化,避免不必要的重启 + if change_type == NotifyChangeType.DEVICE_SCENE_RELATION_UPDATE: + self.restart_device_thread(device_id) diff --git a/apis/device.py b/apis/device.py index f786851..85fb10c 100644 --- a/apis/device.py +++ b/apis/device.py @@ -3,15 +3,18 @@ from fastapi import APIRouter, Depends, Query from sqlmodel import Session -from algo.algo_runner import AlgoRunner from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo from services.device_service import DeviceService -from algo.algo_runner_manager import get_algo_runner +from app_instance import get_app router = APIRouter() +app = get_app() + +def get_service(): + return app.state.device_service @router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) @@ -45,15 +48,13 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.device_service +def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): device = service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.device_service +def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): device = service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") @@ -61,8 +62,7 @@ @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.device_service +def delete_device(device_id: int, service: DeviceService = Depends(get_service)): device = service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py index 9152eca..443176b 100644 --- a/apis/device_model_realtion.py +++ b/apis/device_model_realtion.py @@ -8,10 +8,13 @@ from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation from services.device_model_relation_service import DeviceModelRelationService -from algo.algo_runner import AlgoRunner -from algo.algo_runner_manager import get_algo_runner +from app_instance import get_app router = APIRouter() +app = get_app() + +def get_service(): + return app.state.device_model_relation_service @router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) @@ -35,12 +38,10 @@ @router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) def update_by_device(relation_data: List[DeviceModelRelationCreate], device_id: int = Query(...), - algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.relation_service + service: DeviceModelRelationService = Depends(get_service)): relations = service.update_relations_by_device(device_id, relation_data) return standard_response(data=relations) - # @router.delete("/delete_by_device", response_model=StandardResponse[int]) # def delete_device(device_id: int, db: Session = Depends(get_db)): # service = DeviceModelRelationService(db) diff --git a/apis/model.py b/apis/model.py index 264e10a..78c064f 100644 --- a/apis/model.py +++ b/apis/model.py @@ -4,15 +4,17 @@ from sqlmodel import Session from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param +from app_instance import get_app from db.database import get_db from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo from services.model_service import ModelService -from algo.algo_runner import AlgoRunner -from algo.algo_runner_manager import get_algo_runner router = APIRouter() +app = get_app() +def get_service(): + return app.state.model_service @router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) def get_model_list( @@ -60,8 +62,7 @@ @router.post("/update", response_model=StandardResponse[AlgoModelInfo]) def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), file: UploadFile = File(None, description="模型文件"), - algo_runner: AlgoRunner = Depends(get_algo_runner)): - service = algo_runner.model_service + service: ModelService = Depends(get_service)): model_data = AlgoModelUpdate.parse_raw(json_data) model = service.update_model(model_data, file) if not model: diff --git a/apis/scene.py b/apis/scene.py index b0a4d22..f84b54b 100644 --- a/apis/scene.py +++ b/apis/scene.py @@ -9,8 +9,6 @@ from entity.scene import SceneInfo from services.scene_service import SceneService -from algo.algo_runner import AlgoRunner -from algo.algo_runner_manager import get_algo_runner router = APIRouter() diff --git a/app_instance.py b/app_instance.py new file mode 100644 index 0000000..b72926d --- /dev/null +++ b/app_instance.py @@ -0,0 +1,77 @@ +import asyncio +from contextlib import asynccontextmanager + +from fastapi import FastAPI, HTTPException + +from algo.algo_runner import AlgoRunner +from algo.scene_runner import SceneRunner + +from common.biz_exception import BizExceptionHandlers +from common.global_logger import logger +from db.database import get_db +from services.device_model_relation_service import DeviceModelRelationService +from services.device_scene_relation_service import DeviceSceneRelationService +from services.device_service import DeviceService +from services.model_service import ModelService +from services.scene_service import SceneService +from tcp.tcp_manager import TcpManager + +_app = None # 创建一个私有变量来存储 app 实例 + + +def create_app() -> FastAPI: + global _app + if _app is None: + _app = FastAPI() + + @asynccontextmanager + async def lifespan(app: FastAPI): + main_loop = asyncio.get_running_loop() + + with next(get_db()) as db: + device_service = DeviceService(db) + model_service = ModelService(db) + model_relation_service = DeviceModelRelationService(db) + scene_service = SceneService(db) + scene_relation_service = DeviceSceneRelationService(db) + + app.state.device_service = device_service + app.state.model_service = model_service + app.state.model_relation_service = model_relation_service + app.state.scene_service = scene_service + app.state.scene_relation_service = scene_relation_service + + tcp_manager = TcpManager(device_service=device_service) + app.state.tcp_manager = tcp_manager + await tcp_manager.start() + + algo_runner = AlgoRunner( + device_service=device_service, + model_service=model_service, + relation_service=model_relation_service, + ) + app.state.algo_runner = algo_runner + await algo_runner.start() + + scene_runner = SceneRunner( + device_service=device_service, + scene_service=scene_service, + relation_service=scene_relation_service, + tcp_manager=tcp_manager, + main_loop=main_loop + ) + app.state.scene_runner = scene_runner + await scene_runner.start() + + yield # 允许请求处理 + + logger.info("Shutting down application...") + + _app.router.lifespan_context = lifespan + _app.add_exception_handler(HTTPException, BizExceptionHandlers.biz_exception_handler) + + return _app + + +def get_app() -> FastAPI: + return _app or create_app() diff --git a/common/string_utils.py b/common/string_utils.py index 42d987b..7ff7f03 100644 --- a/common/string_utils.py +++ b/common/string_utils.py @@ -1,5 +1,12 @@ +import importlib import re +def get_class(module_name, class_name): + # 动态导入模块 + module = importlib.import_module(module_name) + # 使用 getattr 从模块中获取类 + return getattr(module, class_name) + def camel_to_snake(name): # 将大写字母前加上下划线,并将整个字符串转换为小写 diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db index b1d8fab..d1acdd4 100644 --- a/db/safe-algo-pro.db +++ b/db/safe-algo-pro.db Binary files differ diff --git a/main.py b/main.py index 61cbe9e..8e5e464 100644 --- a/main.py +++ b/main.py @@ -1,58 +1,21 @@ -import asyncio -from contextlib import asynccontextmanager -from fastapi import FastAPI, HTTPException - -from algo.algo_runner_manager import get_algo_runner -from apis.router import router import uvicorn import logging -from common.biz_exception import BizExceptionHandlers +from app_instance import get_app from common.global_logger import logger +app = get_app() -app = FastAPI() - -# # 初始化 AlgoRunner -# algo_runner = AlgoRunner() -# -# -# # 使用 FastAPI 的 startup 事件来启动 AlgoRunner -# @app.on_event("startup") -# async def startup_event(): -# algo_runner.start() - -algo_runner = get_algo_runner() - - -@asynccontextmanager -async def lifespan(app: FastAPI): - # 应用启动时的初始化 - await algo_runner.start() - # app.add_event_handler("startup", lambda: asyncio.create_task(algo_runner.start())) - # 允许请求处理 - yield - # 应用关闭时的清理逻辑 - logger.info("Shutting down application...") - - -# 包含所有模块的路由 +# 延迟导入 router 并注册路由 +from apis.router import router app.include_router(router, prefix="/api") -app.add_exception_handler(HTTPException, BizExceptionHandlers.biz_exception_handler) -app.router.lifespan_context = lifespan - if __name__ == "__main__": - # 重定向 uvicorn 的日志 uvicorn_logger = logging.getLogger("uvicorn") uvicorn_logger.handlers = logger.handlers uvicorn_logger.setLevel(logging.DEBUG) - # 重定向 uvicorn 的 access 日志 - # access_logger = logging.getLogger("uvicorn.access") - # access_logger.handlers = logger.handlers - # access_logger.setLevel(logging.DEBUG) uvicorn.run(app, host="0.0.0.0", port=9000, log_config=None) diff --git a/requirements.txt b/requirements.txt index 9a7d88c..af16eda 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,12 @@ sqlmodel openpyxl python-multipart -docker \ No newline at end of file +docker +numpy +ultralytics +opencv-python +pydantic +pandas +starlette +uvicorn +sqlalchemy \ No newline at end of file diff --git a/scene_handler/base_scene_handler.py b/scene_handler/base_scene_handler.py new file mode 100644 index 0000000..5451a16 --- /dev/null +++ b/scene_handler/base_scene_handler.py @@ -0,0 +1,29 @@ +from asyncio import Event + +from algo.model_manager import AlgoModelExec +from algo.stream_loader import OpenCVStreamLoad +from common.device_status_manager import DeviceStatusManager +from common.global_logger import logger +from entity.device import Device + +from ultralytics import YOLO + +from tcp.tcp_manager import TcpManager + + +class BaseSceneHandler: + + def __init__(self, device: Device, + thread_id: str, + tcp_manager: TcpManager, + main_loop ): + self.device = device + self.thread_id = thread_id + self.tcp_manager =tcp_manager + self.main_loop = main_loop + + def stop_task(self): + pass + + def run(self): + pass \ No newline at end of file diff --git a/scene_handler/limit_space_scene_handler.py b/scene_handler/limit_space_scene_handler.py new file mode 100644 index 0000000..68df7df --- /dev/null +++ b/scene_handler/limit_space_scene_handler.py @@ -0,0 +1,65 @@ +import asyncio +from asyncio import Event + +from algo.model_manager import AlgoModelExec +from algo.stream_loader import OpenCVStreamLoad +from common.device_status_manager import DeviceStatusManager +from common.global_logger import logger +from entity.device import Device + +from ultralytics import YOLO + +from scene_handler.base_scene_handler import BaseSceneHandler +from tcp.tcp_manager import TcpManager + + +class LimitSpaceSceneHandler(BaseSceneHandler): + + def __init__(self, device: Device, thread_id: str, tcp_manager: TcpManager, main_loop): + super().__init__(device=device, thread_id=thread_id, tcp_manager=tcp_manager, main_loop=main_loop) + # self.device = device + # self.thread_id = thread_id + self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code, + device_thread_id=thread_id) + self.device_status_manager = DeviceStatusManager() + + self.person_model = YOLO('weights/yolov8s.pt') + + self.__stop_event = Event(loop=main_loop) # 使用 Event 控制线程的运行状态 + + def stop_task(self, **kwargs): + logger.info(f'stop detection task {self.device.id}, thread_id: {self.thread_id}') + self.__stop_event.set() + self.stream_loader.stop() # 停止视频流加载的线程 + + def send_tcp_message(self, message: bytes, have_response=False): + asyncio.run_coroutine_threadsafe( + self.tcp_manager.send_message_to_device(device_id=self.device.id, + message=message, + have_response=have_response), + self.main_loop) + + def run(self): + while not self.stream_loader.init: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + self.stream_loader.init_cap() + for frame in self.stream_loader: + if self.__stop_event.is_set(): + break # 如果触发了停止事件,则退出循环 + # print('frame') + if frame is None: + continue + + self.device_status_manager.set_status(device_id=self.device.id) + results = self.person_model.predict(source=frame, imgsz=640, + save_txt=False, + save=False, + verbose=False, stream=True) + result = (list(results)) + if len(result[0]) > 0: + asyncio.run_coroutine_threadsafe( + self.tcp_manager.send_message_to_device(device_id=self.device.id, + message=b'\xaa\x01\x00\x93\x07\x00\x9B', + have_response=False), + self.main_loop) diff --git a/tcp/tcp_client_connector.py b/tcp/tcp_client_connector.py index bda6a95..a62758e 100644 --- a/tcp/tcp_client_connector.py +++ b/tcp/tcp_client_connector.py @@ -139,7 +139,6 @@ # 可以根据需求选择是否接收响应 if have_response: data = await asyncio.wait_for(self.reader.read(1024), timeout=self.timeout) - print(format_bytes(data)) # if not data: # raise ConnectionResetError("Connection lost or no data received") self.parse_response(data) @@ -159,6 +158,6 @@ # asyncio.run(client.connect()) # Run the asynchronous connect method # 示例数据 - data = b'\x07 \x00\x01\x00\x01\xaa\x01\x00"0\r`' + data = b'\x07\x00\x01\x00\x01\xaa\x01\x00"0\r`' result = parse_gas_data(data) print(result) diff --git a/tcp/tcp_manager.py b/tcp/tcp_manager.py index 35ab60a..f28446d 100644 --- a/tcp/tcp_manager.py +++ b/tcp/tcp_manager.py @@ -1,28 +1,24 @@ import asyncio from typing import List, Dict -from algo.algo_runner_manager import get_algo_runner + from common.consts import DEVICE_TYPE, NotifyChangeType -from db.database import get_db from entity.device import Device from services.device_service import DeviceService - from common.global_logger import logger from tcp.tcp_client_connector import TcpClientConnector class TcpManager: - def __init__(self): + def __init__(self, device_service: DeviceService): self.devices: List[Device] = [] self.connector_map: Dict[int, TcpClientConnector] = {} - # 从全局algo_runner中获取device_service,确保能收到设备更新通知 - algo_runner = get_algo_runner() - self.device_service = algo_runner.device_service + self.device_service = device_service # 注册设备和模型的变化回调 - self.device_service.register_change_callback(self.on_device_change) + # self.device_service.register_change_callback(self.on_device_change) async def load_and_connect_devices(self): """从数据库加载设备并连接所有设备""" @@ -55,7 +51,7 @@ await self.start_device_connect(device) async def on_device_change(self, device_id, change_type): - """设备变化时的回调处理 todo 线程处理待优化""" + """设备变化时的回调处理""" if change_type == NotifyChangeType.DEVICE_CREATE: # 新增设备,加载新设备并连接 new_device = self.device_service.get_device(device_id) @@ -72,6 +68,16 @@ """断开设备连接""" await connector.disconnect() + async def send_message_to_device(self, device_id, message: bytes, have_response): + if device_id not in self.connector_map: + device = self.device_service.get_device(device_id) + await self.start_device_connect(device) + connector = self.connector_map[device_id] + if connector: + await connector.send_message(message, have_response=have_response) + + + if __name__ == '__main__': tcp_manager = TcpManager() asyncio.run(tcp_manager.start())