diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml
index d0876a7..33f5c97 100644
--- a/.idea/safe-algo-pro.iml
+++ b/.idea/safe-algo-pro.iml
@@ -5,4 +5,7 @@
+
+
+
\ No newline at end of file
diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml
index d0876a7..33f5c97 100644
--- a/.idea/safe-algo-pro.iml
+++ b/.idea/safe-algo-pro.iml
@@ -5,4 +5,7 @@
+
+
+
\ No newline at end of file
diff --git a/algo/algo_runner.py b/algo/algo_runner.py
index d59d93a..40b8d1e 100644
--- a/algo/algo_runner.py
+++ b/algo/algo_runner.py
@@ -5,7 +5,7 @@
from algo.device_detection_task import DeviceDetectionTask
from algo.model_manager import ModelManager
-from common.consts import NotifyChangeType
+from common.consts import NotifyChangeType, DEVICE_MODE
from common.global_thread_pool import GlobalThreadPool
from db.database import get_db
from entity.device import Device
@@ -14,23 +14,25 @@
from services.model_service import ModelService
from common.global_logger import logger
-
class AlgoRunner:
- def __init__(self):
- with next(get_db()) as db:
- self.device_service = DeviceService(db)
- self.model_service = ModelService(db)
- self.relation_service = DeviceModelRelationService(db)
- self.model_manager = ModelManager(db)
+ def __init__(self,
+ device_service : DeviceService,
+ model_service: ModelService,
+ relation_service: DeviceModelRelationService):
+
+ self.device_service = device_service
+ self.model_service = model_service
+ self.relation_service = relation_service
+ self.model_manager = ModelManager(model_service)
self.thread_pool = GlobalThreadPool()
self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程
self.device_model_relations = {}
# 注册设备和模型的变化回调
- self.device_service.register_change_callback(self.on_device_change)
- self.model_service.register_change_callback(self.on_model_change)
- self.relation_service.register_change_callback(self.on_relation_change)
+ # self.device_service.register_change_callback(self.on_device_change)
+ # self.model_service.register_change_callback(self.on_model_change)
+ # self.relation_service.register_change_callback(self.on_relation_change)
async def start(self):
logger.info("Starting AlgoRunner...")
@@ -44,10 +46,6 @@
for device in devices:
self.start_device_thread(device)
- # @staticmethod
- # def get_device_thread_id(device_id):
- # return f'device_{device_id}'
-
def start_device_thread(self, device: Device):
device = copy.deepcopy(device)
"""为单个设备启动检测线程"""
@@ -56,7 +54,10 @@
logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程')
return
- thread_id = f'device_{device.id}_{uuid.uuid4()}'
+ if not device.mode == DEVICE_MODE.ALGO:
+ return
+
+ thread_id = f'device_{device.id}_algo_{uuid.uuid4()}'
logger.info(f'start thread {thread_id}, device info: {device}')
# if self.thread_pool.check_task_is_running(thread_id=thread_id):
@@ -135,9 +136,10 @@
elif change_type == NotifyChangeType.DEVICE_DELETE:
self.stop_device_thread(device_id)
elif change_type == NotifyChangeType.DEVICE_UPDATE:
- # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
- old_device = self.device_tasks[device_id].device
- new_device = self.device_service.get_device(device_id)
+ if device_id in self.device_tasks:
+ # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
+ old_device = self.device_tasks[device_id].device
+ new_device = self.device_service.get_device(device_id)
self.restart_device_thread(device_id)
diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml
index d0876a7..33f5c97 100644
--- a/.idea/safe-algo-pro.iml
+++ b/.idea/safe-algo-pro.iml
@@ -5,4 +5,7 @@
+
+
+
\ No newline at end of file
diff --git a/algo/algo_runner.py b/algo/algo_runner.py
index d59d93a..40b8d1e 100644
--- a/algo/algo_runner.py
+++ b/algo/algo_runner.py
@@ -5,7 +5,7 @@
from algo.device_detection_task import DeviceDetectionTask
from algo.model_manager import ModelManager
-from common.consts import NotifyChangeType
+from common.consts import NotifyChangeType, DEVICE_MODE
from common.global_thread_pool import GlobalThreadPool
from db.database import get_db
from entity.device import Device
@@ -14,23 +14,25 @@
from services.model_service import ModelService
from common.global_logger import logger
-
class AlgoRunner:
- def __init__(self):
- with next(get_db()) as db:
- self.device_service = DeviceService(db)
- self.model_service = ModelService(db)
- self.relation_service = DeviceModelRelationService(db)
- self.model_manager = ModelManager(db)
+ def __init__(self,
+ device_service : DeviceService,
+ model_service: ModelService,
+ relation_service: DeviceModelRelationService):
+
+ self.device_service = device_service
+ self.model_service = model_service
+ self.relation_service = relation_service
+ self.model_manager = ModelManager(model_service)
self.thread_pool = GlobalThreadPool()
self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程
self.device_model_relations = {}
# 注册设备和模型的变化回调
- self.device_service.register_change_callback(self.on_device_change)
- self.model_service.register_change_callback(self.on_model_change)
- self.relation_service.register_change_callback(self.on_relation_change)
+ # self.device_service.register_change_callback(self.on_device_change)
+ # self.model_service.register_change_callback(self.on_model_change)
+ # self.relation_service.register_change_callback(self.on_relation_change)
async def start(self):
logger.info("Starting AlgoRunner...")
@@ -44,10 +46,6 @@
for device in devices:
self.start_device_thread(device)
- # @staticmethod
- # def get_device_thread_id(device_id):
- # return f'device_{device_id}'
-
def start_device_thread(self, device: Device):
device = copy.deepcopy(device)
"""为单个设备启动检测线程"""
@@ -56,7 +54,10 @@
logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程')
return
- thread_id = f'device_{device.id}_{uuid.uuid4()}'
+ if not device.mode == DEVICE_MODE.ALGO:
+ return
+
+ thread_id = f'device_{device.id}_algo_{uuid.uuid4()}'
logger.info(f'start thread {thread_id}, device info: {device}')
# if self.thread_pool.check_task_is_running(thread_id=thread_id):
@@ -135,9 +136,10 @@
elif change_type == NotifyChangeType.DEVICE_DELETE:
self.stop_device_thread(device_id)
elif change_type == NotifyChangeType.DEVICE_UPDATE:
- # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
- old_device = self.device_tasks[device_id].device
- new_device = self.device_service.get_device(device_id)
+ if device_id in self.device_tasks:
+ # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
+ old_device = self.device_tasks[device_id].device
+ new_device = self.device_service.get_device(device_id)
self.restart_device_thread(device_id)
diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py
index d7d8015..9109d75 100644
--- a/algo/algo_runner_manager.py
+++ b/algo/algo_runner_manager.py
@@ -1,7 +1,8 @@
from algo.algo_runner import AlgoRunner
-algo_runner = AlgoRunner()
+# algo_runner = AlgoRunner()
def get_algo_runner():
- return algo_runner
+ pass
+ # return algo_runner
diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml
index d0876a7..33f5c97 100644
--- a/.idea/safe-algo-pro.iml
+++ b/.idea/safe-algo-pro.iml
@@ -5,4 +5,7 @@
+
+
+
\ No newline at end of file
diff --git a/algo/algo_runner.py b/algo/algo_runner.py
index d59d93a..40b8d1e 100644
--- a/algo/algo_runner.py
+++ b/algo/algo_runner.py
@@ -5,7 +5,7 @@
from algo.device_detection_task import DeviceDetectionTask
from algo.model_manager import ModelManager
-from common.consts import NotifyChangeType
+from common.consts import NotifyChangeType, DEVICE_MODE
from common.global_thread_pool import GlobalThreadPool
from db.database import get_db
from entity.device import Device
@@ -14,23 +14,25 @@
from services.model_service import ModelService
from common.global_logger import logger
-
class AlgoRunner:
- def __init__(self):
- with next(get_db()) as db:
- self.device_service = DeviceService(db)
- self.model_service = ModelService(db)
- self.relation_service = DeviceModelRelationService(db)
- self.model_manager = ModelManager(db)
+ def __init__(self,
+ device_service : DeviceService,
+ model_service: ModelService,
+ relation_service: DeviceModelRelationService):
+
+ self.device_service = device_service
+ self.model_service = model_service
+ self.relation_service = relation_service
+ self.model_manager = ModelManager(model_service)
self.thread_pool = GlobalThreadPool()
self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程
self.device_model_relations = {}
# 注册设备和模型的变化回调
- self.device_service.register_change_callback(self.on_device_change)
- self.model_service.register_change_callback(self.on_model_change)
- self.relation_service.register_change_callback(self.on_relation_change)
+ # self.device_service.register_change_callback(self.on_device_change)
+ # self.model_service.register_change_callback(self.on_model_change)
+ # self.relation_service.register_change_callback(self.on_relation_change)
async def start(self):
logger.info("Starting AlgoRunner...")
@@ -44,10 +46,6 @@
for device in devices:
self.start_device_thread(device)
- # @staticmethod
- # def get_device_thread_id(device_id):
- # return f'device_{device_id}'
-
def start_device_thread(self, device: Device):
device = copy.deepcopy(device)
"""为单个设备启动检测线程"""
@@ -56,7 +54,10 @@
logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程')
return
- thread_id = f'device_{device.id}_{uuid.uuid4()}'
+ if not device.mode == DEVICE_MODE.ALGO:
+ return
+
+ thread_id = f'device_{device.id}_algo_{uuid.uuid4()}'
logger.info(f'start thread {thread_id}, device info: {device}')
# if self.thread_pool.check_task_is_running(thread_id=thread_id):
@@ -135,9 +136,10 @@
elif change_type == NotifyChangeType.DEVICE_DELETE:
self.stop_device_thread(device_id)
elif change_type == NotifyChangeType.DEVICE_UPDATE:
- # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
- old_device = self.device_tasks[device_id].device
- new_device = self.device_service.get_device(device_id)
+ if device_id in self.device_tasks:
+ # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
+ old_device = self.device_tasks[device_id].device
+ new_device = self.device_service.get_device(device_id)
self.restart_device_thread(device_id)
diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py
index d7d8015..9109d75 100644
--- a/algo/algo_runner_manager.py
+++ b/algo/algo_runner_manager.py
@@ -1,7 +1,8 @@
from algo.algo_runner import AlgoRunner
-algo_runner = AlgoRunner()
+# algo_runner = AlgoRunner()
def get_algo_runner():
- return algo_runner
+ pass
+ # return algo_runner
diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py
index b5a2d9e..e222ca7 100644
--- a/algo/device_detection_task.py
+++ b/algo/device_detection_task.py
@@ -9,7 +9,7 @@
from common.device_status_manager import DeviceStatusManager
from common.global_logger import logger
from common.global_thread_pool import GlobalThreadPool
-from common.string_utils import camel_to_snake
+from common.string_utils import camel_to_snake, get_class
from db.database import get_db
from entity.device import Device
from entity.frame_analysis_result import FrameAnalysisResultCreate
@@ -17,11 +17,7 @@
from services.frame_analysis_result_service import FrameAnalysisResultService
-def get_class(module_name, class_name):
- # 动态导入模块
- module = importlib.import_module(module_name)
- # 使用 getattr 从模块中获取类
- return getattr(module, class_name)
+
@dataclass
diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml
index d0876a7..33f5c97 100644
--- a/.idea/safe-algo-pro.iml
+++ b/.idea/safe-algo-pro.iml
@@ -5,4 +5,7 @@
+
+
+
\ No newline at end of file
diff --git a/algo/algo_runner.py b/algo/algo_runner.py
index d59d93a..40b8d1e 100644
--- a/algo/algo_runner.py
+++ b/algo/algo_runner.py
@@ -5,7 +5,7 @@
from algo.device_detection_task import DeviceDetectionTask
from algo.model_manager import ModelManager
-from common.consts import NotifyChangeType
+from common.consts import NotifyChangeType, DEVICE_MODE
from common.global_thread_pool import GlobalThreadPool
from db.database import get_db
from entity.device import Device
@@ -14,23 +14,25 @@
from services.model_service import ModelService
from common.global_logger import logger
-
class AlgoRunner:
- def __init__(self):
- with next(get_db()) as db:
- self.device_service = DeviceService(db)
- self.model_service = ModelService(db)
- self.relation_service = DeviceModelRelationService(db)
- self.model_manager = ModelManager(db)
+ def __init__(self,
+ device_service : DeviceService,
+ model_service: ModelService,
+ relation_service: DeviceModelRelationService):
+
+ self.device_service = device_service
+ self.model_service = model_service
+ self.relation_service = relation_service
+ self.model_manager = ModelManager(model_service)
self.thread_pool = GlobalThreadPool()
self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程
self.device_model_relations = {}
# 注册设备和模型的变化回调
- self.device_service.register_change_callback(self.on_device_change)
- self.model_service.register_change_callback(self.on_model_change)
- self.relation_service.register_change_callback(self.on_relation_change)
+ # self.device_service.register_change_callback(self.on_device_change)
+ # self.model_service.register_change_callback(self.on_model_change)
+ # self.relation_service.register_change_callback(self.on_relation_change)
async def start(self):
logger.info("Starting AlgoRunner...")
@@ -44,10 +46,6 @@
for device in devices:
self.start_device_thread(device)
- # @staticmethod
- # def get_device_thread_id(device_id):
- # return f'device_{device_id}'
-
def start_device_thread(self, device: Device):
device = copy.deepcopy(device)
"""为单个设备启动检测线程"""
@@ -56,7 +54,10 @@
logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程')
return
- thread_id = f'device_{device.id}_{uuid.uuid4()}'
+ if not device.mode == DEVICE_MODE.ALGO:
+ return
+
+ thread_id = f'device_{device.id}_algo_{uuid.uuid4()}'
logger.info(f'start thread {thread_id}, device info: {device}')
# if self.thread_pool.check_task_is_running(thread_id=thread_id):
@@ -135,9 +136,10 @@
elif change_type == NotifyChangeType.DEVICE_DELETE:
self.stop_device_thread(device_id)
elif change_type == NotifyChangeType.DEVICE_UPDATE:
- # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
- old_device = self.device_tasks[device_id].device
- new_device = self.device_service.get_device(device_id)
+ if device_id in self.device_tasks:
+ # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
+ old_device = self.device_tasks[device_id].device
+ new_device = self.device_service.get_device(device_id)
self.restart_device_thread(device_id)
diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py
index d7d8015..9109d75 100644
--- a/algo/algo_runner_manager.py
+++ b/algo/algo_runner_manager.py
@@ -1,7 +1,8 @@
from algo.algo_runner import AlgoRunner
-algo_runner = AlgoRunner()
+# algo_runner = AlgoRunner()
def get_algo_runner():
- return algo_runner
+ pass
+ # return algo_runner
diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py
index b5a2d9e..e222ca7 100644
--- a/algo/device_detection_task.py
+++ b/algo/device_detection_task.py
@@ -9,7 +9,7 @@
from common.device_status_manager import DeviceStatusManager
from common.global_logger import logger
from common.global_thread_pool import GlobalThreadPool
-from common.string_utils import camel_to_snake
+from common.string_utils import camel_to_snake, get_class
from db.database import get_db
from entity.device import Device
from entity.frame_analysis_result import FrameAnalysisResultCreate
@@ -17,11 +17,7 @@
from services.frame_analysis_result_service import FrameAnalysisResultService
-def get_class(module_name, class_name):
- # 动态导入模块
- module = importlib.import_module(module_name)
- # 使用 getattr 从模块中获取类
- return getattr(module, class_name)
+
@dataclass
diff --git a/algo/model_manager.py b/algo/model_manager.py
index 94a058c..e3bc4c9 100644
--- a/algo/model_manager.py
+++ b/algo/model_manager.py
@@ -19,9 +19,9 @@
class ModelManager:
- def __init__(self, db, model_warm_up=5):
- self.db = db
- self.model_service = ModelService(self.db)
+ def __init__(self, model_service: ModelService, model_warm_up=5):
+ # self.db = db
+ self.model_service = model_service
self.models: Dict[int, AlgoModelExec] = {}
self.model_warm_up = model_warm_up
@@ -37,6 +37,8 @@
logger.info('loading models')
self.models = {}
self.query_model_inuse()
+ if not self.models:
+ logger.info("no model in use")
for algo_model_id, algo_model_exec in self.models.items():
self.load_model(algo_model_exec)
diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml
index d0876a7..33f5c97 100644
--- a/.idea/safe-algo-pro.iml
+++ b/.idea/safe-algo-pro.iml
@@ -5,4 +5,7 @@
+
+
+
\ No newline at end of file
diff --git a/algo/algo_runner.py b/algo/algo_runner.py
index d59d93a..40b8d1e 100644
--- a/algo/algo_runner.py
+++ b/algo/algo_runner.py
@@ -5,7 +5,7 @@
from algo.device_detection_task import DeviceDetectionTask
from algo.model_manager import ModelManager
-from common.consts import NotifyChangeType
+from common.consts import NotifyChangeType, DEVICE_MODE
from common.global_thread_pool import GlobalThreadPool
from db.database import get_db
from entity.device import Device
@@ -14,23 +14,25 @@
from services.model_service import ModelService
from common.global_logger import logger
-
class AlgoRunner:
- def __init__(self):
- with next(get_db()) as db:
- self.device_service = DeviceService(db)
- self.model_service = ModelService(db)
- self.relation_service = DeviceModelRelationService(db)
- self.model_manager = ModelManager(db)
+ def __init__(self,
+ device_service : DeviceService,
+ model_service: ModelService,
+ relation_service: DeviceModelRelationService):
+
+ self.device_service = device_service
+ self.model_service = model_service
+ self.relation_service = relation_service
+ self.model_manager = ModelManager(model_service)
self.thread_pool = GlobalThreadPool()
self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程
self.device_model_relations = {}
# 注册设备和模型的变化回调
- self.device_service.register_change_callback(self.on_device_change)
- self.model_service.register_change_callback(self.on_model_change)
- self.relation_service.register_change_callback(self.on_relation_change)
+ # self.device_service.register_change_callback(self.on_device_change)
+ # self.model_service.register_change_callback(self.on_model_change)
+ # self.relation_service.register_change_callback(self.on_relation_change)
async def start(self):
logger.info("Starting AlgoRunner...")
@@ -44,10 +46,6 @@
for device in devices:
self.start_device_thread(device)
- # @staticmethod
- # def get_device_thread_id(device_id):
- # return f'device_{device_id}'
-
def start_device_thread(self, device: Device):
device = copy.deepcopy(device)
"""为单个设备启动检测线程"""
@@ -56,7 +54,10 @@
logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程')
return
- thread_id = f'device_{device.id}_{uuid.uuid4()}'
+ if not device.mode == DEVICE_MODE.ALGO:
+ return
+
+ thread_id = f'device_{device.id}_algo_{uuid.uuid4()}'
logger.info(f'start thread {thread_id}, device info: {device}')
# if self.thread_pool.check_task_is_running(thread_id=thread_id):
@@ -135,9 +136,10 @@
elif change_type == NotifyChangeType.DEVICE_DELETE:
self.stop_device_thread(device_id)
elif change_type == NotifyChangeType.DEVICE_UPDATE:
- # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
- old_device = self.device_tasks[device_id].device
- new_device = self.device_service.get_device(device_id)
+ if device_id in self.device_tasks:
+ # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
+ old_device = self.device_tasks[device_id].device
+ new_device = self.device_service.get_device(device_id)
self.restart_device_thread(device_id)
diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py
index d7d8015..9109d75 100644
--- a/algo/algo_runner_manager.py
+++ b/algo/algo_runner_manager.py
@@ -1,7 +1,8 @@
from algo.algo_runner import AlgoRunner
-algo_runner = AlgoRunner()
+# algo_runner = AlgoRunner()
def get_algo_runner():
- return algo_runner
+ pass
+ # return algo_runner
diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py
index b5a2d9e..e222ca7 100644
--- a/algo/device_detection_task.py
+++ b/algo/device_detection_task.py
@@ -9,7 +9,7 @@
from common.device_status_manager import DeviceStatusManager
from common.global_logger import logger
from common.global_thread_pool import GlobalThreadPool
-from common.string_utils import camel_to_snake
+from common.string_utils import camel_to_snake, get_class
from db.database import get_db
from entity.device import Device
from entity.frame_analysis_result import FrameAnalysisResultCreate
@@ -17,11 +17,7 @@
from services.frame_analysis_result_service import FrameAnalysisResultService
-def get_class(module_name, class_name):
- # 动态导入模块
- module = importlib.import_module(module_name)
- # 使用 getattr 从模块中获取类
- return getattr(module, class_name)
+
@dataclass
diff --git a/algo/model_manager.py b/algo/model_manager.py
index 94a058c..e3bc4c9 100644
--- a/algo/model_manager.py
+++ b/algo/model_manager.py
@@ -19,9 +19,9 @@
class ModelManager:
- def __init__(self, db, model_warm_up=5):
- self.db = db
- self.model_service = ModelService(self.db)
+ def __init__(self, model_service: ModelService, model_warm_up=5):
+ # self.db = db
+ self.model_service = model_service
self.models: Dict[int, AlgoModelExec] = {}
self.model_warm_up = model_warm_up
@@ -37,6 +37,8 @@
logger.info('loading models')
self.models = {}
self.query_model_inuse()
+ if not self.models:
+ logger.info("no model in use")
for algo_model_id, algo_model_exec in self.models.items():
self.load_model(algo_model_exec)
diff --git a/algo/scene_runner.py b/algo/scene_runner.py
new file mode 100644
index 0000000..2fb2b2b
--- /dev/null
+++ b/algo/scene_runner.py
@@ -0,0 +1,146 @@
+import concurrent
+import copy
+import uuid
+from typing import Dict
+
+from common.consts import DEVICE_MODE, NotifyChangeType
+from common.global_logger import logger
+from common.global_thread_pool import GlobalThreadPool
+from common.string_utils import get_class, camel_to_snake
+from entity.device import Device
+from scene_handler.base_scene_handler import BaseSceneHandler
+from services.device_scene_relation_service import DeviceSceneRelationService
+from services.device_service import DeviceService
+from services.scene_service import SceneService
+from tcp.tcp_manager import TcpManager
+
+
+class SceneRunner:
+ def __init__(self,
+ device_service: DeviceService,
+ scene_service: SceneService,
+ relation_service: DeviceSceneRelationService,
+ tcp_manager: TcpManager,
+ main_loop):
+ self.device_service = device_service
+ self.scene_service = scene_service
+ self.relation_service = relation_service
+ self.tcp_manager = tcp_manager
+ self.main_loop = main_loop
+
+ self.thread_pool = GlobalThreadPool()
+ self.device_tasks: Dict[int, BaseSceneHandler] = {}
+ self.device_scene_relations = {}
+
+ # 注册设备和模型的变化回调
+ # self.device_service.register_change_callback(self.on_device_change)
+ # self.scene_service.register_change_callback(self.on_scene_change)
+ # self.relation_service.register_change_callback(self.on_relation_change)
+
+ async def start(self):
+ self.load_and_start_devices()
+
+ def load_and_start_devices(self):
+ """从数据库读取设备列表并启动线程"""
+ devices = self.device_service.get_device_list()
+ devices = [device for device in devices if device.mode == DEVICE_MODE.SCENE]
+ for device in devices:
+ self.start_device_thread(device)
+
+ def start_device_thread(self, device: Device):
+ device = copy.deepcopy(device)
+ """为单个设备启动检测线程"""
+
+ if not device.input_stream_url:
+ logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程')
+ return
+
+ if not device.mode == DEVICE_MODE.SCENE:
+ return
+
+ thread_id = f'device_{device.id}_scene_{uuid.uuid4()}'
+ scene = self.relation_service.get_device_scene(device.id)
+ if scene:
+ self.device_scene_relations[device.id] = scene
+ try:
+ handle_task_name = scene.scene_handle_task
+ handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name)
+ handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop)
+ future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id)
+ self.device_tasks[device.id] = handler_instance
+ logger.info(f'start thread {thread_id}, device info: {device}')
+ except Exception as e:
+ logger.exception(f"设备 {device.code} 场景 {scene.scene_name} 启动异常 : {e}")
+
+ def stop_device_thread(self, device_id):
+ """停止指定设备的检测线程,并确保其成功停止"""
+ if device_id in self.device_tasks:
+ logger.info(f'stop device {device_id} thread')
+ # 获取线程 ID 和 future 对象
+ thread_id = self.device_tasks[device_id].thread_id
+ future = self.thread_pool.get_task_future(thread_id=thread_id)
+
+ if future:
+ if not future.done():
+ # 任务正在运行,调用 stop_detection_task 停止任务
+ self.device_tasks[device_id].stop_task()
+ try:
+ # 设置超时时间等待任务停止(例如10秒)
+ result = future.result(timeout=30)
+ logger.info(f"Task {thread_id} stopped successfully.")
+ except concurrent.futures.TimeoutError as te:
+ logger.error(f"Task {thread_id} did not stop within the timeout.")
+ except Exception as e:
+ logger.exception(f"Task {thread_id} encountered an error while stopping: {e}")
+ finally:
+ # 确保无论任务是否停止,都将其从任务列表中移除
+ del self.device_tasks[device_id]
+ else:
+ logger.info(f"Task {thread_id} has already stopped.")
+ # 任务已停止,直接从任务列表中移除
+ del self.device_tasks[device_id]
+ else:
+ logger.warning(f"No task found for {thread_id} .")
+ else:
+ logger.warning(f"No task exists for device {device_id} in device_tasks.")
+
+ def restart_device_thread(self, device_id):
+ try:
+ self.stop_device_thread(device_id)
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ except Exception as e:
+ logger.error(e)
+
+ def on_device_change(self, device_id, change_type):
+ logger.info(f"on device change, device {device_id} {change_type}")
+ """设备变化回调"""
+ if change_type == NotifyChangeType.DEVICE_CREATE:
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ elif change_type == NotifyChangeType.DEVICE_DELETE:
+ self.stop_device_thread(device_id)
+ elif change_type == NotifyChangeType.DEVICE_UPDATE:
+ # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
+ if device_id in self.device_tasks:
+ old_device = self.device_tasks[device_id].device
+ new_device = self.device_service.get_device(device_id)
+
+ self.restart_device_thread(device_id)
+
+ def on_scene_change(self, scene_id, change_type):
+ """模型变化回调"""
+ if change_type == NotifyChangeType.SCENE_UPDATE:
+
+ devices_to_reload = [
+ device_id for device_id, relation_info in self.device_scene_relations.items()
+ if relation_info.scene_id == scene_id
+ ]
+ for device_id in devices_to_reload:
+ self.restart_device_thread(device_id)
+
+ def on_relation_change(self, device_id, change_type):
+ """设备模型关系变化回调"""
+ # 判断模型绑定关系是否真的发生了变化,避免不必要的重启
+ if change_type == NotifyChangeType.DEVICE_SCENE_RELATION_UPDATE:
+ self.restart_device_thread(device_id)
diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml
index d0876a7..33f5c97 100644
--- a/.idea/safe-algo-pro.iml
+++ b/.idea/safe-algo-pro.iml
@@ -5,4 +5,7 @@
+
+
+
\ No newline at end of file
diff --git a/algo/algo_runner.py b/algo/algo_runner.py
index d59d93a..40b8d1e 100644
--- a/algo/algo_runner.py
+++ b/algo/algo_runner.py
@@ -5,7 +5,7 @@
from algo.device_detection_task import DeviceDetectionTask
from algo.model_manager import ModelManager
-from common.consts import NotifyChangeType
+from common.consts import NotifyChangeType, DEVICE_MODE
from common.global_thread_pool import GlobalThreadPool
from db.database import get_db
from entity.device import Device
@@ -14,23 +14,25 @@
from services.model_service import ModelService
from common.global_logger import logger
-
class AlgoRunner:
- def __init__(self):
- with next(get_db()) as db:
- self.device_service = DeviceService(db)
- self.model_service = ModelService(db)
- self.relation_service = DeviceModelRelationService(db)
- self.model_manager = ModelManager(db)
+ def __init__(self,
+ device_service : DeviceService,
+ model_service: ModelService,
+ relation_service: DeviceModelRelationService):
+
+ self.device_service = device_service
+ self.model_service = model_service
+ self.relation_service = relation_service
+ self.model_manager = ModelManager(model_service)
self.thread_pool = GlobalThreadPool()
self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程
self.device_model_relations = {}
# 注册设备和模型的变化回调
- self.device_service.register_change_callback(self.on_device_change)
- self.model_service.register_change_callback(self.on_model_change)
- self.relation_service.register_change_callback(self.on_relation_change)
+ # self.device_service.register_change_callback(self.on_device_change)
+ # self.model_service.register_change_callback(self.on_model_change)
+ # self.relation_service.register_change_callback(self.on_relation_change)
async def start(self):
logger.info("Starting AlgoRunner...")
@@ -44,10 +46,6 @@
for device in devices:
self.start_device_thread(device)
- # @staticmethod
- # def get_device_thread_id(device_id):
- # return f'device_{device_id}'
-
def start_device_thread(self, device: Device):
device = copy.deepcopy(device)
"""为单个设备启动检测线程"""
@@ -56,7 +54,10 @@
logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程')
return
- thread_id = f'device_{device.id}_{uuid.uuid4()}'
+ if not device.mode == DEVICE_MODE.ALGO:
+ return
+
+ thread_id = f'device_{device.id}_algo_{uuid.uuid4()}'
logger.info(f'start thread {thread_id}, device info: {device}')
# if self.thread_pool.check_task_is_running(thread_id=thread_id):
@@ -135,9 +136,10 @@
elif change_type == NotifyChangeType.DEVICE_DELETE:
self.stop_device_thread(device_id)
elif change_type == NotifyChangeType.DEVICE_UPDATE:
- # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
- old_device = self.device_tasks[device_id].device
- new_device = self.device_service.get_device(device_id)
+ if device_id in self.device_tasks:
+ # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
+ old_device = self.device_tasks[device_id].device
+ new_device = self.device_service.get_device(device_id)
self.restart_device_thread(device_id)
diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py
index d7d8015..9109d75 100644
--- a/algo/algo_runner_manager.py
+++ b/algo/algo_runner_manager.py
@@ -1,7 +1,8 @@
from algo.algo_runner import AlgoRunner
-algo_runner = AlgoRunner()
+# algo_runner = AlgoRunner()
def get_algo_runner():
- return algo_runner
+ pass
+ # return algo_runner
diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py
index b5a2d9e..e222ca7 100644
--- a/algo/device_detection_task.py
+++ b/algo/device_detection_task.py
@@ -9,7 +9,7 @@
from common.device_status_manager import DeviceStatusManager
from common.global_logger import logger
from common.global_thread_pool import GlobalThreadPool
-from common.string_utils import camel_to_snake
+from common.string_utils import camel_to_snake, get_class
from db.database import get_db
from entity.device import Device
from entity.frame_analysis_result import FrameAnalysisResultCreate
@@ -17,11 +17,7 @@
from services.frame_analysis_result_service import FrameAnalysisResultService
-def get_class(module_name, class_name):
- # 动态导入模块
- module = importlib.import_module(module_name)
- # 使用 getattr 从模块中获取类
- return getattr(module, class_name)
+
@dataclass
diff --git a/algo/model_manager.py b/algo/model_manager.py
index 94a058c..e3bc4c9 100644
--- a/algo/model_manager.py
+++ b/algo/model_manager.py
@@ -19,9 +19,9 @@
class ModelManager:
- def __init__(self, db, model_warm_up=5):
- self.db = db
- self.model_service = ModelService(self.db)
+ def __init__(self, model_service: ModelService, model_warm_up=5):
+ # self.db = db
+ self.model_service = model_service
self.models: Dict[int, AlgoModelExec] = {}
self.model_warm_up = model_warm_up
@@ -37,6 +37,8 @@
logger.info('loading models')
self.models = {}
self.query_model_inuse()
+ if not self.models:
+ logger.info("no model in use")
for algo_model_id, algo_model_exec in self.models.items():
self.load_model(algo_model_exec)
diff --git a/algo/scene_runner.py b/algo/scene_runner.py
new file mode 100644
index 0000000..2fb2b2b
--- /dev/null
+++ b/algo/scene_runner.py
@@ -0,0 +1,146 @@
+import concurrent
+import copy
+import uuid
+from typing import Dict
+
+from common.consts import DEVICE_MODE, NotifyChangeType
+from common.global_logger import logger
+from common.global_thread_pool import GlobalThreadPool
+from common.string_utils import get_class, camel_to_snake
+from entity.device import Device
+from scene_handler.base_scene_handler import BaseSceneHandler
+from services.device_scene_relation_service import DeviceSceneRelationService
+from services.device_service import DeviceService
+from services.scene_service import SceneService
+from tcp.tcp_manager import TcpManager
+
+
+class SceneRunner:
+ def __init__(self,
+ device_service: DeviceService,
+ scene_service: SceneService,
+ relation_service: DeviceSceneRelationService,
+ tcp_manager: TcpManager,
+ main_loop):
+ self.device_service = device_service
+ self.scene_service = scene_service
+ self.relation_service = relation_service
+ self.tcp_manager = tcp_manager
+ self.main_loop = main_loop
+
+ self.thread_pool = GlobalThreadPool()
+ self.device_tasks: Dict[int, BaseSceneHandler] = {}
+ self.device_scene_relations = {}
+
+ # 注册设备和模型的变化回调
+ # self.device_service.register_change_callback(self.on_device_change)
+ # self.scene_service.register_change_callback(self.on_scene_change)
+ # self.relation_service.register_change_callback(self.on_relation_change)
+
+ async def start(self):
+ self.load_and_start_devices()
+
+ def load_and_start_devices(self):
+ """从数据库读取设备列表并启动线程"""
+ devices = self.device_service.get_device_list()
+ devices = [device for device in devices if device.mode == DEVICE_MODE.SCENE]
+ for device in devices:
+ self.start_device_thread(device)
+
+ def start_device_thread(self, device: Device):
+ device = copy.deepcopy(device)
+ """为单个设备启动检测线程"""
+
+ if not device.input_stream_url:
+ logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程')
+ return
+
+ if not device.mode == DEVICE_MODE.SCENE:
+ return
+
+ thread_id = f'device_{device.id}_scene_{uuid.uuid4()}'
+ scene = self.relation_service.get_device_scene(device.id)
+ if scene:
+ self.device_scene_relations[device.id] = scene
+ try:
+ handle_task_name = scene.scene_handle_task
+ handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name)
+ handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop)
+ future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id)
+ self.device_tasks[device.id] = handler_instance
+ logger.info(f'start thread {thread_id}, device info: {device}')
+ except Exception as e:
+ logger.exception(f"设备 {device.code} 场景 {scene.scene_name} 启动异常 : {e}")
+
+ def stop_device_thread(self, device_id):
+ """停止指定设备的检测线程,并确保其成功停止"""
+ if device_id in self.device_tasks:
+ logger.info(f'stop device {device_id} thread')
+ # 获取线程 ID 和 future 对象
+ thread_id = self.device_tasks[device_id].thread_id
+ future = self.thread_pool.get_task_future(thread_id=thread_id)
+
+ if future:
+ if not future.done():
+ # 任务正在运行,调用 stop_detection_task 停止任务
+ self.device_tasks[device_id].stop_task()
+ try:
+ # 设置超时时间等待任务停止(例如10秒)
+ result = future.result(timeout=30)
+ logger.info(f"Task {thread_id} stopped successfully.")
+ except concurrent.futures.TimeoutError as te:
+ logger.error(f"Task {thread_id} did not stop within the timeout.")
+ except Exception as e:
+ logger.exception(f"Task {thread_id} encountered an error while stopping: {e}")
+ finally:
+ # 确保无论任务是否停止,都将其从任务列表中移除
+ del self.device_tasks[device_id]
+ else:
+ logger.info(f"Task {thread_id} has already stopped.")
+ # 任务已停止,直接从任务列表中移除
+ del self.device_tasks[device_id]
+ else:
+ logger.warning(f"No task found for {thread_id} .")
+ else:
+ logger.warning(f"No task exists for device {device_id} in device_tasks.")
+
+ def restart_device_thread(self, device_id):
+ try:
+ self.stop_device_thread(device_id)
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ except Exception as e:
+ logger.error(e)
+
+ def on_device_change(self, device_id, change_type):
+ logger.info(f"on device change, device {device_id} {change_type}")
+ """设备变化回调"""
+ if change_type == NotifyChangeType.DEVICE_CREATE:
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ elif change_type == NotifyChangeType.DEVICE_DELETE:
+ self.stop_device_thread(device_id)
+ elif change_type == NotifyChangeType.DEVICE_UPDATE:
+ # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
+ if device_id in self.device_tasks:
+ old_device = self.device_tasks[device_id].device
+ new_device = self.device_service.get_device(device_id)
+
+ self.restart_device_thread(device_id)
+
+ def on_scene_change(self, scene_id, change_type):
+ """模型变化回调"""
+ if change_type == NotifyChangeType.SCENE_UPDATE:
+
+ devices_to_reload = [
+ device_id for device_id, relation_info in self.device_scene_relations.items()
+ if relation_info.scene_id == scene_id
+ ]
+ for device_id in devices_to_reload:
+ self.restart_device_thread(device_id)
+
+ def on_relation_change(self, device_id, change_type):
+ """设备模型关系变化回调"""
+ # 判断模型绑定关系是否真的发生了变化,避免不必要的重启
+ if change_type == NotifyChangeType.DEVICE_SCENE_RELATION_UPDATE:
+ self.restart_device_thread(device_id)
diff --git a/apis/device.py b/apis/device.py
index f786851..85fb10c 100644
--- a/apis/device.py
+++ b/apis/device.py
@@ -3,15 +3,18 @@
from fastapi import APIRouter, Depends, Query
from sqlmodel import Session
-from algo.algo_runner import AlgoRunner
from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param
from db.database import get_db
from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo
from services.device_service import DeviceService
-from algo.algo_runner_manager import get_algo_runner
+from app_instance import get_app
router = APIRouter()
+app = get_app()
+
+def get_service():
+ return app.state.device_service
@router.get("/list", response_model=StandardResponse[List[DeviceInfo]])
@@ -45,15 +48,13 @@
@router.post("/add", response_model=StandardResponse[DeviceInfo])
-def create_device(device_data: DeviceCreate, algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.device_service
+def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)):
device = service.create_device(device_data)
return standard_response(data=device)
@router.post("/update", response_model=StandardResponse[DeviceInfo])
-def update_device(device_data: DeviceUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.device_service
+def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)):
device = service.update_device(device_data)
if not device:
return standard_error_response(data=device_data, message="Device not found")
@@ -61,8 +62,7 @@
@router.delete("/delete", response_model=StandardResponse[int])
-def delete_device(device_id: int, algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.device_service
+def delete_device(device_id: int, service: DeviceService = Depends(get_service)):
device = service.delete_device(device_id)
if not device:
return standard_error_response(data=device_id, message="Device not found")
diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml
index d0876a7..33f5c97 100644
--- a/.idea/safe-algo-pro.iml
+++ b/.idea/safe-algo-pro.iml
@@ -5,4 +5,7 @@
+
+
+
\ No newline at end of file
diff --git a/algo/algo_runner.py b/algo/algo_runner.py
index d59d93a..40b8d1e 100644
--- a/algo/algo_runner.py
+++ b/algo/algo_runner.py
@@ -5,7 +5,7 @@
from algo.device_detection_task import DeviceDetectionTask
from algo.model_manager import ModelManager
-from common.consts import NotifyChangeType
+from common.consts import NotifyChangeType, DEVICE_MODE
from common.global_thread_pool import GlobalThreadPool
from db.database import get_db
from entity.device import Device
@@ -14,23 +14,25 @@
from services.model_service import ModelService
from common.global_logger import logger
-
class AlgoRunner:
- def __init__(self):
- with next(get_db()) as db:
- self.device_service = DeviceService(db)
- self.model_service = ModelService(db)
- self.relation_service = DeviceModelRelationService(db)
- self.model_manager = ModelManager(db)
+ def __init__(self,
+ device_service : DeviceService,
+ model_service: ModelService,
+ relation_service: DeviceModelRelationService):
+
+ self.device_service = device_service
+ self.model_service = model_service
+ self.relation_service = relation_service
+ self.model_manager = ModelManager(model_service)
self.thread_pool = GlobalThreadPool()
self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程
self.device_model_relations = {}
# 注册设备和模型的变化回调
- self.device_service.register_change_callback(self.on_device_change)
- self.model_service.register_change_callback(self.on_model_change)
- self.relation_service.register_change_callback(self.on_relation_change)
+ # self.device_service.register_change_callback(self.on_device_change)
+ # self.model_service.register_change_callback(self.on_model_change)
+ # self.relation_service.register_change_callback(self.on_relation_change)
async def start(self):
logger.info("Starting AlgoRunner...")
@@ -44,10 +46,6 @@
for device in devices:
self.start_device_thread(device)
- # @staticmethod
- # def get_device_thread_id(device_id):
- # return f'device_{device_id}'
-
def start_device_thread(self, device: Device):
device = copy.deepcopy(device)
"""为单个设备启动检测线程"""
@@ -56,7 +54,10 @@
logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程')
return
- thread_id = f'device_{device.id}_{uuid.uuid4()}'
+ if not device.mode == DEVICE_MODE.ALGO:
+ return
+
+ thread_id = f'device_{device.id}_algo_{uuid.uuid4()}'
logger.info(f'start thread {thread_id}, device info: {device}')
# if self.thread_pool.check_task_is_running(thread_id=thread_id):
@@ -135,9 +136,10 @@
elif change_type == NotifyChangeType.DEVICE_DELETE:
self.stop_device_thread(device_id)
elif change_type == NotifyChangeType.DEVICE_UPDATE:
- # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
- old_device = self.device_tasks[device_id].device
- new_device = self.device_service.get_device(device_id)
+ if device_id in self.device_tasks:
+ # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
+ old_device = self.device_tasks[device_id].device
+ new_device = self.device_service.get_device(device_id)
self.restart_device_thread(device_id)
diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py
index d7d8015..9109d75 100644
--- a/algo/algo_runner_manager.py
+++ b/algo/algo_runner_manager.py
@@ -1,7 +1,8 @@
from algo.algo_runner import AlgoRunner
-algo_runner = AlgoRunner()
+# algo_runner = AlgoRunner()
def get_algo_runner():
- return algo_runner
+ pass
+ # return algo_runner
diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py
index b5a2d9e..e222ca7 100644
--- a/algo/device_detection_task.py
+++ b/algo/device_detection_task.py
@@ -9,7 +9,7 @@
from common.device_status_manager import DeviceStatusManager
from common.global_logger import logger
from common.global_thread_pool import GlobalThreadPool
-from common.string_utils import camel_to_snake
+from common.string_utils import camel_to_snake, get_class
from db.database import get_db
from entity.device import Device
from entity.frame_analysis_result import FrameAnalysisResultCreate
@@ -17,11 +17,7 @@
from services.frame_analysis_result_service import FrameAnalysisResultService
-def get_class(module_name, class_name):
- # 动态导入模块
- module = importlib.import_module(module_name)
- # 使用 getattr 从模块中获取类
- return getattr(module, class_name)
+
@dataclass
diff --git a/algo/model_manager.py b/algo/model_manager.py
index 94a058c..e3bc4c9 100644
--- a/algo/model_manager.py
+++ b/algo/model_manager.py
@@ -19,9 +19,9 @@
class ModelManager:
- def __init__(self, db, model_warm_up=5):
- self.db = db
- self.model_service = ModelService(self.db)
+ def __init__(self, model_service: ModelService, model_warm_up=5):
+ # self.db = db
+ self.model_service = model_service
self.models: Dict[int, AlgoModelExec] = {}
self.model_warm_up = model_warm_up
@@ -37,6 +37,8 @@
logger.info('loading models')
self.models = {}
self.query_model_inuse()
+ if not self.models:
+ logger.info("no model in use")
for algo_model_id, algo_model_exec in self.models.items():
self.load_model(algo_model_exec)
diff --git a/algo/scene_runner.py b/algo/scene_runner.py
new file mode 100644
index 0000000..2fb2b2b
--- /dev/null
+++ b/algo/scene_runner.py
@@ -0,0 +1,146 @@
+import concurrent
+import copy
+import uuid
+from typing import Dict
+
+from common.consts import DEVICE_MODE, NotifyChangeType
+from common.global_logger import logger
+from common.global_thread_pool import GlobalThreadPool
+from common.string_utils import get_class, camel_to_snake
+from entity.device import Device
+from scene_handler.base_scene_handler import BaseSceneHandler
+from services.device_scene_relation_service import DeviceSceneRelationService
+from services.device_service import DeviceService
+from services.scene_service import SceneService
+from tcp.tcp_manager import TcpManager
+
+
+class SceneRunner:
+ def __init__(self,
+ device_service: DeviceService,
+ scene_service: SceneService,
+ relation_service: DeviceSceneRelationService,
+ tcp_manager: TcpManager,
+ main_loop):
+ self.device_service = device_service
+ self.scene_service = scene_service
+ self.relation_service = relation_service
+ self.tcp_manager = tcp_manager
+ self.main_loop = main_loop
+
+ self.thread_pool = GlobalThreadPool()
+ self.device_tasks: Dict[int, BaseSceneHandler] = {}
+ self.device_scene_relations = {}
+
+ # 注册设备和模型的变化回调
+ # self.device_service.register_change_callback(self.on_device_change)
+ # self.scene_service.register_change_callback(self.on_scene_change)
+ # self.relation_service.register_change_callback(self.on_relation_change)
+
+ async def start(self):
+ self.load_and_start_devices()
+
+ def load_and_start_devices(self):
+ """从数据库读取设备列表并启动线程"""
+ devices = self.device_service.get_device_list()
+ devices = [device for device in devices if device.mode == DEVICE_MODE.SCENE]
+ for device in devices:
+ self.start_device_thread(device)
+
+ def start_device_thread(self, device: Device):
+ device = copy.deepcopy(device)
+ """为单个设备启动检测线程"""
+
+ if not device.input_stream_url:
+ logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程')
+ return
+
+ if not device.mode == DEVICE_MODE.SCENE:
+ return
+
+ thread_id = f'device_{device.id}_scene_{uuid.uuid4()}'
+ scene = self.relation_service.get_device_scene(device.id)
+ if scene:
+ self.device_scene_relations[device.id] = scene
+ try:
+ handle_task_name = scene.scene_handle_task
+ handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name)
+ handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop)
+ future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id)
+ self.device_tasks[device.id] = handler_instance
+ logger.info(f'start thread {thread_id}, device info: {device}')
+ except Exception as e:
+ logger.exception(f"设备 {device.code} 场景 {scene.scene_name} 启动异常 : {e}")
+
+ def stop_device_thread(self, device_id):
+ """停止指定设备的检测线程,并确保其成功停止"""
+ if device_id in self.device_tasks:
+ logger.info(f'stop device {device_id} thread')
+ # 获取线程 ID 和 future 对象
+ thread_id = self.device_tasks[device_id].thread_id
+ future = self.thread_pool.get_task_future(thread_id=thread_id)
+
+ if future:
+ if not future.done():
+ # 任务正在运行,调用 stop_detection_task 停止任务
+ self.device_tasks[device_id].stop_task()
+ try:
+ # 设置超时时间等待任务停止(例如10秒)
+ result = future.result(timeout=30)
+ logger.info(f"Task {thread_id} stopped successfully.")
+ except concurrent.futures.TimeoutError as te:
+ logger.error(f"Task {thread_id} did not stop within the timeout.")
+ except Exception as e:
+ logger.exception(f"Task {thread_id} encountered an error while stopping: {e}")
+ finally:
+ # 确保无论任务是否停止,都将其从任务列表中移除
+ del self.device_tasks[device_id]
+ else:
+ logger.info(f"Task {thread_id} has already stopped.")
+ # 任务已停止,直接从任务列表中移除
+ del self.device_tasks[device_id]
+ else:
+ logger.warning(f"No task found for {thread_id} .")
+ else:
+ logger.warning(f"No task exists for device {device_id} in device_tasks.")
+
+ def restart_device_thread(self, device_id):
+ try:
+ self.stop_device_thread(device_id)
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ except Exception as e:
+ logger.error(e)
+
+ def on_device_change(self, device_id, change_type):
+ logger.info(f"on device change, device {device_id} {change_type}")
+ """设备变化回调"""
+ if change_type == NotifyChangeType.DEVICE_CREATE:
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ elif change_type == NotifyChangeType.DEVICE_DELETE:
+ self.stop_device_thread(device_id)
+ elif change_type == NotifyChangeType.DEVICE_UPDATE:
+ # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
+ if device_id in self.device_tasks:
+ old_device = self.device_tasks[device_id].device
+ new_device = self.device_service.get_device(device_id)
+
+ self.restart_device_thread(device_id)
+
+ def on_scene_change(self, scene_id, change_type):
+ """模型变化回调"""
+ if change_type == NotifyChangeType.SCENE_UPDATE:
+
+ devices_to_reload = [
+ device_id for device_id, relation_info in self.device_scene_relations.items()
+ if relation_info.scene_id == scene_id
+ ]
+ for device_id in devices_to_reload:
+ self.restart_device_thread(device_id)
+
+ def on_relation_change(self, device_id, change_type):
+ """设备模型关系变化回调"""
+ # 判断模型绑定关系是否真的发生了变化,避免不必要的重启
+ if change_type == NotifyChangeType.DEVICE_SCENE_RELATION_UPDATE:
+ self.restart_device_thread(device_id)
diff --git a/apis/device.py b/apis/device.py
index f786851..85fb10c 100644
--- a/apis/device.py
+++ b/apis/device.py
@@ -3,15 +3,18 @@
from fastapi import APIRouter, Depends, Query
from sqlmodel import Session
-from algo.algo_runner import AlgoRunner
from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param
from db.database import get_db
from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo
from services.device_service import DeviceService
-from algo.algo_runner_manager import get_algo_runner
+from app_instance import get_app
router = APIRouter()
+app = get_app()
+
+def get_service():
+ return app.state.device_service
@router.get("/list", response_model=StandardResponse[List[DeviceInfo]])
@@ -45,15 +48,13 @@
@router.post("/add", response_model=StandardResponse[DeviceInfo])
-def create_device(device_data: DeviceCreate, algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.device_service
+def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)):
device = service.create_device(device_data)
return standard_response(data=device)
@router.post("/update", response_model=StandardResponse[DeviceInfo])
-def update_device(device_data: DeviceUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.device_service
+def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)):
device = service.update_device(device_data)
if not device:
return standard_error_response(data=device_data, message="Device not found")
@@ -61,8 +62,7 @@
@router.delete("/delete", response_model=StandardResponse[int])
-def delete_device(device_id: int, algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.device_service
+def delete_device(device_id: int, service: DeviceService = Depends(get_service)):
device = service.delete_device(device_id)
if not device:
return standard_error_response(data=device_id, message="Device not found")
diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py
index 9152eca..443176b 100644
--- a/apis/device_model_realtion.py
+++ b/apis/device_model_realtion.py
@@ -8,10 +8,13 @@
from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation
from services.device_model_relation_service import DeviceModelRelationService
-from algo.algo_runner import AlgoRunner
-from algo.algo_runner_manager import get_algo_runner
+from app_instance import get_app
router = APIRouter()
+app = get_app()
+
+def get_service():
+ return app.state.device_model_relation_service
@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]])
@@ -35,12 +38,10 @@
@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
def update_by_device(relation_data: List[DeviceModelRelationCreate],
device_id: int = Query(...),
- algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.relation_service
+ service: DeviceModelRelationService = Depends(get_service)):
relations = service.update_relations_by_device(device_id, relation_data)
return standard_response(data=relations)
-
# @router.delete("/delete_by_device", response_model=StandardResponse[int])
# def delete_device(device_id: int, db: Session = Depends(get_db)):
# service = DeviceModelRelationService(db)
diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml
index d0876a7..33f5c97 100644
--- a/.idea/safe-algo-pro.iml
+++ b/.idea/safe-algo-pro.iml
@@ -5,4 +5,7 @@
+
+
+
\ No newline at end of file
diff --git a/algo/algo_runner.py b/algo/algo_runner.py
index d59d93a..40b8d1e 100644
--- a/algo/algo_runner.py
+++ b/algo/algo_runner.py
@@ -5,7 +5,7 @@
from algo.device_detection_task import DeviceDetectionTask
from algo.model_manager import ModelManager
-from common.consts import NotifyChangeType
+from common.consts import NotifyChangeType, DEVICE_MODE
from common.global_thread_pool import GlobalThreadPool
from db.database import get_db
from entity.device import Device
@@ -14,23 +14,25 @@
from services.model_service import ModelService
from common.global_logger import logger
-
class AlgoRunner:
- def __init__(self):
- with next(get_db()) as db:
- self.device_service = DeviceService(db)
- self.model_service = ModelService(db)
- self.relation_service = DeviceModelRelationService(db)
- self.model_manager = ModelManager(db)
+ def __init__(self,
+ device_service : DeviceService,
+ model_service: ModelService,
+ relation_service: DeviceModelRelationService):
+
+ self.device_service = device_service
+ self.model_service = model_service
+ self.relation_service = relation_service
+ self.model_manager = ModelManager(model_service)
self.thread_pool = GlobalThreadPool()
self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程
self.device_model_relations = {}
# 注册设备和模型的变化回调
- self.device_service.register_change_callback(self.on_device_change)
- self.model_service.register_change_callback(self.on_model_change)
- self.relation_service.register_change_callback(self.on_relation_change)
+ # self.device_service.register_change_callback(self.on_device_change)
+ # self.model_service.register_change_callback(self.on_model_change)
+ # self.relation_service.register_change_callback(self.on_relation_change)
async def start(self):
logger.info("Starting AlgoRunner...")
@@ -44,10 +46,6 @@
for device in devices:
self.start_device_thread(device)
- # @staticmethod
- # def get_device_thread_id(device_id):
- # return f'device_{device_id}'
-
def start_device_thread(self, device: Device):
device = copy.deepcopy(device)
"""为单个设备启动检测线程"""
@@ -56,7 +54,10 @@
logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程')
return
- thread_id = f'device_{device.id}_{uuid.uuid4()}'
+ if not device.mode == DEVICE_MODE.ALGO:
+ return
+
+ thread_id = f'device_{device.id}_algo_{uuid.uuid4()}'
logger.info(f'start thread {thread_id}, device info: {device}')
# if self.thread_pool.check_task_is_running(thread_id=thread_id):
@@ -135,9 +136,10 @@
elif change_type == NotifyChangeType.DEVICE_DELETE:
self.stop_device_thread(device_id)
elif change_type == NotifyChangeType.DEVICE_UPDATE:
- # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
- old_device = self.device_tasks[device_id].device
- new_device = self.device_service.get_device(device_id)
+ if device_id in self.device_tasks:
+ # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
+ old_device = self.device_tasks[device_id].device
+ new_device = self.device_service.get_device(device_id)
self.restart_device_thread(device_id)
diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py
index d7d8015..9109d75 100644
--- a/algo/algo_runner_manager.py
+++ b/algo/algo_runner_manager.py
@@ -1,7 +1,8 @@
from algo.algo_runner import AlgoRunner
-algo_runner = AlgoRunner()
+# algo_runner = AlgoRunner()
def get_algo_runner():
- return algo_runner
+ pass
+ # return algo_runner
diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py
index b5a2d9e..e222ca7 100644
--- a/algo/device_detection_task.py
+++ b/algo/device_detection_task.py
@@ -9,7 +9,7 @@
from common.device_status_manager import DeviceStatusManager
from common.global_logger import logger
from common.global_thread_pool import GlobalThreadPool
-from common.string_utils import camel_to_snake
+from common.string_utils import camel_to_snake, get_class
from db.database import get_db
from entity.device import Device
from entity.frame_analysis_result import FrameAnalysisResultCreate
@@ -17,11 +17,7 @@
from services.frame_analysis_result_service import FrameAnalysisResultService
-def get_class(module_name, class_name):
- # 动态导入模块
- module = importlib.import_module(module_name)
- # 使用 getattr 从模块中获取类
- return getattr(module, class_name)
+
@dataclass
diff --git a/algo/model_manager.py b/algo/model_manager.py
index 94a058c..e3bc4c9 100644
--- a/algo/model_manager.py
+++ b/algo/model_manager.py
@@ -19,9 +19,9 @@
class ModelManager:
- def __init__(self, db, model_warm_up=5):
- self.db = db
- self.model_service = ModelService(self.db)
+ def __init__(self, model_service: ModelService, model_warm_up=5):
+ # self.db = db
+ self.model_service = model_service
self.models: Dict[int, AlgoModelExec] = {}
self.model_warm_up = model_warm_up
@@ -37,6 +37,8 @@
logger.info('loading models')
self.models = {}
self.query_model_inuse()
+ if not self.models:
+ logger.info("no model in use")
for algo_model_id, algo_model_exec in self.models.items():
self.load_model(algo_model_exec)
diff --git a/algo/scene_runner.py b/algo/scene_runner.py
new file mode 100644
index 0000000..2fb2b2b
--- /dev/null
+++ b/algo/scene_runner.py
@@ -0,0 +1,146 @@
+import concurrent
+import copy
+import uuid
+from typing import Dict
+
+from common.consts import DEVICE_MODE, NotifyChangeType
+from common.global_logger import logger
+from common.global_thread_pool import GlobalThreadPool
+from common.string_utils import get_class, camel_to_snake
+from entity.device import Device
+from scene_handler.base_scene_handler import BaseSceneHandler
+from services.device_scene_relation_service import DeviceSceneRelationService
+from services.device_service import DeviceService
+from services.scene_service import SceneService
+from tcp.tcp_manager import TcpManager
+
+
+class SceneRunner:
+ def __init__(self,
+ device_service: DeviceService,
+ scene_service: SceneService,
+ relation_service: DeviceSceneRelationService,
+ tcp_manager: TcpManager,
+ main_loop):
+ self.device_service = device_service
+ self.scene_service = scene_service
+ self.relation_service = relation_service
+ self.tcp_manager = tcp_manager
+ self.main_loop = main_loop
+
+ self.thread_pool = GlobalThreadPool()
+ self.device_tasks: Dict[int, BaseSceneHandler] = {}
+ self.device_scene_relations = {}
+
+ # 注册设备和模型的变化回调
+ # self.device_service.register_change_callback(self.on_device_change)
+ # self.scene_service.register_change_callback(self.on_scene_change)
+ # self.relation_service.register_change_callback(self.on_relation_change)
+
+ async def start(self):
+ self.load_and_start_devices()
+
+ def load_and_start_devices(self):
+ """从数据库读取设备列表并启动线程"""
+ devices = self.device_service.get_device_list()
+ devices = [device for device in devices if device.mode == DEVICE_MODE.SCENE]
+ for device in devices:
+ self.start_device_thread(device)
+
+ def start_device_thread(self, device: Device):
+ device = copy.deepcopy(device)
+ """为单个设备启动检测线程"""
+
+ if not device.input_stream_url:
+ logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程')
+ return
+
+ if not device.mode == DEVICE_MODE.SCENE:
+ return
+
+ thread_id = f'device_{device.id}_scene_{uuid.uuid4()}'
+ scene = self.relation_service.get_device_scene(device.id)
+ if scene:
+ self.device_scene_relations[device.id] = scene
+ try:
+ handle_task_name = scene.scene_handle_task
+ handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name)
+ handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop)
+ future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id)
+ self.device_tasks[device.id] = handler_instance
+ logger.info(f'start thread {thread_id}, device info: {device}')
+ except Exception as e:
+ logger.exception(f"设备 {device.code} 场景 {scene.scene_name} 启动异常 : {e}")
+
+ def stop_device_thread(self, device_id):
+ """停止指定设备的检测线程,并确保其成功停止"""
+ if device_id in self.device_tasks:
+ logger.info(f'stop device {device_id} thread')
+ # 获取线程 ID 和 future 对象
+ thread_id = self.device_tasks[device_id].thread_id
+ future = self.thread_pool.get_task_future(thread_id=thread_id)
+
+ if future:
+ if not future.done():
+ # 任务正在运行,调用 stop_detection_task 停止任务
+ self.device_tasks[device_id].stop_task()
+ try:
+ # 设置超时时间等待任务停止(例如10秒)
+ result = future.result(timeout=30)
+ logger.info(f"Task {thread_id} stopped successfully.")
+ except concurrent.futures.TimeoutError as te:
+ logger.error(f"Task {thread_id} did not stop within the timeout.")
+ except Exception as e:
+ logger.exception(f"Task {thread_id} encountered an error while stopping: {e}")
+ finally:
+ # 确保无论任务是否停止,都将其从任务列表中移除
+ del self.device_tasks[device_id]
+ else:
+ logger.info(f"Task {thread_id} has already stopped.")
+ # 任务已停止,直接从任务列表中移除
+ del self.device_tasks[device_id]
+ else:
+ logger.warning(f"No task found for {thread_id} .")
+ else:
+ logger.warning(f"No task exists for device {device_id} in device_tasks.")
+
+ def restart_device_thread(self, device_id):
+ try:
+ self.stop_device_thread(device_id)
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ except Exception as e:
+ logger.error(e)
+
+ def on_device_change(self, device_id, change_type):
+ logger.info(f"on device change, device {device_id} {change_type}")
+ """设备变化回调"""
+ if change_type == NotifyChangeType.DEVICE_CREATE:
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ elif change_type == NotifyChangeType.DEVICE_DELETE:
+ self.stop_device_thread(device_id)
+ elif change_type == NotifyChangeType.DEVICE_UPDATE:
+ # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
+ if device_id in self.device_tasks:
+ old_device = self.device_tasks[device_id].device
+ new_device = self.device_service.get_device(device_id)
+
+ self.restart_device_thread(device_id)
+
+ def on_scene_change(self, scene_id, change_type):
+ """模型变化回调"""
+ if change_type == NotifyChangeType.SCENE_UPDATE:
+
+ devices_to_reload = [
+ device_id for device_id, relation_info in self.device_scene_relations.items()
+ if relation_info.scene_id == scene_id
+ ]
+ for device_id in devices_to_reload:
+ self.restart_device_thread(device_id)
+
+ def on_relation_change(self, device_id, change_type):
+ """设备模型关系变化回调"""
+ # 判断模型绑定关系是否真的发生了变化,避免不必要的重启
+ if change_type == NotifyChangeType.DEVICE_SCENE_RELATION_UPDATE:
+ self.restart_device_thread(device_id)
diff --git a/apis/device.py b/apis/device.py
index f786851..85fb10c 100644
--- a/apis/device.py
+++ b/apis/device.py
@@ -3,15 +3,18 @@
from fastapi import APIRouter, Depends, Query
from sqlmodel import Session
-from algo.algo_runner import AlgoRunner
from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param
from db.database import get_db
from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo
from services.device_service import DeviceService
-from algo.algo_runner_manager import get_algo_runner
+from app_instance import get_app
router = APIRouter()
+app = get_app()
+
+def get_service():
+ return app.state.device_service
@router.get("/list", response_model=StandardResponse[List[DeviceInfo]])
@@ -45,15 +48,13 @@
@router.post("/add", response_model=StandardResponse[DeviceInfo])
-def create_device(device_data: DeviceCreate, algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.device_service
+def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)):
device = service.create_device(device_data)
return standard_response(data=device)
@router.post("/update", response_model=StandardResponse[DeviceInfo])
-def update_device(device_data: DeviceUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.device_service
+def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)):
device = service.update_device(device_data)
if not device:
return standard_error_response(data=device_data, message="Device not found")
@@ -61,8 +62,7 @@
@router.delete("/delete", response_model=StandardResponse[int])
-def delete_device(device_id: int, algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.device_service
+def delete_device(device_id: int, service: DeviceService = Depends(get_service)):
device = service.delete_device(device_id)
if not device:
return standard_error_response(data=device_id, message="Device not found")
diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py
index 9152eca..443176b 100644
--- a/apis/device_model_realtion.py
+++ b/apis/device_model_realtion.py
@@ -8,10 +8,13 @@
from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation
from services.device_model_relation_service import DeviceModelRelationService
-from algo.algo_runner import AlgoRunner
-from algo.algo_runner_manager import get_algo_runner
+from app_instance import get_app
router = APIRouter()
+app = get_app()
+
+def get_service():
+ return app.state.device_model_relation_service
@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]])
@@ -35,12 +38,10 @@
@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
def update_by_device(relation_data: List[DeviceModelRelationCreate],
device_id: int = Query(...),
- algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.relation_service
+ service: DeviceModelRelationService = Depends(get_service)):
relations = service.update_relations_by_device(device_id, relation_data)
return standard_response(data=relations)
-
# @router.delete("/delete_by_device", response_model=StandardResponse[int])
# def delete_device(device_id: int, db: Session = Depends(get_db)):
# service = DeviceModelRelationService(db)
diff --git a/apis/model.py b/apis/model.py
index 264e10a..78c064f 100644
--- a/apis/model.py
+++ b/apis/model.py
@@ -4,15 +4,17 @@
from sqlmodel import Session
from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param
+from app_instance import get_app
from db.database import get_db
from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo
from services.model_service import ModelService
-from algo.algo_runner import AlgoRunner
-from algo.algo_runner_manager import get_algo_runner
router = APIRouter()
+app = get_app()
+def get_service():
+ return app.state.model_service
@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]])
def get_model_list(
@@ -60,8 +62,7 @@
@router.post("/update", response_model=StandardResponse[AlgoModelInfo])
def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"),
file: UploadFile = File(None, description="模型文件"),
- algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.model_service
+ service: ModelService = Depends(get_service)):
model_data = AlgoModelUpdate.parse_raw(json_data)
model = service.update_model(model_data, file)
if not model:
diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml
index d0876a7..33f5c97 100644
--- a/.idea/safe-algo-pro.iml
+++ b/.idea/safe-algo-pro.iml
@@ -5,4 +5,7 @@
+
+
+
\ No newline at end of file
diff --git a/algo/algo_runner.py b/algo/algo_runner.py
index d59d93a..40b8d1e 100644
--- a/algo/algo_runner.py
+++ b/algo/algo_runner.py
@@ -5,7 +5,7 @@
from algo.device_detection_task import DeviceDetectionTask
from algo.model_manager import ModelManager
-from common.consts import NotifyChangeType
+from common.consts import NotifyChangeType, DEVICE_MODE
from common.global_thread_pool import GlobalThreadPool
from db.database import get_db
from entity.device import Device
@@ -14,23 +14,25 @@
from services.model_service import ModelService
from common.global_logger import logger
-
class AlgoRunner:
- def __init__(self):
- with next(get_db()) as db:
- self.device_service = DeviceService(db)
- self.model_service = ModelService(db)
- self.relation_service = DeviceModelRelationService(db)
- self.model_manager = ModelManager(db)
+ def __init__(self,
+ device_service : DeviceService,
+ model_service: ModelService,
+ relation_service: DeviceModelRelationService):
+
+ self.device_service = device_service
+ self.model_service = model_service
+ self.relation_service = relation_service
+ self.model_manager = ModelManager(model_service)
self.thread_pool = GlobalThreadPool()
self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程
self.device_model_relations = {}
# 注册设备和模型的变化回调
- self.device_service.register_change_callback(self.on_device_change)
- self.model_service.register_change_callback(self.on_model_change)
- self.relation_service.register_change_callback(self.on_relation_change)
+ # self.device_service.register_change_callback(self.on_device_change)
+ # self.model_service.register_change_callback(self.on_model_change)
+ # self.relation_service.register_change_callback(self.on_relation_change)
async def start(self):
logger.info("Starting AlgoRunner...")
@@ -44,10 +46,6 @@
for device in devices:
self.start_device_thread(device)
- # @staticmethod
- # def get_device_thread_id(device_id):
- # return f'device_{device_id}'
-
def start_device_thread(self, device: Device):
device = copy.deepcopy(device)
"""为单个设备启动检测线程"""
@@ -56,7 +54,10 @@
logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程')
return
- thread_id = f'device_{device.id}_{uuid.uuid4()}'
+ if not device.mode == DEVICE_MODE.ALGO:
+ return
+
+ thread_id = f'device_{device.id}_algo_{uuid.uuid4()}'
logger.info(f'start thread {thread_id}, device info: {device}')
# if self.thread_pool.check_task_is_running(thread_id=thread_id):
@@ -135,9 +136,10 @@
elif change_type == NotifyChangeType.DEVICE_DELETE:
self.stop_device_thread(device_id)
elif change_type == NotifyChangeType.DEVICE_UPDATE:
- # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
- old_device = self.device_tasks[device_id].device
- new_device = self.device_service.get_device(device_id)
+ if device_id in self.device_tasks:
+ # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
+ old_device = self.device_tasks[device_id].device
+ new_device = self.device_service.get_device(device_id)
self.restart_device_thread(device_id)
diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py
index d7d8015..9109d75 100644
--- a/algo/algo_runner_manager.py
+++ b/algo/algo_runner_manager.py
@@ -1,7 +1,8 @@
from algo.algo_runner import AlgoRunner
-algo_runner = AlgoRunner()
+# algo_runner = AlgoRunner()
def get_algo_runner():
- return algo_runner
+ pass
+ # return algo_runner
diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py
index b5a2d9e..e222ca7 100644
--- a/algo/device_detection_task.py
+++ b/algo/device_detection_task.py
@@ -9,7 +9,7 @@
from common.device_status_manager import DeviceStatusManager
from common.global_logger import logger
from common.global_thread_pool import GlobalThreadPool
-from common.string_utils import camel_to_snake
+from common.string_utils import camel_to_snake, get_class
from db.database import get_db
from entity.device import Device
from entity.frame_analysis_result import FrameAnalysisResultCreate
@@ -17,11 +17,7 @@
from services.frame_analysis_result_service import FrameAnalysisResultService
-def get_class(module_name, class_name):
- # 动态导入模块
- module = importlib.import_module(module_name)
- # 使用 getattr 从模块中获取类
- return getattr(module, class_name)
+
@dataclass
diff --git a/algo/model_manager.py b/algo/model_manager.py
index 94a058c..e3bc4c9 100644
--- a/algo/model_manager.py
+++ b/algo/model_manager.py
@@ -19,9 +19,9 @@
class ModelManager:
- def __init__(self, db, model_warm_up=5):
- self.db = db
- self.model_service = ModelService(self.db)
+ def __init__(self, model_service: ModelService, model_warm_up=5):
+ # self.db = db
+ self.model_service = model_service
self.models: Dict[int, AlgoModelExec] = {}
self.model_warm_up = model_warm_up
@@ -37,6 +37,8 @@
logger.info('loading models')
self.models = {}
self.query_model_inuse()
+ if not self.models:
+ logger.info("no model in use")
for algo_model_id, algo_model_exec in self.models.items():
self.load_model(algo_model_exec)
diff --git a/algo/scene_runner.py b/algo/scene_runner.py
new file mode 100644
index 0000000..2fb2b2b
--- /dev/null
+++ b/algo/scene_runner.py
@@ -0,0 +1,146 @@
+import concurrent
+import copy
+import uuid
+from typing import Dict
+
+from common.consts import DEVICE_MODE, NotifyChangeType
+from common.global_logger import logger
+from common.global_thread_pool import GlobalThreadPool
+from common.string_utils import get_class, camel_to_snake
+from entity.device import Device
+from scene_handler.base_scene_handler import BaseSceneHandler
+from services.device_scene_relation_service import DeviceSceneRelationService
+from services.device_service import DeviceService
+from services.scene_service import SceneService
+from tcp.tcp_manager import TcpManager
+
+
+class SceneRunner:
+ def __init__(self,
+ device_service: DeviceService,
+ scene_service: SceneService,
+ relation_service: DeviceSceneRelationService,
+ tcp_manager: TcpManager,
+ main_loop):
+ self.device_service = device_service
+ self.scene_service = scene_service
+ self.relation_service = relation_service
+ self.tcp_manager = tcp_manager
+ self.main_loop = main_loop
+
+ self.thread_pool = GlobalThreadPool()
+ self.device_tasks: Dict[int, BaseSceneHandler] = {}
+ self.device_scene_relations = {}
+
+ # 注册设备和模型的变化回调
+ # self.device_service.register_change_callback(self.on_device_change)
+ # self.scene_service.register_change_callback(self.on_scene_change)
+ # self.relation_service.register_change_callback(self.on_relation_change)
+
+ async def start(self):
+ self.load_and_start_devices()
+
+ def load_and_start_devices(self):
+ """从数据库读取设备列表并启动线程"""
+ devices = self.device_service.get_device_list()
+ devices = [device for device in devices if device.mode == DEVICE_MODE.SCENE]
+ for device in devices:
+ self.start_device_thread(device)
+
+ def start_device_thread(self, device: Device):
+ device = copy.deepcopy(device)
+ """为单个设备启动检测线程"""
+
+ if not device.input_stream_url:
+ logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程')
+ return
+
+ if not device.mode == DEVICE_MODE.SCENE:
+ return
+
+ thread_id = f'device_{device.id}_scene_{uuid.uuid4()}'
+ scene = self.relation_service.get_device_scene(device.id)
+ if scene:
+ self.device_scene_relations[device.id] = scene
+ try:
+ handle_task_name = scene.scene_handle_task
+ handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name)
+ handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop)
+ future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id)
+ self.device_tasks[device.id] = handler_instance
+ logger.info(f'start thread {thread_id}, device info: {device}')
+ except Exception as e:
+ logger.exception(f"设备 {device.code} 场景 {scene.scene_name} 启动异常 : {e}")
+
+ def stop_device_thread(self, device_id):
+ """停止指定设备的检测线程,并确保其成功停止"""
+ if device_id in self.device_tasks:
+ logger.info(f'stop device {device_id} thread')
+ # 获取线程 ID 和 future 对象
+ thread_id = self.device_tasks[device_id].thread_id
+ future = self.thread_pool.get_task_future(thread_id=thread_id)
+
+ if future:
+ if not future.done():
+ # 任务正在运行,调用 stop_detection_task 停止任务
+ self.device_tasks[device_id].stop_task()
+ try:
+ # 设置超时时间等待任务停止(例如10秒)
+ result = future.result(timeout=30)
+ logger.info(f"Task {thread_id} stopped successfully.")
+ except concurrent.futures.TimeoutError as te:
+ logger.error(f"Task {thread_id} did not stop within the timeout.")
+ except Exception as e:
+ logger.exception(f"Task {thread_id} encountered an error while stopping: {e}")
+ finally:
+ # 确保无论任务是否停止,都将其从任务列表中移除
+ del self.device_tasks[device_id]
+ else:
+ logger.info(f"Task {thread_id} has already stopped.")
+ # 任务已停止,直接从任务列表中移除
+ del self.device_tasks[device_id]
+ else:
+ logger.warning(f"No task found for {thread_id} .")
+ else:
+ logger.warning(f"No task exists for device {device_id} in device_tasks.")
+
+ def restart_device_thread(self, device_id):
+ try:
+ self.stop_device_thread(device_id)
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ except Exception as e:
+ logger.error(e)
+
+ def on_device_change(self, device_id, change_type):
+ logger.info(f"on device change, device {device_id} {change_type}")
+ """设备变化回调"""
+ if change_type == NotifyChangeType.DEVICE_CREATE:
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ elif change_type == NotifyChangeType.DEVICE_DELETE:
+ self.stop_device_thread(device_id)
+ elif change_type == NotifyChangeType.DEVICE_UPDATE:
+ # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
+ if device_id in self.device_tasks:
+ old_device = self.device_tasks[device_id].device
+ new_device = self.device_service.get_device(device_id)
+
+ self.restart_device_thread(device_id)
+
+ def on_scene_change(self, scene_id, change_type):
+ """模型变化回调"""
+ if change_type == NotifyChangeType.SCENE_UPDATE:
+
+ devices_to_reload = [
+ device_id for device_id, relation_info in self.device_scene_relations.items()
+ if relation_info.scene_id == scene_id
+ ]
+ for device_id in devices_to_reload:
+ self.restart_device_thread(device_id)
+
+ def on_relation_change(self, device_id, change_type):
+ """设备模型关系变化回调"""
+ # 判断模型绑定关系是否真的发生了变化,避免不必要的重启
+ if change_type == NotifyChangeType.DEVICE_SCENE_RELATION_UPDATE:
+ self.restart_device_thread(device_id)
diff --git a/apis/device.py b/apis/device.py
index f786851..85fb10c 100644
--- a/apis/device.py
+++ b/apis/device.py
@@ -3,15 +3,18 @@
from fastapi import APIRouter, Depends, Query
from sqlmodel import Session
-from algo.algo_runner import AlgoRunner
from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param
from db.database import get_db
from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo
from services.device_service import DeviceService
-from algo.algo_runner_manager import get_algo_runner
+from app_instance import get_app
router = APIRouter()
+app = get_app()
+
+def get_service():
+ return app.state.device_service
@router.get("/list", response_model=StandardResponse[List[DeviceInfo]])
@@ -45,15 +48,13 @@
@router.post("/add", response_model=StandardResponse[DeviceInfo])
-def create_device(device_data: DeviceCreate, algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.device_service
+def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)):
device = service.create_device(device_data)
return standard_response(data=device)
@router.post("/update", response_model=StandardResponse[DeviceInfo])
-def update_device(device_data: DeviceUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.device_service
+def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)):
device = service.update_device(device_data)
if not device:
return standard_error_response(data=device_data, message="Device not found")
@@ -61,8 +62,7 @@
@router.delete("/delete", response_model=StandardResponse[int])
-def delete_device(device_id: int, algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.device_service
+def delete_device(device_id: int, service: DeviceService = Depends(get_service)):
device = service.delete_device(device_id)
if not device:
return standard_error_response(data=device_id, message="Device not found")
diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py
index 9152eca..443176b 100644
--- a/apis/device_model_realtion.py
+++ b/apis/device_model_realtion.py
@@ -8,10 +8,13 @@
from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation
from services.device_model_relation_service import DeviceModelRelationService
-from algo.algo_runner import AlgoRunner
-from algo.algo_runner_manager import get_algo_runner
+from app_instance import get_app
router = APIRouter()
+app = get_app()
+
+def get_service():
+ return app.state.device_model_relation_service
@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]])
@@ -35,12 +38,10 @@
@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
def update_by_device(relation_data: List[DeviceModelRelationCreate],
device_id: int = Query(...),
- algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.relation_service
+ service: DeviceModelRelationService = Depends(get_service)):
relations = service.update_relations_by_device(device_id, relation_data)
return standard_response(data=relations)
-
# @router.delete("/delete_by_device", response_model=StandardResponse[int])
# def delete_device(device_id: int, db: Session = Depends(get_db)):
# service = DeviceModelRelationService(db)
diff --git a/apis/model.py b/apis/model.py
index 264e10a..78c064f 100644
--- a/apis/model.py
+++ b/apis/model.py
@@ -4,15 +4,17 @@
from sqlmodel import Session
from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param
+from app_instance import get_app
from db.database import get_db
from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo
from services.model_service import ModelService
-from algo.algo_runner import AlgoRunner
-from algo.algo_runner_manager import get_algo_runner
router = APIRouter()
+app = get_app()
+def get_service():
+ return app.state.model_service
@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]])
def get_model_list(
@@ -60,8 +62,7 @@
@router.post("/update", response_model=StandardResponse[AlgoModelInfo])
def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"),
file: UploadFile = File(None, description="模型文件"),
- algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.model_service
+ service: ModelService = Depends(get_service)):
model_data = AlgoModelUpdate.parse_raw(json_data)
model = service.update_model(model_data, file)
if not model:
diff --git a/apis/scene.py b/apis/scene.py
index b0a4d22..f84b54b 100644
--- a/apis/scene.py
+++ b/apis/scene.py
@@ -9,8 +9,6 @@
from entity.scene import SceneInfo
from services.scene_service import SceneService
-from algo.algo_runner import AlgoRunner
-from algo.algo_runner_manager import get_algo_runner
router = APIRouter()
diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml
index d0876a7..33f5c97 100644
--- a/.idea/safe-algo-pro.iml
+++ b/.idea/safe-algo-pro.iml
@@ -5,4 +5,7 @@
+
+
+
\ No newline at end of file
diff --git a/algo/algo_runner.py b/algo/algo_runner.py
index d59d93a..40b8d1e 100644
--- a/algo/algo_runner.py
+++ b/algo/algo_runner.py
@@ -5,7 +5,7 @@
from algo.device_detection_task import DeviceDetectionTask
from algo.model_manager import ModelManager
-from common.consts import NotifyChangeType
+from common.consts import NotifyChangeType, DEVICE_MODE
from common.global_thread_pool import GlobalThreadPool
from db.database import get_db
from entity.device import Device
@@ -14,23 +14,25 @@
from services.model_service import ModelService
from common.global_logger import logger
-
class AlgoRunner:
- def __init__(self):
- with next(get_db()) as db:
- self.device_service = DeviceService(db)
- self.model_service = ModelService(db)
- self.relation_service = DeviceModelRelationService(db)
- self.model_manager = ModelManager(db)
+ def __init__(self,
+ device_service : DeviceService,
+ model_service: ModelService,
+ relation_service: DeviceModelRelationService):
+
+ self.device_service = device_service
+ self.model_service = model_service
+ self.relation_service = relation_service
+ self.model_manager = ModelManager(model_service)
self.thread_pool = GlobalThreadPool()
self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程
self.device_model_relations = {}
# 注册设备和模型的变化回调
- self.device_service.register_change_callback(self.on_device_change)
- self.model_service.register_change_callback(self.on_model_change)
- self.relation_service.register_change_callback(self.on_relation_change)
+ # self.device_service.register_change_callback(self.on_device_change)
+ # self.model_service.register_change_callback(self.on_model_change)
+ # self.relation_service.register_change_callback(self.on_relation_change)
async def start(self):
logger.info("Starting AlgoRunner...")
@@ -44,10 +46,6 @@
for device in devices:
self.start_device_thread(device)
- # @staticmethod
- # def get_device_thread_id(device_id):
- # return f'device_{device_id}'
-
def start_device_thread(self, device: Device):
device = copy.deepcopy(device)
"""为单个设备启动检测线程"""
@@ -56,7 +54,10 @@
logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程')
return
- thread_id = f'device_{device.id}_{uuid.uuid4()}'
+ if not device.mode == DEVICE_MODE.ALGO:
+ return
+
+ thread_id = f'device_{device.id}_algo_{uuid.uuid4()}'
logger.info(f'start thread {thread_id}, device info: {device}')
# if self.thread_pool.check_task_is_running(thread_id=thread_id):
@@ -135,9 +136,10 @@
elif change_type == NotifyChangeType.DEVICE_DELETE:
self.stop_device_thread(device_id)
elif change_type == NotifyChangeType.DEVICE_UPDATE:
- # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
- old_device = self.device_tasks[device_id].device
- new_device = self.device_service.get_device(device_id)
+ if device_id in self.device_tasks:
+ # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
+ old_device = self.device_tasks[device_id].device
+ new_device = self.device_service.get_device(device_id)
self.restart_device_thread(device_id)
diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py
index d7d8015..9109d75 100644
--- a/algo/algo_runner_manager.py
+++ b/algo/algo_runner_manager.py
@@ -1,7 +1,8 @@
from algo.algo_runner import AlgoRunner
-algo_runner = AlgoRunner()
+# algo_runner = AlgoRunner()
def get_algo_runner():
- return algo_runner
+ pass
+ # return algo_runner
diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py
index b5a2d9e..e222ca7 100644
--- a/algo/device_detection_task.py
+++ b/algo/device_detection_task.py
@@ -9,7 +9,7 @@
from common.device_status_manager import DeviceStatusManager
from common.global_logger import logger
from common.global_thread_pool import GlobalThreadPool
-from common.string_utils import camel_to_snake
+from common.string_utils import camel_to_snake, get_class
from db.database import get_db
from entity.device import Device
from entity.frame_analysis_result import FrameAnalysisResultCreate
@@ -17,11 +17,7 @@
from services.frame_analysis_result_service import FrameAnalysisResultService
-def get_class(module_name, class_name):
- # 动态导入模块
- module = importlib.import_module(module_name)
- # 使用 getattr 从模块中获取类
- return getattr(module, class_name)
+
@dataclass
diff --git a/algo/model_manager.py b/algo/model_manager.py
index 94a058c..e3bc4c9 100644
--- a/algo/model_manager.py
+++ b/algo/model_manager.py
@@ -19,9 +19,9 @@
class ModelManager:
- def __init__(self, db, model_warm_up=5):
- self.db = db
- self.model_service = ModelService(self.db)
+ def __init__(self, model_service: ModelService, model_warm_up=5):
+ # self.db = db
+ self.model_service = model_service
self.models: Dict[int, AlgoModelExec] = {}
self.model_warm_up = model_warm_up
@@ -37,6 +37,8 @@
logger.info('loading models')
self.models = {}
self.query_model_inuse()
+ if not self.models:
+ logger.info("no model in use")
for algo_model_id, algo_model_exec in self.models.items():
self.load_model(algo_model_exec)
diff --git a/algo/scene_runner.py b/algo/scene_runner.py
new file mode 100644
index 0000000..2fb2b2b
--- /dev/null
+++ b/algo/scene_runner.py
@@ -0,0 +1,146 @@
+import concurrent
+import copy
+import uuid
+from typing import Dict
+
+from common.consts import DEVICE_MODE, NotifyChangeType
+from common.global_logger import logger
+from common.global_thread_pool import GlobalThreadPool
+from common.string_utils import get_class, camel_to_snake
+from entity.device import Device
+from scene_handler.base_scene_handler import BaseSceneHandler
+from services.device_scene_relation_service import DeviceSceneRelationService
+from services.device_service import DeviceService
+from services.scene_service import SceneService
+from tcp.tcp_manager import TcpManager
+
+
+class SceneRunner:
+ def __init__(self,
+ device_service: DeviceService,
+ scene_service: SceneService,
+ relation_service: DeviceSceneRelationService,
+ tcp_manager: TcpManager,
+ main_loop):
+ self.device_service = device_service
+ self.scene_service = scene_service
+ self.relation_service = relation_service
+ self.tcp_manager = tcp_manager
+ self.main_loop = main_loop
+
+ self.thread_pool = GlobalThreadPool()
+ self.device_tasks: Dict[int, BaseSceneHandler] = {}
+ self.device_scene_relations = {}
+
+ # 注册设备和模型的变化回调
+ # self.device_service.register_change_callback(self.on_device_change)
+ # self.scene_service.register_change_callback(self.on_scene_change)
+ # self.relation_service.register_change_callback(self.on_relation_change)
+
+ async def start(self):
+ self.load_and_start_devices()
+
+ def load_and_start_devices(self):
+ """从数据库读取设备列表并启动线程"""
+ devices = self.device_service.get_device_list()
+ devices = [device for device in devices if device.mode == DEVICE_MODE.SCENE]
+ for device in devices:
+ self.start_device_thread(device)
+
+ def start_device_thread(self, device: Device):
+ device = copy.deepcopy(device)
+ """为单个设备启动检测线程"""
+
+ if not device.input_stream_url:
+ logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程')
+ return
+
+ if not device.mode == DEVICE_MODE.SCENE:
+ return
+
+ thread_id = f'device_{device.id}_scene_{uuid.uuid4()}'
+ scene = self.relation_service.get_device_scene(device.id)
+ if scene:
+ self.device_scene_relations[device.id] = scene
+ try:
+ handle_task_name = scene.scene_handle_task
+ handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name)
+ handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop)
+ future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id)
+ self.device_tasks[device.id] = handler_instance
+ logger.info(f'start thread {thread_id}, device info: {device}')
+ except Exception as e:
+ logger.exception(f"设备 {device.code} 场景 {scene.scene_name} 启动异常 : {e}")
+
+ def stop_device_thread(self, device_id):
+ """停止指定设备的检测线程,并确保其成功停止"""
+ if device_id in self.device_tasks:
+ logger.info(f'stop device {device_id} thread')
+ # 获取线程 ID 和 future 对象
+ thread_id = self.device_tasks[device_id].thread_id
+ future = self.thread_pool.get_task_future(thread_id=thread_id)
+
+ if future:
+ if not future.done():
+ # 任务正在运行,调用 stop_detection_task 停止任务
+ self.device_tasks[device_id].stop_task()
+ try:
+ # 设置超时时间等待任务停止(例如10秒)
+ result = future.result(timeout=30)
+ logger.info(f"Task {thread_id} stopped successfully.")
+ except concurrent.futures.TimeoutError as te:
+ logger.error(f"Task {thread_id} did not stop within the timeout.")
+ except Exception as e:
+ logger.exception(f"Task {thread_id} encountered an error while stopping: {e}")
+ finally:
+ # 确保无论任务是否停止,都将其从任务列表中移除
+ del self.device_tasks[device_id]
+ else:
+ logger.info(f"Task {thread_id} has already stopped.")
+ # 任务已停止,直接从任务列表中移除
+ del self.device_tasks[device_id]
+ else:
+ logger.warning(f"No task found for {thread_id} .")
+ else:
+ logger.warning(f"No task exists for device {device_id} in device_tasks.")
+
+ def restart_device_thread(self, device_id):
+ try:
+ self.stop_device_thread(device_id)
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ except Exception as e:
+ logger.error(e)
+
+ def on_device_change(self, device_id, change_type):
+ logger.info(f"on device change, device {device_id} {change_type}")
+ """设备变化回调"""
+ if change_type == NotifyChangeType.DEVICE_CREATE:
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ elif change_type == NotifyChangeType.DEVICE_DELETE:
+ self.stop_device_thread(device_id)
+ elif change_type == NotifyChangeType.DEVICE_UPDATE:
+ # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
+ if device_id in self.device_tasks:
+ old_device = self.device_tasks[device_id].device
+ new_device = self.device_service.get_device(device_id)
+
+ self.restart_device_thread(device_id)
+
+ def on_scene_change(self, scene_id, change_type):
+ """模型变化回调"""
+ if change_type == NotifyChangeType.SCENE_UPDATE:
+
+ devices_to_reload = [
+ device_id for device_id, relation_info in self.device_scene_relations.items()
+ if relation_info.scene_id == scene_id
+ ]
+ for device_id in devices_to_reload:
+ self.restart_device_thread(device_id)
+
+ def on_relation_change(self, device_id, change_type):
+ """设备模型关系变化回调"""
+ # 判断模型绑定关系是否真的发生了变化,避免不必要的重启
+ if change_type == NotifyChangeType.DEVICE_SCENE_RELATION_UPDATE:
+ self.restart_device_thread(device_id)
diff --git a/apis/device.py b/apis/device.py
index f786851..85fb10c 100644
--- a/apis/device.py
+++ b/apis/device.py
@@ -3,15 +3,18 @@
from fastapi import APIRouter, Depends, Query
from sqlmodel import Session
-from algo.algo_runner import AlgoRunner
from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param
from db.database import get_db
from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo
from services.device_service import DeviceService
-from algo.algo_runner_manager import get_algo_runner
+from app_instance import get_app
router = APIRouter()
+app = get_app()
+
+def get_service():
+ return app.state.device_service
@router.get("/list", response_model=StandardResponse[List[DeviceInfo]])
@@ -45,15 +48,13 @@
@router.post("/add", response_model=StandardResponse[DeviceInfo])
-def create_device(device_data: DeviceCreate, algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.device_service
+def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)):
device = service.create_device(device_data)
return standard_response(data=device)
@router.post("/update", response_model=StandardResponse[DeviceInfo])
-def update_device(device_data: DeviceUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.device_service
+def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)):
device = service.update_device(device_data)
if not device:
return standard_error_response(data=device_data, message="Device not found")
@@ -61,8 +62,7 @@
@router.delete("/delete", response_model=StandardResponse[int])
-def delete_device(device_id: int, algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.device_service
+def delete_device(device_id: int, service: DeviceService = Depends(get_service)):
device = service.delete_device(device_id)
if not device:
return standard_error_response(data=device_id, message="Device not found")
diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py
index 9152eca..443176b 100644
--- a/apis/device_model_realtion.py
+++ b/apis/device_model_realtion.py
@@ -8,10 +8,13 @@
from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation
from services.device_model_relation_service import DeviceModelRelationService
-from algo.algo_runner import AlgoRunner
-from algo.algo_runner_manager import get_algo_runner
+from app_instance import get_app
router = APIRouter()
+app = get_app()
+
+def get_service():
+ return app.state.device_model_relation_service
@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]])
@@ -35,12 +38,10 @@
@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
def update_by_device(relation_data: List[DeviceModelRelationCreate],
device_id: int = Query(...),
- algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.relation_service
+ service: DeviceModelRelationService = Depends(get_service)):
relations = service.update_relations_by_device(device_id, relation_data)
return standard_response(data=relations)
-
# @router.delete("/delete_by_device", response_model=StandardResponse[int])
# def delete_device(device_id: int, db: Session = Depends(get_db)):
# service = DeviceModelRelationService(db)
diff --git a/apis/model.py b/apis/model.py
index 264e10a..78c064f 100644
--- a/apis/model.py
+++ b/apis/model.py
@@ -4,15 +4,17 @@
from sqlmodel import Session
from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param
+from app_instance import get_app
from db.database import get_db
from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo
from services.model_service import ModelService
-from algo.algo_runner import AlgoRunner
-from algo.algo_runner_manager import get_algo_runner
router = APIRouter()
+app = get_app()
+def get_service():
+ return app.state.model_service
@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]])
def get_model_list(
@@ -60,8 +62,7 @@
@router.post("/update", response_model=StandardResponse[AlgoModelInfo])
def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"),
file: UploadFile = File(None, description="模型文件"),
- algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.model_service
+ service: ModelService = Depends(get_service)):
model_data = AlgoModelUpdate.parse_raw(json_data)
model = service.update_model(model_data, file)
if not model:
diff --git a/apis/scene.py b/apis/scene.py
index b0a4d22..f84b54b 100644
--- a/apis/scene.py
+++ b/apis/scene.py
@@ -9,8 +9,6 @@
from entity.scene import SceneInfo
from services.scene_service import SceneService
-from algo.algo_runner import AlgoRunner
-from algo.algo_runner_manager import get_algo_runner
router = APIRouter()
diff --git a/app_instance.py b/app_instance.py
new file mode 100644
index 0000000..b72926d
--- /dev/null
+++ b/app_instance.py
@@ -0,0 +1,77 @@
+import asyncio
+from contextlib import asynccontextmanager
+
+from fastapi import FastAPI, HTTPException
+
+from algo.algo_runner import AlgoRunner
+from algo.scene_runner import SceneRunner
+
+from common.biz_exception import BizExceptionHandlers
+from common.global_logger import logger
+from db.database import get_db
+from services.device_model_relation_service import DeviceModelRelationService
+from services.device_scene_relation_service import DeviceSceneRelationService
+from services.device_service import DeviceService
+from services.model_service import ModelService
+from services.scene_service import SceneService
+from tcp.tcp_manager import TcpManager
+
+_app = None # 创建一个私有变量来存储 app 实例
+
+
+def create_app() -> FastAPI:
+ global _app
+ if _app is None:
+ _app = FastAPI()
+
+ @asynccontextmanager
+ async def lifespan(app: FastAPI):
+ main_loop = asyncio.get_running_loop()
+
+ with next(get_db()) as db:
+ device_service = DeviceService(db)
+ model_service = ModelService(db)
+ model_relation_service = DeviceModelRelationService(db)
+ scene_service = SceneService(db)
+ scene_relation_service = DeviceSceneRelationService(db)
+
+ app.state.device_service = device_service
+ app.state.model_service = model_service
+ app.state.model_relation_service = model_relation_service
+ app.state.scene_service = scene_service
+ app.state.scene_relation_service = scene_relation_service
+
+ tcp_manager = TcpManager(device_service=device_service)
+ app.state.tcp_manager = tcp_manager
+ await tcp_manager.start()
+
+ algo_runner = AlgoRunner(
+ device_service=device_service,
+ model_service=model_service,
+ relation_service=model_relation_service,
+ )
+ app.state.algo_runner = algo_runner
+ await algo_runner.start()
+
+ scene_runner = SceneRunner(
+ device_service=device_service,
+ scene_service=scene_service,
+ relation_service=scene_relation_service,
+ tcp_manager=tcp_manager,
+ main_loop=main_loop
+ )
+ app.state.scene_runner = scene_runner
+ await scene_runner.start()
+
+ yield # 允许请求处理
+
+ logger.info("Shutting down application...")
+
+ _app.router.lifespan_context = lifespan
+ _app.add_exception_handler(HTTPException, BizExceptionHandlers.biz_exception_handler)
+
+ return _app
+
+
+def get_app() -> FastAPI:
+ return _app or create_app()
diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml
index d0876a7..33f5c97 100644
--- a/.idea/safe-algo-pro.iml
+++ b/.idea/safe-algo-pro.iml
@@ -5,4 +5,7 @@
+
+
+
\ No newline at end of file
diff --git a/algo/algo_runner.py b/algo/algo_runner.py
index d59d93a..40b8d1e 100644
--- a/algo/algo_runner.py
+++ b/algo/algo_runner.py
@@ -5,7 +5,7 @@
from algo.device_detection_task import DeviceDetectionTask
from algo.model_manager import ModelManager
-from common.consts import NotifyChangeType
+from common.consts import NotifyChangeType, DEVICE_MODE
from common.global_thread_pool import GlobalThreadPool
from db.database import get_db
from entity.device import Device
@@ -14,23 +14,25 @@
from services.model_service import ModelService
from common.global_logger import logger
-
class AlgoRunner:
- def __init__(self):
- with next(get_db()) as db:
- self.device_service = DeviceService(db)
- self.model_service = ModelService(db)
- self.relation_service = DeviceModelRelationService(db)
- self.model_manager = ModelManager(db)
+ def __init__(self,
+ device_service : DeviceService,
+ model_service: ModelService,
+ relation_service: DeviceModelRelationService):
+
+ self.device_service = device_service
+ self.model_service = model_service
+ self.relation_service = relation_service
+ self.model_manager = ModelManager(model_service)
self.thread_pool = GlobalThreadPool()
self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程
self.device_model_relations = {}
# 注册设备和模型的变化回调
- self.device_service.register_change_callback(self.on_device_change)
- self.model_service.register_change_callback(self.on_model_change)
- self.relation_service.register_change_callback(self.on_relation_change)
+ # self.device_service.register_change_callback(self.on_device_change)
+ # self.model_service.register_change_callback(self.on_model_change)
+ # self.relation_service.register_change_callback(self.on_relation_change)
async def start(self):
logger.info("Starting AlgoRunner...")
@@ -44,10 +46,6 @@
for device in devices:
self.start_device_thread(device)
- # @staticmethod
- # def get_device_thread_id(device_id):
- # return f'device_{device_id}'
-
def start_device_thread(self, device: Device):
device = copy.deepcopy(device)
"""为单个设备启动检测线程"""
@@ -56,7 +54,10 @@
logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程')
return
- thread_id = f'device_{device.id}_{uuid.uuid4()}'
+ if not device.mode == DEVICE_MODE.ALGO:
+ return
+
+ thread_id = f'device_{device.id}_algo_{uuid.uuid4()}'
logger.info(f'start thread {thread_id}, device info: {device}')
# if self.thread_pool.check_task_is_running(thread_id=thread_id):
@@ -135,9 +136,10 @@
elif change_type == NotifyChangeType.DEVICE_DELETE:
self.stop_device_thread(device_id)
elif change_type == NotifyChangeType.DEVICE_UPDATE:
- # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
- old_device = self.device_tasks[device_id].device
- new_device = self.device_service.get_device(device_id)
+ if device_id in self.device_tasks:
+ # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
+ old_device = self.device_tasks[device_id].device
+ new_device = self.device_service.get_device(device_id)
self.restart_device_thread(device_id)
diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py
index d7d8015..9109d75 100644
--- a/algo/algo_runner_manager.py
+++ b/algo/algo_runner_manager.py
@@ -1,7 +1,8 @@
from algo.algo_runner import AlgoRunner
-algo_runner = AlgoRunner()
+# algo_runner = AlgoRunner()
def get_algo_runner():
- return algo_runner
+ pass
+ # return algo_runner
diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py
index b5a2d9e..e222ca7 100644
--- a/algo/device_detection_task.py
+++ b/algo/device_detection_task.py
@@ -9,7 +9,7 @@
from common.device_status_manager import DeviceStatusManager
from common.global_logger import logger
from common.global_thread_pool import GlobalThreadPool
-from common.string_utils import camel_to_snake
+from common.string_utils import camel_to_snake, get_class
from db.database import get_db
from entity.device import Device
from entity.frame_analysis_result import FrameAnalysisResultCreate
@@ -17,11 +17,7 @@
from services.frame_analysis_result_service import FrameAnalysisResultService
-def get_class(module_name, class_name):
- # 动态导入模块
- module = importlib.import_module(module_name)
- # 使用 getattr 从模块中获取类
- return getattr(module, class_name)
+
@dataclass
diff --git a/algo/model_manager.py b/algo/model_manager.py
index 94a058c..e3bc4c9 100644
--- a/algo/model_manager.py
+++ b/algo/model_manager.py
@@ -19,9 +19,9 @@
class ModelManager:
- def __init__(self, db, model_warm_up=5):
- self.db = db
- self.model_service = ModelService(self.db)
+ def __init__(self, model_service: ModelService, model_warm_up=5):
+ # self.db = db
+ self.model_service = model_service
self.models: Dict[int, AlgoModelExec] = {}
self.model_warm_up = model_warm_up
@@ -37,6 +37,8 @@
logger.info('loading models')
self.models = {}
self.query_model_inuse()
+ if not self.models:
+ logger.info("no model in use")
for algo_model_id, algo_model_exec in self.models.items():
self.load_model(algo_model_exec)
diff --git a/algo/scene_runner.py b/algo/scene_runner.py
new file mode 100644
index 0000000..2fb2b2b
--- /dev/null
+++ b/algo/scene_runner.py
@@ -0,0 +1,146 @@
+import concurrent
+import copy
+import uuid
+from typing import Dict
+
+from common.consts import DEVICE_MODE, NotifyChangeType
+from common.global_logger import logger
+from common.global_thread_pool import GlobalThreadPool
+from common.string_utils import get_class, camel_to_snake
+from entity.device import Device
+from scene_handler.base_scene_handler import BaseSceneHandler
+from services.device_scene_relation_service import DeviceSceneRelationService
+from services.device_service import DeviceService
+from services.scene_service import SceneService
+from tcp.tcp_manager import TcpManager
+
+
+class SceneRunner:
+ def __init__(self,
+ device_service: DeviceService,
+ scene_service: SceneService,
+ relation_service: DeviceSceneRelationService,
+ tcp_manager: TcpManager,
+ main_loop):
+ self.device_service = device_service
+ self.scene_service = scene_service
+ self.relation_service = relation_service
+ self.tcp_manager = tcp_manager
+ self.main_loop = main_loop
+
+ self.thread_pool = GlobalThreadPool()
+ self.device_tasks: Dict[int, BaseSceneHandler] = {}
+ self.device_scene_relations = {}
+
+ # 注册设备和模型的变化回调
+ # self.device_service.register_change_callback(self.on_device_change)
+ # self.scene_service.register_change_callback(self.on_scene_change)
+ # self.relation_service.register_change_callback(self.on_relation_change)
+
+ async def start(self):
+ self.load_and_start_devices()
+
+ def load_and_start_devices(self):
+ """从数据库读取设备列表并启动线程"""
+ devices = self.device_service.get_device_list()
+ devices = [device for device in devices if device.mode == DEVICE_MODE.SCENE]
+ for device in devices:
+ self.start_device_thread(device)
+
+ def start_device_thread(self, device: Device):
+ device = copy.deepcopy(device)
+ """为单个设备启动检测线程"""
+
+ if not device.input_stream_url:
+ logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程')
+ return
+
+ if not device.mode == DEVICE_MODE.SCENE:
+ return
+
+ thread_id = f'device_{device.id}_scene_{uuid.uuid4()}'
+ scene = self.relation_service.get_device_scene(device.id)
+ if scene:
+ self.device_scene_relations[device.id] = scene
+ try:
+ handle_task_name = scene.scene_handle_task
+ handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name)
+ handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop)
+ future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id)
+ self.device_tasks[device.id] = handler_instance
+ logger.info(f'start thread {thread_id}, device info: {device}')
+ except Exception as e:
+ logger.exception(f"设备 {device.code} 场景 {scene.scene_name} 启动异常 : {e}")
+
+ def stop_device_thread(self, device_id):
+ """停止指定设备的检测线程,并确保其成功停止"""
+ if device_id in self.device_tasks:
+ logger.info(f'stop device {device_id} thread')
+ # 获取线程 ID 和 future 对象
+ thread_id = self.device_tasks[device_id].thread_id
+ future = self.thread_pool.get_task_future(thread_id=thread_id)
+
+ if future:
+ if not future.done():
+ # 任务正在运行,调用 stop_detection_task 停止任务
+ self.device_tasks[device_id].stop_task()
+ try:
+ # 设置超时时间等待任务停止(例如10秒)
+ result = future.result(timeout=30)
+ logger.info(f"Task {thread_id} stopped successfully.")
+ except concurrent.futures.TimeoutError as te:
+ logger.error(f"Task {thread_id} did not stop within the timeout.")
+ except Exception as e:
+ logger.exception(f"Task {thread_id} encountered an error while stopping: {e}")
+ finally:
+ # 确保无论任务是否停止,都将其从任务列表中移除
+ del self.device_tasks[device_id]
+ else:
+ logger.info(f"Task {thread_id} has already stopped.")
+ # 任务已停止,直接从任务列表中移除
+ del self.device_tasks[device_id]
+ else:
+ logger.warning(f"No task found for {thread_id} .")
+ else:
+ logger.warning(f"No task exists for device {device_id} in device_tasks.")
+
+ def restart_device_thread(self, device_id):
+ try:
+ self.stop_device_thread(device_id)
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ except Exception as e:
+ logger.error(e)
+
+ def on_device_change(self, device_id, change_type):
+ logger.info(f"on device change, device {device_id} {change_type}")
+ """设备变化回调"""
+ if change_type == NotifyChangeType.DEVICE_CREATE:
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ elif change_type == NotifyChangeType.DEVICE_DELETE:
+ self.stop_device_thread(device_id)
+ elif change_type == NotifyChangeType.DEVICE_UPDATE:
+ # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
+ if device_id in self.device_tasks:
+ old_device = self.device_tasks[device_id].device
+ new_device = self.device_service.get_device(device_id)
+
+ self.restart_device_thread(device_id)
+
+ def on_scene_change(self, scene_id, change_type):
+ """模型变化回调"""
+ if change_type == NotifyChangeType.SCENE_UPDATE:
+
+ devices_to_reload = [
+ device_id for device_id, relation_info in self.device_scene_relations.items()
+ if relation_info.scene_id == scene_id
+ ]
+ for device_id in devices_to_reload:
+ self.restart_device_thread(device_id)
+
+ def on_relation_change(self, device_id, change_type):
+ """设备模型关系变化回调"""
+ # 判断模型绑定关系是否真的发生了变化,避免不必要的重启
+ if change_type == NotifyChangeType.DEVICE_SCENE_RELATION_UPDATE:
+ self.restart_device_thread(device_id)
diff --git a/apis/device.py b/apis/device.py
index f786851..85fb10c 100644
--- a/apis/device.py
+++ b/apis/device.py
@@ -3,15 +3,18 @@
from fastapi import APIRouter, Depends, Query
from sqlmodel import Session
-from algo.algo_runner import AlgoRunner
from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param
from db.database import get_db
from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo
from services.device_service import DeviceService
-from algo.algo_runner_manager import get_algo_runner
+from app_instance import get_app
router = APIRouter()
+app = get_app()
+
+def get_service():
+ return app.state.device_service
@router.get("/list", response_model=StandardResponse[List[DeviceInfo]])
@@ -45,15 +48,13 @@
@router.post("/add", response_model=StandardResponse[DeviceInfo])
-def create_device(device_data: DeviceCreate, algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.device_service
+def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)):
device = service.create_device(device_data)
return standard_response(data=device)
@router.post("/update", response_model=StandardResponse[DeviceInfo])
-def update_device(device_data: DeviceUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.device_service
+def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)):
device = service.update_device(device_data)
if not device:
return standard_error_response(data=device_data, message="Device not found")
@@ -61,8 +62,7 @@
@router.delete("/delete", response_model=StandardResponse[int])
-def delete_device(device_id: int, algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.device_service
+def delete_device(device_id: int, service: DeviceService = Depends(get_service)):
device = service.delete_device(device_id)
if not device:
return standard_error_response(data=device_id, message="Device not found")
diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py
index 9152eca..443176b 100644
--- a/apis/device_model_realtion.py
+++ b/apis/device_model_realtion.py
@@ -8,10 +8,13 @@
from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation
from services.device_model_relation_service import DeviceModelRelationService
-from algo.algo_runner import AlgoRunner
-from algo.algo_runner_manager import get_algo_runner
+from app_instance import get_app
router = APIRouter()
+app = get_app()
+
+def get_service():
+ return app.state.device_model_relation_service
@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]])
@@ -35,12 +38,10 @@
@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
def update_by_device(relation_data: List[DeviceModelRelationCreate],
device_id: int = Query(...),
- algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.relation_service
+ service: DeviceModelRelationService = Depends(get_service)):
relations = service.update_relations_by_device(device_id, relation_data)
return standard_response(data=relations)
-
# @router.delete("/delete_by_device", response_model=StandardResponse[int])
# def delete_device(device_id: int, db: Session = Depends(get_db)):
# service = DeviceModelRelationService(db)
diff --git a/apis/model.py b/apis/model.py
index 264e10a..78c064f 100644
--- a/apis/model.py
+++ b/apis/model.py
@@ -4,15 +4,17 @@
from sqlmodel import Session
from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param
+from app_instance import get_app
from db.database import get_db
from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo
from services.model_service import ModelService
-from algo.algo_runner import AlgoRunner
-from algo.algo_runner_manager import get_algo_runner
router = APIRouter()
+app = get_app()
+def get_service():
+ return app.state.model_service
@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]])
def get_model_list(
@@ -60,8 +62,7 @@
@router.post("/update", response_model=StandardResponse[AlgoModelInfo])
def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"),
file: UploadFile = File(None, description="模型文件"),
- algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.model_service
+ service: ModelService = Depends(get_service)):
model_data = AlgoModelUpdate.parse_raw(json_data)
model = service.update_model(model_data, file)
if not model:
diff --git a/apis/scene.py b/apis/scene.py
index b0a4d22..f84b54b 100644
--- a/apis/scene.py
+++ b/apis/scene.py
@@ -9,8 +9,6 @@
from entity.scene import SceneInfo
from services.scene_service import SceneService
-from algo.algo_runner import AlgoRunner
-from algo.algo_runner_manager import get_algo_runner
router = APIRouter()
diff --git a/app_instance.py b/app_instance.py
new file mode 100644
index 0000000..b72926d
--- /dev/null
+++ b/app_instance.py
@@ -0,0 +1,77 @@
+import asyncio
+from contextlib import asynccontextmanager
+
+from fastapi import FastAPI, HTTPException
+
+from algo.algo_runner import AlgoRunner
+from algo.scene_runner import SceneRunner
+
+from common.biz_exception import BizExceptionHandlers
+from common.global_logger import logger
+from db.database import get_db
+from services.device_model_relation_service import DeviceModelRelationService
+from services.device_scene_relation_service import DeviceSceneRelationService
+from services.device_service import DeviceService
+from services.model_service import ModelService
+from services.scene_service import SceneService
+from tcp.tcp_manager import TcpManager
+
+_app = None # 创建一个私有变量来存储 app 实例
+
+
+def create_app() -> FastAPI:
+ global _app
+ if _app is None:
+ _app = FastAPI()
+
+ @asynccontextmanager
+ async def lifespan(app: FastAPI):
+ main_loop = asyncio.get_running_loop()
+
+ with next(get_db()) as db:
+ device_service = DeviceService(db)
+ model_service = ModelService(db)
+ model_relation_service = DeviceModelRelationService(db)
+ scene_service = SceneService(db)
+ scene_relation_service = DeviceSceneRelationService(db)
+
+ app.state.device_service = device_service
+ app.state.model_service = model_service
+ app.state.model_relation_service = model_relation_service
+ app.state.scene_service = scene_service
+ app.state.scene_relation_service = scene_relation_service
+
+ tcp_manager = TcpManager(device_service=device_service)
+ app.state.tcp_manager = tcp_manager
+ await tcp_manager.start()
+
+ algo_runner = AlgoRunner(
+ device_service=device_service,
+ model_service=model_service,
+ relation_service=model_relation_service,
+ )
+ app.state.algo_runner = algo_runner
+ await algo_runner.start()
+
+ scene_runner = SceneRunner(
+ device_service=device_service,
+ scene_service=scene_service,
+ relation_service=scene_relation_service,
+ tcp_manager=tcp_manager,
+ main_loop=main_loop
+ )
+ app.state.scene_runner = scene_runner
+ await scene_runner.start()
+
+ yield # 允许请求处理
+
+ logger.info("Shutting down application...")
+
+ _app.router.lifespan_context = lifespan
+ _app.add_exception_handler(HTTPException, BizExceptionHandlers.biz_exception_handler)
+
+ return _app
+
+
+def get_app() -> FastAPI:
+ return _app or create_app()
diff --git a/common/string_utils.py b/common/string_utils.py
index 42d987b..7ff7f03 100644
--- a/common/string_utils.py
+++ b/common/string_utils.py
@@ -1,5 +1,12 @@
+import importlib
import re
+def get_class(module_name, class_name):
+ # 动态导入模块
+ module = importlib.import_module(module_name)
+ # 使用 getattr 从模块中获取类
+ return getattr(module, class_name)
+
def camel_to_snake(name):
# 将大写字母前加上下划线,并将整个字符串转换为小写
diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml
index d0876a7..33f5c97 100644
--- a/.idea/safe-algo-pro.iml
+++ b/.idea/safe-algo-pro.iml
@@ -5,4 +5,7 @@
+
+
+
\ No newline at end of file
diff --git a/algo/algo_runner.py b/algo/algo_runner.py
index d59d93a..40b8d1e 100644
--- a/algo/algo_runner.py
+++ b/algo/algo_runner.py
@@ -5,7 +5,7 @@
from algo.device_detection_task import DeviceDetectionTask
from algo.model_manager import ModelManager
-from common.consts import NotifyChangeType
+from common.consts import NotifyChangeType, DEVICE_MODE
from common.global_thread_pool import GlobalThreadPool
from db.database import get_db
from entity.device import Device
@@ -14,23 +14,25 @@
from services.model_service import ModelService
from common.global_logger import logger
-
class AlgoRunner:
- def __init__(self):
- with next(get_db()) as db:
- self.device_service = DeviceService(db)
- self.model_service = ModelService(db)
- self.relation_service = DeviceModelRelationService(db)
- self.model_manager = ModelManager(db)
+ def __init__(self,
+ device_service : DeviceService,
+ model_service: ModelService,
+ relation_service: DeviceModelRelationService):
+
+ self.device_service = device_service
+ self.model_service = model_service
+ self.relation_service = relation_service
+ self.model_manager = ModelManager(model_service)
self.thread_pool = GlobalThreadPool()
self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程
self.device_model_relations = {}
# 注册设备和模型的变化回调
- self.device_service.register_change_callback(self.on_device_change)
- self.model_service.register_change_callback(self.on_model_change)
- self.relation_service.register_change_callback(self.on_relation_change)
+ # self.device_service.register_change_callback(self.on_device_change)
+ # self.model_service.register_change_callback(self.on_model_change)
+ # self.relation_service.register_change_callback(self.on_relation_change)
async def start(self):
logger.info("Starting AlgoRunner...")
@@ -44,10 +46,6 @@
for device in devices:
self.start_device_thread(device)
- # @staticmethod
- # def get_device_thread_id(device_id):
- # return f'device_{device_id}'
-
def start_device_thread(self, device: Device):
device = copy.deepcopy(device)
"""为单个设备启动检测线程"""
@@ -56,7 +54,10 @@
logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程')
return
- thread_id = f'device_{device.id}_{uuid.uuid4()}'
+ if not device.mode == DEVICE_MODE.ALGO:
+ return
+
+ thread_id = f'device_{device.id}_algo_{uuid.uuid4()}'
logger.info(f'start thread {thread_id}, device info: {device}')
# if self.thread_pool.check_task_is_running(thread_id=thread_id):
@@ -135,9 +136,10 @@
elif change_type == NotifyChangeType.DEVICE_DELETE:
self.stop_device_thread(device_id)
elif change_type == NotifyChangeType.DEVICE_UPDATE:
- # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
- old_device = self.device_tasks[device_id].device
- new_device = self.device_service.get_device(device_id)
+ if device_id in self.device_tasks:
+ # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
+ old_device = self.device_tasks[device_id].device
+ new_device = self.device_service.get_device(device_id)
self.restart_device_thread(device_id)
diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py
index d7d8015..9109d75 100644
--- a/algo/algo_runner_manager.py
+++ b/algo/algo_runner_manager.py
@@ -1,7 +1,8 @@
from algo.algo_runner import AlgoRunner
-algo_runner = AlgoRunner()
+# algo_runner = AlgoRunner()
def get_algo_runner():
- return algo_runner
+ pass
+ # return algo_runner
diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py
index b5a2d9e..e222ca7 100644
--- a/algo/device_detection_task.py
+++ b/algo/device_detection_task.py
@@ -9,7 +9,7 @@
from common.device_status_manager import DeviceStatusManager
from common.global_logger import logger
from common.global_thread_pool import GlobalThreadPool
-from common.string_utils import camel_to_snake
+from common.string_utils import camel_to_snake, get_class
from db.database import get_db
from entity.device import Device
from entity.frame_analysis_result import FrameAnalysisResultCreate
@@ -17,11 +17,7 @@
from services.frame_analysis_result_service import FrameAnalysisResultService
-def get_class(module_name, class_name):
- # 动态导入模块
- module = importlib.import_module(module_name)
- # 使用 getattr 从模块中获取类
- return getattr(module, class_name)
+
@dataclass
diff --git a/algo/model_manager.py b/algo/model_manager.py
index 94a058c..e3bc4c9 100644
--- a/algo/model_manager.py
+++ b/algo/model_manager.py
@@ -19,9 +19,9 @@
class ModelManager:
- def __init__(self, db, model_warm_up=5):
- self.db = db
- self.model_service = ModelService(self.db)
+ def __init__(self, model_service: ModelService, model_warm_up=5):
+ # self.db = db
+ self.model_service = model_service
self.models: Dict[int, AlgoModelExec] = {}
self.model_warm_up = model_warm_up
@@ -37,6 +37,8 @@
logger.info('loading models')
self.models = {}
self.query_model_inuse()
+ if not self.models:
+ logger.info("no model in use")
for algo_model_id, algo_model_exec in self.models.items():
self.load_model(algo_model_exec)
diff --git a/algo/scene_runner.py b/algo/scene_runner.py
new file mode 100644
index 0000000..2fb2b2b
--- /dev/null
+++ b/algo/scene_runner.py
@@ -0,0 +1,146 @@
+import concurrent
+import copy
+import uuid
+from typing import Dict
+
+from common.consts import DEVICE_MODE, NotifyChangeType
+from common.global_logger import logger
+from common.global_thread_pool import GlobalThreadPool
+from common.string_utils import get_class, camel_to_snake
+from entity.device import Device
+from scene_handler.base_scene_handler import BaseSceneHandler
+from services.device_scene_relation_service import DeviceSceneRelationService
+from services.device_service import DeviceService
+from services.scene_service import SceneService
+from tcp.tcp_manager import TcpManager
+
+
+class SceneRunner:
+ def __init__(self,
+ device_service: DeviceService,
+ scene_service: SceneService,
+ relation_service: DeviceSceneRelationService,
+ tcp_manager: TcpManager,
+ main_loop):
+ self.device_service = device_service
+ self.scene_service = scene_service
+ self.relation_service = relation_service
+ self.tcp_manager = tcp_manager
+ self.main_loop = main_loop
+
+ self.thread_pool = GlobalThreadPool()
+ self.device_tasks: Dict[int, BaseSceneHandler] = {}
+ self.device_scene_relations = {}
+
+ # 注册设备和模型的变化回调
+ # self.device_service.register_change_callback(self.on_device_change)
+ # self.scene_service.register_change_callback(self.on_scene_change)
+ # self.relation_service.register_change_callback(self.on_relation_change)
+
+ async def start(self):
+ self.load_and_start_devices()
+
+ def load_and_start_devices(self):
+ """从数据库读取设备列表并启动线程"""
+ devices = self.device_service.get_device_list()
+ devices = [device for device in devices if device.mode == DEVICE_MODE.SCENE]
+ for device in devices:
+ self.start_device_thread(device)
+
+ def start_device_thread(self, device: Device):
+ device = copy.deepcopy(device)
+ """为单个设备启动检测线程"""
+
+ if not device.input_stream_url:
+ logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程')
+ return
+
+ if not device.mode == DEVICE_MODE.SCENE:
+ return
+
+ thread_id = f'device_{device.id}_scene_{uuid.uuid4()}'
+ scene = self.relation_service.get_device_scene(device.id)
+ if scene:
+ self.device_scene_relations[device.id] = scene
+ try:
+ handle_task_name = scene.scene_handle_task
+ handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name)
+ handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop)
+ future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id)
+ self.device_tasks[device.id] = handler_instance
+ logger.info(f'start thread {thread_id}, device info: {device}')
+ except Exception as e:
+ logger.exception(f"设备 {device.code} 场景 {scene.scene_name} 启动异常 : {e}")
+
+ def stop_device_thread(self, device_id):
+ """停止指定设备的检测线程,并确保其成功停止"""
+ if device_id in self.device_tasks:
+ logger.info(f'stop device {device_id} thread')
+ # 获取线程 ID 和 future 对象
+ thread_id = self.device_tasks[device_id].thread_id
+ future = self.thread_pool.get_task_future(thread_id=thread_id)
+
+ if future:
+ if not future.done():
+ # 任务正在运行,调用 stop_detection_task 停止任务
+ self.device_tasks[device_id].stop_task()
+ try:
+ # 设置超时时间等待任务停止(例如10秒)
+ result = future.result(timeout=30)
+ logger.info(f"Task {thread_id} stopped successfully.")
+ except concurrent.futures.TimeoutError as te:
+ logger.error(f"Task {thread_id} did not stop within the timeout.")
+ except Exception as e:
+ logger.exception(f"Task {thread_id} encountered an error while stopping: {e}")
+ finally:
+ # 确保无论任务是否停止,都将其从任务列表中移除
+ del self.device_tasks[device_id]
+ else:
+ logger.info(f"Task {thread_id} has already stopped.")
+ # 任务已停止,直接从任务列表中移除
+ del self.device_tasks[device_id]
+ else:
+ logger.warning(f"No task found for {thread_id} .")
+ else:
+ logger.warning(f"No task exists for device {device_id} in device_tasks.")
+
+ def restart_device_thread(self, device_id):
+ try:
+ self.stop_device_thread(device_id)
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ except Exception as e:
+ logger.error(e)
+
+ def on_device_change(self, device_id, change_type):
+ logger.info(f"on device change, device {device_id} {change_type}")
+ """设备变化回调"""
+ if change_type == NotifyChangeType.DEVICE_CREATE:
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ elif change_type == NotifyChangeType.DEVICE_DELETE:
+ self.stop_device_thread(device_id)
+ elif change_type == NotifyChangeType.DEVICE_UPDATE:
+ # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
+ if device_id in self.device_tasks:
+ old_device = self.device_tasks[device_id].device
+ new_device = self.device_service.get_device(device_id)
+
+ self.restart_device_thread(device_id)
+
+ def on_scene_change(self, scene_id, change_type):
+ """模型变化回调"""
+ if change_type == NotifyChangeType.SCENE_UPDATE:
+
+ devices_to_reload = [
+ device_id for device_id, relation_info in self.device_scene_relations.items()
+ if relation_info.scene_id == scene_id
+ ]
+ for device_id in devices_to_reload:
+ self.restart_device_thread(device_id)
+
+ def on_relation_change(self, device_id, change_type):
+ """设备模型关系变化回调"""
+ # 判断模型绑定关系是否真的发生了变化,避免不必要的重启
+ if change_type == NotifyChangeType.DEVICE_SCENE_RELATION_UPDATE:
+ self.restart_device_thread(device_id)
diff --git a/apis/device.py b/apis/device.py
index f786851..85fb10c 100644
--- a/apis/device.py
+++ b/apis/device.py
@@ -3,15 +3,18 @@
from fastapi import APIRouter, Depends, Query
from sqlmodel import Session
-from algo.algo_runner import AlgoRunner
from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param
from db.database import get_db
from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo
from services.device_service import DeviceService
-from algo.algo_runner_manager import get_algo_runner
+from app_instance import get_app
router = APIRouter()
+app = get_app()
+
+def get_service():
+ return app.state.device_service
@router.get("/list", response_model=StandardResponse[List[DeviceInfo]])
@@ -45,15 +48,13 @@
@router.post("/add", response_model=StandardResponse[DeviceInfo])
-def create_device(device_data: DeviceCreate, algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.device_service
+def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)):
device = service.create_device(device_data)
return standard_response(data=device)
@router.post("/update", response_model=StandardResponse[DeviceInfo])
-def update_device(device_data: DeviceUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.device_service
+def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)):
device = service.update_device(device_data)
if not device:
return standard_error_response(data=device_data, message="Device not found")
@@ -61,8 +62,7 @@
@router.delete("/delete", response_model=StandardResponse[int])
-def delete_device(device_id: int, algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.device_service
+def delete_device(device_id: int, service: DeviceService = Depends(get_service)):
device = service.delete_device(device_id)
if not device:
return standard_error_response(data=device_id, message="Device not found")
diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py
index 9152eca..443176b 100644
--- a/apis/device_model_realtion.py
+++ b/apis/device_model_realtion.py
@@ -8,10 +8,13 @@
from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation
from services.device_model_relation_service import DeviceModelRelationService
-from algo.algo_runner import AlgoRunner
-from algo.algo_runner_manager import get_algo_runner
+from app_instance import get_app
router = APIRouter()
+app = get_app()
+
+def get_service():
+ return app.state.device_model_relation_service
@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]])
@@ -35,12 +38,10 @@
@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
def update_by_device(relation_data: List[DeviceModelRelationCreate],
device_id: int = Query(...),
- algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.relation_service
+ service: DeviceModelRelationService = Depends(get_service)):
relations = service.update_relations_by_device(device_id, relation_data)
return standard_response(data=relations)
-
# @router.delete("/delete_by_device", response_model=StandardResponse[int])
# def delete_device(device_id: int, db: Session = Depends(get_db)):
# service = DeviceModelRelationService(db)
diff --git a/apis/model.py b/apis/model.py
index 264e10a..78c064f 100644
--- a/apis/model.py
+++ b/apis/model.py
@@ -4,15 +4,17 @@
from sqlmodel import Session
from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param
+from app_instance import get_app
from db.database import get_db
from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo
from services.model_service import ModelService
-from algo.algo_runner import AlgoRunner
-from algo.algo_runner_manager import get_algo_runner
router = APIRouter()
+app = get_app()
+def get_service():
+ return app.state.model_service
@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]])
def get_model_list(
@@ -60,8 +62,7 @@
@router.post("/update", response_model=StandardResponse[AlgoModelInfo])
def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"),
file: UploadFile = File(None, description="模型文件"),
- algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.model_service
+ service: ModelService = Depends(get_service)):
model_data = AlgoModelUpdate.parse_raw(json_data)
model = service.update_model(model_data, file)
if not model:
diff --git a/apis/scene.py b/apis/scene.py
index b0a4d22..f84b54b 100644
--- a/apis/scene.py
+++ b/apis/scene.py
@@ -9,8 +9,6 @@
from entity.scene import SceneInfo
from services.scene_service import SceneService
-from algo.algo_runner import AlgoRunner
-from algo.algo_runner_manager import get_algo_runner
router = APIRouter()
diff --git a/app_instance.py b/app_instance.py
new file mode 100644
index 0000000..b72926d
--- /dev/null
+++ b/app_instance.py
@@ -0,0 +1,77 @@
+import asyncio
+from contextlib import asynccontextmanager
+
+from fastapi import FastAPI, HTTPException
+
+from algo.algo_runner import AlgoRunner
+from algo.scene_runner import SceneRunner
+
+from common.biz_exception import BizExceptionHandlers
+from common.global_logger import logger
+from db.database import get_db
+from services.device_model_relation_service import DeviceModelRelationService
+from services.device_scene_relation_service import DeviceSceneRelationService
+from services.device_service import DeviceService
+from services.model_service import ModelService
+from services.scene_service import SceneService
+from tcp.tcp_manager import TcpManager
+
+_app = None # 创建一个私有变量来存储 app 实例
+
+
+def create_app() -> FastAPI:
+ global _app
+ if _app is None:
+ _app = FastAPI()
+
+ @asynccontextmanager
+ async def lifespan(app: FastAPI):
+ main_loop = asyncio.get_running_loop()
+
+ with next(get_db()) as db:
+ device_service = DeviceService(db)
+ model_service = ModelService(db)
+ model_relation_service = DeviceModelRelationService(db)
+ scene_service = SceneService(db)
+ scene_relation_service = DeviceSceneRelationService(db)
+
+ app.state.device_service = device_service
+ app.state.model_service = model_service
+ app.state.model_relation_service = model_relation_service
+ app.state.scene_service = scene_service
+ app.state.scene_relation_service = scene_relation_service
+
+ tcp_manager = TcpManager(device_service=device_service)
+ app.state.tcp_manager = tcp_manager
+ await tcp_manager.start()
+
+ algo_runner = AlgoRunner(
+ device_service=device_service,
+ model_service=model_service,
+ relation_service=model_relation_service,
+ )
+ app.state.algo_runner = algo_runner
+ await algo_runner.start()
+
+ scene_runner = SceneRunner(
+ device_service=device_service,
+ scene_service=scene_service,
+ relation_service=scene_relation_service,
+ tcp_manager=tcp_manager,
+ main_loop=main_loop
+ )
+ app.state.scene_runner = scene_runner
+ await scene_runner.start()
+
+ yield # 允许请求处理
+
+ logger.info("Shutting down application...")
+
+ _app.router.lifespan_context = lifespan
+ _app.add_exception_handler(HTTPException, BizExceptionHandlers.biz_exception_handler)
+
+ return _app
+
+
+def get_app() -> FastAPI:
+ return _app or create_app()
diff --git a/common/string_utils.py b/common/string_utils.py
index 42d987b..7ff7f03 100644
--- a/common/string_utils.py
+++ b/common/string_utils.py
@@ -1,5 +1,12 @@
+import importlib
import re
+def get_class(module_name, class_name):
+ # 动态导入模块
+ module = importlib.import_module(module_name)
+ # 使用 getattr 从模块中获取类
+ return getattr(module, class_name)
+
def camel_to_snake(name):
# 将大写字母前加上下划线,并将整个字符串转换为小写
diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db
index b1d8fab..d1acdd4 100644
--- a/db/safe-algo-pro.db
+++ b/db/safe-algo-pro.db
Binary files differ
diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml
index d0876a7..33f5c97 100644
--- a/.idea/safe-algo-pro.iml
+++ b/.idea/safe-algo-pro.iml
@@ -5,4 +5,7 @@
+
+
+
\ No newline at end of file
diff --git a/algo/algo_runner.py b/algo/algo_runner.py
index d59d93a..40b8d1e 100644
--- a/algo/algo_runner.py
+++ b/algo/algo_runner.py
@@ -5,7 +5,7 @@
from algo.device_detection_task import DeviceDetectionTask
from algo.model_manager import ModelManager
-from common.consts import NotifyChangeType
+from common.consts import NotifyChangeType, DEVICE_MODE
from common.global_thread_pool import GlobalThreadPool
from db.database import get_db
from entity.device import Device
@@ -14,23 +14,25 @@
from services.model_service import ModelService
from common.global_logger import logger
-
class AlgoRunner:
- def __init__(self):
- with next(get_db()) as db:
- self.device_service = DeviceService(db)
- self.model_service = ModelService(db)
- self.relation_service = DeviceModelRelationService(db)
- self.model_manager = ModelManager(db)
+ def __init__(self,
+ device_service : DeviceService,
+ model_service: ModelService,
+ relation_service: DeviceModelRelationService):
+
+ self.device_service = device_service
+ self.model_service = model_service
+ self.relation_service = relation_service
+ self.model_manager = ModelManager(model_service)
self.thread_pool = GlobalThreadPool()
self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程
self.device_model_relations = {}
# 注册设备和模型的变化回调
- self.device_service.register_change_callback(self.on_device_change)
- self.model_service.register_change_callback(self.on_model_change)
- self.relation_service.register_change_callback(self.on_relation_change)
+ # self.device_service.register_change_callback(self.on_device_change)
+ # self.model_service.register_change_callback(self.on_model_change)
+ # self.relation_service.register_change_callback(self.on_relation_change)
async def start(self):
logger.info("Starting AlgoRunner...")
@@ -44,10 +46,6 @@
for device in devices:
self.start_device_thread(device)
- # @staticmethod
- # def get_device_thread_id(device_id):
- # return f'device_{device_id}'
-
def start_device_thread(self, device: Device):
device = copy.deepcopy(device)
"""为单个设备启动检测线程"""
@@ -56,7 +54,10 @@
logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程')
return
- thread_id = f'device_{device.id}_{uuid.uuid4()}'
+ if not device.mode == DEVICE_MODE.ALGO:
+ return
+
+ thread_id = f'device_{device.id}_algo_{uuid.uuid4()}'
logger.info(f'start thread {thread_id}, device info: {device}')
# if self.thread_pool.check_task_is_running(thread_id=thread_id):
@@ -135,9 +136,10 @@
elif change_type == NotifyChangeType.DEVICE_DELETE:
self.stop_device_thread(device_id)
elif change_type == NotifyChangeType.DEVICE_UPDATE:
- # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
- old_device = self.device_tasks[device_id].device
- new_device = self.device_service.get_device(device_id)
+ if device_id in self.device_tasks:
+ # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
+ old_device = self.device_tasks[device_id].device
+ new_device = self.device_service.get_device(device_id)
self.restart_device_thread(device_id)
diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py
index d7d8015..9109d75 100644
--- a/algo/algo_runner_manager.py
+++ b/algo/algo_runner_manager.py
@@ -1,7 +1,8 @@
from algo.algo_runner import AlgoRunner
-algo_runner = AlgoRunner()
+# algo_runner = AlgoRunner()
def get_algo_runner():
- return algo_runner
+ pass
+ # return algo_runner
diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py
index b5a2d9e..e222ca7 100644
--- a/algo/device_detection_task.py
+++ b/algo/device_detection_task.py
@@ -9,7 +9,7 @@
from common.device_status_manager import DeviceStatusManager
from common.global_logger import logger
from common.global_thread_pool import GlobalThreadPool
-from common.string_utils import camel_to_snake
+from common.string_utils import camel_to_snake, get_class
from db.database import get_db
from entity.device import Device
from entity.frame_analysis_result import FrameAnalysisResultCreate
@@ -17,11 +17,7 @@
from services.frame_analysis_result_service import FrameAnalysisResultService
-def get_class(module_name, class_name):
- # 动态导入模块
- module = importlib.import_module(module_name)
- # 使用 getattr 从模块中获取类
- return getattr(module, class_name)
+
@dataclass
diff --git a/algo/model_manager.py b/algo/model_manager.py
index 94a058c..e3bc4c9 100644
--- a/algo/model_manager.py
+++ b/algo/model_manager.py
@@ -19,9 +19,9 @@
class ModelManager:
- def __init__(self, db, model_warm_up=5):
- self.db = db
- self.model_service = ModelService(self.db)
+ def __init__(self, model_service: ModelService, model_warm_up=5):
+ # self.db = db
+ self.model_service = model_service
self.models: Dict[int, AlgoModelExec] = {}
self.model_warm_up = model_warm_up
@@ -37,6 +37,8 @@
logger.info('loading models')
self.models = {}
self.query_model_inuse()
+ if not self.models:
+ logger.info("no model in use")
for algo_model_id, algo_model_exec in self.models.items():
self.load_model(algo_model_exec)
diff --git a/algo/scene_runner.py b/algo/scene_runner.py
new file mode 100644
index 0000000..2fb2b2b
--- /dev/null
+++ b/algo/scene_runner.py
@@ -0,0 +1,146 @@
+import concurrent
+import copy
+import uuid
+from typing import Dict
+
+from common.consts import DEVICE_MODE, NotifyChangeType
+from common.global_logger import logger
+from common.global_thread_pool import GlobalThreadPool
+from common.string_utils import get_class, camel_to_snake
+from entity.device import Device
+from scene_handler.base_scene_handler import BaseSceneHandler
+from services.device_scene_relation_service import DeviceSceneRelationService
+from services.device_service import DeviceService
+from services.scene_service import SceneService
+from tcp.tcp_manager import TcpManager
+
+
+class SceneRunner:
+ def __init__(self,
+ device_service: DeviceService,
+ scene_service: SceneService,
+ relation_service: DeviceSceneRelationService,
+ tcp_manager: TcpManager,
+ main_loop):
+ self.device_service = device_service
+ self.scene_service = scene_service
+ self.relation_service = relation_service
+ self.tcp_manager = tcp_manager
+ self.main_loop = main_loop
+
+ self.thread_pool = GlobalThreadPool()
+ self.device_tasks: Dict[int, BaseSceneHandler] = {}
+ self.device_scene_relations = {}
+
+ # 注册设备和模型的变化回调
+ # self.device_service.register_change_callback(self.on_device_change)
+ # self.scene_service.register_change_callback(self.on_scene_change)
+ # self.relation_service.register_change_callback(self.on_relation_change)
+
+ async def start(self):
+ self.load_and_start_devices()
+
+ def load_and_start_devices(self):
+ """从数据库读取设备列表并启动线程"""
+ devices = self.device_service.get_device_list()
+ devices = [device for device in devices if device.mode == DEVICE_MODE.SCENE]
+ for device in devices:
+ self.start_device_thread(device)
+
+ def start_device_thread(self, device: Device):
+ device = copy.deepcopy(device)
+ """为单个设备启动检测线程"""
+
+ if not device.input_stream_url:
+ logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程')
+ return
+
+ if not device.mode == DEVICE_MODE.SCENE:
+ return
+
+ thread_id = f'device_{device.id}_scene_{uuid.uuid4()}'
+ scene = self.relation_service.get_device_scene(device.id)
+ if scene:
+ self.device_scene_relations[device.id] = scene
+ try:
+ handle_task_name = scene.scene_handle_task
+ handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name)
+ handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop)
+ future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id)
+ self.device_tasks[device.id] = handler_instance
+ logger.info(f'start thread {thread_id}, device info: {device}')
+ except Exception as e:
+ logger.exception(f"设备 {device.code} 场景 {scene.scene_name} 启动异常 : {e}")
+
+ def stop_device_thread(self, device_id):
+ """停止指定设备的检测线程,并确保其成功停止"""
+ if device_id in self.device_tasks:
+ logger.info(f'stop device {device_id} thread')
+ # 获取线程 ID 和 future 对象
+ thread_id = self.device_tasks[device_id].thread_id
+ future = self.thread_pool.get_task_future(thread_id=thread_id)
+
+ if future:
+ if not future.done():
+ # 任务正在运行,调用 stop_detection_task 停止任务
+ self.device_tasks[device_id].stop_task()
+ try:
+ # 设置超时时间等待任务停止(例如10秒)
+ result = future.result(timeout=30)
+ logger.info(f"Task {thread_id} stopped successfully.")
+ except concurrent.futures.TimeoutError as te:
+ logger.error(f"Task {thread_id} did not stop within the timeout.")
+ except Exception as e:
+ logger.exception(f"Task {thread_id} encountered an error while stopping: {e}")
+ finally:
+ # 确保无论任务是否停止,都将其从任务列表中移除
+ del self.device_tasks[device_id]
+ else:
+ logger.info(f"Task {thread_id} has already stopped.")
+ # 任务已停止,直接从任务列表中移除
+ del self.device_tasks[device_id]
+ else:
+ logger.warning(f"No task found for {thread_id} .")
+ else:
+ logger.warning(f"No task exists for device {device_id} in device_tasks.")
+
+ def restart_device_thread(self, device_id):
+ try:
+ self.stop_device_thread(device_id)
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ except Exception as e:
+ logger.error(e)
+
+ def on_device_change(self, device_id, change_type):
+ logger.info(f"on device change, device {device_id} {change_type}")
+ """设备变化回调"""
+ if change_type == NotifyChangeType.DEVICE_CREATE:
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ elif change_type == NotifyChangeType.DEVICE_DELETE:
+ self.stop_device_thread(device_id)
+ elif change_type == NotifyChangeType.DEVICE_UPDATE:
+ # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
+ if device_id in self.device_tasks:
+ old_device = self.device_tasks[device_id].device
+ new_device = self.device_service.get_device(device_id)
+
+ self.restart_device_thread(device_id)
+
+ def on_scene_change(self, scene_id, change_type):
+ """模型变化回调"""
+ if change_type == NotifyChangeType.SCENE_UPDATE:
+
+ devices_to_reload = [
+ device_id for device_id, relation_info in self.device_scene_relations.items()
+ if relation_info.scene_id == scene_id
+ ]
+ for device_id in devices_to_reload:
+ self.restart_device_thread(device_id)
+
+ def on_relation_change(self, device_id, change_type):
+ """设备模型关系变化回调"""
+ # 判断模型绑定关系是否真的发生了变化,避免不必要的重启
+ if change_type == NotifyChangeType.DEVICE_SCENE_RELATION_UPDATE:
+ self.restart_device_thread(device_id)
diff --git a/apis/device.py b/apis/device.py
index f786851..85fb10c 100644
--- a/apis/device.py
+++ b/apis/device.py
@@ -3,15 +3,18 @@
from fastapi import APIRouter, Depends, Query
from sqlmodel import Session
-from algo.algo_runner import AlgoRunner
from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param
from db.database import get_db
from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo
from services.device_service import DeviceService
-from algo.algo_runner_manager import get_algo_runner
+from app_instance import get_app
router = APIRouter()
+app = get_app()
+
+def get_service():
+ return app.state.device_service
@router.get("/list", response_model=StandardResponse[List[DeviceInfo]])
@@ -45,15 +48,13 @@
@router.post("/add", response_model=StandardResponse[DeviceInfo])
-def create_device(device_data: DeviceCreate, algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.device_service
+def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)):
device = service.create_device(device_data)
return standard_response(data=device)
@router.post("/update", response_model=StandardResponse[DeviceInfo])
-def update_device(device_data: DeviceUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.device_service
+def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)):
device = service.update_device(device_data)
if not device:
return standard_error_response(data=device_data, message="Device not found")
@@ -61,8 +62,7 @@
@router.delete("/delete", response_model=StandardResponse[int])
-def delete_device(device_id: int, algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.device_service
+def delete_device(device_id: int, service: DeviceService = Depends(get_service)):
device = service.delete_device(device_id)
if not device:
return standard_error_response(data=device_id, message="Device not found")
diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py
index 9152eca..443176b 100644
--- a/apis/device_model_realtion.py
+++ b/apis/device_model_realtion.py
@@ -8,10 +8,13 @@
from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation
from services.device_model_relation_service import DeviceModelRelationService
-from algo.algo_runner import AlgoRunner
-from algo.algo_runner_manager import get_algo_runner
+from app_instance import get_app
router = APIRouter()
+app = get_app()
+
+def get_service():
+ return app.state.device_model_relation_service
@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]])
@@ -35,12 +38,10 @@
@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
def update_by_device(relation_data: List[DeviceModelRelationCreate],
device_id: int = Query(...),
- algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.relation_service
+ service: DeviceModelRelationService = Depends(get_service)):
relations = service.update_relations_by_device(device_id, relation_data)
return standard_response(data=relations)
-
# @router.delete("/delete_by_device", response_model=StandardResponse[int])
# def delete_device(device_id: int, db: Session = Depends(get_db)):
# service = DeviceModelRelationService(db)
diff --git a/apis/model.py b/apis/model.py
index 264e10a..78c064f 100644
--- a/apis/model.py
+++ b/apis/model.py
@@ -4,15 +4,17 @@
from sqlmodel import Session
from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param
+from app_instance import get_app
from db.database import get_db
from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo
from services.model_service import ModelService
-from algo.algo_runner import AlgoRunner
-from algo.algo_runner_manager import get_algo_runner
router = APIRouter()
+app = get_app()
+def get_service():
+ return app.state.model_service
@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]])
def get_model_list(
@@ -60,8 +62,7 @@
@router.post("/update", response_model=StandardResponse[AlgoModelInfo])
def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"),
file: UploadFile = File(None, description="模型文件"),
- algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.model_service
+ service: ModelService = Depends(get_service)):
model_data = AlgoModelUpdate.parse_raw(json_data)
model = service.update_model(model_data, file)
if not model:
diff --git a/apis/scene.py b/apis/scene.py
index b0a4d22..f84b54b 100644
--- a/apis/scene.py
+++ b/apis/scene.py
@@ -9,8 +9,6 @@
from entity.scene import SceneInfo
from services.scene_service import SceneService
-from algo.algo_runner import AlgoRunner
-from algo.algo_runner_manager import get_algo_runner
router = APIRouter()
diff --git a/app_instance.py b/app_instance.py
new file mode 100644
index 0000000..b72926d
--- /dev/null
+++ b/app_instance.py
@@ -0,0 +1,77 @@
+import asyncio
+from contextlib import asynccontextmanager
+
+from fastapi import FastAPI, HTTPException
+
+from algo.algo_runner import AlgoRunner
+from algo.scene_runner import SceneRunner
+
+from common.biz_exception import BizExceptionHandlers
+from common.global_logger import logger
+from db.database import get_db
+from services.device_model_relation_service import DeviceModelRelationService
+from services.device_scene_relation_service import DeviceSceneRelationService
+from services.device_service import DeviceService
+from services.model_service import ModelService
+from services.scene_service import SceneService
+from tcp.tcp_manager import TcpManager
+
+_app = None # 创建一个私有变量来存储 app 实例
+
+
+def create_app() -> FastAPI:
+ global _app
+ if _app is None:
+ _app = FastAPI()
+
+ @asynccontextmanager
+ async def lifespan(app: FastAPI):
+ main_loop = asyncio.get_running_loop()
+
+ with next(get_db()) as db:
+ device_service = DeviceService(db)
+ model_service = ModelService(db)
+ model_relation_service = DeviceModelRelationService(db)
+ scene_service = SceneService(db)
+ scene_relation_service = DeviceSceneRelationService(db)
+
+ app.state.device_service = device_service
+ app.state.model_service = model_service
+ app.state.model_relation_service = model_relation_service
+ app.state.scene_service = scene_service
+ app.state.scene_relation_service = scene_relation_service
+
+ tcp_manager = TcpManager(device_service=device_service)
+ app.state.tcp_manager = tcp_manager
+ await tcp_manager.start()
+
+ algo_runner = AlgoRunner(
+ device_service=device_service,
+ model_service=model_service,
+ relation_service=model_relation_service,
+ )
+ app.state.algo_runner = algo_runner
+ await algo_runner.start()
+
+ scene_runner = SceneRunner(
+ device_service=device_service,
+ scene_service=scene_service,
+ relation_service=scene_relation_service,
+ tcp_manager=tcp_manager,
+ main_loop=main_loop
+ )
+ app.state.scene_runner = scene_runner
+ await scene_runner.start()
+
+ yield # 允许请求处理
+
+ logger.info("Shutting down application...")
+
+ _app.router.lifespan_context = lifespan
+ _app.add_exception_handler(HTTPException, BizExceptionHandlers.biz_exception_handler)
+
+ return _app
+
+
+def get_app() -> FastAPI:
+ return _app or create_app()
diff --git a/common/string_utils.py b/common/string_utils.py
index 42d987b..7ff7f03 100644
--- a/common/string_utils.py
+++ b/common/string_utils.py
@@ -1,5 +1,12 @@
+import importlib
import re
+def get_class(module_name, class_name):
+ # 动态导入模块
+ module = importlib.import_module(module_name)
+ # 使用 getattr 从模块中获取类
+ return getattr(module, class_name)
+
def camel_to_snake(name):
# 将大写字母前加上下划线,并将整个字符串转换为小写
diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db
index b1d8fab..d1acdd4 100644
--- a/db/safe-algo-pro.db
+++ b/db/safe-algo-pro.db
Binary files differ
diff --git a/main.py b/main.py
index 61cbe9e..8e5e464 100644
--- a/main.py
+++ b/main.py
@@ -1,58 +1,21 @@
-import asyncio
-from contextlib import asynccontextmanager
-from fastapi import FastAPI, HTTPException
-
-from algo.algo_runner_manager import get_algo_runner
-from apis.router import router
import uvicorn
import logging
-from common.biz_exception import BizExceptionHandlers
+from app_instance import get_app
from common.global_logger import logger
+app = get_app()
-app = FastAPI()
-
-# # 初始化 AlgoRunner
-# algo_runner = AlgoRunner()
-#
-#
-# # 使用 FastAPI 的 startup 事件来启动 AlgoRunner
-# @app.on_event("startup")
-# async def startup_event():
-# algo_runner.start()
-
-algo_runner = get_algo_runner()
-
-
-@asynccontextmanager
-async def lifespan(app: FastAPI):
- # 应用启动时的初始化
- await algo_runner.start()
- # app.add_event_handler("startup", lambda: asyncio.create_task(algo_runner.start()))
- # 允许请求处理
- yield
- # 应用关闭时的清理逻辑
- logger.info("Shutting down application...")
-
-
-# 包含所有模块的路由
+# 延迟导入 router 并注册路由
+from apis.router import router
app.include_router(router, prefix="/api")
-app.add_exception_handler(HTTPException, BizExceptionHandlers.biz_exception_handler)
-app.router.lifespan_context = lifespan
-
if __name__ == "__main__":
-
# 重定向 uvicorn 的日志
uvicorn_logger = logging.getLogger("uvicorn")
uvicorn_logger.handlers = logger.handlers
uvicorn_logger.setLevel(logging.DEBUG)
- # 重定向 uvicorn 的 access 日志
- # access_logger = logging.getLogger("uvicorn.access")
- # access_logger.handlers = logger.handlers
- # access_logger.setLevel(logging.DEBUG)
uvicorn.run(app, host="0.0.0.0", port=9000, log_config=None)
diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml
index d0876a7..33f5c97 100644
--- a/.idea/safe-algo-pro.iml
+++ b/.idea/safe-algo-pro.iml
@@ -5,4 +5,7 @@
+
+
+
\ No newline at end of file
diff --git a/algo/algo_runner.py b/algo/algo_runner.py
index d59d93a..40b8d1e 100644
--- a/algo/algo_runner.py
+++ b/algo/algo_runner.py
@@ -5,7 +5,7 @@
from algo.device_detection_task import DeviceDetectionTask
from algo.model_manager import ModelManager
-from common.consts import NotifyChangeType
+from common.consts import NotifyChangeType, DEVICE_MODE
from common.global_thread_pool import GlobalThreadPool
from db.database import get_db
from entity.device import Device
@@ -14,23 +14,25 @@
from services.model_service import ModelService
from common.global_logger import logger
-
class AlgoRunner:
- def __init__(self):
- with next(get_db()) as db:
- self.device_service = DeviceService(db)
- self.model_service = ModelService(db)
- self.relation_service = DeviceModelRelationService(db)
- self.model_manager = ModelManager(db)
+ def __init__(self,
+ device_service : DeviceService,
+ model_service: ModelService,
+ relation_service: DeviceModelRelationService):
+
+ self.device_service = device_service
+ self.model_service = model_service
+ self.relation_service = relation_service
+ self.model_manager = ModelManager(model_service)
self.thread_pool = GlobalThreadPool()
self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程
self.device_model_relations = {}
# 注册设备和模型的变化回调
- self.device_service.register_change_callback(self.on_device_change)
- self.model_service.register_change_callback(self.on_model_change)
- self.relation_service.register_change_callback(self.on_relation_change)
+ # self.device_service.register_change_callback(self.on_device_change)
+ # self.model_service.register_change_callback(self.on_model_change)
+ # self.relation_service.register_change_callback(self.on_relation_change)
async def start(self):
logger.info("Starting AlgoRunner...")
@@ -44,10 +46,6 @@
for device in devices:
self.start_device_thread(device)
- # @staticmethod
- # def get_device_thread_id(device_id):
- # return f'device_{device_id}'
-
def start_device_thread(self, device: Device):
device = copy.deepcopy(device)
"""为单个设备启动检测线程"""
@@ -56,7 +54,10 @@
logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程')
return
- thread_id = f'device_{device.id}_{uuid.uuid4()}'
+ if not device.mode == DEVICE_MODE.ALGO:
+ return
+
+ thread_id = f'device_{device.id}_algo_{uuid.uuid4()}'
logger.info(f'start thread {thread_id}, device info: {device}')
# if self.thread_pool.check_task_is_running(thread_id=thread_id):
@@ -135,9 +136,10 @@
elif change_type == NotifyChangeType.DEVICE_DELETE:
self.stop_device_thread(device_id)
elif change_type == NotifyChangeType.DEVICE_UPDATE:
- # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
- old_device = self.device_tasks[device_id].device
- new_device = self.device_service.get_device(device_id)
+ if device_id in self.device_tasks:
+ # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
+ old_device = self.device_tasks[device_id].device
+ new_device = self.device_service.get_device(device_id)
self.restart_device_thread(device_id)
diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py
index d7d8015..9109d75 100644
--- a/algo/algo_runner_manager.py
+++ b/algo/algo_runner_manager.py
@@ -1,7 +1,8 @@
from algo.algo_runner import AlgoRunner
-algo_runner = AlgoRunner()
+# algo_runner = AlgoRunner()
def get_algo_runner():
- return algo_runner
+ pass
+ # return algo_runner
diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py
index b5a2d9e..e222ca7 100644
--- a/algo/device_detection_task.py
+++ b/algo/device_detection_task.py
@@ -9,7 +9,7 @@
from common.device_status_manager import DeviceStatusManager
from common.global_logger import logger
from common.global_thread_pool import GlobalThreadPool
-from common.string_utils import camel_to_snake
+from common.string_utils import camel_to_snake, get_class
from db.database import get_db
from entity.device import Device
from entity.frame_analysis_result import FrameAnalysisResultCreate
@@ -17,11 +17,7 @@
from services.frame_analysis_result_service import FrameAnalysisResultService
-def get_class(module_name, class_name):
- # 动态导入模块
- module = importlib.import_module(module_name)
- # 使用 getattr 从模块中获取类
- return getattr(module, class_name)
+
@dataclass
diff --git a/algo/model_manager.py b/algo/model_manager.py
index 94a058c..e3bc4c9 100644
--- a/algo/model_manager.py
+++ b/algo/model_manager.py
@@ -19,9 +19,9 @@
class ModelManager:
- def __init__(self, db, model_warm_up=5):
- self.db = db
- self.model_service = ModelService(self.db)
+ def __init__(self, model_service: ModelService, model_warm_up=5):
+ # self.db = db
+ self.model_service = model_service
self.models: Dict[int, AlgoModelExec] = {}
self.model_warm_up = model_warm_up
@@ -37,6 +37,8 @@
logger.info('loading models')
self.models = {}
self.query_model_inuse()
+ if not self.models:
+ logger.info("no model in use")
for algo_model_id, algo_model_exec in self.models.items():
self.load_model(algo_model_exec)
diff --git a/algo/scene_runner.py b/algo/scene_runner.py
new file mode 100644
index 0000000..2fb2b2b
--- /dev/null
+++ b/algo/scene_runner.py
@@ -0,0 +1,146 @@
+import concurrent
+import copy
+import uuid
+from typing import Dict
+
+from common.consts import DEVICE_MODE, NotifyChangeType
+from common.global_logger import logger
+from common.global_thread_pool import GlobalThreadPool
+from common.string_utils import get_class, camel_to_snake
+from entity.device import Device
+from scene_handler.base_scene_handler import BaseSceneHandler
+from services.device_scene_relation_service import DeviceSceneRelationService
+from services.device_service import DeviceService
+from services.scene_service import SceneService
+from tcp.tcp_manager import TcpManager
+
+
+class SceneRunner:
+ def __init__(self,
+ device_service: DeviceService,
+ scene_service: SceneService,
+ relation_service: DeviceSceneRelationService,
+ tcp_manager: TcpManager,
+ main_loop):
+ self.device_service = device_service
+ self.scene_service = scene_service
+ self.relation_service = relation_service
+ self.tcp_manager = tcp_manager
+ self.main_loop = main_loop
+
+ self.thread_pool = GlobalThreadPool()
+ self.device_tasks: Dict[int, BaseSceneHandler] = {}
+ self.device_scene_relations = {}
+
+ # 注册设备和模型的变化回调
+ # self.device_service.register_change_callback(self.on_device_change)
+ # self.scene_service.register_change_callback(self.on_scene_change)
+ # self.relation_service.register_change_callback(self.on_relation_change)
+
+ async def start(self):
+ self.load_and_start_devices()
+
+ def load_and_start_devices(self):
+ """从数据库读取设备列表并启动线程"""
+ devices = self.device_service.get_device_list()
+ devices = [device for device in devices if device.mode == DEVICE_MODE.SCENE]
+ for device in devices:
+ self.start_device_thread(device)
+
+ def start_device_thread(self, device: Device):
+ device = copy.deepcopy(device)
+ """为单个设备启动检测线程"""
+
+ if not device.input_stream_url:
+ logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程')
+ return
+
+ if not device.mode == DEVICE_MODE.SCENE:
+ return
+
+ thread_id = f'device_{device.id}_scene_{uuid.uuid4()}'
+ scene = self.relation_service.get_device_scene(device.id)
+ if scene:
+ self.device_scene_relations[device.id] = scene
+ try:
+ handle_task_name = scene.scene_handle_task
+ handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name)
+ handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop)
+ future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id)
+ self.device_tasks[device.id] = handler_instance
+ logger.info(f'start thread {thread_id}, device info: {device}')
+ except Exception as e:
+ logger.exception(f"设备 {device.code} 场景 {scene.scene_name} 启动异常 : {e}")
+
+ def stop_device_thread(self, device_id):
+ """停止指定设备的检测线程,并确保其成功停止"""
+ if device_id in self.device_tasks:
+ logger.info(f'stop device {device_id} thread')
+ # 获取线程 ID 和 future 对象
+ thread_id = self.device_tasks[device_id].thread_id
+ future = self.thread_pool.get_task_future(thread_id=thread_id)
+
+ if future:
+ if not future.done():
+ # 任务正在运行,调用 stop_detection_task 停止任务
+ self.device_tasks[device_id].stop_task()
+ try:
+ # 设置超时时间等待任务停止(例如10秒)
+ result = future.result(timeout=30)
+ logger.info(f"Task {thread_id} stopped successfully.")
+ except concurrent.futures.TimeoutError as te:
+ logger.error(f"Task {thread_id} did not stop within the timeout.")
+ except Exception as e:
+ logger.exception(f"Task {thread_id} encountered an error while stopping: {e}")
+ finally:
+ # 确保无论任务是否停止,都将其从任务列表中移除
+ del self.device_tasks[device_id]
+ else:
+ logger.info(f"Task {thread_id} has already stopped.")
+ # 任务已停止,直接从任务列表中移除
+ del self.device_tasks[device_id]
+ else:
+ logger.warning(f"No task found for {thread_id} .")
+ else:
+ logger.warning(f"No task exists for device {device_id} in device_tasks.")
+
+ def restart_device_thread(self, device_id):
+ try:
+ self.stop_device_thread(device_id)
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ except Exception as e:
+ logger.error(e)
+
+ def on_device_change(self, device_id, change_type):
+ logger.info(f"on device change, device {device_id} {change_type}")
+ """设备变化回调"""
+ if change_type == NotifyChangeType.DEVICE_CREATE:
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ elif change_type == NotifyChangeType.DEVICE_DELETE:
+ self.stop_device_thread(device_id)
+ elif change_type == NotifyChangeType.DEVICE_UPDATE:
+ # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
+ if device_id in self.device_tasks:
+ old_device = self.device_tasks[device_id].device
+ new_device = self.device_service.get_device(device_id)
+
+ self.restart_device_thread(device_id)
+
+ def on_scene_change(self, scene_id, change_type):
+ """模型变化回调"""
+ if change_type == NotifyChangeType.SCENE_UPDATE:
+
+ devices_to_reload = [
+ device_id for device_id, relation_info in self.device_scene_relations.items()
+ if relation_info.scene_id == scene_id
+ ]
+ for device_id in devices_to_reload:
+ self.restart_device_thread(device_id)
+
+ def on_relation_change(self, device_id, change_type):
+ """设备模型关系变化回调"""
+ # 判断模型绑定关系是否真的发生了变化,避免不必要的重启
+ if change_type == NotifyChangeType.DEVICE_SCENE_RELATION_UPDATE:
+ self.restart_device_thread(device_id)
diff --git a/apis/device.py b/apis/device.py
index f786851..85fb10c 100644
--- a/apis/device.py
+++ b/apis/device.py
@@ -3,15 +3,18 @@
from fastapi import APIRouter, Depends, Query
from sqlmodel import Session
-from algo.algo_runner import AlgoRunner
from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param
from db.database import get_db
from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo
from services.device_service import DeviceService
-from algo.algo_runner_manager import get_algo_runner
+from app_instance import get_app
router = APIRouter()
+app = get_app()
+
+def get_service():
+ return app.state.device_service
@router.get("/list", response_model=StandardResponse[List[DeviceInfo]])
@@ -45,15 +48,13 @@
@router.post("/add", response_model=StandardResponse[DeviceInfo])
-def create_device(device_data: DeviceCreate, algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.device_service
+def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)):
device = service.create_device(device_data)
return standard_response(data=device)
@router.post("/update", response_model=StandardResponse[DeviceInfo])
-def update_device(device_data: DeviceUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.device_service
+def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)):
device = service.update_device(device_data)
if not device:
return standard_error_response(data=device_data, message="Device not found")
@@ -61,8 +62,7 @@
@router.delete("/delete", response_model=StandardResponse[int])
-def delete_device(device_id: int, algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.device_service
+def delete_device(device_id: int, service: DeviceService = Depends(get_service)):
device = service.delete_device(device_id)
if not device:
return standard_error_response(data=device_id, message="Device not found")
diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py
index 9152eca..443176b 100644
--- a/apis/device_model_realtion.py
+++ b/apis/device_model_realtion.py
@@ -8,10 +8,13 @@
from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation
from services.device_model_relation_service import DeviceModelRelationService
-from algo.algo_runner import AlgoRunner
-from algo.algo_runner_manager import get_algo_runner
+from app_instance import get_app
router = APIRouter()
+app = get_app()
+
+def get_service():
+ return app.state.device_model_relation_service
@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]])
@@ -35,12 +38,10 @@
@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
def update_by_device(relation_data: List[DeviceModelRelationCreate],
device_id: int = Query(...),
- algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.relation_service
+ service: DeviceModelRelationService = Depends(get_service)):
relations = service.update_relations_by_device(device_id, relation_data)
return standard_response(data=relations)
-
# @router.delete("/delete_by_device", response_model=StandardResponse[int])
# def delete_device(device_id: int, db: Session = Depends(get_db)):
# service = DeviceModelRelationService(db)
diff --git a/apis/model.py b/apis/model.py
index 264e10a..78c064f 100644
--- a/apis/model.py
+++ b/apis/model.py
@@ -4,15 +4,17 @@
from sqlmodel import Session
from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param
+from app_instance import get_app
from db.database import get_db
from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo
from services.model_service import ModelService
-from algo.algo_runner import AlgoRunner
-from algo.algo_runner_manager import get_algo_runner
router = APIRouter()
+app = get_app()
+def get_service():
+ return app.state.model_service
@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]])
def get_model_list(
@@ -60,8 +62,7 @@
@router.post("/update", response_model=StandardResponse[AlgoModelInfo])
def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"),
file: UploadFile = File(None, description="模型文件"),
- algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.model_service
+ service: ModelService = Depends(get_service)):
model_data = AlgoModelUpdate.parse_raw(json_data)
model = service.update_model(model_data, file)
if not model:
diff --git a/apis/scene.py b/apis/scene.py
index b0a4d22..f84b54b 100644
--- a/apis/scene.py
+++ b/apis/scene.py
@@ -9,8 +9,6 @@
from entity.scene import SceneInfo
from services.scene_service import SceneService
-from algo.algo_runner import AlgoRunner
-from algo.algo_runner_manager import get_algo_runner
router = APIRouter()
diff --git a/app_instance.py b/app_instance.py
new file mode 100644
index 0000000..b72926d
--- /dev/null
+++ b/app_instance.py
@@ -0,0 +1,77 @@
+import asyncio
+from contextlib import asynccontextmanager
+
+from fastapi import FastAPI, HTTPException
+
+from algo.algo_runner import AlgoRunner
+from algo.scene_runner import SceneRunner
+
+from common.biz_exception import BizExceptionHandlers
+from common.global_logger import logger
+from db.database import get_db
+from services.device_model_relation_service import DeviceModelRelationService
+from services.device_scene_relation_service import DeviceSceneRelationService
+from services.device_service import DeviceService
+from services.model_service import ModelService
+from services.scene_service import SceneService
+from tcp.tcp_manager import TcpManager
+
+_app = None # 创建一个私有变量来存储 app 实例
+
+
+def create_app() -> FastAPI:
+ global _app
+ if _app is None:
+ _app = FastAPI()
+
+ @asynccontextmanager
+ async def lifespan(app: FastAPI):
+ main_loop = asyncio.get_running_loop()
+
+ with next(get_db()) as db:
+ device_service = DeviceService(db)
+ model_service = ModelService(db)
+ model_relation_service = DeviceModelRelationService(db)
+ scene_service = SceneService(db)
+ scene_relation_service = DeviceSceneRelationService(db)
+
+ app.state.device_service = device_service
+ app.state.model_service = model_service
+ app.state.model_relation_service = model_relation_service
+ app.state.scene_service = scene_service
+ app.state.scene_relation_service = scene_relation_service
+
+ tcp_manager = TcpManager(device_service=device_service)
+ app.state.tcp_manager = tcp_manager
+ await tcp_manager.start()
+
+ algo_runner = AlgoRunner(
+ device_service=device_service,
+ model_service=model_service,
+ relation_service=model_relation_service,
+ )
+ app.state.algo_runner = algo_runner
+ await algo_runner.start()
+
+ scene_runner = SceneRunner(
+ device_service=device_service,
+ scene_service=scene_service,
+ relation_service=scene_relation_service,
+ tcp_manager=tcp_manager,
+ main_loop=main_loop
+ )
+ app.state.scene_runner = scene_runner
+ await scene_runner.start()
+
+ yield # 允许请求处理
+
+ logger.info("Shutting down application...")
+
+ _app.router.lifespan_context = lifespan
+ _app.add_exception_handler(HTTPException, BizExceptionHandlers.biz_exception_handler)
+
+ return _app
+
+
+def get_app() -> FastAPI:
+ return _app or create_app()
diff --git a/common/string_utils.py b/common/string_utils.py
index 42d987b..7ff7f03 100644
--- a/common/string_utils.py
+++ b/common/string_utils.py
@@ -1,5 +1,12 @@
+import importlib
import re
+def get_class(module_name, class_name):
+ # 动态导入模块
+ module = importlib.import_module(module_name)
+ # 使用 getattr 从模块中获取类
+ return getattr(module, class_name)
+
def camel_to_snake(name):
# 将大写字母前加上下划线,并将整个字符串转换为小写
diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db
index b1d8fab..d1acdd4 100644
--- a/db/safe-algo-pro.db
+++ b/db/safe-algo-pro.db
Binary files differ
diff --git a/main.py b/main.py
index 61cbe9e..8e5e464 100644
--- a/main.py
+++ b/main.py
@@ -1,58 +1,21 @@
-import asyncio
-from contextlib import asynccontextmanager
-from fastapi import FastAPI, HTTPException
-
-from algo.algo_runner_manager import get_algo_runner
-from apis.router import router
import uvicorn
import logging
-from common.biz_exception import BizExceptionHandlers
+from app_instance import get_app
from common.global_logger import logger
+app = get_app()
-app = FastAPI()
-
-# # 初始化 AlgoRunner
-# algo_runner = AlgoRunner()
-#
-#
-# # 使用 FastAPI 的 startup 事件来启动 AlgoRunner
-# @app.on_event("startup")
-# async def startup_event():
-# algo_runner.start()
-
-algo_runner = get_algo_runner()
-
-
-@asynccontextmanager
-async def lifespan(app: FastAPI):
- # 应用启动时的初始化
- await algo_runner.start()
- # app.add_event_handler("startup", lambda: asyncio.create_task(algo_runner.start()))
- # 允许请求处理
- yield
- # 应用关闭时的清理逻辑
- logger.info("Shutting down application...")
-
-
-# 包含所有模块的路由
+# 延迟导入 router 并注册路由
+from apis.router import router
app.include_router(router, prefix="/api")
-app.add_exception_handler(HTTPException, BizExceptionHandlers.biz_exception_handler)
-app.router.lifespan_context = lifespan
-
if __name__ == "__main__":
-
# 重定向 uvicorn 的日志
uvicorn_logger = logging.getLogger("uvicorn")
uvicorn_logger.handlers = logger.handlers
uvicorn_logger.setLevel(logging.DEBUG)
- # 重定向 uvicorn 的 access 日志
- # access_logger = logging.getLogger("uvicorn.access")
- # access_logger.handlers = logger.handlers
- # access_logger.setLevel(logging.DEBUG)
uvicorn.run(app, host="0.0.0.0", port=9000, log_config=None)
diff --git a/requirements.txt b/requirements.txt
index 9a7d88c..af16eda 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -2,4 +2,12 @@
sqlmodel
openpyxl
python-multipart
-docker
\ No newline at end of file
+docker
+numpy
+ultralytics
+opencv-python
+pydantic
+pandas
+starlette
+uvicorn
+sqlalchemy
\ No newline at end of file
diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml
index d0876a7..33f5c97 100644
--- a/.idea/safe-algo-pro.iml
+++ b/.idea/safe-algo-pro.iml
@@ -5,4 +5,7 @@
+
+
+
\ No newline at end of file
diff --git a/algo/algo_runner.py b/algo/algo_runner.py
index d59d93a..40b8d1e 100644
--- a/algo/algo_runner.py
+++ b/algo/algo_runner.py
@@ -5,7 +5,7 @@
from algo.device_detection_task import DeviceDetectionTask
from algo.model_manager import ModelManager
-from common.consts import NotifyChangeType
+from common.consts import NotifyChangeType, DEVICE_MODE
from common.global_thread_pool import GlobalThreadPool
from db.database import get_db
from entity.device import Device
@@ -14,23 +14,25 @@
from services.model_service import ModelService
from common.global_logger import logger
-
class AlgoRunner:
- def __init__(self):
- with next(get_db()) as db:
- self.device_service = DeviceService(db)
- self.model_service = ModelService(db)
- self.relation_service = DeviceModelRelationService(db)
- self.model_manager = ModelManager(db)
+ def __init__(self,
+ device_service : DeviceService,
+ model_service: ModelService,
+ relation_service: DeviceModelRelationService):
+
+ self.device_service = device_service
+ self.model_service = model_service
+ self.relation_service = relation_service
+ self.model_manager = ModelManager(model_service)
self.thread_pool = GlobalThreadPool()
self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程
self.device_model_relations = {}
# 注册设备和模型的变化回调
- self.device_service.register_change_callback(self.on_device_change)
- self.model_service.register_change_callback(self.on_model_change)
- self.relation_service.register_change_callback(self.on_relation_change)
+ # self.device_service.register_change_callback(self.on_device_change)
+ # self.model_service.register_change_callback(self.on_model_change)
+ # self.relation_service.register_change_callback(self.on_relation_change)
async def start(self):
logger.info("Starting AlgoRunner...")
@@ -44,10 +46,6 @@
for device in devices:
self.start_device_thread(device)
- # @staticmethod
- # def get_device_thread_id(device_id):
- # return f'device_{device_id}'
-
def start_device_thread(self, device: Device):
device = copy.deepcopy(device)
"""为单个设备启动检测线程"""
@@ -56,7 +54,10 @@
logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程')
return
- thread_id = f'device_{device.id}_{uuid.uuid4()}'
+ if not device.mode == DEVICE_MODE.ALGO:
+ return
+
+ thread_id = f'device_{device.id}_algo_{uuid.uuid4()}'
logger.info(f'start thread {thread_id}, device info: {device}')
# if self.thread_pool.check_task_is_running(thread_id=thread_id):
@@ -135,9 +136,10 @@
elif change_type == NotifyChangeType.DEVICE_DELETE:
self.stop_device_thread(device_id)
elif change_type == NotifyChangeType.DEVICE_UPDATE:
- # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
- old_device = self.device_tasks[device_id].device
- new_device = self.device_service.get_device(device_id)
+ if device_id in self.device_tasks:
+ # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
+ old_device = self.device_tasks[device_id].device
+ new_device = self.device_service.get_device(device_id)
self.restart_device_thread(device_id)
diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py
index d7d8015..9109d75 100644
--- a/algo/algo_runner_manager.py
+++ b/algo/algo_runner_manager.py
@@ -1,7 +1,8 @@
from algo.algo_runner import AlgoRunner
-algo_runner = AlgoRunner()
+# algo_runner = AlgoRunner()
def get_algo_runner():
- return algo_runner
+ pass
+ # return algo_runner
diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py
index b5a2d9e..e222ca7 100644
--- a/algo/device_detection_task.py
+++ b/algo/device_detection_task.py
@@ -9,7 +9,7 @@
from common.device_status_manager import DeviceStatusManager
from common.global_logger import logger
from common.global_thread_pool import GlobalThreadPool
-from common.string_utils import camel_to_snake
+from common.string_utils import camel_to_snake, get_class
from db.database import get_db
from entity.device import Device
from entity.frame_analysis_result import FrameAnalysisResultCreate
@@ -17,11 +17,7 @@
from services.frame_analysis_result_service import FrameAnalysisResultService
-def get_class(module_name, class_name):
- # 动态导入模块
- module = importlib.import_module(module_name)
- # 使用 getattr 从模块中获取类
- return getattr(module, class_name)
+
@dataclass
diff --git a/algo/model_manager.py b/algo/model_manager.py
index 94a058c..e3bc4c9 100644
--- a/algo/model_manager.py
+++ b/algo/model_manager.py
@@ -19,9 +19,9 @@
class ModelManager:
- def __init__(self, db, model_warm_up=5):
- self.db = db
- self.model_service = ModelService(self.db)
+ def __init__(self, model_service: ModelService, model_warm_up=5):
+ # self.db = db
+ self.model_service = model_service
self.models: Dict[int, AlgoModelExec] = {}
self.model_warm_up = model_warm_up
@@ -37,6 +37,8 @@
logger.info('loading models')
self.models = {}
self.query_model_inuse()
+ if not self.models:
+ logger.info("no model in use")
for algo_model_id, algo_model_exec in self.models.items():
self.load_model(algo_model_exec)
diff --git a/algo/scene_runner.py b/algo/scene_runner.py
new file mode 100644
index 0000000..2fb2b2b
--- /dev/null
+++ b/algo/scene_runner.py
@@ -0,0 +1,146 @@
+import concurrent
+import copy
+import uuid
+from typing import Dict
+
+from common.consts import DEVICE_MODE, NotifyChangeType
+from common.global_logger import logger
+from common.global_thread_pool import GlobalThreadPool
+from common.string_utils import get_class, camel_to_snake
+from entity.device import Device
+from scene_handler.base_scene_handler import BaseSceneHandler
+from services.device_scene_relation_service import DeviceSceneRelationService
+from services.device_service import DeviceService
+from services.scene_service import SceneService
+from tcp.tcp_manager import TcpManager
+
+
+class SceneRunner:
+ def __init__(self,
+ device_service: DeviceService,
+ scene_service: SceneService,
+ relation_service: DeviceSceneRelationService,
+ tcp_manager: TcpManager,
+ main_loop):
+ self.device_service = device_service
+ self.scene_service = scene_service
+ self.relation_service = relation_service
+ self.tcp_manager = tcp_manager
+ self.main_loop = main_loop
+
+ self.thread_pool = GlobalThreadPool()
+ self.device_tasks: Dict[int, BaseSceneHandler] = {}
+ self.device_scene_relations = {}
+
+ # 注册设备和模型的变化回调
+ # self.device_service.register_change_callback(self.on_device_change)
+ # self.scene_service.register_change_callback(self.on_scene_change)
+ # self.relation_service.register_change_callback(self.on_relation_change)
+
+ async def start(self):
+ self.load_and_start_devices()
+
+ def load_and_start_devices(self):
+ """从数据库读取设备列表并启动线程"""
+ devices = self.device_service.get_device_list()
+ devices = [device for device in devices if device.mode == DEVICE_MODE.SCENE]
+ for device in devices:
+ self.start_device_thread(device)
+
+ def start_device_thread(self, device: Device):
+ device = copy.deepcopy(device)
+ """为单个设备启动检测线程"""
+
+ if not device.input_stream_url:
+ logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程')
+ return
+
+ if not device.mode == DEVICE_MODE.SCENE:
+ return
+
+ thread_id = f'device_{device.id}_scene_{uuid.uuid4()}'
+ scene = self.relation_service.get_device_scene(device.id)
+ if scene:
+ self.device_scene_relations[device.id] = scene
+ try:
+ handle_task_name = scene.scene_handle_task
+ handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name)
+ handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop)
+ future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id)
+ self.device_tasks[device.id] = handler_instance
+ logger.info(f'start thread {thread_id}, device info: {device}')
+ except Exception as e:
+ logger.exception(f"设备 {device.code} 场景 {scene.scene_name} 启动异常 : {e}")
+
+ def stop_device_thread(self, device_id):
+ """停止指定设备的检测线程,并确保其成功停止"""
+ if device_id in self.device_tasks:
+ logger.info(f'stop device {device_id} thread')
+ # 获取线程 ID 和 future 对象
+ thread_id = self.device_tasks[device_id].thread_id
+ future = self.thread_pool.get_task_future(thread_id=thread_id)
+
+ if future:
+ if not future.done():
+ # 任务正在运行,调用 stop_detection_task 停止任务
+ self.device_tasks[device_id].stop_task()
+ try:
+ # 设置超时时间等待任务停止(例如10秒)
+ result = future.result(timeout=30)
+ logger.info(f"Task {thread_id} stopped successfully.")
+ except concurrent.futures.TimeoutError as te:
+ logger.error(f"Task {thread_id} did not stop within the timeout.")
+ except Exception as e:
+ logger.exception(f"Task {thread_id} encountered an error while stopping: {e}")
+ finally:
+ # 确保无论任务是否停止,都将其从任务列表中移除
+ del self.device_tasks[device_id]
+ else:
+ logger.info(f"Task {thread_id} has already stopped.")
+ # 任务已停止,直接从任务列表中移除
+ del self.device_tasks[device_id]
+ else:
+ logger.warning(f"No task found for {thread_id} .")
+ else:
+ logger.warning(f"No task exists for device {device_id} in device_tasks.")
+
+ def restart_device_thread(self, device_id):
+ try:
+ self.stop_device_thread(device_id)
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ except Exception as e:
+ logger.error(e)
+
+ def on_device_change(self, device_id, change_type):
+ logger.info(f"on device change, device {device_id} {change_type}")
+ """设备变化回调"""
+ if change_type == NotifyChangeType.DEVICE_CREATE:
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ elif change_type == NotifyChangeType.DEVICE_DELETE:
+ self.stop_device_thread(device_id)
+ elif change_type == NotifyChangeType.DEVICE_UPDATE:
+ # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
+ if device_id in self.device_tasks:
+ old_device = self.device_tasks[device_id].device
+ new_device = self.device_service.get_device(device_id)
+
+ self.restart_device_thread(device_id)
+
+ def on_scene_change(self, scene_id, change_type):
+ """模型变化回调"""
+ if change_type == NotifyChangeType.SCENE_UPDATE:
+
+ devices_to_reload = [
+ device_id for device_id, relation_info in self.device_scene_relations.items()
+ if relation_info.scene_id == scene_id
+ ]
+ for device_id in devices_to_reload:
+ self.restart_device_thread(device_id)
+
+ def on_relation_change(self, device_id, change_type):
+ """设备模型关系变化回调"""
+ # 判断模型绑定关系是否真的发生了变化,避免不必要的重启
+ if change_type == NotifyChangeType.DEVICE_SCENE_RELATION_UPDATE:
+ self.restart_device_thread(device_id)
diff --git a/apis/device.py b/apis/device.py
index f786851..85fb10c 100644
--- a/apis/device.py
+++ b/apis/device.py
@@ -3,15 +3,18 @@
from fastapi import APIRouter, Depends, Query
from sqlmodel import Session
-from algo.algo_runner import AlgoRunner
from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param
from db.database import get_db
from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo
from services.device_service import DeviceService
-from algo.algo_runner_manager import get_algo_runner
+from app_instance import get_app
router = APIRouter()
+app = get_app()
+
+def get_service():
+ return app.state.device_service
@router.get("/list", response_model=StandardResponse[List[DeviceInfo]])
@@ -45,15 +48,13 @@
@router.post("/add", response_model=StandardResponse[DeviceInfo])
-def create_device(device_data: DeviceCreate, algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.device_service
+def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)):
device = service.create_device(device_data)
return standard_response(data=device)
@router.post("/update", response_model=StandardResponse[DeviceInfo])
-def update_device(device_data: DeviceUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.device_service
+def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)):
device = service.update_device(device_data)
if not device:
return standard_error_response(data=device_data, message="Device not found")
@@ -61,8 +62,7 @@
@router.delete("/delete", response_model=StandardResponse[int])
-def delete_device(device_id: int, algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.device_service
+def delete_device(device_id: int, service: DeviceService = Depends(get_service)):
device = service.delete_device(device_id)
if not device:
return standard_error_response(data=device_id, message="Device not found")
diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py
index 9152eca..443176b 100644
--- a/apis/device_model_realtion.py
+++ b/apis/device_model_realtion.py
@@ -8,10 +8,13 @@
from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation
from services.device_model_relation_service import DeviceModelRelationService
-from algo.algo_runner import AlgoRunner
-from algo.algo_runner_manager import get_algo_runner
+from app_instance import get_app
router = APIRouter()
+app = get_app()
+
+def get_service():
+ return app.state.device_model_relation_service
@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]])
@@ -35,12 +38,10 @@
@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
def update_by_device(relation_data: List[DeviceModelRelationCreate],
device_id: int = Query(...),
- algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.relation_service
+ service: DeviceModelRelationService = Depends(get_service)):
relations = service.update_relations_by_device(device_id, relation_data)
return standard_response(data=relations)
-
# @router.delete("/delete_by_device", response_model=StandardResponse[int])
# def delete_device(device_id: int, db: Session = Depends(get_db)):
# service = DeviceModelRelationService(db)
diff --git a/apis/model.py b/apis/model.py
index 264e10a..78c064f 100644
--- a/apis/model.py
+++ b/apis/model.py
@@ -4,15 +4,17 @@
from sqlmodel import Session
from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param
+from app_instance import get_app
from db.database import get_db
from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo
from services.model_service import ModelService
-from algo.algo_runner import AlgoRunner
-from algo.algo_runner_manager import get_algo_runner
router = APIRouter()
+app = get_app()
+def get_service():
+ return app.state.model_service
@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]])
def get_model_list(
@@ -60,8 +62,7 @@
@router.post("/update", response_model=StandardResponse[AlgoModelInfo])
def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"),
file: UploadFile = File(None, description="模型文件"),
- algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.model_service
+ service: ModelService = Depends(get_service)):
model_data = AlgoModelUpdate.parse_raw(json_data)
model = service.update_model(model_data, file)
if not model:
diff --git a/apis/scene.py b/apis/scene.py
index b0a4d22..f84b54b 100644
--- a/apis/scene.py
+++ b/apis/scene.py
@@ -9,8 +9,6 @@
from entity.scene import SceneInfo
from services.scene_service import SceneService
-from algo.algo_runner import AlgoRunner
-from algo.algo_runner_manager import get_algo_runner
router = APIRouter()
diff --git a/app_instance.py b/app_instance.py
new file mode 100644
index 0000000..b72926d
--- /dev/null
+++ b/app_instance.py
@@ -0,0 +1,77 @@
+import asyncio
+from contextlib import asynccontextmanager
+
+from fastapi import FastAPI, HTTPException
+
+from algo.algo_runner import AlgoRunner
+from algo.scene_runner import SceneRunner
+
+from common.biz_exception import BizExceptionHandlers
+from common.global_logger import logger
+from db.database import get_db
+from services.device_model_relation_service import DeviceModelRelationService
+from services.device_scene_relation_service import DeviceSceneRelationService
+from services.device_service import DeviceService
+from services.model_service import ModelService
+from services.scene_service import SceneService
+from tcp.tcp_manager import TcpManager
+
+_app = None # 创建一个私有变量来存储 app 实例
+
+
+def create_app() -> FastAPI:
+ global _app
+ if _app is None:
+ _app = FastAPI()
+
+ @asynccontextmanager
+ async def lifespan(app: FastAPI):
+ main_loop = asyncio.get_running_loop()
+
+ with next(get_db()) as db:
+ device_service = DeviceService(db)
+ model_service = ModelService(db)
+ model_relation_service = DeviceModelRelationService(db)
+ scene_service = SceneService(db)
+ scene_relation_service = DeviceSceneRelationService(db)
+
+ app.state.device_service = device_service
+ app.state.model_service = model_service
+ app.state.model_relation_service = model_relation_service
+ app.state.scene_service = scene_service
+ app.state.scene_relation_service = scene_relation_service
+
+ tcp_manager = TcpManager(device_service=device_service)
+ app.state.tcp_manager = tcp_manager
+ await tcp_manager.start()
+
+ algo_runner = AlgoRunner(
+ device_service=device_service,
+ model_service=model_service,
+ relation_service=model_relation_service,
+ )
+ app.state.algo_runner = algo_runner
+ await algo_runner.start()
+
+ scene_runner = SceneRunner(
+ device_service=device_service,
+ scene_service=scene_service,
+ relation_service=scene_relation_service,
+ tcp_manager=tcp_manager,
+ main_loop=main_loop
+ )
+ app.state.scene_runner = scene_runner
+ await scene_runner.start()
+
+ yield # 允许请求处理
+
+ logger.info("Shutting down application...")
+
+ _app.router.lifespan_context = lifespan
+ _app.add_exception_handler(HTTPException, BizExceptionHandlers.biz_exception_handler)
+
+ return _app
+
+
+def get_app() -> FastAPI:
+ return _app or create_app()
diff --git a/common/string_utils.py b/common/string_utils.py
index 42d987b..7ff7f03 100644
--- a/common/string_utils.py
+++ b/common/string_utils.py
@@ -1,5 +1,12 @@
+import importlib
import re
+def get_class(module_name, class_name):
+ # 动态导入模块
+ module = importlib.import_module(module_name)
+ # 使用 getattr 从模块中获取类
+ return getattr(module, class_name)
+
def camel_to_snake(name):
# 将大写字母前加上下划线,并将整个字符串转换为小写
diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db
index b1d8fab..d1acdd4 100644
--- a/db/safe-algo-pro.db
+++ b/db/safe-algo-pro.db
Binary files differ
diff --git a/main.py b/main.py
index 61cbe9e..8e5e464 100644
--- a/main.py
+++ b/main.py
@@ -1,58 +1,21 @@
-import asyncio
-from contextlib import asynccontextmanager
-from fastapi import FastAPI, HTTPException
-
-from algo.algo_runner_manager import get_algo_runner
-from apis.router import router
import uvicorn
import logging
-from common.biz_exception import BizExceptionHandlers
+from app_instance import get_app
from common.global_logger import logger
+app = get_app()
-app = FastAPI()
-
-# # 初始化 AlgoRunner
-# algo_runner = AlgoRunner()
-#
-#
-# # 使用 FastAPI 的 startup 事件来启动 AlgoRunner
-# @app.on_event("startup")
-# async def startup_event():
-# algo_runner.start()
-
-algo_runner = get_algo_runner()
-
-
-@asynccontextmanager
-async def lifespan(app: FastAPI):
- # 应用启动时的初始化
- await algo_runner.start()
- # app.add_event_handler("startup", lambda: asyncio.create_task(algo_runner.start()))
- # 允许请求处理
- yield
- # 应用关闭时的清理逻辑
- logger.info("Shutting down application...")
-
-
-# 包含所有模块的路由
+# 延迟导入 router 并注册路由
+from apis.router import router
app.include_router(router, prefix="/api")
-app.add_exception_handler(HTTPException, BizExceptionHandlers.biz_exception_handler)
-app.router.lifespan_context = lifespan
-
if __name__ == "__main__":
-
# 重定向 uvicorn 的日志
uvicorn_logger = logging.getLogger("uvicorn")
uvicorn_logger.handlers = logger.handlers
uvicorn_logger.setLevel(logging.DEBUG)
- # 重定向 uvicorn 的 access 日志
- # access_logger = logging.getLogger("uvicorn.access")
- # access_logger.handlers = logger.handlers
- # access_logger.setLevel(logging.DEBUG)
uvicorn.run(app, host="0.0.0.0", port=9000, log_config=None)
diff --git a/requirements.txt b/requirements.txt
index 9a7d88c..af16eda 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -2,4 +2,12 @@
sqlmodel
openpyxl
python-multipart
-docker
\ No newline at end of file
+docker
+numpy
+ultralytics
+opencv-python
+pydantic
+pandas
+starlette
+uvicorn
+sqlalchemy
\ No newline at end of file
diff --git a/scene_handler/base_scene_handler.py b/scene_handler/base_scene_handler.py
new file mode 100644
index 0000000..5451a16
--- /dev/null
+++ b/scene_handler/base_scene_handler.py
@@ -0,0 +1,29 @@
+from asyncio import Event
+
+from algo.model_manager import AlgoModelExec
+from algo.stream_loader import OpenCVStreamLoad
+from common.device_status_manager import DeviceStatusManager
+from common.global_logger import logger
+from entity.device import Device
+
+from ultralytics import YOLO
+
+from tcp.tcp_manager import TcpManager
+
+
+class BaseSceneHandler:
+
+ def __init__(self, device: Device,
+ thread_id: str,
+ tcp_manager: TcpManager,
+ main_loop ):
+ self.device = device
+ self.thread_id = thread_id
+ self.tcp_manager =tcp_manager
+ self.main_loop = main_loop
+
+ def stop_task(self):
+ pass
+
+ def run(self):
+ pass
\ No newline at end of file
diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml
index d0876a7..33f5c97 100644
--- a/.idea/safe-algo-pro.iml
+++ b/.idea/safe-algo-pro.iml
@@ -5,4 +5,7 @@
+
+
+
\ No newline at end of file
diff --git a/algo/algo_runner.py b/algo/algo_runner.py
index d59d93a..40b8d1e 100644
--- a/algo/algo_runner.py
+++ b/algo/algo_runner.py
@@ -5,7 +5,7 @@
from algo.device_detection_task import DeviceDetectionTask
from algo.model_manager import ModelManager
-from common.consts import NotifyChangeType
+from common.consts import NotifyChangeType, DEVICE_MODE
from common.global_thread_pool import GlobalThreadPool
from db.database import get_db
from entity.device import Device
@@ -14,23 +14,25 @@
from services.model_service import ModelService
from common.global_logger import logger
-
class AlgoRunner:
- def __init__(self):
- with next(get_db()) as db:
- self.device_service = DeviceService(db)
- self.model_service = ModelService(db)
- self.relation_service = DeviceModelRelationService(db)
- self.model_manager = ModelManager(db)
+ def __init__(self,
+ device_service : DeviceService,
+ model_service: ModelService,
+ relation_service: DeviceModelRelationService):
+
+ self.device_service = device_service
+ self.model_service = model_service
+ self.relation_service = relation_service
+ self.model_manager = ModelManager(model_service)
self.thread_pool = GlobalThreadPool()
self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程
self.device_model_relations = {}
# 注册设备和模型的变化回调
- self.device_service.register_change_callback(self.on_device_change)
- self.model_service.register_change_callback(self.on_model_change)
- self.relation_service.register_change_callback(self.on_relation_change)
+ # self.device_service.register_change_callback(self.on_device_change)
+ # self.model_service.register_change_callback(self.on_model_change)
+ # self.relation_service.register_change_callback(self.on_relation_change)
async def start(self):
logger.info("Starting AlgoRunner...")
@@ -44,10 +46,6 @@
for device in devices:
self.start_device_thread(device)
- # @staticmethod
- # def get_device_thread_id(device_id):
- # return f'device_{device_id}'
-
def start_device_thread(self, device: Device):
device = copy.deepcopy(device)
"""为单个设备启动检测线程"""
@@ -56,7 +54,10 @@
logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程')
return
- thread_id = f'device_{device.id}_{uuid.uuid4()}'
+ if not device.mode == DEVICE_MODE.ALGO:
+ return
+
+ thread_id = f'device_{device.id}_algo_{uuid.uuid4()}'
logger.info(f'start thread {thread_id}, device info: {device}')
# if self.thread_pool.check_task_is_running(thread_id=thread_id):
@@ -135,9 +136,10 @@
elif change_type == NotifyChangeType.DEVICE_DELETE:
self.stop_device_thread(device_id)
elif change_type == NotifyChangeType.DEVICE_UPDATE:
- # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
- old_device = self.device_tasks[device_id].device
- new_device = self.device_service.get_device(device_id)
+ if device_id in self.device_tasks:
+ # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
+ old_device = self.device_tasks[device_id].device
+ new_device = self.device_service.get_device(device_id)
self.restart_device_thread(device_id)
diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py
index d7d8015..9109d75 100644
--- a/algo/algo_runner_manager.py
+++ b/algo/algo_runner_manager.py
@@ -1,7 +1,8 @@
from algo.algo_runner import AlgoRunner
-algo_runner = AlgoRunner()
+# algo_runner = AlgoRunner()
def get_algo_runner():
- return algo_runner
+ pass
+ # return algo_runner
diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py
index b5a2d9e..e222ca7 100644
--- a/algo/device_detection_task.py
+++ b/algo/device_detection_task.py
@@ -9,7 +9,7 @@
from common.device_status_manager import DeviceStatusManager
from common.global_logger import logger
from common.global_thread_pool import GlobalThreadPool
-from common.string_utils import camel_to_snake
+from common.string_utils import camel_to_snake, get_class
from db.database import get_db
from entity.device import Device
from entity.frame_analysis_result import FrameAnalysisResultCreate
@@ -17,11 +17,7 @@
from services.frame_analysis_result_service import FrameAnalysisResultService
-def get_class(module_name, class_name):
- # 动态导入模块
- module = importlib.import_module(module_name)
- # 使用 getattr 从模块中获取类
- return getattr(module, class_name)
+
@dataclass
diff --git a/algo/model_manager.py b/algo/model_manager.py
index 94a058c..e3bc4c9 100644
--- a/algo/model_manager.py
+++ b/algo/model_manager.py
@@ -19,9 +19,9 @@
class ModelManager:
- def __init__(self, db, model_warm_up=5):
- self.db = db
- self.model_service = ModelService(self.db)
+ def __init__(self, model_service: ModelService, model_warm_up=5):
+ # self.db = db
+ self.model_service = model_service
self.models: Dict[int, AlgoModelExec] = {}
self.model_warm_up = model_warm_up
@@ -37,6 +37,8 @@
logger.info('loading models')
self.models = {}
self.query_model_inuse()
+ if not self.models:
+ logger.info("no model in use")
for algo_model_id, algo_model_exec in self.models.items():
self.load_model(algo_model_exec)
diff --git a/algo/scene_runner.py b/algo/scene_runner.py
new file mode 100644
index 0000000..2fb2b2b
--- /dev/null
+++ b/algo/scene_runner.py
@@ -0,0 +1,146 @@
+import concurrent
+import copy
+import uuid
+from typing import Dict
+
+from common.consts import DEVICE_MODE, NotifyChangeType
+from common.global_logger import logger
+from common.global_thread_pool import GlobalThreadPool
+from common.string_utils import get_class, camel_to_snake
+from entity.device import Device
+from scene_handler.base_scene_handler import BaseSceneHandler
+from services.device_scene_relation_service import DeviceSceneRelationService
+from services.device_service import DeviceService
+from services.scene_service import SceneService
+from tcp.tcp_manager import TcpManager
+
+
+class SceneRunner:
+ def __init__(self,
+ device_service: DeviceService,
+ scene_service: SceneService,
+ relation_service: DeviceSceneRelationService,
+ tcp_manager: TcpManager,
+ main_loop):
+ self.device_service = device_service
+ self.scene_service = scene_service
+ self.relation_service = relation_service
+ self.tcp_manager = tcp_manager
+ self.main_loop = main_loop
+
+ self.thread_pool = GlobalThreadPool()
+ self.device_tasks: Dict[int, BaseSceneHandler] = {}
+ self.device_scene_relations = {}
+
+ # 注册设备和模型的变化回调
+ # self.device_service.register_change_callback(self.on_device_change)
+ # self.scene_service.register_change_callback(self.on_scene_change)
+ # self.relation_service.register_change_callback(self.on_relation_change)
+
+ async def start(self):
+ self.load_and_start_devices()
+
+ def load_and_start_devices(self):
+ """从数据库读取设备列表并启动线程"""
+ devices = self.device_service.get_device_list()
+ devices = [device for device in devices if device.mode == DEVICE_MODE.SCENE]
+ for device in devices:
+ self.start_device_thread(device)
+
+ def start_device_thread(self, device: Device):
+ device = copy.deepcopy(device)
+ """为单个设备启动检测线程"""
+
+ if not device.input_stream_url:
+ logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程')
+ return
+
+ if not device.mode == DEVICE_MODE.SCENE:
+ return
+
+ thread_id = f'device_{device.id}_scene_{uuid.uuid4()}'
+ scene = self.relation_service.get_device_scene(device.id)
+ if scene:
+ self.device_scene_relations[device.id] = scene
+ try:
+ handle_task_name = scene.scene_handle_task
+ handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name)
+ handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop)
+ future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id)
+ self.device_tasks[device.id] = handler_instance
+ logger.info(f'start thread {thread_id}, device info: {device}')
+ except Exception as e:
+ logger.exception(f"设备 {device.code} 场景 {scene.scene_name} 启动异常 : {e}")
+
+ def stop_device_thread(self, device_id):
+ """停止指定设备的检测线程,并确保其成功停止"""
+ if device_id in self.device_tasks:
+ logger.info(f'stop device {device_id} thread')
+ # 获取线程 ID 和 future 对象
+ thread_id = self.device_tasks[device_id].thread_id
+ future = self.thread_pool.get_task_future(thread_id=thread_id)
+
+ if future:
+ if not future.done():
+ # 任务正在运行,调用 stop_detection_task 停止任务
+ self.device_tasks[device_id].stop_task()
+ try:
+ # 设置超时时间等待任务停止(例如10秒)
+ result = future.result(timeout=30)
+ logger.info(f"Task {thread_id} stopped successfully.")
+ except concurrent.futures.TimeoutError as te:
+ logger.error(f"Task {thread_id} did not stop within the timeout.")
+ except Exception as e:
+ logger.exception(f"Task {thread_id} encountered an error while stopping: {e}")
+ finally:
+ # 确保无论任务是否停止,都将其从任务列表中移除
+ del self.device_tasks[device_id]
+ else:
+ logger.info(f"Task {thread_id} has already stopped.")
+ # 任务已停止,直接从任务列表中移除
+ del self.device_tasks[device_id]
+ else:
+ logger.warning(f"No task found for {thread_id} .")
+ else:
+ logger.warning(f"No task exists for device {device_id} in device_tasks.")
+
+ def restart_device_thread(self, device_id):
+ try:
+ self.stop_device_thread(device_id)
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ except Exception as e:
+ logger.error(e)
+
+ def on_device_change(self, device_id, change_type):
+ logger.info(f"on device change, device {device_id} {change_type}")
+ """设备变化回调"""
+ if change_type == NotifyChangeType.DEVICE_CREATE:
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ elif change_type == NotifyChangeType.DEVICE_DELETE:
+ self.stop_device_thread(device_id)
+ elif change_type == NotifyChangeType.DEVICE_UPDATE:
+ # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
+ if device_id in self.device_tasks:
+ old_device = self.device_tasks[device_id].device
+ new_device = self.device_service.get_device(device_id)
+
+ self.restart_device_thread(device_id)
+
+ def on_scene_change(self, scene_id, change_type):
+ """模型变化回调"""
+ if change_type == NotifyChangeType.SCENE_UPDATE:
+
+ devices_to_reload = [
+ device_id for device_id, relation_info in self.device_scene_relations.items()
+ if relation_info.scene_id == scene_id
+ ]
+ for device_id in devices_to_reload:
+ self.restart_device_thread(device_id)
+
+ def on_relation_change(self, device_id, change_type):
+ """设备模型关系变化回调"""
+ # 判断模型绑定关系是否真的发生了变化,避免不必要的重启
+ if change_type == NotifyChangeType.DEVICE_SCENE_RELATION_UPDATE:
+ self.restart_device_thread(device_id)
diff --git a/apis/device.py b/apis/device.py
index f786851..85fb10c 100644
--- a/apis/device.py
+++ b/apis/device.py
@@ -3,15 +3,18 @@
from fastapi import APIRouter, Depends, Query
from sqlmodel import Session
-from algo.algo_runner import AlgoRunner
from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param
from db.database import get_db
from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo
from services.device_service import DeviceService
-from algo.algo_runner_manager import get_algo_runner
+from app_instance import get_app
router = APIRouter()
+app = get_app()
+
+def get_service():
+ return app.state.device_service
@router.get("/list", response_model=StandardResponse[List[DeviceInfo]])
@@ -45,15 +48,13 @@
@router.post("/add", response_model=StandardResponse[DeviceInfo])
-def create_device(device_data: DeviceCreate, algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.device_service
+def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)):
device = service.create_device(device_data)
return standard_response(data=device)
@router.post("/update", response_model=StandardResponse[DeviceInfo])
-def update_device(device_data: DeviceUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.device_service
+def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)):
device = service.update_device(device_data)
if not device:
return standard_error_response(data=device_data, message="Device not found")
@@ -61,8 +62,7 @@
@router.delete("/delete", response_model=StandardResponse[int])
-def delete_device(device_id: int, algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.device_service
+def delete_device(device_id: int, service: DeviceService = Depends(get_service)):
device = service.delete_device(device_id)
if not device:
return standard_error_response(data=device_id, message="Device not found")
diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py
index 9152eca..443176b 100644
--- a/apis/device_model_realtion.py
+++ b/apis/device_model_realtion.py
@@ -8,10 +8,13 @@
from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation
from services.device_model_relation_service import DeviceModelRelationService
-from algo.algo_runner import AlgoRunner
-from algo.algo_runner_manager import get_algo_runner
+from app_instance import get_app
router = APIRouter()
+app = get_app()
+
+def get_service():
+ return app.state.device_model_relation_service
@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]])
@@ -35,12 +38,10 @@
@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
def update_by_device(relation_data: List[DeviceModelRelationCreate],
device_id: int = Query(...),
- algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.relation_service
+ service: DeviceModelRelationService = Depends(get_service)):
relations = service.update_relations_by_device(device_id, relation_data)
return standard_response(data=relations)
-
# @router.delete("/delete_by_device", response_model=StandardResponse[int])
# def delete_device(device_id: int, db: Session = Depends(get_db)):
# service = DeviceModelRelationService(db)
diff --git a/apis/model.py b/apis/model.py
index 264e10a..78c064f 100644
--- a/apis/model.py
+++ b/apis/model.py
@@ -4,15 +4,17 @@
from sqlmodel import Session
from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param
+from app_instance import get_app
from db.database import get_db
from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo
from services.model_service import ModelService
-from algo.algo_runner import AlgoRunner
-from algo.algo_runner_manager import get_algo_runner
router = APIRouter()
+app = get_app()
+def get_service():
+ return app.state.model_service
@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]])
def get_model_list(
@@ -60,8 +62,7 @@
@router.post("/update", response_model=StandardResponse[AlgoModelInfo])
def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"),
file: UploadFile = File(None, description="模型文件"),
- algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.model_service
+ service: ModelService = Depends(get_service)):
model_data = AlgoModelUpdate.parse_raw(json_data)
model = service.update_model(model_data, file)
if not model:
diff --git a/apis/scene.py b/apis/scene.py
index b0a4d22..f84b54b 100644
--- a/apis/scene.py
+++ b/apis/scene.py
@@ -9,8 +9,6 @@
from entity.scene import SceneInfo
from services.scene_service import SceneService
-from algo.algo_runner import AlgoRunner
-from algo.algo_runner_manager import get_algo_runner
router = APIRouter()
diff --git a/app_instance.py b/app_instance.py
new file mode 100644
index 0000000..b72926d
--- /dev/null
+++ b/app_instance.py
@@ -0,0 +1,77 @@
+import asyncio
+from contextlib import asynccontextmanager
+
+from fastapi import FastAPI, HTTPException
+
+from algo.algo_runner import AlgoRunner
+from algo.scene_runner import SceneRunner
+
+from common.biz_exception import BizExceptionHandlers
+from common.global_logger import logger
+from db.database import get_db
+from services.device_model_relation_service import DeviceModelRelationService
+from services.device_scene_relation_service import DeviceSceneRelationService
+from services.device_service import DeviceService
+from services.model_service import ModelService
+from services.scene_service import SceneService
+from tcp.tcp_manager import TcpManager
+
+_app = None # 创建一个私有变量来存储 app 实例
+
+
+def create_app() -> FastAPI:
+ global _app
+ if _app is None:
+ _app = FastAPI()
+
+ @asynccontextmanager
+ async def lifespan(app: FastAPI):
+ main_loop = asyncio.get_running_loop()
+
+ with next(get_db()) as db:
+ device_service = DeviceService(db)
+ model_service = ModelService(db)
+ model_relation_service = DeviceModelRelationService(db)
+ scene_service = SceneService(db)
+ scene_relation_service = DeviceSceneRelationService(db)
+
+ app.state.device_service = device_service
+ app.state.model_service = model_service
+ app.state.model_relation_service = model_relation_service
+ app.state.scene_service = scene_service
+ app.state.scene_relation_service = scene_relation_service
+
+ tcp_manager = TcpManager(device_service=device_service)
+ app.state.tcp_manager = tcp_manager
+ await tcp_manager.start()
+
+ algo_runner = AlgoRunner(
+ device_service=device_service,
+ model_service=model_service,
+ relation_service=model_relation_service,
+ )
+ app.state.algo_runner = algo_runner
+ await algo_runner.start()
+
+ scene_runner = SceneRunner(
+ device_service=device_service,
+ scene_service=scene_service,
+ relation_service=scene_relation_service,
+ tcp_manager=tcp_manager,
+ main_loop=main_loop
+ )
+ app.state.scene_runner = scene_runner
+ await scene_runner.start()
+
+ yield # 允许请求处理
+
+ logger.info("Shutting down application...")
+
+ _app.router.lifespan_context = lifespan
+ _app.add_exception_handler(HTTPException, BizExceptionHandlers.biz_exception_handler)
+
+ return _app
+
+
+def get_app() -> FastAPI:
+ return _app or create_app()
diff --git a/common/string_utils.py b/common/string_utils.py
index 42d987b..7ff7f03 100644
--- a/common/string_utils.py
+++ b/common/string_utils.py
@@ -1,5 +1,12 @@
+import importlib
import re
+def get_class(module_name, class_name):
+ # 动态导入模块
+ module = importlib.import_module(module_name)
+ # 使用 getattr 从模块中获取类
+ return getattr(module, class_name)
+
def camel_to_snake(name):
# 将大写字母前加上下划线,并将整个字符串转换为小写
diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db
index b1d8fab..d1acdd4 100644
--- a/db/safe-algo-pro.db
+++ b/db/safe-algo-pro.db
Binary files differ
diff --git a/main.py b/main.py
index 61cbe9e..8e5e464 100644
--- a/main.py
+++ b/main.py
@@ -1,58 +1,21 @@
-import asyncio
-from contextlib import asynccontextmanager
-from fastapi import FastAPI, HTTPException
-
-from algo.algo_runner_manager import get_algo_runner
-from apis.router import router
import uvicorn
import logging
-from common.biz_exception import BizExceptionHandlers
+from app_instance import get_app
from common.global_logger import logger
+app = get_app()
-app = FastAPI()
-
-# # 初始化 AlgoRunner
-# algo_runner = AlgoRunner()
-#
-#
-# # 使用 FastAPI 的 startup 事件来启动 AlgoRunner
-# @app.on_event("startup")
-# async def startup_event():
-# algo_runner.start()
-
-algo_runner = get_algo_runner()
-
-
-@asynccontextmanager
-async def lifespan(app: FastAPI):
- # 应用启动时的初始化
- await algo_runner.start()
- # app.add_event_handler("startup", lambda: asyncio.create_task(algo_runner.start()))
- # 允许请求处理
- yield
- # 应用关闭时的清理逻辑
- logger.info("Shutting down application...")
-
-
-# 包含所有模块的路由
+# 延迟导入 router 并注册路由
+from apis.router import router
app.include_router(router, prefix="/api")
-app.add_exception_handler(HTTPException, BizExceptionHandlers.biz_exception_handler)
-app.router.lifespan_context = lifespan
-
if __name__ == "__main__":
-
# 重定向 uvicorn 的日志
uvicorn_logger = logging.getLogger("uvicorn")
uvicorn_logger.handlers = logger.handlers
uvicorn_logger.setLevel(logging.DEBUG)
- # 重定向 uvicorn 的 access 日志
- # access_logger = logging.getLogger("uvicorn.access")
- # access_logger.handlers = logger.handlers
- # access_logger.setLevel(logging.DEBUG)
uvicorn.run(app, host="0.0.0.0", port=9000, log_config=None)
diff --git a/requirements.txt b/requirements.txt
index 9a7d88c..af16eda 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -2,4 +2,12 @@
sqlmodel
openpyxl
python-multipart
-docker
\ No newline at end of file
+docker
+numpy
+ultralytics
+opencv-python
+pydantic
+pandas
+starlette
+uvicorn
+sqlalchemy
\ No newline at end of file
diff --git a/scene_handler/base_scene_handler.py b/scene_handler/base_scene_handler.py
new file mode 100644
index 0000000..5451a16
--- /dev/null
+++ b/scene_handler/base_scene_handler.py
@@ -0,0 +1,29 @@
+from asyncio import Event
+
+from algo.model_manager import AlgoModelExec
+from algo.stream_loader import OpenCVStreamLoad
+from common.device_status_manager import DeviceStatusManager
+from common.global_logger import logger
+from entity.device import Device
+
+from ultralytics import YOLO
+
+from tcp.tcp_manager import TcpManager
+
+
+class BaseSceneHandler:
+
+ def __init__(self, device: Device,
+ thread_id: str,
+ tcp_manager: TcpManager,
+ main_loop ):
+ self.device = device
+ self.thread_id = thread_id
+ self.tcp_manager =tcp_manager
+ self.main_loop = main_loop
+
+ def stop_task(self):
+ pass
+
+ def run(self):
+ pass
\ No newline at end of file
diff --git a/scene_handler/limit_space_scene_handler.py b/scene_handler/limit_space_scene_handler.py
new file mode 100644
index 0000000..68df7df
--- /dev/null
+++ b/scene_handler/limit_space_scene_handler.py
@@ -0,0 +1,65 @@
+import asyncio
+from asyncio import Event
+
+from algo.model_manager import AlgoModelExec
+from algo.stream_loader import OpenCVStreamLoad
+from common.device_status_manager import DeviceStatusManager
+from common.global_logger import logger
+from entity.device import Device
+
+from ultralytics import YOLO
+
+from scene_handler.base_scene_handler import BaseSceneHandler
+from tcp.tcp_manager import TcpManager
+
+
+class LimitSpaceSceneHandler(BaseSceneHandler):
+
+ def __init__(self, device: Device, thread_id: str, tcp_manager: TcpManager, main_loop):
+ super().__init__(device=device, thread_id=thread_id, tcp_manager=tcp_manager, main_loop=main_loop)
+ # self.device = device
+ # self.thread_id = thread_id
+ self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code,
+ device_thread_id=thread_id)
+ self.device_status_manager = DeviceStatusManager()
+
+ self.person_model = YOLO('weights/yolov8s.pt')
+
+ self.__stop_event = Event(loop=main_loop) # 使用 Event 控制线程的运行状态
+
+ def stop_task(self, **kwargs):
+ logger.info(f'stop detection task {self.device.id}, thread_id: {self.thread_id}')
+ self.__stop_event.set()
+ self.stream_loader.stop() # 停止视频流加载的线程
+
+ def send_tcp_message(self, message: bytes, have_response=False):
+ asyncio.run_coroutine_threadsafe(
+ self.tcp_manager.send_message_to_device(device_id=self.device.id,
+ message=message,
+ have_response=have_response),
+ self.main_loop)
+
+ def run(self):
+ while not self.stream_loader.init:
+ if self.__stop_event.is_set():
+ break # 如果触发了停止事件,则退出循环
+ self.stream_loader.init_cap()
+ for frame in self.stream_loader:
+ if self.__stop_event.is_set():
+ break # 如果触发了停止事件,则退出循环
+ # print('frame')
+ if frame is None:
+ continue
+
+ self.device_status_manager.set_status(device_id=self.device.id)
+ results = self.person_model.predict(source=frame, imgsz=640,
+ save_txt=False,
+ save=False,
+ verbose=False, stream=True)
+ result = (list(results))
+ if len(result[0]) > 0:
+ asyncio.run_coroutine_threadsafe(
+ self.tcp_manager.send_message_to_device(device_id=self.device.id,
+ message=b'\xaa\x01\x00\x93\x07\x00\x9B',
+ have_response=False),
+ self.main_loop)
diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml
index d0876a7..33f5c97 100644
--- a/.idea/safe-algo-pro.iml
+++ b/.idea/safe-algo-pro.iml
@@ -5,4 +5,7 @@
+
+
+
\ No newline at end of file
diff --git a/algo/algo_runner.py b/algo/algo_runner.py
index d59d93a..40b8d1e 100644
--- a/algo/algo_runner.py
+++ b/algo/algo_runner.py
@@ -5,7 +5,7 @@
from algo.device_detection_task import DeviceDetectionTask
from algo.model_manager import ModelManager
-from common.consts import NotifyChangeType
+from common.consts import NotifyChangeType, DEVICE_MODE
from common.global_thread_pool import GlobalThreadPool
from db.database import get_db
from entity.device import Device
@@ -14,23 +14,25 @@
from services.model_service import ModelService
from common.global_logger import logger
-
class AlgoRunner:
- def __init__(self):
- with next(get_db()) as db:
- self.device_service = DeviceService(db)
- self.model_service = ModelService(db)
- self.relation_service = DeviceModelRelationService(db)
- self.model_manager = ModelManager(db)
+ def __init__(self,
+ device_service : DeviceService,
+ model_service: ModelService,
+ relation_service: DeviceModelRelationService):
+
+ self.device_service = device_service
+ self.model_service = model_service
+ self.relation_service = relation_service
+ self.model_manager = ModelManager(model_service)
self.thread_pool = GlobalThreadPool()
self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程
self.device_model_relations = {}
# 注册设备和模型的变化回调
- self.device_service.register_change_callback(self.on_device_change)
- self.model_service.register_change_callback(self.on_model_change)
- self.relation_service.register_change_callback(self.on_relation_change)
+ # self.device_service.register_change_callback(self.on_device_change)
+ # self.model_service.register_change_callback(self.on_model_change)
+ # self.relation_service.register_change_callback(self.on_relation_change)
async def start(self):
logger.info("Starting AlgoRunner...")
@@ -44,10 +46,6 @@
for device in devices:
self.start_device_thread(device)
- # @staticmethod
- # def get_device_thread_id(device_id):
- # return f'device_{device_id}'
-
def start_device_thread(self, device: Device):
device = copy.deepcopy(device)
"""为单个设备启动检测线程"""
@@ -56,7 +54,10 @@
logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程')
return
- thread_id = f'device_{device.id}_{uuid.uuid4()}'
+ if not device.mode == DEVICE_MODE.ALGO:
+ return
+
+ thread_id = f'device_{device.id}_algo_{uuid.uuid4()}'
logger.info(f'start thread {thread_id}, device info: {device}')
# if self.thread_pool.check_task_is_running(thread_id=thread_id):
@@ -135,9 +136,10 @@
elif change_type == NotifyChangeType.DEVICE_DELETE:
self.stop_device_thread(device_id)
elif change_type == NotifyChangeType.DEVICE_UPDATE:
- # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
- old_device = self.device_tasks[device_id].device
- new_device = self.device_service.get_device(device_id)
+ if device_id in self.device_tasks:
+ # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
+ old_device = self.device_tasks[device_id].device
+ new_device = self.device_service.get_device(device_id)
self.restart_device_thread(device_id)
diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py
index d7d8015..9109d75 100644
--- a/algo/algo_runner_manager.py
+++ b/algo/algo_runner_manager.py
@@ -1,7 +1,8 @@
from algo.algo_runner import AlgoRunner
-algo_runner = AlgoRunner()
+# algo_runner = AlgoRunner()
def get_algo_runner():
- return algo_runner
+ pass
+ # return algo_runner
diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py
index b5a2d9e..e222ca7 100644
--- a/algo/device_detection_task.py
+++ b/algo/device_detection_task.py
@@ -9,7 +9,7 @@
from common.device_status_manager import DeviceStatusManager
from common.global_logger import logger
from common.global_thread_pool import GlobalThreadPool
-from common.string_utils import camel_to_snake
+from common.string_utils import camel_to_snake, get_class
from db.database import get_db
from entity.device import Device
from entity.frame_analysis_result import FrameAnalysisResultCreate
@@ -17,11 +17,7 @@
from services.frame_analysis_result_service import FrameAnalysisResultService
-def get_class(module_name, class_name):
- # 动态导入模块
- module = importlib.import_module(module_name)
- # 使用 getattr 从模块中获取类
- return getattr(module, class_name)
+
@dataclass
diff --git a/algo/model_manager.py b/algo/model_manager.py
index 94a058c..e3bc4c9 100644
--- a/algo/model_manager.py
+++ b/algo/model_manager.py
@@ -19,9 +19,9 @@
class ModelManager:
- def __init__(self, db, model_warm_up=5):
- self.db = db
- self.model_service = ModelService(self.db)
+ def __init__(self, model_service: ModelService, model_warm_up=5):
+ # self.db = db
+ self.model_service = model_service
self.models: Dict[int, AlgoModelExec] = {}
self.model_warm_up = model_warm_up
@@ -37,6 +37,8 @@
logger.info('loading models')
self.models = {}
self.query_model_inuse()
+ if not self.models:
+ logger.info("no model in use")
for algo_model_id, algo_model_exec in self.models.items():
self.load_model(algo_model_exec)
diff --git a/algo/scene_runner.py b/algo/scene_runner.py
new file mode 100644
index 0000000..2fb2b2b
--- /dev/null
+++ b/algo/scene_runner.py
@@ -0,0 +1,146 @@
+import concurrent
+import copy
+import uuid
+from typing import Dict
+
+from common.consts import DEVICE_MODE, NotifyChangeType
+from common.global_logger import logger
+from common.global_thread_pool import GlobalThreadPool
+from common.string_utils import get_class, camel_to_snake
+from entity.device import Device
+from scene_handler.base_scene_handler import BaseSceneHandler
+from services.device_scene_relation_service import DeviceSceneRelationService
+from services.device_service import DeviceService
+from services.scene_service import SceneService
+from tcp.tcp_manager import TcpManager
+
+
+class SceneRunner:
+ def __init__(self,
+ device_service: DeviceService,
+ scene_service: SceneService,
+ relation_service: DeviceSceneRelationService,
+ tcp_manager: TcpManager,
+ main_loop):
+ self.device_service = device_service
+ self.scene_service = scene_service
+ self.relation_service = relation_service
+ self.tcp_manager = tcp_manager
+ self.main_loop = main_loop
+
+ self.thread_pool = GlobalThreadPool()
+ self.device_tasks: Dict[int, BaseSceneHandler] = {}
+ self.device_scene_relations = {}
+
+ # 注册设备和模型的变化回调
+ # self.device_service.register_change_callback(self.on_device_change)
+ # self.scene_service.register_change_callback(self.on_scene_change)
+ # self.relation_service.register_change_callback(self.on_relation_change)
+
+ async def start(self):
+ self.load_and_start_devices()
+
+ def load_and_start_devices(self):
+ """从数据库读取设备列表并启动线程"""
+ devices = self.device_service.get_device_list()
+ devices = [device for device in devices if device.mode == DEVICE_MODE.SCENE]
+ for device in devices:
+ self.start_device_thread(device)
+
+ def start_device_thread(self, device: Device):
+ device = copy.deepcopy(device)
+ """为单个设备启动检测线程"""
+
+ if not device.input_stream_url:
+ logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程')
+ return
+
+ if not device.mode == DEVICE_MODE.SCENE:
+ return
+
+ thread_id = f'device_{device.id}_scene_{uuid.uuid4()}'
+ scene = self.relation_service.get_device_scene(device.id)
+ if scene:
+ self.device_scene_relations[device.id] = scene
+ try:
+ handle_task_name = scene.scene_handle_task
+ handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name)
+ handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop)
+ future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id)
+ self.device_tasks[device.id] = handler_instance
+ logger.info(f'start thread {thread_id}, device info: {device}')
+ except Exception as e:
+ logger.exception(f"设备 {device.code} 场景 {scene.scene_name} 启动异常 : {e}")
+
+ def stop_device_thread(self, device_id):
+ """停止指定设备的检测线程,并确保其成功停止"""
+ if device_id in self.device_tasks:
+ logger.info(f'stop device {device_id} thread')
+ # 获取线程 ID 和 future 对象
+ thread_id = self.device_tasks[device_id].thread_id
+ future = self.thread_pool.get_task_future(thread_id=thread_id)
+
+ if future:
+ if not future.done():
+ # 任务正在运行,调用 stop_detection_task 停止任务
+ self.device_tasks[device_id].stop_task()
+ try:
+ # 设置超时时间等待任务停止(例如10秒)
+ result = future.result(timeout=30)
+ logger.info(f"Task {thread_id} stopped successfully.")
+ except concurrent.futures.TimeoutError as te:
+ logger.error(f"Task {thread_id} did not stop within the timeout.")
+ except Exception as e:
+ logger.exception(f"Task {thread_id} encountered an error while stopping: {e}")
+ finally:
+ # 确保无论任务是否停止,都将其从任务列表中移除
+ del self.device_tasks[device_id]
+ else:
+ logger.info(f"Task {thread_id} has already stopped.")
+ # 任务已停止,直接从任务列表中移除
+ del self.device_tasks[device_id]
+ else:
+ logger.warning(f"No task found for {thread_id} .")
+ else:
+ logger.warning(f"No task exists for device {device_id} in device_tasks.")
+
+ def restart_device_thread(self, device_id):
+ try:
+ self.stop_device_thread(device_id)
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ except Exception as e:
+ logger.error(e)
+
+ def on_device_change(self, device_id, change_type):
+ logger.info(f"on device change, device {device_id} {change_type}")
+ """设备变化回调"""
+ if change_type == NotifyChangeType.DEVICE_CREATE:
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ elif change_type == NotifyChangeType.DEVICE_DELETE:
+ self.stop_device_thread(device_id)
+ elif change_type == NotifyChangeType.DEVICE_UPDATE:
+ # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
+ if device_id in self.device_tasks:
+ old_device = self.device_tasks[device_id].device
+ new_device = self.device_service.get_device(device_id)
+
+ self.restart_device_thread(device_id)
+
+ def on_scene_change(self, scene_id, change_type):
+ """模型变化回调"""
+ if change_type == NotifyChangeType.SCENE_UPDATE:
+
+ devices_to_reload = [
+ device_id for device_id, relation_info in self.device_scene_relations.items()
+ if relation_info.scene_id == scene_id
+ ]
+ for device_id in devices_to_reload:
+ self.restart_device_thread(device_id)
+
+ def on_relation_change(self, device_id, change_type):
+ """设备模型关系变化回调"""
+ # 判断模型绑定关系是否真的发生了变化,避免不必要的重启
+ if change_type == NotifyChangeType.DEVICE_SCENE_RELATION_UPDATE:
+ self.restart_device_thread(device_id)
diff --git a/apis/device.py b/apis/device.py
index f786851..85fb10c 100644
--- a/apis/device.py
+++ b/apis/device.py
@@ -3,15 +3,18 @@
from fastapi import APIRouter, Depends, Query
from sqlmodel import Session
-from algo.algo_runner import AlgoRunner
from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param
from db.database import get_db
from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo
from services.device_service import DeviceService
-from algo.algo_runner_manager import get_algo_runner
+from app_instance import get_app
router = APIRouter()
+app = get_app()
+
+def get_service():
+ return app.state.device_service
@router.get("/list", response_model=StandardResponse[List[DeviceInfo]])
@@ -45,15 +48,13 @@
@router.post("/add", response_model=StandardResponse[DeviceInfo])
-def create_device(device_data: DeviceCreate, algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.device_service
+def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)):
device = service.create_device(device_data)
return standard_response(data=device)
@router.post("/update", response_model=StandardResponse[DeviceInfo])
-def update_device(device_data: DeviceUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.device_service
+def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)):
device = service.update_device(device_data)
if not device:
return standard_error_response(data=device_data, message="Device not found")
@@ -61,8 +62,7 @@
@router.delete("/delete", response_model=StandardResponse[int])
-def delete_device(device_id: int, algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.device_service
+def delete_device(device_id: int, service: DeviceService = Depends(get_service)):
device = service.delete_device(device_id)
if not device:
return standard_error_response(data=device_id, message="Device not found")
diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py
index 9152eca..443176b 100644
--- a/apis/device_model_realtion.py
+++ b/apis/device_model_realtion.py
@@ -8,10 +8,13 @@
from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation
from services.device_model_relation_service import DeviceModelRelationService
-from algo.algo_runner import AlgoRunner
-from algo.algo_runner_manager import get_algo_runner
+from app_instance import get_app
router = APIRouter()
+app = get_app()
+
+def get_service():
+ return app.state.device_model_relation_service
@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]])
@@ -35,12 +38,10 @@
@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
def update_by_device(relation_data: List[DeviceModelRelationCreate],
device_id: int = Query(...),
- algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.relation_service
+ service: DeviceModelRelationService = Depends(get_service)):
relations = service.update_relations_by_device(device_id, relation_data)
return standard_response(data=relations)
-
# @router.delete("/delete_by_device", response_model=StandardResponse[int])
# def delete_device(device_id: int, db: Session = Depends(get_db)):
# service = DeviceModelRelationService(db)
diff --git a/apis/model.py b/apis/model.py
index 264e10a..78c064f 100644
--- a/apis/model.py
+++ b/apis/model.py
@@ -4,15 +4,17 @@
from sqlmodel import Session
from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param
+from app_instance import get_app
from db.database import get_db
from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo
from services.model_service import ModelService
-from algo.algo_runner import AlgoRunner
-from algo.algo_runner_manager import get_algo_runner
router = APIRouter()
+app = get_app()
+def get_service():
+ return app.state.model_service
@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]])
def get_model_list(
@@ -60,8 +62,7 @@
@router.post("/update", response_model=StandardResponse[AlgoModelInfo])
def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"),
file: UploadFile = File(None, description="模型文件"),
- algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.model_service
+ service: ModelService = Depends(get_service)):
model_data = AlgoModelUpdate.parse_raw(json_data)
model = service.update_model(model_data, file)
if not model:
diff --git a/apis/scene.py b/apis/scene.py
index b0a4d22..f84b54b 100644
--- a/apis/scene.py
+++ b/apis/scene.py
@@ -9,8 +9,6 @@
from entity.scene import SceneInfo
from services.scene_service import SceneService
-from algo.algo_runner import AlgoRunner
-from algo.algo_runner_manager import get_algo_runner
router = APIRouter()
diff --git a/app_instance.py b/app_instance.py
new file mode 100644
index 0000000..b72926d
--- /dev/null
+++ b/app_instance.py
@@ -0,0 +1,77 @@
+import asyncio
+from contextlib import asynccontextmanager
+
+from fastapi import FastAPI, HTTPException
+
+from algo.algo_runner import AlgoRunner
+from algo.scene_runner import SceneRunner
+
+from common.biz_exception import BizExceptionHandlers
+from common.global_logger import logger
+from db.database import get_db
+from services.device_model_relation_service import DeviceModelRelationService
+from services.device_scene_relation_service import DeviceSceneRelationService
+from services.device_service import DeviceService
+from services.model_service import ModelService
+from services.scene_service import SceneService
+from tcp.tcp_manager import TcpManager
+
+_app = None # 创建一个私有变量来存储 app 实例
+
+
+def create_app() -> FastAPI:
+ global _app
+ if _app is None:
+ _app = FastAPI()
+
+ @asynccontextmanager
+ async def lifespan(app: FastAPI):
+ main_loop = asyncio.get_running_loop()
+
+ with next(get_db()) as db:
+ device_service = DeviceService(db)
+ model_service = ModelService(db)
+ model_relation_service = DeviceModelRelationService(db)
+ scene_service = SceneService(db)
+ scene_relation_service = DeviceSceneRelationService(db)
+
+ app.state.device_service = device_service
+ app.state.model_service = model_service
+ app.state.model_relation_service = model_relation_service
+ app.state.scene_service = scene_service
+ app.state.scene_relation_service = scene_relation_service
+
+ tcp_manager = TcpManager(device_service=device_service)
+ app.state.tcp_manager = tcp_manager
+ await tcp_manager.start()
+
+ algo_runner = AlgoRunner(
+ device_service=device_service,
+ model_service=model_service,
+ relation_service=model_relation_service,
+ )
+ app.state.algo_runner = algo_runner
+ await algo_runner.start()
+
+ scene_runner = SceneRunner(
+ device_service=device_service,
+ scene_service=scene_service,
+ relation_service=scene_relation_service,
+ tcp_manager=tcp_manager,
+ main_loop=main_loop
+ )
+ app.state.scene_runner = scene_runner
+ await scene_runner.start()
+
+ yield # 允许请求处理
+
+ logger.info("Shutting down application...")
+
+ _app.router.lifespan_context = lifespan
+ _app.add_exception_handler(HTTPException, BizExceptionHandlers.biz_exception_handler)
+
+ return _app
+
+
+def get_app() -> FastAPI:
+ return _app or create_app()
diff --git a/common/string_utils.py b/common/string_utils.py
index 42d987b..7ff7f03 100644
--- a/common/string_utils.py
+++ b/common/string_utils.py
@@ -1,5 +1,12 @@
+import importlib
import re
+def get_class(module_name, class_name):
+ # 动态导入模块
+ module = importlib.import_module(module_name)
+ # 使用 getattr 从模块中获取类
+ return getattr(module, class_name)
+
def camel_to_snake(name):
# 将大写字母前加上下划线,并将整个字符串转换为小写
diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db
index b1d8fab..d1acdd4 100644
--- a/db/safe-algo-pro.db
+++ b/db/safe-algo-pro.db
Binary files differ
diff --git a/main.py b/main.py
index 61cbe9e..8e5e464 100644
--- a/main.py
+++ b/main.py
@@ -1,58 +1,21 @@
-import asyncio
-from contextlib import asynccontextmanager
-from fastapi import FastAPI, HTTPException
-
-from algo.algo_runner_manager import get_algo_runner
-from apis.router import router
import uvicorn
import logging
-from common.biz_exception import BizExceptionHandlers
+from app_instance import get_app
from common.global_logger import logger
+app = get_app()
-app = FastAPI()
-
-# # 初始化 AlgoRunner
-# algo_runner = AlgoRunner()
-#
-#
-# # 使用 FastAPI 的 startup 事件来启动 AlgoRunner
-# @app.on_event("startup")
-# async def startup_event():
-# algo_runner.start()
-
-algo_runner = get_algo_runner()
-
-
-@asynccontextmanager
-async def lifespan(app: FastAPI):
- # 应用启动时的初始化
- await algo_runner.start()
- # app.add_event_handler("startup", lambda: asyncio.create_task(algo_runner.start()))
- # 允许请求处理
- yield
- # 应用关闭时的清理逻辑
- logger.info("Shutting down application...")
-
-
-# 包含所有模块的路由
+# 延迟导入 router 并注册路由
+from apis.router import router
app.include_router(router, prefix="/api")
-app.add_exception_handler(HTTPException, BizExceptionHandlers.biz_exception_handler)
-app.router.lifespan_context = lifespan
-
if __name__ == "__main__":
-
# 重定向 uvicorn 的日志
uvicorn_logger = logging.getLogger("uvicorn")
uvicorn_logger.handlers = logger.handlers
uvicorn_logger.setLevel(logging.DEBUG)
- # 重定向 uvicorn 的 access 日志
- # access_logger = logging.getLogger("uvicorn.access")
- # access_logger.handlers = logger.handlers
- # access_logger.setLevel(logging.DEBUG)
uvicorn.run(app, host="0.0.0.0", port=9000, log_config=None)
diff --git a/requirements.txt b/requirements.txt
index 9a7d88c..af16eda 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -2,4 +2,12 @@
sqlmodel
openpyxl
python-multipart
-docker
\ No newline at end of file
+docker
+numpy
+ultralytics
+opencv-python
+pydantic
+pandas
+starlette
+uvicorn
+sqlalchemy
\ No newline at end of file
diff --git a/scene_handler/base_scene_handler.py b/scene_handler/base_scene_handler.py
new file mode 100644
index 0000000..5451a16
--- /dev/null
+++ b/scene_handler/base_scene_handler.py
@@ -0,0 +1,29 @@
+from asyncio import Event
+
+from algo.model_manager import AlgoModelExec
+from algo.stream_loader import OpenCVStreamLoad
+from common.device_status_manager import DeviceStatusManager
+from common.global_logger import logger
+from entity.device import Device
+
+from ultralytics import YOLO
+
+from tcp.tcp_manager import TcpManager
+
+
+class BaseSceneHandler:
+
+ def __init__(self, device: Device,
+ thread_id: str,
+ tcp_manager: TcpManager,
+ main_loop ):
+ self.device = device
+ self.thread_id = thread_id
+ self.tcp_manager =tcp_manager
+ self.main_loop = main_loop
+
+ def stop_task(self):
+ pass
+
+ def run(self):
+ pass
\ No newline at end of file
diff --git a/scene_handler/limit_space_scene_handler.py b/scene_handler/limit_space_scene_handler.py
new file mode 100644
index 0000000..68df7df
--- /dev/null
+++ b/scene_handler/limit_space_scene_handler.py
@@ -0,0 +1,65 @@
+import asyncio
+from asyncio import Event
+
+from algo.model_manager import AlgoModelExec
+from algo.stream_loader import OpenCVStreamLoad
+from common.device_status_manager import DeviceStatusManager
+from common.global_logger import logger
+from entity.device import Device
+
+from ultralytics import YOLO
+
+from scene_handler.base_scene_handler import BaseSceneHandler
+from tcp.tcp_manager import TcpManager
+
+
+class LimitSpaceSceneHandler(BaseSceneHandler):
+
+ def __init__(self, device: Device, thread_id: str, tcp_manager: TcpManager, main_loop):
+ super().__init__(device=device, thread_id=thread_id, tcp_manager=tcp_manager, main_loop=main_loop)
+ # self.device = device
+ # self.thread_id = thread_id
+ self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code,
+ device_thread_id=thread_id)
+ self.device_status_manager = DeviceStatusManager()
+
+ self.person_model = YOLO('weights/yolov8s.pt')
+
+ self.__stop_event = Event(loop=main_loop) # 使用 Event 控制线程的运行状态
+
+ def stop_task(self, **kwargs):
+ logger.info(f'stop detection task {self.device.id}, thread_id: {self.thread_id}')
+ self.__stop_event.set()
+ self.stream_loader.stop() # 停止视频流加载的线程
+
+ def send_tcp_message(self, message: bytes, have_response=False):
+ asyncio.run_coroutine_threadsafe(
+ self.tcp_manager.send_message_to_device(device_id=self.device.id,
+ message=message,
+ have_response=have_response),
+ self.main_loop)
+
+ def run(self):
+ while not self.stream_loader.init:
+ if self.__stop_event.is_set():
+ break # 如果触发了停止事件,则退出循环
+ self.stream_loader.init_cap()
+ for frame in self.stream_loader:
+ if self.__stop_event.is_set():
+ break # 如果触发了停止事件,则退出循环
+ # print('frame')
+ if frame is None:
+ continue
+
+ self.device_status_manager.set_status(device_id=self.device.id)
+ results = self.person_model.predict(source=frame, imgsz=640,
+ save_txt=False,
+ save=False,
+ verbose=False, stream=True)
+ result = (list(results))
+ if len(result[0]) > 0:
+ asyncio.run_coroutine_threadsafe(
+ self.tcp_manager.send_message_to_device(device_id=self.device.id,
+ message=b'\xaa\x01\x00\x93\x07\x00\x9B',
+ have_response=False),
+ self.main_loop)
diff --git a/tcp/tcp_client_connector.py b/tcp/tcp_client_connector.py
index bda6a95..a62758e 100644
--- a/tcp/tcp_client_connector.py
+++ b/tcp/tcp_client_connector.py
@@ -139,7 +139,6 @@
# 可以根据需求选择是否接收响应
if have_response:
data = await asyncio.wait_for(self.reader.read(1024), timeout=self.timeout)
- print(format_bytes(data))
# if not data:
# raise ConnectionResetError("Connection lost or no data received")
self.parse_response(data)
@@ -159,6 +158,6 @@
# asyncio.run(client.connect()) # Run the asynchronous connect method
# 示例数据
- data = b'\x07 \x00\x01\x00\x01\xaa\x01\x00"0\r`'
+ data = b'\x07\x00\x01\x00\x01\xaa\x01\x00"0\r`'
result = parse_gas_data(data)
print(result)
diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml
index d0876a7..33f5c97 100644
--- a/.idea/safe-algo-pro.iml
+++ b/.idea/safe-algo-pro.iml
@@ -5,4 +5,7 @@
+
+
+
\ No newline at end of file
diff --git a/algo/algo_runner.py b/algo/algo_runner.py
index d59d93a..40b8d1e 100644
--- a/algo/algo_runner.py
+++ b/algo/algo_runner.py
@@ -5,7 +5,7 @@
from algo.device_detection_task import DeviceDetectionTask
from algo.model_manager import ModelManager
-from common.consts import NotifyChangeType
+from common.consts import NotifyChangeType, DEVICE_MODE
from common.global_thread_pool import GlobalThreadPool
from db.database import get_db
from entity.device import Device
@@ -14,23 +14,25 @@
from services.model_service import ModelService
from common.global_logger import logger
-
class AlgoRunner:
- def __init__(self):
- with next(get_db()) as db:
- self.device_service = DeviceService(db)
- self.model_service = ModelService(db)
- self.relation_service = DeviceModelRelationService(db)
- self.model_manager = ModelManager(db)
+ def __init__(self,
+ device_service : DeviceService,
+ model_service: ModelService,
+ relation_service: DeviceModelRelationService):
+
+ self.device_service = device_service
+ self.model_service = model_service
+ self.relation_service = relation_service
+ self.model_manager = ModelManager(model_service)
self.thread_pool = GlobalThreadPool()
self.device_tasks: Dict[int, DeviceDetectionTask] = {} # 用于存储设备对应的线程
self.device_model_relations = {}
# 注册设备和模型的变化回调
- self.device_service.register_change_callback(self.on_device_change)
- self.model_service.register_change_callback(self.on_model_change)
- self.relation_service.register_change_callback(self.on_relation_change)
+ # self.device_service.register_change_callback(self.on_device_change)
+ # self.model_service.register_change_callback(self.on_model_change)
+ # self.relation_service.register_change_callback(self.on_relation_change)
async def start(self):
logger.info("Starting AlgoRunner...")
@@ -44,10 +46,6 @@
for device in devices:
self.start_device_thread(device)
- # @staticmethod
- # def get_device_thread_id(device_id):
- # return f'device_{device_id}'
-
def start_device_thread(self, device: Device):
device = copy.deepcopy(device)
"""为单个设备启动检测线程"""
@@ -56,7 +54,10 @@
logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程')
return
- thread_id = f'device_{device.id}_{uuid.uuid4()}'
+ if not device.mode == DEVICE_MODE.ALGO:
+ return
+
+ thread_id = f'device_{device.id}_algo_{uuid.uuid4()}'
logger.info(f'start thread {thread_id}, device info: {device}')
# if self.thread_pool.check_task_is_running(thread_id=thread_id):
@@ -135,9 +136,10 @@
elif change_type == NotifyChangeType.DEVICE_DELETE:
self.stop_device_thread(device_id)
elif change_type == NotifyChangeType.DEVICE_UPDATE:
- # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
- old_device = self.device_tasks[device_id].device
- new_device = self.device_service.get_device(device_id)
+ if device_id in self.device_tasks:
+ # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
+ old_device = self.device_tasks[device_id].device
+ new_device = self.device_service.get_device(device_id)
self.restart_device_thread(device_id)
diff --git a/algo/algo_runner_manager.py b/algo/algo_runner_manager.py
index d7d8015..9109d75 100644
--- a/algo/algo_runner_manager.py
+++ b/algo/algo_runner_manager.py
@@ -1,7 +1,8 @@
from algo.algo_runner import AlgoRunner
-algo_runner = AlgoRunner()
+# algo_runner = AlgoRunner()
def get_algo_runner():
- return algo_runner
+ pass
+ # return algo_runner
diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py
index b5a2d9e..e222ca7 100644
--- a/algo/device_detection_task.py
+++ b/algo/device_detection_task.py
@@ -9,7 +9,7 @@
from common.device_status_manager import DeviceStatusManager
from common.global_logger import logger
from common.global_thread_pool import GlobalThreadPool
-from common.string_utils import camel_to_snake
+from common.string_utils import camel_to_snake, get_class
from db.database import get_db
from entity.device import Device
from entity.frame_analysis_result import FrameAnalysisResultCreate
@@ -17,11 +17,7 @@
from services.frame_analysis_result_service import FrameAnalysisResultService
-def get_class(module_name, class_name):
- # 动态导入模块
- module = importlib.import_module(module_name)
- # 使用 getattr 从模块中获取类
- return getattr(module, class_name)
+
@dataclass
diff --git a/algo/model_manager.py b/algo/model_manager.py
index 94a058c..e3bc4c9 100644
--- a/algo/model_manager.py
+++ b/algo/model_manager.py
@@ -19,9 +19,9 @@
class ModelManager:
- def __init__(self, db, model_warm_up=5):
- self.db = db
- self.model_service = ModelService(self.db)
+ def __init__(self, model_service: ModelService, model_warm_up=5):
+ # self.db = db
+ self.model_service = model_service
self.models: Dict[int, AlgoModelExec] = {}
self.model_warm_up = model_warm_up
@@ -37,6 +37,8 @@
logger.info('loading models')
self.models = {}
self.query_model_inuse()
+ if not self.models:
+ logger.info("no model in use")
for algo_model_id, algo_model_exec in self.models.items():
self.load_model(algo_model_exec)
diff --git a/algo/scene_runner.py b/algo/scene_runner.py
new file mode 100644
index 0000000..2fb2b2b
--- /dev/null
+++ b/algo/scene_runner.py
@@ -0,0 +1,146 @@
+import concurrent
+import copy
+import uuid
+from typing import Dict
+
+from common.consts import DEVICE_MODE, NotifyChangeType
+from common.global_logger import logger
+from common.global_thread_pool import GlobalThreadPool
+from common.string_utils import get_class, camel_to_snake
+from entity.device import Device
+from scene_handler.base_scene_handler import BaseSceneHandler
+from services.device_scene_relation_service import DeviceSceneRelationService
+from services.device_service import DeviceService
+from services.scene_service import SceneService
+from tcp.tcp_manager import TcpManager
+
+
+class SceneRunner:
+ def __init__(self,
+ device_service: DeviceService,
+ scene_service: SceneService,
+ relation_service: DeviceSceneRelationService,
+ tcp_manager: TcpManager,
+ main_loop):
+ self.device_service = device_service
+ self.scene_service = scene_service
+ self.relation_service = relation_service
+ self.tcp_manager = tcp_manager
+ self.main_loop = main_loop
+
+ self.thread_pool = GlobalThreadPool()
+ self.device_tasks: Dict[int, BaseSceneHandler] = {}
+ self.device_scene_relations = {}
+
+ # 注册设备和模型的变化回调
+ # self.device_service.register_change_callback(self.on_device_change)
+ # self.scene_service.register_change_callback(self.on_scene_change)
+ # self.relation_service.register_change_callback(self.on_relation_change)
+
+ async def start(self):
+ self.load_and_start_devices()
+
+ def load_and_start_devices(self):
+ """从数据库读取设备列表并启动线程"""
+ devices = self.device_service.get_device_list()
+ devices = [device for device in devices if device.mode == DEVICE_MODE.SCENE]
+ for device in devices:
+ self.start_device_thread(device)
+
+ def start_device_thread(self, device: Device):
+ device = copy.deepcopy(device)
+ """为单个设备启动检测线程"""
+
+ if not device.input_stream_url:
+ logger.info(f'设备 {device.code} 未配置视频流地址,无法启动检测线程')
+ return
+
+ if not device.mode == DEVICE_MODE.SCENE:
+ return
+
+ thread_id = f'device_{device.id}_scene_{uuid.uuid4()}'
+ scene = self.relation_service.get_device_scene(device.id)
+ if scene:
+ self.device_scene_relations[device.id] = scene
+ try:
+ handle_task_name = scene.scene_handle_task
+ handler_cls = get_class(f'scene_handler.{camel_to_snake(handle_task_name)}', handle_task_name)
+ handler_instance = handler_cls(device, thread_id, self.tcp_manager,self.main_loop)
+ future = self.thread_pool.submit_task(handler_instance.run, thread_id=thread_id)
+ self.device_tasks[device.id] = handler_instance
+ logger.info(f'start thread {thread_id}, device info: {device}')
+ except Exception as e:
+ logger.exception(f"设备 {device.code} 场景 {scene.scene_name} 启动异常 : {e}")
+
+ def stop_device_thread(self, device_id):
+ """停止指定设备的检测线程,并确保其成功停止"""
+ if device_id in self.device_tasks:
+ logger.info(f'stop device {device_id} thread')
+ # 获取线程 ID 和 future 对象
+ thread_id = self.device_tasks[device_id].thread_id
+ future = self.thread_pool.get_task_future(thread_id=thread_id)
+
+ if future:
+ if not future.done():
+ # 任务正在运行,调用 stop_detection_task 停止任务
+ self.device_tasks[device_id].stop_task()
+ try:
+ # 设置超时时间等待任务停止(例如10秒)
+ result = future.result(timeout=30)
+ logger.info(f"Task {thread_id} stopped successfully.")
+ except concurrent.futures.TimeoutError as te:
+ logger.error(f"Task {thread_id} did not stop within the timeout.")
+ except Exception as e:
+ logger.exception(f"Task {thread_id} encountered an error while stopping: {e}")
+ finally:
+ # 确保无论任务是否停止,都将其从任务列表中移除
+ del self.device_tasks[device_id]
+ else:
+ logger.info(f"Task {thread_id} has already stopped.")
+ # 任务已停止,直接从任务列表中移除
+ del self.device_tasks[device_id]
+ else:
+ logger.warning(f"No task found for {thread_id} .")
+ else:
+ logger.warning(f"No task exists for device {device_id} in device_tasks.")
+
+ def restart_device_thread(self, device_id):
+ try:
+ self.stop_device_thread(device_id)
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ except Exception as e:
+ logger.error(e)
+
+ def on_device_change(self, device_id, change_type):
+ logger.info(f"on device change, device {device_id} {change_type}")
+ """设备变化回调"""
+ if change_type == NotifyChangeType.DEVICE_CREATE:
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ elif change_type == NotifyChangeType.DEVICE_DELETE:
+ self.stop_device_thread(device_id)
+ elif change_type == NotifyChangeType.DEVICE_UPDATE:
+ # 如果只是设备属性更新,而非视频流或模型变化,考虑是否需要直接重启
+ if device_id in self.device_tasks:
+ old_device = self.device_tasks[device_id].device
+ new_device = self.device_service.get_device(device_id)
+
+ self.restart_device_thread(device_id)
+
+ def on_scene_change(self, scene_id, change_type):
+ """模型变化回调"""
+ if change_type == NotifyChangeType.SCENE_UPDATE:
+
+ devices_to_reload = [
+ device_id for device_id, relation_info in self.device_scene_relations.items()
+ if relation_info.scene_id == scene_id
+ ]
+ for device_id in devices_to_reload:
+ self.restart_device_thread(device_id)
+
+ def on_relation_change(self, device_id, change_type):
+ """设备模型关系变化回调"""
+ # 判断模型绑定关系是否真的发生了变化,避免不必要的重启
+ if change_type == NotifyChangeType.DEVICE_SCENE_RELATION_UPDATE:
+ self.restart_device_thread(device_id)
diff --git a/apis/device.py b/apis/device.py
index f786851..85fb10c 100644
--- a/apis/device.py
+++ b/apis/device.py
@@ -3,15 +3,18 @@
from fastapi import APIRouter, Depends, Query
from sqlmodel import Session
-from algo.algo_runner import AlgoRunner
from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param
from db.database import get_db
from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo
from services.device_service import DeviceService
-from algo.algo_runner_manager import get_algo_runner
+from app_instance import get_app
router = APIRouter()
+app = get_app()
+
+def get_service():
+ return app.state.device_service
@router.get("/list", response_model=StandardResponse[List[DeviceInfo]])
@@ -45,15 +48,13 @@
@router.post("/add", response_model=StandardResponse[DeviceInfo])
-def create_device(device_data: DeviceCreate, algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.device_service
+def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)):
device = service.create_device(device_data)
return standard_response(data=device)
@router.post("/update", response_model=StandardResponse[DeviceInfo])
-def update_device(device_data: DeviceUpdate, algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.device_service
+def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)):
device = service.update_device(device_data)
if not device:
return standard_error_response(data=device_data, message="Device not found")
@@ -61,8 +62,7 @@
@router.delete("/delete", response_model=StandardResponse[int])
-def delete_device(device_id: int, algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.device_service
+def delete_device(device_id: int, service: DeviceService = Depends(get_service)):
device = service.delete_device(device_id)
if not device:
return standard_error_response(data=device_id, message="Device not found")
diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py
index 9152eca..443176b 100644
--- a/apis/device_model_realtion.py
+++ b/apis/device_model_realtion.py
@@ -8,10 +8,13 @@
from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation
from services.device_model_relation_service import DeviceModelRelationService
-from algo.algo_runner import AlgoRunner
-from algo.algo_runner_manager import get_algo_runner
+from app_instance import get_app
router = APIRouter()
+app = get_app()
+
+def get_service():
+ return app.state.device_model_relation_service
@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]])
@@ -35,12 +38,10 @@
@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
def update_by_device(relation_data: List[DeviceModelRelationCreate],
device_id: int = Query(...),
- algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.relation_service
+ service: DeviceModelRelationService = Depends(get_service)):
relations = service.update_relations_by_device(device_id, relation_data)
return standard_response(data=relations)
-
# @router.delete("/delete_by_device", response_model=StandardResponse[int])
# def delete_device(device_id: int, db: Session = Depends(get_db)):
# service = DeviceModelRelationService(db)
diff --git a/apis/model.py b/apis/model.py
index 264e10a..78c064f 100644
--- a/apis/model.py
+++ b/apis/model.py
@@ -4,15 +4,17 @@
from sqlmodel import Session
from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param
+from app_instance import get_app
from db.database import get_db
from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo
from services.model_service import ModelService
-from algo.algo_runner import AlgoRunner
-from algo.algo_runner_manager import get_algo_runner
router = APIRouter()
+app = get_app()
+def get_service():
+ return app.state.model_service
@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]])
def get_model_list(
@@ -60,8 +62,7 @@
@router.post("/update", response_model=StandardResponse[AlgoModelInfo])
def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"),
file: UploadFile = File(None, description="模型文件"),
- algo_runner: AlgoRunner = Depends(get_algo_runner)):
- service = algo_runner.model_service
+ service: ModelService = Depends(get_service)):
model_data = AlgoModelUpdate.parse_raw(json_data)
model = service.update_model(model_data, file)
if not model:
diff --git a/apis/scene.py b/apis/scene.py
index b0a4d22..f84b54b 100644
--- a/apis/scene.py
+++ b/apis/scene.py
@@ -9,8 +9,6 @@
from entity.scene import SceneInfo
from services.scene_service import SceneService
-from algo.algo_runner import AlgoRunner
-from algo.algo_runner_manager import get_algo_runner
router = APIRouter()
diff --git a/app_instance.py b/app_instance.py
new file mode 100644
index 0000000..b72926d
--- /dev/null
+++ b/app_instance.py
@@ -0,0 +1,77 @@
+import asyncio
+from contextlib import asynccontextmanager
+
+from fastapi import FastAPI, HTTPException
+
+from algo.algo_runner import AlgoRunner
+from algo.scene_runner import SceneRunner
+
+from common.biz_exception import BizExceptionHandlers
+from common.global_logger import logger
+from db.database import get_db
+from services.device_model_relation_service import DeviceModelRelationService
+from services.device_scene_relation_service import DeviceSceneRelationService
+from services.device_service import DeviceService
+from services.model_service import ModelService
+from services.scene_service import SceneService
+from tcp.tcp_manager import TcpManager
+
+_app = None # 创建一个私有变量来存储 app 实例
+
+
+def create_app() -> FastAPI:
+ global _app
+ if _app is None:
+ _app = FastAPI()
+
+ @asynccontextmanager
+ async def lifespan(app: FastAPI):
+ main_loop = asyncio.get_running_loop()
+
+ with next(get_db()) as db:
+ device_service = DeviceService(db)
+ model_service = ModelService(db)
+ model_relation_service = DeviceModelRelationService(db)
+ scene_service = SceneService(db)
+ scene_relation_service = DeviceSceneRelationService(db)
+
+ app.state.device_service = device_service
+ app.state.model_service = model_service
+ app.state.model_relation_service = model_relation_service
+ app.state.scene_service = scene_service
+ app.state.scene_relation_service = scene_relation_service
+
+ tcp_manager = TcpManager(device_service=device_service)
+ app.state.tcp_manager = tcp_manager
+ await tcp_manager.start()
+
+ algo_runner = AlgoRunner(
+ device_service=device_service,
+ model_service=model_service,
+ relation_service=model_relation_service,
+ )
+ app.state.algo_runner = algo_runner
+ await algo_runner.start()
+
+ scene_runner = SceneRunner(
+ device_service=device_service,
+ scene_service=scene_service,
+ relation_service=scene_relation_service,
+ tcp_manager=tcp_manager,
+ main_loop=main_loop
+ )
+ app.state.scene_runner = scene_runner
+ await scene_runner.start()
+
+ yield # 允许请求处理
+
+ logger.info("Shutting down application...")
+
+ _app.router.lifespan_context = lifespan
+ _app.add_exception_handler(HTTPException, BizExceptionHandlers.biz_exception_handler)
+
+ return _app
+
+
+def get_app() -> FastAPI:
+ return _app or create_app()
diff --git a/common/string_utils.py b/common/string_utils.py
index 42d987b..7ff7f03 100644
--- a/common/string_utils.py
+++ b/common/string_utils.py
@@ -1,5 +1,12 @@
+import importlib
import re
+def get_class(module_name, class_name):
+ # 动态导入模块
+ module = importlib.import_module(module_name)
+ # 使用 getattr 从模块中获取类
+ return getattr(module, class_name)
+
def camel_to_snake(name):
# 将大写字母前加上下划线,并将整个字符串转换为小写
diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db
index b1d8fab..d1acdd4 100644
--- a/db/safe-algo-pro.db
+++ b/db/safe-algo-pro.db
Binary files differ
diff --git a/main.py b/main.py
index 61cbe9e..8e5e464 100644
--- a/main.py
+++ b/main.py
@@ -1,58 +1,21 @@
-import asyncio
-from contextlib import asynccontextmanager
-from fastapi import FastAPI, HTTPException
-
-from algo.algo_runner_manager import get_algo_runner
-from apis.router import router
import uvicorn
import logging
-from common.biz_exception import BizExceptionHandlers
+from app_instance import get_app
from common.global_logger import logger
+app = get_app()
-app = FastAPI()
-
-# # 初始化 AlgoRunner
-# algo_runner = AlgoRunner()
-#
-#
-# # 使用 FastAPI 的 startup 事件来启动 AlgoRunner
-# @app.on_event("startup")
-# async def startup_event():
-# algo_runner.start()
-
-algo_runner = get_algo_runner()
-
-
-@asynccontextmanager
-async def lifespan(app: FastAPI):
- # 应用启动时的初始化
- await algo_runner.start()
- # app.add_event_handler("startup", lambda: asyncio.create_task(algo_runner.start()))
- # 允许请求处理
- yield
- # 应用关闭时的清理逻辑
- logger.info("Shutting down application...")
-
-
-# 包含所有模块的路由
+# 延迟导入 router 并注册路由
+from apis.router import router
app.include_router(router, prefix="/api")
-app.add_exception_handler(HTTPException, BizExceptionHandlers.biz_exception_handler)
-app.router.lifespan_context = lifespan
-
if __name__ == "__main__":
-
# 重定向 uvicorn 的日志
uvicorn_logger = logging.getLogger("uvicorn")
uvicorn_logger.handlers = logger.handlers
uvicorn_logger.setLevel(logging.DEBUG)
- # 重定向 uvicorn 的 access 日志
- # access_logger = logging.getLogger("uvicorn.access")
- # access_logger.handlers = logger.handlers
- # access_logger.setLevel(logging.DEBUG)
uvicorn.run(app, host="0.0.0.0", port=9000, log_config=None)
diff --git a/requirements.txt b/requirements.txt
index 9a7d88c..af16eda 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -2,4 +2,12 @@
sqlmodel
openpyxl
python-multipart
-docker
\ No newline at end of file
+docker
+numpy
+ultralytics
+opencv-python
+pydantic
+pandas
+starlette
+uvicorn
+sqlalchemy
\ No newline at end of file
diff --git a/scene_handler/base_scene_handler.py b/scene_handler/base_scene_handler.py
new file mode 100644
index 0000000..5451a16
--- /dev/null
+++ b/scene_handler/base_scene_handler.py
@@ -0,0 +1,29 @@
+from asyncio import Event
+
+from algo.model_manager import AlgoModelExec
+from algo.stream_loader import OpenCVStreamLoad
+from common.device_status_manager import DeviceStatusManager
+from common.global_logger import logger
+from entity.device import Device
+
+from ultralytics import YOLO
+
+from tcp.tcp_manager import TcpManager
+
+
+class BaseSceneHandler:
+
+ def __init__(self, device: Device,
+ thread_id: str,
+ tcp_manager: TcpManager,
+ main_loop ):
+ self.device = device
+ self.thread_id = thread_id
+ self.tcp_manager =tcp_manager
+ self.main_loop = main_loop
+
+ def stop_task(self):
+ pass
+
+ def run(self):
+ pass
\ No newline at end of file
diff --git a/scene_handler/limit_space_scene_handler.py b/scene_handler/limit_space_scene_handler.py
new file mode 100644
index 0000000..68df7df
--- /dev/null
+++ b/scene_handler/limit_space_scene_handler.py
@@ -0,0 +1,65 @@
+import asyncio
+from asyncio import Event
+
+from algo.model_manager import AlgoModelExec
+from algo.stream_loader import OpenCVStreamLoad
+from common.device_status_manager import DeviceStatusManager
+from common.global_logger import logger
+from entity.device import Device
+
+from ultralytics import YOLO
+
+from scene_handler.base_scene_handler import BaseSceneHandler
+from tcp.tcp_manager import TcpManager
+
+
+class LimitSpaceSceneHandler(BaseSceneHandler):
+
+ def __init__(self, device: Device, thread_id: str, tcp_manager: TcpManager, main_loop):
+ super().__init__(device=device, thread_id=thread_id, tcp_manager=tcp_manager, main_loop=main_loop)
+ # self.device = device
+ # self.thread_id = thread_id
+ self.stream_loader = OpenCVStreamLoad(camera_url=device.input_stream_url, camera_code=device.code,
+ device_thread_id=thread_id)
+ self.device_status_manager = DeviceStatusManager()
+
+ self.person_model = YOLO('weights/yolov8s.pt')
+
+ self.__stop_event = Event(loop=main_loop) # 使用 Event 控制线程的运行状态
+
+ def stop_task(self, **kwargs):
+ logger.info(f'stop detection task {self.device.id}, thread_id: {self.thread_id}')
+ self.__stop_event.set()
+ self.stream_loader.stop() # 停止视频流加载的线程
+
+ def send_tcp_message(self, message: bytes, have_response=False):
+ asyncio.run_coroutine_threadsafe(
+ self.tcp_manager.send_message_to_device(device_id=self.device.id,
+ message=message,
+ have_response=have_response),
+ self.main_loop)
+
+ def run(self):
+ while not self.stream_loader.init:
+ if self.__stop_event.is_set():
+ break # 如果触发了停止事件,则退出循环
+ self.stream_loader.init_cap()
+ for frame in self.stream_loader:
+ if self.__stop_event.is_set():
+ break # 如果触发了停止事件,则退出循环
+ # print('frame')
+ if frame is None:
+ continue
+
+ self.device_status_manager.set_status(device_id=self.device.id)
+ results = self.person_model.predict(source=frame, imgsz=640,
+ save_txt=False,
+ save=False,
+ verbose=False, stream=True)
+ result = (list(results))
+ if len(result[0]) > 0:
+ asyncio.run_coroutine_threadsafe(
+ self.tcp_manager.send_message_to_device(device_id=self.device.id,
+ message=b'\xaa\x01\x00\x93\x07\x00\x9B',
+ have_response=False),
+ self.main_loop)
diff --git a/tcp/tcp_client_connector.py b/tcp/tcp_client_connector.py
index bda6a95..a62758e 100644
--- a/tcp/tcp_client_connector.py
+++ b/tcp/tcp_client_connector.py
@@ -139,7 +139,6 @@
# 可以根据需求选择是否接收响应
if have_response:
data = await asyncio.wait_for(self.reader.read(1024), timeout=self.timeout)
- print(format_bytes(data))
# if not data:
# raise ConnectionResetError("Connection lost or no data received")
self.parse_response(data)
@@ -159,6 +158,6 @@
# asyncio.run(client.connect()) # Run the asynchronous connect method
# 示例数据
- data = b'\x07 \x00\x01\x00\x01\xaa\x01\x00"0\r`'
+ data = b'\x07\x00\x01\x00\x01\xaa\x01\x00"0\r`'
result = parse_gas_data(data)
print(result)
diff --git a/tcp/tcp_manager.py b/tcp/tcp_manager.py
index 35ab60a..f28446d 100644
--- a/tcp/tcp_manager.py
+++ b/tcp/tcp_manager.py
@@ -1,28 +1,24 @@
import asyncio
from typing import List, Dict
-from algo.algo_runner_manager import get_algo_runner
+
from common.consts import DEVICE_TYPE, NotifyChangeType
-from db.database import get_db
from entity.device import Device
from services.device_service import DeviceService
-
from common.global_logger import logger
from tcp.tcp_client_connector import TcpClientConnector
class TcpManager:
- def __init__(self):
+ def __init__(self, device_service: DeviceService):
self.devices: List[Device] = []
self.connector_map: Dict[int, TcpClientConnector] = {}
- # 从全局algo_runner中获取device_service,确保能收到设备更新通知
- algo_runner = get_algo_runner()
- self.device_service = algo_runner.device_service
+ self.device_service = device_service
# 注册设备和模型的变化回调
- self.device_service.register_change_callback(self.on_device_change)
+ # self.device_service.register_change_callback(self.on_device_change)
async def load_and_connect_devices(self):
"""从数据库加载设备并连接所有设备"""
@@ -55,7 +51,7 @@
await self.start_device_connect(device)
async def on_device_change(self, device_id, change_type):
- """设备变化时的回调处理 todo 线程处理待优化"""
+ """设备变化时的回调处理"""
if change_type == NotifyChangeType.DEVICE_CREATE:
# 新增设备,加载新设备并连接
new_device = self.device_service.get_device(device_id)
@@ -72,6 +68,16 @@
"""断开设备连接"""
await connector.disconnect()
+ async def send_message_to_device(self, device_id, message: bytes, have_response):
+ if device_id not in self.connector_map:
+ device = self.device_service.get_device(device_id)
+ await self.start_device_connect(device)
+ connector = self.connector_map[device_id]
+ if connector:
+ await connector.send_message(message, have_response=have_response)
+
+
+
if __name__ == '__main__':
tcp_manager = TcpManager()
asyncio.run(tcp_manager.start())