diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..e530a86
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+test.py
+/logs/*
\ No newline at end of file
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..e530a86
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+test.py
+/logs/*
\ No newline at end of file
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..e530a86
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+test.py
+/logs/*
\ No newline at end of file
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..901fe74
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,20 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..e530a86
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+test.py
+/logs/*
\ No newline at end of file
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..901fe74
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,20 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..e530a86
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+test.py
+/logs/*
\ No newline at end of file
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..901fe74
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,20 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..b68d82f
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..e530a86
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+test.py
+/logs/*
\ No newline at end of file
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..901fe74
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,20 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..b68d82f
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..3deddad
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..e530a86
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+test.py
+/logs/*
\ No newline at end of file
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..901fe74
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,20 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..b68d82f
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..3deddad
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml
new file mode 100644
index 0000000..d0876a7
--- /dev/null
+++ b/.idea/safe-algo-pro.iml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..e530a86
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+test.py
+/logs/*
\ No newline at end of file
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..901fe74
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,20 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..b68d82f
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..3deddad
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml
new file mode 100644
index 0000000..d0876a7
--- /dev/null
+++ b/.idea/safe-algo-pro.iml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..94a25f7
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..e530a86
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+test.py
+/logs/*
\ No newline at end of file
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..901fe74
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,20 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..b68d82f
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..3deddad
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml
new file mode 100644
index 0000000..d0876a7
--- /dev/null
+++ b/.idea/safe-algo-pro.iml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..94a25f7
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/algo/__init__.py b/algo/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/algo/__init__.py
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..e530a86
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+test.py
+/logs/*
\ No newline at end of file
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..901fe74
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,20 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..b68d82f
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..3deddad
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml
new file mode 100644
index 0000000..d0876a7
--- /dev/null
+++ b/.idea/safe-algo-pro.iml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..94a25f7
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/algo/__init__.py b/algo/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/algo/__init__.py
diff --git a/algo/algo_runner.py b/algo/algo_runner.py
new file mode 100644
index 0000000..ee07eaf
--- /dev/null
+++ b/algo/algo_runner.py
@@ -0,0 +1,117 @@
+from algo.model_manager import ModelManager
+from common.consts import NotifyChangeType
+from common.global_thread_pool import GlobalThreadPool
+from db.database import get_db
+from entity.device import Device
+from services.device_model_relation_service import DeviceModelRelationService
+from services.device_service import DeviceService
+from services.model_service import ModelService
+
+
+class AlgoRunner:
+ def __init__(self):
+ self.db = get_db()
+ self.device_service = DeviceService(self.db)
+ self.model_service = ModelService(self.db)
+ self.relation_service = DeviceModelRelationService(self.db)
+ self.model_manager = ModelManager(self.db)
+ self.thread_pool = GlobalThreadPool()
+ self.threads = {} # 用于存储设备对应的线程
+ self.device_model_relations = {}
+
+ # 注册设备和模型的变化回调
+ self.device_service.register_change_callback(self.on_device_change)
+ self.model_service.register_change_callback(self.on_model_change)
+ self.relation_service.relation_change_callbacks(self.on_relation_change)
+
+ def start(self):
+ """在程序启动时调用,读取设备和模型,启动检测线程"""
+ self.model_manager.load_models()
+ self.load_and_start_devices()
+
+ def load_and_start_devices(self):
+ """从数据库读取设备列表并启动线程"""
+ devices = self.device_service.get_device_list()
+ devices = [device for device in devices if device.input_stream_url]
+ for device in devices:
+ self.start_device_thread(device)
+
+ def run_detection(self, device: Device):
+ """todo 设备目标检测主逻辑"""
+ print(f"设备 {device.device_id} 的检测线程启动")
+ video_stream = self.device_service.get_video_stream(device.device_id)
+
+ while True:
+ try:
+ frame = video_stream.read_frame()
+ for model in self.model_manager.models:
+ # 调用目标检测模型
+ result = self.model_service.run_inference(model, frame)
+ # 在此处进行后处理,例如标记视频、生成告警等
+ self.process_result(result, device.device_id)
+ except Exception as e:
+ print(f"设备 {device.device_id} 处理时出错: {e}")
+ break
+ video_stream.close()
+ print(f"设备 {device.device_id} 的检测线程结束")
+
+ def start_device_thread(self, device: Device):
+ """为单个设备启动检测线程"""
+
+ if not device.input_stream_url:
+ print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程')
+ return
+
+ if device.device_id in self.threads:
+ print(f"设备 {device.device_id} 已经在运行中")
+ return
+
+ # 获取设备绑定的模型列表
+ relations = self.relation_service.get_device_models(device.device_id)
+ self.device_model_relations[device.device_id] = relations
+
+ if not relations:
+ print(f"设备 {device.code} 未绑定模型,无法启动检测")
+ return
+
+ future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}')
+
+ def stop_device_thread(self, device_id):
+ """todo 控制标志位 停止指定设备的检测线程"""
+ if device_id in self.threads:
+ self.threads[device_id].cancel() # 尝试取消线程
+ del self.threads[device_id]
+ print(f"设备 {device_id} 的检测线程已停止")
+
+ def restart_device_thread(self, device_id):
+ self.stop_device_thread(device_id)
+ # todo 需要留个关闭时间吗
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+
+ def on_device_change(self, device_id, change_type):
+ """设备变化回调"""
+ if change_type == NotifyChangeType.DEVICE_CREATE:
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ elif change_type == NotifyChangeType.DEVICE_DELETE:
+ self.stop_device_thread(device_id)
+ elif change_type == NotifyChangeType.DEVICE_UPDATE:
+ self.restart_device_thread(device_id)
+
+ def on_model_change(self, model_id, change_type):
+ """模型变化回调"""
+ if change_type == NotifyChangeType.MODEL_UPDATE:
+ self.model_manager.reload_model(model_id)
+
+ devices_to_reload = [
+ device_id for device_id, relation_info in self.device_model_relations.items()
+ if relation_info.model_id == model_id
+ ]
+ for device_id in devices_to_reload:
+ self.restart_device_thread(device_id)
+
+ def on_relation_change(self, device_id, change_type):
+ """设备模型关系变化回调"""
+ if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE:
+ self.restart_device_thread(device_id)
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..e530a86
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+test.py
+/logs/*
\ No newline at end of file
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..901fe74
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,20 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..b68d82f
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..3deddad
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml
new file mode 100644
index 0000000..d0876a7
--- /dev/null
+++ b/.idea/safe-algo-pro.iml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..94a25f7
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/algo/__init__.py b/algo/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/algo/__init__.py
diff --git a/algo/algo_runner.py b/algo/algo_runner.py
new file mode 100644
index 0000000..ee07eaf
--- /dev/null
+++ b/algo/algo_runner.py
@@ -0,0 +1,117 @@
+from algo.model_manager import ModelManager
+from common.consts import NotifyChangeType
+from common.global_thread_pool import GlobalThreadPool
+from db.database import get_db
+from entity.device import Device
+from services.device_model_relation_service import DeviceModelRelationService
+from services.device_service import DeviceService
+from services.model_service import ModelService
+
+
+class AlgoRunner:
+ def __init__(self):
+ self.db = get_db()
+ self.device_service = DeviceService(self.db)
+ self.model_service = ModelService(self.db)
+ self.relation_service = DeviceModelRelationService(self.db)
+ self.model_manager = ModelManager(self.db)
+ self.thread_pool = GlobalThreadPool()
+ self.threads = {} # 用于存储设备对应的线程
+ self.device_model_relations = {}
+
+ # 注册设备和模型的变化回调
+ self.device_service.register_change_callback(self.on_device_change)
+ self.model_service.register_change_callback(self.on_model_change)
+ self.relation_service.relation_change_callbacks(self.on_relation_change)
+
+ def start(self):
+ """在程序启动时调用,读取设备和模型,启动检测线程"""
+ self.model_manager.load_models()
+ self.load_and_start_devices()
+
+ def load_and_start_devices(self):
+ """从数据库读取设备列表并启动线程"""
+ devices = self.device_service.get_device_list()
+ devices = [device for device in devices if device.input_stream_url]
+ for device in devices:
+ self.start_device_thread(device)
+
+ def run_detection(self, device: Device):
+ """todo 设备目标检测主逻辑"""
+ print(f"设备 {device.device_id} 的检测线程启动")
+ video_stream = self.device_service.get_video_stream(device.device_id)
+
+ while True:
+ try:
+ frame = video_stream.read_frame()
+ for model in self.model_manager.models:
+ # 调用目标检测模型
+ result = self.model_service.run_inference(model, frame)
+ # 在此处进行后处理,例如标记视频、生成告警等
+ self.process_result(result, device.device_id)
+ except Exception as e:
+ print(f"设备 {device.device_id} 处理时出错: {e}")
+ break
+ video_stream.close()
+ print(f"设备 {device.device_id} 的检测线程结束")
+
+ def start_device_thread(self, device: Device):
+ """为单个设备启动检测线程"""
+
+ if not device.input_stream_url:
+ print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程')
+ return
+
+ if device.device_id in self.threads:
+ print(f"设备 {device.device_id} 已经在运行中")
+ return
+
+ # 获取设备绑定的模型列表
+ relations = self.relation_service.get_device_models(device.device_id)
+ self.device_model_relations[device.device_id] = relations
+
+ if not relations:
+ print(f"设备 {device.code} 未绑定模型,无法启动检测")
+ return
+
+ future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}')
+
+ def stop_device_thread(self, device_id):
+ """todo 控制标志位 停止指定设备的检测线程"""
+ if device_id in self.threads:
+ self.threads[device_id].cancel() # 尝试取消线程
+ del self.threads[device_id]
+ print(f"设备 {device_id} 的检测线程已停止")
+
+ def restart_device_thread(self, device_id):
+ self.stop_device_thread(device_id)
+ # todo 需要留个关闭时间吗
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+
+ def on_device_change(self, device_id, change_type):
+ """设备变化回调"""
+ if change_type == NotifyChangeType.DEVICE_CREATE:
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ elif change_type == NotifyChangeType.DEVICE_DELETE:
+ self.stop_device_thread(device_id)
+ elif change_type == NotifyChangeType.DEVICE_UPDATE:
+ self.restart_device_thread(device_id)
+
+ def on_model_change(self, model_id, change_type):
+ """模型变化回调"""
+ if change_type == NotifyChangeType.MODEL_UPDATE:
+ self.model_manager.reload_model(model_id)
+
+ devices_to_reload = [
+ device_id for device_id, relation_info in self.device_model_relations.items()
+ if relation_info.model_id == model_id
+ ]
+ for device_id in devices_to_reload:
+ self.restart_device_thread(device_id)
+
+ def on_relation_change(self, device_id, change_type):
+ """设备模型关系变化回调"""
+ if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE:
+ self.restart_device_thread(device_id)
diff --git a/algo/model_manager.py b/algo/model_manager.py
new file mode 100644
index 0000000..3ab0709
--- /dev/null
+++ b/algo/model_manager.py
@@ -0,0 +1,64 @@
+from dataclasses import dataclass
+from typing import Optional
+
+import numpy as np
+from ultralytics import YOLO
+
+from entity.model import AlgoModel
+from services.model_service import ModelService
+from common.global_logger import logger
+
+
+@dataclass
+class AlgoModelExec:
+ algo_model_id: int
+ algo_model_info: AlgoModel
+ algo_model_exec: Optional[object] = None
+ input_size: int = 640
+
+
+class ModelManager:
+ def __init__(self, db, model_warm_up=5):
+ self.db = db
+ self.model_service = ModelService(self.db)
+ self.models = {}
+ self.model_warm_up = model_warm_up
+
+ self.load_models()
+
+ def query_model_inuse(self):
+ algo_model_list = list(self.model_service.get_models_in_use())
+ for algo_model in algo_model_list:
+ self.models[algo_model.id] = AlgoModelExec(
+ algo_model_id=algo_model.id,
+ algo_model_info=algo_model
+ )
+
+ def load_models(self):
+ self.models = {}
+ self.query_model_inuse()
+ for algo_model_id, algo_model_exec in self.models.items():
+ self.load_model(algo_model_exec)
+
+ def load_model(self, algo_model_exec: AlgoModelExec):
+ model_name = algo_model_exec.algo_model_info.name
+ model_path = algo_model_exec.algo_model_info.path
+ logger.info(f'loading model {model_name}: {model_path}')
+
+ algo_model_exec.algo_model_exec = YOLO(model_path, task='detect')
+ if self.model_warm_up > 0:
+ logger.info(f'warming up model {model_name}')
+ imgsz = algo_model_exec.input_size
+ if not isinstance(imgsz, list):
+ imgsz = [imgsz, imgsz]
+ dummy_input = np.zeros((imgsz[0], imgsz[1], 3))
+ for i in range(self.model_warm_up):
+ algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False)
+ logger.info(f'warm up model {model_name} success!')
+ logger.info(f'load model {model_name} success!')
+
+ def reload_model(self, model_id):
+ algo_model_exec = self.models.get(model_id)
+ if algo_model_exec:
+ algo_model_exec.algo_model_exec = None
+ self.load_model(algo_model_exec)
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..e530a86
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+test.py
+/logs/*
\ No newline at end of file
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..901fe74
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,20 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..b68d82f
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..3deddad
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml
new file mode 100644
index 0000000..d0876a7
--- /dev/null
+++ b/.idea/safe-algo-pro.iml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..94a25f7
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/algo/__init__.py b/algo/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/algo/__init__.py
diff --git a/algo/algo_runner.py b/algo/algo_runner.py
new file mode 100644
index 0000000..ee07eaf
--- /dev/null
+++ b/algo/algo_runner.py
@@ -0,0 +1,117 @@
+from algo.model_manager import ModelManager
+from common.consts import NotifyChangeType
+from common.global_thread_pool import GlobalThreadPool
+from db.database import get_db
+from entity.device import Device
+from services.device_model_relation_service import DeviceModelRelationService
+from services.device_service import DeviceService
+from services.model_service import ModelService
+
+
+class AlgoRunner:
+ def __init__(self):
+ self.db = get_db()
+ self.device_service = DeviceService(self.db)
+ self.model_service = ModelService(self.db)
+ self.relation_service = DeviceModelRelationService(self.db)
+ self.model_manager = ModelManager(self.db)
+ self.thread_pool = GlobalThreadPool()
+ self.threads = {} # 用于存储设备对应的线程
+ self.device_model_relations = {}
+
+ # 注册设备和模型的变化回调
+ self.device_service.register_change_callback(self.on_device_change)
+ self.model_service.register_change_callback(self.on_model_change)
+ self.relation_service.relation_change_callbacks(self.on_relation_change)
+
+ def start(self):
+ """在程序启动时调用,读取设备和模型,启动检测线程"""
+ self.model_manager.load_models()
+ self.load_and_start_devices()
+
+ def load_and_start_devices(self):
+ """从数据库读取设备列表并启动线程"""
+ devices = self.device_service.get_device_list()
+ devices = [device for device in devices if device.input_stream_url]
+ for device in devices:
+ self.start_device_thread(device)
+
+ def run_detection(self, device: Device):
+ """todo 设备目标检测主逻辑"""
+ print(f"设备 {device.device_id} 的检测线程启动")
+ video_stream = self.device_service.get_video_stream(device.device_id)
+
+ while True:
+ try:
+ frame = video_stream.read_frame()
+ for model in self.model_manager.models:
+ # 调用目标检测模型
+ result = self.model_service.run_inference(model, frame)
+ # 在此处进行后处理,例如标记视频、生成告警等
+ self.process_result(result, device.device_id)
+ except Exception as e:
+ print(f"设备 {device.device_id} 处理时出错: {e}")
+ break
+ video_stream.close()
+ print(f"设备 {device.device_id} 的检测线程结束")
+
+ def start_device_thread(self, device: Device):
+ """为单个设备启动检测线程"""
+
+ if not device.input_stream_url:
+ print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程')
+ return
+
+ if device.device_id in self.threads:
+ print(f"设备 {device.device_id} 已经在运行中")
+ return
+
+ # 获取设备绑定的模型列表
+ relations = self.relation_service.get_device_models(device.device_id)
+ self.device_model_relations[device.device_id] = relations
+
+ if not relations:
+ print(f"设备 {device.code} 未绑定模型,无法启动检测")
+ return
+
+ future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}')
+
+ def stop_device_thread(self, device_id):
+ """todo 控制标志位 停止指定设备的检测线程"""
+ if device_id in self.threads:
+ self.threads[device_id].cancel() # 尝试取消线程
+ del self.threads[device_id]
+ print(f"设备 {device_id} 的检测线程已停止")
+
+ def restart_device_thread(self, device_id):
+ self.stop_device_thread(device_id)
+ # todo 需要留个关闭时间吗
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+
+ def on_device_change(self, device_id, change_type):
+ """设备变化回调"""
+ if change_type == NotifyChangeType.DEVICE_CREATE:
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ elif change_type == NotifyChangeType.DEVICE_DELETE:
+ self.stop_device_thread(device_id)
+ elif change_type == NotifyChangeType.DEVICE_UPDATE:
+ self.restart_device_thread(device_id)
+
+ def on_model_change(self, model_id, change_type):
+ """模型变化回调"""
+ if change_type == NotifyChangeType.MODEL_UPDATE:
+ self.model_manager.reload_model(model_id)
+
+ devices_to_reload = [
+ device_id for device_id, relation_info in self.device_model_relations.items()
+ if relation_info.model_id == model_id
+ ]
+ for device_id in devices_to_reload:
+ self.restart_device_thread(device_id)
+
+ def on_relation_change(self, device_id, change_type):
+ """设备模型关系变化回调"""
+ if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE:
+ self.restart_device_thread(device_id)
diff --git a/algo/model_manager.py b/algo/model_manager.py
new file mode 100644
index 0000000..3ab0709
--- /dev/null
+++ b/algo/model_manager.py
@@ -0,0 +1,64 @@
+from dataclasses import dataclass
+from typing import Optional
+
+import numpy as np
+from ultralytics import YOLO
+
+from entity.model import AlgoModel
+from services.model_service import ModelService
+from common.global_logger import logger
+
+
+@dataclass
+class AlgoModelExec:
+ algo_model_id: int
+ algo_model_info: AlgoModel
+ algo_model_exec: Optional[object] = None
+ input_size: int = 640
+
+
+class ModelManager:
+ def __init__(self, db, model_warm_up=5):
+ self.db = db
+ self.model_service = ModelService(self.db)
+ self.models = {}
+ self.model_warm_up = model_warm_up
+
+ self.load_models()
+
+ def query_model_inuse(self):
+ algo_model_list = list(self.model_service.get_models_in_use())
+ for algo_model in algo_model_list:
+ self.models[algo_model.id] = AlgoModelExec(
+ algo_model_id=algo_model.id,
+ algo_model_info=algo_model
+ )
+
+ def load_models(self):
+ self.models = {}
+ self.query_model_inuse()
+ for algo_model_id, algo_model_exec in self.models.items():
+ self.load_model(algo_model_exec)
+
+ def load_model(self, algo_model_exec: AlgoModelExec):
+ model_name = algo_model_exec.algo_model_info.name
+ model_path = algo_model_exec.algo_model_info.path
+ logger.info(f'loading model {model_name}: {model_path}')
+
+ algo_model_exec.algo_model_exec = YOLO(model_path, task='detect')
+ if self.model_warm_up > 0:
+ logger.info(f'warming up model {model_name}')
+ imgsz = algo_model_exec.input_size
+ if not isinstance(imgsz, list):
+ imgsz = [imgsz, imgsz]
+ dummy_input = np.zeros((imgsz[0], imgsz[1], 3))
+ for i in range(self.model_warm_up):
+ algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False)
+ logger.info(f'warm up model {model_name} success!')
+ logger.info(f'load model {model_name} success!')
+
+ def reload_model(self, model_id):
+ algo_model_exec = self.models.get(model_id)
+ if algo_model_exec:
+ algo_model_exec.algo_model_exec = None
+ self.load_model(algo_model_exec)
diff --git a/apis/__init__.py b/apis/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/apis/__init__.py
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..e530a86
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+test.py
+/logs/*
\ No newline at end of file
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..901fe74
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,20 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..b68d82f
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..3deddad
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml
new file mode 100644
index 0000000..d0876a7
--- /dev/null
+++ b/.idea/safe-algo-pro.iml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..94a25f7
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/algo/__init__.py b/algo/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/algo/__init__.py
diff --git a/algo/algo_runner.py b/algo/algo_runner.py
new file mode 100644
index 0000000..ee07eaf
--- /dev/null
+++ b/algo/algo_runner.py
@@ -0,0 +1,117 @@
+from algo.model_manager import ModelManager
+from common.consts import NotifyChangeType
+from common.global_thread_pool import GlobalThreadPool
+from db.database import get_db
+from entity.device import Device
+from services.device_model_relation_service import DeviceModelRelationService
+from services.device_service import DeviceService
+from services.model_service import ModelService
+
+
+class AlgoRunner:
+ def __init__(self):
+ self.db = get_db()
+ self.device_service = DeviceService(self.db)
+ self.model_service = ModelService(self.db)
+ self.relation_service = DeviceModelRelationService(self.db)
+ self.model_manager = ModelManager(self.db)
+ self.thread_pool = GlobalThreadPool()
+ self.threads = {} # 用于存储设备对应的线程
+ self.device_model_relations = {}
+
+ # 注册设备和模型的变化回调
+ self.device_service.register_change_callback(self.on_device_change)
+ self.model_service.register_change_callback(self.on_model_change)
+ self.relation_service.relation_change_callbacks(self.on_relation_change)
+
+ def start(self):
+ """在程序启动时调用,读取设备和模型,启动检测线程"""
+ self.model_manager.load_models()
+ self.load_and_start_devices()
+
+ def load_and_start_devices(self):
+ """从数据库读取设备列表并启动线程"""
+ devices = self.device_service.get_device_list()
+ devices = [device for device in devices if device.input_stream_url]
+ for device in devices:
+ self.start_device_thread(device)
+
+ def run_detection(self, device: Device):
+ """todo 设备目标检测主逻辑"""
+ print(f"设备 {device.device_id} 的检测线程启动")
+ video_stream = self.device_service.get_video_stream(device.device_id)
+
+ while True:
+ try:
+ frame = video_stream.read_frame()
+ for model in self.model_manager.models:
+ # 调用目标检测模型
+ result = self.model_service.run_inference(model, frame)
+ # 在此处进行后处理,例如标记视频、生成告警等
+ self.process_result(result, device.device_id)
+ except Exception as e:
+ print(f"设备 {device.device_id} 处理时出错: {e}")
+ break
+ video_stream.close()
+ print(f"设备 {device.device_id} 的检测线程结束")
+
+ def start_device_thread(self, device: Device):
+ """为单个设备启动检测线程"""
+
+ if not device.input_stream_url:
+ print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程')
+ return
+
+ if device.device_id in self.threads:
+ print(f"设备 {device.device_id} 已经在运行中")
+ return
+
+ # 获取设备绑定的模型列表
+ relations = self.relation_service.get_device_models(device.device_id)
+ self.device_model_relations[device.device_id] = relations
+
+ if not relations:
+ print(f"设备 {device.code} 未绑定模型,无法启动检测")
+ return
+
+ future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}')
+
+ def stop_device_thread(self, device_id):
+ """todo 控制标志位 停止指定设备的检测线程"""
+ if device_id in self.threads:
+ self.threads[device_id].cancel() # 尝试取消线程
+ del self.threads[device_id]
+ print(f"设备 {device_id} 的检测线程已停止")
+
+ def restart_device_thread(self, device_id):
+ self.stop_device_thread(device_id)
+ # todo 需要留个关闭时间吗
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+
+ def on_device_change(self, device_id, change_type):
+ """设备变化回调"""
+ if change_type == NotifyChangeType.DEVICE_CREATE:
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ elif change_type == NotifyChangeType.DEVICE_DELETE:
+ self.stop_device_thread(device_id)
+ elif change_type == NotifyChangeType.DEVICE_UPDATE:
+ self.restart_device_thread(device_id)
+
+ def on_model_change(self, model_id, change_type):
+ """模型变化回调"""
+ if change_type == NotifyChangeType.MODEL_UPDATE:
+ self.model_manager.reload_model(model_id)
+
+ devices_to_reload = [
+ device_id for device_id, relation_info in self.device_model_relations.items()
+ if relation_info.model_id == model_id
+ ]
+ for device_id in devices_to_reload:
+ self.restart_device_thread(device_id)
+
+ def on_relation_change(self, device_id, change_type):
+ """设备模型关系变化回调"""
+ if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE:
+ self.restart_device_thread(device_id)
diff --git a/algo/model_manager.py b/algo/model_manager.py
new file mode 100644
index 0000000..3ab0709
--- /dev/null
+++ b/algo/model_manager.py
@@ -0,0 +1,64 @@
+from dataclasses import dataclass
+from typing import Optional
+
+import numpy as np
+from ultralytics import YOLO
+
+from entity.model import AlgoModel
+from services.model_service import ModelService
+from common.global_logger import logger
+
+
+@dataclass
+class AlgoModelExec:
+ algo_model_id: int
+ algo_model_info: AlgoModel
+ algo_model_exec: Optional[object] = None
+ input_size: int = 640
+
+
+class ModelManager:
+ def __init__(self, db, model_warm_up=5):
+ self.db = db
+ self.model_service = ModelService(self.db)
+ self.models = {}
+ self.model_warm_up = model_warm_up
+
+ self.load_models()
+
+ def query_model_inuse(self):
+ algo_model_list = list(self.model_service.get_models_in_use())
+ for algo_model in algo_model_list:
+ self.models[algo_model.id] = AlgoModelExec(
+ algo_model_id=algo_model.id,
+ algo_model_info=algo_model
+ )
+
+ def load_models(self):
+ self.models = {}
+ self.query_model_inuse()
+ for algo_model_id, algo_model_exec in self.models.items():
+ self.load_model(algo_model_exec)
+
+ def load_model(self, algo_model_exec: AlgoModelExec):
+ model_name = algo_model_exec.algo_model_info.name
+ model_path = algo_model_exec.algo_model_info.path
+ logger.info(f'loading model {model_name}: {model_path}')
+
+ algo_model_exec.algo_model_exec = YOLO(model_path, task='detect')
+ if self.model_warm_up > 0:
+ logger.info(f'warming up model {model_name}')
+ imgsz = algo_model_exec.input_size
+ if not isinstance(imgsz, list):
+ imgsz = [imgsz, imgsz]
+ dummy_input = np.zeros((imgsz[0], imgsz[1], 3))
+ for i in range(self.model_warm_up):
+ algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False)
+ logger.info(f'warm up model {model_name} success!')
+ logger.info(f'load model {model_name} success!')
+
+ def reload_model(self, model_id):
+ algo_model_exec = self.models.get(model_id)
+ if algo_model_exec:
+ algo_model_exec.algo_model_exec = None
+ self.load_model(algo_model_exec)
diff --git a/apis/__init__.py b/apis/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/apis/__init__.py
diff --git a/apis/base.py b/apis/base.py
new file mode 100644
index 0000000..a7c4ede
--- /dev/null
+++ b/apis/base.py
@@ -0,0 +1,26 @@
+from pydantic import BaseModel, Field
+from typing import Optional, Generic, TypeVar, Any, List
+
+# 定义一个泛型类型变量 T
+T = TypeVar("T")
+
+
+# 使用泛型 T 定义 data 的类型
+class StandardResponse(BaseModel, Generic[T]):
+ code: int = 200
+ data: Optional[T] = None
+ message: str = "请求成功"
+ success: bool = True
+
+
+class PageResponse(BaseModel, Generic[T]):
+ total: int
+ items: List[T] = Field(default_factory=list)
+
+
+def standard_response(data: Any = None, code: int = 200, message: str = "请求成功", success: bool = True):
+ return StandardResponse(data=data, code=code, message=message, success=success)
+
+
+def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False):
+ return StandardResponse(data=data, code=code, message=message, success=success)
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..e530a86
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+test.py
+/logs/*
\ No newline at end of file
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..901fe74
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,20 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..b68d82f
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..3deddad
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml
new file mode 100644
index 0000000..d0876a7
--- /dev/null
+++ b/.idea/safe-algo-pro.iml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..94a25f7
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/algo/__init__.py b/algo/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/algo/__init__.py
diff --git a/algo/algo_runner.py b/algo/algo_runner.py
new file mode 100644
index 0000000..ee07eaf
--- /dev/null
+++ b/algo/algo_runner.py
@@ -0,0 +1,117 @@
+from algo.model_manager import ModelManager
+from common.consts import NotifyChangeType
+from common.global_thread_pool import GlobalThreadPool
+from db.database import get_db
+from entity.device import Device
+from services.device_model_relation_service import DeviceModelRelationService
+from services.device_service import DeviceService
+from services.model_service import ModelService
+
+
+class AlgoRunner:
+ def __init__(self):
+ self.db = get_db()
+ self.device_service = DeviceService(self.db)
+ self.model_service = ModelService(self.db)
+ self.relation_service = DeviceModelRelationService(self.db)
+ self.model_manager = ModelManager(self.db)
+ self.thread_pool = GlobalThreadPool()
+ self.threads = {} # 用于存储设备对应的线程
+ self.device_model_relations = {}
+
+ # 注册设备和模型的变化回调
+ self.device_service.register_change_callback(self.on_device_change)
+ self.model_service.register_change_callback(self.on_model_change)
+ self.relation_service.relation_change_callbacks(self.on_relation_change)
+
+ def start(self):
+ """在程序启动时调用,读取设备和模型,启动检测线程"""
+ self.model_manager.load_models()
+ self.load_and_start_devices()
+
+ def load_and_start_devices(self):
+ """从数据库读取设备列表并启动线程"""
+ devices = self.device_service.get_device_list()
+ devices = [device for device in devices if device.input_stream_url]
+ for device in devices:
+ self.start_device_thread(device)
+
+ def run_detection(self, device: Device):
+ """todo 设备目标检测主逻辑"""
+ print(f"设备 {device.device_id} 的检测线程启动")
+ video_stream = self.device_service.get_video_stream(device.device_id)
+
+ while True:
+ try:
+ frame = video_stream.read_frame()
+ for model in self.model_manager.models:
+ # 调用目标检测模型
+ result = self.model_service.run_inference(model, frame)
+ # 在此处进行后处理,例如标记视频、生成告警等
+ self.process_result(result, device.device_id)
+ except Exception as e:
+ print(f"设备 {device.device_id} 处理时出错: {e}")
+ break
+ video_stream.close()
+ print(f"设备 {device.device_id} 的检测线程结束")
+
+ def start_device_thread(self, device: Device):
+ """为单个设备启动检测线程"""
+
+ if not device.input_stream_url:
+ print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程')
+ return
+
+ if device.device_id in self.threads:
+ print(f"设备 {device.device_id} 已经在运行中")
+ return
+
+ # 获取设备绑定的模型列表
+ relations = self.relation_service.get_device_models(device.device_id)
+ self.device_model_relations[device.device_id] = relations
+
+ if not relations:
+ print(f"设备 {device.code} 未绑定模型,无法启动检测")
+ return
+
+ future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}')
+
+ def stop_device_thread(self, device_id):
+ """todo 控制标志位 停止指定设备的检测线程"""
+ if device_id in self.threads:
+ self.threads[device_id].cancel() # 尝试取消线程
+ del self.threads[device_id]
+ print(f"设备 {device_id} 的检测线程已停止")
+
+ def restart_device_thread(self, device_id):
+ self.stop_device_thread(device_id)
+ # todo 需要留个关闭时间吗
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+
+ def on_device_change(self, device_id, change_type):
+ """设备变化回调"""
+ if change_type == NotifyChangeType.DEVICE_CREATE:
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ elif change_type == NotifyChangeType.DEVICE_DELETE:
+ self.stop_device_thread(device_id)
+ elif change_type == NotifyChangeType.DEVICE_UPDATE:
+ self.restart_device_thread(device_id)
+
+ def on_model_change(self, model_id, change_type):
+ """模型变化回调"""
+ if change_type == NotifyChangeType.MODEL_UPDATE:
+ self.model_manager.reload_model(model_id)
+
+ devices_to_reload = [
+ device_id for device_id, relation_info in self.device_model_relations.items()
+ if relation_info.model_id == model_id
+ ]
+ for device_id in devices_to_reload:
+ self.restart_device_thread(device_id)
+
+ def on_relation_change(self, device_id, change_type):
+ """设备模型关系变化回调"""
+ if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE:
+ self.restart_device_thread(device_id)
diff --git a/algo/model_manager.py b/algo/model_manager.py
new file mode 100644
index 0000000..3ab0709
--- /dev/null
+++ b/algo/model_manager.py
@@ -0,0 +1,64 @@
+from dataclasses import dataclass
+from typing import Optional
+
+import numpy as np
+from ultralytics import YOLO
+
+from entity.model import AlgoModel
+from services.model_service import ModelService
+from common.global_logger import logger
+
+
+@dataclass
+class AlgoModelExec:
+ algo_model_id: int
+ algo_model_info: AlgoModel
+ algo_model_exec: Optional[object] = None
+ input_size: int = 640
+
+
+class ModelManager:
+ def __init__(self, db, model_warm_up=5):
+ self.db = db
+ self.model_service = ModelService(self.db)
+ self.models = {}
+ self.model_warm_up = model_warm_up
+
+ self.load_models()
+
+ def query_model_inuse(self):
+ algo_model_list = list(self.model_service.get_models_in_use())
+ for algo_model in algo_model_list:
+ self.models[algo_model.id] = AlgoModelExec(
+ algo_model_id=algo_model.id,
+ algo_model_info=algo_model
+ )
+
+ def load_models(self):
+ self.models = {}
+ self.query_model_inuse()
+ for algo_model_id, algo_model_exec in self.models.items():
+ self.load_model(algo_model_exec)
+
+ def load_model(self, algo_model_exec: AlgoModelExec):
+ model_name = algo_model_exec.algo_model_info.name
+ model_path = algo_model_exec.algo_model_info.path
+ logger.info(f'loading model {model_name}: {model_path}')
+
+ algo_model_exec.algo_model_exec = YOLO(model_path, task='detect')
+ if self.model_warm_up > 0:
+ logger.info(f'warming up model {model_name}')
+ imgsz = algo_model_exec.input_size
+ if not isinstance(imgsz, list):
+ imgsz = [imgsz, imgsz]
+ dummy_input = np.zeros((imgsz[0], imgsz[1], 3))
+ for i in range(self.model_warm_up):
+ algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False)
+ logger.info(f'warm up model {model_name} success!')
+ logger.info(f'load model {model_name} success!')
+
+ def reload_model(self, model_id):
+ algo_model_exec = self.models.get(model_id)
+ if algo_model_exec:
+ algo_model_exec.algo_model_exec = None
+ self.load_model(algo_model_exec)
diff --git a/apis/__init__.py b/apis/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/apis/__init__.py
diff --git a/apis/base.py b/apis/base.py
new file mode 100644
index 0000000..a7c4ede
--- /dev/null
+++ b/apis/base.py
@@ -0,0 +1,26 @@
+from pydantic import BaseModel, Field
+from typing import Optional, Generic, TypeVar, Any, List
+
+# 定义一个泛型类型变量 T
+T = TypeVar("T")
+
+
+# 使用泛型 T 定义 data 的类型
+class StandardResponse(BaseModel, Generic[T]):
+ code: int = 200
+ data: Optional[T] = None
+ message: str = "请求成功"
+ success: bool = True
+
+
+class PageResponse(BaseModel, Generic[T]):
+ total: int
+ items: List[T] = Field(default_factory=list)
+
+
+def standard_response(data: Any = None, code: int = 200, message: str = "请求成功", success: bool = True):
+ return StandardResponse(data=data, code=code, message=message, success=success)
+
+
+def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False):
+ return StandardResponse(data=data, code=code, message=message, success=success)
diff --git a/apis/device.py b/apis/device.py
new file mode 100644
index 0000000..7eeaa78
--- /dev/null
+++ b/apis/device.py
@@ -0,0 +1,65 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, Query
+from sqlmodel import Session
+
+from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response
+from db.database import get_db
+from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo
+from services.device_service import DeviceService
+
+router = APIRouter()
+
+
+@router.get("/list", response_model=StandardResponse[List[DeviceInfo]])
+def get_device_list(
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ devices = list(service.get_device_list(name, code, device_type))
+ return standard_response(data=devices)
+
+
+@router.get("/page/", response_model=StandardResponse[PageResponse[Device]])
+def get_device_page(
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ offset: int = Query(0, ge=0), # 从第几条开始
+ limit: int = Query(10, ge=1), # 每页显示多少条记录
+ db: Session = Depends(get_db)):
+ service = DeviceService(db)
+
+ # 获取分页后的设备列表和总数
+ devices, total = service.get_device_page(name, code, device_type, offset, limit)
+
+ return standard_response(
+ data=PageResponse(total=total, items=devices)
+ )
+
+
+@router.post("/add", response_model=StandardResponse[DeviceInfo])
+def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.create_device(device_data)
+ return standard_response(data=device)
+
+
+@router.post("/update", response_model=StandardResponse[DeviceInfo])
+def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.update_device(device_data)
+ if not device:
+ return standard_error_response(data=device_data, message="Device not found")
+ return standard_response(data=device)
+
+
+@router.delete("/delete", response_model=StandardResponse[int])
+def delete_device(device_id: int, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.delete_device(device_id)
+ if not device:
+ return standard_error_response(data=device_id, message="Device not found")
+ return standard_response(data=device_id)
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..e530a86
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+test.py
+/logs/*
\ No newline at end of file
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..901fe74
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,20 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..b68d82f
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..3deddad
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml
new file mode 100644
index 0000000..d0876a7
--- /dev/null
+++ b/.idea/safe-algo-pro.iml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..94a25f7
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/algo/__init__.py b/algo/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/algo/__init__.py
diff --git a/algo/algo_runner.py b/algo/algo_runner.py
new file mode 100644
index 0000000..ee07eaf
--- /dev/null
+++ b/algo/algo_runner.py
@@ -0,0 +1,117 @@
+from algo.model_manager import ModelManager
+from common.consts import NotifyChangeType
+from common.global_thread_pool import GlobalThreadPool
+from db.database import get_db
+from entity.device import Device
+from services.device_model_relation_service import DeviceModelRelationService
+from services.device_service import DeviceService
+from services.model_service import ModelService
+
+
+class AlgoRunner:
+ def __init__(self):
+ self.db = get_db()
+ self.device_service = DeviceService(self.db)
+ self.model_service = ModelService(self.db)
+ self.relation_service = DeviceModelRelationService(self.db)
+ self.model_manager = ModelManager(self.db)
+ self.thread_pool = GlobalThreadPool()
+ self.threads = {} # 用于存储设备对应的线程
+ self.device_model_relations = {}
+
+ # 注册设备和模型的变化回调
+ self.device_service.register_change_callback(self.on_device_change)
+ self.model_service.register_change_callback(self.on_model_change)
+ self.relation_service.relation_change_callbacks(self.on_relation_change)
+
+ def start(self):
+ """在程序启动时调用,读取设备和模型,启动检测线程"""
+ self.model_manager.load_models()
+ self.load_and_start_devices()
+
+ def load_and_start_devices(self):
+ """从数据库读取设备列表并启动线程"""
+ devices = self.device_service.get_device_list()
+ devices = [device for device in devices if device.input_stream_url]
+ for device in devices:
+ self.start_device_thread(device)
+
+ def run_detection(self, device: Device):
+ """todo 设备目标检测主逻辑"""
+ print(f"设备 {device.device_id} 的检测线程启动")
+ video_stream = self.device_service.get_video_stream(device.device_id)
+
+ while True:
+ try:
+ frame = video_stream.read_frame()
+ for model in self.model_manager.models:
+ # 调用目标检测模型
+ result = self.model_service.run_inference(model, frame)
+ # 在此处进行后处理,例如标记视频、生成告警等
+ self.process_result(result, device.device_id)
+ except Exception as e:
+ print(f"设备 {device.device_id} 处理时出错: {e}")
+ break
+ video_stream.close()
+ print(f"设备 {device.device_id} 的检测线程结束")
+
+ def start_device_thread(self, device: Device):
+ """为单个设备启动检测线程"""
+
+ if not device.input_stream_url:
+ print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程')
+ return
+
+ if device.device_id in self.threads:
+ print(f"设备 {device.device_id} 已经在运行中")
+ return
+
+ # 获取设备绑定的模型列表
+ relations = self.relation_service.get_device_models(device.device_id)
+ self.device_model_relations[device.device_id] = relations
+
+ if not relations:
+ print(f"设备 {device.code} 未绑定模型,无法启动检测")
+ return
+
+ future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}')
+
+ def stop_device_thread(self, device_id):
+ """todo 控制标志位 停止指定设备的检测线程"""
+ if device_id in self.threads:
+ self.threads[device_id].cancel() # 尝试取消线程
+ del self.threads[device_id]
+ print(f"设备 {device_id} 的检测线程已停止")
+
+ def restart_device_thread(self, device_id):
+ self.stop_device_thread(device_id)
+ # todo 需要留个关闭时间吗
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+
+ def on_device_change(self, device_id, change_type):
+ """设备变化回调"""
+ if change_type == NotifyChangeType.DEVICE_CREATE:
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ elif change_type == NotifyChangeType.DEVICE_DELETE:
+ self.stop_device_thread(device_id)
+ elif change_type == NotifyChangeType.DEVICE_UPDATE:
+ self.restart_device_thread(device_id)
+
+ def on_model_change(self, model_id, change_type):
+ """模型变化回调"""
+ if change_type == NotifyChangeType.MODEL_UPDATE:
+ self.model_manager.reload_model(model_id)
+
+ devices_to_reload = [
+ device_id for device_id, relation_info in self.device_model_relations.items()
+ if relation_info.model_id == model_id
+ ]
+ for device_id in devices_to_reload:
+ self.restart_device_thread(device_id)
+
+ def on_relation_change(self, device_id, change_type):
+ """设备模型关系变化回调"""
+ if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE:
+ self.restart_device_thread(device_id)
diff --git a/algo/model_manager.py b/algo/model_manager.py
new file mode 100644
index 0000000..3ab0709
--- /dev/null
+++ b/algo/model_manager.py
@@ -0,0 +1,64 @@
+from dataclasses import dataclass
+from typing import Optional
+
+import numpy as np
+from ultralytics import YOLO
+
+from entity.model import AlgoModel
+from services.model_service import ModelService
+from common.global_logger import logger
+
+
+@dataclass
+class AlgoModelExec:
+ algo_model_id: int
+ algo_model_info: AlgoModel
+ algo_model_exec: Optional[object] = None
+ input_size: int = 640
+
+
+class ModelManager:
+ def __init__(self, db, model_warm_up=5):
+ self.db = db
+ self.model_service = ModelService(self.db)
+ self.models = {}
+ self.model_warm_up = model_warm_up
+
+ self.load_models()
+
+ def query_model_inuse(self):
+ algo_model_list = list(self.model_service.get_models_in_use())
+ for algo_model in algo_model_list:
+ self.models[algo_model.id] = AlgoModelExec(
+ algo_model_id=algo_model.id,
+ algo_model_info=algo_model
+ )
+
+ def load_models(self):
+ self.models = {}
+ self.query_model_inuse()
+ for algo_model_id, algo_model_exec in self.models.items():
+ self.load_model(algo_model_exec)
+
+ def load_model(self, algo_model_exec: AlgoModelExec):
+ model_name = algo_model_exec.algo_model_info.name
+ model_path = algo_model_exec.algo_model_info.path
+ logger.info(f'loading model {model_name}: {model_path}')
+
+ algo_model_exec.algo_model_exec = YOLO(model_path, task='detect')
+ if self.model_warm_up > 0:
+ logger.info(f'warming up model {model_name}')
+ imgsz = algo_model_exec.input_size
+ if not isinstance(imgsz, list):
+ imgsz = [imgsz, imgsz]
+ dummy_input = np.zeros((imgsz[0], imgsz[1], 3))
+ for i in range(self.model_warm_up):
+ algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False)
+ logger.info(f'warm up model {model_name} success!')
+ logger.info(f'load model {model_name} success!')
+
+ def reload_model(self, model_id):
+ algo_model_exec = self.models.get(model_id)
+ if algo_model_exec:
+ algo_model_exec.algo_model_exec = None
+ self.load_model(algo_model_exec)
diff --git a/apis/__init__.py b/apis/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/apis/__init__.py
diff --git a/apis/base.py b/apis/base.py
new file mode 100644
index 0000000..a7c4ede
--- /dev/null
+++ b/apis/base.py
@@ -0,0 +1,26 @@
+from pydantic import BaseModel, Field
+from typing import Optional, Generic, TypeVar, Any, List
+
+# 定义一个泛型类型变量 T
+T = TypeVar("T")
+
+
+# 使用泛型 T 定义 data 的类型
+class StandardResponse(BaseModel, Generic[T]):
+ code: int = 200
+ data: Optional[T] = None
+ message: str = "请求成功"
+ success: bool = True
+
+
+class PageResponse(BaseModel, Generic[T]):
+ total: int
+ items: List[T] = Field(default_factory=list)
+
+
+def standard_response(data: Any = None, code: int = 200, message: str = "请求成功", success: bool = True):
+ return StandardResponse(data=data, code=code, message=message, success=success)
+
+
+def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False):
+ return StandardResponse(data=data, code=code, message=message, success=success)
diff --git a/apis/device.py b/apis/device.py
new file mode 100644
index 0000000..7eeaa78
--- /dev/null
+++ b/apis/device.py
@@ -0,0 +1,65 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, Query
+from sqlmodel import Session
+
+from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response
+from db.database import get_db
+from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo
+from services.device_service import DeviceService
+
+router = APIRouter()
+
+
+@router.get("/list", response_model=StandardResponse[List[DeviceInfo]])
+def get_device_list(
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ devices = list(service.get_device_list(name, code, device_type))
+ return standard_response(data=devices)
+
+
+@router.get("/page/", response_model=StandardResponse[PageResponse[Device]])
+def get_device_page(
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ offset: int = Query(0, ge=0), # 从第几条开始
+ limit: int = Query(10, ge=1), # 每页显示多少条记录
+ db: Session = Depends(get_db)):
+ service = DeviceService(db)
+
+ # 获取分页后的设备列表和总数
+ devices, total = service.get_device_page(name, code, device_type, offset, limit)
+
+ return standard_response(
+ data=PageResponse(total=total, items=devices)
+ )
+
+
+@router.post("/add", response_model=StandardResponse[DeviceInfo])
+def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.create_device(device_data)
+ return standard_response(data=device)
+
+
+@router.post("/update", response_model=StandardResponse[DeviceInfo])
+def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.update_device(device_data)
+ if not device:
+ return standard_error_response(data=device_data, message="Device not found")
+ return standard_response(data=device)
+
+
+@router.delete("/delete", response_model=StandardResponse[int])
+def delete_device(device_id: int, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.delete_device(device_id)
+ if not device:
+ return standard_error_response(data=device_id, message="Device not found")
+ return standard_response(data=device_id)
diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py
new file mode 100644
index 0000000..7108571
--- /dev/null
+++ b/apis/device_model_realtion.py
@@ -0,0 +1,45 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, Query
+from sqlmodel import Session
+
+from apis.base import StandardResponse, standard_response
+from db.database import get_db
+from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation
+from services.device_model_relation_service import DeviceModelRelationService
+
+router = APIRouter()
+
+
+@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]])
+def list_by_device(
+ device_id: int,
+ db: Session = Depends(get_db)):
+ service = DeviceModelRelationService(db)
+ models = list(service.get_device_models(device_id))
+ return standard_response(data=models)
+
+
+# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
+# def add_by_device(relation_data: List[DeviceModelRelationCreate],
+# device_id: int = Query(...),
+# db: Session = Depends(get_db)):
+# service = DeviceModelRelationService(db)
+# relations = service.add_relations_by_device(device_id, relation_data)
+# return standard_response(data=relations)
+
+
+@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
+def update_by_device(relation_data: List[DeviceModelRelationCreate],
+ device_id: int = Query(...),
+ db: Session = Depends(get_db)):
+ service = DeviceModelRelationService(db)
+ relations = service.update_relations_by_device(device_id, relation_data)
+ return standard_response(data=relations)
+
+
+# @router.delete("/delete_by_device", response_model=StandardResponse[int])
+# def delete_device(device_id: int, db: Session = Depends(get_db)):
+# service = DeviceModelRelationService(db)
+# count = service.delete_relations_by_device(device_id)
+# return standard_response(data=count)
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..e530a86
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+test.py
+/logs/*
\ No newline at end of file
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..901fe74
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,20 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..b68d82f
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..3deddad
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml
new file mode 100644
index 0000000..d0876a7
--- /dev/null
+++ b/.idea/safe-algo-pro.iml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..94a25f7
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/algo/__init__.py b/algo/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/algo/__init__.py
diff --git a/algo/algo_runner.py b/algo/algo_runner.py
new file mode 100644
index 0000000..ee07eaf
--- /dev/null
+++ b/algo/algo_runner.py
@@ -0,0 +1,117 @@
+from algo.model_manager import ModelManager
+from common.consts import NotifyChangeType
+from common.global_thread_pool import GlobalThreadPool
+from db.database import get_db
+from entity.device import Device
+from services.device_model_relation_service import DeviceModelRelationService
+from services.device_service import DeviceService
+from services.model_service import ModelService
+
+
+class AlgoRunner:
+ def __init__(self):
+ self.db = get_db()
+ self.device_service = DeviceService(self.db)
+ self.model_service = ModelService(self.db)
+ self.relation_service = DeviceModelRelationService(self.db)
+ self.model_manager = ModelManager(self.db)
+ self.thread_pool = GlobalThreadPool()
+ self.threads = {} # 用于存储设备对应的线程
+ self.device_model_relations = {}
+
+ # 注册设备和模型的变化回调
+ self.device_service.register_change_callback(self.on_device_change)
+ self.model_service.register_change_callback(self.on_model_change)
+ self.relation_service.relation_change_callbacks(self.on_relation_change)
+
+ def start(self):
+ """在程序启动时调用,读取设备和模型,启动检测线程"""
+ self.model_manager.load_models()
+ self.load_and_start_devices()
+
+ def load_and_start_devices(self):
+ """从数据库读取设备列表并启动线程"""
+ devices = self.device_service.get_device_list()
+ devices = [device for device in devices if device.input_stream_url]
+ for device in devices:
+ self.start_device_thread(device)
+
+ def run_detection(self, device: Device):
+ """todo 设备目标检测主逻辑"""
+ print(f"设备 {device.device_id} 的检测线程启动")
+ video_stream = self.device_service.get_video_stream(device.device_id)
+
+ while True:
+ try:
+ frame = video_stream.read_frame()
+ for model in self.model_manager.models:
+ # 调用目标检测模型
+ result = self.model_service.run_inference(model, frame)
+ # 在此处进行后处理,例如标记视频、生成告警等
+ self.process_result(result, device.device_id)
+ except Exception as e:
+ print(f"设备 {device.device_id} 处理时出错: {e}")
+ break
+ video_stream.close()
+ print(f"设备 {device.device_id} 的检测线程结束")
+
+ def start_device_thread(self, device: Device):
+ """为单个设备启动检测线程"""
+
+ if not device.input_stream_url:
+ print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程')
+ return
+
+ if device.device_id in self.threads:
+ print(f"设备 {device.device_id} 已经在运行中")
+ return
+
+ # 获取设备绑定的模型列表
+ relations = self.relation_service.get_device_models(device.device_id)
+ self.device_model_relations[device.device_id] = relations
+
+ if not relations:
+ print(f"设备 {device.code} 未绑定模型,无法启动检测")
+ return
+
+ future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}')
+
+ def stop_device_thread(self, device_id):
+ """todo 控制标志位 停止指定设备的检测线程"""
+ if device_id in self.threads:
+ self.threads[device_id].cancel() # 尝试取消线程
+ del self.threads[device_id]
+ print(f"设备 {device_id} 的检测线程已停止")
+
+ def restart_device_thread(self, device_id):
+ self.stop_device_thread(device_id)
+ # todo 需要留个关闭时间吗
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+
+ def on_device_change(self, device_id, change_type):
+ """设备变化回调"""
+ if change_type == NotifyChangeType.DEVICE_CREATE:
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ elif change_type == NotifyChangeType.DEVICE_DELETE:
+ self.stop_device_thread(device_id)
+ elif change_type == NotifyChangeType.DEVICE_UPDATE:
+ self.restart_device_thread(device_id)
+
+ def on_model_change(self, model_id, change_type):
+ """模型变化回调"""
+ if change_type == NotifyChangeType.MODEL_UPDATE:
+ self.model_manager.reload_model(model_id)
+
+ devices_to_reload = [
+ device_id for device_id, relation_info in self.device_model_relations.items()
+ if relation_info.model_id == model_id
+ ]
+ for device_id in devices_to_reload:
+ self.restart_device_thread(device_id)
+
+ def on_relation_change(self, device_id, change_type):
+ """设备模型关系变化回调"""
+ if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE:
+ self.restart_device_thread(device_id)
diff --git a/algo/model_manager.py b/algo/model_manager.py
new file mode 100644
index 0000000..3ab0709
--- /dev/null
+++ b/algo/model_manager.py
@@ -0,0 +1,64 @@
+from dataclasses import dataclass
+from typing import Optional
+
+import numpy as np
+from ultralytics import YOLO
+
+from entity.model import AlgoModel
+from services.model_service import ModelService
+from common.global_logger import logger
+
+
+@dataclass
+class AlgoModelExec:
+ algo_model_id: int
+ algo_model_info: AlgoModel
+ algo_model_exec: Optional[object] = None
+ input_size: int = 640
+
+
+class ModelManager:
+ def __init__(self, db, model_warm_up=5):
+ self.db = db
+ self.model_service = ModelService(self.db)
+ self.models = {}
+ self.model_warm_up = model_warm_up
+
+ self.load_models()
+
+ def query_model_inuse(self):
+ algo_model_list = list(self.model_service.get_models_in_use())
+ for algo_model in algo_model_list:
+ self.models[algo_model.id] = AlgoModelExec(
+ algo_model_id=algo_model.id,
+ algo_model_info=algo_model
+ )
+
+ def load_models(self):
+ self.models = {}
+ self.query_model_inuse()
+ for algo_model_id, algo_model_exec in self.models.items():
+ self.load_model(algo_model_exec)
+
+ def load_model(self, algo_model_exec: AlgoModelExec):
+ model_name = algo_model_exec.algo_model_info.name
+ model_path = algo_model_exec.algo_model_info.path
+ logger.info(f'loading model {model_name}: {model_path}')
+
+ algo_model_exec.algo_model_exec = YOLO(model_path, task='detect')
+ if self.model_warm_up > 0:
+ logger.info(f'warming up model {model_name}')
+ imgsz = algo_model_exec.input_size
+ if not isinstance(imgsz, list):
+ imgsz = [imgsz, imgsz]
+ dummy_input = np.zeros((imgsz[0], imgsz[1], 3))
+ for i in range(self.model_warm_up):
+ algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False)
+ logger.info(f'warm up model {model_name} success!')
+ logger.info(f'load model {model_name} success!')
+
+ def reload_model(self, model_id):
+ algo_model_exec = self.models.get(model_id)
+ if algo_model_exec:
+ algo_model_exec.algo_model_exec = None
+ self.load_model(algo_model_exec)
diff --git a/apis/__init__.py b/apis/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/apis/__init__.py
diff --git a/apis/base.py b/apis/base.py
new file mode 100644
index 0000000..a7c4ede
--- /dev/null
+++ b/apis/base.py
@@ -0,0 +1,26 @@
+from pydantic import BaseModel, Field
+from typing import Optional, Generic, TypeVar, Any, List
+
+# 定义一个泛型类型变量 T
+T = TypeVar("T")
+
+
+# 使用泛型 T 定义 data 的类型
+class StandardResponse(BaseModel, Generic[T]):
+ code: int = 200
+ data: Optional[T] = None
+ message: str = "请求成功"
+ success: bool = True
+
+
+class PageResponse(BaseModel, Generic[T]):
+ total: int
+ items: List[T] = Field(default_factory=list)
+
+
+def standard_response(data: Any = None, code: int = 200, message: str = "请求成功", success: bool = True):
+ return StandardResponse(data=data, code=code, message=message, success=success)
+
+
+def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False):
+ return StandardResponse(data=data, code=code, message=message, success=success)
diff --git a/apis/device.py b/apis/device.py
new file mode 100644
index 0000000..7eeaa78
--- /dev/null
+++ b/apis/device.py
@@ -0,0 +1,65 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, Query
+from sqlmodel import Session
+
+from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response
+from db.database import get_db
+from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo
+from services.device_service import DeviceService
+
+router = APIRouter()
+
+
+@router.get("/list", response_model=StandardResponse[List[DeviceInfo]])
+def get_device_list(
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ devices = list(service.get_device_list(name, code, device_type))
+ return standard_response(data=devices)
+
+
+@router.get("/page/", response_model=StandardResponse[PageResponse[Device]])
+def get_device_page(
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ offset: int = Query(0, ge=0), # 从第几条开始
+ limit: int = Query(10, ge=1), # 每页显示多少条记录
+ db: Session = Depends(get_db)):
+ service = DeviceService(db)
+
+ # 获取分页后的设备列表和总数
+ devices, total = service.get_device_page(name, code, device_type, offset, limit)
+
+ return standard_response(
+ data=PageResponse(total=total, items=devices)
+ )
+
+
+@router.post("/add", response_model=StandardResponse[DeviceInfo])
+def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.create_device(device_data)
+ return standard_response(data=device)
+
+
+@router.post("/update", response_model=StandardResponse[DeviceInfo])
+def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.update_device(device_data)
+ if not device:
+ return standard_error_response(data=device_data, message="Device not found")
+ return standard_response(data=device)
+
+
+@router.delete("/delete", response_model=StandardResponse[int])
+def delete_device(device_id: int, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.delete_device(device_id)
+ if not device:
+ return standard_error_response(data=device_id, message="Device not found")
+ return standard_response(data=device_id)
diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py
new file mode 100644
index 0000000..7108571
--- /dev/null
+++ b/apis/device_model_realtion.py
@@ -0,0 +1,45 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, Query
+from sqlmodel import Session
+
+from apis.base import StandardResponse, standard_response
+from db.database import get_db
+from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation
+from services.device_model_relation_service import DeviceModelRelationService
+
+router = APIRouter()
+
+
+@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]])
+def list_by_device(
+ device_id: int,
+ db: Session = Depends(get_db)):
+ service = DeviceModelRelationService(db)
+ models = list(service.get_device_models(device_id))
+ return standard_response(data=models)
+
+
+# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
+# def add_by_device(relation_data: List[DeviceModelRelationCreate],
+# device_id: int = Query(...),
+# db: Session = Depends(get_db)):
+# service = DeviceModelRelationService(db)
+# relations = service.add_relations_by_device(device_id, relation_data)
+# return standard_response(data=relations)
+
+
+@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
+def update_by_device(relation_data: List[DeviceModelRelationCreate],
+ device_id: int = Query(...),
+ db: Session = Depends(get_db)):
+ service = DeviceModelRelationService(db)
+ relations = service.update_relations_by_device(device_id, relation_data)
+ return standard_response(data=relations)
+
+
+# @router.delete("/delete_by_device", response_model=StandardResponse[int])
+# def delete_device(device_id: int, db: Session = Depends(get_db)):
+# service = DeviceModelRelationService(db)
+# count = service.delete_relations_by_device(device_id)
+# return standard_response(data=count)
diff --git a/apis/model.py b/apis/model.py
new file mode 100644
index 0000000..8adcc14
--- /dev/null
+++ b/apis/model.py
@@ -0,0 +1,64 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, HTTPException, Query
+from sqlmodel import Session
+
+from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response
+from common.biz_exception import BizException
+from db.database import get_db
+from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo
+from services.model_service import ModelService
+
+router = APIRouter()
+
+
+@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]])
+def get_model_list(
+ name: Optional[str] = None,
+ remark: Optional[str] = None,
+ db: Session = Depends(get_db)):
+ service = ModelService(db)
+ models = list(service.get_model_list(name, remark))
+ return standard_response(data=models)
+
+
+@router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModel]])
+def get_model_page(
+ name: Optional[str] = None,
+ remark: Optional[str] = None,
+ offset: int = Query(0, ge=0), # 从第几条开始
+ limit: int = Query(10, ge=1), # 每页显示多少条记录
+ db: Session = Depends(get_db)):
+ service = ModelService(db)
+
+ # 获取分页后的设备列表和总数
+ models, total = service.get_model_page(name, remark, offset, limit)
+
+ return standard_response(
+ data=PageResponse(total=total, items=models)
+ )
+
+
+@router.post("/add", response_model=StandardResponse[AlgoModelInfo])
+def create_model(model_data: AlgoModelCreate, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.create_model(model_data)
+ return standard_response(data=model)
+
+
+@router.post("/update", response_model=StandardResponse[AlgoModelInfo])
+def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.update_model(model_data)
+ if not model:
+ return standard_error_response(data=model_data, message="Model not found")
+ return standard_response(data=model)
+
+
+@router.delete("/delete", response_model=StandardResponse[int])
+def delete_model(model_id: int, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.delete_model(model_id)
+ if not model:
+ return standard_error_response(data=model_id, message="Model not found")
+ return standard_response(data=model_id)
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..e530a86
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+test.py
+/logs/*
\ No newline at end of file
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..901fe74
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,20 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..b68d82f
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..3deddad
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml
new file mode 100644
index 0000000..d0876a7
--- /dev/null
+++ b/.idea/safe-algo-pro.iml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..94a25f7
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/algo/__init__.py b/algo/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/algo/__init__.py
diff --git a/algo/algo_runner.py b/algo/algo_runner.py
new file mode 100644
index 0000000..ee07eaf
--- /dev/null
+++ b/algo/algo_runner.py
@@ -0,0 +1,117 @@
+from algo.model_manager import ModelManager
+from common.consts import NotifyChangeType
+from common.global_thread_pool import GlobalThreadPool
+from db.database import get_db
+from entity.device import Device
+from services.device_model_relation_service import DeviceModelRelationService
+from services.device_service import DeviceService
+from services.model_service import ModelService
+
+
+class AlgoRunner:
+ def __init__(self):
+ self.db = get_db()
+ self.device_service = DeviceService(self.db)
+ self.model_service = ModelService(self.db)
+ self.relation_service = DeviceModelRelationService(self.db)
+ self.model_manager = ModelManager(self.db)
+ self.thread_pool = GlobalThreadPool()
+ self.threads = {} # 用于存储设备对应的线程
+ self.device_model_relations = {}
+
+ # 注册设备和模型的变化回调
+ self.device_service.register_change_callback(self.on_device_change)
+ self.model_service.register_change_callback(self.on_model_change)
+ self.relation_service.relation_change_callbacks(self.on_relation_change)
+
+ def start(self):
+ """在程序启动时调用,读取设备和模型,启动检测线程"""
+ self.model_manager.load_models()
+ self.load_and_start_devices()
+
+ def load_and_start_devices(self):
+ """从数据库读取设备列表并启动线程"""
+ devices = self.device_service.get_device_list()
+ devices = [device for device in devices if device.input_stream_url]
+ for device in devices:
+ self.start_device_thread(device)
+
+ def run_detection(self, device: Device):
+ """todo 设备目标检测主逻辑"""
+ print(f"设备 {device.device_id} 的检测线程启动")
+ video_stream = self.device_service.get_video_stream(device.device_id)
+
+ while True:
+ try:
+ frame = video_stream.read_frame()
+ for model in self.model_manager.models:
+ # 调用目标检测模型
+ result = self.model_service.run_inference(model, frame)
+ # 在此处进行后处理,例如标记视频、生成告警等
+ self.process_result(result, device.device_id)
+ except Exception as e:
+ print(f"设备 {device.device_id} 处理时出错: {e}")
+ break
+ video_stream.close()
+ print(f"设备 {device.device_id} 的检测线程结束")
+
+ def start_device_thread(self, device: Device):
+ """为单个设备启动检测线程"""
+
+ if not device.input_stream_url:
+ print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程')
+ return
+
+ if device.device_id in self.threads:
+ print(f"设备 {device.device_id} 已经在运行中")
+ return
+
+ # 获取设备绑定的模型列表
+ relations = self.relation_service.get_device_models(device.device_id)
+ self.device_model_relations[device.device_id] = relations
+
+ if not relations:
+ print(f"设备 {device.code} 未绑定模型,无法启动检测")
+ return
+
+ future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}')
+
+ def stop_device_thread(self, device_id):
+ """todo 控制标志位 停止指定设备的检测线程"""
+ if device_id in self.threads:
+ self.threads[device_id].cancel() # 尝试取消线程
+ del self.threads[device_id]
+ print(f"设备 {device_id} 的检测线程已停止")
+
+ def restart_device_thread(self, device_id):
+ self.stop_device_thread(device_id)
+ # todo 需要留个关闭时间吗
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+
+ def on_device_change(self, device_id, change_type):
+ """设备变化回调"""
+ if change_type == NotifyChangeType.DEVICE_CREATE:
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ elif change_type == NotifyChangeType.DEVICE_DELETE:
+ self.stop_device_thread(device_id)
+ elif change_type == NotifyChangeType.DEVICE_UPDATE:
+ self.restart_device_thread(device_id)
+
+ def on_model_change(self, model_id, change_type):
+ """模型变化回调"""
+ if change_type == NotifyChangeType.MODEL_UPDATE:
+ self.model_manager.reload_model(model_id)
+
+ devices_to_reload = [
+ device_id for device_id, relation_info in self.device_model_relations.items()
+ if relation_info.model_id == model_id
+ ]
+ for device_id in devices_to_reload:
+ self.restart_device_thread(device_id)
+
+ def on_relation_change(self, device_id, change_type):
+ """设备模型关系变化回调"""
+ if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE:
+ self.restart_device_thread(device_id)
diff --git a/algo/model_manager.py b/algo/model_manager.py
new file mode 100644
index 0000000..3ab0709
--- /dev/null
+++ b/algo/model_manager.py
@@ -0,0 +1,64 @@
+from dataclasses import dataclass
+from typing import Optional
+
+import numpy as np
+from ultralytics import YOLO
+
+from entity.model import AlgoModel
+from services.model_service import ModelService
+from common.global_logger import logger
+
+
+@dataclass
+class AlgoModelExec:
+ algo_model_id: int
+ algo_model_info: AlgoModel
+ algo_model_exec: Optional[object] = None
+ input_size: int = 640
+
+
+class ModelManager:
+ def __init__(self, db, model_warm_up=5):
+ self.db = db
+ self.model_service = ModelService(self.db)
+ self.models = {}
+ self.model_warm_up = model_warm_up
+
+ self.load_models()
+
+ def query_model_inuse(self):
+ algo_model_list = list(self.model_service.get_models_in_use())
+ for algo_model in algo_model_list:
+ self.models[algo_model.id] = AlgoModelExec(
+ algo_model_id=algo_model.id,
+ algo_model_info=algo_model
+ )
+
+ def load_models(self):
+ self.models = {}
+ self.query_model_inuse()
+ for algo_model_id, algo_model_exec in self.models.items():
+ self.load_model(algo_model_exec)
+
+ def load_model(self, algo_model_exec: AlgoModelExec):
+ model_name = algo_model_exec.algo_model_info.name
+ model_path = algo_model_exec.algo_model_info.path
+ logger.info(f'loading model {model_name}: {model_path}')
+
+ algo_model_exec.algo_model_exec = YOLO(model_path, task='detect')
+ if self.model_warm_up > 0:
+ logger.info(f'warming up model {model_name}')
+ imgsz = algo_model_exec.input_size
+ if not isinstance(imgsz, list):
+ imgsz = [imgsz, imgsz]
+ dummy_input = np.zeros((imgsz[0], imgsz[1], 3))
+ for i in range(self.model_warm_up):
+ algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False)
+ logger.info(f'warm up model {model_name} success!')
+ logger.info(f'load model {model_name} success!')
+
+ def reload_model(self, model_id):
+ algo_model_exec = self.models.get(model_id)
+ if algo_model_exec:
+ algo_model_exec.algo_model_exec = None
+ self.load_model(algo_model_exec)
diff --git a/apis/__init__.py b/apis/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/apis/__init__.py
diff --git a/apis/base.py b/apis/base.py
new file mode 100644
index 0000000..a7c4ede
--- /dev/null
+++ b/apis/base.py
@@ -0,0 +1,26 @@
+from pydantic import BaseModel, Field
+from typing import Optional, Generic, TypeVar, Any, List
+
+# 定义一个泛型类型变量 T
+T = TypeVar("T")
+
+
+# 使用泛型 T 定义 data 的类型
+class StandardResponse(BaseModel, Generic[T]):
+ code: int = 200
+ data: Optional[T] = None
+ message: str = "请求成功"
+ success: bool = True
+
+
+class PageResponse(BaseModel, Generic[T]):
+ total: int
+ items: List[T] = Field(default_factory=list)
+
+
+def standard_response(data: Any = None, code: int = 200, message: str = "请求成功", success: bool = True):
+ return StandardResponse(data=data, code=code, message=message, success=success)
+
+
+def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False):
+ return StandardResponse(data=data, code=code, message=message, success=success)
diff --git a/apis/device.py b/apis/device.py
new file mode 100644
index 0000000..7eeaa78
--- /dev/null
+++ b/apis/device.py
@@ -0,0 +1,65 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, Query
+from sqlmodel import Session
+
+from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response
+from db.database import get_db
+from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo
+from services.device_service import DeviceService
+
+router = APIRouter()
+
+
+@router.get("/list", response_model=StandardResponse[List[DeviceInfo]])
+def get_device_list(
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ devices = list(service.get_device_list(name, code, device_type))
+ return standard_response(data=devices)
+
+
+@router.get("/page/", response_model=StandardResponse[PageResponse[Device]])
+def get_device_page(
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ offset: int = Query(0, ge=0), # 从第几条开始
+ limit: int = Query(10, ge=1), # 每页显示多少条记录
+ db: Session = Depends(get_db)):
+ service = DeviceService(db)
+
+ # 获取分页后的设备列表和总数
+ devices, total = service.get_device_page(name, code, device_type, offset, limit)
+
+ return standard_response(
+ data=PageResponse(total=total, items=devices)
+ )
+
+
+@router.post("/add", response_model=StandardResponse[DeviceInfo])
+def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.create_device(device_data)
+ return standard_response(data=device)
+
+
+@router.post("/update", response_model=StandardResponse[DeviceInfo])
+def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.update_device(device_data)
+ if not device:
+ return standard_error_response(data=device_data, message="Device not found")
+ return standard_response(data=device)
+
+
+@router.delete("/delete", response_model=StandardResponse[int])
+def delete_device(device_id: int, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.delete_device(device_id)
+ if not device:
+ return standard_error_response(data=device_id, message="Device not found")
+ return standard_response(data=device_id)
diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py
new file mode 100644
index 0000000..7108571
--- /dev/null
+++ b/apis/device_model_realtion.py
@@ -0,0 +1,45 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, Query
+from sqlmodel import Session
+
+from apis.base import StandardResponse, standard_response
+from db.database import get_db
+from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation
+from services.device_model_relation_service import DeviceModelRelationService
+
+router = APIRouter()
+
+
+@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]])
+def list_by_device(
+ device_id: int,
+ db: Session = Depends(get_db)):
+ service = DeviceModelRelationService(db)
+ models = list(service.get_device_models(device_id))
+ return standard_response(data=models)
+
+
+# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
+# def add_by_device(relation_data: List[DeviceModelRelationCreate],
+# device_id: int = Query(...),
+# db: Session = Depends(get_db)):
+# service = DeviceModelRelationService(db)
+# relations = service.add_relations_by_device(device_id, relation_data)
+# return standard_response(data=relations)
+
+
+@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
+def update_by_device(relation_data: List[DeviceModelRelationCreate],
+ device_id: int = Query(...),
+ db: Session = Depends(get_db)):
+ service = DeviceModelRelationService(db)
+ relations = service.update_relations_by_device(device_id, relation_data)
+ return standard_response(data=relations)
+
+
+# @router.delete("/delete_by_device", response_model=StandardResponse[int])
+# def delete_device(device_id: int, db: Session = Depends(get_db)):
+# service = DeviceModelRelationService(db)
+# count = service.delete_relations_by_device(device_id)
+# return standard_response(data=count)
diff --git a/apis/model.py b/apis/model.py
new file mode 100644
index 0000000..8adcc14
--- /dev/null
+++ b/apis/model.py
@@ -0,0 +1,64 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, HTTPException, Query
+from sqlmodel import Session
+
+from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response
+from common.biz_exception import BizException
+from db.database import get_db
+from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo
+from services.model_service import ModelService
+
+router = APIRouter()
+
+
+@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]])
+def get_model_list(
+ name: Optional[str] = None,
+ remark: Optional[str] = None,
+ db: Session = Depends(get_db)):
+ service = ModelService(db)
+ models = list(service.get_model_list(name, remark))
+ return standard_response(data=models)
+
+
+@router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModel]])
+def get_model_page(
+ name: Optional[str] = None,
+ remark: Optional[str] = None,
+ offset: int = Query(0, ge=0), # 从第几条开始
+ limit: int = Query(10, ge=1), # 每页显示多少条记录
+ db: Session = Depends(get_db)):
+ service = ModelService(db)
+
+ # 获取分页后的设备列表和总数
+ models, total = service.get_model_page(name, remark, offset, limit)
+
+ return standard_response(
+ data=PageResponse(total=total, items=models)
+ )
+
+
+@router.post("/add", response_model=StandardResponse[AlgoModelInfo])
+def create_model(model_data: AlgoModelCreate, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.create_model(model_data)
+ return standard_response(data=model)
+
+
+@router.post("/update", response_model=StandardResponse[AlgoModelInfo])
+def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.update_model(model_data)
+ if not model:
+ return standard_error_response(data=model_data, message="Model not found")
+ return standard_response(data=model)
+
+
+@router.delete("/delete", response_model=StandardResponse[int])
+def delete_model(model_id: int, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.delete_model(model_id)
+ if not model:
+ return standard_error_response(data=model_id, message="Model not found")
+ return standard_response(data=model_id)
diff --git a/apis/router.py b/apis/router.py
new file mode 100644
index 0000000..22b5ed6
--- /dev/null
+++ b/apis/router.py
@@ -0,0 +1,15 @@
+# api/router.py
+from fastapi import APIRouter
+from .device import router as devices_router
+from .model import router as models_router
+from .device_model_realtion import router as device_model_relation_router
+
+
+# 创建一个全局的 router
+router = APIRouter()
+
+# 将各个模块的 router 注册到全局 router 中
+router.include_router(devices_router, prefix="/device", tags=["Devices"])
+router.include_router(models_router, prefix="/model", tags=["Models"])
+router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"])
+
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..e530a86
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+test.py
+/logs/*
\ No newline at end of file
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..901fe74
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,20 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..b68d82f
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..3deddad
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml
new file mode 100644
index 0000000..d0876a7
--- /dev/null
+++ b/.idea/safe-algo-pro.iml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..94a25f7
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/algo/__init__.py b/algo/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/algo/__init__.py
diff --git a/algo/algo_runner.py b/algo/algo_runner.py
new file mode 100644
index 0000000..ee07eaf
--- /dev/null
+++ b/algo/algo_runner.py
@@ -0,0 +1,117 @@
+from algo.model_manager import ModelManager
+from common.consts import NotifyChangeType
+from common.global_thread_pool import GlobalThreadPool
+from db.database import get_db
+from entity.device import Device
+from services.device_model_relation_service import DeviceModelRelationService
+from services.device_service import DeviceService
+from services.model_service import ModelService
+
+
+class AlgoRunner:
+ def __init__(self):
+ self.db = get_db()
+ self.device_service = DeviceService(self.db)
+ self.model_service = ModelService(self.db)
+ self.relation_service = DeviceModelRelationService(self.db)
+ self.model_manager = ModelManager(self.db)
+ self.thread_pool = GlobalThreadPool()
+ self.threads = {} # 用于存储设备对应的线程
+ self.device_model_relations = {}
+
+ # 注册设备和模型的变化回调
+ self.device_service.register_change_callback(self.on_device_change)
+ self.model_service.register_change_callback(self.on_model_change)
+ self.relation_service.relation_change_callbacks(self.on_relation_change)
+
+ def start(self):
+ """在程序启动时调用,读取设备和模型,启动检测线程"""
+ self.model_manager.load_models()
+ self.load_and_start_devices()
+
+ def load_and_start_devices(self):
+ """从数据库读取设备列表并启动线程"""
+ devices = self.device_service.get_device_list()
+ devices = [device for device in devices if device.input_stream_url]
+ for device in devices:
+ self.start_device_thread(device)
+
+ def run_detection(self, device: Device):
+ """todo 设备目标检测主逻辑"""
+ print(f"设备 {device.device_id} 的检测线程启动")
+ video_stream = self.device_service.get_video_stream(device.device_id)
+
+ while True:
+ try:
+ frame = video_stream.read_frame()
+ for model in self.model_manager.models:
+ # 调用目标检测模型
+ result = self.model_service.run_inference(model, frame)
+ # 在此处进行后处理,例如标记视频、生成告警等
+ self.process_result(result, device.device_id)
+ except Exception as e:
+ print(f"设备 {device.device_id} 处理时出错: {e}")
+ break
+ video_stream.close()
+ print(f"设备 {device.device_id} 的检测线程结束")
+
+ def start_device_thread(self, device: Device):
+ """为单个设备启动检测线程"""
+
+ if not device.input_stream_url:
+ print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程')
+ return
+
+ if device.device_id in self.threads:
+ print(f"设备 {device.device_id} 已经在运行中")
+ return
+
+ # 获取设备绑定的模型列表
+ relations = self.relation_service.get_device_models(device.device_id)
+ self.device_model_relations[device.device_id] = relations
+
+ if not relations:
+ print(f"设备 {device.code} 未绑定模型,无法启动检测")
+ return
+
+ future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}')
+
+ def stop_device_thread(self, device_id):
+ """todo 控制标志位 停止指定设备的检测线程"""
+ if device_id in self.threads:
+ self.threads[device_id].cancel() # 尝试取消线程
+ del self.threads[device_id]
+ print(f"设备 {device_id} 的检测线程已停止")
+
+ def restart_device_thread(self, device_id):
+ self.stop_device_thread(device_id)
+ # todo 需要留个关闭时间吗
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+
+ def on_device_change(self, device_id, change_type):
+ """设备变化回调"""
+ if change_type == NotifyChangeType.DEVICE_CREATE:
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ elif change_type == NotifyChangeType.DEVICE_DELETE:
+ self.stop_device_thread(device_id)
+ elif change_type == NotifyChangeType.DEVICE_UPDATE:
+ self.restart_device_thread(device_id)
+
+ def on_model_change(self, model_id, change_type):
+ """模型变化回调"""
+ if change_type == NotifyChangeType.MODEL_UPDATE:
+ self.model_manager.reload_model(model_id)
+
+ devices_to_reload = [
+ device_id for device_id, relation_info in self.device_model_relations.items()
+ if relation_info.model_id == model_id
+ ]
+ for device_id in devices_to_reload:
+ self.restart_device_thread(device_id)
+
+ def on_relation_change(self, device_id, change_type):
+ """设备模型关系变化回调"""
+ if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE:
+ self.restart_device_thread(device_id)
diff --git a/algo/model_manager.py b/algo/model_manager.py
new file mode 100644
index 0000000..3ab0709
--- /dev/null
+++ b/algo/model_manager.py
@@ -0,0 +1,64 @@
+from dataclasses import dataclass
+from typing import Optional
+
+import numpy as np
+from ultralytics import YOLO
+
+from entity.model import AlgoModel
+from services.model_service import ModelService
+from common.global_logger import logger
+
+
+@dataclass
+class AlgoModelExec:
+ algo_model_id: int
+ algo_model_info: AlgoModel
+ algo_model_exec: Optional[object] = None
+ input_size: int = 640
+
+
+class ModelManager:
+ def __init__(self, db, model_warm_up=5):
+ self.db = db
+ self.model_service = ModelService(self.db)
+ self.models = {}
+ self.model_warm_up = model_warm_up
+
+ self.load_models()
+
+ def query_model_inuse(self):
+ algo_model_list = list(self.model_service.get_models_in_use())
+ for algo_model in algo_model_list:
+ self.models[algo_model.id] = AlgoModelExec(
+ algo_model_id=algo_model.id,
+ algo_model_info=algo_model
+ )
+
+ def load_models(self):
+ self.models = {}
+ self.query_model_inuse()
+ for algo_model_id, algo_model_exec in self.models.items():
+ self.load_model(algo_model_exec)
+
+ def load_model(self, algo_model_exec: AlgoModelExec):
+ model_name = algo_model_exec.algo_model_info.name
+ model_path = algo_model_exec.algo_model_info.path
+ logger.info(f'loading model {model_name}: {model_path}')
+
+ algo_model_exec.algo_model_exec = YOLO(model_path, task='detect')
+ if self.model_warm_up > 0:
+ logger.info(f'warming up model {model_name}')
+ imgsz = algo_model_exec.input_size
+ if not isinstance(imgsz, list):
+ imgsz = [imgsz, imgsz]
+ dummy_input = np.zeros((imgsz[0], imgsz[1], 3))
+ for i in range(self.model_warm_up):
+ algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False)
+ logger.info(f'warm up model {model_name} success!')
+ logger.info(f'load model {model_name} success!')
+
+ def reload_model(self, model_id):
+ algo_model_exec = self.models.get(model_id)
+ if algo_model_exec:
+ algo_model_exec.algo_model_exec = None
+ self.load_model(algo_model_exec)
diff --git a/apis/__init__.py b/apis/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/apis/__init__.py
diff --git a/apis/base.py b/apis/base.py
new file mode 100644
index 0000000..a7c4ede
--- /dev/null
+++ b/apis/base.py
@@ -0,0 +1,26 @@
+from pydantic import BaseModel, Field
+from typing import Optional, Generic, TypeVar, Any, List
+
+# 定义一个泛型类型变量 T
+T = TypeVar("T")
+
+
+# 使用泛型 T 定义 data 的类型
+class StandardResponse(BaseModel, Generic[T]):
+ code: int = 200
+ data: Optional[T] = None
+ message: str = "请求成功"
+ success: bool = True
+
+
+class PageResponse(BaseModel, Generic[T]):
+ total: int
+ items: List[T] = Field(default_factory=list)
+
+
+def standard_response(data: Any = None, code: int = 200, message: str = "请求成功", success: bool = True):
+ return StandardResponse(data=data, code=code, message=message, success=success)
+
+
+def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False):
+ return StandardResponse(data=data, code=code, message=message, success=success)
diff --git a/apis/device.py b/apis/device.py
new file mode 100644
index 0000000..7eeaa78
--- /dev/null
+++ b/apis/device.py
@@ -0,0 +1,65 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, Query
+from sqlmodel import Session
+
+from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response
+from db.database import get_db
+from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo
+from services.device_service import DeviceService
+
+router = APIRouter()
+
+
+@router.get("/list", response_model=StandardResponse[List[DeviceInfo]])
+def get_device_list(
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ devices = list(service.get_device_list(name, code, device_type))
+ return standard_response(data=devices)
+
+
+@router.get("/page/", response_model=StandardResponse[PageResponse[Device]])
+def get_device_page(
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ offset: int = Query(0, ge=0), # 从第几条开始
+ limit: int = Query(10, ge=1), # 每页显示多少条记录
+ db: Session = Depends(get_db)):
+ service = DeviceService(db)
+
+ # 获取分页后的设备列表和总数
+ devices, total = service.get_device_page(name, code, device_type, offset, limit)
+
+ return standard_response(
+ data=PageResponse(total=total, items=devices)
+ )
+
+
+@router.post("/add", response_model=StandardResponse[DeviceInfo])
+def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.create_device(device_data)
+ return standard_response(data=device)
+
+
+@router.post("/update", response_model=StandardResponse[DeviceInfo])
+def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.update_device(device_data)
+ if not device:
+ return standard_error_response(data=device_data, message="Device not found")
+ return standard_response(data=device)
+
+
+@router.delete("/delete", response_model=StandardResponse[int])
+def delete_device(device_id: int, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.delete_device(device_id)
+ if not device:
+ return standard_error_response(data=device_id, message="Device not found")
+ return standard_response(data=device_id)
diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py
new file mode 100644
index 0000000..7108571
--- /dev/null
+++ b/apis/device_model_realtion.py
@@ -0,0 +1,45 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, Query
+from sqlmodel import Session
+
+from apis.base import StandardResponse, standard_response
+from db.database import get_db
+from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation
+from services.device_model_relation_service import DeviceModelRelationService
+
+router = APIRouter()
+
+
+@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]])
+def list_by_device(
+ device_id: int,
+ db: Session = Depends(get_db)):
+ service = DeviceModelRelationService(db)
+ models = list(service.get_device_models(device_id))
+ return standard_response(data=models)
+
+
+# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
+# def add_by_device(relation_data: List[DeviceModelRelationCreate],
+# device_id: int = Query(...),
+# db: Session = Depends(get_db)):
+# service = DeviceModelRelationService(db)
+# relations = service.add_relations_by_device(device_id, relation_data)
+# return standard_response(data=relations)
+
+
+@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
+def update_by_device(relation_data: List[DeviceModelRelationCreate],
+ device_id: int = Query(...),
+ db: Session = Depends(get_db)):
+ service = DeviceModelRelationService(db)
+ relations = service.update_relations_by_device(device_id, relation_data)
+ return standard_response(data=relations)
+
+
+# @router.delete("/delete_by_device", response_model=StandardResponse[int])
+# def delete_device(device_id: int, db: Session = Depends(get_db)):
+# service = DeviceModelRelationService(db)
+# count = service.delete_relations_by_device(device_id)
+# return standard_response(data=count)
diff --git a/apis/model.py b/apis/model.py
new file mode 100644
index 0000000..8adcc14
--- /dev/null
+++ b/apis/model.py
@@ -0,0 +1,64 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, HTTPException, Query
+from sqlmodel import Session
+
+from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response
+from common.biz_exception import BizException
+from db.database import get_db
+from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo
+from services.model_service import ModelService
+
+router = APIRouter()
+
+
+@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]])
+def get_model_list(
+ name: Optional[str] = None,
+ remark: Optional[str] = None,
+ db: Session = Depends(get_db)):
+ service = ModelService(db)
+ models = list(service.get_model_list(name, remark))
+ return standard_response(data=models)
+
+
+@router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModel]])
+def get_model_page(
+ name: Optional[str] = None,
+ remark: Optional[str] = None,
+ offset: int = Query(0, ge=0), # 从第几条开始
+ limit: int = Query(10, ge=1), # 每页显示多少条记录
+ db: Session = Depends(get_db)):
+ service = ModelService(db)
+
+ # 获取分页后的设备列表和总数
+ models, total = service.get_model_page(name, remark, offset, limit)
+
+ return standard_response(
+ data=PageResponse(total=total, items=models)
+ )
+
+
+@router.post("/add", response_model=StandardResponse[AlgoModelInfo])
+def create_model(model_data: AlgoModelCreate, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.create_model(model_data)
+ return standard_response(data=model)
+
+
+@router.post("/update", response_model=StandardResponse[AlgoModelInfo])
+def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.update_model(model_data)
+ if not model:
+ return standard_error_response(data=model_data, message="Model not found")
+ return standard_response(data=model)
+
+
+@router.delete("/delete", response_model=StandardResponse[int])
+def delete_model(model_id: int, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.delete_model(model_id)
+ if not model:
+ return standard_error_response(data=model_id, message="Model not found")
+ return standard_response(data=model_id)
diff --git a/apis/router.py b/apis/router.py
new file mode 100644
index 0000000..22b5ed6
--- /dev/null
+++ b/apis/router.py
@@ -0,0 +1,15 @@
+# api/router.py
+from fastapi import APIRouter
+from .device import router as devices_router
+from .model import router as models_router
+from .device_model_realtion import router as device_model_relation_router
+
+
+# 创建一个全局的 router
+router = APIRouter()
+
+# 将各个模块的 router 注册到全局 router 中
+router.include_router(devices_router, prefix="/device", tags=["Devices"])
+router.include_router(models_router, prefix="/model", tags=["Models"])
+router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"])
+
diff --git a/common/__init__.py b/common/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/common/__init__.py
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..e530a86
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+test.py
+/logs/*
\ No newline at end of file
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..901fe74
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,20 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..b68d82f
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..3deddad
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml
new file mode 100644
index 0000000..d0876a7
--- /dev/null
+++ b/.idea/safe-algo-pro.iml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..94a25f7
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/algo/__init__.py b/algo/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/algo/__init__.py
diff --git a/algo/algo_runner.py b/algo/algo_runner.py
new file mode 100644
index 0000000..ee07eaf
--- /dev/null
+++ b/algo/algo_runner.py
@@ -0,0 +1,117 @@
+from algo.model_manager import ModelManager
+from common.consts import NotifyChangeType
+from common.global_thread_pool import GlobalThreadPool
+from db.database import get_db
+from entity.device import Device
+from services.device_model_relation_service import DeviceModelRelationService
+from services.device_service import DeviceService
+from services.model_service import ModelService
+
+
+class AlgoRunner:
+ def __init__(self):
+ self.db = get_db()
+ self.device_service = DeviceService(self.db)
+ self.model_service = ModelService(self.db)
+ self.relation_service = DeviceModelRelationService(self.db)
+ self.model_manager = ModelManager(self.db)
+ self.thread_pool = GlobalThreadPool()
+ self.threads = {} # 用于存储设备对应的线程
+ self.device_model_relations = {}
+
+ # 注册设备和模型的变化回调
+ self.device_service.register_change_callback(self.on_device_change)
+ self.model_service.register_change_callback(self.on_model_change)
+ self.relation_service.relation_change_callbacks(self.on_relation_change)
+
+ def start(self):
+ """在程序启动时调用,读取设备和模型,启动检测线程"""
+ self.model_manager.load_models()
+ self.load_and_start_devices()
+
+ def load_and_start_devices(self):
+ """从数据库读取设备列表并启动线程"""
+ devices = self.device_service.get_device_list()
+ devices = [device for device in devices if device.input_stream_url]
+ for device in devices:
+ self.start_device_thread(device)
+
+ def run_detection(self, device: Device):
+ """todo 设备目标检测主逻辑"""
+ print(f"设备 {device.device_id} 的检测线程启动")
+ video_stream = self.device_service.get_video_stream(device.device_id)
+
+ while True:
+ try:
+ frame = video_stream.read_frame()
+ for model in self.model_manager.models:
+ # 调用目标检测模型
+ result = self.model_service.run_inference(model, frame)
+ # 在此处进行后处理,例如标记视频、生成告警等
+ self.process_result(result, device.device_id)
+ except Exception as e:
+ print(f"设备 {device.device_id} 处理时出错: {e}")
+ break
+ video_stream.close()
+ print(f"设备 {device.device_id} 的检测线程结束")
+
+ def start_device_thread(self, device: Device):
+ """为单个设备启动检测线程"""
+
+ if not device.input_stream_url:
+ print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程')
+ return
+
+ if device.device_id in self.threads:
+ print(f"设备 {device.device_id} 已经在运行中")
+ return
+
+ # 获取设备绑定的模型列表
+ relations = self.relation_service.get_device_models(device.device_id)
+ self.device_model_relations[device.device_id] = relations
+
+ if not relations:
+ print(f"设备 {device.code} 未绑定模型,无法启动检测")
+ return
+
+ future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}')
+
+ def stop_device_thread(self, device_id):
+ """todo 控制标志位 停止指定设备的检测线程"""
+ if device_id in self.threads:
+ self.threads[device_id].cancel() # 尝试取消线程
+ del self.threads[device_id]
+ print(f"设备 {device_id} 的检测线程已停止")
+
+ def restart_device_thread(self, device_id):
+ self.stop_device_thread(device_id)
+ # todo 需要留个关闭时间吗
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+
+ def on_device_change(self, device_id, change_type):
+ """设备变化回调"""
+ if change_type == NotifyChangeType.DEVICE_CREATE:
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ elif change_type == NotifyChangeType.DEVICE_DELETE:
+ self.stop_device_thread(device_id)
+ elif change_type == NotifyChangeType.DEVICE_UPDATE:
+ self.restart_device_thread(device_id)
+
+ def on_model_change(self, model_id, change_type):
+ """模型变化回调"""
+ if change_type == NotifyChangeType.MODEL_UPDATE:
+ self.model_manager.reload_model(model_id)
+
+ devices_to_reload = [
+ device_id for device_id, relation_info in self.device_model_relations.items()
+ if relation_info.model_id == model_id
+ ]
+ for device_id in devices_to_reload:
+ self.restart_device_thread(device_id)
+
+ def on_relation_change(self, device_id, change_type):
+ """设备模型关系变化回调"""
+ if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE:
+ self.restart_device_thread(device_id)
diff --git a/algo/model_manager.py b/algo/model_manager.py
new file mode 100644
index 0000000..3ab0709
--- /dev/null
+++ b/algo/model_manager.py
@@ -0,0 +1,64 @@
+from dataclasses import dataclass
+from typing import Optional
+
+import numpy as np
+from ultralytics import YOLO
+
+from entity.model import AlgoModel
+from services.model_service import ModelService
+from common.global_logger import logger
+
+
+@dataclass
+class AlgoModelExec:
+ algo_model_id: int
+ algo_model_info: AlgoModel
+ algo_model_exec: Optional[object] = None
+ input_size: int = 640
+
+
+class ModelManager:
+ def __init__(self, db, model_warm_up=5):
+ self.db = db
+ self.model_service = ModelService(self.db)
+ self.models = {}
+ self.model_warm_up = model_warm_up
+
+ self.load_models()
+
+ def query_model_inuse(self):
+ algo_model_list = list(self.model_service.get_models_in_use())
+ for algo_model in algo_model_list:
+ self.models[algo_model.id] = AlgoModelExec(
+ algo_model_id=algo_model.id,
+ algo_model_info=algo_model
+ )
+
+ def load_models(self):
+ self.models = {}
+ self.query_model_inuse()
+ for algo_model_id, algo_model_exec in self.models.items():
+ self.load_model(algo_model_exec)
+
+ def load_model(self, algo_model_exec: AlgoModelExec):
+ model_name = algo_model_exec.algo_model_info.name
+ model_path = algo_model_exec.algo_model_info.path
+ logger.info(f'loading model {model_name}: {model_path}')
+
+ algo_model_exec.algo_model_exec = YOLO(model_path, task='detect')
+ if self.model_warm_up > 0:
+ logger.info(f'warming up model {model_name}')
+ imgsz = algo_model_exec.input_size
+ if not isinstance(imgsz, list):
+ imgsz = [imgsz, imgsz]
+ dummy_input = np.zeros((imgsz[0], imgsz[1], 3))
+ for i in range(self.model_warm_up):
+ algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False)
+ logger.info(f'warm up model {model_name} success!')
+ logger.info(f'load model {model_name} success!')
+
+ def reload_model(self, model_id):
+ algo_model_exec = self.models.get(model_id)
+ if algo_model_exec:
+ algo_model_exec.algo_model_exec = None
+ self.load_model(algo_model_exec)
diff --git a/apis/__init__.py b/apis/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/apis/__init__.py
diff --git a/apis/base.py b/apis/base.py
new file mode 100644
index 0000000..a7c4ede
--- /dev/null
+++ b/apis/base.py
@@ -0,0 +1,26 @@
+from pydantic import BaseModel, Field
+from typing import Optional, Generic, TypeVar, Any, List
+
+# 定义一个泛型类型变量 T
+T = TypeVar("T")
+
+
+# 使用泛型 T 定义 data 的类型
+class StandardResponse(BaseModel, Generic[T]):
+ code: int = 200
+ data: Optional[T] = None
+ message: str = "请求成功"
+ success: bool = True
+
+
+class PageResponse(BaseModel, Generic[T]):
+ total: int
+ items: List[T] = Field(default_factory=list)
+
+
+def standard_response(data: Any = None, code: int = 200, message: str = "请求成功", success: bool = True):
+ return StandardResponse(data=data, code=code, message=message, success=success)
+
+
+def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False):
+ return StandardResponse(data=data, code=code, message=message, success=success)
diff --git a/apis/device.py b/apis/device.py
new file mode 100644
index 0000000..7eeaa78
--- /dev/null
+++ b/apis/device.py
@@ -0,0 +1,65 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, Query
+from sqlmodel import Session
+
+from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response
+from db.database import get_db
+from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo
+from services.device_service import DeviceService
+
+router = APIRouter()
+
+
+@router.get("/list", response_model=StandardResponse[List[DeviceInfo]])
+def get_device_list(
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ devices = list(service.get_device_list(name, code, device_type))
+ return standard_response(data=devices)
+
+
+@router.get("/page/", response_model=StandardResponse[PageResponse[Device]])
+def get_device_page(
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ offset: int = Query(0, ge=0), # 从第几条开始
+ limit: int = Query(10, ge=1), # 每页显示多少条记录
+ db: Session = Depends(get_db)):
+ service = DeviceService(db)
+
+ # 获取分页后的设备列表和总数
+ devices, total = service.get_device_page(name, code, device_type, offset, limit)
+
+ return standard_response(
+ data=PageResponse(total=total, items=devices)
+ )
+
+
+@router.post("/add", response_model=StandardResponse[DeviceInfo])
+def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.create_device(device_data)
+ return standard_response(data=device)
+
+
+@router.post("/update", response_model=StandardResponse[DeviceInfo])
+def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.update_device(device_data)
+ if not device:
+ return standard_error_response(data=device_data, message="Device not found")
+ return standard_response(data=device)
+
+
+@router.delete("/delete", response_model=StandardResponse[int])
+def delete_device(device_id: int, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.delete_device(device_id)
+ if not device:
+ return standard_error_response(data=device_id, message="Device not found")
+ return standard_response(data=device_id)
diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py
new file mode 100644
index 0000000..7108571
--- /dev/null
+++ b/apis/device_model_realtion.py
@@ -0,0 +1,45 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, Query
+from sqlmodel import Session
+
+from apis.base import StandardResponse, standard_response
+from db.database import get_db
+from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation
+from services.device_model_relation_service import DeviceModelRelationService
+
+router = APIRouter()
+
+
+@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]])
+def list_by_device(
+ device_id: int,
+ db: Session = Depends(get_db)):
+ service = DeviceModelRelationService(db)
+ models = list(service.get_device_models(device_id))
+ return standard_response(data=models)
+
+
+# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
+# def add_by_device(relation_data: List[DeviceModelRelationCreate],
+# device_id: int = Query(...),
+# db: Session = Depends(get_db)):
+# service = DeviceModelRelationService(db)
+# relations = service.add_relations_by_device(device_id, relation_data)
+# return standard_response(data=relations)
+
+
+@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
+def update_by_device(relation_data: List[DeviceModelRelationCreate],
+ device_id: int = Query(...),
+ db: Session = Depends(get_db)):
+ service = DeviceModelRelationService(db)
+ relations = service.update_relations_by_device(device_id, relation_data)
+ return standard_response(data=relations)
+
+
+# @router.delete("/delete_by_device", response_model=StandardResponse[int])
+# def delete_device(device_id: int, db: Session = Depends(get_db)):
+# service = DeviceModelRelationService(db)
+# count = service.delete_relations_by_device(device_id)
+# return standard_response(data=count)
diff --git a/apis/model.py b/apis/model.py
new file mode 100644
index 0000000..8adcc14
--- /dev/null
+++ b/apis/model.py
@@ -0,0 +1,64 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, HTTPException, Query
+from sqlmodel import Session
+
+from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response
+from common.biz_exception import BizException
+from db.database import get_db
+from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo
+from services.model_service import ModelService
+
+router = APIRouter()
+
+
+@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]])
+def get_model_list(
+ name: Optional[str] = None,
+ remark: Optional[str] = None,
+ db: Session = Depends(get_db)):
+ service = ModelService(db)
+ models = list(service.get_model_list(name, remark))
+ return standard_response(data=models)
+
+
+@router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModel]])
+def get_model_page(
+ name: Optional[str] = None,
+ remark: Optional[str] = None,
+ offset: int = Query(0, ge=0), # 从第几条开始
+ limit: int = Query(10, ge=1), # 每页显示多少条记录
+ db: Session = Depends(get_db)):
+ service = ModelService(db)
+
+ # 获取分页后的设备列表和总数
+ models, total = service.get_model_page(name, remark, offset, limit)
+
+ return standard_response(
+ data=PageResponse(total=total, items=models)
+ )
+
+
+@router.post("/add", response_model=StandardResponse[AlgoModelInfo])
+def create_model(model_data: AlgoModelCreate, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.create_model(model_data)
+ return standard_response(data=model)
+
+
+@router.post("/update", response_model=StandardResponse[AlgoModelInfo])
+def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.update_model(model_data)
+ if not model:
+ return standard_error_response(data=model_data, message="Model not found")
+ return standard_response(data=model)
+
+
+@router.delete("/delete", response_model=StandardResponse[int])
+def delete_model(model_id: int, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.delete_model(model_id)
+ if not model:
+ return standard_error_response(data=model_id, message="Model not found")
+ return standard_response(data=model_id)
diff --git a/apis/router.py b/apis/router.py
new file mode 100644
index 0000000..22b5ed6
--- /dev/null
+++ b/apis/router.py
@@ -0,0 +1,15 @@
+# api/router.py
+from fastapi import APIRouter
+from .device import router as devices_router
+from .model import router as models_router
+from .device_model_realtion import router as device_model_relation_router
+
+
+# 创建一个全局的 router
+router = APIRouter()
+
+# 将各个模块的 router 注册到全局 router 中
+router.include_router(devices_router, prefix="/device", tags=["Devices"])
+router.include_router(models_router, prefix="/model", tags=["Models"])
+router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"])
+
diff --git a/common/__init__.py b/common/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/common/__init__.py
diff --git a/common/biz_exception.py b/common/biz_exception.py
new file mode 100644
index 0000000..946a794
--- /dev/null
+++ b/common/biz_exception.py
@@ -0,0 +1,14 @@
+from fastapi import Request, HTTPException
+
+from apis.base import standard_error_response
+
+
+class BizException(HTTPException):
+ def __init__(self, message: str = "请求异常", status_code: int = 500):
+ super().__init__(status_code=status_code, detail=message)
+
+
+class BizExceptionHandlers:
+ @staticmethod
+ async def biz_exception_handler(request: Request, exc: HTTPException):
+ standard_error_response(code=exc.status_code, message=exc.detail)
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..e530a86
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+test.py
+/logs/*
\ No newline at end of file
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..901fe74
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,20 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..b68d82f
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..3deddad
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml
new file mode 100644
index 0000000..d0876a7
--- /dev/null
+++ b/.idea/safe-algo-pro.iml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..94a25f7
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/algo/__init__.py b/algo/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/algo/__init__.py
diff --git a/algo/algo_runner.py b/algo/algo_runner.py
new file mode 100644
index 0000000..ee07eaf
--- /dev/null
+++ b/algo/algo_runner.py
@@ -0,0 +1,117 @@
+from algo.model_manager import ModelManager
+from common.consts import NotifyChangeType
+from common.global_thread_pool import GlobalThreadPool
+from db.database import get_db
+from entity.device import Device
+from services.device_model_relation_service import DeviceModelRelationService
+from services.device_service import DeviceService
+from services.model_service import ModelService
+
+
+class AlgoRunner:
+ def __init__(self):
+ self.db = get_db()
+ self.device_service = DeviceService(self.db)
+ self.model_service = ModelService(self.db)
+ self.relation_service = DeviceModelRelationService(self.db)
+ self.model_manager = ModelManager(self.db)
+ self.thread_pool = GlobalThreadPool()
+ self.threads = {} # 用于存储设备对应的线程
+ self.device_model_relations = {}
+
+ # 注册设备和模型的变化回调
+ self.device_service.register_change_callback(self.on_device_change)
+ self.model_service.register_change_callback(self.on_model_change)
+ self.relation_service.relation_change_callbacks(self.on_relation_change)
+
+ def start(self):
+ """在程序启动时调用,读取设备和模型,启动检测线程"""
+ self.model_manager.load_models()
+ self.load_and_start_devices()
+
+ def load_and_start_devices(self):
+ """从数据库读取设备列表并启动线程"""
+ devices = self.device_service.get_device_list()
+ devices = [device for device in devices if device.input_stream_url]
+ for device in devices:
+ self.start_device_thread(device)
+
+ def run_detection(self, device: Device):
+ """todo 设备目标检测主逻辑"""
+ print(f"设备 {device.device_id} 的检测线程启动")
+ video_stream = self.device_service.get_video_stream(device.device_id)
+
+ while True:
+ try:
+ frame = video_stream.read_frame()
+ for model in self.model_manager.models:
+ # 调用目标检测模型
+ result = self.model_service.run_inference(model, frame)
+ # 在此处进行后处理,例如标记视频、生成告警等
+ self.process_result(result, device.device_id)
+ except Exception as e:
+ print(f"设备 {device.device_id} 处理时出错: {e}")
+ break
+ video_stream.close()
+ print(f"设备 {device.device_id} 的检测线程结束")
+
+ def start_device_thread(self, device: Device):
+ """为单个设备启动检测线程"""
+
+ if not device.input_stream_url:
+ print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程')
+ return
+
+ if device.device_id in self.threads:
+ print(f"设备 {device.device_id} 已经在运行中")
+ return
+
+ # 获取设备绑定的模型列表
+ relations = self.relation_service.get_device_models(device.device_id)
+ self.device_model_relations[device.device_id] = relations
+
+ if not relations:
+ print(f"设备 {device.code} 未绑定模型,无法启动检测")
+ return
+
+ future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}')
+
+ def stop_device_thread(self, device_id):
+ """todo 控制标志位 停止指定设备的检测线程"""
+ if device_id in self.threads:
+ self.threads[device_id].cancel() # 尝试取消线程
+ del self.threads[device_id]
+ print(f"设备 {device_id} 的检测线程已停止")
+
+ def restart_device_thread(self, device_id):
+ self.stop_device_thread(device_id)
+ # todo 需要留个关闭时间吗
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+
+ def on_device_change(self, device_id, change_type):
+ """设备变化回调"""
+ if change_type == NotifyChangeType.DEVICE_CREATE:
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ elif change_type == NotifyChangeType.DEVICE_DELETE:
+ self.stop_device_thread(device_id)
+ elif change_type == NotifyChangeType.DEVICE_UPDATE:
+ self.restart_device_thread(device_id)
+
+ def on_model_change(self, model_id, change_type):
+ """模型变化回调"""
+ if change_type == NotifyChangeType.MODEL_UPDATE:
+ self.model_manager.reload_model(model_id)
+
+ devices_to_reload = [
+ device_id for device_id, relation_info in self.device_model_relations.items()
+ if relation_info.model_id == model_id
+ ]
+ for device_id in devices_to_reload:
+ self.restart_device_thread(device_id)
+
+ def on_relation_change(self, device_id, change_type):
+ """设备模型关系变化回调"""
+ if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE:
+ self.restart_device_thread(device_id)
diff --git a/algo/model_manager.py b/algo/model_manager.py
new file mode 100644
index 0000000..3ab0709
--- /dev/null
+++ b/algo/model_manager.py
@@ -0,0 +1,64 @@
+from dataclasses import dataclass
+from typing import Optional
+
+import numpy as np
+from ultralytics import YOLO
+
+from entity.model import AlgoModel
+from services.model_service import ModelService
+from common.global_logger import logger
+
+
+@dataclass
+class AlgoModelExec:
+ algo_model_id: int
+ algo_model_info: AlgoModel
+ algo_model_exec: Optional[object] = None
+ input_size: int = 640
+
+
+class ModelManager:
+ def __init__(self, db, model_warm_up=5):
+ self.db = db
+ self.model_service = ModelService(self.db)
+ self.models = {}
+ self.model_warm_up = model_warm_up
+
+ self.load_models()
+
+ def query_model_inuse(self):
+ algo_model_list = list(self.model_service.get_models_in_use())
+ for algo_model in algo_model_list:
+ self.models[algo_model.id] = AlgoModelExec(
+ algo_model_id=algo_model.id,
+ algo_model_info=algo_model
+ )
+
+ def load_models(self):
+ self.models = {}
+ self.query_model_inuse()
+ for algo_model_id, algo_model_exec in self.models.items():
+ self.load_model(algo_model_exec)
+
+ def load_model(self, algo_model_exec: AlgoModelExec):
+ model_name = algo_model_exec.algo_model_info.name
+ model_path = algo_model_exec.algo_model_info.path
+ logger.info(f'loading model {model_name}: {model_path}')
+
+ algo_model_exec.algo_model_exec = YOLO(model_path, task='detect')
+ if self.model_warm_up > 0:
+ logger.info(f'warming up model {model_name}')
+ imgsz = algo_model_exec.input_size
+ if not isinstance(imgsz, list):
+ imgsz = [imgsz, imgsz]
+ dummy_input = np.zeros((imgsz[0], imgsz[1], 3))
+ for i in range(self.model_warm_up):
+ algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False)
+ logger.info(f'warm up model {model_name} success!')
+ logger.info(f'load model {model_name} success!')
+
+ def reload_model(self, model_id):
+ algo_model_exec = self.models.get(model_id)
+ if algo_model_exec:
+ algo_model_exec.algo_model_exec = None
+ self.load_model(algo_model_exec)
diff --git a/apis/__init__.py b/apis/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/apis/__init__.py
diff --git a/apis/base.py b/apis/base.py
new file mode 100644
index 0000000..a7c4ede
--- /dev/null
+++ b/apis/base.py
@@ -0,0 +1,26 @@
+from pydantic import BaseModel, Field
+from typing import Optional, Generic, TypeVar, Any, List
+
+# 定义一个泛型类型变量 T
+T = TypeVar("T")
+
+
+# 使用泛型 T 定义 data 的类型
+class StandardResponse(BaseModel, Generic[T]):
+ code: int = 200
+ data: Optional[T] = None
+ message: str = "请求成功"
+ success: bool = True
+
+
+class PageResponse(BaseModel, Generic[T]):
+ total: int
+ items: List[T] = Field(default_factory=list)
+
+
+def standard_response(data: Any = None, code: int = 200, message: str = "请求成功", success: bool = True):
+ return StandardResponse(data=data, code=code, message=message, success=success)
+
+
+def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False):
+ return StandardResponse(data=data, code=code, message=message, success=success)
diff --git a/apis/device.py b/apis/device.py
new file mode 100644
index 0000000..7eeaa78
--- /dev/null
+++ b/apis/device.py
@@ -0,0 +1,65 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, Query
+from sqlmodel import Session
+
+from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response
+from db.database import get_db
+from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo
+from services.device_service import DeviceService
+
+router = APIRouter()
+
+
+@router.get("/list", response_model=StandardResponse[List[DeviceInfo]])
+def get_device_list(
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ devices = list(service.get_device_list(name, code, device_type))
+ return standard_response(data=devices)
+
+
+@router.get("/page/", response_model=StandardResponse[PageResponse[Device]])
+def get_device_page(
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ offset: int = Query(0, ge=0), # 从第几条开始
+ limit: int = Query(10, ge=1), # 每页显示多少条记录
+ db: Session = Depends(get_db)):
+ service = DeviceService(db)
+
+ # 获取分页后的设备列表和总数
+ devices, total = service.get_device_page(name, code, device_type, offset, limit)
+
+ return standard_response(
+ data=PageResponse(total=total, items=devices)
+ )
+
+
+@router.post("/add", response_model=StandardResponse[DeviceInfo])
+def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.create_device(device_data)
+ return standard_response(data=device)
+
+
+@router.post("/update", response_model=StandardResponse[DeviceInfo])
+def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.update_device(device_data)
+ if not device:
+ return standard_error_response(data=device_data, message="Device not found")
+ return standard_response(data=device)
+
+
+@router.delete("/delete", response_model=StandardResponse[int])
+def delete_device(device_id: int, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.delete_device(device_id)
+ if not device:
+ return standard_error_response(data=device_id, message="Device not found")
+ return standard_response(data=device_id)
diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py
new file mode 100644
index 0000000..7108571
--- /dev/null
+++ b/apis/device_model_realtion.py
@@ -0,0 +1,45 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, Query
+from sqlmodel import Session
+
+from apis.base import StandardResponse, standard_response
+from db.database import get_db
+from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation
+from services.device_model_relation_service import DeviceModelRelationService
+
+router = APIRouter()
+
+
+@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]])
+def list_by_device(
+ device_id: int,
+ db: Session = Depends(get_db)):
+ service = DeviceModelRelationService(db)
+ models = list(service.get_device_models(device_id))
+ return standard_response(data=models)
+
+
+# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
+# def add_by_device(relation_data: List[DeviceModelRelationCreate],
+# device_id: int = Query(...),
+# db: Session = Depends(get_db)):
+# service = DeviceModelRelationService(db)
+# relations = service.add_relations_by_device(device_id, relation_data)
+# return standard_response(data=relations)
+
+
+@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
+def update_by_device(relation_data: List[DeviceModelRelationCreate],
+ device_id: int = Query(...),
+ db: Session = Depends(get_db)):
+ service = DeviceModelRelationService(db)
+ relations = service.update_relations_by_device(device_id, relation_data)
+ return standard_response(data=relations)
+
+
+# @router.delete("/delete_by_device", response_model=StandardResponse[int])
+# def delete_device(device_id: int, db: Session = Depends(get_db)):
+# service = DeviceModelRelationService(db)
+# count = service.delete_relations_by_device(device_id)
+# return standard_response(data=count)
diff --git a/apis/model.py b/apis/model.py
new file mode 100644
index 0000000..8adcc14
--- /dev/null
+++ b/apis/model.py
@@ -0,0 +1,64 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, HTTPException, Query
+from sqlmodel import Session
+
+from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response
+from common.biz_exception import BizException
+from db.database import get_db
+from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo
+from services.model_service import ModelService
+
+router = APIRouter()
+
+
+@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]])
+def get_model_list(
+ name: Optional[str] = None,
+ remark: Optional[str] = None,
+ db: Session = Depends(get_db)):
+ service = ModelService(db)
+ models = list(service.get_model_list(name, remark))
+ return standard_response(data=models)
+
+
+@router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModel]])
+def get_model_page(
+ name: Optional[str] = None,
+ remark: Optional[str] = None,
+ offset: int = Query(0, ge=0), # 从第几条开始
+ limit: int = Query(10, ge=1), # 每页显示多少条记录
+ db: Session = Depends(get_db)):
+ service = ModelService(db)
+
+ # 获取分页后的设备列表和总数
+ models, total = service.get_model_page(name, remark, offset, limit)
+
+ return standard_response(
+ data=PageResponse(total=total, items=models)
+ )
+
+
+@router.post("/add", response_model=StandardResponse[AlgoModelInfo])
+def create_model(model_data: AlgoModelCreate, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.create_model(model_data)
+ return standard_response(data=model)
+
+
+@router.post("/update", response_model=StandardResponse[AlgoModelInfo])
+def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.update_model(model_data)
+ if not model:
+ return standard_error_response(data=model_data, message="Model not found")
+ return standard_response(data=model)
+
+
+@router.delete("/delete", response_model=StandardResponse[int])
+def delete_model(model_id: int, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.delete_model(model_id)
+ if not model:
+ return standard_error_response(data=model_id, message="Model not found")
+ return standard_response(data=model_id)
diff --git a/apis/router.py b/apis/router.py
new file mode 100644
index 0000000..22b5ed6
--- /dev/null
+++ b/apis/router.py
@@ -0,0 +1,15 @@
+# api/router.py
+from fastapi import APIRouter
+from .device import router as devices_router
+from .model import router as models_router
+from .device_model_realtion import router as device_model_relation_router
+
+
+# 创建一个全局的 router
+router = APIRouter()
+
+# 将各个模块的 router 注册到全局 router 中
+router.include_router(devices_router, prefix="/device", tags=["Devices"])
+router.include_router(models_router, prefix="/model", tags=["Models"])
+router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"])
+
diff --git a/common/__init__.py b/common/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/common/__init__.py
diff --git a/common/biz_exception.py b/common/biz_exception.py
new file mode 100644
index 0000000..946a794
--- /dev/null
+++ b/common/biz_exception.py
@@ -0,0 +1,14 @@
+from fastapi import Request, HTTPException
+
+from apis.base import standard_error_response
+
+
+class BizException(HTTPException):
+ def __init__(self, message: str = "请求异常", status_code: int = 500):
+ super().__init__(status_code=status_code, detail=message)
+
+
+class BizExceptionHandlers:
+ @staticmethod
+ async def biz_exception_handler(request: Request, exc: HTTPException):
+ standard_error_response(code=exc.status_code, message=exc.detail)
diff --git a/common/consts.py b/common/consts.py
new file mode 100644
index 0000000..bba41e4
--- /dev/null
+++ b/common/consts.py
@@ -0,0 +1,16 @@
+class Constants:
+ def __setattr__(self, name, value):
+ raise AttributeError(f"Cannot modify constant {name}")
+
+
+class NotifyChangeType(Constants):
+ # 定义常量成员
+ DEVICE_CREATE = "device_create"
+ DEVICE_UPDATE = "device_update"
+ DEVICE_DELETE = "device_delete"
+ MODEL_CREATE = "model_create" # 模型新增不用通知,与设备绑定时才生效
+ MODEL_UPDATE = "model_update"
+ MODEL_DELETE = "model_delete" # 正在使用的模型不能直接删除
+ DEVICE_MODEL_RELATION_CREATE = "device_model_relation_create"
+ DEVICE_MODEL_RELATION_UPDATE = "device_model_relation_update" # 绑定关系变化 应该只用这个吧??
+ DEVICE_MODEL_RELATION_DELETE = "device_model_relation_delete"
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..e530a86
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+test.py
+/logs/*
\ No newline at end of file
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..901fe74
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,20 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..b68d82f
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..3deddad
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml
new file mode 100644
index 0000000..d0876a7
--- /dev/null
+++ b/.idea/safe-algo-pro.iml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..94a25f7
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/algo/__init__.py b/algo/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/algo/__init__.py
diff --git a/algo/algo_runner.py b/algo/algo_runner.py
new file mode 100644
index 0000000..ee07eaf
--- /dev/null
+++ b/algo/algo_runner.py
@@ -0,0 +1,117 @@
+from algo.model_manager import ModelManager
+from common.consts import NotifyChangeType
+from common.global_thread_pool import GlobalThreadPool
+from db.database import get_db
+from entity.device import Device
+from services.device_model_relation_service import DeviceModelRelationService
+from services.device_service import DeviceService
+from services.model_service import ModelService
+
+
+class AlgoRunner:
+ def __init__(self):
+ self.db = get_db()
+ self.device_service = DeviceService(self.db)
+ self.model_service = ModelService(self.db)
+ self.relation_service = DeviceModelRelationService(self.db)
+ self.model_manager = ModelManager(self.db)
+ self.thread_pool = GlobalThreadPool()
+ self.threads = {} # 用于存储设备对应的线程
+ self.device_model_relations = {}
+
+ # 注册设备和模型的变化回调
+ self.device_service.register_change_callback(self.on_device_change)
+ self.model_service.register_change_callback(self.on_model_change)
+ self.relation_service.relation_change_callbacks(self.on_relation_change)
+
+ def start(self):
+ """在程序启动时调用,读取设备和模型,启动检测线程"""
+ self.model_manager.load_models()
+ self.load_and_start_devices()
+
+ def load_and_start_devices(self):
+ """从数据库读取设备列表并启动线程"""
+ devices = self.device_service.get_device_list()
+ devices = [device for device in devices if device.input_stream_url]
+ for device in devices:
+ self.start_device_thread(device)
+
+ def run_detection(self, device: Device):
+ """todo 设备目标检测主逻辑"""
+ print(f"设备 {device.device_id} 的检测线程启动")
+ video_stream = self.device_service.get_video_stream(device.device_id)
+
+ while True:
+ try:
+ frame = video_stream.read_frame()
+ for model in self.model_manager.models:
+ # 调用目标检测模型
+ result = self.model_service.run_inference(model, frame)
+ # 在此处进行后处理,例如标记视频、生成告警等
+ self.process_result(result, device.device_id)
+ except Exception as e:
+ print(f"设备 {device.device_id} 处理时出错: {e}")
+ break
+ video_stream.close()
+ print(f"设备 {device.device_id} 的检测线程结束")
+
+ def start_device_thread(self, device: Device):
+ """为单个设备启动检测线程"""
+
+ if not device.input_stream_url:
+ print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程')
+ return
+
+ if device.device_id in self.threads:
+ print(f"设备 {device.device_id} 已经在运行中")
+ return
+
+ # 获取设备绑定的模型列表
+ relations = self.relation_service.get_device_models(device.device_id)
+ self.device_model_relations[device.device_id] = relations
+
+ if not relations:
+ print(f"设备 {device.code} 未绑定模型,无法启动检测")
+ return
+
+ future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}')
+
+ def stop_device_thread(self, device_id):
+ """todo 控制标志位 停止指定设备的检测线程"""
+ if device_id in self.threads:
+ self.threads[device_id].cancel() # 尝试取消线程
+ del self.threads[device_id]
+ print(f"设备 {device_id} 的检测线程已停止")
+
+ def restart_device_thread(self, device_id):
+ self.stop_device_thread(device_id)
+ # todo 需要留个关闭时间吗
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+
+ def on_device_change(self, device_id, change_type):
+ """设备变化回调"""
+ if change_type == NotifyChangeType.DEVICE_CREATE:
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ elif change_type == NotifyChangeType.DEVICE_DELETE:
+ self.stop_device_thread(device_id)
+ elif change_type == NotifyChangeType.DEVICE_UPDATE:
+ self.restart_device_thread(device_id)
+
+ def on_model_change(self, model_id, change_type):
+ """模型变化回调"""
+ if change_type == NotifyChangeType.MODEL_UPDATE:
+ self.model_manager.reload_model(model_id)
+
+ devices_to_reload = [
+ device_id for device_id, relation_info in self.device_model_relations.items()
+ if relation_info.model_id == model_id
+ ]
+ for device_id in devices_to_reload:
+ self.restart_device_thread(device_id)
+
+ def on_relation_change(self, device_id, change_type):
+ """设备模型关系变化回调"""
+ if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE:
+ self.restart_device_thread(device_id)
diff --git a/algo/model_manager.py b/algo/model_manager.py
new file mode 100644
index 0000000..3ab0709
--- /dev/null
+++ b/algo/model_manager.py
@@ -0,0 +1,64 @@
+from dataclasses import dataclass
+from typing import Optional
+
+import numpy as np
+from ultralytics import YOLO
+
+from entity.model import AlgoModel
+from services.model_service import ModelService
+from common.global_logger import logger
+
+
+@dataclass
+class AlgoModelExec:
+ algo_model_id: int
+ algo_model_info: AlgoModel
+ algo_model_exec: Optional[object] = None
+ input_size: int = 640
+
+
+class ModelManager:
+ def __init__(self, db, model_warm_up=5):
+ self.db = db
+ self.model_service = ModelService(self.db)
+ self.models = {}
+ self.model_warm_up = model_warm_up
+
+ self.load_models()
+
+ def query_model_inuse(self):
+ algo_model_list = list(self.model_service.get_models_in_use())
+ for algo_model in algo_model_list:
+ self.models[algo_model.id] = AlgoModelExec(
+ algo_model_id=algo_model.id,
+ algo_model_info=algo_model
+ )
+
+ def load_models(self):
+ self.models = {}
+ self.query_model_inuse()
+ for algo_model_id, algo_model_exec in self.models.items():
+ self.load_model(algo_model_exec)
+
+ def load_model(self, algo_model_exec: AlgoModelExec):
+ model_name = algo_model_exec.algo_model_info.name
+ model_path = algo_model_exec.algo_model_info.path
+ logger.info(f'loading model {model_name}: {model_path}')
+
+ algo_model_exec.algo_model_exec = YOLO(model_path, task='detect')
+ if self.model_warm_up > 0:
+ logger.info(f'warming up model {model_name}')
+ imgsz = algo_model_exec.input_size
+ if not isinstance(imgsz, list):
+ imgsz = [imgsz, imgsz]
+ dummy_input = np.zeros((imgsz[0], imgsz[1], 3))
+ for i in range(self.model_warm_up):
+ algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False)
+ logger.info(f'warm up model {model_name} success!')
+ logger.info(f'load model {model_name} success!')
+
+ def reload_model(self, model_id):
+ algo_model_exec = self.models.get(model_id)
+ if algo_model_exec:
+ algo_model_exec.algo_model_exec = None
+ self.load_model(algo_model_exec)
diff --git a/apis/__init__.py b/apis/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/apis/__init__.py
diff --git a/apis/base.py b/apis/base.py
new file mode 100644
index 0000000..a7c4ede
--- /dev/null
+++ b/apis/base.py
@@ -0,0 +1,26 @@
+from pydantic import BaseModel, Field
+from typing import Optional, Generic, TypeVar, Any, List
+
+# 定义一个泛型类型变量 T
+T = TypeVar("T")
+
+
+# 使用泛型 T 定义 data 的类型
+class StandardResponse(BaseModel, Generic[T]):
+ code: int = 200
+ data: Optional[T] = None
+ message: str = "请求成功"
+ success: bool = True
+
+
+class PageResponse(BaseModel, Generic[T]):
+ total: int
+ items: List[T] = Field(default_factory=list)
+
+
+def standard_response(data: Any = None, code: int = 200, message: str = "请求成功", success: bool = True):
+ return StandardResponse(data=data, code=code, message=message, success=success)
+
+
+def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False):
+ return StandardResponse(data=data, code=code, message=message, success=success)
diff --git a/apis/device.py b/apis/device.py
new file mode 100644
index 0000000..7eeaa78
--- /dev/null
+++ b/apis/device.py
@@ -0,0 +1,65 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, Query
+from sqlmodel import Session
+
+from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response
+from db.database import get_db
+from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo
+from services.device_service import DeviceService
+
+router = APIRouter()
+
+
+@router.get("/list", response_model=StandardResponse[List[DeviceInfo]])
+def get_device_list(
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ devices = list(service.get_device_list(name, code, device_type))
+ return standard_response(data=devices)
+
+
+@router.get("/page/", response_model=StandardResponse[PageResponse[Device]])
+def get_device_page(
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ offset: int = Query(0, ge=0), # 从第几条开始
+ limit: int = Query(10, ge=1), # 每页显示多少条记录
+ db: Session = Depends(get_db)):
+ service = DeviceService(db)
+
+ # 获取分页后的设备列表和总数
+ devices, total = service.get_device_page(name, code, device_type, offset, limit)
+
+ return standard_response(
+ data=PageResponse(total=total, items=devices)
+ )
+
+
+@router.post("/add", response_model=StandardResponse[DeviceInfo])
+def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.create_device(device_data)
+ return standard_response(data=device)
+
+
+@router.post("/update", response_model=StandardResponse[DeviceInfo])
+def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.update_device(device_data)
+ if not device:
+ return standard_error_response(data=device_data, message="Device not found")
+ return standard_response(data=device)
+
+
+@router.delete("/delete", response_model=StandardResponse[int])
+def delete_device(device_id: int, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.delete_device(device_id)
+ if not device:
+ return standard_error_response(data=device_id, message="Device not found")
+ return standard_response(data=device_id)
diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py
new file mode 100644
index 0000000..7108571
--- /dev/null
+++ b/apis/device_model_realtion.py
@@ -0,0 +1,45 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, Query
+from sqlmodel import Session
+
+from apis.base import StandardResponse, standard_response
+from db.database import get_db
+from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation
+from services.device_model_relation_service import DeviceModelRelationService
+
+router = APIRouter()
+
+
+@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]])
+def list_by_device(
+ device_id: int,
+ db: Session = Depends(get_db)):
+ service = DeviceModelRelationService(db)
+ models = list(service.get_device_models(device_id))
+ return standard_response(data=models)
+
+
+# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
+# def add_by_device(relation_data: List[DeviceModelRelationCreate],
+# device_id: int = Query(...),
+# db: Session = Depends(get_db)):
+# service = DeviceModelRelationService(db)
+# relations = service.add_relations_by_device(device_id, relation_data)
+# return standard_response(data=relations)
+
+
+@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
+def update_by_device(relation_data: List[DeviceModelRelationCreate],
+ device_id: int = Query(...),
+ db: Session = Depends(get_db)):
+ service = DeviceModelRelationService(db)
+ relations = service.update_relations_by_device(device_id, relation_data)
+ return standard_response(data=relations)
+
+
+# @router.delete("/delete_by_device", response_model=StandardResponse[int])
+# def delete_device(device_id: int, db: Session = Depends(get_db)):
+# service = DeviceModelRelationService(db)
+# count = service.delete_relations_by_device(device_id)
+# return standard_response(data=count)
diff --git a/apis/model.py b/apis/model.py
new file mode 100644
index 0000000..8adcc14
--- /dev/null
+++ b/apis/model.py
@@ -0,0 +1,64 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, HTTPException, Query
+from sqlmodel import Session
+
+from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response
+from common.biz_exception import BizException
+from db.database import get_db
+from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo
+from services.model_service import ModelService
+
+router = APIRouter()
+
+
+@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]])
+def get_model_list(
+ name: Optional[str] = None,
+ remark: Optional[str] = None,
+ db: Session = Depends(get_db)):
+ service = ModelService(db)
+ models = list(service.get_model_list(name, remark))
+ return standard_response(data=models)
+
+
+@router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModel]])
+def get_model_page(
+ name: Optional[str] = None,
+ remark: Optional[str] = None,
+ offset: int = Query(0, ge=0), # 从第几条开始
+ limit: int = Query(10, ge=1), # 每页显示多少条记录
+ db: Session = Depends(get_db)):
+ service = ModelService(db)
+
+ # 获取分页后的设备列表和总数
+ models, total = service.get_model_page(name, remark, offset, limit)
+
+ return standard_response(
+ data=PageResponse(total=total, items=models)
+ )
+
+
+@router.post("/add", response_model=StandardResponse[AlgoModelInfo])
+def create_model(model_data: AlgoModelCreate, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.create_model(model_data)
+ return standard_response(data=model)
+
+
+@router.post("/update", response_model=StandardResponse[AlgoModelInfo])
+def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.update_model(model_data)
+ if not model:
+ return standard_error_response(data=model_data, message="Model not found")
+ return standard_response(data=model)
+
+
+@router.delete("/delete", response_model=StandardResponse[int])
+def delete_model(model_id: int, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.delete_model(model_id)
+ if not model:
+ return standard_error_response(data=model_id, message="Model not found")
+ return standard_response(data=model_id)
diff --git a/apis/router.py b/apis/router.py
new file mode 100644
index 0000000..22b5ed6
--- /dev/null
+++ b/apis/router.py
@@ -0,0 +1,15 @@
+# api/router.py
+from fastapi import APIRouter
+from .device import router as devices_router
+from .model import router as models_router
+from .device_model_realtion import router as device_model_relation_router
+
+
+# 创建一个全局的 router
+router = APIRouter()
+
+# 将各个模块的 router 注册到全局 router 中
+router.include_router(devices_router, prefix="/device", tags=["Devices"])
+router.include_router(models_router, prefix="/model", tags=["Models"])
+router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"])
+
diff --git a/common/__init__.py b/common/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/common/__init__.py
diff --git a/common/biz_exception.py b/common/biz_exception.py
new file mode 100644
index 0000000..946a794
--- /dev/null
+++ b/common/biz_exception.py
@@ -0,0 +1,14 @@
+from fastapi import Request, HTTPException
+
+from apis.base import standard_error_response
+
+
+class BizException(HTTPException):
+ def __init__(self, message: str = "请求异常", status_code: int = 500):
+ super().__init__(status_code=status_code, detail=message)
+
+
+class BizExceptionHandlers:
+ @staticmethod
+ async def biz_exception_handler(request: Request, exc: HTTPException):
+ standard_error_response(code=exc.status_code, message=exc.detail)
diff --git a/common/consts.py b/common/consts.py
new file mode 100644
index 0000000..bba41e4
--- /dev/null
+++ b/common/consts.py
@@ -0,0 +1,16 @@
+class Constants:
+ def __setattr__(self, name, value):
+ raise AttributeError(f"Cannot modify constant {name}")
+
+
+class NotifyChangeType(Constants):
+ # 定义常量成员
+ DEVICE_CREATE = "device_create"
+ DEVICE_UPDATE = "device_update"
+ DEVICE_DELETE = "device_delete"
+ MODEL_CREATE = "model_create" # 模型新增不用通知,与设备绑定时才生效
+ MODEL_UPDATE = "model_update"
+ MODEL_DELETE = "model_delete" # 正在使用的模型不能直接删除
+ DEVICE_MODEL_RELATION_CREATE = "device_model_relation_create"
+ DEVICE_MODEL_RELATION_UPDATE = "device_model_relation_update" # 绑定关系变化 应该只用这个吧??
+ DEVICE_MODEL_RELATION_DELETE = "device_model_relation_delete"
diff --git a/common/global_logger.py b/common/global_logger.py
new file mode 100644
index 0000000..1d65173
--- /dev/null
+++ b/common/global_logger.py
@@ -0,0 +1,30 @@
+# logger.py
+import logging.handlers
+import os
+
+# 确保日志目录存在
+log_dir = '../logs'
+if not os.path.exists(log_dir):
+ os.makedirs(log_dir)
+
+# 实例化并导出全局日志记录器
+logger = logging.getLogger("casic_safe_logger")
+logger.setLevel(logging.DEBUG) # 设置日志级别
+
+# 创建一个TimedRotatingFileHandler
+handler = logging.handlers.TimedRotatingFileHandler(
+ os.path.join(log_dir, 'app.log'), # 日志文件名
+ when='midnight', # 每天午夜滚动
+ interval=1 # 滚动间隔为1天
+)
+formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+handler.setFormatter(formatter)
+
+# 将handler添加到日志器
+logger.addHandler(handler)
+
+# 创建控制台处理器
+console_handler = logging.StreamHandler()
+console_handler.setLevel(logging.DEBUG)
+console_handler.setFormatter(formatter)
+logger.addHandler(console_handler)
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..e530a86
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+test.py
+/logs/*
\ No newline at end of file
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..901fe74
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,20 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..b68d82f
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..3deddad
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml
new file mode 100644
index 0000000..d0876a7
--- /dev/null
+++ b/.idea/safe-algo-pro.iml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..94a25f7
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/algo/__init__.py b/algo/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/algo/__init__.py
diff --git a/algo/algo_runner.py b/algo/algo_runner.py
new file mode 100644
index 0000000..ee07eaf
--- /dev/null
+++ b/algo/algo_runner.py
@@ -0,0 +1,117 @@
+from algo.model_manager import ModelManager
+from common.consts import NotifyChangeType
+from common.global_thread_pool import GlobalThreadPool
+from db.database import get_db
+from entity.device import Device
+from services.device_model_relation_service import DeviceModelRelationService
+from services.device_service import DeviceService
+from services.model_service import ModelService
+
+
+class AlgoRunner:
+ def __init__(self):
+ self.db = get_db()
+ self.device_service = DeviceService(self.db)
+ self.model_service = ModelService(self.db)
+ self.relation_service = DeviceModelRelationService(self.db)
+ self.model_manager = ModelManager(self.db)
+ self.thread_pool = GlobalThreadPool()
+ self.threads = {} # 用于存储设备对应的线程
+ self.device_model_relations = {}
+
+ # 注册设备和模型的变化回调
+ self.device_service.register_change_callback(self.on_device_change)
+ self.model_service.register_change_callback(self.on_model_change)
+ self.relation_service.relation_change_callbacks(self.on_relation_change)
+
+ def start(self):
+ """在程序启动时调用,读取设备和模型,启动检测线程"""
+ self.model_manager.load_models()
+ self.load_and_start_devices()
+
+ def load_and_start_devices(self):
+ """从数据库读取设备列表并启动线程"""
+ devices = self.device_service.get_device_list()
+ devices = [device for device in devices if device.input_stream_url]
+ for device in devices:
+ self.start_device_thread(device)
+
+ def run_detection(self, device: Device):
+ """todo 设备目标检测主逻辑"""
+ print(f"设备 {device.device_id} 的检测线程启动")
+ video_stream = self.device_service.get_video_stream(device.device_id)
+
+ while True:
+ try:
+ frame = video_stream.read_frame()
+ for model in self.model_manager.models:
+ # 调用目标检测模型
+ result = self.model_service.run_inference(model, frame)
+ # 在此处进行后处理,例如标记视频、生成告警等
+ self.process_result(result, device.device_id)
+ except Exception as e:
+ print(f"设备 {device.device_id} 处理时出错: {e}")
+ break
+ video_stream.close()
+ print(f"设备 {device.device_id} 的检测线程结束")
+
+ def start_device_thread(self, device: Device):
+ """为单个设备启动检测线程"""
+
+ if not device.input_stream_url:
+ print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程')
+ return
+
+ if device.device_id in self.threads:
+ print(f"设备 {device.device_id} 已经在运行中")
+ return
+
+ # 获取设备绑定的模型列表
+ relations = self.relation_service.get_device_models(device.device_id)
+ self.device_model_relations[device.device_id] = relations
+
+ if not relations:
+ print(f"设备 {device.code} 未绑定模型,无法启动检测")
+ return
+
+ future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}')
+
+ def stop_device_thread(self, device_id):
+ """todo 控制标志位 停止指定设备的检测线程"""
+ if device_id in self.threads:
+ self.threads[device_id].cancel() # 尝试取消线程
+ del self.threads[device_id]
+ print(f"设备 {device_id} 的检测线程已停止")
+
+ def restart_device_thread(self, device_id):
+ self.stop_device_thread(device_id)
+ # todo 需要留个关闭时间吗
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+
+ def on_device_change(self, device_id, change_type):
+ """设备变化回调"""
+ if change_type == NotifyChangeType.DEVICE_CREATE:
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ elif change_type == NotifyChangeType.DEVICE_DELETE:
+ self.stop_device_thread(device_id)
+ elif change_type == NotifyChangeType.DEVICE_UPDATE:
+ self.restart_device_thread(device_id)
+
+ def on_model_change(self, model_id, change_type):
+ """模型变化回调"""
+ if change_type == NotifyChangeType.MODEL_UPDATE:
+ self.model_manager.reload_model(model_id)
+
+ devices_to_reload = [
+ device_id for device_id, relation_info in self.device_model_relations.items()
+ if relation_info.model_id == model_id
+ ]
+ for device_id in devices_to_reload:
+ self.restart_device_thread(device_id)
+
+ def on_relation_change(self, device_id, change_type):
+ """设备模型关系变化回调"""
+ if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE:
+ self.restart_device_thread(device_id)
diff --git a/algo/model_manager.py b/algo/model_manager.py
new file mode 100644
index 0000000..3ab0709
--- /dev/null
+++ b/algo/model_manager.py
@@ -0,0 +1,64 @@
+from dataclasses import dataclass
+from typing import Optional
+
+import numpy as np
+from ultralytics import YOLO
+
+from entity.model import AlgoModel
+from services.model_service import ModelService
+from common.global_logger import logger
+
+
+@dataclass
+class AlgoModelExec:
+ algo_model_id: int
+ algo_model_info: AlgoModel
+ algo_model_exec: Optional[object] = None
+ input_size: int = 640
+
+
+class ModelManager:
+ def __init__(self, db, model_warm_up=5):
+ self.db = db
+ self.model_service = ModelService(self.db)
+ self.models = {}
+ self.model_warm_up = model_warm_up
+
+ self.load_models()
+
+ def query_model_inuse(self):
+ algo_model_list = list(self.model_service.get_models_in_use())
+ for algo_model in algo_model_list:
+ self.models[algo_model.id] = AlgoModelExec(
+ algo_model_id=algo_model.id,
+ algo_model_info=algo_model
+ )
+
+ def load_models(self):
+ self.models = {}
+ self.query_model_inuse()
+ for algo_model_id, algo_model_exec in self.models.items():
+ self.load_model(algo_model_exec)
+
+ def load_model(self, algo_model_exec: AlgoModelExec):
+ model_name = algo_model_exec.algo_model_info.name
+ model_path = algo_model_exec.algo_model_info.path
+ logger.info(f'loading model {model_name}: {model_path}')
+
+ algo_model_exec.algo_model_exec = YOLO(model_path, task='detect')
+ if self.model_warm_up > 0:
+ logger.info(f'warming up model {model_name}')
+ imgsz = algo_model_exec.input_size
+ if not isinstance(imgsz, list):
+ imgsz = [imgsz, imgsz]
+ dummy_input = np.zeros((imgsz[0], imgsz[1], 3))
+ for i in range(self.model_warm_up):
+ algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False)
+ logger.info(f'warm up model {model_name} success!')
+ logger.info(f'load model {model_name} success!')
+
+ def reload_model(self, model_id):
+ algo_model_exec = self.models.get(model_id)
+ if algo_model_exec:
+ algo_model_exec.algo_model_exec = None
+ self.load_model(algo_model_exec)
diff --git a/apis/__init__.py b/apis/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/apis/__init__.py
diff --git a/apis/base.py b/apis/base.py
new file mode 100644
index 0000000..a7c4ede
--- /dev/null
+++ b/apis/base.py
@@ -0,0 +1,26 @@
+from pydantic import BaseModel, Field
+from typing import Optional, Generic, TypeVar, Any, List
+
+# 定义一个泛型类型变量 T
+T = TypeVar("T")
+
+
+# 使用泛型 T 定义 data 的类型
+class StandardResponse(BaseModel, Generic[T]):
+ code: int = 200
+ data: Optional[T] = None
+ message: str = "请求成功"
+ success: bool = True
+
+
+class PageResponse(BaseModel, Generic[T]):
+ total: int
+ items: List[T] = Field(default_factory=list)
+
+
+def standard_response(data: Any = None, code: int = 200, message: str = "请求成功", success: bool = True):
+ return StandardResponse(data=data, code=code, message=message, success=success)
+
+
+def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False):
+ return StandardResponse(data=data, code=code, message=message, success=success)
diff --git a/apis/device.py b/apis/device.py
new file mode 100644
index 0000000..7eeaa78
--- /dev/null
+++ b/apis/device.py
@@ -0,0 +1,65 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, Query
+from sqlmodel import Session
+
+from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response
+from db.database import get_db
+from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo
+from services.device_service import DeviceService
+
+router = APIRouter()
+
+
+@router.get("/list", response_model=StandardResponse[List[DeviceInfo]])
+def get_device_list(
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ devices = list(service.get_device_list(name, code, device_type))
+ return standard_response(data=devices)
+
+
+@router.get("/page/", response_model=StandardResponse[PageResponse[Device]])
+def get_device_page(
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ offset: int = Query(0, ge=0), # 从第几条开始
+ limit: int = Query(10, ge=1), # 每页显示多少条记录
+ db: Session = Depends(get_db)):
+ service = DeviceService(db)
+
+ # 获取分页后的设备列表和总数
+ devices, total = service.get_device_page(name, code, device_type, offset, limit)
+
+ return standard_response(
+ data=PageResponse(total=total, items=devices)
+ )
+
+
+@router.post("/add", response_model=StandardResponse[DeviceInfo])
+def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.create_device(device_data)
+ return standard_response(data=device)
+
+
+@router.post("/update", response_model=StandardResponse[DeviceInfo])
+def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.update_device(device_data)
+ if not device:
+ return standard_error_response(data=device_data, message="Device not found")
+ return standard_response(data=device)
+
+
+@router.delete("/delete", response_model=StandardResponse[int])
+def delete_device(device_id: int, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.delete_device(device_id)
+ if not device:
+ return standard_error_response(data=device_id, message="Device not found")
+ return standard_response(data=device_id)
diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py
new file mode 100644
index 0000000..7108571
--- /dev/null
+++ b/apis/device_model_realtion.py
@@ -0,0 +1,45 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, Query
+from sqlmodel import Session
+
+from apis.base import StandardResponse, standard_response
+from db.database import get_db
+from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation
+from services.device_model_relation_service import DeviceModelRelationService
+
+router = APIRouter()
+
+
+@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]])
+def list_by_device(
+ device_id: int,
+ db: Session = Depends(get_db)):
+ service = DeviceModelRelationService(db)
+ models = list(service.get_device_models(device_id))
+ return standard_response(data=models)
+
+
+# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
+# def add_by_device(relation_data: List[DeviceModelRelationCreate],
+# device_id: int = Query(...),
+# db: Session = Depends(get_db)):
+# service = DeviceModelRelationService(db)
+# relations = service.add_relations_by_device(device_id, relation_data)
+# return standard_response(data=relations)
+
+
+@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
+def update_by_device(relation_data: List[DeviceModelRelationCreate],
+ device_id: int = Query(...),
+ db: Session = Depends(get_db)):
+ service = DeviceModelRelationService(db)
+ relations = service.update_relations_by_device(device_id, relation_data)
+ return standard_response(data=relations)
+
+
+# @router.delete("/delete_by_device", response_model=StandardResponse[int])
+# def delete_device(device_id: int, db: Session = Depends(get_db)):
+# service = DeviceModelRelationService(db)
+# count = service.delete_relations_by_device(device_id)
+# return standard_response(data=count)
diff --git a/apis/model.py b/apis/model.py
new file mode 100644
index 0000000..8adcc14
--- /dev/null
+++ b/apis/model.py
@@ -0,0 +1,64 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, HTTPException, Query
+from sqlmodel import Session
+
+from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response
+from common.biz_exception import BizException
+from db.database import get_db
+from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo
+from services.model_service import ModelService
+
+router = APIRouter()
+
+
+@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]])
+def get_model_list(
+ name: Optional[str] = None,
+ remark: Optional[str] = None,
+ db: Session = Depends(get_db)):
+ service = ModelService(db)
+ models = list(service.get_model_list(name, remark))
+ return standard_response(data=models)
+
+
+@router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModel]])
+def get_model_page(
+ name: Optional[str] = None,
+ remark: Optional[str] = None,
+ offset: int = Query(0, ge=0), # 从第几条开始
+ limit: int = Query(10, ge=1), # 每页显示多少条记录
+ db: Session = Depends(get_db)):
+ service = ModelService(db)
+
+ # 获取分页后的设备列表和总数
+ models, total = service.get_model_page(name, remark, offset, limit)
+
+ return standard_response(
+ data=PageResponse(total=total, items=models)
+ )
+
+
+@router.post("/add", response_model=StandardResponse[AlgoModelInfo])
+def create_model(model_data: AlgoModelCreate, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.create_model(model_data)
+ return standard_response(data=model)
+
+
+@router.post("/update", response_model=StandardResponse[AlgoModelInfo])
+def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.update_model(model_data)
+ if not model:
+ return standard_error_response(data=model_data, message="Model not found")
+ return standard_response(data=model)
+
+
+@router.delete("/delete", response_model=StandardResponse[int])
+def delete_model(model_id: int, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.delete_model(model_id)
+ if not model:
+ return standard_error_response(data=model_id, message="Model not found")
+ return standard_response(data=model_id)
diff --git a/apis/router.py b/apis/router.py
new file mode 100644
index 0000000..22b5ed6
--- /dev/null
+++ b/apis/router.py
@@ -0,0 +1,15 @@
+# api/router.py
+from fastapi import APIRouter
+from .device import router as devices_router
+from .model import router as models_router
+from .device_model_realtion import router as device_model_relation_router
+
+
+# 创建一个全局的 router
+router = APIRouter()
+
+# 将各个模块的 router 注册到全局 router 中
+router.include_router(devices_router, prefix="/device", tags=["Devices"])
+router.include_router(models_router, prefix="/model", tags=["Models"])
+router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"])
+
diff --git a/common/__init__.py b/common/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/common/__init__.py
diff --git a/common/biz_exception.py b/common/biz_exception.py
new file mode 100644
index 0000000..946a794
--- /dev/null
+++ b/common/biz_exception.py
@@ -0,0 +1,14 @@
+from fastapi import Request, HTTPException
+
+from apis.base import standard_error_response
+
+
+class BizException(HTTPException):
+ def __init__(self, message: str = "请求异常", status_code: int = 500):
+ super().__init__(status_code=status_code, detail=message)
+
+
+class BizExceptionHandlers:
+ @staticmethod
+ async def biz_exception_handler(request: Request, exc: HTTPException):
+ standard_error_response(code=exc.status_code, message=exc.detail)
diff --git a/common/consts.py b/common/consts.py
new file mode 100644
index 0000000..bba41e4
--- /dev/null
+++ b/common/consts.py
@@ -0,0 +1,16 @@
+class Constants:
+ def __setattr__(self, name, value):
+ raise AttributeError(f"Cannot modify constant {name}")
+
+
+class NotifyChangeType(Constants):
+ # 定义常量成员
+ DEVICE_CREATE = "device_create"
+ DEVICE_UPDATE = "device_update"
+ DEVICE_DELETE = "device_delete"
+ MODEL_CREATE = "model_create" # 模型新增不用通知,与设备绑定时才生效
+ MODEL_UPDATE = "model_update"
+ MODEL_DELETE = "model_delete" # 正在使用的模型不能直接删除
+ DEVICE_MODEL_RELATION_CREATE = "device_model_relation_create"
+ DEVICE_MODEL_RELATION_UPDATE = "device_model_relation_update" # 绑定关系变化 应该只用这个吧??
+ DEVICE_MODEL_RELATION_DELETE = "device_model_relation_delete"
diff --git a/common/global_logger.py b/common/global_logger.py
new file mode 100644
index 0000000..1d65173
--- /dev/null
+++ b/common/global_logger.py
@@ -0,0 +1,30 @@
+# logger.py
+import logging.handlers
+import os
+
+# 确保日志目录存在
+log_dir = '../logs'
+if not os.path.exists(log_dir):
+ os.makedirs(log_dir)
+
+# 实例化并导出全局日志记录器
+logger = logging.getLogger("casic_safe_logger")
+logger.setLevel(logging.DEBUG) # 设置日志级别
+
+# 创建一个TimedRotatingFileHandler
+handler = logging.handlers.TimedRotatingFileHandler(
+ os.path.join(log_dir, 'app.log'), # 日志文件名
+ when='midnight', # 每天午夜滚动
+ interval=1 # 滚动间隔为1天
+)
+formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+handler.setFormatter(formatter)
+
+# 将handler添加到日志器
+logger.addHandler(handler)
+
+# 创建控制台处理器
+console_handler = logging.StreamHandler()
+console_handler.setLevel(logging.DEBUG)
+console_handler.setFormatter(formatter)
+logger.addHandler(console_handler)
diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py
new file mode 100644
index 0000000..21d7fd5
--- /dev/null
+++ b/common/global_thread_pool.py
@@ -0,0 +1,60 @@
+import uuid
+from concurrent.futures import ThreadPoolExecutor
+import threading
+
+
+def generate_thread_id():
+ """生成唯一的线程 ID"""
+ return str(uuid.uuid4())
+
+
+class GlobalThreadPool:
+ _instance = None
+ _lock = threading.Lock()
+
+ def __new__(cls, *args, **kwargs):
+ with cls._lock:
+ if cls._instance is None:
+ # 第一次创建实例时调用父类的 __new__ 来创建实例
+ cls._instance = super(GlobalThreadPool, cls).__new__(cls)
+ # 在此进行一次性的初始化,比如线程池的创建
+ max_workers = kwargs.get('max_workers', 10)
+ cls._instance.executor = ThreadPoolExecutor(max_workers=max_workers)
+ cls._instance.task_map = {} # 初始化任务映射
+ return cls._instance
+
+ def submit_task(self, fn, *args, thread_id=None, **kwargs):
+ """提交任务到线程池,并记录线程 ID"""
+ if thread_id is None:
+ thread_id = generate_thread_id()
+ if self.check_task_is_running(thread_id):
+ raise ValueError(f"线程 ID {thread_id} 已存在")
+ future = self.executor.submit(fn, *args, **kwargs)
+ self.task_map[thread_id] = future # 记录线程 ID 和 Future 对象的映射
+ return thread_id
+
+ def check_task_is_running(self, thread_id):
+ future = self.task_map.get(thread_id)
+ if future:
+ if future.running():
+ return True
+ else:
+ del self.task_map[thread_id]
+ return False
+ else:
+ return False
+
+ def stop_task(self, thread_id):
+ """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务"""
+ future = self.task_map.get(thread_id)
+ if future:
+ future.cancel() # 尝试取消任务
+ print(f"任务 {thread_id} 已取消")
+ del self.task_map[thread_id] # 从任务映射中删除
+ else:
+ print(f"未找到线程 ID {thread_id}")
+
+ def shutdown(self, wait=True):
+ """关闭线程池"""
+ self.executor.shutdown(wait=wait)
+ GlobalThreadPool._instance = None
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..e530a86
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+test.py
+/logs/*
\ No newline at end of file
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..901fe74
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,20 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..b68d82f
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..3deddad
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml
new file mode 100644
index 0000000..d0876a7
--- /dev/null
+++ b/.idea/safe-algo-pro.iml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..94a25f7
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/algo/__init__.py b/algo/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/algo/__init__.py
diff --git a/algo/algo_runner.py b/algo/algo_runner.py
new file mode 100644
index 0000000..ee07eaf
--- /dev/null
+++ b/algo/algo_runner.py
@@ -0,0 +1,117 @@
+from algo.model_manager import ModelManager
+from common.consts import NotifyChangeType
+from common.global_thread_pool import GlobalThreadPool
+from db.database import get_db
+from entity.device import Device
+from services.device_model_relation_service import DeviceModelRelationService
+from services.device_service import DeviceService
+from services.model_service import ModelService
+
+
+class AlgoRunner:
+ def __init__(self):
+ self.db = get_db()
+ self.device_service = DeviceService(self.db)
+ self.model_service = ModelService(self.db)
+ self.relation_service = DeviceModelRelationService(self.db)
+ self.model_manager = ModelManager(self.db)
+ self.thread_pool = GlobalThreadPool()
+ self.threads = {} # 用于存储设备对应的线程
+ self.device_model_relations = {}
+
+ # 注册设备和模型的变化回调
+ self.device_service.register_change_callback(self.on_device_change)
+ self.model_service.register_change_callback(self.on_model_change)
+ self.relation_service.relation_change_callbacks(self.on_relation_change)
+
+ def start(self):
+ """在程序启动时调用,读取设备和模型,启动检测线程"""
+ self.model_manager.load_models()
+ self.load_and_start_devices()
+
+ def load_and_start_devices(self):
+ """从数据库读取设备列表并启动线程"""
+ devices = self.device_service.get_device_list()
+ devices = [device for device in devices if device.input_stream_url]
+ for device in devices:
+ self.start_device_thread(device)
+
+ def run_detection(self, device: Device):
+ """todo 设备目标检测主逻辑"""
+ print(f"设备 {device.device_id} 的检测线程启动")
+ video_stream = self.device_service.get_video_stream(device.device_id)
+
+ while True:
+ try:
+ frame = video_stream.read_frame()
+ for model in self.model_manager.models:
+ # 调用目标检测模型
+ result = self.model_service.run_inference(model, frame)
+ # 在此处进行后处理,例如标记视频、生成告警等
+ self.process_result(result, device.device_id)
+ except Exception as e:
+ print(f"设备 {device.device_id} 处理时出错: {e}")
+ break
+ video_stream.close()
+ print(f"设备 {device.device_id} 的检测线程结束")
+
+ def start_device_thread(self, device: Device):
+ """为单个设备启动检测线程"""
+
+ if not device.input_stream_url:
+ print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程')
+ return
+
+ if device.device_id in self.threads:
+ print(f"设备 {device.device_id} 已经在运行中")
+ return
+
+ # 获取设备绑定的模型列表
+ relations = self.relation_service.get_device_models(device.device_id)
+ self.device_model_relations[device.device_id] = relations
+
+ if not relations:
+ print(f"设备 {device.code} 未绑定模型,无法启动检测")
+ return
+
+ future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}')
+
+ def stop_device_thread(self, device_id):
+ """todo 控制标志位 停止指定设备的检测线程"""
+ if device_id in self.threads:
+ self.threads[device_id].cancel() # 尝试取消线程
+ del self.threads[device_id]
+ print(f"设备 {device_id} 的检测线程已停止")
+
+ def restart_device_thread(self, device_id):
+ self.stop_device_thread(device_id)
+ # todo 需要留个关闭时间吗
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+
+ def on_device_change(self, device_id, change_type):
+ """设备变化回调"""
+ if change_type == NotifyChangeType.DEVICE_CREATE:
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ elif change_type == NotifyChangeType.DEVICE_DELETE:
+ self.stop_device_thread(device_id)
+ elif change_type == NotifyChangeType.DEVICE_UPDATE:
+ self.restart_device_thread(device_id)
+
+ def on_model_change(self, model_id, change_type):
+ """模型变化回调"""
+ if change_type == NotifyChangeType.MODEL_UPDATE:
+ self.model_manager.reload_model(model_id)
+
+ devices_to_reload = [
+ device_id for device_id, relation_info in self.device_model_relations.items()
+ if relation_info.model_id == model_id
+ ]
+ for device_id in devices_to_reload:
+ self.restart_device_thread(device_id)
+
+ def on_relation_change(self, device_id, change_type):
+ """设备模型关系变化回调"""
+ if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE:
+ self.restart_device_thread(device_id)
diff --git a/algo/model_manager.py b/algo/model_manager.py
new file mode 100644
index 0000000..3ab0709
--- /dev/null
+++ b/algo/model_manager.py
@@ -0,0 +1,64 @@
+from dataclasses import dataclass
+from typing import Optional
+
+import numpy as np
+from ultralytics import YOLO
+
+from entity.model import AlgoModel
+from services.model_service import ModelService
+from common.global_logger import logger
+
+
+@dataclass
+class AlgoModelExec:
+ algo_model_id: int
+ algo_model_info: AlgoModel
+ algo_model_exec: Optional[object] = None
+ input_size: int = 640
+
+
+class ModelManager:
+ def __init__(self, db, model_warm_up=5):
+ self.db = db
+ self.model_service = ModelService(self.db)
+ self.models = {}
+ self.model_warm_up = model_warm_up
+
+ self.load_models()
+
+ def query_model_inuse(self):
+ algo_model_list = list(self.model_service.get_models_in_use())
+ for algo_model in algo_model_list:
+ self.models[algo_model.id] = AlgoModelExec(
+ algo_model_id=algo_model.id,
+ algo_model_info=algo_model
+ )
+
+ def load_models(self):
+ self.models = {}
+ self.query_model_inuse()
+ for algo_model_id, algo_model_exec in self.models.items():
+ self.load_model(algo_model_exec)
+
+ def load_model(self, algo_model_exec: AlgoModelExec):
+ model_name = algo_model_exec.algo_model_info.name
+ model_path = algo_model_exec.algo_model_info.path
+ logger.info(f'loading model {model_name}: {model_path}')
+
+ algo_model_exec.algo_model_exec = YOLO(model_path, task='detect')
+ if self.model_warm_up > 0:
+ logger.info(f'warming up model {model_name}')
+ imgsz = algo_model_exec.input_size
+ if not isinstance(imgsz, list):
+ imgsz = [imgsz, imgsz]
+ dummy_input = np.zeros((imgsz[0], imgsz[1], 3))
+ for i in range(self.model_warm_up):
+ algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False)
+ logger.info(f'warm up model {model_name} success!')
+ logger.info(f'load model {model_name} success!')
+
+ def reload_model(self, model_id):
+ algo_model_exec = self.models.get(model_id)
+ if algo_model_exec:
+ algo_model_exec.algo_model_exec = None
+ self.load_model(algo_model_exec)
diff --git a/apis/__init__.py b/apis/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/apis/__init__.py
diff --git a/apis/base.py b/apis/base.py
new file mode 100644
index 0000000..a7c4ede
--- /dev/null
+++ b/apis/base.py
@@ -0,0 +1,26 @@
+from pydantic import BaseModel, Field
+from typing import Optional, Generic, TypeVar, Any, List
+
+# 定义一个泛型类型变量 T
+T = TypeVar("T")
+
+
+# 使用泛型 T 定义 data 的类型
+class StandardResponse(BaseModel, Generic[T]):
+ code: int = 200
+ data: Optional[T] = None
+ message: str = "请求成功"
+ success: bool = True
+
+
+class PageResponse(BaseModel, Generic[T]):
+ total: int
+ items: List[T] = Field(default_factory=list)
+
+
+def standard_response(data: Any = None, code: int = 200, message: str = "请求成功", success: bool = True):
+ return StandardResponse(data=data, code=code, message=message, success=success)
+
+
+def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False):
+ return StandardResponse(data=data, code=code, message=message, success=success)
diff --git a/apis/device.py b/apis/device.py
new file mode 100644
index 0000000..7eeaa78
--- /dev/null
+++ b/apis/device.py
@@ -0,0 +1,65 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, Query
+from sqlmodel import Session
+
+from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response
+from db.database import get_db
+from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo
+from services.device_service import DeviceService
+
+router = APIRouter()
+
+
+@router.get("/list", response_model=StandardResponse[List[DeviceInfo]])
+def get_device_list(
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ devices = list(service.get_device_list(name, code, device_type))
+ return standard_response(data=devices)
+
+
+@router.get("/page/", response_model=StandardResponse[PageResponse[Device]])
+def get_device_page(
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ offset: int = Query(0, ge=0), # 从第几条开始
+ limit: int = Query(10, ge=1), # 每页显示多少条记录
+ db: Session = Depends(get_db)):
+ service = DeviceService(db)
+
+ # 获取分页后的设备列表和总数
+ devices, total = service.get_device_page(name, code, device_type, offset, limit)
+
+ return standard_response(
+ data=PageResponse(total=total, items=devices)
+ )
+
+
+@router.post("/add", response_model=StandardResponse[DeviceInfo])
+def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.create_device(device_data)
+ return standard_response(data=device)
+
+
+@router.post("/update", response_model=StandardResponse[DeviceInfo])
+def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.update_device(device_data)
+ if not device:
+ return standard_error_response(data=device_data, message="Device not found")
+ return standard_response(data=device)
+
+
+@router.delete("/delete", response_model=StandardResponse[int])
+def delete_device(device_id: int, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.delete_device(device_id)
+ if not device:
+ return standard_error_response(data=device_id, message="Device not found")
+ return standard_response(data=device_id)
diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py
new file mode 100644
index 0000000..7108571
--- /dev/null
+++ b/apis/device_model_realtion.py
@@ -0,0 +1,45 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, Query
+from sqlmodel import Session
+
+from apis.base import StandardResponse, standard_response
+from db.database import get_db
+from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation
+from services.device_model_relation_service import DeviceModelRelationService
+
+router = APIRouter()
+
+
+@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]])
+def list_by_device(
+ device_id: int,
+ db: Session = Depends(get_db)):
+ service = DeviceModelRelationService(db)
+ models = list(service.get_device_models(device_id))
+ return standard_response(data=models)
+
+
+# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
+# def add_by_device(relation_data: List[DeviceModelRelationCreate],
+# device_id: int = Query(...),
+# db: Session = Depends(get_db)):
+# service = DeviceModelRelationService(db)
+# relations = service.add_relations_by_device(device_id, relation_data)
+# return standard_response(data=relations)
+
+
+@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
+def update_by_device(relation_data: List[DeviceModelRelationCreate],
+ device_id: int = Query(...),
+ db: Session = Depends(get_db)):
+ service = DeviceModelRelationService(db)
+ relations = service.update_relations_by_device(device_id, relation_data)
+ return standard_response(data=relations)
+
+
+# @router.delete("/delete_by_device", response_model=StandardResponse[int])
+# def delete_device(device_id: int, db: Session = Depends(get_db)):
+# service = DeviceModelRelationService(db)
+# count = service.delete_relations_by_device(device_id)
+# return standard_response(data=count)
diff --git a/apis/model.py b/apis/model.py
new file mode 100644
index 0000000..8adcc14
--- /dev/null
+++ b/apis/model.py
@@ -0,0 +1,64 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, HTTPException, Query
+from sqlmodel import Session
+
+from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response
+from common.biz_exception import BizException
+from db.database import get_db
+from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo
+from services.model_service import ModelService
+
+router = APIRouter()
+
+
+@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]])
+def get_model_list(
+ name: Optional[str] = None,
+ remark: Optional[str] = None,
+ db: Session = Depends(get_db)):
+ service = ModelService(db)
+ models = list(service.get_model_list(name, remark))
+ return standard_response(data=models)
+
+
+@router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModel]])
+def get_model_page(
+ name: Optional[str] = None,
+ remark: Optional[str] = None,
+ offset: int = Query(0, ge=0), # 从第几条开始
+ limit: int = Query(10, ge=1), # 每页显示多少条记录
+ db: Session = Depends(get_db)):
+ service = ModelService(db)
+
+ # 获取分页后的设备列表和总数
+ models, total = service.get_model_page(name, remark, offset, limit)
+
+ return standard_response(
+ data=PageResponse(total=total, items=models)
+ )
+
+
+@router.post("/add", response_model=StandardResponse[AlgoModelInfo])
+def create_model(model_data: AlgoModelCreate, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.create_model(model_data)
+ return standard_response(data=model)
+
+
+@router.post("/update", response_model=StandardResponse[AlgoModelInfo])
+def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.update_model(model_data)
+ if not model:
+ return standard_error_response(data=model_data, message="Model not found")
+ return standard_response(data=model)
+
+
+@router.delete("/delete", response_model=StandardResponse[int])
+def delete_model(model_id: int, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.delete_model(model_id)
+ if not model:
+ return standard_error_response(data=model_id, message="Model not found")
+ return standard_response(data=model_id)
diff --git a/apis/router.py b/apis/router.py
new file mode 100644
index 0000000..22b5ed6
--- /dev/null
+++ b/apis/router.py
@@ -0,0 +1,15 @@
+# api/router.py
+from fastapi import APIRouter
+from .device import router as devices_router
+from .model import router as models_router
+from .device_model_realtion import router as device_model_relation_router
+
+
+# 创建一个全局的 router
+router = APIRouter()
+
+# 将各个模块的 router 注册到全局 router 中
+router.include_router(devices_router, prefix="/device", tags=["Devices"])
+router.include_router(models_router, prefix="/model", tags=["Models"])
+router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"])
+
diff --git a/common/__init__.py b/common/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/common/__init__.py
diff --git a/common/biz_exception.py b/common/biz_exception.py
new file mode 100644
index 0000000..946a794
--- /dev/null
+++ b/common/biz_exception.py
@@ -0,0 +1,14 @@
+from fastapi import Request, HTTPException
+
+from apis.base import standard_error_response
+
+
+class BizException(HTTPException):
+ def __init__(self, message: str = "请求异常", status_code: int = 500):
+ super().__init__(status_code=status_code, detail=message)
+
+
+class BizExceptionHandlers:
+ @staticmethod
+ async def biz_exception_handler(request: Request, exc: HTTPException):
+ standard_error_response(code=exc.status_code, message=exc.detail)
diff --git a/common/consts.py b/common/consts.py
new file mode 100644
index 0000000..bba41e4
--- /dev/null
+++ b/common/consts.py
@@ -0,0 +1,16 @@
+class Constants:
+ def __setattr__(self, name, value):
+ raise AttributeError(f"Cannot modify constant {name}")
+
+
+class NotifyChangeType(Constants):
+ # 定义常量成员
+ DEVICE_CREATE = "device_create"
+ DEVICE_UPDATE = "device_update"
+ DEVICE_DELETE = "device_delete"
+ MODEL_CREATE = "model_create" # 模型新增不用通知,与设备绑定时才生效
+ MODEL_UPDATE = "model_update"
+ MODEL_DELETE = "model_delete" # 正在使用的模型不能直接删除
+ DEVICE_MODEL_RELATION_CREATE = "device_model_relation_create"
+ DEVICE_MODEL_RELATION_UPDATE = "device_model_relation_update" # 绑定关系变化 应该只用这个吧??
+ DEVICE_MODEL_RELATION_DELETE = "device_model_relation_delete"
diff --git a/common/global_logger.py b/common/global_logger.py
new file mode 100644
index 0000000..1d65173
--- /dev/null
+++ b/common/global_logger.py
@@ -0,0 +1,30 @@
+# logger.py
+import logging.handlers
+import os
+
+# 确保日志目录存在
+log_dir = '../logs'
+if not os.path.exists(log_dir):
+ os.makedirs(log_dir)
+
+# 实例化并导出全局日志记录器
+logger = logging.getLogger("casic_safe_logger")
+logger.setLevel(logging.DEBUG) # 设置日志级别
+
+# 创建一个TimedRotatingFileHandler
+handler = logging.handlers.TimedRotatingFileHandler(
+ os.path.join(log_dir, 'app.log'), # 日志文件名
+ when='midnight', # 每天午夜滚动
+ interval=1 # 滚动间隔为1天
+)
+formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+handler.setFormatter(formatter)
+
+# 将handler添加到日志器
+logger.addHandler(handler)
+
+# 创建控制台处理器
+console_handler = logging.StreamHandler()
+console_handler.setLevel(logging.DEBUG)
+console_handler.setFormatter(formatter)
+logger.addHandler(console_handler)
diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py
new file mode 100644
index 0000000..21d7fd5
--- /dev/null
+++ b/common/global_thread_pool.py
@@ -0,0 +1,60 @@
+import uuid
+from concurrent.futures import ThreadPoolExecutor
+import threading
+
+
+def generate_thread_id():
+ """生成唯一的线程 ID"""
+ return str(uuid.uuid4())
+
+
+class GlobalThreadPool:
+ _instance = None
+ _lock = threading.Lock()
+
+ def __new__(cls, *args, **kwargs):
+ with cls._lock:
+ if cls._instance is None:
+ # 第一次创建实例时调用父类的 __new__ 来创建实例
+ cls._instance = super(GlobalThreadPool, cls).__new__(cls)
+ # 在此进行一次性的初始化,比如线程池的创建
+ max_workers = kwargs.get('max_workers', 10)
+ cls._instance.executor = ThreadPoolExecutor(max_workers=max_workers)
+ cls._instance.task_map = {} # 初始化任务映射
+ return cls._instance
+
+ def submit_task(self, fn, *args, thread_id=None, **kwargs):
+ """提交任务到线程池,并记录线程 ID"""
+ if thread_id is None:
+ thread_id = generate_thread_id()
+ if self.check_task_is_running(thread_id):
+ raise ValueError(f"线程 ID {thread_id} 已存在")
+ future = self.executor.submit(fn, *args, **kwargs)
+ self.task_map[thread_id] = future # 记录线程 ID 和 Future 对象的映射
+ return thread_id
+
+ def check_task_is_running(self, thread_id):
+ future = self.task_map.get(thread_id)
+ if future:
+ if future.running():
+ return True
+ else:
+ del self.task_map[thread_id]
+ return False
+ else:
+ return False
+
+ def stop_task(self, thread_id):
+ """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务"""
+ future = self.task_map.get(thread_id)
+ if future:
+ future.cancel() # 尝试取消任务
+ print(f"任务 {thread_id} 已取消")
+ del self.task_map[thread_id] # 从任务映射中删除
+ else:
+ print(f"未找到线程 ID {thread_id}")
+
+ def shutdown(self, wait=True):
+ """关闭线程池"""
+ self.executor.shutdown(wait=wait)
+ GlobalThreadPool._instance = None
diff --git a/db/__init__.py b/db/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/db/__init__.py
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..e530a86
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+test.py
+/logs/*
\ No newline at end of file
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..901fe74
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,20 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..b68d82f
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..3deddad
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml
new file mode 100644
index 0000000..d0876a7
--- /dev/null
+++ b/.idea/safe-algo-pro.iml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..94a25f7
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/algo/__init__.py b/algo/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/algo/__init__.py
diff --git a/algo/algo_runner.py b/algo/algo_runner.py
new file mode 100644
index 0000000..ee07eaf
--- /dev/null
+++ b/algo/algo_runner.py
@@ -0,0 +1,117 @@
+from algo.model_manager import ModelManager
+from common.consts import NotifyChangeType
+from common.global_thread_pool import GlobalThreadPool
+from db.database import get_db
+from entity.device import Device
+from services.device_model_relation_service import DeviceModelRelationService
+from services.device_service import DeviceService
+from services.model_service import ModelService
+
+
+class AlgoRunner:
+ def __init__(self):
+ self.db = get_db()
+ self.device_service = DeviceService(self.db)
+ self.model_service = ModelService(self.db)
+ self.relation_service = DeviceModelRelationService(self.db)
+ self.model_manager = ModelManager(self.db)
+ self.thread_pool = GlobalThreadPool()
+ self.threads = {} # 用于存储设备对应的线程
+ self.device_model_relations = {}
+
+ # 注册设备和模型的变化回调
+ self.device_service.register_change_callback(self.on_device_change)
+ self.model_service.register_change_callback(self.on_model_change)
+ self.relation_service.relation_change_callbacks(self.on_relation_change)
+
+ def start(self):
+ """在程序启动时调用,读取设备和模型,启动检测线程"""
+ self.model_manager.load_models()
+ self.load_and_start_devices()
+
+ def load_and_start_devices(self):
+ """从数据库读取设备列表并启动线程"""
+ devices = self.device_service.get_device_list()
+ devices = [device for device in devices if device.input_stream_url]
+ for device in devices:
+ self.start_device_thread(device)
+
+ def run_detection(self, device: Device):
+ """todo 设备目标检测主逻辑"""
+ print(f"设备 {device.device_id} 的检测线程启动")
+ video_stream = self.device_service.get_video_stream(device.device_id)
+
+ while True:
+ try:
+ frame = video_stream.read_frame()
+ for model in self.model_manager.models:
+ # 调用目标检测模型
+ result = self.model_service.run_inference(model, frame)
+ # 在此处进行后处理,例如标记视频、生成告警等
+ self.process_result(result, device.device_id)
+ except Exception as e:
+ print(f"设备 {device.device_id} 处理时出错: {e}")
+ break
+ video_stream.close()
+ print(f"设备 {device.device_id} 的检测线程结束")
+
+ def start_device_thread(self, device: Device):
+ """为单个设备启动检测线程"""
+
+ if not device.input_stream_url:
+ print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程')
+ return
+
+ if device.device_id in self.threads:
+ print(f"设备 {device.device_id} 已经在运行中")
+ return
+
+ # 获取设备绑定的模型列表
+ relations = self.relation_service.get_device_models(device.device_id)
+ self.device_model_relations[device.device_id] = relations
+
+ if not relations:
+ print(f"设备 {device.code} 未绑定模型,无法启动检测")
+ return
+
+ future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}')
+
+ def stop_device_thread(self, device_id):
+ """todo 控制标志位 停止指定设备的检测线程"""
+ if device_id in self.threads:
+ self.threads[device_id].cancel() # 尝试取消线程
+ del self.threads[device_id]
+ print(f"设备 {device_id} 的检测线程已停止")
+
+ def restart_device_thread(self, device_id):
+ self.stop_device_thread(device_id)
+ # todo 需要留个关闭时间吗
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+
+ def on_device_change(self, device_id, change_type):
+ """设备变化回调"""
+ if change_type == NotifyChangeType.DEVICE_CREATE:
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ elif change_type == NotifyChangeType.DEVICE_DELETE:
+ self.stop_device_thread(device_id)
+ elif change_type == NotifyChangeType.DEVICE_UPDATE:
+ self.restart_device_thread(device_id)
+
+ def on_model_change(self, model_id, change_type):
+ """模型变化回调"""
+ if change_type == NotifyChangeType.MODEL_UPDATE:
+ self.model_manager.reload_model(model_id)
+
+ devices_to_reload = [
+ device_id for device_id, relation_info in self.device_model_relations.items()
+ if relation_info.model_id == model_id
+ ]
+ for device_id in devices_to_reload:
+ self.restart_device_thread(device_id)
+
+ def on_relation_change(self, device_id, change_type):
+ """设备模型关系变化回调"""
+ if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE:
+ self.restart_device_thread(device_id)
diff --git a/algo/model_manager.py b/algo/model_manager.py
new file mode 100644
index 0000000..3ab0709
--- /dev/null
+++ b/algo/model_manager.py
@@ -0,0 +1,64 @@
+from dataclasses import dataclass
+from typing import Optional
+
+import numpy as np
+from ultralytics import YOLO
+
+from entity.model import AlgoModel
+from services.model_service import ModelService
+from common.global_logger import logger
+
+
+@dataclass
+class AlgoModelExec:
+ algo_model_id: int
+ algo_model_info: AlgoModel
+ algo_model_exec: Optional[object] = None
+ input_size: int = 640
+
+
+class ModelManager:
+ def __init__(self, db, model_warm_up=5):
+ self.db = db
+ self.model_service = ModelService(self.db)
+ self.models = {}
+ self.model_warm_up = model_warm_up
+
+ self.load_models()
+
+ def query_model_inuse(self):
+ algo_model_list = list(self.model_service.get_models_in_use())
+ for algo_model in algo_model_list:
+ self.models[algo_model.id] = AlgoModelExec(
+ algo_model_id=algo_model.id,
+ algo_model_info=algo_model
+ )
+
+ def load_models(self):
+ self.models = {}
+ self.query_model_inuse()
+ for algo_model_id, algo_model_exec in self.models.items():
+ self.load_model(algo_model_exec)
+
+ def load_model(self, algo_model_exec: AlgoModelExec):
+ model_name = algo_model_exec.algo_model_info.name
+ model_path = algo_model_exec.algo_model_info.path
+ logger.info(f'loading model {model_name}: {model_path}')
+
+ algo_model_exec.algo_model_exec = YOLO(model_path, task='detect')
+ if self.model_warm_up > 0:
+ logger.info(f'warming up model {model_name}')
+ imgsz = algo_model_exec.input_size
+ if not isinstance(imgsz, list):
+ imgsz = [imgsz, imgsz]
+ dummy_input = np.zeros((imgsz[0], imgsz[1], 3))
+ for i in range(self.model_warm_up):
+ algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False)
+ logger.info(f'warm up model {model_name} success!')
+ logger.info(f'load model {model_name} success!')
+
+ def reload_model(self, model_id):
+ algo_model_exec = self.models.get(model_id)
+ if algo_model_exec:
+ algo_model_exec.algo_model_exec = None
+ self.load_model(algo_model_exec)
diff --git a/apis/__init__.py b/apis/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/apis/__init__.py
diff --git a/apis/base.py b/apis/base.py
new file mode 100644
index 0000000..a7c4ede
--- /dev/null
+++ b/apis/base.py
@@ -0,0 +1,26 @@
+from pydantic import BaseModel, Field
+from typing import Optional, Generic, TypeVar, Any, List
+
+# 定义一个泛型类型变量 T
+T = TypeVar("T")
+
+
+# 使用泛型 T 定义 data 的类型
+class StandardResponse(BaseModel, Generic[T]):
+ code: int = 200
+ data: Optional[T] = None
+ message: str = "请求成功"
+ success: bool = True
+
+
+class PageResponse(BaseModel, Generic[T]):
+ total: int
+ items: List[T] = Field(default_factory=list)
+
+
+def standard_response(data: Any = None, code: int = 200, message: str = "请求成功", success: bool = True):
+ return StandardResponse(data=data, code=code, message=message, success=success)
+
+
+def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False):
+ return StandardResponse(data=data, code=code, message=message, success=success)
diff --git a/apis/device.py b/apis/device.py
new file mode 100644
index 0000000..7eeaa78
--- /dev/null
+++ b/apis/device.py
@@ -0,0 +1,65 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, Query
+from sqlmodel import Session
+
+from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response
+from db.database import get_db
+from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo
+from services.device_service import DeviceService
+
+router = APIRouter()
+
+
+@router.get("/list", response_model=StandardResponse[List[DeviceInfo]])
+def get_device_list(
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ devices = list(service.get_device_list(name, code, device_type))
+ return standard_response(data=devices)
+
+
+@router.get("/page/", response_model=StandardResponse[PageResponse[Device]])
+def get_device_page(
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ offset: int = Query(0, ge=0), # 从第几条开始
+ limit: int = Query(10, ge=1), # 每页显示多少条记录
+ db: Session = Depends(get_db)):
+ service = DeviceService(db)
+
+ # 获取分页后的设备列表和总数
+ devices, total = service.get_device_page(name, code, device_type, offset, limit)
+
+ return standard_response(
+ data=PageResponse(total=total, items=devices)
+ )
+
+
+@router.post("/add", response_model=StandardResponse[DeviceInfo])
+def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.create_device(device_data)
+ return standard_response(data=device)
+
+
+@router.post("/update", response_model=StandardResponse[DeviceInfo])
+def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.update_device(device_data)
+ if not device:
+ return standard_error_response(data=device_data, message="Device not found")
+ return standard_response(data=device)
+
+
+@router.delete("/delete", response_model=StandardResponse[int])
+def delete_device(device_id: int, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.delete_device(device_id)
+ if not device:
+ return standard_error_response(data=device_id, message="Device not found")
+ return standard_response(data=device_id)
diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py
new file mode 100644
index 0000000..7108571
--- /dev/null
+++ b/apis/device_model_realtion.py
@@ -0,0 +1,45 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, Query
+from sqlmodel import Session
+
+from apis.base import StandardResponse, standard_response
+from db.database import get_db
+from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation
+from services.device_model_relation_service import DeviceModelRelationService
+
+router = APIRouter()
+
+
+@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]])
+def list_by_device(
+ device_id: int,
+ db: Session = Depends(get_db)):
+ service = DeviceModelRelationService(db)
+ models = list(service.get_device_models(device_id))
+ return standard_response(data=models)
+
+
+# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
+# def add_by_device(relation_data: List[DeviceModelRelationCreate],
+# device_id: int = Query(...),
+# db: Session = Depends(get_db)):
+# service = DeviceModelRelationService(db)
+# relations = service.add_relations_by_device(device_id, relation_data)
+# return standard_response(data=relations)
+
+
+@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
+def update_by_device(relation_data: List[DeviceModelRelationCreate],
+ device_id: int = Query(...),
+ db: Session = Depends(get_db)):
+ service = DeviceModelRelationService(db)
+ relations = service.update_relations_by_device(device_id, relation_data)
+ return standard_response(data=relations)
+
+
+# @router.delete("/delete_by_device", response_model=StandardResponse[int])
+# def delete_device(device_id: int, db: Session = Depends(get_db)):
+# service = DeviceModelRelationService(db)
+# count = service.delete_relations_by_device(device_id)
+# return standard_response(data=count)
diff --git a/apis/model.py b/apis/model.py
new file mode 100644
index 0000000..8adcc14
--- /dev/null
+++ b/apis/model.py
@@ -0,0 +1,64 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, HTTPException, Query
+from sqlmodel import Session
+
+from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response
+from common.biz_exception import BizException
+from db.database import get_db
+from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo
+from services.model_service import ModelService
+
+router = APIRouter()
+
+
+@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]])
+def get_model_list(
+ name: Optional[str] = None,
+ remark: Optional[str] = None,
+ db: Session = Depends(get_db)):
+ service = ModelService(db)
+ models = list(service.get_model_list(name, remark))
+ return standard_response(data=models)
+
+
+@router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModel]])
+def get_model_page(
+ name: Optional[str] = None,
+ remark: Optional[str] = None,
+ offset: int = Query(0, ge=0), # 从第几条开始
+ limit: int = Query(10, ge=1), # 每页显示多少条记录
+ db: Session = Depends(get_db)):
+ service = ModelService(db)
+
+ # 获取分页后的设备列表和总数
+ models, total = service.get_model_page(name, remark, offset, limit)
+
+ return standard_response(
+ data=PageResponse(total=total, items=models)
+ )
+
+
+@router.post("/add", response_model=StandardResponse[AlgoModelInfo])
+def create_model(model_data: AlgoModelCreate, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.create_model(model_data)
+ return standard_response(data=model)
+
+
+@router.post("/update", response_model=StandardResponse[AlgoModelInfo])
+def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.update_model(model_data)
+ if not model:
+ return standard_error_response(data=model_data, message="Model not found")
+ return standard_response(data=model)
+
+
+@router.delete("/delete", response_model=StandardResponse[int])
+def delete_model(model_id: int, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.delete_model(model_id)
+ if not model:
+ return standard_error_response(data=model_id, message="Model not found")
+ return standard_response(data=model_id)
diff --git a/apis/router.py b/apis/router.py
new file mode 100644
index 0000000..22b5ed6
--- /dev/null
+++ b/apis/router.py
@@ -0,0 +1,15 @@
+# api/router.py
+from fastapi import APIRouter
+from .device import router as devices_router
+from .model import router as models_router
+from .device_model_realtion import router as device_model_relation_router
+
+
+# 创建一个全局的 router
+router = APIRouter()
+
+# 将各个模块的 router 注册到全局 router 中
+router.include_router(devices_router, prefix="/device", tags=["Devices"])
+router.include_router(models_router, prefix="/model", tags=["Models"])
+router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"])
+
diff --git a/common/__init__.py b/common/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/common/__init__.py
diff --git a/common/biz_exception.py b/common/biz_exception.py
new file mode 100644
index 0000000..946a794
--- /dev/null
+++ b/common/biz_exception.py
@@ -0,0 +1,14 @@
+from fastapi import Request, HTTPException
+
+from apis.base import standard_error_response
+
+
+class BizException(HTTPException):
+ def __init__(self, message: str = "请求异常", status_code: int = 500):
+ super().__init__(status_code=status_code, detail=message)
+
+
+class BizExceptionHandlers:
+ @staticmethod
+ async def biz_exception_handler(request: Request, exc: HTTPException):
+ standard_error_response(code=exc.status_code, message=exc.detail)
diff --git a/common/consts.py b/common/consts.py
new file mode 100644
index 0000000..bba41e4
--- /dev/null
+++ b/common/consts.py
@@ -0,0 +1,16 @@
+class Constants:
+ def __setattr__(self, name, value):
+ raise AttributeError(f"Cannot modify constant {name}")
+
+
+class NotifyChangeType(Constants):
+ # 定义常量成员
+ DEVICE_CREATE = "device_create"
+ DEVICE_UPDATE = "device_update"
+ DEVICE_DELETE = "device_delete"
+ MODEL_CREATE = "model_create" # 模型新增不用通知,与设备绑定时才生效
+ MODEL_UPDATE = "model_update"
+ MODEL_DELETE = "model_delete" # 正在使用的模型不能直接删除
+ DEVICE_MODEL_RELATION_CREATE = "device_model_relation_create"
+ DEVICE_MODEL_RELATION_UPDATE = "device_model_relation_update" # 绑定关系变化 应该只用这个吧??
+ DEVICE_MODEL_RELATION_DELETE = "device_model_relation_delete"
diff --git a/common/global_logger.py b/common/global_logger.py
new file mode 100644
index 0000000..1d65173
--- /dev/null
+++ b/common/global_logger.py
@@ -0,0 +1,30 @@
+# logger.py
+import logging.handlers
+import os
+
+# 确保日志目录存在
+log_dir = '../logs'
+if not os.path.exists(log_dir):
+ os.makedirs(log_dir)
+
+# 实例化并导出全局日志记录器
+logger = logging.getLogger("casic_safe_logger")
+logger.setLevel(logging.DEBUG) # 设置日志级别
+
+# 创建一个TimedRotatingFileHandler
+handler = logging.handlers.TimedRotatingFileHandler(
+ os.path.join(log_dir, 'app.log'), # 日志文件名
+ when='midnight', # 每天午夜滚动
+ interval=1 # 滚动间隔为1天
+)
+formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+handler.setFormatter(formatter)
+
+# 将handler添加到日志器
+logger.addHandler(handler)
+
+# 创建控制台处理器
+console_handler = logging.StreamHandler()
+console_handler.setLevel(logging.DEBUG)
+console_handler.setFormatter(formatter)
+logger.addHandler(console_handler)
diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py
new file mode 100644
index 0000000..21d7fd5
--- /dev/null
+++ b/common/global_thread_pool.py
@@ -0,0 +1,60 @@
+import uuid
+from concurrent.futures import ThreadPoolExecutor
+import threading
+
+
+def generate_thread_id():
+ """生成唯一的线程 ID"""
+ return str(uuid.uuid4())
+
+
+class GlobalThreadPool:
+ _instance = None
+ _lock = threading.Lock()
+
+ def __new__(cls, *args, **kwargs):
+ with cls._lock:
+ if cls._instance is None:
+ # 第一次创建实例时调用父类的 __new__ 来创建实例
+ cls._instance = super(GlobalThreadPool, cls).__new__(cls)
+ # 在此进行一次性的初始化,比如线程池的创建
+ max_workers = kwargs.get('max_workers', 10)
+ cls._instance.executor = ThreadPoolExecutor(max_workers=max_workers)
+ cls._instance.task_map = {} # 初始化任务映射
+ return cls._instance
+
+ def submit_task(self, fn, *args, thread_id=None, **kwargs):
+ """提交任务到线程池,并记录线程 ID"""
+ if thread_id is None:
+ thread_id = generate_thread_id()
+ if self.check_task_is_running(thread_id):
+ raise ValueError(f"线程 ID {thread_id} 已存在")
+ future = self.executor.submit(fn, *args, **kwargs)
+ self.task_map[thread_id] = future # 记录线程 ID 和 Future 对象的映射
+ return thread_id
+
+ def check_task_is_running(self, thread_id):
+ future = self.task_map.get(thread_id)
+ if future:
+ if future.running():
+ return True
+ else:
+ del self.task_map[thread_id]
+ return False
+ else:
+ return False
+
+ def stop_task(self, thread_id):
+ """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务"""
+ future = self.task_map.get(thread_id)
+ if future:
+ future.cancel() # 尝试取消任务
+ print(f"任务 {thread_id} 已取消")
+ del self.task_map[thread_id] # 从任务映射中删除
+ else:
+ print(f"未找到线程 ID {thread_id}")
+
+ def shutdown(self, wait=True):
+ """关闭线程池"""
+ self.executor.shutdown(wait=wait)
+ GlobalThreadPool._instance = None
diff --git a/db/__init__.py b/db/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/db/__init__.py
diff --git a/db/database.py b/db/database.py
new file mode 100644
index 0000000..4f44d2d
--- /dev/null
+++ b/db/database.py
@@ -0,0 +1,22 @@
+from sqlmodel import SQLModel, create_engine, Session
+from contextlib import contextmanager
+
+sqlite_file_name = "./db/safe-algo-pro.db"
+sqlite_url = f"sqlite:///{sqlite_file_name}"
+
+connect_args = {"check_same_thread": False}
+engine = create_engine(sqlite_url, connect_args=connect_args)
+
+
+# 初始化数据库表
+def init_db():
+ SQLModel.metadata.create_all(engine)
+
+
+# 数据库会话管理
+def get_db():
+ session = Session(engine)
+ try:
+ yield session
+ finally:
+ session.close()
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..e530a86
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+test.py
+/logs/*
\ No newline at end of file
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..901fe74
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,20 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..b68d82f
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..3deddad
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml
new file mode 100644
index 0000000..d0876a7
--- /dev/null
+++ b/.idea/safe-algo-pro.iml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..94a25f7
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/algo/__init__.py b/algo/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/algo/__init__.py
diff --git a/algo/algo_runner.py b/algo/algo_runner.py
new file mode 100644
index 0000000..ee07eaf
--- /dev/null
+++ b/algo/algo_runner.py
@@ -0,0 +1,117 @@
+from algo.model_manager import ModelManager
+from common.consts import NotifyChangeType
+from common.global_thread_pool import GlobalThreadPool
+from db.database import get_db
+from entity.device import Device
+from services.device_model_relation_service import DeviceModelRelationService
+from services.device_service import DeviceService
+from services.model_service import ModelService
+
+
+class AlgoRunner:
+ def __init__(self):
+ self.db = get_db()
+ self.device_service = DeviceService(self.db)
+ self.model_service = ModelService(self.db)
+ self.relation_service = DeviceModelRelationService(self.db)
+ self.model_manager = ModelManager(self.db)
+ self.thread_pool = GlobalThreadPool()
+ self.threads = {} # 用于存储设备对应的线程
+ self.device_model_relations = {}
+
+ # 注册设备和模型的变化回调
+ self.device_service.register_change_callback(self.on_device_change)
+ self.model_service.register_change_callback(self.on_model_change)
+ self.relation_service.relation_change_callbacks(self.on_relation_change)
+
+ def start(self):
+ """在程序启动时调用,读取设备和模型,启动检测线程"""
+ self.model_manager.load_models()
+ self.load_and_start_devices()
+
+ def load_and_start_devices(self):
+ """从数据库读取设备列表并启动线程"""
+ devices = self.device_service.get_device_list()
+ devices = [device for device in devices if device.input_stream_url]
+ for device in devices:
+ self.start_device_thread(device)
+
+ def run_detection(self, device: Device):
+ """todo 设备目标检测主逻辑"""
+ print(f"设备 {device.device_id} 的检测线程启动")
+ video_stream = self.device_service.get_video_stream(device.device_id)
+
+ while True:
+ try:
+ frame = video_stream.read_frame()
+ for model in self.model_manager.models:
+ # 调用目标检测模型
+ result = self.model_service.run_inference(model, frame)
+ # 在此处进行后处理,例如标记视频、生成告警等
+ self.process_result(result, device.device_id)
+ except Exception as e:
+ print(f"设备 {device.device_id} 处理时出错: {e}")
+ break
+ video_stream.close()
+ print(f"设备 {device.device_id} 的检测线程结束")
+
+ def start_device_thread(self, device: Device):
+ """为单个设备启动检测线程"""
+
+ if not device.input_stream_url:
+ print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程')
+ return
+
+ if device.device_id in self.threads:
+ print(f"设备 {device.device_id} 已经在运行中")
+ return
+
+ # 获取设备绑定的模型列表
+ relations = self.relation_service.get_device_models(device.device_id)
+ self.device_model_relations[device.device_id] = relations
+
+ if not relations:
+ print(f"设备 {device.code} 未绑定模型,无法启动检测")
+ return
+
+ future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}')
+
+ def stop_device_thread(self, device_id):
+ """todo 控制标志位 停止指定设备的检测线程"""
+ if device_id in self.threads:
+ self.threads[device_id].cancel() # 尝试取消线程
+ del self.threads[device_id]
+ print(f"设备 {device_id} 的检测线程已停止")
+
+ def restart_device_thread(self, device_id):
+ self.stop_device_thread(device_id)
+ # todo 需要留个关闭时间吗
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+
+ def on_device_change(self, device_id, change_type):
+ """设备变化回调"""
+ if change_type == NotifyChangeType.DEVICE_CREATE:
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ elif change_type == NotifyChangeType.DEVICE_DELETE:
+ self.stop_device_thread(device_id)
+ elif change_type == NotifyChangeType.DEVICE_UPDATE:
+ self.restart_device_thread(device_id)
+
+ def on_model_change(self, model_id, change_type):
+ """模型变化回调"""
+ if change_type == NotifyChangeType.MODEL_UPDATE:
+ self.model_manager.reload_model(model_id)
+
+ devices_to_reload = [
+ device_id for device_id, relation_info in self.device_model_relations.items()
+ if relation_info.model_id == model_id
+ ]
+ for device_id in devices_to_reload:
+ self.restart_device_thread(device_id)
+
+ def on_relation_change(self, device_id, change_type):
+ """设备模型关系变化回调"""
+ if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE:
+ self.restart_device_thread(device_id)
diff --git a/algo/model_manager.py b/algo/model_manager.py
new file mode 100644
index 0000000..3ab0709
--- /dev/null
+++ b/algo/model_manager.py
@@ -0,0 +1,64 @@
+from dataclasses import dataclass
+from typing import Optional
+
+import numpy as np
+from ultralytics import YOLO
+
+from entity.model import AlgoModel
+from services.model_service import ModelService
+from common.global_logger import logger
+
+
+@dataclass
+class AlgoModelExec:
+ algo_model_id: int
+ algo_model_info: AlgoModel
+ algo_model_exec: Optional[object] = None
+ input_size: int = 640
+
+
+class ModelManager:
+ def __init__(self, db, model_warm_up=5):
+ self.db = db
+ self.model_service = ModelService(self.db)
+ self.models = {}
+ self.model_warm_up = model_warm_up
+
+ self.load_models()
+
+ def query_model_inuse(self):
+ algo_model_list = list(self.model_service.get_models_in_use())
+ for algo_model in algo_model_list:
+ self.models[algo_model.id] = AlgoModelExec(
+ algo_model_id=algo_model.id,
+ algo_model_info=algo_model
+ )
+
+ def load_models(self):
+ self.models = {}
+ self.query_model_inuse()
+ for algo_model_id, algo_model_exec in self.models.items():
+ self.load_model(algo_model_exec)
+
+ def load_model(self, algo_model_exec: AlgoModelExec):
+ model_name = algo_model_exec.algo_model_info.name
+ model_path = algo_model_exec.algo_model_info.path
+ logger.info(f'loading model {model_name}: {model_path}')
+
+ algo_model_exec.algo_model_exec = YOLO(model_path, task='detect')
+ if self.model_warm_up > 0:
+ logger.info(f'warming up model {model_name}')
+ imgsz = algo_model_exec.input_size
+ if not isinstance(imgsz, list):
+ imgsz = [imgsz, imgsz]
+ dummy_input = np.zeros((imgsz[0], imgsz[1], 3))
+ for i in range(self.model_warm_up):
+ algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False)
+ logger.info(f'warm up model {model_name} success!')
+ logger.info(f'load model {model_name} success!')
+
+ def reload_model(self, model_id):
+ algo_model_exec = self.models.get(model_id)
+ if algo_model_exec:
+ algo_model_exec.algo_model_exec = None
+ self.load_model(algo_model_exec)
diff --git a/apis/__init__.py b/apis/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/apis/__init__.py
diff --git a/apis/base.py b/apis/base.py
new file mode 100644
index 0000000..a7c4ede
--- /dev/null
+++ b/apis/base.py
@@ -0,0 +1,26 @@
+from pydantic import BaseModel, Field
+from typing import Optional, Generic, TypeVar, Any, List
+
+# 定义一个泛型类型变量 T
+T = TypeVar("T")
+
+
+# 使用泛型 T 定义 data 的类型
+class StandardResponse(BaseModel, Generic[T]):
+ code: int = 200
+ data: Optional[T] = None
+ message: str = "请求成功"
+ success: bool = True
+
+
+class PageResponse(BaseModel, Generic[T]):
+ total: int
+ items: List[T] = Field(default_factory=list)
+
+
+def standard_response(data: Any = None, code: int = 200, message: str = "请求成功", success: bool = True):
+ return StandardResponse(data=data, code=code, message=message, success=success)
+
+
+def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False):
+ return StandardResponse(data=data, code=code, message=message, success=success)
diff --git a/apis/device.py b/apis/device.py
new file mode 100644
index 0000000..7eeaa78
--- /dev/null
+++ b/apis/device.py
@@ -0,0 +1,65 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, Query
+from sqlmodel import Session
+
+from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response
+from db.database import get_db
+from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo
+from services.device_service import DeviceService
+
+router = APIRouter()
+
+
+@router.get("/list", response_model=StandardResponse[List[DeviceInfo]])
+def get_device_list(
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ devices = list(service.get_device_list(name, code, device_type))
+ return standard_response(data=devices)
+
+
+@router.get("/page/", response_model=StandardResponse[PageResponse[Device]])
+def get_device_page(
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ offset: int = Query(0, ge=0), # 从第几条开始
+ limit: int = Query(10, ge=1), # 每页显示多少条记录
+ db: Session = Depends(get_db)):
+ service = DeviceService(db)
+
+ # 获取分页后的设备列表和总数
+ devices, total = service.get_device_page(name, code, device_type, offset, limit)
+
+ return standard_response(
+ data=PageResponse(total=total, items=devices)
+ )
+
+
+@router.post("/add", response_model=StandardResponse[DeviceInfo])
+def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.create_device(device_data)
+ return standard_response(data=device)
+
+
+@router.post("/update", response_model=StandardResponse[DeviceInfo])
+def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.update_device(device_data)
+ if not device:
+ return standard_error_response(data=device_data, message="Device not found")
+ return standard_response(data=device)
+
+
+@router.delete("/delete", response_model=StandardResponse[int])
+def delete_device(device_id: int, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.delete_device(device_id)
+ if not device:
+ return standard_error_response(data=device_id, message="Device not found")
+ return standard_response(data=device_id)
diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py
new file mode 100644
index 0000000..7108571
--- /dev/null
+++ b/apis/device_model_realtion.py
@@ -0,0 +1,45 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, Query
+from sqlmodel import Session
+
+from apis.base import StandardResponse, standard_response
+from db.database import get_db
+from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation
+from services.device_model_relation_service import DeviceModelRelationService
+
+router = APIRouter()
+
+
+@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]])
+def list_by_device(
+ device_id: int,
+ db: Session = Depends(get_db)):
+ service = DeviceModelRelationService(db)
+ models = list(service.get_device_models(device_id))
+ return standard_response(data=models)
+
+
+# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
+# def add_by_device(relation_data: List[DeviceModelRelationCreate],
+# device_id: int = Query(...),
+# db: Session = Depends(get_db)):
+# service = DeviceModelRelationService(db)
+# relations = service.add_relations_by_device(device_id, relation_data)
+# return standard_response(data=relations)
+
+
+@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
+def update_by_device(relation_data: List[DeviceModelRelationCreate],
+ device_id: int = Query(...),
+ db: Session = Depends(get_db)):
+ service = DeviceModelRelationService(db)
+ relations = service.update_relations_by_device(device_id, relation_data)
+ return standard_response(data=relations)
+
+
+# @router.delete("/delete_by_device", response_model=StandardResponse[int])
+# def delete_device(device_id: int, db: Session = Depends(get_db)):
+# service = DeviceModelRelationService(db)
+# count = service.delete_relations_by_device(device_id)
+# return standard_response(data=count)
diff --git a/apis/model.py b/apis/model.py
new file mode 100644
index 0000000..8adcc14
--- /dev/null
+++ b/apis/model.py
@@ -0,0 +1,64 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, HTTPException, Query
+from sqlmodel import Session
+
+from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response
+from common.biz_exception import BizException
+from db.database import get_db
+from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo
+from services.model_service import ModelService
+
+router = APIRouter()
+
+
+@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]])
+def get_model_list(
+ name: Optional[str] = None,
+ remark: Optional[str] = None,
+ db: Session = Depends(get_db)):
+ service = ModelService(db)
+ models = list(service.get_model_list(name, remark))
+ return standard_response(data=models)
+
+
+@router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModel]])
+def get_model_page(
+ name: Optional[str] = None,
+ remark: Optional[str] = None,
+ offset: int = Query(0, ge=0), # 从第几条开始
+ limit: int = Query(10, ge=1), # 每页显示多少条记录
+ db: Session = Depends(get_db)):
+ service = ModelService(db)
+
+ # 获取分页后的设备列表和总数
+ models, total = service.get_model_page(name, remark, offset, limit)
+
+ return standard_response(
+ data=PageResponse(total=total, items=models)
+ )
+
+
+@router.post("/add", response_model=StandardResponse[AlgoModelInfo])
+def create_model(model_data: AlgoModelCreate, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.create_model(model_data)
+ return standard_response(data=model)
+
+
+@router.post("/update", response_model=StandardResponse[AlgoModelInfo])
+def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.update_model(model_data)
+ if not model:
+ return standard_error_response(data=model_data, message="Model not found")
+ return standard_response(data=model)
+
+
+@router.delete("/delete", response_model=StandardResponse[int])
+def delete_model(model_id: int, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.delete_model(model_id)
+ if not model:
+ return standard_error_response(data=model_id, message="Model not found")
+ return standard_response(data=model_id)
diff --git a/apis/router.py b/apis/router.py
new file mode 100644
index 0000000..22b5ed6
--- /dev/null
+++ b/apis/router.py
@@ -0,0 +1,15 @@
+# api/router.py
+from fastapi import APIRouter
+from .device import router as devices_router
+from .model import router as models_router
+from .device_model_realtion import router as device_model_relation_router
+
+
+# 创建一个全局的 router
+router = APIRouter()
+
+# 将各个模块的 router 注册到全局 router 中
+router.include_router(devices_router, prefix="/device", tags=["Devices"])
+router.include_router(models_router, prefix="/model", tags=["Models"])
+router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"])
+
diff --git a/common/__init__.py b/common/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/common/__init__.py
diff --git a/common/biz_exception.py b/common/biz_exception.py
new file mode 100644
index 0000000..946a794
--- /dev/null
+++ b/common/biz_exception.py
@@ -0,0 +1,14 @@
+from fastapi import Request, HTTPException
+
+from apis.base import standard_error_response
+
+
+class BizException(HTTPException):
+ def __init__(self, message: str = "请求异常", status_code: int = 500):
+ super().__init__(status_code=status_code, detail=message)
+
+
+class BizExceptionHandlers:
+ @staticmethod
+ async def biz_exception_handler(request: Request, exc: HTTPException):
+ standard_error_response(code=exc.status_code, message=exc.detail)
diff --git a/common/consts.py b/common/consts.py
new file mode 100644
index 0000000..bba41e4
--- /dev/null
+++ b/common/consts.py
@@ -0,0 +1,16 @@
+class Constants:
+ def __setattr__(self, name, value):
+ raise AttributeError(f"Cannot modify constant {name}")
+
+
+class NotifyChangeType(Constants):
+ # 定义常量成员
+ DEVICE_CREATE = "device_create"
+ DEVICE_UPDATE = "device_update"
+ DEVICE_DELETE = "device_delete"
+ MODEL_CREATE = "model_create" # 模型新增不用通知,与设备绑定时才生效
+ MODEL_UPDATE = "model_update"
+ MODEL_DELETE = "model_delete" # 正在使用的模型不能直接删除
+ DEVICE_MODEL_RELATION_CREATE = "device_model_relation_create"
+ DEVICE_MODEL_RELATION_UPDATE = "device_model_relation_update" # 绑定关系变化 应该只用这个吧??
+ DEVICE_MODEL_RELATION_DELETE = "device_model_relation_delete"
diff --git a/common/global_logger.py b/common/global_logger.py
new file mode 100644
index 0000000..1d65173
--- /dev/null
+++ b/common/global_logger.py
@@ -0,0 +1,30 @@
+# logger.py
+import logging.handlers
+import os
+
+# 确保日志目录存在
+log_dir = '../logs'
+if not os.path.exists(log_dir):
+ os.makedirs(log_dir)
+
+# 实例化并导出全局日志记录器
+logger = logging.getLogger("casic_safe_logger")
+logger.setLevel(logging.DEBUG) # 设置日志级别
+
+# 创建一个TimedRotatingFileHandler
+handler = logging.handlers.TimedRotatingFileHandler(
+ os.path.join(log_dir, 'app.log'), # 日志文件名
+ when='midnight', # 每天午夜滚动
+ interval=1 # 滚动间隔为1天
+)
+formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+handler.setFormatter(formatter)
+
+# 将handler添加到日志器
+logger.addHandler(handler)
+
+# 创建控制台处理器
+console_handler = logging.StreamHandler()
+console_handler.setLevel(logging.DEBUG)
+console_handler.setFormatter(formatter)
+logger.addHandler(console_handler)
diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py
new file mode 100644
index 0000000..21d7fd5
--- /dev/null
+++ b/common/global_thread_pool.py
@@ -0,0 +1,60 @@
+import uuid
+from concurrent.futures import ThreadPoolExecutor
+import threading
+
+
+def generate_thread_id():
+ """生成唯一的线程 ID"""
+ return str(uuid.uuid4())
+
+
+class GlobalThreadPool:
+ _instance = None
+ _lock = threading.Lock()
+
+ def __new__(cls, *args, **kwargs):
+ with cls._lock:
+ if cls._instance is None:
+ # 第一次创建实例时调用父类的 __new__ 来创建实例
+ cls._instance = super(GlobalThreadPool, cls).__new__(cls)
+ # 在此进行一次性的初始化,比如线程池的创建
+ max_workers = kwargs.get('max_workers', 10)
+ cls._instance.executor = ThreadPoolExecutor(max_workers=max_workers)
+ cls._instance.task_map = {} # 初始化任务映射
+ return cls._instance
+
+ def submit_task(self, fn, *args, thread_id=None, **kwargs):
+ """提交任务到线程池,并记录线程 ID"""
+ if thread_id is None:
+ thread_id = generate_thread_id()
+ if self.check_task_is_running(thread_id):
+ raise ValueError(f"线程 ID {thread_id} 已存在")
+ future = self.executor.submit(fn, *args, **kwargs)
+ self.task_map[thread_id] = future # 记录线程 ID 和 Future 对象的映射
+ return thread_id
+
+ def check_task_is_running(self, thread_id):
+ future = self.task_map.get(thread_id)
+ if future:
+ if future.running():
+ return True
+ else:
+ del self.task_map[thread_id]
+ return False
+ else:
+ return False
+
+ def stop_task(self, thread_id):
+ """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务"""
+ future = self.task_map.get(thread_id)
+ if future:
+ future.cancel() # 尝试取消任务
+ print(f"任务 {thread_id} 已取消")
+ del self.task_map[thread_id] # 从任务映射中删除
+ else:
+ print(f"未找到线程 ID {thread_id}")
+
+ def shutdown(self, wait=True):
+ """关闭线程池"""
+ self.executor.shutdown(wait=wait)
+ GlobalThreadPool._instance = None
diff --git a/db/__init__.py b/db/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/db/__init__.py
diff --git a/db/database.py b/db/database.py
new file mode 100644
index 0000000..4f44d2d
--- /dev/null
+++ b/db/database.py
@@ -0,0 +1,22 @@
+from sqlmodel import SQLModel, create_engine, Session
+from contextlib import contextmanager
+
+sqlite_file_name = "./db/safe-algo-pro.db"
+sqlite_url = f"sqlite:///{sqlite_file_name}"
+
+connect_args = {"check_same_thread": False}
+engine = create_engine(sqlite_url, connect_args=connect_args)
+
+
+# 初始化数据库表
+def init_db():
+ SQLModel.metadata.create_all(engine)
+
+
+# 数据库会话管理
+def get_db():
+ session = Session(engine)
+ try:
+ yield session
+ finally:
+ session.close()
diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db
new file mode 100644
index 0000000..86c8d64
--- /dev/null
+++ b/db/safe-algo-pro.db
Binary files differ
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..e530a86
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+test.py
+/logs/*
\ No newline at end of file
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..901fe74
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,20 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..b68d82f
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..3deddad
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml
new file mode 100644
index 0000000..d0876a7
--- /dev/null
+++ b/.idea/safe-algo-pro.iml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..94a25f7
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/algo/__init__.py b/algo/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/algo/__init__.py
diff --git a/algo/algo_runner.py b/algo/algo_runner.py
new file mode 100644
index 0000000..ee07eaf
--- /dev/null
+++ b/algo/algo_runner.py
@@ -0,0 +1,117 @@
+from algo.model_manager import ModelManager
+from common.consts import NotifyChangeType
+from common.global_thread_pool import GlobalThreadPool
+from db.database import get_db
+from entity.device import Device
+from services.device_model_relation_service import DeviceModelRelationService
+from services.device_service import DeviceService
+from services.model_service import ModelService
+
+
+class AlgoRunner:
+ def __init__(self):
+ self.db = get_db()
+ self.device_service = DeviceService(self.db)
+ self.model_service = ModelService(self.db)
+ self.relation_service = DeviceModelRelationService(self.db)
+ self.model_manager = ModelManager(self.db)
+ self.thread_pool = GlobalThreadPool()
+ self.threads = {} # 用于存储设备对应的线程
+ self.device_model_relations = {}
+
+ # 注册设备和模型的变化回调
+ self.device_service.register_change_callback(self.on_device_change)
+ self.model_service.register_change_callback(self.on_model_change)
+ self.relation_service.relation_change_callbacks(self.on_relation_change)
+
+ def start(self):
+ """在程序启动时调用,读取设备和模型,启动检测线程"""
+ self.model_manager.load_models()
+ self.load_and_start_devices()
+
+ def load_and_start_devices(self):
+ """从数据库读取设备列表并启动线程"""
+ devices = self.device_service.get_device_list()
+ devices = [device for device in devices if device.input_stream_url]
+ for device in devices:
+ self.start_device_thread(device)
+
+ def run_detection(self, device: Device):
+ """todo 设备目标检测主逻辑"""
+ print(f"设备 {device.device_id} 的检测线程启动")
+ video_stream = self.device_service.get_video_stream(device.device_id)
+
+ while True:
+ try:
+ frame = video_stream.read_frame()
+ for model in self.model_manager.models:
+ # 调用目标检测模型
+ result = self.model_service.run_inference(model, frame)
+ # 在此处进行后处理,例如标记视频、生成告警等
+ self.process_result(result, device.device_id)
+ except Exception as e:
+ print(f"设备 {device.device_id} 处理时出错: {e}")
+ break
+ video_stream.close()
+ print(f"设备 {device.device_id} 的检测线程结束")
+
+ def start_device_thread(self, device: Device):
+ """为单个设备启动检测线程"""
+
+ if not device.input_stream_url:
+ print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程')
+ return
+
+ if device.device_id in self.threads:
+ print(f"设备 {device.device_id} 已经在运行中")
+ return
+
+ # 获取设备绑定的模型列表
+ relations = self.relation_service.get_device_models(device.device_id)
+ self.device_model_relations[device.device_id] = relations
+
+ if not relations:
+ print(f"设备 {device.code} 未绑定模型,无法启动检测")
+ return
+
+ future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}')
+
+ def stop_device_thread(self, device_id):
+ """todo 控制标志位 停止指定设备的检测线程"""
+ if device_id in self.threads:
+ self.threads[device_id].cancel() # 尝试取消线程
+ del self.threads[device_id]
+ print(f"设备 {device_id} 的检测线程已停止")
+
+ def restart_device_thread(self, device_id):
+ self.stop_device_thread(device_id)
+ # todo 需要留个关闭时间吗
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+
+ def on_device_change(self, device_id, change_type):
+ """设备变化回调"""
+ if change_type == NotifyChangeType.DEVICE_CREATE:
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ elif change_type == NotifyChangeType.DEVICE_DELETE:
+ self.stop_device_thread(device_id)
+ elif change_type == NotifyChangeType.DEVICE_UPDATE:
+ self.restart_device_thread(device_id)
+
+ def on_model_change(self, model_id, change_type):
+ """模型变化回调"""
+ if change_type == NotifyChangeType.MODEL_UPDATE:
+ self.model_manager.reload_model(model_id)
+
+ devices_to_reload = [
+ device_id for device_id, relation_info in self.device_model_relations.items()
+ if relation_info.model_id == model_id
+ ]
+ for device_id in devices_to_reload:
+ self.restart_device_thread(device_id)
+
+ def on_relation_change(self, device_id, change_type):
+ """设备模型关系变化回调"""
+ if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE:
+ self.restart_device_thread(device_id)
diff --git a/algo/model_manager.py b/algo/model_manager.py
new file mode 100644
index 0000000..3ab0709
--- /dev/null
+++ b/algo/model_manager.py
@@ -0,0 +1,64 @@
+from dataclasses import dataclass
+from typing import Optional
+
+import numpy as np
+from ultralytics import YOLO
+
+from entity.model import AlgoModel
+from services.model_service import ModelService
+from common.global_logger import logger
+
+
+@dataclass
+class AlgoModelExec:
+ algo_model_id: int
+ algo_model_info: AlgoModel
+ algo_model_exec: Optional[object] = None
+ input_size: int = 640
+
+
+class ModelManager:
+ def __init__(self, db, model_warm_up=5):
+ self.db = db
+ self.model_service = ModelService(self.db)
+ self.models = {}
+ self.model_warm_up = model_warm_up
+
+ self.load_models()
+
+ def query_model_inuse(self):
+ algo_model_list = list(self.model_service.get_models_in_use())
+ for algo_model in algo_model_list:
+ self.models[algo_model.id] = AlgoModelExec(
+ algo_model_id=algo_model.id,
+ algo_model_info=algo_model
+ )
+
+ def load_models(self):
+ self.models = {}
+ self.query_model_inuse()
+ for algo_model_id, algo_model_exec in self.models.items():
+ self.load_model(algo_model_exec)
+
+ def load_model(self, algo_model_exec: AlgoModelExec):
+ model_name = algo_model_exec.algo_model_info.name
+ model_path = algo_model_exec.algo_model_info.path
+ logger.info(f'loading model {model_name}: {model_path}')
+
+ algo_model_exec.algo_model_exec = YOLO(model_path, task='detect')
+ if self.model_warm_up > 0:
+ logger.info(f'warming up model {model_name}')
+ imgsz = algo_model_exec.input_size
+ if not isinstance(imgsz, list):
+ imgsz = [imgsz, imgsz]
+ dummy_input = np.zeros((imgsz[0], imgsz[1], 3))
+ for i in range(self.model_warm_up):
+ algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False)
+ logger.info(f'warm up model {model_name} success!')
+ logger.info(f'load model {model_name} success!')
+
+ def reload_model(self, model_id):
+ algo_model_exec = self.models.get(model_id)
+ if algo_model_exec:
+ algo_model_exec.algo_model_exec = None
+ self.load_model(algo_model_exec)
diff --git a/apis/__init__.py b/apis/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/apis/__init__.py
diff --git a/apis/base.py b/apis/base.py
new file mode 100644
index 0000000..a7c4ede
--- /dev/null
+++ b/apis/base.py
@@ -0,0 +1,26 @@
+from pydantic import BaseModel, Field
+from typing import Optional, Generic, TypeVar, Any, List
+
+# 定义一个泛型类型变量 T
+T = TypeVar("T")
+
+
+# 使用泛型 T 定义 data 的类型
+class StandardResponse(BaseModel, Generic[T]):
+ code: int = 200
+ data: Optional[T] = None
+ message: str = "请求成功"
+ success: bool = True
+
+
+class PageResponse(BaseModel, Generic[T]):
+ total: int
+ items: List[T] = Field(default_factory=list)
+
+
+def standard_response(data: Any = None, code: int = 200, message: str = "请求成功", success: bool = True):
+ return StandardResponse(data=data, code=code, message=message, success=success)
+
+
+def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False):
+ return StandardResponse(data=data, code=code, message=message, success=success)
diff --git a/apis/device.py b/apis/device.py
new file mode 100644
index 0000000..7eeaa78
--- /dev/null
+++ b/apis/device.py
@@ -0,0 +1,65 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, Query
+from sqlmodel import Session
+
+from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response
+from db.database import get_db
+from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo
+from services.device_service import DeviceService
+
+router = APIRouter()
+
+
+@router.get("/list", response_model=StandardResponse[List[DeviceInfo]])
+def get_device_list(
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ devices = list(service.get_device_list(name, code, device_type))
+ return standard_response(data=devices)
+
+
+@router.get("/page/", response_model=StandardResponse[PageResponse[Device]])
+def get_device_page(
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ offset: int = Query(0, ge=0), # 从第几条开始
+ limit: int = Query(10, ge=1), # 每页显示多少条记录
+ db: Session = Depends(get_db)):
+ service = DeviceService(db)
+
+ # 获取分页后的设备列表和总数
+ devices, total = service.get_device_page(name, code, device_type, offset, limit)
+
+ return standard_response(
+ data=PageResponse(total=total, items=devices)
+ )
+
+
+@router.post("/add", response_model=StandardResponse[DeviceInfo])
+def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.create_device(device_data)
+ return standard_response(data=device)
+
+
+@router.post("/update", response_model=StandardResponse[DeviceInfo])
+def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.update_device(device_data)
+ if not device:
+ return standard_error_response(data=device_data, message="Device not found")
+ return standard_response(data=device)
+
+
+@router.delete("/delete", response_model=StandardResponse[int])
+def delete_device(device_id: int, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.delete_device(device_id)
+ if not device:
+ return standard_error_response(data=device_id, message="Device not found")
+ return standard_response(data=device_id)
diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py
new file mode 100644
index 0000000..7108571
--- /dev/null
+++ b/apis/device_model_realtion.py
@@ -0,0 +1,45 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, Query
+from sqlmodel import Session
+
+from apis.base import StandardResponse, standard_response
+from db.database import get_db
+from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation
+from services.device_model_relation_service import DeviceModelRelationService
+
+router = APIRouter()
+
+
+@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]])
+def list_by_device(
+ device_id: int,
+ db: Session = Depends(get_db)):
+ service = DeviceModelRelationService(db)
+ models = list(service.get_device_models(device_id))
+ return standard_response(data=models)
+
+
+# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
+# def add_by_device(relation_data: List[DeviceModelRelationCreate],
+# device_id: int = Query(...),
+# db: Session = Depends(get_db)):
+# service = DeviceModelRelationService(db)
+# relations = service.add_relations_by_device(device_id, relation_data)
+# return standard_response(data=relations)
+
+
+@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
+def update_by_device(relation_data: List[DeviceModelRelationCreate],
+ device_id: int = Query(...),
+ db: Session = Depends(get_db)):
+ service = DeviceModelRelationService(db)
+ relations = service.update_relations_by_device(device_id, relation_data)
+ return standard_response(data=relations)
+
+
+# @router.delete("/delete_by_device", response_model=StandardResponse[int])
+# def delete_device(device_id: int, db: Session = Depends(get_db)):
+# service = DeviceModelRelationService(db)
+# count = service.delete_relations_by_device(device_id)
+# return standard_response(data=count)
diff --git a/apis/model.py b/apis/model.py
new file mode 100644
index 0000000..8adcc14
--- /dev/null
+++ b/apis/model.py
@@ -0,0 +1,64 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, HTTPException, Query
+from sqlmodel import Session
+
+from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response
+from common.biz_exception import BizException
+from db.database import get_db
+from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo
+from services.model_service import ModelService
+
+router = APIRouter()
+
+
+@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]])
+def get_model_list(
+ name: Optional[str] = None,
+ remark: Optional[str] = None,
+ db: Session = Depends(get_db)):
+ service = ModelService(db)
+ models = list(service.get_model_list(name, remark))
+ return standard_response(data=models)
+
+
+@router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModel]])
+def get_model_page(
+ name: Optional[str] = None,
+ remark: Optional[str] = None,
+ offset: int = Query(0, ge=0), # 从第几条开始
+ limit: int = Query(10, ge=1), # 每页显示多少条记录
+ db: Session = Depends(get_db)):
+ service = ModelService(db)
+
+ # 获取分页后的设备列表和总数
+ models, total = service.get_model_page(name, remark, offset, limit)
+
+ return standard_response(
+ data=PageResponse(total=total, items=models)
+ )
+
+
+@router.post("/add", response_model=StandardResponse[AlgoModelInfo])
+def create_model(model_data: AlgoModelCreate, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.create_model(model_data)
+ return standard_response(data=model)
+
+
+@router.post("/update", response_model=StandardResponse[AlgoModelInfo])
+def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.update_model(model_data)
+ if not model:
+ return standard_error_response(data=model_data, message="Model not found")
+ return standard_response(data=model)
+
+
+@router.delete("/delete", response_model=StandardResponse[int])
+def delete_model(model_id: int, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.delete_model(model_id)
+ if not model:
+ return standard_error_response(data=model_id, message="Model not found")
+ return standard_response(data=model_id)
diff --git a/apis/router.py b/apis/router.py
new file mode 100644
index 0000000..22b5ed6
--- /dev/null
+++ b/apis/router.py
@@ -0,0 +1,15 @@
+# api/router.py
+from fastapi import APIRouter
+from .device import router as devices_router
+from .model import router as models_router
+from .device_model_realtion import router as device_model_relation_router
+
+
+# 创建一个全局的 router
+router = APIRouter()
+
+# 将各个模块的 router 注册到全局 router 中
+router.include_router(devices_router, prefix="/device", tags=["Devices"])
+router.include_router(models_router, prefix="/model", tags=["Models"])
+router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"])
+
diff --git a/common/__init__.py b/common/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/common/__init__.py
diff --git a/common/biz_exception.py b/common/biz_exception.py
new file mode 100644
index 0000000..946a794
--- /dev/null
+++ b/common/biz_exception.py
@@ -0,0 +1,14 @@
+from fastapi import Request, HTTPException
+
+from apis.base import standard_error_response
+
+
+class BizException(HTTPException):
+ def __init__(self, message: str = "请求异常", status_code: int = 500):
+ super().__init__(status_code=status_code, detail=message)
+
+
+class BizExceptionHandlers:
+ @staticmethod
+ async def biz_exception_handler(request: Request, exc: HTTPException):
+ standard_error_response(code=exc.status_code, message=exc.detail)
diff --git a/common/consts.py b/common/consts.py
new file mode 100644
index 0000000..bba41e4
--- /dev/null
+++ b/common/consts.py
@@ -0,0 +1,16 @@
+class Constants:
+ def __setattr__(self, name, value):
+ raise AttributeError(f"Cannot modify constant {name}")
+
+
+class NotifyChangeType(Constants):
+ # 定义常量成员
+ DEVICE_CREATE = "device_create"
+ DEVICE_UPDATE = "device_update"
+ DEVICE_DELETE = "device_delete"
+ MODEL_CREATE = "model_create" # 模型新增不用通知,与设备绑定时才生效
+ MODEL_UPDATE = "model_update"
+ MODEL_DELETE = "model_delete" # 正在使用的模型不能直接删除
+ DEVICE_MODEL_RELATION_CREATE = "device_model_relation_create"
+ DEVICE_MODEL_RELATION_UPDATE = "device_model_relation_update" # 绑定关系变化 应该只用这个吧??
+ DEVICE_MODEL_RELATION_DELETE = "device_model_relation_delete"
diff --git a/common/global_logger.py b/common/global_logger.py
new file mode 100644
index 0000000..1d65173
--- /dev/null
+++ b/common/global_logger.py
@@ -0,0 +1,30 @@
+# logger.py
+import logging.handlers
+import os
+
+# 确保日志目录存在
+log_dir = '../logs'
+if not os.path.exists(log_dir):
+ os.makedirs(log_dir)
+
+# 实例化并导出全局日志记录器
+logger = logging.getLogger("casic_safe_logger")
+logger.setLevel(logging.DEBUG) # 设置日志级别
+
+# 创建一个TimedRotatingFileHandler
+handler = logging.handlers.TimedRotatingFileHandler(
+ os.path.join(log_dir, 'app.log'), # 日志文件名
+ when='midnight', # 每天午夜滚动
+ interval=1 # 滚动间隔为1天
+)
+formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+handler.setFormatter(formatter)
+
+# 将handler添加到日志器
+logger.addHandler(handler)
+
+# 创建控制台处理器
+console_handler = logging.StreamHandler()
+console_handler.setLevel(logging.DEBUG)
+console_handler.setFormatter(formatter)
+logger.addHandler(console_handler)
diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py
new file mode 100644
index 0000000..21d7fd5
--- /dev/null
+++ b/common/global_thread_pool.py
@@ -0,0 +1,60 @@
+import uuid
+from concurrent.futures import ThreadPoolExecutor
+import threading
+
+
+def generate_thread_id():
+ """生成唯一的线程 ID"""
+ return str(uuid.uuid4())
+
+
+class GlobalThreadPool:
+ _instance = None
+ _lock = threading.Lock()
+
+ def __new__(cls, *args, **kwargs):
+ with cls._lock:
+ if cls._instance is None:
+ # 第一次创建实例时调用父类的 __new__ 来创建实例
+ cls._instance = super(GlobalThreadPool, cls).__new__(cls)
+ # 在此进行一次性的初始化,比如线程池的创建
+ max_workers = kwargs.get('max_workers', 10)
+ cls._instance.executor = ThreadPoolExecutor(max_workers=max_workers)
+ cls._instance.task_map = {} # 初始化任务映射
+ return cls._instance
+
+ def submit_task(self, fn, *args, thread_id=None, **kwargs):
+ """提交任务到线程池,并记录线程 ID"""
+ if thread_id is None:
+ thread_id = generate_thread_id()
+ if self.check_task_is_running(thread_id):
+ raise ValueError(f"线程 ID {thread_id} 已存在")
+ future = self.executor.submit(fn, *args, **kwargs)
+ self.task_map[thread_id] = future # 记录线程 ID 和 Future 对象的映射
+ return thread_id
+
+ def check_task_is_running(self, thread_id):
+ future = self.task_map.get(thread_id)
+ if future:
+ if future.running():
+ return True
+ else:
+ del self.task_map[thread_id]
+ return False
+ else:
+ return False
+
+ def stop_task(self, thread_id):
+ """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务"""
+ future = self.task_map.get(thread_id)
+ if future:
+ future.cancel() # 尝试取消任务
+ print(f"任务 {thread_id} 已取消")
+ del self.task_map[thread_id] # 从任务映射中删除
+ else:
+ print(f"未找到线程 ID {thread_id}")
+
+ def shutdown(self, wait=True):
+ """关闭线程池"""
+ self.executor.shutdown(wait=wait)
+ GlobalThreadPool._instance = None
diff --git a/db/__init__.py b/db/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/db/__init__.py
diff --git a/db/database.py b/db/database.py
new file mode 100644
index 0000000..4f44d2d
--- /dev/null
+++ b/db/database.py
@@ -0,0 +1,22 @@
+from sqlmodel import SQLModel, create_engine, Session
+from contextlib import contextmanager
+
+sqlite_file_name = "./db/safe-algo-pro.db"
+sqlite_url = f"sqlite:///{sqlite_file_name}"
+
+connect_args = {"check_same_thread": False}
+engine = create_engine(sqlite_url, connect_args=connect_args)
+
+
+# 初始化数据库表
+def init_db():
+ SQLModel.metadata.create_all(engine)
+
+
+# 数据库会话管理
+def get_db():
+ session = Session(engine)
+ try:
+ yield session
+ finally:
+ session.close()
diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db
new file mode 100644
index 0000000..86c8d64
--- /dev/null
+++ b/db/safe-algo-pro.db
Binary files differ
diff --git a/entity/__init__.py b/entity/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/entity/__init__.py
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..e530a86
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+test.py
+/logs/*
\ No newline at end of file
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..901fe74
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,20 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..b68d82f
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..3deddad
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml
new file mode 100644
index 0000000..d0876a7
--- /dev/null
+++ b/.idea/safe-algo-pro.iml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..94a25f7
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/algo/__init__.py b/algo/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/algo/__init__.py
diff --git a/algo/algo_runner.py b/algo/algo_runner.py
new file mode 100644
index 0000000..ee07eaf
--- /dev/null
+++ b/algo/algo_runner.py
@@ -0,0 +1,117 @@
+from algo.model_manager import ModelManager
+from common.consts import NotifyChangeType
+from common.global_thread_pool import GlobalThreadPool
+from db.database import get_db
+from entity.device import Device
+from services.device_model_relation_service import DeviceModelRelationService
+from services.device_service import DeviceService
+from services.model_service import ModelService
+
+
+class AlgoRunner:
+ def __init__(self):
+ self.db = get_db()
+ self.device_service = DeviceService(self.db)
+ self.model_service = ModelService(self.db)
+ self.relation_service = DeviceModelRelationService(self.db)
+ self.model_manager = ModelManager(self.db)
+ self.thread_pool = GlobalThreadPool()
+ self.threads = {} # 用于存储设备对应的线程
+ self.device_model_relations = {}
+
+ # 注册设备和模型的变化回调
+ self.device_service.register_change_callback(self.on_device_change)
+ self.model_service.register_change_callback(self.on_model_change)
+ self.relation_service.relation_change_callbacks(self.on_relation_change)
+
+ def start(self):
+ """在程序启动时调用,读取设备和模型,启动检测线程"""
+ self.model_manager.load_models()
+ self.load_and_start_devices()
+
+ def load_and_start_devices(self):
+ """从数据库读取设备列表并启动线程"""
+ devices = self.device_service.get_device_list()
+ devices = [device for device in devices if device.input_stream_url]
+ for device in devices:
+ self.start_device_thread(device)
+
+ def run_detection(self, device: Device):
+ """todo 设备目标检测主逻辑"""
+ print(f"设备 {device.device_id} 的检测线程启动")
+ video_stream = self.device_service.get_video_stream(device.device_id)
+
+ while True:
+ try:
+ frame = video_stream.read_frame()
+ for model in self.model_manager.models:
+ # 调用目标检测模型
+ result = self.model_service.run_inference(model, frame)
+ # 在此处进行后处理,例如标记视频、生成告警等
+ self.process_result(result, device.device_id)
+ except Exception as e:
+ print(f"设备 {device.device_id} 处理时出错: {e}")
+ break
+ video_stream.close()
+ print(f"设备 {device.device_id} 的检测线程结束")
+
+ def start_device_thread(self, device: Device):
+ """为单个设备启动检测线程"""
+
+ if not device.input_stream_url:
+ print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程')
+ return
+
+ if device.device_id in self.threads:
+ print(f"设备 {device.device_id} 已经在运行中")
+ return
+
+ # 获取设备绑定的模型列表
+ relations = self.relation_service.get_device_models(device.device_id)
+ self.device_model_relations[device.device_id] = relations
+
+ if not relations:
+ print(f"设备 {device.code} 未绑定模型,无法启动检测")
+ return
+
+ future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}')
+
+ def stop_device_thread(self, device_id):
+ """todo 控制标志位 停止指定设备的检测线程"""
+ if device_id in self.threads:
+ self.threads[device_id].cancel() # 尝试取消线程
+ del self.threads[device_id]
+ print(f"设备 {device_id} 的检测线程已停止")
+
+ def restart_device_thread(self, device_id):
+ self.stop_device_thread(device_id)
+ # todo 需要留个关闭时间吗
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+
+ def on_device_change(self, device_id, change_type):
+ """设备变化回调"""
+ if change_type == NotifyChangeType.DEVICE_CREATE:
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ elif change_type == NotifyChangeType.DEVICE_DELETE:
+ self.stop_device_thread(device_id)
+ elif change_type == NotifyChangeType.DEVICE_UPDATE:
+ self.restart_device_thread(device_id)
+
+ def on_model_change(self, model_id, change_type):
+ """模型变化回调"""
+ if change_type == NotifyChangeType.MODEL_UPDATE:
+ self.model_manager.reload_model(model_id)
+
+ devices_to_reload = [
+ device_id for device_id, relation_info in self.device_model_relations.items()
+ if relation_info.model_id == model_id
+ ]
+ for device_id in devices_to_reload:
+ self.restart_device_thread(device_id)
+
+ def on_relation_change(self, device_id, change_type):
+ """设备模型关系变化回调"""
+ if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE:
+ self.restart_device_thread(device_id)
diff --git a/algo/model_manager.py b/algo/model_manager.py
new file mode 100644
index 0000000..3ab0709
--- /dev/null
+++ b/algo/model_manager.py
@@ -0,0 +1,64 @@
+from dataclasses import dataclass
+from typing import Optional
+
+import numpy as np
+from ultralytics import YOLO
+
+from entity.model import AlgoModel
+from services.model_service import ModelService
+from common.global_logger import logger
+
+
+@dataclass
+class AlgoModelExec:
+ algo_model_id: int
+ algo_model_info: AlgoModel
+ algo_model_exec: Optional[object] = None
+ input_size: int = 640
+
+
+class ModelManager:
+ def __init__(self, db, model_warm_up=5):
+ self.db = db
+ self.model_service = ModelService(self.db)
+ self.models = {}
+ self.model_warm_up = model_warm_up
+
+ self.load_models()
+
+ def query_model_inuse(self):
+ algo_model_list = list(self.model_service.get_models_in_use())
+ for algo_model in algo_model_list:
+ self.models[algo_model.id] = AlgoModelExec(
+ algo_model_id=algo_model.id,
+ algo_model_info=algo_model
+ )
+
+ def load_models(self):
+ self.models = {}
+ self.query_model_inuse()
+ for algo_model_id, algo_model_exec in self.models.items():
+ self.load_model(algo_model_exec)
+
+ def load_model(self, algo_model_exec: AlgoModelExec):
+ model_name = algo_model_exec.algo_model_info.name
+ model_path = algo_model_exec.algo_model_info.path
+ logger.info(f'loading model {model_name}: {model_path}')
+
+ algo_model_exec.algo_model_exec = YOLO(model_path, task='detect')
+ if self.model_warm_up > 0:
+ logger.info(f'warming up model {model_name}')
+ imgsz = algo_model_exec.input_size
+ if not isinstance(imgsz, list):
+ imgsz = [imgsz, imgsz]
+ dummy_input = np.zeros((imgsz[0], imgsz[1], 3))
+ for i in range(self.model_warm_up):
+ algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False)
+ logger.info(f'warm up model {model_name} success!')
+ logger.info(f'load model {model_name} success!')
+
+ def reload_model(self, model_id):
+ algo_model_exec = self.models.get(model_id)
+ if algo_model_exec:
+ algo_model_exec.algo_model_exec = None
+ self.load_model(algo_model_exec)
diff --git a/apis/__init__.py b/apis/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/apis/__init__.py
diff --git a/apis/base.py b/apis/base.py
new file mode 100644
index 0000000..a7c4ede
--- /dev/null
+++ b/apis/base.py
@@ -0,0 +1,26 @@
+from pydantic import BaseModel, Field
+from typing import Optional, Generic, TypeVar, Any, List
+
+# 定义一个泛型类型变量 T
+T = TypeVar("T")
+
+
+# 使用泛型 T 定义 data 的类型
+class StandardResponse(BaseModel, Generic[T]):
+ code: int = 200
+ data: Optional[T] = None
+ message: str = "请求成功"
+ success: bool = True
+
+
+class PageResponse(BaseModel, Generic[T]):
+ total: int
+ items: List[T] = Field(default_factory=list)
+
+
+def standard_response(data: Any = None, code: int = 200, message: str = "请求成功", success: bool = True):
+ return StandardResponse(data=data, code=code, message=message, success=success)
+
+
+def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False):
+ return StandardResponse(data=data, code=code, message=message, success=success)
diff --git a/apis/device.py b/apis/device.py
new file mode 100644
index 0000000..7eeaa78
--- /dev/null
+++ b/apis/device.py
@@ -0,0 +1,65 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, Query
+from sqlmodel import Session
+
+from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response
+from db.database import get_db
+from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo
+from services.device_service import DeviceService
+
+router = APIRouter()
+
+
+@router.get("/list", response_model=StandardResponse[List[DeviceInfo]])
+def get_device_list(
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ devices = list(service.get_device_list(name, code, device_type))
+ return standard_response(data=devices)
+
+
+@router.get("/page/", response_model=StandardResponse[PageResponse[Device]])
+def get_device_page(
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ offset: int = Query(0, ge=0), # 从第几条开始
+ limit: int = Query(10, ge=1), # 每页显示多少条记录
+ db: Session = Depends(get_db)):
+ service = DeviceService(db)
+
+ # 获取分页后的设备列表和总数
+ devices, total = service.get_device_page(name, code, device_type, offset, limit)
+
+ return standard_response(
+ data=PageResponse(total=total, items=devices)
+ )
+
+
+@router.post("/add", response_model=StandardResponse[DeviceInfo])
+def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.create_device(device_data)
+ return standard_response(data=device)
+
+
+@router.post("/update", response_model=StandardResponse[DeviceInfo])
+def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.update_device(device_data)
+ if not device:
+ return standard_error_response(data=device_data, message="Device not found")
+ return standard_response(data=device)
+
+
+@router.delete("/delete", response_model=StandardResponse[int])
+def delete_device(device_id: int, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.delete_device(device_id)
+ if not device:
+ return standard_error_response(data=device_id, message="Device not found")
+ return standard_response(data=device_id)
diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py
new file mode 100644
index 0000000..7108571
--- /dev/null
+++ b/apis/device_model_realtion.py
@@ -0,0 +1,45 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, Query
+from sqlmodel import Session
+
+from apis.base import StandardResponse, standard_response
+from db.database import get_db
+from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation
+from services.device_model_relation_service import DeviceModelRelationService
+
+router = APIRouter()
+
+
+@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]])
+def list_by_device(
+ device_id: int,
+ db: Session = Depends(get_db)):
+ service = DeviceModelRelationService(db)
+ models = list(service.get_device_models(device_id))
+ return standard_response(data=models)
+
+
+# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
+# def add_by_device(relation_data: List[DeviceModelRelationCreate],
+# device_id: int = Query(...),
+# db: Session = Depends(get_db)):
+# service = DeviceModelRelationService(db)
+# relations = service.add_relations_by_device(device_id, relation_data)
+# return standard_response(data=relations)
+
+
+@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
+def update_by_device(relation_data: List[DeviceModelRelationCreate],
+ device_id: int = Query(...),
+ db: Session = Depends(get_db)):
+ service = DeviceModelRelationService(db)
+ relations = service.update_relations_by_device(device_id, relation_data)
+ return standard_response(data=relations)
+
+
+# @router.delete("/delete_by_device", response_model=StandardResponse[int])
+# def delete_device(device_id: int, db: Session = Depends(get_db)):
+# service = DeviceModelRelationService(db)
+# count = service.delete_relations_by_device(device_id)
+# return standard_response(data=count)
diff --git a/apis/model.py b/apis/model.py
new file mode 100644
index 0000000..8adcc14
--- /dev/null
+++ b/apis/model.py
@@ -0,0 +1,64 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, HTTPException, Query
+from sqlmodel import Session
+
+from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response
+from common.biz_exception import BizException
+from db.database import get_db
+from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo
+from services.model_service import ModelService
+
+router = APIRouter()
+
+
+@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]])
+def get_model_list(
+ name: Optional[str] = None,
+ remark: Optional[str] = None,
+ db: Session = Depends(get_db)):
+ service = ModelService(db)
+ models = list(service.get_model_list(name, remark))
+ return standard_response(data=models)
+
+
+@router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModel]])
+def get_model_page(
+ name: Optional[str] = None,
+ remark: Optional[str] = None,
+ offset: int = Query(0, ge=0), # 从第几条开始
+ limit: int = Query(10, ge=1), # 每页显示多少条记录
+ db: Session = Depends(get_db)):
+ service = ModelService(db)
+
+ # 获取分页后的设备列表和总数
+ models, total = service.get_model_page(name, remark, offset, limit)
+
+ return standard_response(
+ data=PageResponse(total=total, items=models)
+ )
+
+
+@router.post("/add", response_model=StandardResponse[AlgoModelInfo])
+def create_model(model_data: AlgoModelCreate, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.create_model(model_data)
+ return standard_response(data=model)
+
+
+@router.post("/update", response_model=StandardResponse[AlgoModelInfo])
+def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.update_model(model_data)
+ if not model:
+ return standard_error_response(data=model_data, message="Model not found")
+ return standard_response(data=model)
+
+
+@router.delete("/delete", response_model=StandardResponse[int])
+def delete_model(model_id: int, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.delete_model(model_id)
+ if not model:
+ return standard_error_response(data=model_id, message="Model not found")
+ return standard_response(data=model_id)
diff --git a/apis/router.py b/apis/router.py
new file mode 100644
index 0000000..22b5ed6
--- /dev/null
+++ b/apis/router.py
@@ -0,0 +1,15 @@
+# api/router.py
+from fastapi import APIRouter
+from .device import router as devices_router
+from .model import router as models_router
+from .device_model_realtion import router as device_model_relation_router
+
+
+# 创建一个全局的 router
+router = APIRouter()
+
+# 将各个模块的 router 注册到全局 router 中
+router.include_router(devices_router, prefix="/device", tags=["Devices"])
+router.include_router(models_router, prefix="/model", tags=["Models"])
+router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"])
+
diff --git a/common/__init__.py b/common/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/common/__init__.py
diff --git a/common/biz_exception.py b/common/biz_exception.py
new file mode 100644
index 0000000..946a794
--- /dev/null
+++ b/common/biz_exception.py
@@ -0,0 +1,14 @@
+from fastapi import Request, HTTPException
+
+from apis.base import standard_error_response
+
+
+class BizException(HTTPException):
+ def __init__(self, message: str = "请求异常", status_code: int = 500):
+ super().__init__(status_code=status_code, detail=message)
+
+
+class BizExceptionHandlers:
+ @staticmethod
+ async def biz_exception_handler(request: Request, exc: HTTPException):
+ standard_error_response(code=exc.status_code, message=exc.detail)
diff --git a/common/consts.py b/common/consts.py
new file mode 100644
index 0000000..bba41e4
--- /dev/null
+++ b/common/consts.py
@@ -0,0 +1,16 @@
+class Constants:
+ def __setattr__(self, name, value):
+ raise AttributeError(f"Cannot modify constant {name}")
+
+
+class NotifyChangeType(Constants):
+ # 定义常量成员
+ DEVICE_CREATE = "device_create"
+ DEVICE_UPDATE = "device_update"
+ DEVICE_DELETE = "device_delete"
+ MODEL_CREATE = "model_create" # 模型新增不用通知,与设备绑定时才生效
+ MODEL_UPDATE = "model_update"
+ MODEL_DELETE = "model_delete" # 正在使用的模型不能直接删除
+ DEVICE_MODEL_RELATION_CREATE = "device_model_relation_create"
+ DEVICE_MODEL_RELATION_UPDATE = "device_model_relation_update" # 绑定关系变化 应该只用这个吧??
+ DEVICE_MODEL_RELATION_DELETE = "device_model_relation_delete"
diff --git a/common/global_logger.py b/common/global_logger.py
new file mode 100644
index 0000000..1d65173
--- /dev/null
+++ b/common/global_logger.py
@@ -0,0 +1,30 @@
+# logger.py
+import logging.handlers
+import os
+
+# 确保日志目录存在
+log_dir = '../logs'
+if not os.path.exists(log_dir):
+ os.makedirs(log_dir)
+
+# 实例化并导出全局日志记录器
+logger = logging.getLogger("casic_safe_logger")
+logger.setLevel(logging.DEBUG) # 设置日志级别
+
+# 创建一个TimedRotatingFileHandler
+handler = logging.handlers.TimedRotatingFileHandler(
+ os.path.join(log_dir, 'app.log'), # 日志文件名
+ when='midnight', # 每天午夜滚动
+ interval=1 # 滚动间隔为1天
+)
+formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+handler.setFormatter(formatter)
+
+# 将handler添加到日志器
+logger.addHandler(handler)
+
+# 创建控制台处理器
+console_handler = logging.StreamHandler()
+console_handler.setLevel(logging.DEBUG)
+console_handler.setFormatter(formatter)
+logger.addHandler(console_handler)
diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py
new file mode 100644
index 0000000..21d7fd5
--- /dev/null
+++ b/common/global_thread_pool.py
@@ -0,0 +1,60 @@
+import uuid
+from concurrent.futures import ThreadPoolExecutor
+import threading
+
+
+def generate_thread_id():
+ """生成唯一的线程 ID"""
+ return str(uuid.uuid4())
+
+
+class GlobalThreadPool:
+ _instance = None
+ _lock = threading.Lock()
+
+ def __new__(cls, *args, **kwargs):
+ with cls._lock:
+ if cls._instance is None:
+ # 第一次创建实例时调用父类的 __new__ 来创建实例
+ cls._instance = super(GlobalThreadPool, cls).__new__(cls)
+ # 在此进行一次性的初始化,比如线程池的创建
+ max_workers = kwargs.get('max_workers', 10)
+ cls._instance.executor = ThreadPoolExecutor(max_workers=max_workers)
+ cls._instance.task_map = {} # 初始化任务映射
+ return cls._instance
+
+ def submit_task(self, fn, *args, thread_id=None, **kwargs):
+ """提交任务到线程池,并记录线程 ID"""
+ if thread_id is None:
+ thread_id = generate_thread_id()
+ if self.check_task_is_running(thread_id):
+ raise ValueError(f"线程 ID {thread_id} 已存在")
+ future = self.executor.submit(fn, *args, **kwargs)
+ self.task_map[thread_id] = future # 记录线程 ID 和 Future 对象的映射
+ return thread_id
+
+ def check_task_is_running(self, thread_id):
+ future = self.task_map.get(thread_id)
+ if future:
+ if future.running():
+ return True
+ else:
+ del self.task_map[thread_id]
+ return False
+ else:
+ return False
+
+ def stop_task(self, thread_id):
+ """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务"""
+ future = self.task_map.get(thread_id)
+ if future:
+ future.cancel() # 尝试取消任务
+ print(f"任务 {thread_id} 已取消")
+ del self.task_map[thread_id] # 从任务映射中删除
+ else:
+ print(f"未找到线程 ID {thread_id}")
+
+ def shutdown(self, wait=True):
+ """关闭线程池"""
+ self.executor.shutdown(wait=wait)
+ GlobalThreadPool._instance = None
diff --git a/db/__init__.py b/db/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/db/__init__.py
diff --git a/db/database.py b/db/database.py
new file mode 100644
index 0000000..4f44d2d
--- /dev/null
+++ b/db/database.py
@@ -0,0 +1,22 @@
+from sqlmodel import SQLModel, create_engine, Session
+from contextlib import contextmanager
+
+sqlite_file_name = "./db/safe-algo-pro.db"
+sqlite_url = f"sqlite:///{sqlite_file_name}"
+
+connect_args = {"check_same_thread": False}
+engine = create_engine(sqlite_url, connect_args=connect_args)
+
+
+# 初始化数据库表
+def init_db():
+ SQLModel.metadata.create_all(engine)
+
+
+# 数据库会话管理
+def get_db():
+ session = Session(engine)
+ try:
+ yield session
+ finally:
+ session.close()
diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db
new file mode 100644
index 0000000..86c8d64
--- /dev/null
+++ b/db/safe-algo-pro.db
Binary files differ
diff --git a/entity/__init__.py b/entity/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/entity/__init__.py
diff --git a/entity/base.py b/entity/base.py
new file mode 100644
index 0000000..6aa862c
--- /dev/null
+++ b/entity/base.py
@@ -0,0 +1,13 @@
+from sqlmodel import SQLModel, Field
+from typing import Optional
+from datetime import datetime
+
+
+class TimestampMixin(SQLModel):
+ create_time: Optional[datetime] = Field(default_factory=datetime.now)
+ update_time: Optional[datetime] = Field(default_factory=datetime.now)
+
+ class Config:
+ json_encoders = {
+ datetime: lambda v: v.strftime('%Y-%m-%d %H:%M:%S')
+ }
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..e530a86
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+test.py
+/logs/*
\ No newline at end of file
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..901fe74
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,20 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..b68d82f
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..3deddad
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml
new file mode 100644
index 0000000..d0876a7
--- /dev/null
+++ b/.idea/safe-algo-pro.iml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..94a25f7
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/algo/__init__.py b/algo/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/algo/__init__.py
diff --git a/algo/algo_runner.py b/algo/algo_runner.py
new file mode 100644
index 0000000..ee07eaf
--- /dev/null
+++ b/algo/algo_runner.py
@@ -0,0 +1,117 @@
+from algo.model_manager import ModelManager
+from common.consts import NotifyChangeType
+from common.global_thread_pool import GlobalThreadPool
+from db.database import get_db
+from entity.device import Device
+from services.device_model_relation_service import DeviceModelRelationService
+from services.device_service import DeviceService
+from services.model_service import ModelService
+
+
+class AlgoRunner:
+ def __init__(self):
+ self.db = get_db()
+ self.device_service = DeviceService(self.db)
+ self.model_service = ModelService(self.db)
+ self.relation_service = DeviceModelRelationService(self.db)
+ self.model_manager = ModelManager(self.db)
+ self.thread_pool = GlobalThreadPool()
+ self.threads = {} # 用于存储设备对应的线程
+ self.device_model_relations = {}
+
+ # 注册设备和模型的变化回调
+ self.device_service.register_change_callback(self.on_device_change)
+ self.model_service.register_change_callback(self.on_model_change)
+ self.relation_service.relation_change_callbacks(self.on_relation_change)
+
+ def start(self):
+ """在程序启动时调用,读取设备和模型,启动检测线程"""
+ self.model_manager.load_models()
+ self.load_and_start_devices()
+
+ def load_and_start_devices(self):
+ """从数据库读取设备列表并启动线程"""
+ devices = self.device_service.get_device_list()
+ devices = [device for device in devices if device.input_stream_url]
+ for device in devices:
+ self.start_device_thread(device)
+
+ def run_detection(self, device: Device):
+ """todo 设备目标检测主逻辑"""
+ print(f"设备 {device.device_id} 的检测线程启动")
+ video_stream = self.device_service.get_video_stream(device.device_id)
+
+ while True:
+ try:
+ frame = video_stream.read_frame()
+ for model in self.model_manager.models:
+ # 调用目标检测模型
+ result = self.model_service.run_inference(model, frame)
+ # 在此处进行后处理,例如标记视频、生成告警等
+ self.process_result(result, device.device_id)
+ except Exception as e:
+ print(f"设备 {device.device_id} 处理时出错: {e}")
+ break
+ video_stream.close()
+ print(f"设备 {device.device_id} 的检测线程结束")
+
+ def start_device_thread(self, device: Device):
+ """为单个设备启动检测线程"""
+
+ if not device.input_stream_url:
+ print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程')
+ return
+
+ if device.device_id in self.threads:
+ print(f"设备 {device.device_id} 已经在运行中")
+ return
+
+ # 获取设备绑定的模型列表
+ relations = self.relation_service.get_device_models(device.device_id)
+ self.device_model_relations[device.device_id] = relations
+
+ if not relations:
+ print(f"设备 {device.code} 未绑定模型,无法启动检测")
+ return
+
+ future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}')
+
+ def stop_device_thread(self, device_id):
+ """todo 控制标志位 停止指定设备的检测线程"""
+ if device_id in self.threads:
+ self.threads[device_id].cancel() # 尝试取消线程
+ del self.threads[device_id]
+ print(f"设备 {device_id} 的检测线程已停止")
+
+ def restart_device_thread(self, device_id):
+ self.stop_device_thread(device_id)
+ # todo 需要留个关闭时间吗
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+
+ def on_device_change(self, device_id, change_type):
+ """设备变化回调"""
+ if change_type == NotifyChangeType.DEVICE_CREATE:
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ elif change_type == NotifyChangeType.DEVICE_DELETE:
+ self.stop_device_thread(device_id)
+ elif change_type == NotifyChangeType.DEVICE_UPDATE:
+ self.restart_device_thread(device_id)
+
+ def on_model_change(self, model_id, change_type):
+ """模型变化回调"""
+ if change_type == NotifyChangeType.MODEL_UPDATE:
+ self.model_manager.reload_model(model_id)
+
+ devices_to_reload = [
+ device_id for device_id, relation_info in self.device_model_relations.items()
+ if relation_info.model_id == model_id
+ ]
+ for device_id in devices_to_reload:
+ self.restart_device_thread(device_id)
+
+ def on_relation_change(self, device_id, change_type):
+ """设备模型关系变化回调"""
+ if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE:
+ self.restart_device_thread(device_id)
diff --git a/algo/model_manager.py b/algo/model_manager.py
new file mode 100644
index 0000000..3ab0709
--- /dev/null
+++ b/algo/model_manager.py
@@ -0,0 +1,64 @@
+from dataclasses import dataclass
+from typing import Optional
+
+import numpy as np
+from ultralytics import YOLO
+
+from entity.model import AlgoModel
+from services.model_service import ModelService
+from common.global_logger import logger
+
+
+@dataclass
+class AlgoModelExec:
+ algo_model_id: int
+ algo_model_info: AlgoModel
+ algo_model_exec: Optional[object] = None
+ input_size: int = 640
+
+
+class ModelManager:
+ def __init__(self, db, model_warm_up=5):
+ self.db = db
+ self.model_service = ModelService(self.db)
+ self.models = {}
+ self.model_warm_up = model_warm_up
+
+ self.load_models()
+
+ def query_model_inuse(self):
+ algo_model_list = list(self.model_service.get_models_in_use())
+ for algo_model in algo_model_list:
+ self.models[algo_model.id] = AlgoModelExec(
+ algo_model_id=algo_model.id,
+ algo_model_info=algo_model
+ )
+
+ def load_models(self):
+ self.models = {}
+ self.query_model_inuse()
+ for algo_model_id, algo_model_exec in self.models.items():
+ self.load_model(algo_model_exec)
+
+ def load_model(self, algo_model_exec: AlgoModelExec):
+ model_name = algo_model_exec.algo_model_info.name
+ model_path = algo_model_exec.algo_model_info.path
+ logger.info(f'loading model {model_name}: {model_path}')
+
+ algo_model_exec.algo_model_exec = YOLO(model_path, task='detect')
+ if self.model_warm_up > 0:
+ logger.info(f'warming up model {model_name}')
+ imgsz = algo_model_exec.input_size
+ if not isinstance(imgsz, list):
+ imgsz = [imgsz, imgsz]
+ dummy_input = np.zeros((imgsz[0], imgsz[1], 3))
+ for i in range(self.model_warm_up):
+ algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False)
+ logger.info(f'warm up model {model_name} success!')
+ logger.info(f'load model {model_name} success!')
+
+ def reload_model(self, model_id):
+ algo_model_exec = self.models.get(model_id)
+ if algo_model_exec:
+ algo_model_exec.algo_model_exec = None
+ self.load_model(algo_model_exec)
diff --git a/apis/__init__.py b/apis/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/apis/__init__.py
diff --git a/apis/base.py b/apis/base.py
new file mode 100644
index 0000000..a7c4ede
--- /dev/null
+++ b/apis/base.py
@@ -0,0 +1,26 @@
+from pydantic import BaseModel, Field
+from typing import Optional, Generic, TypeVar, Any, List
+
+# 定义一个泛型类型变量 T
+T = TypeVar("T")
+
+
+# 使用泛型 T 定义 data 的类型
+class StandardResponse(BaseModel, Generic[T]):
+ code: int = 200
+ data: Optional[T] = None
+ message: str = "请求成功"
+ success: bool = True
+
+
+class PageResponse(BaseModel, Generic[T]):
+ total: int
+ items: List[T] = Field(default_factory=list)
+
+
+def standard_response(data: Any = None, code: int = 200, message: str = "请求成功", success: bool = True):
+ return StandardResponse(data=data, code=code, message=message, success=success)
+
+
+def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False):
+ return StandardResponse(data=data, code=code, message=message, success=success)
diff --git a/apis/device.py b/apis/device.py
new file mode 100644
index 0000000..7eeaa78
--- /dev/null
+++ b/apis/device.py
@@ -0,0 +1,65 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, Query
+from sqlmodel import Session
+
+from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response
+from db.database import get_db
+from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo
+from services.device_service import DeviceService
+
+router = APIRouter()
+
+
+@router.get("/list", response_model=StandardResponse[List[DeviceInfo]])
+def get_device_list(
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ devices = list(service.get_device_list(name, code, device_type))
+ return standard_response(data=devices)
+
+
+@router.get("/page/", response_model=StandardResponse[PageResponse[Device]])
+def get_device_page(
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ offset: int = Query(0, ge=0), # 从第几条开始
+ limit: int = Query(10, ge=1), # 每页显示多少条记录
+ db: Session = Depends(get_db)):
+ service = DeviceService(db)
+
+ # 获取分页后的设备列表和总数
+ devices, total = service.get_device_page(name, code, device_type, offset, limit)
+
+ return standard_response(
+ data=PageResponse(total=total, items=devices)
+ )
+
+
+@router.post("/add", response_model=StandardResponse[DeviceInfo])
+def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.create_device(device_data)
+ return standard_response(data=device)
+
+
+@router.post("/update", response_model=StandardResponse[DeviceInfo])
+def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.update_device(device_data)
+ if not device:
+ return standard_error_response(data=device_data, message="Device not found")
+ return standard_response(data=device)
+
+
+@router.delete("/delete", response_model=StandardResponse[int])
+def delete_device(device_id: int, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.delete_device(device_id)
+ if not device:
+ return standard_error_response(data=device_id, message="Device not found")
+ return standard_response(data=device_id)
diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py
new file mode 100644
index 0000000..7108571
--- /dev/null
+++ b/apis/device_model_realtion.py
@@ -0,0 +1,45 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, Query
+from sqlmodel import Session
+
+from apis.base import StandardResponse, standard_response
+from db.database import get_db
+from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation
+from services.device_model_relation_service import DeviceModelRelationService
+
+router = APIRouter()
+
+
+@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]])
+def list_by_device(
+ device_id: int,
+ db: Session = Depends(get_db)):
+ service = DeviceModelRelationService(db)
+ models = list(service.get_device_models(device_id))
+ return standard_response(data=models)
+
+
+# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
+# def add_by_device(relation_data: List[DeviceModelRelationCreate],
+# device_id: int = Query(...),
+# db: Session = Depends(get_db)):
+# service = DeviceModelRelationService(db)
+# relations = service.add_relations_by_device(device_id, relation_data)
+# return standard_response(data=relations)
+
+
+@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
+def update_by_device(relation_data: List[DeviceModelRelationCreate],
+ device_id: int = Query(...),
+ db: Session = Depends(get_db)):
+ service = DeviceModelRelationService(db)
+ relations = service.update_relations_by_device(device_id, relation_data)
+ return standard_response(data=relations)
+
+
+# @router.delete("/delete_by_device", response_model=StandardResponse[int])
+# def delete_device(device_id: int, db: Session = Depends(get_db)):
+# service = DeviceModelRelationService(db)
+# count = service.delete_relations_by_device(device_id)
+# return standard_response(data=count)
diff --git a/apis/model.py b/apis/model.py
new file mode 100644
index 0000000..8adcc14
--- /dev/null
+++ b/apis/model.py
@@ -0,0 +1,64 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, HTTPException, Query
+from sqlmodel import Session
+
+from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response
+from common.biz_exception import BizException
+from db.database import get_db
+from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo
+from services.model_service import ModelService
+
+router = APIRouter()
+
+
+@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]])
+def get_model_list(
+ name: Optional[str] = None,
+ remark: Optional[str] = None,
+ db: Session = Depends(get_db)):
+ service = ModelService(db)
+ models = list(service.get_model_list(name, remark))
+ return standard_response(data=models)
+
+
+@router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModel]])
+def get_model_page(
+ name: Optional[str] = None,
+ remark: Optional[str] = None,
+ offset: int = Query(0, ge=0), # 从第几条开始
+ limit: int = Query(10, ge=1), # 每页显示多少条记录
+ db: Session = Depends(get_db)):
+ service = ModelService(db)
+
+ # 获取分页后的设备列表和总数
+ models, total = service.get_model_page(name, remark, offset, limit)
+
+ return standard_response(
+ data=PageResponse(total=total, items=models)
+ )
+
+
+@router.post("/add", response_model=StandardResponse[AlgoModelInfo])
+def create_model(model_data: AlgoModelCreate, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.create_model(model_data)
+ return standard_response(data=model)
+
+
+@router.post("/update", response_model=StandardResponse[AlgoModelInfo])
+def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.update_model(model_data)
+ if not model:
+ return standard_error_response(data=model_data, message="Model not found")
+ return standard_response(data=model)
+
+
+@router.delete("/delete", response_model=StandardResponse[int])
+def delete_model(model_id: int, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.delete_model(model_id)
+ if not model:
+ return standard_error_response(data=model_id, message="Model not found")
+ return standard_response(data=model_id)
diff --git a/apis/router.py b/apis/router.py
new file mode 100644
index 0000000..22b5ed6
--- /dev/null
+++ b/apis/router.py
@@ -0,0 +1,15 @@
+# api/router.py
+from fastapi import APIRouter
+from .device import router as devices_router
+from .model import router as models_router
+from .device_model_realtion import router as device_model_relation_router
+
+
+# 创建一个全局的 router
+router = APIRouter()
+
+# 将各个模块的 router 注册到全局 router 中
+router.include_router(devices_router, prefix="/device", tags=["Devices"])
+router.include_router(models_router, prefix="/model", tags=["Models"])
+router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"])
+
diff --git a/common/__init__.py b/common/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/common/__init__.py
diff --git a/common/biz_exception.py b/common/biz_exception.py
new file mode 100644
index 0000000..946a794
--- /dev/null
+++ b/common/biz_exception.py
@@ -0,0 +1,14 @@
+from fastapi import Request, HTTPException
+
+from apis.base import standard_error_response
+
+
+class BizException(HTTPException):
+ def __init__(self, message: str = "请求异常", status_code: int = 500):
+ super().__init__(status_code=status_code, detail=message)
+
+
+class BizExceptionHandlers:
+ @staticmethod
+ async def biz_exception_handler(request: Request, exc: HTTPException):
+ standard_error_response(code=exc.status_code, message=exc.detail)
diff --git a/common/consts.py b/common/consts.py
new file mode 100644
index 0000000..bba41e4
--- /dev/null
+++ b/common/consts.py
@@ -0,0 +1,16 @@
+class Constants:
+ def __setattr__(self, name, value):
+ raise AttributeError(f"Cannot modify constant {name}")
+
+
+class NotifyChangeType(Constants):
+ # 定义常量成员
+ DEVICE_CREATE = "device_create"
+ DEVICE_UPDATE = "device_update"
+ DEVICE_DELETE = "device_delete"
+ MODEL_CREATE = "model_create" # 模型新增不用通知,与设备绑定时才生效
+ MODEL_UPDATE = "model_update"
+ MODEL_DELETE = "model_delete" # 正在使用的模型不能直接删除
+ DEVICE_MODEL_RELATION_CREATE = "device_model_relation_create"
+ DEVICE_MODEL_RELATION_UPDATE = "device_model_relation_update" # 绑定关系变化 应该只用这个吧??
+ DEVICE_MODEL_RELATION_DELETE = "device_model_relation_delete"
diff --git a/common/global_logger.py b/common/global_logger.py
new file mode 100644
index 0000000..1d65173
--- /dev/null
+++ b/common/global_logger.py
@@ -0,0 +1,30 @@
+# logger.py
+import logging.handlers
+import os
+
+# 确保日志目录存在
+log_dir = '../logs'
+if not os.path.exists(log_dir):
+ os.makedirs(log_dir)
+
+# 实例化并导出全局日志记录器
+logger = logging.getLogger("casic_safe_logger")
+logger.setLevel(logging.DEBUG) # 设置日志级别
+
+# 创建一个TimedRotatingFileHandler
+handler = logging.handlers.TimedRotatingFileHandler(
+ os.path.join(log_dir, 'app.log'), # 日志文件名
+ when='midnight', # 每天午夜滚动
+ interval=1 # 滚动间隔为1天
+)
+formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+handler.setFormatter(formatter)
+
+# 将handler添加到日志器
+logger.addHandler(handler)
+
+# 创建控制台处理器
+console_handler = logging.StreamHandler()
+console_handler.setLevel(logging.DEBUG)
+console_handler.setFormatter(formatter)
+logger.addHandler(console_handler)
diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py
new file mode 100644
index 0000000..21d7fd5
--- /dev/null
+++ b/common/global_thread_pool.py
@@ -0,0 +1,60 @@
+import uuid
+from concurrent.futures import ThreadPoolExecutor
+import threading
+
+
+def generate_thread_id():
+ """生成唯一的线程 ID"""
+ return str(uuid.uuid4())
+
+
+class GlobalThreadPool:
+ _instance = None
+ _lock = threading.Lock()
+
+ def __new__(cls, *args, **kwargs):
+ with cls._lock:
+ if cls._instance is None:
+ # 第一次创建实例时调用父类的 __new__ 来创建实例
+ cls._instance = super(GlobalThreadPool, cls).__new__(cls)
+ # 在此进行一次性的初始化,比如线程池的创建
+ max_workers = kwargs.get('max_workers', 10)
+ cls._instance.executor = ThreadPoolExecutor(max_workers=max_workers)
+ cls._instance.task_map = {} # 初始化任务映射
+ return cls._instance
+
+ def submit_task(self, fn, *args, thread_id=None, **kwargs):
+ """提交任务到线程池,并记录线程 ID"""
+ if thread_id is None:
+ thread_id = generate_thread_id()
+ if self.check_task_is_running(thread_id):
+ raise ValueError(f"线程 ID {thread_id} 已存在")
+ future = self.executor.submit(fn, *args, **kwargs)
+ self.task_map[thread_id] = future # 记录线程 ID 和 Future 对象的映射
+ return thread_id
+
+ def check_task_is_running(self, thread_id):
+ future = self.task_map.get(thread_id)
+ if future:
+ if future.running():
+ return True
+ else:
+ del self.task_map[thread_id]
+ return False
+ else:
+ return False
+
+ def stop_task(self, thread_id):
+ """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务"""
+ future = self.task_map.get(thread_id)
+ if future:
+ future.cancel() # 尝试取消任务
+ print(f"任务 {thread_id} 已取消")
+ del self.task_map[thread_id] # 从任务映射中删除
+ else:
+ print(f"未找到线程 ID {thread_id}")
+
+ def shutdown(self, wait=True):
+ """关闭线程池"""
+ self.executor.shutdown(wait=wait)
+ GlobalThreadPool._instance = None
diff --git a/db/__init__.py b/db/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/db/__init__.py
diff --git a/db/database.py b/db/database.py
new file mode 100644
index 0000000..4f44d2d
--- /dev/null
+++ b/db/database.py
@@ -0,0 +1,22 @@
+from sqlmodel import SQLModel, create_engine, Session
+from contextlib import contextmanager
+
+sqlite_file_name = "./db/safe-algo-pro.db"
+sqlite_url = f"sqlite:///{sqlite_file_name}"
+
+connect_args = {"check_same_thread": False}
+engine = create_engine(sqlite_url, connect_args=connect_args)
+
+
+# 初始化数据库表
+def init_db():
+ SQLModel.metadata.create_all(engine)
+
+
+# 数据库会话管理
+def get_db():
+ session = Session(engine)
+ try:
+ yield session
+ finally:
+ session.close()
diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db
new file mode 100644
index 0000000..86c8d64
--- /dev/null
+++ b/db/safe-algo-pro.db
Binary files differ
diff --git a/entity/__init__.py b/entity/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/entity/__init__.py
diff --git a/entity/base.py b/entity/base.py
new file mode 100644
index 0000000..6aa862c
--- /dev/null
+++ b/entity/base.py
@@ -0,0 +1,13 @@
+from sqlmodel import SQLModel, Field
+from typing import Optional
+from datetime import datetime
+
+
+class TimestampMixin(SQLModel):
+ create_time: Optional[datetime] = Field(default_factory=datetime.now)
+ update_time: Optional[datetime] = Field(default_factory=datetime.now)
+
+ class Config:
+ json_encoders = {
+ datetime: lambda v: v.strftime('%Y-%m-%d %H:%M:%S')
+ }
diff --git a/entity/device.py b/entity/device.py
new file mode 100644
index 0000000..f09b52d
--- /dev/null
+++ b/entity/device.py
@@ -0,0 +1,33 @@
+from typing import Optional
+from sqlmodel import SQLModel, Field
+
+from entity.base import TimestampMixin
+
+
+class DeviceBase(SQLModel):
+ name: str
+ code: str
+ type: Optional[str] = None
+ ip: str
+ input_stream_url: Optional[str] = None
+ output_stream_url: Optional[str] = None
+ image_save_interval: Optional[int] = None
+
+
+class Device(DeviceBase, TimestampMixin, table=True):
+ id: Optional[int] = Field(default=None, primary_key=True)
+
+
+class DeviceCreate(DeviceBase):
+ pass
+
+
+class DeviceUpdate(DeviceBase):
+ id: int
+ name: Optional[str] = None
+ code: Optional[str] = None
+ ip: Optional[str] = None
+
+
+class DeviceInfo(DeviceBase, TimestampMixin):
+ id: int
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..e530a86
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+test.py
+/logs/*
\ No newline at end of file
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..901fe74
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,20 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..b68d82f
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..3deddad
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml
new file mode 100644
index 0000000..d0876a7
--- /dev/null
+++ b/.idea/safe-algo-pro.iml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..94a25f7
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/algo/__init__.py b/algo/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/algo/__init__.py
diff --git a/algo/algo_runner.py b/algo/algo_runner.py
new file mode 100644
index 0000000..ee07eaf
--- /dev/null
+++ b/algo/algo_runner.py
@@ -0,0 +1,117 @@
+from algo.model_manager import ModelManager
+from common.consts import NotifyChangeType
+from common.global_thread_pool import GlobalThreadPool
+from db.database import get_db
+from entity.device import Device
+from services.device_model_relation_service import DeviceModelRelationService
+from services.device_service import DeviceService
+from services.model_service import ModelService
+
+
+class AlgoRunner:
+ def __init__(self):
+ self.db = get_db()
+ self.device_service = DeviceService(self.db)
+ self.model_service = ModelService(self.db)
+ self.relation_service = DeviceModelRelationService(self.db)
+ self.model_manager = ModelManager(self.db)
+ self.thread_pool = GlobalThreadPool()
+ self.threads = {} # 用于存储设备对应的线程
+ self.device_model_relations = {}
+
+ # 注册设备和模型的变化回调
+ self.device_service.register_change_callback(self.on_device_change)
+ self.model_service.register_change_callback(self.on_model_change)
+ self.relation_service.relation_change_callbacks(self.on_relation_change)
+
+ def start(self):
+ """在程序启动时调用,读取设备和模型,启动检测线程"""
+ self.model_manager.load_models()
+ self.load_and_start_devices()
+
+ def load_and_start_devices(self):
+ """从数据库读取设备列表并启动线程"""
+ devices = self.device_service.get_device_list()
+ devices = [device for device in devices if device.input_stream_url]
+ for device in devices:
+ self.start_device_thread(device)
+
+ def run_detection(self, device: Device):
+ """todo 设备目标检测主逻辑"""
+ print(f"设备 {device.device_id} 的检测线程启动")
+ video_stream = self.device_service.get_video_stream(device.device_id)
+
+ while True:
+ try:
+ frame = video_stream.read_frame()
+ for model in self.model_manager.models:
+ # 调用目标检测模型
+ result = self.model_service.run_inference(model, frame)
+ # 在此处进行后处理,例如标记视频、生成告警等
+ self.process_result(result, device.device_id)
+ except Exception as e:
+ print(f"设备 {device.device_id} 处理时出错: {e}")
+ break
+ video_stream.close()
+ print(f"设备 {device.device_id} 的检测线程结束")
+
+ def start_device_thread(self, device: Device):
+ """为单个设备启动检测线程"""
+
+ if not device.input_stream_url:
+ print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程')
+ return
+
+ if device.device_id in self.threads:
+ print(f"设备 {device.device_id} 已经在运行中")
+ return
+
+ # 获取设备绑定的模型列表
+ relations = self.relation_service.get_device_models(device.device_id)
+ self.device_model_relations[device.device_id] = relations
+
+ if not relations:
+ print(f"设备 {device.code} 未绑定模型,无法启动检测")
+ return
+
+ future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}')
+
+ def stop_device_thread(self, device_id):
+ """todo 控制标志位 停止指定设备的检测线程"""
+ if device_id in self.threads:
+ self.threads[device_id].cancel() # 尝试取消线程
+ del self.threads[device_id]
+ print(f"设备 {device_id} 的检测线程已停止")
+
+ def restart_device_thread(self, device_id):
+ self.stop_device_thread(device_id)
+ # todo 需要留个关闭时间吗
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+
+ def on_device_change(self, device_id, change_type):
+ """设备变化回调"""
+ if change_type == NotifyChangeType.DEVICE_CREATE:
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ elif change_type == NotifyChangeType.DEVICE_DELETE:
+ self.stop_device_thread(device_id)
+ elif change_type == NotifyChangeType.DEVICE_UPDATE:
+ self.restart_device_thread(device_id)
+
+ def on_model_change(self, model_id, change_type):
+ """模型变化回调"""
+ if change_type == NotifyChangeType.MODEL_UPDATE:
+ self.model_manager.reload_model(model_id)
+
+ devices_to_reload = [
+ device_id for device_id, relation_info in self.device_model_relations.items()
+ if relation_info.model_id == model_id
+ ]
+ for device_id in devices_to_reload:
+ self.restart_device_thread(device_id)
+
+ def on_relation_change(self, device_id, change_type):
+ """设备模型关系变化回调"""
+ if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE:
+ self.restart_device_thread(device_id)
diff --git a/algo/model_manager.py b/algo/model_manager.py
new file mode 100644
index 0000000..3ab0709
--- /dev/null
+++ b/algo/model_manager.py
@@ -0,0 +1,64 @@
+from dataclasses import dataclass
+from typing import Optional
+
+import numpy as np
+from ultralytics import YOLO
+
+from entity.model import AlgoModel
+from services.model_service import ModelService
+from common.global_logger import logger
+
+
+@dataclass
+class AlgoModelExec:
+ algo_model_id: int
+ algo_model_info: AlgoModel
+ algo_model_exec: Optional[object] = None
+ input_size: int = 640
+
+
+class ModelManager:
+ def __init__(self, db, model_warm_up=5):
+ self.db = db
+ self.model_service = ModelService(self.db)
+ self.models = {}
+ self.model_warm_up = model_warm_up
+
+ self.load_models()
+
+ def query_model_inuse(self):
+ algo_model_list = list(self.model_service.get_models_in_use())
+ for algo_model in algo_model_list:
+ self.models[algo_model.id] = AlgoModelExec(
+ algo_model_id=algo_model.id,
+ algo_model_info=algo_model
+ )
+
+ def load_models(self):
+ self.models = {}
+ self.query_model_inuse()
+ for algo_model_id, algo_model_exec in self.models.items():
+ self.load_model(algo_model_exec)
+
+ def load_model(self, algo_model_exec: AlgoModelExec):
+ model_name = algo_model_exec.algo_model_info.name
+ model_path = algo_model_exec.algo_model_info.path
+ logger.info(f'loading model {model_name}: {model_path}')
+
+ algo_model_exec.algo_model_exec = YOLO(model_path, task='detect')
+ if self.model_warm_up > 0:
+ logger.info(f'warming up model {model_name}')
+ imgsz = algo_model_exec.input_size
+ if not isinstance(imgsz, list):
+ imgsz = [imgsz, imgsz]
+ dummy_input = np.zeros((imgsz[0], imgsz[1], 3))
+ for i in range(self.model_warm_up):
+ algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False)
+ logger.info(f'warm up model {model_name} success!')
+ logger.info(f'load model {model_name} success!')
+
+ def reload_model(self, model_id):
+ algo_model_exec = self.models.get(model_id)
+ if algo_model_exec:
+ algo_model_exec.algo_model_exec = None
+ self.load_model(algo_model_exec)
diff --git a/apis/__init__.py b/apis/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/apis/__init__.py
diff --git a/apis/base.py b/apis/base.py
new file mode 100644
index 0000000..a7c4ede
--- /dev/null
+++ b/apis/base.py
@@ -0,0 +1,26 @@
+from pydantic import BaseModel, Field
+from typing import Optional, Generic, TypeVar, Any, List
+
+# 定义一个泛型类型变量 T
+T = TypeVar("T")
+
+
+# 使用泛型 T 定义 data 的类型
+class StandardResponse(BaseModel, Generic[T]):
+ code: int = 200
+ data: Optional[T] = None
+ message: str = "请求成功"
+ success: bool = True
+
+
+class PageResponse(BaseModel, Generic[T]):
+ total: int
+ items: List[T] = Field(default_factory=list)
+
+
+def standard_response(data: Any = None, code: int = 200, message: str = "请求成功", success: bool = True):
+ return StandardResponse(data=data, code=code, message=message, success=success)
+
+
+def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False):
+ return StandardResponse(data=data, code=code, message=message, success=success)
diff --git a/apis/device.py b/apis/device.py
new file mode 100644
index 0000000..7eeaa78
--- /dev/null
+++ b/apis/device.py
@@ -0,0 +1,65 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, Query
+from sqlmodel import Session
+
+from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response
+from db.database import get_db
+from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo
+from services.device_service import DeviceService
+
+router = APIRouter()
+
+
+@router.get("/list", response_model=StandardResponse[List[DeviceInfo]])
+def get_device_list(
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ devices = list(service.get_device_list(name, code, device_type))
+ return standard_response(data=devices)
+
+
+@router.get("/page/", response_model=StandardResponse[PageResponse[Device]])
+def get_device_page(
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ offset: int = Query(0, ge=0), # 从第几条开始
+ limit: int = Query(10, ge=1), # 每页显示多少条记录
+ db: Session = Depends(get_db)):
+ service = DeviceService(db)
+
+ # 获取分页后的设备列表和总数
+ devices, total = service.get_device_page(name, code, device_type, offset, limit)
+
+ return standard_response(
+ data=PageResponse(total=total, items=devices)
+ )
+
+
+@router.post("/add", response_model=StandardResponse[DeviceInfo])
+def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.create_device(device_data)
+ return standard_response(data=device)
+
+
+@router.post("/update", response_model=StandardResponse[DeviceInfo])
+def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.update_device(device_data)
+ if not device:
+ return standard_error_response(data=device_data, message="Device not found")
+ return standard_response(data=device)
+
+
+@router.delete("/delete", response_model=StandardResponse[int])
+def delete_device(device_id: int, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.delete_device(device_id)
+ if not device:
+ return standard_error_response(data=device_id, message="Device not found")
+ return standard_response(data=device_id)
diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py
new file mode 100644
index 0000000..7108571
--- /dev/null
+++ b/apis/device_model_realtion.py
@@ -0,0 +1,45 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, Query
+from sqlmodel import Session
+
+from apis.base import StandardResponse, standard_response
+from db.database import get_db
+from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation
+from services.device_model_relation_service import DeviceModelRelationService
+
+router = APIRouter()
+
+
+@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]])
+def list_by_device(
+ device_id: int,
+ db: Session = Depends(get_db)):
+ service = DeviceModelRelationService(db)
+ models = list(service.get_device_models(device_id))
+ return standard_response(data=models)
+
+
+# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
+# def add_by_device(relation_data: List[DeviceModelRelationCreate],
+# device_id: int = Query(...),
+# db: Session = Depends(get_db)):
+# service = DeviceModelRelationService(db)
+# relations = service.add_relations_by_device(device_id, relation_data)
+# return standard_response(data=relations)
+
+
+@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
+def update_by_device(relation_data: List[DeviceModelRelationCreate],
+ device_id: int = Query(...),
+ db: Session = Depends(get_db)):
+ service = DeviceModelRelationService(db)
+ relations = service.update_relations_by_device(device_id, relation_data)
+ return standard_response(data=relations)
+
+
+# @router.delete("/delete_by_device", response_model=StandardResponse[int])
+# def delete_device(device_id: int, db: Session = Depends(get_db)):
+# service = DeviceModelRelationService(db)
+# count = service.delete_relations_by_device(device_id)
+# return standard_response(data=count)
diff --git a/apis/model.py b/apis/model.py
new file mode 100644
index 0000000..8adcc14
--- /dev/null
+++ b/apis/model.py
@@ -0,0 +1,64 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, HTTPException, Query
+from sqlmodel import Session
+
+from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response
+from common.biz_exception import BizException
+from db.database import get_db
+from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo
+from services.model_service import ModelService
+
+router = APIRouter()
+
+
+@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]])
+def get_model_list(
+ name: Optional[str] = None,
+ remark: Optional[str] = None,
+ db: Session = Depends(get_db)):
+ service = ModelService(db)
+ models = list(service.get_model_list(name, remark))
+ return standard_response(data=models)
+
+
+@router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModel]])
+def get_model_page(
+ name: Optional[str] = None,
+ remark: Optional[str] = None,
+ offset: int = Query(0, ge=0), # 从第几条开始
+ limit: int = Query(10, ge=1), # 每页显示多少条记录
+ db: Session = Depends(get_db)):
+ service = ModelService(db)
+
+ # 获取分页后的设备列表和总数
+ models, total = service.get_model_page(name, remark, offset, limit)
+
+ return standard_response(
+ data=PageResponse(total=total, items=models)
+ )
+
+
+@router.post("/add", response_model=StandardResponse[AlgoModelInfo])
+def create_model(model_data: AlgoModelCreate, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.create_model(model_data)
+ return standard_response(data=model)
+
+
+@router.post("/update", response_model=StandardResponse[AlgoModelInfo])
+def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.update_model(model_data)
+ if not model:
+ return standard_error_response(data=model_data, message="Model not found")
+ return standard_response(data=model)
+
+
+@router.delete("/delete", response_model=StandardResponse[int])
+def delete_model(model_id: int, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.delete_model(model_id)
+ if not model:
+ return standard_error_response(data=model_id, message="Model not found")
+ return standard_response(data=model_id)
diff --git a/apis/router.py b/apis/router.py
new file mode 100644
index 0000000..22b5ed6
--- /dev/null
+++ b/apis/router.py
@@ -0,0 +1,15 @@
+# api/router.py
+from fastapi import APIRouter
+from .device import router as devices_router
+from .model import router as models_router
+from .device_model_realtion import router as device_model_relation_router
+
+
+# 创建一个全局的 router
+router = APIRouter()
+
+# 将各个模块的 router 注册到全局 router 中
+router.include_router(devices_router, prefix="/device", tags=["Devices"])
+router.include_router(models_router, prefix="/model", tags=["Models"])
+router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"])
+
diff --git a/common/__init__.py b/common/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/common/__init__.py
diff --git a/common/biz_exception.py b/common/biz_exception.py
new file mode 100644
index 0000000..946a794
--- /dev/null
+++ b/common/biz_exception.py
@@ -0,0 +1,14 @@
+from fastapi import Request, HTTPException
+
+from apis.base import standard_error_response
+
+
+class BizException(HTTPException):
+ def __init__(self, message: str = "请求异常", status_code: int = 500):
+ super().__init__(status_code=status_code, detail=message)
+
+
+class BizExceptionHandlers:
+ @staticmethod
+ async def biz_exception_handler(request: Request, exc: HTTPException):
+ standard_error_response(code=exc.status_code, message=exc.detail)
diff --git a/common/consts.py b/common/consts.py
new file mode 100644
index 0000000..bba41e4
--- /dev/null
+++ b/common/consts.py
@@ -0,0 +1,16 @@
+class Constants:
+ def __setattr__(self, name, value):
+ raise AttributeError(f"Cannot modify constant {name}")
+
+
+class NotifyChangeType(Constants):
+ # 定义常量成员
+ DEVICE_CREATE = "device_create"
+ DEVICE_UPDATE = "device_update"
+ DEVICE_DELETE = "device_delete"
+ MODEL_CREATE = "model_create" # 模型新增不用通知,与设备绑定时才生效
+ MODEL_UPDATE = "model_update"
+ MODEL_DELETE = "model_delete" # 正在使用的模型不能直接删除
+ DEVICE_MODEL_RELATION_CREATE = "device_model_relation_create"
+ DEVICE_MODEL_RELATION_UPDATE = "device_model_relation_update" # 绑定关系变化 应该只用这个吧??
+ DEVICE_MODEL_RELATION_DELETE = "device_model_relation_delete"
diff --git a/common/global_logger.py b/common/global_logger.py
new file mode 100644
index 0000000..1d65173
--- /dev/null
+++ b/common/global_logger.py
@@ -0,0 +1,30 @@
+# logger.py
+import logging.handlers
+import os
+
+# 确保日志目录存在
+log_dir = '../logs'
+if not os.path.exists(log_dir):
+ os.makedirs(log_dir)
+
+# 实例化并导出全局日志记录器
+logger = logging.getLogger("casic_safe_logger")
+logger.setLevel(logging.DEBUG) # 设置日志级别
+
+# 创建一个TimedRotatingFileHandler
+handler = logging.handlers.TimedRotatingFileHandler(
+ os.path.join(log_dir, 'app.log'), # 日志文件名
+ when='midnight', # 每天午夜滚动
+ interval=1 # 滚动间隔为1天
+)
+formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+handler.setFormatter(formatter)
+
+# 将handler添加到日志器
+logger.addHandler(handler)
+
+# 创建控制台处理器
+console_handler = logging.StreamHandler()
+console_handler.setLevel(logging.DEBUG)
+console_handler.setFormatter(formatter)
+logger.addHandler(console_handler)
diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py
new file mode 100644
index 0000000..21d7fd5
--- /dev/null
+++ b/common/global_thread_pool.py
@@ -0,0 +1,60 @@
+import uuid
+from concurrent.futures import ThreadPoolExecutor
+import threading
+
+
+def generate_thread_id():
+ """生成唯一的线程 ID"""
+ return str(uuid.uuid4())
+
+
+class GlobalThreadPool:
+ _instance = None
+ _lock = threading.Lock()
+
+ def __new__(cls, *args, **kwargs):
+ with cls._lock:
+ if cls._instance is None:
+ # 第一次创建实例时调用父类的 __new__ 来创建实例
+ cls._instance = super(GlobalThreadPool, cls).__new__(cls)
+ # 在此进行一次性的初始化,比如线程池的创建
+ max_workers = kwargs.get('max_workers', 10)
+ cls._instance.executor = ThreadPoolExecutor(max_workers=max_workers)
+ cls._instance.task_map = {} # 初始化任务映射
+ return cls._instance
+
+ def submit_task(self, fn, *args, thread_id=None, **kwargs):
+ """提交任务到线程池,并记录线程 ID"""
+ if thread_id is None:
+ thread_id = generate_thread_id()
+ if self.check_task_is_running(thread_id):
+ raise ValueError(f"线程 ID {thread_id} 已存在")
+ future = self.executor.submit(fn, *args, **kwargs)
+ self.task_map[thread_id] = future # 记录线程 ID 和 Future 对象的映射
+ return thread_id
+
+ def check_task_is_running(self, thread_id):
+ future = self.task_map.get(thread_id)
+ if future:
+ if future.running():
+ return True
+ else:
+ del self.task_map[thread_id]
+ return False
+ else:
+ return False
+
+ def stop_task(self, thread_id):
+ """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务"""
+ future = self.task_map.get(thread_id)
+ if future:
+ future.cancel() # 尝试取消任务
+ print(f"任务 {thread_id} 已取消")
+ del self.task_map[thread_id] # 从任务映射中删除
+ else:
+ print(f"未找到线程 ID {thread_id}")
+
+ def shutdown(self, wait=True):
+ """关闭线程池"""
+ self.executor.shutdown(wait=wait)
+ GlobalThreadPool._instance = None
diff --git a/db/__init__.py b/db/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/db/__init__.py
diff --git a/db/database.py b/db/database.py
new file mode 100644
index 0000000..4f44d2d
--- /dev/null
+++ b/db/database.py
@@ -0,0 +1,22 @@
+from sqlmodel import SQLModel, create_engine, Session
+from contextlib import contextmanager
+
+sqlite_file_name = "./db/safe-algo-pro.db"
+sqlite_url = f"sqlite:///{sqlite_file_name}"
+
+connect_args = {"check_same_thread": False}
+engine = create_engine(sqlite_url, connect_args=connect_args)
+
+
+# 初始化数据库表
+def init_db():
+ SQLModel.metadata.create_all(engine)
+
+
+# 数据库会话管理
+def get_db():
+ session = Session(engine)
+ try:
+ yield session
+ finally:
+ session.close()
diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db
new file mode 100644
index 0000000..86c8d64
--- /dev/null
+++ b/db/safe-algo-pro.db
Binary files differ
diff --git a/entity/__init__.py b/entity/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/entity/__init__.py
diff --git a/entity/base.py b/entity/base.py
new file mode 100644
index 0000000..6aa862c
--- /dev/null
+++ b/entity/base.py
@@ -0,0 +1,13 @@
+from sqlmodel import SQLModel, Field
+from typing import Optional
+from datetime import datetime
+
+
+class TimestampMixin(SQLModel):
+ create_time: Optional[datetime] = Field(default_factory=datetime.now)
+ update_time: Optional[datetime] = Field(default_factory=datetime.now)
+
+ class Config:
+ json_encoders = {
+ datetime: lambda v: v.strftime('%Y-%m-%d %H:%M:%S')
+ }
diff --git a/entity/device.py b/entity/device.py
new file mode 100644
index 0000000..f09b52d
--- /dev/null
+++ b/entity/device.py
@@ -0,0 +1,33 @@
+from typing import Optional
+from sqlmodel import SQLModel, Field
+
+from entity.base import TimestampMixin
+
+
+class DeviceBase(SQLModel):
+ name: str
+ code: str
+ type: Optional[str] = None
+ ip: str
+ input_stream_url: Optional[str] = None
+ output_stream_url: Optional[str] = None
+ image_save_interval: Optional[int] = None
+
+
+class Device(DeviceBase, TimestampMixin, table=True):
+ id: Optional[int] = Field(default=None, primary_key=True)
+
+
+class DeviceCreate(DeviceBase):
+ pass
+
+
+class DeviceUpdate(DeviceBase):
+ id: int
+ name: Optional[str] = None
+ code: Optional[str] = None
+ ip: Optional[str] = None
+
+
+class DeviceInfo(DeviceBase, TimestampMixin):
+ id: int
diff --git a/entity/device_model_relation.py b/entity/device_model_relation.py
new file mode 100644
index 0000000..2e58a39
--- /dev/null
+++ b/entity/device_model_relation.py
@@ -0,0 +1,31 @@
+from typing import Optional
+
+from sqlmodel import SQLModel, Field
+
+from entity.base import TimestampMixin
+
+
+class DeviceModelRelationBase(SQLModel):
+ algo_model_id: int
+ is_use: int
+ threshold: Optional[float] = None
+ alarm_interval: Optional[int] = None
+ alarm_type: Optional[str] = None
+
+
+class DeviceModelRelation(DeviceModelRelationBase, TimestampMixin, table=True):
+ __tablename__ = 'device_model_relation'
+ id: Optional[int] = Field(default=None, primary_key=True)
+ device_id: int
+
+class DeviceModelRelationCreate(DeviceModelRelationBase):
+ pass
+
+
+class DeviceModelRelationInfo(DeviceModelRelationBase, TimestampMixin):
+ id: int
+ device_id: int
+ algo_model_name: str
+ algo_model_version: str
+ algo_model_path: Optional[str] = None # 可选字段
+ algo_model_remark: Optional[str] = None
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..e530a86
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+test.py
+/logs/*
\ No newline at end of file
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..901fe74
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,20 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..b68d82f
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..3deddad
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml
new file mode 100644
index 0000000..d0876a7
--- /dev/null
+++ b/.idea/safe-algo-pro.iml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..94a25f7
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/algo/__init__.py b/algo/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/algo/__init__.py
diff --git a/algo/algo_runner.py b/algo/algo_runner.py
new file mode 100644
index 0000000..ee07eaf
--- /dev/null
+++ b/algo/algo_runner.py
@@ -0,0 +1,117 @@
+from algo.model_manager import ModelManager
+from common.consts import NotifyChangeType
+from common.global_thread_pool import GlobalThreadPool
+from db.database import get_db
+from entity.device import Device
+from services.device_model_relation_service import DeviceModelRelationService
+from services.device_service import DeviceService
+from services.model_service import ModelService
+
+
+class AlgoRunner:
+ def __init__(self):
+ self.db = get_db()
+ self.device_service = DeviceService(self.db)
+ self.model_service = ModelService(self.db)
+ self.relation_service = DeviceModelRelationService(self.db)
+ self.model_manager = ModelManager(self.db)
+ self.thread_pool = GlobalThreadPool()
+ self.threads = {} # 用于存储设备对应的线程
+ self.device_model_relations = {}
+
+ # 注册设备和模型的变化回调
+ self.device_service.register_change_callback(self.on_device_change)
+ self.model_service.register_change_callback(self.on_model_change)
+ self.relation_service.relation_change_callbacks(self.on_relation_change)
+
+ def start(self):
+ """在程序启动时调用,读取设备和模型,启动检测线程"""
+ self.model_manager.load_models()
+ self.load_and_start_devices()
+
+ def load_and_start_devices(self):
+ """从数据库读取设备列表并启动线程"""
+ devices = self.device_service.get_device_list()
+ devices = [device for device in devices if device.input_stream_url]
+ for device in devices:
+ self.start_device_thread(device)
+
+ def run_detection(self, device: Device):
+ """todo 设备目标检测主逻辑"""
+ print(f"设备 {device.device_id} 的检测线程启动")
+ video_stream = self.device_service.get_video_stream(device.device_id)
+
+ while True:
+ try:
+ frame = video_stream.read_frame()
+ for model in self.model_manager.models:
+ # 调用目标检测模型
+ result = self.model_service.run_inference(model, frame)
+ # 在此处进行后处理,例如标记视频、生成告警等
+ self.process_result(result, device.device_id)
+ except Exception as e:
+ print(f"设备 {device.device_id} 处理时出错: {e}")
+ break
+ video_stream.close()
+ print(f"设备 {device.device_id} 的检测线程结束")
+
+ def start_device_thread(self, device: Device):
+ """为单个设备启动检测线程"""
+
+ if not device.input_stream_url:
+ print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程')
+ return
+
+ if device.device_id in self.threads:
+ print(f"设备 {device.device_id} 已经在运行中")
+ return
+
+ # 获取设备绑定的模型列表
+ relations = self.relation_service.get_device_models(device.device_id)
+ self.device_model_relations[device.device_id] = relations
+
+ if not relations:
+ print(f"设备 {device.code} 未绑定模型,无法启动检测")
+ return
+
+ future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}')
+
+ def stop_device_thread(self, device_id):
+ """todo 控制标志位 停止指定设备的检测线程"""
+ if device_id in self.threads:
+ self.threads[device_id].cancel() # 尝试取消线程
+ del self.threads[device_id]
+ print(f"设备 {device_id} 的检测线程已停止")
+
+ def restart_device_thread(self, device_id):
+ self.stop_device_thread(device_id)
+ # todo 需要留个关闭时间吗
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+
+ def on_device_change(self, device_id, change_type):
+ """设备变化回调"""
+ if change_type == NotifyChangeType.DEVICE_CREATE:
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ elif change_type == NotifyChangeType.DEVICE_DELETE:
+ self.stop_device_thread(device_id)
+ elif change_type == NotifyChangeType.DEVICE_UPDATE:
+ self.restart_device_thread(device_id)
+
+ def on_model_change(self, model_id, change_type):
+ """模型变化回调"""
+ if change_type == NotifyChangeType.MODEL_UPDATE:
+ self.model_manager.reload_model(model_id)
+
+ devices_to_reload = [
+ device_id for device_id, relation_info in self.device_model_relations.items()
+ if relation_info.model_id == model_id
+ ]
+ for device_id in devices_to_reload:
+ self.restart_device_thread(device_id)
+
+ def on_relation_change(self, device_id, change_type):
+ """设备模型关系变化回调"""
+ if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE:
+ self.restart_device_thread(device_id)
diff --git a/algo/model_manager.py b/algo/model_manager.py
new file mode 100644
index 0000000..3ab0709
--- /dev/null
+++ b/algo/model_manager.py
@@ -0,0 +1,64 @@
+from dataclasses import dataclass
+from typing import Optional
+
+import numpy as np
+from ultralytics import YOLO
+
+from entity.model import AlgoModel
+from services.model_service import ModelService
+from common.global_logger import logger
+
+
+@dataclass
+class AlgoModelExec:
+ algo_model_id: int
+ algo_model_info: AlgoModel
+ algo_model_exec: Optional[object] = None
+ input_size: int = 640
+
+
+class ModelManager:
+ def __init__(self, db, model_warm_up=5):
+ self.db = db
+ self.model_service = ModelService(self.db)
+ self.models = {}
+ self.model_warm_up = model_warm_up
+
+ self.load_models()
+
+ def query_model_inuse(self):
+ algo_model_list = list(self.model_service.get_models_in_use())
+ for algo_model in algo_model_list:
+ self.models[algo_model.id] = AlgoModelExec(
+ algo_model_id=algo_model.id,
+ algo_model_info=algo_model
+ )
+
+ def load_models(self):
+ self.models = {}
+ self.query_model_inuse()
+ for algo_model_id, algo_model_exec in self.models.items():
+ self.load_model(algo_model_exec)
+
+ def load_model(self, algo_model_exec: AlgoModelExec):
+ model_name = algo_model_exec.algo_model_info.name
+ model_path = algo_model_exec.algo_model_info.path
+ logger.info(f'loading model {model_name}: {model_path}')
+
+ algo_model_exec.algo_model_exec = YOLO(model_path, task='detect')
+ if self.model_warm_up > 0:
+ logger.info(f'warming up model {model_name}')
+ imgsz = algo_model_exec.input_size
+ if not isinstance(imgsz, list):
+ imgsz = [imgsz, imgsz]
+ dummy_input = np.zeros((imgsz[0], imgsz[1], 3))
+ for i in range(self.model_warm_up):
+ algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False)
+ logger.info(f'warm up model {model_name} success!')
+ logger.info(f'load model {model_name} success!')
+
+ def reload_model(self, model_id):
+ algo_model_exec = self.models.get(model_id)
+ if algo_model_exec:
+ algo_model_exec.algo_model_exec = None
+ self.load_model(algo_model_exec)
diff --git a/apis/__init__.py b/apis/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/apis/__init__.py
diff --git a/apis/base.py b/apis/base.py
new file mode 100644
index 0000000..a7c4ede
--- /dev/null
+++ b/apis/base.py
@@ -0,0 +1,26 @@
+from pydantic import BaseModel, Field
+from typing import Optional, Generic, TypeVar, Any, List
+
+# 定义一个泛型类型变量 T
+T = TypeVar("T")
+
+
+# 使用泛型 T 定义 data 的类型
+class StandardResponse(BaseModel, Generic[T]):
+ code: int = 200
+ data: Optional[T] = None
+ message: str = "请求成功"
+ success: bool = True
+
+
+class PageResponse(BaseModel, Generic[T]):
+ total: int
+ items: List[T] = Field(default_factory=list)
+
+
+def standard_response(data: Any = None, code: int = 200, message: str = "请求成功", success: bool = True):
+ return StandardResponse(data=data, code=code, message=message, success=success)
+
+
+def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False):
+ return StandardResponse(data=data, code=code, message=message, success=success)
diff --git a/apis/device.py b/apis/device.py
new file mode 100644
index 0000000..7eeaa78
--- /dev/null
+++ b/apis/device.py
@@ -0,0 +1,65 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, Query
+from sqlmodel import Session
+
+from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response
+from db.database import get_db
+from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo
+from services.device_service import DeviceService
+
+router = APIRouter()
+
+
+@router.get("/list", response_model=StandardResponse[List[DeviceInfo]])
+def get_device_list(
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ devices = list(service.get_device_list(name, code, device_type))
+ return standard_response(data=devices)
+
+
+@router.get("/page/", response_model=StandardResponse[PageResponse[Device]])
+def get_device_page(
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ offset: int = Query(0, ge=0), # 从第几条开始
+ limit: int = Query(10, ge=1), # 每页显示多少条记录
+ db: Session = Depends(get_db)):
+ service = DeviceService(db)
+
+ # 获取分页后的设备列表和总数
+ devices, total = service.get_device_page(name, code, device_type, offset, limit)
+
+ return standard_response(
+ data=PageResponse(total=total, items=devices)
+ )
+
+
+@router.post("/add", response_model=StandardResponse[DeviceInfo])
+def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.create_device(device_data)
+ return standard_response(data=device)
+
+
+@router.post("/update", response_model=StandardResponse[DeviceInfo])
+def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.update_device(device_data)
+ if not device:
+ return standard_error_response(data=device_data, message="Device not found")
+ return standard_response(data=device)
+
+
+@router.delete("/delete", response_model=StandardResponse[int])
+def delete_device(device_id: int, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.delete_device(device_id)
+ if not device:
+ return standard_error_response(data=device_id, message="Device not found")
+ return standard_response(data=device_id)
diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py
new file mode 100644
index 0000000..7108571
--- /dev/null
+++ b/apis/device_model_realtion.py
@@ -0,0 +1,45 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, Query
+from sqlmodel import Session
+
+from apis.base import StandardResponse, standard_response
+from db.database import get_db
+from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation
+from services.device_model_relation_service import DeviceModelRelationService
+
+router = APIRouter()
+
+
+@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]])
+def list_by_device(
+ device_id: int,
+ db: Session = Depends(get_db)):
+ service = DeviceModelRelationService(db)
+ models = list(service.get_device_models(device_id))
+ return standard_response(data=models)
+
+
+# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
+# def add_by_device(relation_data: List[DeviceModelRelationCreate],
+# device_id: int = Query(...),
+# db: Session = Depends(get_db)):
+# service = DeviceModelRelationService(db)
+# relations = service.add_relations_by_device(device_id, relation_data)
+# return standard_response(data=relations)
+
+
+@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
+def update_by_device(relation_data: List[DeviceModelRelationCreate],
+ device_id: int = Query(...),
+ db: Session = Depends(get_db)):
+ service = DeviceModelRelationService(db)
+ relations = service.update_relations_by_device(device_id, relation_data)
+ return standard_response(data=relations)
+
+
+# @router.delete("/delete_by_device", response_model=StandardResponse[int])
+# def delete_device(device_id: int, db: Session = Depends(get_db)):
+# service = DeviceModelRelationService(db)
+# count = service.delete_relations_by_device(device_id)
+# return standard_response(data=count)
diff --git a/apis/model.py b/apis/model.py
new file mode 100644
index 0000000..8adcc14
--- /dev/null
+++ b/apis/model.py
@@ -0,0 +1,64 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, HTTPException, Query
+from sqlmodel import Session
+
+from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response
+from common.biz_exception import BizException
+from db.database import get_db
+from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo
+from services.model_service import ModelService
+
+router = APIRouter()
+
+
+@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]])
+def get_model_list(
+ name: Optional[str] = None,
+ remark: Optional[str] = None,
+ db: Session = Depends(get_db)):
+ service = ModelService(db)
+ models = list(service.get_model_list(name, remark))
+ return standard_response(data=models)
+
+
+@router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModel]])
+def get_model_page(
+ name: Optional[str] = None,
+ remark: Optional[str] = None,
+ offset: int = Query(0, ge=0), # 从第几条开始
+ limit: int = Query(10, ge=1), # 每页显示多少条记录
+ db: Session = Depends(get_db)):
+ service = ModelService(db)
+
+ # 获取分页后的设备列表和总数
+ models, total = service.get_model_page(name, remark, offset, limit)
+
+ return standard_response(
+ data=PageResponse(total=total, items=models)
+ )
+
+
+@router.post("/add", response_model=StandardResponse[AlgoModelInfo])
+def create_model(model_data: AlgoModelCreate, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.create_model(model_data)
+ return standard_response(data=model)
+
+
+@router.post("/update", response_model=StandardResponse[AlgoModelInfo])
+def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.update_model(model_data)
+ if not model:
+ return standard_error_response(data=model_data, message="Model not found")
+ return standard_response(data=model)
+
+
+@router.delete("/delete", response_model=StandardResponse[int])
+def delete_model(model_id: int, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.delete_model(model_id)
+ if not model:
+ return standard_error_response(data=model_id, message="Model not found")
+ return standard_response(data=model_id)
diff --git a/apis/router.py b/apis/router.py
new file mode 100644
index 0000000..22b5ed6
--- /dev/null
+++ b/apis/router.py
@@ -0,0 +1,15 @@
+# api/router.py
+from fastapi import APIRouter
+from .device import router as devices_router
+from .model import router as models_router
+from .device_model_realtion import router as device_model_relation_router
+
+
+# 创建一个全局的 router
+router = APIRouter()
+
+# 将各个模块的 router 注册到全局 router 中
+router.include_router(devices_router, prefix="/device", tags=["Devices"])
+router.include_router(models_router, prefix="/model", tags=["Models"])
+router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"])
+
diff --git a/common/__init__.py b/common/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/common/__init__.py
diff --git a/common/biz_exception.py b/common/biz_exception.py
new file mode 100644
index 0000000..946a794
--- /dev/null
+++ b/common/biz_exception.py
@@ -0,0 +1,14 @@
+from fastapi import Request, HTTPException
+
+from apis.base import standard_error_response
+
+
+class BizException(HTTPException):
+ def __init__(self, message: str = "请求异常", status_code: int = 500):
+ super().__init__(status_code=status_code, detail=message)
+
+
+class BizExceptionHandlers:
+ @staticmethod
+ async def biz_exception_handler(request: Request, exc: HTTPException):
+ standard_error_response(code=exc.status_code, message=exc.detail)
diff --git a/common/consts.py b/common/consts.py
new file mode 100644
index 0000000..bba41e4
--- /dev/null
+++ b/common/consts.py
@@ -0,0 +1,16 @@
+class Constants:
+ def __setattr__(self, name, value):
+ raise AttributeError(f"Cannot modify constant {name}")
+
+
+class NotifyChangeType(Constants):
+ # 定义常量成员
+ DEVICE_CREATE = "device_create"
+ DEVICE_UPDATE = "device_update"
+ DEVICE_DELETE = "device_delete"
+ MODEL_CREATE = "model_create" # 模型新增不用通知,与设备绑定时才生效
+ MODEL_UPDATE = "model_update"
+ MODEL_DELETE = "model_delete" # 正在使用的模型不能直接删除
+ DEVICE_MODEL_RELATION_CREATE = "device_model_relation_create"
+ DEVICE_MODEL_RELATION_UPDATE = "device_model_relation_update" # 绑定关系变化 应该只用这个吧??
+ DEVICE_MODEL_RELATION_DELETE = "device_model_relation_delete"
diff --git a/common/global_logger.py b/common/global_logger.py
new file mode 100644
index 0000000..1d65173
--- /dev/null
+++ b/common/global_logger.py
@@ -0,0 +1,30 @@
+# logger.py
+import logging.handlers
+import os
+
+# 确保日志目录存在
+log_dir = '../logs'
+if not os.path.exists(log_dir):
+ os.makedirs(log_dir)
+
+# 实例化并导出全局日志记录器
+logger = logging.getLogger("casic_safe_logger")
+logger.setLevel(logging.DEBUG) # 设置日志级别
+
+# 创建一个TimedRotatingFileHandler
+handler = logging.handlers.TimedRotatingFileHandler(
+ os.path.join(log_dir, 'app.log'), # 日志文件名
+ when='midnight', # 每天午夜滚动
+ interval=1 # 滚动间隔为1天
+)
+formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+handler.setFormatter(formatter)
+
+# 将handler添加到日志器
+logger.addHandler(handler)
+
+# 创建控制台处理器
+console_handler = logging.StreamHandler()
+console_handler.setLevel(logging.DEBUG)
+console_handler.setFormatter(formatter)
+logger.addHandler(console_handler)
diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py
new file mode 100644
index 0000000..21d7fd5
--- /dev/null
+++ b/common/global_thread_pool.py
@@ -0,0 +1,60 @@
+import uuid
+from concurrent.futures import ThreadPoolExecutor
+import threading
+
+
+def generate_thread_id():
+ """生成唯一的线程 ID"""
+ return str(uuid.uuid4())
+
+
+class GlobalThreadPool:
+ _instance = None
+ _lock = threading.Lock()
+
+ def __new__(cls, *args, **kwargs):
+ with cls._lock:
+ if cls._instance is None:
+ # 第一次创建实例时调用父类的 __new__ 来创建实例
+ cls._instance = super(GlobalThreadPool, cls).__new__(cls)
+ # 在此进行一次性的初始化,比如线程池的创建
+ max_workers = kwargs.get('max_workers', 10)
+ cls._instance.executor = ThreadPoolExecutor(max_workers=max_workers)
+ cls._instance.task_map = {} # 初始化任务映射
+ return cls._instance
+
+ def submit_task(self, fn, *args, thread_id=None, **kwargs):
+ """提交任务到线程池,并记录线程 ID"""
+ if thread_id is None:
+ thread_id = generate_thread_id()
+ if self.check_task_is_running(thread_id):
+ raise ValueError(f"线程 ID {thread_id} 已存在")
+ future = self.executor.submit(fn, *args, **kwargs)
+ self.task_map[thread_id] = future # 记录线程 ID 和 Future 对象的映射
+ return thread_id
+
+ def check_task_is_running(self, thread_id):
+ future = self.task_map.get(thread_id)
+ if future:
+ if future.running():
+ return True
+ else:
+ del self.task_map[thread_id]
+ return False
+ else:
+ return False
+
+ def stop_task(self, thread_id):
+ """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务"""
+ future = self.task_map.get(thread_id)
+ if future:
+ future.cancel() # 尝试取消任务
+ print(f"任务 {thread_id} 已取消")
+ del self.task_map[thread_id] # 从任务映射中删除
+ else:
+ print(f"未找到线程 ID {thread_id}")
+
+ def shutdown(self, wait=True):
+ """关闭线程池"""
+ self.executor.shutdown(wait=wait)
+ GlobalThreadPool._instance = None
diff --git a/db/__init__.py b/db/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/db/__init__.py
diff --git a/db/database.py b/db/database.py
new file mode 100644
index 0000000..4f44d2d
--- /dev/null
+++ b/db/database.py
@@ -0,0 +1,22 @@
+from sqlmodel import SQLModel, create_engine, Session
+from contextlib import contextmanager
+
+sqlite_file_name = "./db/safe-algo-pro.db"
+sqlite_url = f"sqlite:///{sqlite_file_name}"
+
+connect_args = {"check_same_thread": False}
+engine = create_engine(sqlite_url, connect_args=connect_args)
+
+
+# 初始化数据库表
+def init_db():
+ SQLModel.metadata.create_all(engine)
+
+
+# 数据库会话管理
+def get_db():
+ session = Session(engine)
+ try:
+ yield session
+ finally:
+ session.close()
diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db
new file mode 100644
index 0000000..86c8d64
--- /dev/null
+++ b/db/safe-algo-pro.db
Binary files differ
diff --git a/entity/__init__.py b/entity/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/entity/__init__.py
diff --git a/entity/base.py b/entity/base.py
new file mode 100644
index 0000000..6aa862c
--- /dev/null
+++ b/entity/base.py
@@ -0,0 +1,13 @@
+from sqlmodel import SQLModel, Field
+from typing import Optional
+from datetime import datetime
+
+
+class TimestampMixin(SQLModel):
+ create_time: Optional[datetime] = Field(default_factory=datetime.now)
+ update_time: Optional[datetime] = Field(default_factory=datetime.now)
+
+ class Config:
+ json_encoders = {
+ datetime: lambda v: v.strftime('%Y-%m-%d %H:%M:%S')
+ }
diff --git a/entity/device.py b/entity/device.py
new file mode 100644
index 0000000..f09b52d
--- /dev/null
+++ b/entity/device.py
@@ -0,0 +1,33 @@
+from typing import Optional
+from sqlmodel import SQLModel, Field
+
+from entity.base import TimestampMixin
+
+
+class DeviceBase(SQLModel):
+ name: str
+ code: str
+ type: Optional[str] = None
+ ip: str
+ input_stream_url: Optional[str] = None
+ output_stream_url: Optional[str] = None
+ image_save_interval: Optional[int] = None
+
+
+class Device(DeviceBase, TimestampMixin, table=True):
+ id: Optional[int] = Field(default=None, primary_key=True)
+
+
+class DeviceCreate(DeviceBase):
+ pass
+
+
+class DeviceUpdate(DeviceBase):
+ id: int
+ name: Optional[str] = None
+ code: Optional[str] = None
+ ip: Optional[str] = None
+
+
+class DeviceInfo(DeviceBase, TimestampMixin):
+ id: int
diff --git a/entity/device_model_relation.py b/entity/device_model_relation.py
new file mode 100644
index 0000000..2e58a39
--- /dev/null
+++ b/entity/device_model_relation.py
@@ -0,0 +1,31 @@
+from typing import Optional
+
+from sqlmodel import SQLModel, Field
+
+from entity.base import TimestampMixin
+
+
+class DeviceModelRelationBase(SQLModel):
+ algo_model_id: int
+ is_use: int
+ threshold: Optional[float] = None
+ alarm_interval: Optional[int] = None
+ alarm_type: Optional[str] = None
+
+
+class DeviceModelRelation(DeviceModelRelationBase, TimestampMixin, table=True):
+ __tablename__ = 'device_model_relation'
+ id: Optional[int] = Field(default=None, primary_key=True)
+ device_id: int
+
+class DeviceModelRelationCreate(DeviceModelRelationBase):
+ pass
+
+
+class DeviceModelRelationInfo(DeviceModelRelationBase, TimestampMixin):
+ id: int
+ device_id: int
+ algo_model_name: str
+ algo_model_version: str
+ algo_model_path: Optional[str] = None # 可选字段
+ algo_model_remark: Optional[str] = None
diff --git a/entity/model.py b/entity/model.py
new file mode 100644
index 0000000..db9133e
--- /dev/null
+++ b/entity/model.py
@@ -0,0 +1,33 @@
+from typing import Optional
+
+from sqlmodel import SQLModel, Field
+
+from entity.base import TimestampMixin
+
+
+class AlgoModelBase(SQLModel):
+ name: str
+ version: str
+ path: str
+ remark: Optional[str] = None
+
+
+class AlgoModel(AlgoModelBase, TimestampMixin, table=True):
+ __tablename__ = "algo_model" # 显式指定表名
+
+ id: Optional[int] = Field(default=None, primary_key=True)
+
+
+class AlgoModelCreate(AlgoModelBase):
+ pass
+
+
+class AlgoModelUpdate(AlgoModelBase):
+ id: int
+ name: Optional[str] = None
+ version: Optional[str] = None
+ path: Optional[str] = None
+
+
+class AlgoModelInfo(AlgoModelBase, TimestampMixin):
+ id: int
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..e530a86
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+test.py
+/logs/*
\ No newline at end of file
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..901fe74
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,20 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..b68d82f
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..3deddad
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml
new file mode 100644
index 0000000..d0876a7
--- /dev/null
+++ b/.idea/safe-algo-pro.iml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..94a25f7
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/algo/__init__.py b/algo/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/algo/__init__.py
diff --git a/algo/algo_runner.py b/algo/algo_runner.py
new file mode 100644
index 0000000..ee07eaf
--- /dev/null
+++ b/algo/algo_runner.py
@@ -0,0 +1,117 @@
+from algo.model_manager import ModelManager
+from common.consts import NotifyChangeType
+from common.global_thread_pool import GlobalThreadPool
+from db.database import get_db
+from entity.device import Device
+from services.device_model_relation_service import DeviceModelRelationService
+from services.device_service import DeviceService
+from services.model_service import ModelService
+
+
+class AlgoRunner:
+ def __init__(self):
+ self.db = get_db()
+ self.device_service = DeviceService(self.db)
+ self.model_service = ModelService(self.db)
+ self.relation_service = DeviceModelRelationService(self.db)
+ self.model_manager = ModelManager(self.db)
+ self.thread_pool = GlobalThreadPool()
+ self.threads = {} # 用于存储设备对应的线程
+ self.device_model_relations = {}
+
+ # 注册设备和模型的变化回调
+ self.device_service.register_change_callback(self.on_device_change)
+ self.model_service.register_change_callback(self.on_model_change)
+ self.relation_service.relation_change_callbacks(self.on_relation_change)
+
+ def start(self):
+ """在程序启动时调用,读取设备和模型,启动检测线程"""
+ self.model_manager.load_models()
+ self.load_and_start_devices()
+
+ def load_and_start_devices(self):
+ """从数据库读取设备列表并启动线程"""
+ devices = self.device_service.get_device_list()
+ devices = [device for device in devices if device.input_stream_url]
+ for device in devices:
+ self.start_device_thread(device)
+
+ def run_detection(self, device: Device):
+ """todo 设备目标检测主逻辑"""
+ print(f"设备 {device.device_id} 的检测线程启动")
+ video_stream = self.device_service.get_video_stream(device.device_id)
+
+ while True:
+ try:
+ frame = video_stream.read_frame()
+ for model in self.model_manager.models:
+ # 调用目标检测模型
+ result = self.model_service.run_inference(model, frame)
+ # 在此处进行后处理,例如标记视频、生成告警等
+ self.process_result(result, device.device_id)
+ except Exception as e:
+ print(f"设备 {device.device_id} 处理时出错: {e}")
+ break
+ video_stream.close()
+ print(f"设备 {device.device_id} 的检测线程结束")
+
+ def start_device_thread(self, device: Device):
+ """为单个设备启动检测线程"""
+
+ if not device.input_stream_url:
+ print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程')
+ return
+
+ if device.device_id in self.threads:
+ print(f"设备 {device.device_id} 已经在运行中")
+ return
+
+ # 获取设备绑定的模型列表
+ relations = self.relation_service.get_device_models(device.device_id)
+ self.device_model_relations[device.device_id] = relations
+
+ if not relations:
+ print(f"设备 {device.code} 未绑定模型,无法启动检测")
+ return
+
+ future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}')
+
+ def stop_device_thread(self, device_id):
+ """todo 控制标志位 停止指定设备的检测线程"""
+ if device_id in self.threads:
+ self.threads[device_id].cancel() # 尝试取消线程
+ del self.threads[device_id]
+ print(f"设备 {device_id} 的检测线程已停止")
+
+ def restart_device_thread(self, device_id):
+ self.stop_device_thread(device_id)
+ # todo 需要留个关闭时间吗
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+
+ def on_device_change(self, device_id, change_type):
+ """设备变化回调"""
+ if change_type == NotifyChangeType.DEVICE_CREATE:
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ elif change_type == NotifyChangeType.DEVICE_DELETE:
+ self.stop_device_thread(device_id)
+ elif change_type == NotifyChangeType.DEVICE_UPDATE:
+ self.restart_device_thread(device_id)
+
+ def on_model_change(self, model_id, change_type):
+ """模型变化回调"""
+ if change_type == NotifyChangeType.MODEL_UPDATE:
+ self.model_manager.reload_model(model_id)
+
+ devices_to_reload = [
+ device_id for device_id, relation_info in self.device_model_relations.items()
+ if relation_info.model_id == model_id
+ ]
+ for device_id in devices_to_reload:
+ self.restart_device_thread(device_id)
+
+ def on_relation_change(self, device_id, change_type):
+ """设备模型关系变化回调"""
+ if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE:
+ self.restart_device_thread(device_id)
diff --git a/algo/model_manager.py b/algo/model_manager.py
new file mode 100644
index 0000000..3ab0709
--- /dev/null
+++ b/algo/model_manager.py
@@ -0,0 +1,64 @@
+from dataclasses import dataclass
+from typing import Optional
+
+import numpy as np
+from ultralytics import YOLO
+
+from entity.model import AlgoModel
+from services.model_service import ModelService
+from common.global_logger import logger
+
+
+@dataclass
+class AlgoModelExec:
+ algo_model_id: int
+ algo_model_info: AlgoModel
+ algo_model_exec: Optional[object] = None
+ input_size: int = 640
+
+
+class ModelManager:
+ def __init__(self, db, model_warm_up=5):
+ self.db = db
+ self.model_service = ModelService(self.db)
+ self.models = {}
+ self.model_warm_up = model_warm_up
+
+ self.load_models()
+
+ def query_model_inuse(self):
+ algo_model_list = list(self.model_service.get_models_in_use())
+ for algo_model in algo_model_list:
+ self.models[algo_model.id] = AlgoModelExec(
+ algo_model_id=algo_model.id,
+ algo_model_info=algo_model
+ )
+
+ def load_models(self):
+ self.models = {}
+ self.query_model_inuse()
+ for algo_model_id, algo_model_exec in self.models.items():
+ self.load_model(algo_model_exec)
+
+ def load_model(self, algo_model_exec: AlgoModelExec):
+ model_name = algo_model_exec.algo_model_info.name
+ model_path = algo_model_exec.algo_model_info.path
+ logger.info(f'loading model {model_name}: {model_path}')
+
+ algo_model_exec.algo_model_exec = YOLO(model_path, task='detect')
+ if self.model_warm_up > 0:
+ logger.info(f'warming up model {model_name}')
+ imgsz = algo_model_exec.input_size
+ if not isinstance(imgsz, list):
+ imgsz = [imgsz, imgsz]
+ dummy_input = np.zeros((imgsz[0], imgsz[1], 3))
+ for i in range(self.model_warm_up):
+ algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False)
+ logger.info(f'warm up model {model_name} success!')
+ logger.info(f'load model {model_name} success!')
+
+ def reload_model(self, model_id):
+ algo_model_exec = self.models.get(model_id)
+ if algo_model_exec:
+ algo_model_exec.algo_model_exec = None
+ self.load_model(algo_model_exec)
diff --git a/apis/__init__.py b/apis/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/apis/__init__.py
diff --git a/apis/base.py b/apis/base.py
new file mode 100644
index 0000000..a7c4ede
--- /dev/null
+++ b/apis/base.py
@@ -0,0 +1,26 @@
+from pydantic import BaseModel, Field
+from typing import Optional, Generic, TypeVar, Any, List
+
+# 定义一个泛型类型变量 T
+T = TypeVar("T")
+
+
+# 使用泛型 T 定义 data 的类型
+class StandardResponse(BaseModel, Generic[T]):
+ code: int = 200
+ data: Optional[T] = None
+ message: str = "请求成功"
+ success: bool = True
+
+
+class PageResponse(BaseModel, Generic[T]):
+ total: int
+ items: List[T] = Field(default_factory=list)
+
+
+def standard_response(data: Any = None, code: int = 200, message: str = "请求成功", success: bool = True):
+ return StandardResponse(data=data, code=code, message=message, success=success)
+
+
+def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False):
+ return StandardResponse(data=data, code=code, message=message, success=success)
diff --git a/apis/device.py b/apis/device.py
new file mode 100644
index 0000000..7eeaa78
--- /dev/null
+++ b/apis/device.py
@@ -0,0 +1,65 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, Query
+from sqlmodel import Session
+
+from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response
+from db.database import get_db
+from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo
+from services.device_service import DeviceService
+
+router = APIRouter()
+
+
+@router.get("/list", response_model=StandardResponse[List[DeviceInfo]])
+def get_device_list(
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ devices = list(service.get_device_list(name, code, device_type))
+ return standard_response(data=devices)
+
+
+@router.get("/page/", response_model=StandardResponse[PageResponse[Device]])
+def get_device_page(
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ offset: int = Query(0, ge=0), # 从第几条开始
+ limit: int = Query(10, ge=1), # 每页显示多少条记录
+ db: Session = Depends(get_db)):
+ service = DeviceService(db)
+
+ # 获取分页后的设备列表和总数
+ devices, total = service.get_device_page(name, code, device_type, offset, limit)
+
+ return standard_response(
+ data=PageResponse(total=total, items=devices)
+ )
+
+
+@router.post("/add", response_model=StandardResponse[DeviceInfo])
+def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.create_device(device_data)
+ return standard_response(data=device)
+
+
+@router.post("/update", response_model=StandardResponse[DeviceInfo])
+def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.update_device(device_data)
+ if not device:
+ return standard_error_response(data=device_data, message="Device not found")
+ return standard_response(data=device)
+
+
+@router.delete("/delete", response_model=StandardResponse[int])
+def delete_device(device_id: int, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.delete_device(device_id)
+ if not device:
+ return standard_error_response(data=device_id, message="Device not found")
+ return standard_response(data=device_id)
diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py
new file mode 100644
index 0000000..7108571
--- /dev/null
+++ b/apis/device_model_realtion.py
@@ -0,0 +1,45 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, Query
+from sqlmodel import Session
+
+from apis.base import StandardResponse, standard_response
+from db.database import get_db
+from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation
+from services.device_model_relation_service import DeviceModelRelationService
+
+router = APIRouter()
+
+
+@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]])
+def list_by_device(
+ device_id: int,
+ db: Session = Depends(get_db)):
+ service = DeviceModelRelationService(db)
+ models = list(service.get_device_models(device_id))
+ return standard_response(data=models)
+
+
+# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
+# def add_by_device(relation_data: List[DeviceModelRelationCreate],
+# device_id: int = Query(...),
+# db: Session = Depends(get_db)):
+# service = DeviceModelRelationService(db)
+# relations = service.add_relations_by_device(device_id, relation_data)
+# return standard_response(data=relations)
+
+
+@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
+def update_by_device(relation_data: List[DeviceModelRelationCreate],
+ device_id: int = Query(...),
+ db: Session = Depends(get_db)):
+ service = DeviceModelRelationService(db)
+ relations = service.update_relations_by_device(device_id, relation_data)
+ return standard_response(data=relations)
+
+
+# @router.delete("/delete_by_device", response_model=StandardResponse[int])
+# def delete_device(device_id: int, db: Session = Depends(get_db)):
+# service = DeviceModelRelationService(db)
+# count = service.delete_relations_by_device(device_id)
+# return standard_response(data=count)
diff --git a/apis/model.py b/apis/model.py
new file mode 100644
index 0000000..8adcc14
--- /dev/null
+++ b/apis/model.py
@@ -0,0 +1,64 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, HTTPException, Query
+from sqlmodel import Session
+
+from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response
+from common.biz_exception import BizException
+from db.database import get_db
+from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo
+from services.model_service import ModelService
+
+router = APIRouter()
+
+
+@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]])
+def get_model_list(
+ name: Optional[str] = None,
+ remark: Optional[str] = None,
+ db: Session = Depends(get_db)):
+ service = ModelService(db)
+ models = list(service.get_model_list(name, remark))
+ return standard_response(data=models)
+
+
+@router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModel]])
+def get_model_page(
+ name: Optional[str] = None,
+ remark: Optional[str] = None,
+ offset: int = Query(0, ge=0), # 从第几条开始
+ limit: int = Query(10, ge=1), # 每页显示多少条记录
+ db: Session = Depends(get_db)):
+ service = ModelService(db)
+
+ # 获取分页后的设备列表和总数
+ models, total = service.get_model_page(name, remark, offset, limit)
+
+ return standard_response(
+ data=PageResponse(total=total, items=models)
+ )
+
+
+@router.post("/add", response_model=StandardResponse[AlgoModelInfo])
+def create_model(model_data: AlgoModelCreate, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.create_model(model_data)
+ return standard_response(data=model)
+
+
+@router.post("/update", response_model=StandardResponse[AlgoModelInfo])
+def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.update_model(model_data)
+ if not model:
+ return standard_error_response(data=model_data, message="Model not found")
+ return standard_response(data=model)
+
+
+@router.delete("/delete", response_model=StandardResponse[int])
+def delete_model(model_id: int, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.delete_model(model_id)
+ if not model:
+ return standard_error_response(data=model_id, message="Model not found")
+ return standard_response(data=model_id)
diff --git a/apis/router.py b/apis/router.py
new file mode 100644
index 0000000..22b5ed6
--- /dev/null
+++ b/apis/router.py
@@ -0,0 +1,15 @@
+# api/router.py
+from fastapi import APIRouter
+from .device import router as devices_router
+from .model import router as models_router
+from .device_model_realtion import router as device_model_relation_router
+
+
+# 创建一个全局的 router
+router = APIRouter()
+
+# 将各个模块的 router 注册到全局 router 中
+router.include_router(devices_router, prefix="/device", tags=["Devices"])
+router.include_router(models_router, prefix="/model", tags=["Models"])
+router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"])
+
diff --git a/common/__init__.py b/common/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/common/__init__.py
diff --git a/common/biz_exception.py b/common/biz_exception.py
new file mode 100644
index 0000000..946a794
--- /dev/null
+++ b/common/biz_exception.py
@@ -0,0 +1,14 @@
+from fastapi import Request, HTTPException
+
+from apis.base import standard_error_response
+
+
+class BizException(HTTPException):
+ def __init__(self, message: str = "请求异常", status_code: int = 500):
+ super().__init__(status_code=status_code, detail=message)
+
+
+class BizExceptionHandlers:
+ @staticmethod
+ async def biz_exception_handler(request: Request, exc: HTTPException):
+ standard_error_response(code=exc.status_code, message=exc.detail)
diff --git a/common/consts.py b/common/consts.py
new file mode 100644
index 0000000..bba41e4
--- /dev/null
+++ b/common/consts.py
@@ -0,0 +1,16 @@
+class Constants:
+ def __setattr__(self, name, value):
+ raise AttributeError(f"Cannot modify constant {name}")
+
+
+class NotifyChangeType(Constants):
+ # 定义常量成员
+ DEVICE_CREATE = "device_create"
+ DEVICE_UPDATE = "device_update"
+ DEVICE_DELETE = "device_delete"
+ MODEL_CREATE = "model_create" # 模型新增不用通知,与设备绑定时才生效
+ MODEL_UPDATE = "model_update"
+ MODEL_DELETE = "model_delete" # 正在使用的模型不能直接删除
+ DEVICE_MODEL_RELATION_CREATE = "device_model_relation_create"
+ DEVICE_MODEL_RELATION_UPDATE = "device_model_relation_update" # 绑定关系变化 应该只用这个吧??
+ DEVICE_MODEL_RELATION_DELETE = "device_model_relation_delete"
diff --git a/common/global_logger.py b/common/global_logger.py
new file mode 100644
index 0000000..1d65173
--- /dev/null
+++ b/common/global_logger.py
@@ -0,0 +1,30 @@
+# logger.py
+import logging.handlers
+import os
+
+# 确保日志目录存在
+log_dir = '../logs'
+if not os.path.exists(log_dir):
+ os.makedirs(log_dir)
+
+# 实例化并导出全局日志记录器
+logger = logging.getLogger("casic_safe_logger")
+logger.setLevel(logging.DEBUG) # 设置日志级别
+
+# 创建一个TimedRotatingFileHandler
+handler = logging.handlers.TimedRotatingFileHandler(
+ os.path.join(log_dir, 'app.log'), # 日志文件名
+ when='midnight', # 每天午夜滚动
+ interval=1 # 滚动间隔为1天
+)
+formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+handler.setFormatter(formatter)
+
+# 将handler添加到日志器
+logger.addHandler(handler)
+
+# 创建控制台处理器
+console_handler = logging.StreamHandler()
+console_handler.setLevel(logging.DEBUG)
+console_handler.setFormatter(formatter)
+logger.addHandler(console_handler)
diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py
new file mode 100644
index 0000000..21d7fd5
--- /dev/null
+++ b/common/global_thread_pool.py
@@ -0,0 +1,60 @@
+import uuid
+from concurrent.futures import ThreadPoolExecutor
+import threading
+
+
+def generate_thread_id():
+ """生成唯一的线程 ID"""
+ return str(uuid.uuid4())
+
+
+class GlobalThreadPool:
+ _instance = None
+ _lock = threading.Lock()
+
+ def __new__(cls, *args, **kwargs):
+ with cls._lock:
+ if cls._instance is None:
+ # 第一次创建实例时调用父类的 __new__ 来创建实例
+ cls._instance = super(GlobalThreadPool, cls).__new__(cls)
+ # 在此进行一次性的初始化,比如线程池的创建
+ max_workers = kwargs.get('max_workers', 10)
+ cls._instance.executor = ThreadPoolExecutor(max_workers=max_workers)
+ cls._instance.task_map = {} # 初始化任务映射
+ return cls._instance
+
+ def submit_task(self, fn, *args, thread_id=None, **kwargs):
+ """提交任务到线程池,并记录线程 ID"""
+ if thread_id is None:
+ thread_id = generate_thread_id()
+ if self.check_task_is_running(thread_id):
+ raise ValueError(f"线程 ID {thread_id} 已存在")
+ future = self.executor.submit(fn, *args, **kwargs)
+ self.task_map[thread_id] = future # 记录线程 ID 和 Future 对象的映射
+ return thread_id
+
+ def check_task_is_running(self, thread_id):
+ future = self.task_map.get(thread_id)
+ if future:
+ if future.running():
+ return True
+ else:
+ del self.task_map[thread_id]
+ return False
+ else:
+ return False
+
+ def stop_task(self, thread_id):
+ """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务"""
+ future = self.task_map.get(thread_id)
+ if future:
+ future.cancel() # 尝试取消任务
+ print(f"任务 {thread_id} 已取消")
+ del self.task_map[thread_id] # 从任务映射中删除
+ else:
+ print(f"未找到线程 ID {thread_id}")
+
+ def shutdown(self, wait=True):
+ """关闭线程池"""
+ self.executor.shutdown(wait=wait)
+ GlobalThreadPool._instance = None
diff --git a/db/__init__.py b/db/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/db/__init__.py
diff --git a/db/database.py b/db/database.py
new file mode 100644
index 0000000..4f44d2d
--- /dev/null
+++ b/db/database.py
@@ -0,0 +1,22 @@
+from sqlmodel import SQLModel, create_engine, Session
+from contextlib import contextmanager
+
+sqlite_file_name = "./db/safe-algo-pro.db"
+sqlite_url = f"sqlite:///{sqlite_file_name}"
+
+connect_args = {"check_same_thread": False}
+engine = create_engine(sqlite_url, connect_args=connect_args)
+
+
+# 初始化数据库表
+def init_db():
+ SQLModel.metadata.create_all(engine)
+
+
+# 数据库会话管理
+def get_db():
+ session = Session(engine)
+ try:
+ yield session
+ finally:
+ session.close()
diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db
new file mode 100644
index 0000000..86c8d64
--- /dev/null
+++ b/db/safe-algo-pro.db
Binary files differ
diff --git a/entity/__init__.py b/entity/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/entity/__init__.py
diff --git a/entity/base.py b/entity/base.py
new file mode 100644
index 0000000..6aa862c
--- /dev/null
+++ b/entity/base.py
@@ -0,0 +1,13 @@
+from sqlmodel import SQLModel, Field
+from typing import Optional
+from datetime import datetime
+
+
+class TimestampMixin(SQLModel):
+ create_time: Optional[datetime] = Field(default_factory=datetime.now)
+ update_time: Optional[datetime] = Field(default_factory=datetime.now)
+
+ class Config:
+ json_encoders = {
+ datetime: lambda v: v.strftime('%Y-%m-%d %H:%M:%S')
+ }
diff --git a/entity/device.py b/entity/device.py
new file mode 100644
index 0000000..f09b52d
--- /dev/null
+++ b/entity/device.py
@@ -0,0 +1,33 @@
+from typing import Optional
+from sqlmodel import SQLModel, Field
+
+from entity.base import TimestampMixin
+
+
+class DeviceBase(SQLModel):
+ name: str
+ code: str
+ type: Optional[str] = None
+ ip: str
+ input_stream_url: Optional[str] = None
+ output_stream_url: Optional[str] = None
+ image_save_interval: Optional[int] = None
+
+
+class Device(DeviceBase, TimestampMixin, table=True):
+ id: Optional[int] = Field(default=None, primary_key=True)
+
+
+class DeviceCreate(DeviceBase):
+ pass
+
+
+class DeviceUpdate(DeviceBase):
+ id: int
+ name: Optional[str] = None
+ code: Optional[str] = None
+ ip: Optional[str] = None
+
+
+class DeviceInfo(DeviceBase, TimestampMixin):
+ id: int
diff --git a/entity/device_model_relation.py b/entity/device_model_relation.py
new file mode 100644
index 0000000..2e58a39
--- /dev/null
+++ b/entity/device_model_relation.py
@@ -0,0 +1,31 @@
+from typing import Optional
+
+from sqlmodel import SQLModel, Field
+
+from entity.base import TimestampMixin
+
+
+class DeviceModelRelationBase(SQLModel):
+ algo_model_id: int
+ is_use: int
+ threshold: Optional[float] = None
+ alarm_interval: Optional[int] = None
+ alarm_type: Optional[str] = None
+
+
+class DeviceModelRelation(DeviceModelRelationBase, TimestampMixin, table=True):
+ __tablename__ = 'device_model_relation'
+ id: Optional[int] = Field(default=None, primary_key=True)
+ device_id: int
+
+class DeviceModelRelationCreate(DeviceModelRelationBase):
+ pass
+
+
+class DeviceModelRelationInfo(DeviceModelRelationBase, TimestampMixin):
+ id: int
+ device_id: int
+ algo_model_name: str
+ algo_model_version: str
+ algo_model_path: Optional[str] = None # 可选字段
+ algo_model_remark: Optional[str] = None
diff --git a/entity/model.py b/entity/model.py
new file mode 100644
index 0000000..db9133e
--- /dev/null
+++ b/entity/model.py
@@ -0,0 +1,33 @@
+from typing import Optional
+
+from sqlmodel import SQLModel, Field
+
+from entity.base import TimestampMixin
+
+
+class AlgoModelBase(SQLModel):
+ name: str
+ version: str
+ path: str
+ remark: Optional[str] = None
+
+
+class AlgoModel(AlgoModelBase, TimestampMixin, table=True):
+ __tablename__ = "algo_model" # 显式指定表名
+
+ id: Optional[int] = Field(default=None, primary_key=True)
+
+
+class AlgoModelCreate(AlgoModelBase):
+ pass
+
+
+class AlgoModelUpdate(AlgoModelBase):
+ id: int
+ name: Optional[str] = None
+ version: Optional[str] = None
+ path: Optional[str] = None
+
+
+class AlgoModelInfo(AlgoModelBase, TimestampMixin):
+ id: int
diff --git a/main.py b/main.py
new file mode 100644
index 0000000..75c9b52
--- /dev/null
+++ b/main.py
@@ -0,0 +1,15 @@
+from fastapi import FastAPI, HTTPException
+from apis.router import router
+import uvicorn
+
+from common.biz_exception import BizExceptionHandlers
+
+app = FastAPI()
+
+# 包含所有模块的路由
+app.include_router(router, prefix="/api")
+app.add_exception_handler(HTTPException, BizExceptionHandlers.http_exception_handler)
+
+if __name__ == "__main__":
+
+ uvicorn.run(app, host="0.0.0.0", port=9000)
\ No newline at end of file
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..e530a86
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+test.py
+/logs/*
\ No newline at end of file
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..901fe74
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,20 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..b68d82f
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..3deddad
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml
new file mode 100644
index 0000000..d0876a7
--- /dev/null
+++ b/.idea/safe-algo-pro.iml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..94a25f7
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/algo/__init__.py b/algo/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/algo/__init__.py
diff --git a/algo/algo_runner.py b/algo/algo_runner.py
new file mode 100644
index 0000000..ee07eaf
--- /dev/null
+++ b/algo/algo_runner.py
@@ -0,0 +1,117 @@
+from algo.model_manager import ModelManager
+from common.consts import NotifyChangeType
+from common.global_thread_pool import GlobalThreadPool
+from db.database import get_db
+from entity.device import Device
+from services.device_model_relation_service import DeviceModelRelationService
+from services.device_service import DeviceService
+from services.model_service import ModelService
+
+
+class AlgoRunner:
+ def __init__(self):
+ self.db = get_db()
+ self.device_service = DeviceService(self.db)
+ self.model_service = ModelService(self.db)
+ self.relation_service = DeviceModelRelationService(self.db)
+ self.model_manager = ModelManager(self.db)
+ self.thread_pool = GlobalThreadPool()
+ self.threads = {} # 用于存储设备对应的线程
+ self.device_model_relations = {}
+
+ # 注册设备和模型的变化回调
+ self.device_service.register_change_callback(self.on_device_change)
+ self.model_service.register_change_callback(self.on_model_change)
+ self.relation_service.relation_change_callbacks(self.on_relation_change)
+
+ def start(self):
+ """在程序启动时调用,读取设备和模型,启动检测线程"""
+ self.model_manager.load_models()
+ self.load_and_start_devices()
+
+ def load_and_start_devices(self):
+ """从数据库读取设备列表并启动线程"""
+ devices = self.device_service.get_device_list()
+ devices = [device for device in devices if device.input_stream_url]
+ for device in devices:
+ self.start_device_thread(device)
+
+ def run_detection(self, device: Device):
+ """todo 设备目标检测主逻辑"""
+ print(f"设备 {device.device_id} 的检测线程启动")
+ video_stream = self.device_service.get_video_stream(device.device_id)
+
+ while True:
+ try:
+ frame = video_stream.read_frame()
+ for model in self.model_manager.models:
+ # 调用目标检测模型
+ result = self.model_service.run_inference(model, frame)
+ # 在此处进行后处理,例如标记视频、生成告警等
+ self.process_result(result, device.device_id)
+ except Exception as e:
+ print(f"设备 {device.device_id} 处理时出错: {e}")
+ break
+ video_stream.close()
+ print(f"设备 {device.device_id} 的检测线程结束")
+
+ def start_device_thread(self, device: Device):
+ """为单个设备启动检测线程"""
+
+ if not device.input_stream_url:
+ print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程')
+ return
+
+ if device.device_id in self.threads:
+ print(f"设备 {device.device_id} 已经在运行中")
+ return
+
+ # 获取设备绑定的模型列表
+ relations = self.relation_service.get_device_models(device.device_id)
+ self.device_model_relations[device.device_id] = relations
+
+ if not relations:
+ print(f"设备 {device.code} 未绑定模型,无法启动检测")
+ return
+
+ future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}')
+
+ def stop_device_thread(self, device_id):
+ """todo 控制标志位 停止指定设备的检测线程"""
+ if device_id in self.threads:
+ self.threads[device_id].cancel() # 尝试取消线程
+ del self.threads[device_id]
+ print(f"设备 {device_id} 的检测线程已停止")
+
+ def restart_device_thread(self, device_id):
+ self.stop_device_thread(device_id)
+ # todo 需要留个关闭时间吗
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+
+ def on_device_change(self, device_id, change_type):
+ """设备变化回调"""
+ if change_type == NotifyChangeType.DEVICE_CREATE:
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ elif change_type == NotifyChangeType.DEVICE_DELETE:
+ self.stop_device_thread(device_id)
+ elif change_type == NotifyChangeType.DEVICE_UPDATE:
+ self.restart_device_thread(device_id)
+
+ def on_model_change(self, model_id, change_type):
+ """模型变化回调"""
+ if change_type == NotifyChangeType.MODEL_UPDATE:
+ self.model_manager.reload_model(model_id)
+
+ devices_to_reload = [
+ device_id for device_id, relation_info in self.device_model_relations.items()
+ if relation_info.model_id == model_id
+ ]
+ for device_id in devices_to_reload:
+ self.restart_device_thread(device_id)
+
+ def on_relation_change(self, device_id, change_type):
+ """设备模型关系变化回调"""
+ if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE:
+ self.restart_device_thread(device_id)
diff --git a/algo/model_manager.py b/algo/model_manager.py
new file mode 100644
index 0000000..3ab0709
--- /dev/null
+++ b/algo/model_manager.py
@@ -0,0 +1,64 @@
+from dataclasses import dataclass
+from typing import Optional
+
+import numpy as np
+from ultralytics import YOLO
+
+from entity.model import AlgoModel
+from services.model_service import ModelService
+from common.global_logger import logger
+
+
+@dataclass
+class AlgoModelExec:
+ algo_model_id: int
+ algo_model_info: AlgoModel
+ algo_model_exec: Optional[object] = None
+ input_size: int = 640
+
+
+class ModelManager:
+ def __init__(self, db, model_warm_up=5):
+ self.db = db
+ self.model_service = ModelService(self.db)
+ self.models = {}
+ self.model_warm_up = model_warm_up
+
+ self.load_models()
+
+ def query_model_inuse(self):
+ algo_model_list = list(self.model_service.get_models_in_use())
+ for algo_model in algo_model_list:
+ self.models[algo_model.id] = AlgoModelExec(
+ algo_model_id=algo_model.id,
+ algo_model_info=algo_model
+ )
+
+ def load_models(self):
+ self.models = {}
+ self.query_model_inuse()
+ for algo_model_id, algo_model_exec in self.models.items():
+ self.load_model(algo_model_exec)
+
+ def load_model(self, algo_model_exec: AlgoModelExec):
+ model_name = algo_model_exec.algo_model_info.name
+ model_path = algo_model_exec.algo_model_info.path
+ logger.info(f'loading model {model_name}: {model_path}')
+
+ algo_model_exec.algo_model_exec = YOLO(model_path, task='detect')
+ if self.model_warm_up > 0:
+ logger.info(f'warming up model {model_name}')
+ imgsz = algo_model_exec.input_size
+ if not isinstance(imgsz, list):
+ imgsz = [imgsz, imgsz]
+ dummy_input = np.zeros((imgsz[0], imgsz[1], 3))
+ for i in range(self.model_warm_up):
+ algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False)
+ logger.info(f'warm up model {model_name} success!')
+ logger.info(f'load model {model_name} success!')
+
+ def reload_model(self, model_id):
+ algo_model_exec = self.models.get(model_id)
+ if algo_model_exec:
+ algo_model_exec.algo_model_exec = None
+ self.load_model(algo_model_exec)
diff --git a/apis/__init__.py b/apis/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/apis/__init__.py
diff --git a/apis/base.py b/apis/base.py
new file mode 100644
index 0000000..a7c4ede
--- /dev/null
+++ b/apis/base.py
@@ -0,0 +1,26 @@
+from pydantic import BaseModel, Field
+from typing import Optional, Generic, TypeVar, Any, List
+
+# 定义一个泛型类型变量 T
+T = TypeVar("T")
+
+
+# 使用泛型 T 定义 data 的类型
+class StandardResponse(BaseModel, Generic[T]):
+ code: int = 200
+ data: Optional[T] = None
+ message: str = "请求成功"
+ success: bool = True
+
+
+class PageResponse(BaseModel, Generic[T]):
+ total: int
+ items: List[T] = Field(default_factory=list)
+
+
+def standard_response(data: Any = None, code: int = 200, message: str = "请求成功", success: bool = True):
+ return StandardResponse(data=data, code=code, message=message, success=success)
+
+
+def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False):
+ return StandardResponse(data=data, code=code, message=message, success=success)
diff --git a/apis/device.py b/apis/device.py
new file mode 100644
index 0000000..7eeaa78
--- /dev/null
+++ b/apis/device.py
@@ -0,0 +1,65 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, Query
+from sqlmodel import Session
+
+from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response
+from db.database import get_db
+from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo
+from services.device_service import DeviceService
+
+router = APIRouter()
+
+
+@router.get("/list", response_model=StandardResponse[List[DeviceInfo]])
+def get_device_list(
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ devices = list(service.get_device_list(name, code, device_type))
+ return standard_response(data=devices)
+
+
+@router.get("/page/", response_model=StandardResponse[PageResponse[Device]])
+def get_device_page(
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ offset: int = Query(0, ge=0), # 从第几条开始
+ limit: int = Query(10, ge=1), # 每页显示多少条记录
+ db: Session = Depends(get_db)):
+ service = DeviceService(db)
+
+ # 获取分页后的设备列表和总数
+ devices, total = service.get_device_page(name, code, device_type, offset, limit)
+
+ return standard_response(
+ data=PageResponse(total=total, items=devices)
+ )
+
+
+@router.post("/add", response_model=StandardResponse[DeviceInfo])
+def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.create_device(device_data)
+ return standard_response(data=device)
+
+
+@router.post("/update", response_model=StandardResponse[DeviceInfo])
+def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.update_device(device_data)
+ if not device:
+ return standard_error_response(data=device_data, message="Device not found")
+ return standard_response(data=device)
+
+
+@router.delete("/delete", response_model=StandardResponse[int])
+def delete_device(device_id: int, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.delete_device(device_id)
+ if not device:
+ return standard_error_response(data=device_id, message="Device not found")
+ return standard_response(data=device_id)
diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py
new file mode 100644
index 0000000..7108571
--- /dev/null
+++ b/apis/device_model_realtion.py
@@ -0,0 +1,45 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, Query
+from sqlmodel import Session
+
+from apis.base import StandardResponse, standard_response
+from db.database import get_db
+from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation
+from services.device_model_relation_service import DeviceModelRelationService
+
+router = APIRouter()
+
+
+@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]])
+def list_by_device(
+ device_id: int,
+ db: Session = Depends(get_db)):
+ service = DeviceModelRelationService(db)
+ models = list(service.get_device_models(device_id))
+ return standard_response(data=models)
+
+
+# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
+# def add_by_device(relation_data: List[DeviceModelRelationCreate],
+# device_id: int = Query(...),
+# db: Session = Depends(get_db)):
+# service = DeviceModelRelationService(db)
+# relations = service.add_relations_by_device(device_id, relation_data)
+# return standard_response(data=relations)
+
+
+@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
+def update_by_device(relation_data: List[DeviceModelRelationCreate],
+ device_id: int = Query(...),
+ db: Session = Depends(get_db)):
+ service = DeviceModelRelationService(db)
+ relations = service.update_relations_by_device(device_id, relation_data)
+ return standard_response(data=relations)
+
+
+# @router.delete("/delete_by_device", response_model=StandardResponse[int])
+# def delete_device(device_id: int, db: Session = Depends(get_db)):
+# service = DeviceModelRelationService(db)
+# count = service.delete_relations_by_device(device_id)
+# return standard_response(data=count)
diff --git a/apis/model.py b/apis/model.py
new file mode 100644
index 0000000..8adcc14
--- /dev/null
+++ b/apis/model.py
@@ -0,0 +1,64 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, HTTPException, Query
+from sqlmodel import Session
+
+from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response
+from common.biz_exception import BizException
+from db.database import get_db
+from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo
+from services.model_service import ModelService
+
+router = APIRouter()
+
+
+@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]])
+def get_model_list(
+ name: Optional[str] = None,
+ remark: Optional[str] = None,
+ db: Session = Depends(get_db)):
+ service = ModelService(db)
+ models = list(service.get_model_list(name, remark))
+ return standard_response(data=models)
+
+
+@router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModel]])
+def get_model_page(
+ name: Optional[str] = None,
+ remark: Optional[str] = None,
+ offset: int = Query(0, ge=0), # 从第几条开始
+ limit: int = Query(10, ge=1), # 每页显示多少条记录
+ db: Session = Depends(get_db)):
+ service = ModelService(db)
+
+ # 获取分页后的设备列表和总数
+ models, total = service.get_model_page(name, remark, offset, limit)
+
+ return standard_response(
+ data=PageResponse(total=total, items=models)
+ )
+
+
+@router.post("/add", response_model=StandardResponse[AlgoModelInfo])
+def create_model(model_data: AlgoModelCreate, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.create_model(model_data)
+ return standard_response(data=model)
+
+
+@router.post("/update", response_model=StandardResponse[AlgoModelInfo])
+def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.update_model(model_data)
+ if not model:
+ return standard_error_response(data=model_data, message="Model not found")
+ return standard_response(data=model)
+
+
+@router.delete("/delete", response_model=StandardResponse[int])
+def delete_model(model_id: int, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.delete_model(model_id)
+ if not model:
+ return standard_error_response(data=model_id, message="Model not found")
+ return standard_response(data=model_id)
diff --git a/apis/router.py b/apis/router.py
new file mode 100644
index 0000000..22b5ed6
--- /dev/null
+++ b/apis/router.py
@@ -0,0 +1,15 @@
+# api/router.py
+from fastapi import APIRouter
+from .device import router as devices_router
+from .model import router as models_router
+from .device_model_realtion import router as device_model_relation_router
+
+
+# 创建一个全局的 router
+router = APIRouter()
+
+# 将各个模块的 router 注册到全局 router 中
+router.include_router(devices_router, prefix="/device", tags=["Devices"])
+router.include_router(models_router, prefix="/model", tags=["Models"])
+router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"])
+
diff --git a/common/__init__.py b/common/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/common/__init__.py
diff --git a/common/biz_exception.py b/common/biz_exception.py
new file mode 100644
index 0000000..946a794
--- /dev/null
+++ b/common/biz_exception.py
@@ -0,0 +1,14 @@
+from fastapi import Request, HTTPException
+
+from apis.base import standard_error_response
+
+
+class BizException(HTTPException):
+ def __init__(self, message: str = "请求异常", status_code: int = 500):
+ super().__init__(status_code=status_code, detail=message)
+
+
+class BizExceptionHandlers:
+ @staticmethod
+ async def biz_exception_handler(request: Request, exc: HTTPException):
+ standard_error_response(code=exc.status_code, message=exc.detail)
diff --git a/common/consts.py b/common/consts.py
new file mode 100644
index 0000000..bba41e4
--- /dev/null
+++ b/common/consts.py
@@ -0,0 +1,16 @@
+class Constants:
+ def __setattr__(self, name, value):
+ raise AttributeError(f"Cannot modify constant {name}")
+
+
+class NotifyChangeType(Constants):
+ # 定义常量成员
+ DEVICE_CREATE = "device_create"
+ DEVICE_UPDATE = "device_update"
+ DEVICE_DELETE = "device_delete"
+ MODEL_CREATE = "model_create" # 模型新增不用通知,与设备绑定时才生效
+ MODEL_UPDATE = "model_update"
+ MODEL_DELETE = "model_delete" # 正在使用的模型不能直接删除
+ DEVICE_MODEL_RELATION_CREATE = "device_model_relation_create"
+ DEVICE_MODEL_RELATION_UPDATE = "device_model_relation_update" # 绑定关系变化 应该只用这个吧??
+ DEVICE_MODEL_RELATION_DELETE = "device_model_relation_delete"
diff --git a/common/global_logger.py b/common/global_logger.py
new file mode 100644
index 0000000..1d65173
--- /dev/null
+++ b/common/global_logger.py
@@ -0,0 +1,30 @@
+# logger.py
+import logging.handlers
+import os
+
+# 确保日志目录存在
+log_dir = '../logs'
+if not os.path.exists(log_dir):
+ os.makedirs(log_dir)
+
+# 实例化并导出全局日志记录器
+logger = logging.getLogger("casic_safe_logger")
+logger.setLevel(logging.DEBUG) # 设置日志级别
+
+# 创建一个TimedRotatingFileHandler
+handler = logging.handlers.TimedRotatingFileHandler(
+ os.path.join(log_dir, 'app.log'), # 日志文件名
+ when='midnight', # 每天午夜滚动
+ interval=1 # 滚动间隔为1天
+)
+formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+handler.setFormatter(formatter)
+
+# 将handler添加到日志器
+logger.addHandler(handler)
+
+# 创建控制台处理器
+console_handler = logging.StreamHandler()
+console_handler.setLevel(logging.DEBUG)
+console_handler.setFormatter(formatter)
+logger.addHandler(console_handler)
diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py
new file mode 100644
index 0000000..21d7fd5
--- /dev/null
+++ b/common/global_thread_pool.py
@@ -0,0 +1,60 @@
+import uuid
+from concurrent.futures import ThreadPoolExecutor
+import threading
+
+
+def generate_thread_id():
+ """生成唯一的线程 ID"""
+ return str(uuid.uuid4())
+
+
+class GlobalThreadPool:
+ _instance = None
+ _lock = threading.Lock()
+
+ def __new__(cls, *args, **kwargs):
+ with cls._lock:
+ if cls._instance is None:
+ # 第一次创建实例时调用父类的 __new__ 来创建实例
+ cls._instance = super(GlobalThreadPool, cls).__new__(cls)
+ # 在此进行一次性的初始化,比如线程池的创建
+ max_workers = kwargs.get('max_workers', 10)
+ cls._instance.executor = ThreadPoolExecutor(max_workers=max_workers)
+ cls._instance.task_map = {} # 初始化任务映射
+ return cls._instance
+
+ def submit_task(self, fn, *args, thread_id=None, **kwargs):
+ """提交任务到线程池,并记录线程 ID"""
+ if thread_id is None:
+ thread_id = generate_thread_id()
+ if self.check_task_is_running(thread_id):
+ raise ValueError(f"线程 ID {thread_id} 已存在")
+ future = self.executor.submit(fn, *args, **kwargs)
+ self.task_map[thread_id] = future # 记录线程 ID 和 Future 对象的映射
+ return thread_id
+
+ def check_task_is_running(self, thread_id):
+ future = self.task_map.get(thread_id)
+ if future:
+ if future.running():
+ return True
+ else:
+ del self.task_map[thread_id]
+ return False
+ else:
+ return False
+
+ def stop_task(self, thread_id):
+ """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务"""
+ future = self.task_map.get(thread_id)
+ if future:
+ future.cancel() # 尝试取消任务
+ print(f"任务 {thread_id} 已取消")
+ del self.task_map[thread_id] # 从任务映射中删除
+ else:
+ print(f"未找到线程 ID {thread_id}")
+
+ def shutdown(self, wait=True):
+ """关闭线程池"""
+ self.executor.shutdown(wait=wait)
+ GlobalThreadPool._instance = None
diff --git a/db/__init__.py b/db/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/db/__init__.py
diff --git a/db/database.py b/db/database.py
new file mode 100644
index 0000000..4f44d2d
--- /dev/null
+++ b/db/database.py
@@ -0,0 +1,22 @@
+from sqlmodel import SQLModel, create_engine, Session
+from contextlib import contextmanager
+
+sqlite_file_name = "./db/safe-algo-pro.db"
+sqlite_url = f"sqlite:///{sqlite_file_name}"
+
+connect_args = {"check_same_thread": False}
+engine = create_engine(sqlite_url, connect_args=connect_args)
+
+
+# 初始化数据库表
+def init_db():
+ SQLModel.metadata.create_all(engine)
+
+
+# 数据库会话管理
+def get_db():
+ session = Session(engine)
+ try:
+ yield session
+ finally:
+ session.close()
diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db
new file mode 100644
index 0000000..86c8d64
--- /dev/null
+++ b/db/safe-algo-pro.db
Binary files differ
diff --git a/entity/__init__.py b/entity/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/entity/__init__.py
diff --git a/entity/base.py b/entity/base.py
new file mode 100644
index 0000000..6aa862c
--- /dev/null
+++ b/entity/base.py
@@ -0,0 +1,13 @@
+from sqlmodel import SQLModel, Field
+from typing import Optional
+from datetime import datetime
+
+
+class TimestampMixin(SQLModel):
+ create_time: Optional[datetime] = Field(default_factory=datetime.now)
+ update_time: Optional[datetime] = Field(default_factory=datetime.now)
+
+ class Config:
+ json_encoders = {
+ datetime: lambda v: v.strftime('%Y-%m-%d %H:%M:%S')
+ }
diff --git a/entity/device.py b/entity/device.py
new file mode 100644
index 0000000..f09b52d
--- /dev/null
+++ b/entity/device.py
@@ -0,0 +1,33 @@
+from typing import Optional
+from sqlmodel import SQLModel, Field
+
+from entity.base import TimestampMixin
+
+
+class DeviceBase(SQLModel):
+ name: str
+ code: str
+ type: Optional[str] = None
+ ip: str
+ input_stream_url: Optional[str] = None
+ output_stream_url: Optional[str] = None
+ image_save_interval: Optional[int] = None
+
+
+class Device(DeviceBase, TimestampMixin, table=True):
+ id: Optional[int] = Field(default=None, primary_key=True)
+
+
+class DeviceCreate(DeviceBase):
+ pass
+
+
+class DeviceUpdate(DeviceBase):
+ id: int
+ name: Optional[str] = None
+ code: Optional[str] = None
+ ip: Optional[str] = None
+
+
+class DeviceInfo(DeviceBase, TimestampMixin):
+ id: int
diff --git a/entity/device_model_relation.py b/entity/device_model_relation.py
new file mode 100644
index 0000000..2e58a39
--- /dev/null
+++ b/entity/device_model_relation.py
@@ -0,0 +1,31 @@
+from typing import Optional
+
+from sqlmodel import SQLModel, Field
+
+from entity.base import TimestampMixin
+
+
+class DeviceModelRelationBase(SQLModel):
+ algo_model_id: int
+ is_use: int
+ threshold: Optional[float] = None
+ alarm_interval: Optional[int] = None
+ alarm_type: Optional[str] = None
+
+
+class DeviceModelRelation(DeviceModelRelationBase, TimestampMixin, table=True):
+ __tablename__ = 'device_model_relation'
+ id: Optional[int] = Field(default=None, primary_key=True)
+ device_id: int
+
+class DeviceModelRelationCreate(DeviceModelRelationBase):
+ pass
+
+
+class DeviceModelRelationInfo(DeviceModelRelationBase, TimestampMixin):
+ id: int
+ device_id: int
+ algo_model_name: str
+ algo_model_version: str
+ algo_model_path: Optional[str] = None # 可选字段
+ algo_model_remark: Optional[str] = None
diff --git a/entity/model.py b/entity/model.py
new file mode 100644
index 0000000..db9133e
--- /dev/null
+++ b/entity/model.py
@@ -0,0 +1,33 @@
+from typing import Optional
+
+from sqlmodel import SQLModel, Field
+
+from entity.base import TimestampMixin
+
+
+class AlgoModelBase(SQLModel):
+ name: str
+ version: str
+ path: str
+ remark: Optional[str] = None
+
+
+class AlgoModel(AlgoModelBase, TimestampMixin, table=True):
+ __tablename__ = "algo_model" # 显式指定表名
+
+ id: Optional[int] = Field(default=None, primary_key=True)
+
+
+class AlgoModelCreate(AlgoModelBase):
+ pass
+
+
+class AlgoModelUpdate(AlgoModelBase):
+ id: int
+ name: Optional[str] = None
+ version: Optional[str] = None
+ path: Optional[str] = None
+
+
+class AlgoModelInfo(AlgoModelBase, TimestampMixin):
+ id: int
diff --git a/main.py b/main.py
new file mode 100644
index 0000000..75c9b52
--- /dev/null
+++ b/main.py
@@ -0,0 +1,15 @@
+from fastapi import FastAPI, HTTPException
+from apis.router import router
+import uvicorn
+
+from common.biz_exception import BizExceptionHandlers
+
+app = FastAPI()
+
+# 包含所有模块的路由
+app.include_router(router, prefix="/api")
+app.add_exception_handler(HTTPException, BizExceptionHandlers.http_exception_handler)
+
+if __name__ == "__main__":
+
+ uvicorn.run(app, host="0.0.0.0", port=9000)
\ No newline at end of file
diff --git a/services/__init__.py b/services/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/services/__init__.py
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..e530a86
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+test.py
+/logs/*
\ No newline at end of file
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..901fe74
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,20 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..b68d82f
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..3deddad
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml
new file mode 100644
index 0000000..d0876a7
--- /dev/null
+++ b/.idea/safe-algo-pro.iml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..94a25f7
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/algo/__init__.py b/algo/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/algo/__init__.py
diff --git a/algo/algo_runner.py b/algo/algo_runner.py
new file mode 100644
index 0000000..ee07eaf
--- /dev/null
+++ b/algo/algo_runner.py
@@ -0,0 +1,117 @@
+from algo.model_manager import ModelManager
+from common.consts import NotifyChangeType
+from common.global_thread_pool import GlobalThreadPool
+from db.database import get_db
+from entity.device import Device
+from services.device_model_relation_service import DeviceModelRelationService
+from services.device_service import DeviceService
+from services.model_service import ModelService
+
+
+class AlgoRunner:
+ def __init__(self):
+ self.db = get_db()
+ self.device_service = DeviceService(self.db)
+ self.model_service = ModelService(self.db)
+ self.relation_service = DeviceModelRelationService(self.db)
+ self.model_manager = ModelManager(self.db)
+ self.thread_pool = GlobalThreadPool()
+ self.threads = {} # 用于存储设备对应的线程
+ self.device_model_relations = {}
+
+ # 注册设备和模型的变化回调
+ self.device_service.register_change_callback(self.on_device_change)
+ self.model_service.register_change_callback(self.on_model_change)
+ self.relation_service.relation_change_callbacks(self.on_relation_change)
+
+ def start(self):
+ """在程序启动时调用,读取设备和模型,启动检测线程"""
+ self.model_manager.load_models()
+ self.load_and_start_devices()
+
+ def load_and_start_devices(self):
+ """从数据库读取设备列表并启动线程"""
+ devices = self.device_service.get_device_list()
+ devices = [device for device in devices if device.input_stream_url]
+ for device in devices:
+ self.start_device_thread(device)
+
+ def run_detection(self, device: Device):
+ """todo 设备目标检测主逻辑"""
+ print(f"设备 {device.device_id} 的检测线程启动")
+ video_stream = self.device_service.get_video_stream(device.device_id)
+
+ while True:
+ try:
+ frame = video_stream.read_frame()
+ for model in self.model_manager.models:
+ # 调用目标检测模型
+ result = self.model_service.run_inference(model, frame)
+ # 在此处进行后处理,例如标记视频、生成告警等
+ self.process_result(result, device.device_id)
+ except Exception as e:
+ print(f"设备 {device.device_id} 处理时出错: {e}")
+ break
+ video_stream.close()
+ print(f"设备 {device.device_id} 的检测线程结束")
+
+ def start_device_thread(self, device: Device):
+ """为单个设备启动检测线程"""
+
+ if not device.input_stream_url:
+ print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程')
+ return
+
+ if device.device_id in self.threads:
+ print(f"设备 {device.device_id} 已经在运行中")
+ return
+
+ # 获取设备绑定的模型列表
+ relations = self.relation_service.get_device_models(device.device_id)
+ self.device_model_relations[device.device_id] = relations
+
+ if not relations:
+ print(f"设备 {device.code} 未绑定模型,无法启动检测")
+ return
+
+ future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}')
+
+ def stop_device_thread(self, device_id):
+ """todo 控制标志位 停止指定设备的检测线程"""
+ if device_id in self.threads:
+ self.threads[device_id].cancel() # 尝试取消线程
+ del self.threads[device_id]
+ print(f"设备 {device_id} 的检测线程已停止")
+
+ def restart_device_thread(self, device_id):
+ self.stop_device_thread(device_id)
+ # todo 需要留个关闭时间吗
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+
+ def on_device_change(self, device_id, change_type):
+ """设备变化回调"""
+ if change_type == NotifyChangeType.DEVICE_CREATE:
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ elif change_type == NotifyChangeType.DEVICE_DELETE:
+ self.stop_device_thread(device_id)
+ elif change_type == NotifyChangeType.DEVICE_UPDATE:
+ self.restart_device_thread(device_id)
+
+ def on_model_change(self, model_id, change_type):
+ """模型变化回调"""
+ if change_type == NotifyChangeType.MODEL_UPDATE:
+ self.model_manager.reload_model(model_id)
+
+ devices_to_reload = [
+ device_id for device_id, relation_info in self.device_model_relations.items()
+ if relation_info.model_id == model_id
+ ]
+ for device_id in devices_to_reload:
+ self.restart_device_thread(device_id)
+
+ def on_relation_change(self, device_id, change_type):
+ """设备模型关系变化回调"""
+ if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE:
+ self.restart_device_thread(device_id)
diff --git a/algo/model_manager.py b/algo/model_manager.py
new file mode 100644
index 0000000..3ab0709
--- /dev/null
+++ b/algo/model_manager.py
@@ -0,0 +1,64 @@
+from dataclasses import dataclass
+from typing import Optional
+
+import numpy as np
+from ultralytics import YOLO
+
+from entity.model import AlgoModel
+from services.model_service import ModelService
+from common.global_logger import logger
+
+
+@dataclass
+class AlgoModelExec:
+ algo_model_id: int
+ algo_model_info: AlgoModel
+ algo_model_exec: Optional[object] = None
+ input_size: int = 640
+
+
+class ModelManager:
+ def __init__(self, db, model_warm_up=5):
+ self.db = db
+ self.model_service = ModelService(self.db)
+ self.models = {}
+ self.model_warm_up = model_warm_up
+
+ self.load_models()
+
+ def query_model_inuse(self):
+ algo_model_list = list(self.model_service.get_models_in_use())
+ for algo_model in algo_model_list:
+ self.models[algo_model.id] = AlgoModelExec(
+ algo_model_id=algo_model.id,
+ algo_model_info=algo_model
+ )
+
+ def load_models(self):
+ self.models = {}
+ self.query_model_inuse()
+ for algo_model_id, algo_model_exec in self.models.items():
+ self.load_model(algo_model_exec)
+
+ def load_model(self, algo_model_exec: AlgoModelExec):
+ model_name = algo_model_exec.algo_model_info.name
+ model_path = algo_model_exec.algo_model_info.path
+ logger.info(f'loading model {model_name}: {model_path}')
+
+ algo_model_exec.algo_model_exec = YOLO(model_path, task='detect')
+ if self.model_warm_up > 0:
+ logger.info(f'warming up model {model_name}')
+ imgsz = algo_model_exec.input_size
+ if not isinstance(imgsz, list):
+ imgsz = [imgsz, imgsz]
+ dummy_input = np.zeros((imgsz[0], imgsz[1], 3))
+ for i in range(self.model_warm_up):
+ algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False)
+ logger.info(f'warm up model {model_name} success!')
+ logger.info(f'load model {model_name} success!')
+
+ def reload_model(self, model_id):
+ algo_model_exec = self.models.get(model_id)
+ if algo_model_exec:
+ algo_model_exec.algo_model_exec = None
+ self.load_model(algo_model_exec)
diff --git a/apis/__init__.py b/apis/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/apis/__init__.py
diff --git a/apis/base.py b/apis/base.py
new file mode 100644
index 0000000..a7c4ede
--- /dev/null
+++ b/apis/base.py
@@ -0,0 +1,26 @@
+from pydantic import BaseModel, Field
+from typing import Optional, Generic, TypeVar, Any, List
+
+# 定义一个泛型类型变量 T
+T = TypeVar("T")
+
+
+# 使用泛型 T 定义 data 的类型
+class StandardResponse(BaseModel, Generic[T]):
+ code: int = 200
+ data: Optional[T] = None
+ message: str = "请求成功"
+ success: bool = True
+
+
+class PageResponse(BaseModel, Generic[T]):
+ total: int
+ items: List[T] = Field(default_factory=list)
+
+
+def standard_response(data: Any = None, code: int = 200, message: str = "请求成功", success: bool = True):
+ return StandardResponse(data=data, code=code, message=message, success=success)
+
+
+def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False):
+ return StandardResponse(data=data, code=code, message=message, success=success)
diff --git a/apis/device.py b/apis/device.py
new file mode 100644
index 0000000..7eeaa78
--- /dev/null
+++ b/apis/device.py
@@ -0,0 +1,65 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, Query
+from sqlmodel import Session
+
+from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response
+from db.database import get_db
+from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo
+from services.device_service import DeviceService
+
+router = APIRouter()
+
+
+@router.get("/list", response_model=StandardResponse[List[DeviceInfo]])
+def get_device_list(
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ devices = list(service.get_device_list(name, code, device_type))
+ return standard_response(data=devices)
+
+
+@router.get("/page/", response_model=StandardResponse[PageResponse[Device]])
+def get_device_page(
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ offset: int = Query(0, ge=0), # 从第几条开始
+ limit: int = Query(10, ge=1), # 每页显示多少条记录
+ db: Session = Depends(get_db)):
+ service = DeviceService(db)
+
+ # 获取分页后的设备列表和总数
+ devices, total = service.get_device_page(name, code, device_type, offset, limit)
+
+ return standard_response(
+ data=PageResponse(total=total, items=devices)
+ )
+
+
+@router.post("/add", response_model=StandardResponse[DeviceInfo])
+def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.create_device(device_data)
+ return standard_response(data=device)
+
+
+@router.post("/update", response_model=StandardResponse[DeviceInfo])
+def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.update_device(device_data)
+ if not device:
+ return standard_error_response(data=device_data, message="Device not found")
+ return standard_response(data=device)
+
+
+@router.delete("/delete", response_model=StandardResponse[int])
+def delete_device(device_id: int, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.delete_device(device_id)
+ if not device:
+ return standard_error_response(data=device_id, message="Device not found")
+ return standard_response(data=device_id)
diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py
new file mode 100644
index 0000000..7108571
--- /dev/null
+++ b/apis/device_model_realtion.py
@@ -0,0 +1,45 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, Query
+from sqlmodel import Session
+
+from apis.base import StandardResponse, standard_response
+from db.database import get_db
+from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation
+from services.device_model_relation_service import DeviceModelRelationService
+
+router = APIRouter()
+
+
+@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]])
+def list_by_device(
+ device_id: int,
+ db: Session = Depends(get_db)):
+ service = DeviceModelRelationService(db)
+ models = list(service.get_device_models(device_id))
+ return standard_response(data=models)
+
+
+# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
+# def add_by_device(relation_data: List[DeviceModelRelationCreate],
+# device_id: int = Query(...),
+# db: Session = Depends(get_db)):
+# service = DeviceModelRelationService(db)
+# relations = service.add_relations_by_device(device_id, relation_data)
+# return standard_response(data=relations)
+
+
+@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
+def update_by_device(relation_data: List[DeviceModelRelationCreate],
+ device_id: int = Query(...),
+ db: Session = Depends(get_db)):
+ service = DeviceModelRelationService(db)
+ relations = service.update_relations_by_device(device_id, relation_data)
+ return standard_response(data=relations)
+
+
+# @router.delete("/delete_by_device", response_model=StandardResponse[int])
+# def delete_device(device_id: int, db: Session = Depends(get_db)):
+# service = DeviceModelRelationService(db)
+# count = service.delete_relations_by_device(device_id)
+# return standard_response(data=count)
diff --git a/apis/model.py b/apis/model.py
new file mode 100644
index 0000000..8adcc14
--- /dev/null
+++ b/apis/model.py
@@ -0,0 +1,64 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, HTTPException, Query
+from sqlmodel import Session
+
+from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response
+from common.biz_exception import BizException
+from db.database import get_db
+from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo
+from services.model_service import ModelService
+
+router = APIRouter()
+
+
+@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]])
+def get_model_list(
+ name: Optional[str] = None,
+ remark: Optional[str] = None,
+ db: Session = Depends(get_db)):
+ service = ModelService(db)
+ models = list(service.get_model_list(name, remark))
+ return standard_response(data=models)
+
+
+@router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModel]])
+def get_model_page(
+ name: Optional[str] = None,
+ remark: Optional[str] = None,
+ offset: int = Query(0, ge=0), # 从第几条开始
+ limit: int = Query(10, ge=1), # 每页显示多少条记录
+ db: Session = Depends(get_db)):
+ service = ModelService(db)
+
+ # 获取分页后的设备列表和总数
+ models, total = service.get_model_page(name, remark, offset, limit)
+
+ return standard_response(
+ data=PageResponse(total=total, items=models)
+ )
+
+
+@router.post("/add", response_model=StandardResponse[AlgoModelInfo])
+def create_model(model_data: AlgoModelCreate, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.create_model(model_data)
+ return standard_response(data=model)
+
+
+@router.post("/update", response_model=StandardResponse[AlgoModelInfo])
+def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.update_model(model_data)
+ if not model:
+ return standard_error_response(data=model_data, message="Model not found")
+ return standard_response(data=model)
+
+
+@router.delete("/delete", response_model=StandardResponse[int])
+def delete_model(model_id: int, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.delete_model(model_id)
+ if not model:
+ return standard_error_response(data=model_id, message="Model not found")
+ return standard_response(data=model_id)
diff --git a/apis/router.py b/apis/router.py
new file mode 100644
index 0000000..22b5ed6
--- /dev/null
+++ b/apis/router.py
@@ -0,0 +1,15 @@
+# api/router.py
+from fastapi import APIRouter
+from .device import router as devices_router
+from .model import router as models_router
+from .device_model_realtion import router as device_model_relation_router
+
+
+# 创建一个全局的 router
+router = APIRouter()
+
+# 将各个模块的 router 注册到全局 router 中
+router.include_router(devices_router, prefix="/device", tags=["Devices"])
+router.include_router(models_router, prefix="/model", tags=["Models"])
+router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"])
+
diff --git a/common/__init__.py b/common/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/common/__init__.py
diff --git a/common/biz_exception.py b/common/biz_exception.py
new file mode 100644
index 0000000..946a794
--- /dev/null
+++ b/common/biz_exception.py
@@ -0,0 +1,14 @@
+from fastapi import Request, HTTPException
+
+from apis.base import standard_error_response
+
+
+class BizException(HTTPException):
+ def __init__(self, message: str = "请求异常", status_code: int = 500):
+ super().__init__(status_code=status_code, detail=message)
+
+
+class BizExceptionHandlers:
+ @staticmethod
+ async def biz_exception_handler(request: Request, exc: HTTPException):
+ standard_error_response(code=exc.status_code, message=exc.detail)
diff --git a/common/consts.py b/common/consts.py
new file mode 100644
index 0000000..bba41e4
--- /dev/null
+++ b/common/consts.py
@@ -0,0 +1,16 @@
+class Constants:
+ def __setattr__(self, name, value):
+ raise AttributeError(f"Cannot modify constant {name}")
+
+
+class NotifyChangeType(Constants):
+ # 定义常量成员
+ DEVICE_CREATE = "device_create"
+ DEVICE_UPDATE = "device_update"
+ DEVICE_DELETE = "device_delete"
+ MODEL_CREATE = "model_create" # 模型新增不用通知,与设备绑定时才生效
+ MODEL_UPDATE = "model_update"
+ MODEL_DELETE = "model_delete" # 正在使用的模型不能直接删除
+ DEVICE_MODEL_RELATION_CREATE = "device_model_relation_create"
+ DEVICE_MODEL_RELATION_UPDATE = "device_model_relation_update" # 绑定关系变化 应该只用这个吧??
+ DEVICE_MODEL_RELATION_DELETE = "device_model_relation_delete"
diff --git a/common/global_logger.py b/common/global_logger.py
new file mode 100644
index 0000000..1d65173
--- /dev/null
+++ b/common/global_logger.py
@@ -0,0 +1,30 @@
+# logger.py
+import logging.handlers
+import os
+
+# 确保日志目录存在
+log_dir = '../logs'
+if not os.path.exists(log_dir):
+ os.makedirs(log_dir)
+
+# 实例化并导出全局日志记录器
+logger = logging.getLogger("casic_safe_logger")
+logger.setLevel(logging.DEBUG) # 设置日志级别
+
+# 创建一个TimedRotatingFileHandler
+handler = logging.handlers.TimedRotatingFileHandler(
+ os.path.join(log_dir, 'app.log'), # 日志文件名
+ when='midnight', # 每天午夜滚动
+ interval=1 # 滚动间隔为1天
+)
+formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+handler.setFormatter(formatter)
+
+# 将handler添加到日志器
+logger.addHandler(handler)
+
+# 创建控制台处理器
+console_handler = logging.StreamHandler()
+console_handler.setLevel(logging.DEBUG)
+console_handler.setFormatter(formatter)
+logger.addHandler(console_handler)
diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py
new file mode 100644
index 0000000..21d7fd5
--- /dev/null
+++ b/common/global_thread_pool.py
@@ -0,0 +1,60 @@
+import uuid
+from concurrent.futures import ThreadPoolExecutor
+import threading
+
+
+def generate_thread_id():
+ """生成唯一的线程 ID"""
+ return str(uuid.uuid4())
+
+
+class GlobalThreadPool:
+ _instance = None
+ _lock = threading.Lock()
+
+ def __new__(cls, *args, **kwargs):
+ with cls._lock:
+ if cls._instance is None:
+ # 第一次创建实例时调用父类的 __new__ 来创建实例
+ cls._instance = super(GlobalThreadPool, cls).__new__(cls)
+ # 在此进行一次性的初始化,比如线程池的创建
+ max_workers = kwargs.get('max_workers', 10)
+ cls._instance.executor = ThreadPoolExecutor(max_workers=max_workers)
+ cls._instance.task_map = {} # 初始化任务映射
+ return cls._instance
+
+ def submit_task(self, fn, *args, thread_id=None, **kwargs):
+ """提交任务到线程池,并记录线程 ID"""
+ if thread_id is None:
+ thread_id = generate_thread_id()
+ if self.check_task_is_running(thread_id):
+ raise ValueError(f"线程 ID {thread_id} 已存在")
+ future = self.executor.submit(fn, *args, **kwargs)
+ self.task_map[thread_id] = future # 记录线程 ID 和 Future 对象的映射
+ return thread_id
+
+ def check_task_is_running(self, thread_id):
+ future = self.task_map.get(thread_id)
+ if future:
+ if future.running():
+ return True
+ else:
+ del self.task_map[thread_id]
+ return False
+ else:
+ return False
+
+ def stop_task(self, thread_id):
+ """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务"""
+ future = self.task_map.get(thread_id)
+ if future:
+ future.cancel() # 尝试取消任务
+ print(f"任务 {thread_id} 已取消")
+ del self.task_map[thread_id] # 从任务映射中删除
+ else:
+ print(f"未找到线程 ID {thread_id}")
+
+ def shutdown(self, wait=True):
+ """关闭线程池"""
+ self.executor.shutdown(wait=wait)
+ GlobalThreadPool._instance = None
diff --git a/db/__init__.py b/db/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/db/__init__.py
diff --git a/db/database.py b/db/database.py
new file mode 100644
index 0000000..4f44d2d
--- /dev/null
+++ b/db/database.py
@@ -0,0 +1,22 @@
+from sqlmodel import SQLModel, create_engine, Session
+from contextlib import contextmanager
+
+sqlite_file_name = "./db/safe-algo-pro.db"
+sqlite_url = f"sqlite:///{sqlite_file_name}"
+
+connect_args = {"check_same_thread": False}
+engine = create_engine(sqlite_url, connect_args=connect_args)
+
+
+# 初始化数据库表
+def init_db():
+ SQLModel.metadata.create_all(engine)
+
+
+# 数据库会话管理
+def get_db():
+ session = Session(engine)
+ try:
+ yield session
+ finally:
+ session.close()
diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db
new file mode 100644
index 0000000..86c8d64
--- /dev/null
+++ b/db/safe-algo-pro.db
Binary files differ
diff --git a/entity/__init__.py b/entity/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/entity/__init__.py
diff --git a/entity/base.py b/entity/base.py
new file mode 100644
index 0000000..6aa862c
--- /dev/null
+++ b/entity/base.py
@@ -0,0 +1,13 @@
+from sqlmodel import SQLModel, Field
+from typing import Optional
+from datetime import datetime
+
+
+class TimestampMixin(SQLModel):
+ create_time: Optional[datetime] = Field(default_factory=datetime.now)
+ update_time: Optional[datetime] = Field(default_factory=datetime.now)
+
+ class Config:
+ json_encoders = {
+ datetime: lambda v: v.strftime('%Y-%m-%d %H:%M:%S')
+ }
diff --git a/entity/device.py b/entity/device.py
new file mode 100644
index 0000000..f09b52d
--- /dev/null
+++ b/entity/device.py
@@ -0,0 +1,33 @@
+from typing import Optional
+from sqlmodel import SQLModel, Field
+
+from entity.base import TimestampMixin
+
+
+class DeviceBase(SQLModel):
+ name: str
+ code: str
+ type: Optional[str] = None
+ ip: str
+ input_stream_url: Optional[str] = None
+ output_stream_url: Optional[str] = None
+ image_save_interval: Optional[int] = None
+
+
+class Device(DeviceBase, TimestampMixin, table=True):
+ id: Optional[int] = Field(default=None, primary_key=True)
+
+
+class DeviceCreate(DeviceBase):
+ pass
+
+
+class DeviceUpdate(DeviceBase):
+ id: int
+ name: Optional[str] = None
+ code: Optional[str] = None
+ ip: Optional[str] = None
+
+
+class DeviceInfo(DeviceBase, TimestampMixin):
+ id: int
diff --git a/entity/device_model_relation.py b/entity/device_model_relation.py
new file mode 100644
index 0000000..2e58a39
--- /dev/null
+++ b/entity/device_model_relation.py
@@ -0,0 +1,31 @@
+from typing import Optional
+
+from sqlmodel import SQLModel, Field
+
+from entity.base import TimestampMixin
+
+
+class DeviceModelRelationBase(SQLModel):
+ algo_model_id: int
+ is_use: int
+ threshold: Optional[float] = None
+ alarm_interval: Optional[int] = None
+ alarm_type: Optional[str] = None
+
+
+class DeviceModelRelation(DeviceModelRelationBase, TimestampMixin, table=True):
+ __tablename__ = 'device_model_relation'
+ id: Optional[int] = Field(default=None, primary_key=True)
+ device_id: int
+
+class DeviceModelRelationCreate(DeviceModelRelationBase):
+ pass
+
+
+class DeviceModelRelationInfo(DeviceModelRelationBase, TimestampMixin):
+ id: int
+ device_id: int
+ algo_model_name: str
+ algo_model_version: str
+ algo_model_path: Optional[str] = None # 可选字段
+ algo_model_remark: Optional[str] = None
diff --git a/entity/model.py b/entity/model.py
new file mode 100644
index 0000000..db9133e
--- /dev/null
+++ b/entity/model.py
@@ -0,0 +1,33 @@
+from typing import Optional
+
+from sqlmodel import SQLModel, Field
+
+from entity.base import TimestampMixin
+
+
+class AlgoModelBase(SQLModel):
+ name: str
+ version: str
+ path: str
+ remark: Optional[str] = None
+
+
+class AlgoModel(AlgoModelBase, TimestampMixin, table=True):
+ __tablename__ = "algo_model" # 显式指定表名
+
+ id: Optional[int] = Field(default=None, primary_key=True)
+
+
+class AlgoModelCreate(AlgoModelBase):
+ pass
+
+
+class AlgoModelUpdate(AlgoModelBase):
+ id: int
+ name: Optional[str] = None
+ version: Optional[str] = None
+ path: Optional[str] = None
+
+
+class AlgoModelInfo(AlgoModelBase, TimestampMixin):
+ id: int
diff --git a/main.py b/main.py
new file mode 100644
index 0000000..75c9b52
--- /dev/null
+++ b/main.py
@@ -0,0 +1,15 @@
+from fastapi import FastAPI, HTTPException
+from apis.router import router
+import uvicorn
+
+from common.biz_exception import BizExceptionHandlers
+
+app = FastAPI()
+
+# 包含所有模块的路由
+app.include_router(router, prefix="/api")
+app.add_exception_handler(HTTPException, BizExceptionHandlers.http_exception_handler)
+
+if __name__ == "__main__":
+
+ uvicorn.run(app, host="0.0.0.0", port=9000)
\ No newline at end of file
diff --git a/services/__init__.py b/services/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/services/__init__.py
diff --git a/services/device_model_relation_service.py b/services/device_model_relation_service.py
new file mode 100644
index 0000000..0f8d9a7
--- /dev/null
+++ b/services/device_model_relation_service.py
@@ -0,0 +1,85 @@
+from datetime import datetime
+from typing import List
+
+from sqlmodel import Session, select, delete
+
+from common.consts import NotifyChangeType
+from common.global_thread_pool import GlobalThreadPool
+from entity.device_model_relation import DeviceModelRelation, DeviceModelRelationInfo, DeviceModelRelationCreate
+from entity.model import AlgoModel
+
+
+class DeviceModelRelationService:
+ def __init__(self, db: Session):
+ self.db = db
+ self.relation_change_callbacks = [] # 用于存储回调函数
+ self.thread_pool = GlobalThreadPool()
+
+ def register_change_callback(self, callback):
+ """注册设备变化回调函数"""
+ self.relation_change_callbacks.append(callback)
+
+ def notify_change(self, device_id, change_type):
+ """当设备发生变化时,调用回调通知变化"""
+ for callback in self.relation_change_callbacks:
+ self.thread_pool.executor.submit(callback, device_id, change_type)
+
+ def get_device_models(self, device_id: int) -> List[DeviceModelRelationInfo]:
+ statement = (
+ select(DeviceModelRelation, AlgoModel)
+ .join(AlgoModel, DeviceModelRelation.algo_model_id == AlgoModel.id)
+ .where(DeviceModelRelation.device_id == device_id)
+ )
+
+ # 执行联表查询
+ result = self.db.exec(statement).all()
+
+ models_info = [
+ DeviceModelRelationInfo(
+ id=relation.id,
+ device_id=relation.id,
+ algo_model_id=relation.algo_model_id,
+ is_use=relation.is_use,
+ threshold=relation.threshold,
+ alarm_interval=relation.alarm_interval,
+ alarm_type=relation.alarm_type,
+ algo_model_name=model.name,
+ algo_model_version=model.version,
+ algo_model_path=model.path,
+ algo_model_remark=model.remark,
+ )
+ for relation, model in result
+ ]
+ return models_info
+
+ def add_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]):
+ new_relations = [
+ DeviceModelRelation(
+ algo_model_id=relation.algo_model_id,
+ is_use=relation.is_use,
+ threshold=relation.threshold,
+ alarm_interval=relation.alarm_interval,
+ alarm_type=relation.alarm_type,
+ device_id=device_id, # 统一赋值 device_id
+ createtime=datetime.now(),
+ updatetime=datetime.now(),
+ )
+ for relation in relations
+ ]
+ self.db.add_all(new_relations)
+ self.db.commit()
+ for relation in new_relations:
+ self.db.refresh(relation)
+ return new_relations
+
+ def delete_relations_by_device(self, device_id: int):
+ statement = delete(DeviceModelRelation).where(DeviceModelRelation.device_id == device_id)
+ count = self.db.exec(statement)
+ self.db.commit()
+ return count.rowcount
+
+ def update_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]):
+ self.delete_relations_by_device(device_id)
+ new_relations = self.add_relations_by_device(device_id, relations)
+ self.notify_change(device_id, NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE)
+ return new_relations
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..e530a86
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+test.py
+/logs/*
\ No newline at end of file
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..901fe74
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,20 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..b68d82f
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..3deddad
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml
new file mode 100644
index 0000000..d0876a7
--- /dev/null
+++ b/.idea/safe-algo-pro.iml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..94a25f7
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/algo/__init__.py b/algo/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/algo/__init__.py
diff --git a/algo/algo_runner.py b/algo/algo_runner.py
new file mode 100644
index 0000000..ee07eaf
--- /dev/null
+++ b/algo/algo_runner.py
@@ -0,0 +1,117 @@
+from algo.model_manager import ModelManager
+from common.consts import NotifyChangeType
+from common.global_thread_pool import GlobalThreadPool
+from db.database import get_db
+from entity.device import Device
+from services.device_model_relation_service import DeviceModelRelationService
+from services.device_service import DeviceService
+from services.model_service import ModelService
+
+
+class AlgoRunner:
+ def __init__(self):
+ self.db = get_db()
+ self.device_service = DeviceService(self.db)
+ self.model_service = ModelService(self.db)
+ self.relation_service = DeviceModelRelationService(self.db)
+ self.model_manager = ModelManager(self.db)
+ self.thread_pool = GlobalThreadPool()
+ self.threads = {} # 用于存储设备对应的线程
+ self.device_model_relations = {}
+
+ # 注册设备和模型的变化回调
+ self.device_service.register_change_callback(self.on_device_change)
+ self.model_service.register_change_callback(self.on_model_change)
+ self.relation_service.relation_change_callbacks(self.on_relation_change)
+
+ def start(self):
+ """在程序启动时调用,读取设备和模型,启动检测线程"""
+ self.model_manager.load_models()
+ self.load_and_start_devices()
+
+ def load_and_start_devices(self):
+ """从数据库读取设备列表并启动线程"""
+ devices = self.device_service.get_device_list()
+ devices = [device for device in devices if device.input_stream_url]
+ for device in devices:
+ self.start_device_thread(device)
+
+ def run_detection(self, device: Device):
+ """todo 设备目标检测主逻辑"""
+ print(f"设备 {device.device_id} 的检测线程启动")
+ video_stream = self.device_service.get_video_stream(device.device_id)
+
+ while True:
+ try:
+ frame = video_stream.read_frame()
+ for model in self.model_manager.models:
+ # 调用目标检测模型
+ result = self.model_service.run_inference(model, frame)
+ # 在此处进行后处理,例如标记视频、生成告警等
+ self.process_result(result, device.device_id)
+ except Exception as e:
+ print(f"设备 {device.device_id} 处理时出错: {e}")
+ break
+ video_stream.close()
+ print(f"设备 {device.device_id} 的检测线程结束")
+
+ def start_device_thread(self, device: Device):
+ """为单个设备启动检测线程"""
+
+ if not device.input_stream_url:
+ print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程')
+ return
+
+ if device.device_id in self.threads:
+ print(f"设备 {device.device_id} 已经在运行中")
+ return
+
+ # 获取设备绑定的模型列表
+ relations = self.relation_service.get_device_models(device.device_id)
+ self.device_model_relations[device.device_id] = relations
+
+ if not relations:
+ print(f"设备 {device.code} 未绑定模型,无法启动检测")
+ return
+
+ future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}')
+
+ def stop_device_thread(self, device_id):
+ """todo 控制标志位 停止指定设备的检测线程"""
+ if device_id in self.threads:
+ self.threads[device_id].cancel() # 尝试取消线程
+ del self.threads[device_id]
+ print(f"设备 {device_id} 的检测线程已停止")
+
+ def restart_device_thread(self, device_id):
+ self.stop_device_thread(device_id)
+ # todo 需要留个关闭时间吗
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+
+ def on_device_change(self, device_id, change_type):
+ """设备变化回调"""
+ if change_type == NotifyChangeType.DEVICE_CREATE:
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ elif change_type == NotifyChangeType.DEVICE_DELETE:
+ self.stop_device_thread(device_id)
+ elif change_type == NotifyChangeType.DEVICE_UPDATE:
+ self.restart_device_thread(device_id)
+
+ def on_model_change(self, model_id, change_type):
+ """模型变化回调"""
+ if change_type == NotifyChangeType.MODEL_UPDATE:
+ self.model_manager.reload_model(model_id)
+
+ devices_to_reload = [
+ device_id for device_id, relation_info in self.device_model_relations.items()
+ if relation_info.model_id == model_id
+ ]
+ for device_id in devices_to_reload:
+ self.restart_device_thread(device_id)
+
+ def on_relation_change(self, device_id, change_type):
+ """设备模型关系变化回调"""
+ if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE:
+ self.restart_device_thread(device_id)
diff --git a/algo/model_manager.py b/algo/model_manager.py
new file mode 100644
index 0000000..3ab0709
--- /dev/null
+++ b/algo/model_manager.py
@@ -0,0 +1,64 @@
+from dataclasses import dataclass
+from typing import Optional
+
+import numpy as np
+from ultralytics import YOLO
+
+from entity.model import AlgoModel
+from services.model_service import ModelService
+from common.global_logger import logger
+
+
+@dataclass
+class AlgoModelExec:
+ algo_model_id: int
+ algo_model_info: AlgoModel
+ algo_model_exec: Optional[object] = None
+ input_size: int = 640
+
+
+class ModelManager:
+ def __init__(self, db, model_warm_up=5):
+ self.db = db
+ self.model_service = ModelService(self.db)
+ self.models = {}
+ self.model_warm_up = model_warm_up
+
+ self.load_models()
+
+ def query_model_inuse(self):
+ algo_model_list = list(self.model_service.get_models_in_use())
+ for algo_model in algo_model_list:
+ self.models[algo_model.id] = AlgoModelExec(
+ algo_model_id=algo_model.id,
+ algo_model_info=algo_model
+ )
+
+ def load_models(self):
+ self.models = {}
+ self.query_model_inuse()
+ for algo_model_id, algo_model_exec in self.models.items():
+ self.load_model(algo_model_exec)
+
+ def load_model(self, algo_model_exec: AlgoModelExec):
+ model_name = algo_model_exec.algo_model_info.name
+ model_path = algo_model_exec.algo_model_info.path
+ logger.info(f'loading model {model_name}: {model_path}')
+
+ algo_model_exec.algo_model_exec = YOLO(model_path, task='detect')
+ if self.model_warm_up > 0:
+ logger.info(f'warming up model {model_name}')
+ imgsz = algo_model_exec.input_size
+ if not isinstance(imgsz, list):
+ imgsz = [imgsz, imgsz]
+ dummy_input = np.zeros((imgsz[0], imgsz[1], 3))
+ for i in range(self.model_warm_up):
+ algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False)
+ logger.info(f'warm up model {model_name} success!')
+ logger.info(f'load model {model_name} success!')
+
+ def reload_model(self, model_id):
+ algo_model_exec = self.models.get(model_id)
+ if algo_model_exec:
+ algo_model_exec.algo_model_exec = None
+ self.load_model(algo_model_exec)
diff --git a/apis/__init__.py b/apis/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/apis/__init__.py
diff --git a/apis/base.py b/apis/base.py
new file mode 100644
index 0000000..a7c4ede
--- /dev/null
+++ b/apis/base.py
@@ -0,0 +1,26 @@
+from pydantic import BaseModel, Field
+from typing import Optional, Generic, TypeVar, Any, List
+
+# 定义一个泛型类型变量 T
+T = TypeVar("T")
+
+
+# 使用泛型 T 定义 data 的类型
+class StandardResponse(BaseModel, Generic[T]):
+ code: int = 200
+ data: Optional[T] = None
+ message: str = "请求成功"
+ success: bool = True
+
+
+class PageResponse(BaseModel, Generic[T]):
+ total: int
+ items: List[T] = Field(default_factory=list)
+
+
+def standard_response(data: Any = None, code: int = 200, message: str = "请求成功", success: bool = True):
+ return StandardResponse(data=data, code=code, message=message, success=success)
+
+
+def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False):
+ return StandardResponse(data=data, code=code, message=message, success=success)
diff --git a/apis/device.py b/apis/device.py
new file mode 100644
index 0000000..7eeaa78
--- /dev/null
+++ b/apis/device.py
@@ -0,0 +1,65 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, Query
+from sqlmodel import Session
+
+from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response
+from db.database import get_db
+from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo
+from services.device_service import DeviceService
+
+router = APIRouter()
+
+
+@router.get("/list", response_model=StandardResponse[List[DeviceInfo]])
+def get_device_list(
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ devices = list(service.get_device_list(name, code, device_type))
+ return standard_response(data=devices)
+
+
+@router.get("/page/", response_model=StandardResponse[PageResponse[Device]])
+def get_device_page(
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ offset: int = Query(0, ge=0), # 从第几条开始
+ limit: int = Query(10, ge=1), # 每页显示多少条记录
+ db: Session = Depends(get_db)):
+ service = DeviceService(db)
+
+ # 获取分页后的设备列表和总数
+ devices, total = service.get_device_page(name, code, device_type, offset, limit)
+
+ return standard_response(
+ data=PageResponse(total=total, items=devices)
+ )
+
+
+@router.post("/add", response_model=StandardResponse[DeviceInfo])
+def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.create_device(device_data)
+ return standard_response(data=device)
+
+
+@router.post("/update", response_model=StandardResponse[DeviceInfo])
+def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.update_device(device_data)
+ if not device:
+ return standard_error_response(data=device_data, message="Device not found")
+ return standard_response(data=device)
+
+
+@router.delete("/delete", response_model=StandardResponse[int])
+def delete_device(device_id: int, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.delete_device(device_id)
+ if not device:
+ return standard_error_response(data=device_id, message="Device not found")
+ return standard_response(data=device_id)
diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py
new file mode 100644
index 0000000..7108571
--- /dev/null
+++ b/apis/device_model_realtion.py
@@ -0,0 +1,45 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, Query
+from sqlmodel import Session
+
+from apis.base import StandardResponse, standard_response
+from db.database import get_db
+from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation
+from services.device_model_relation_service import DeviceModelRelationService
+
+router = APIRouter()
+
+
+@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]])
+def list_by_device(
+ device_id: int,
+ db: Session = Depends(get_db)):
+ service = DeviceModelRelationService(db)
+ models = list(service.get_device_models(device_id))
+ return standard_response(data=models)
+
+
+# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
+# def add_by_device(relation_data: List[DeviceModelRelationCreate],
+# device_id: int = Query(...),
+# db: Session = Depends(get_db)):
+# service = DeviceModelRelationService(db)
+# relations = service.add_relations_by_device(device_id, relation_data)
+# return standard_response(data=relations)
+
+
+@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
+def update_by_device(relation_data: List[DeviceModelRelationCreate],
+ device_id: int = Query(...),
+ db: Session = Depends(get_db)):
+ service = DeviceModelRelationService(db)
+ relations = service.update_relations_by_device(device_id, relation_data)
+ return standard_response(data=relations)
+
+
+# @router.delete("/delete_by_device", response_model=StandardResponse[int])
+# def delete_device(device_id: int, db: Session = Depends(get_db)):
+# service = DeviceModelRelationService(db)
+# count = service.delete_relations_by_device(device_id)
+# return standard_response(data=count)
diff --git a/apis/model.py b/apis/model.py
new file mode 100644
index 0000000..8adcc14
--- /dev/null
+++ b/apis/model.py
@@ -0,0 +1,64 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, HTTPException, Query
+from sqlmodel import Session
+
+from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response
+from common.biz_exception import BizException
+from db.database import get_db
+from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo
+from services.model_service import ModelService
+
+router = APIRouter()
+
+
+@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]])
+def get_model_list(
+ name: Optional[str] = None,
+ remark: Optional[str] = None,
+ db: Session = Depends(get_db)):
+ service = ModelService(db)
+ models = list(service.get_model_list(name, remark))
+ return standard_response(data=models)
+
+
+@router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModel]])
+def get_model_page(
+ name: Optional[str] = None,
+ remark: Optional[str] = None,
+ offset: int = Query(0, ge=0), # 从第几条开始
+ limit: int = Query(10, ge=1), # 每页显示多少条记录
+ db: Session = Depends(get_db)):
+ service = ModelService(db)
+
+ # 获取分页后的设备列表和总数
+ models, total = service.get_model_page(name, remark, offset, limit)
+
+ return standard_response(
+ data=PageResponse(total=total, items=models)
+ )
+
+
+@router.post("/add", response_model=StandardResponse[AlgoModelInfo])
+def create_model(model_data: AlgoModelCreate, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.create_model(model_data)
+ return standard_response(data=model)
+
+
+@router.post("/update", response_model=StandardResponse[AlgoModelInfo])
+def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.update_model(model_data)
+ if not model:
+ return standard_error_response(data=model_data, message="Model not found")
+ return standard_response(data=model)
+
+
+@router.delete("/delete", response_model=StandardResponse[int])
+def delete_model(model_id: int, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.delete_model(model_id)
+ if not model:
+ return standard_error_response(data=model_id, message="Model not found")
+ return standard_response(data=model_id)
diff --git a/apis/router.py b/apis/router.py
new file mode 100644
index 0000000..22b5ed6
--- /dev/null
+++ b/apis/router.py
@@ -0,0 +1,15 @@
+# api/router.py
+from fastapi import APIRouter
+from .device import router as devices_router
+from .model import router as models_router
+from .device_model_realtion import router as device_model_relation_router
+
+
+# 创建一个全局的 router
+router = APIRouter()
+
+# 将各个模块的 router 注册到全局 router 中
+router.include_router(devices_router, prefix="/device", tags=["Devices"])
+router.include_router(models_router, prefix="/model", tags=["Models"])
+router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"])
+
diff --git a/common/__init__.py b/common/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/common/__init__.py
diff --git a/common/biz_exception.py b/common/biz_exception.py
new file mode 100644
index 0000000..946a794
--- /dev/null
+++ b/common/biz_exception.py
@@ -0,0 +1,14 @@
+from fastapi import Request, HTTPException
+
+from apis.base import standard_error_response
+
+
+class BizException(HTTPException):
+ def __init__(self, message: str = "请求异常", status_code: int = 500):
+ super().__init__(status_code=status_code, detail=message)
+
+
+class BizExceptionHandlers:
+ @staticmethod
+ async def biz_exception_handler(request: Request, exc: HTTPException):
+ standard_error_response(code=exc.status_code, message=exc.detail)
diff --git a/common/consts.py b/common/consts.py
new file mode 100644
index 0000000..bba41e4
--- /dev/null
+++ b/common/consts.py
@@ -0,0 +1,16 @@
+class Constants:
+ def __setattr__(self, name, value):
+ raise AttributeError(f"Cannot modify constant {name}")
+
+
+class NotifyChangeType(Constants):
+ # 定义常量成员
+ DEVICE_CREATE = "device_create"
+ DEVICE_UPDATE = "device_update"
+ DEVICE_DELETE = "device_delete"
+ MODEL_CREATE = "model_create" # 模型新增不用通知,与设备绑定时才生效
+ MODEL_UPDATE = "model_update"
+ MODEL_DELETE = "model_delete" # 正在使用的模型不能直接删除
+ DEVICE_MODEL_RELATION_CREATE = "device_model_relation_create"
+ DEVICE_MODEL_RELATION_UPDATE = "device_model_relation_update" # 绑定关系变化 应该只用这个吧??
+ DEVICE_MODEL_RELATION_DELETE = "device_model_relation_delete"
diff --git a/common/global_logger.py b/common/global_logger.py
new file mode 100644
index 0000000..1d65173
--- /dev/null
+++ b/common/global_logger.py
@@ -0,0 +1,30 @@
+# logger.py
+import logging.handlers
+import os
+
+# 确保日志目录存在
+log_dir = '../logs'
+if not os.path.exists(log_dir):
+ os.makedirs(log_dir)
+
+# 实例化并导出全局日志记录器
+logger = logging.getLogger("casic_safe_logger")
+logger.setLevel(logging.DEBUG) # 设置日志级别
+
+# 创建一个TimedRotatingFileHandler
+handler = logging.handlers.TimedRotatingFileHandler(
+ os.path.join(log_dir, 'app.log'), # 日志文件名
+ when='midnight', # 每天午夜滚动
+ interval=1 # 滚动间隔为1天
+)
+formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+handler.setFormatter(formatter)
+
+# 将handler添加到日志器
+logger.addHandler(handler)
+
+# 创建控制台处理器
+console_handler = logging.StreamHandler()
+console_handler.setLevel(logging.DEBUG)
+console_handler.setFormatter(formatter)
+logger.addHandler(console_handler)
diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py
new file mode 100644
index 0000000..21d7fd5
--- /dev/null
+++ b/common/global_thread_pool.py
@@ -0,0 +1,60 @@
+import uuid
+from concurrent.futures import ThreadPoolExecutor
+import threading
+
+
+def generate_thread_id():
+ """生成唯一的线程 ID"""
+ return str(uuid.uuid4())
+
+
+class GlobalThreadPool:
+ _instance = None
+ _lock = threading.Lock()
+
+ def __new__(cls, *args, **kwargs):
+ with cls._lock:
+ if cls._instance is None:
+ # 第一次创建实例时调用父类的 __new__ 来创建实例
+ cls._instance = super(GlobalThreadPool, cls).__new__(cls)
+ # 在此进行一次性的初始化,比如线程池的创建
+ max_workers = kwargs.get('max_workers', 10)
+ cls._instance.executor = ThreadPoolExecutor(max_workers=max_workers)
+ cls._instance.task_map = {} # 初始化任务映射
+ return cls._instance
+
+ def submit_task(self, fn, *args, thread_id=None, **kwargs):
+ """提交任务到线程池,并记录线程 ID"""
+ if thread_id is None:
+ thread_id = generate_thread_id()
+ if self.check_task_is_running(thread_id):
+ raise ValueError(f"线程 ID {thread_id} 已存在")
+ future = self.executor.submit(fn, *args, **kwargs)
+ self.task_map[thread_id] = future # 记录线程 ID 和 Future 对象的映射
+ return thread_id
+
+ def check_task_is_running(self, thread_id):
+ future = self.task_map.get(thread_id)
+ if future:
+ if future.running():
+ return True
+ else:
+ del self.task_map[thread_id]
+ return False
+ else:
+ return False
+
+ def stop_task(self, thread_id):
+ """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务"""
+ future = self.task_map.get(thread_id)
+ if future:
+ future.cancel() # 尝试取消任务
+ print(f"任务 {thread_id} 已取消")
+ del self.task_map[thread_id] # 从任务映射中删除
+ else:
+ print(f"未找到线程 ID {thread_id}")
+
+ def shutdown(self, wait=True):
+ """关闭线程池"""
+ self.executor.shutdown(wait=wait)
+ GlobalThreadPool._instance = None
diff --git a/db/__init__.py b/db/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/db/__init__.py
diff --git a/db/database.py b/db/database.py
new file mode 100644
index 0000000..4f44d2d
--- /dev/null
+++ b/db/database.py
@@ -0,0 +1,22 @@
+from sqlmodel import SQLModel, create_engine, Session
+from contextlib import contextmanager
+
+sqlite_file_name = "./db/safe-algo-pro.db"
+sqlite_url = f"sqlite:///{sqlite_file_name}"
+
+connect_args = {"check_same_thread": False}
+engine = create_engine(sqlite_url, connect_args=connect_args)
+
+
+# 初始化数据库表
+def init_db():
+ SQLModel.metadata.create_all(engine)
+
+
+# 数据库会话管理
+def get_db():
+ session = Session(engine)
+ try:
+ yield session
+ finally:
+ session.close()
diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db
new file mode 100644
index 0000000..86c8d64
--- /dev/null
+++ b/db/safe-algo-pro.db
Binary files differ
diff --git a/entity/__init__.py b/entity/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/entity/__init__.py
diff --git a/entity/base.py b/entity/base.py
new file mode 100644
index 0000000..6aa862c
--- /dev/null
+++ b/entity/base.py
@@ -0,0 +1,13 @@
+from sqlmodel import SQLModel, Field
+from typing import Optional
+from datetime import datetime
+
+
+class TimestampMixin(SQLModel):
+ create_time: Optional[datetime] = Field(default_factory=datetime.now)
+ update_time: Optional[datetime] = Field(default_factory=datetime.now)
+
+ class Config:
+ json_encoders = {
+ datetime: lambda v: v.strftime('%Y-%m-%d %H:%M:%S')
+ }
diff --git a/entity/device.py b/entity/device.py
new file mode 100644
index 0000000..f09b52d
--- /dev/null
+++ b/entity/device.py
@@ -0,0 +1,33 @@
+from typing import Optional
+from sqlmodel import SQLModel, Field
+
+from entity.base import TimestampMixin
+
+
+class DeviceBase(SQLModel):
+ name: str
+ code: str
+ type: Optional[str] = None
+ ip: str
+ input_stream_url: Optional[str] = None
+ output_stream_url: Optional[str] = None
+ image_save_interval: Optional[int] = None
+
+
+class Device(DeviceBase, TimestampMixin, table=True):
+ id: Optional[int] = Field(default=None, primary_key=True)
+
+
+class DeviceCreate(DeviceBase):
+ pass
+
+
+class DeviceUpdate(DeviceBase):
+ id: int
+ name: Optional[str] = None
+ code: Optional[str] = None
+ ip: Optional[str] = None
+
+
+class DeviceInfo(DeviceBase, TimestampMixin):
+ id: int
diff --git a/entity/device_model_relation.py b/entity/device_model_relation.py
new file mode 100644
index 0000000..2e58a39
--- /dev/null
+++ b/entity/device_model_relation.py
@@ -0,0 +1,31 @@
+from typing import Optional
+
+from sqlmodel import SQLModel, Field
+
+from entity.base import TimestampMixin
+
+
+class DeviceModelRelationBase(SQLModel):
+ algo_model_id: int
+ is_use: int
+ threshold: Optional[float] = None
+ alarm_interval: Optional[int] = None
+ alarm_type: Optional[str] = None
+
+
+class DeviceModelRelation(DeviceModelRelationBase, TimestampMixin, table=True):
+ __tablename__ = 'device_model_relation'
+ id: Optional[int] = Field(default=None, primary_key=True)
+ device_id: int
+
+class DeviceModelRelationCreate(DeviceModelRelationBase):
+ pass
+
+
+class DeviceModelRelationInfo(DeviceModelRelationBase, TimestampMixin):
+ id: int
+ device_id: int
+ algo_model_name: str
+ algo_model_version: str
+ algo_model_path: Optional[str] = None # 可选字段
+ algo_model_remark: Optional[str] = None
diff --git a/entity/model.py b/entity/model.py
new file mode 100644
index 0000000..db9133e
--- /dev/null
+++ b/entity/model.py
@@ -0,0 +1,33 @@
+from typing import Optional
+
+from sqlmodel import SQLModel, Field
+
+from entity.base import TimestampMixin
+
+
+class AlgoModelBase(SQLModel):
+ name: str
+ version: str
+ path: str
+ remark: Optional[str] = None
+
+
+class AlgoModel(AlgoModelBase, TimestampMixin, table=True):
+ __tablename__ = "algo_model" # 显式指定表名
+
+ id: Optional[int] = Field(default=None, primary_key=True)
+
+
+class AlgoModelCreate(AlgoModelBase):
+ pass
+
+
+class AlgoModelUpdate(AlgoModelBase):
+ id: int
+ name: Optional[str] = None
+ version: Optional[str] = None
+ path: Optional[str] = None
+
+
+class AlgoModelInfo(AlgoModelBase, TimestampMixin):
+ id: int
diff --git a/main.py b/main.py
new file mode 100644
index 0000000..75c9b52
--- /dev/null
+++ b/main.py
@@ -0,0 +1,15 @@
+from fastapi import FastAPI, HTTPException
+from apis.router import router
+import uvicorn
+
+from common.biz_exception import BizExceptionHandlers
+
+app = FastAPI()
+
+# 包含所有模块的路由
+app.include_router(router, prefix="/api")
+app.add_exception_handler(HTTPException, BizExceptionHandlers.http_exception_handler)
+
+if __name__ == "__main__":
+
+ uvicorn.run(app, host="0.0.0.0", port=9000)
\ No newline at end of file
diff --git a/services/__init__.py b/services/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/services/__init__.py
diff --git a/services/device_model_relation_service.py b/services/device_model_relation_service.py
new file mode 100644
index 0000000..0f8d9a7
--- /dev/null
+++ b/services/device_model_relation_service.py
@@ -0,0 +1,85 @@
+from datetime import datetime
+from typing import List
+
+from sqlmodel import Session, select, delete
+
+from common.consts import NotifyChangeType
+from common.global_thread_pool import GlobalThreadPool
+from entity.device_model_relation import DeviceModelRelation, DeviceModelRelationInfo, DeviceModelRelationCreate
+from entity.model import AlgoModel
+
+
+class DeviceModelRelationService:
+ def __init__(self, db: Session):
+ self.db = db
+ self.relation_change_callbacks = [] # 用于存储回调函数
+ self.thread_pool = GlobalThreadPool()
+
+ def register_change_callback(self, callback):
+ """注册设备变化回调函数"""
+ self.relation_change_callbacks.append(callback)
+
+ def notify_change(self, device_id, change_type):
+ """当设备发生变化时,调用回调通知变化"""
+ for callback in self.relation_change_callbacks:
+ self.thread_pool.executor.submit(callback, device_id, change_type)
+
+ def get_device_models(self, device_id: int) -> List[DeviceModelRelationInfo]:
+ statement = (
+ select(DeviceModelRelation, AlgoModel)
+ .join(AlgoModel, DeviceModelRelation.algo_model_id == AlgoModel.id)
+ .where(DeviceModelRelation.device_id == device_id)
+ )
+
+ # 执行联表查询
+ result = self.db.exec(statement).all()
+
+ models_info = [
+ DeviceModelRelationInfo(
+ id=relation.id,
+ device_id=relation.id,
+ algo_model_id=relation.algo_model_id,
+ is_use=relation.is_use,
+ threshold=relation.threshold,
+ alarm_interval=relation.alarm_interval,
+ alarm_type=relation.alarm_type,
+ algo_model_name=model.name,
+ algo_model_version=model.version,
+ algo_model_path=model.path,
+ algo_model_remark=model.remark,
+ )
+ for relation, model in result
+ ]
+ return models_info
+
+ def add_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]):
+ new_relations = [
+ DeviceModelRelation(
+ algo_model_id=relation.algo_model_id,
+ is_use=relation.is_use,
+ threshold=relation.threshold,
+ alarm_interval=relation.alarm_interval,
+ alarm_type=relation.alarm_type,
+ device_id=device_id, # 统一赋值 device_id
+ createtime=datetime.now(),
+ updatetime=datetime.now(),
+ )
+ for relation in relations
+ ]
+ self.db.add_all(new_relations)
+ self.db.commit()
+ for relation in new_relations:
+ self.db.refresh(relation)
+ return new_relations
+
+ def delete_relations_by_device(self, device_id: int):
+ statement = delete(DeviceModelRelation).where(DeviceModelRelation.device_id == device_id)
+ count = self.db.exec(statement)
+ self.db.commit()
+ return count.rowcount
+
+ def update_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]):
+ self.delete_relations_by_device(device_id)
+ new_relations = self.add_relations_by_device(device_id, relations)
+ self.notify_change(device_id, NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE)
+ return new_relations
diff --git a/services/device_service.py b/services/device_service.py
new file mode 100644
index 0000000..0284c4e
--- /dev/null
+++ b/services/device_service.py
@@ -0,0 +1,108 @@
+from datetime import datetime
+from typing import Sequence, Optional, Tuple
+
+from sqlalchemy import func
+from sqlmodel import Session, select
+
+from common.global_thread_pool import GlobalThreadPool
+from common.consts import NotifyChangeType
+from entity.device import Device, DeviceCreate, DeviceUpdate
+from services.device_model_relation_service import DeviceModelRelationService
+
+
+class DeviceService:
+ def __init__(self, db: Session):
+ self.db = db
+ self.device_change_callbacks = [] # 用于存储回调函数
+ self.thread_pool = GlobalThreadPool()
+
+ def register_change_callback(self, callback):
+ """注册设备变化回调函数"""
+ self.device_change_callbacks.append(callback)
+
+ def notify_change(self, device_id, change_type):
+ """当设备发生变化时,调用回调通知变化"""
+ for callback in self.device_change_callbacks:
+ self.thread_pool.executor.submit(callback, device_id, change_type)
+
+ def get_device_list(self,
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ ) -> Sequence[Device]:
+ statement = self.device_query(code, device_type, name)
+ results = self.db.exec(statement)
+ return results.all()
+
+ def get_device_page(self,
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ offset: int = 0,
+ limit: int = 10
+ ) -> Tuple[Sequence[Device], int]:
+ statement = self.device_query(code, device_type, name)
+
+ # 查询总记录数
+ total_statement = select(func.count()).select_from(statement.subquery())
+ total = self.db.exec(total_statement).one()
+
+ # 添加分页限制
+ statement = statement.offset(offset).limit(limit)
+
+ # 执行查询并返回结果
+ results = self.db.exec(statement)
+ return results.all(), total # 返回分页数据和总数
+
+ def device_query(self, code, device_type, name):
+ # 构建查询语句
+ statement = select(Device)
+ if name:
+ statement = statement.where(Device.name.like(f"%{name}%"))
+ if code:
+ statement = statement.where(Device.code.like(f"%{code}%"))
+ if device_type:
+ statement = statement.where(Device.type == device_type)
+ return statement
+
+ def create_device(self, device_data: DeviceCreate):
+ device = Device.model_validate(device_data)
+ device.create_time = datetime.now()
+ device.update_time = datetime.now()
+ self.db.add(device)
+ self.db.commit()
+ self.db.refresh(device)
+ self.notify_change(device.id, NotifyChangeType.DEVICE_CREATE)
+ return device
+
+ def update_device(self, device_data: DeviceUpdate):
+ device = self.db.get(Device, device_data.id)
+ if not device:
+ return None
+
+ update_data = device_data.dict(exclude_unset=True)
+ for key, value in update_data.items():
+ setattr(device, key, value)
+
+ device.update_time = datetime.now()
+ self.db.add(device)
+ self.db.commit()
+ self.db.refresh(device)
+ self.notify_change(device.id, NotifyChangeType.DEVICE_UPDATE)
+ return device
+
+ def delete_device(self, device_id: int):
+ device = self.db.get(Device, device_id)
+ if not device:
+ return None
+
+ self.db.delete(device)
+ self.db.commit()
+ self.notify_change(device.id, NotifyChangeType.DEVICE_DELETE)
+
+ relation_service = DeviceModelRelationService(self.db)
+ relation_service.delete_relations_by_device(device_id)
+ return device
+
+ def get_device(self, device_id: int):
+ return self.db.get(Device, device_id)
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..e530a86
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+test.py
+/logs/*
\ No newline at end of file
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..901fe74
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,20 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..b68d82f
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..3deddad
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml
new file mode 100644
index 0000000..d0876a7
--- /dev/null
+++ b/.idea/safe-algo-pro.iml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..94a25f7
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/algo/__init__.py b/algo/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/algo/__init__.py
diff --git a/algo/algo_runner.py b/algo/algo_runner.py
new file mode 100644
index 0000000..ee07eaf
--- /dev/null
+++ b/algo/algo_runner.py
@@ -0,0 +1,117 @@
+from algo.model_manager import ModelManager
+from common.consts import NotifyChangeType
+from common.global_thread_pool import GlobalThreadPool
+from db.database import get_db
+from entity.device import Device
+from services.device_model_relation_service import DeviceModelRelationService
+from services.device_service import DeviceService
+from services.model_service import ModelService
+
+
+class AlgoRunner:
+ def __init__(self):
+ self.db = get_db()
+ self.device_service = DeviceService(self.db)
+ self.model_service = ModelService(self.db)
+ self.relation_service = DeviceModelRelationService(self.db)
+ self.model_manager = ModelManager(self.db)
+ self.thread_pool = GlobalThreadPool()
+ self.threads = {} # 用于存储设备对应的线程
+ self.device_model_relations = {}
+
+ # 注册设备和模型的变化回调
+ self.device_service.register_change_callback(self.on_device_change)
+ self.model_service.register_change_callback(self.on_model_change)
+ self.relation_service.relation_change_callbacks(self.on_relation_change)
+
+ def start(self):
+ """在程序启动时调用,读取设备和模型,启动检测线程"""
+ self.model_manager.load_models()
+ self.load_and_start_devices()
+
+ def load_and_start_devices(self):
+ """从数据库读取设备列表并启动线程"""
+ devices = self.device_service.get_device_list()
+ devices = [device for device in devices if device.input_stream_url]
+ for device in devices:
+ self.start_device_thread(device)
+
+ def run_detection(self, device: Device):
+ """todo 设备目标检测主逻辑"""
+ print(f"设备 {device.device_id} 的检测线程启动")
+ video_stream = self.device_service.get_video_stream(device.device_id)
+
+ while True:
+ try:
+ frame = video_stream.read_frame()
+ for model in self.model_manager.models:
+ # 调用目标检测模型
+ result = self.model_service.run_inference(model, frame)
+ # 在此处进行后处理,例如标记视频、生成告警等
+ self.process_result(result, device.device_id)
+ except Exception as e:
+ print(f"设备 {device.device_id} 处理时出错: {e}")
+ break
+ video_stream.close()
+ print(f"设备 {device.device_id} 的检测线程结束")
+
+ def start_device_thread(self, device: Device):
+ """为单个设备启动检测线程"""
+
+ if not device.input_stream_url:
+ print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程')
+ return
+
+ if device.device_id in self.threads:
+ print(f"设备 {device.device_id} 已经在运行中")
+ return
+
+ # 获取设备绑定的模型列表
+ relations = self.relation_service.get_device_models(device.device_id)
+ self.device_model_relations[device.device_id] = relations
+
+ if not relations:
+ print(f"设备 {device.code} 未绑定模型,无法启动检测")
+ return
+
+ future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}')
+
+ def stop_device_thread(self, device_id):
+ """todo 控制标志位 停止指定设备的检测线程"""
+ if device_id in self.threads:
+ self.threads[device_id].cancel() # 尝试取消线程
+ del self.threads[device_id]
+ print(f"设备 {device_id} 的检测线程已停止")
+
+ def restart_device_thread(self, device_id):
+ self.stop_device_thread(device_id)
+ # todo 需要留个关闭时间吗
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+
+ def on_device_change(self, device_id, change_type):
+ """设备变化回调"""
+ if change_type == NotifyChangeType.DEVICE_CREATE:
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ elif change_type == NotifyChangeType.DEVICE_DELETE:
+ self.stop_device_thread(device_id)
+ elif change_type == NotifyChangeType.DEVICE_UPDATE:
+ self.restart_device_thread(device_id)
+
+ def on_model_change(self, model_id, change_type):
+ """模型变化回调"""
+ if change_type == NotifyChangeType.MODEL_UPDATE:
+ self.model_manager.reload_model(model_id)
+
+ devices_to_reload = [
+ device_id for device_id, relation_info in self.device_model_relations.items()
+ if relation_info.model_id == model_id
+ ]
+ for device_id in devices_to_reload:
+ self.restart_device_thread(device_id)
+
+ def on_relation_change(self, device_id, change_type):
+ """设备模型关系变化回调"""
+ if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE:
+ self.restart_device_thread(device_id)
diff --git a/algo/model_manager.py b/algo/model_manager.py
new file mode 100644
index 0000000..3ab0709
--- /dev/null
+++ b/algo/model_manager.py
@@ -0,0 +1,64 @@
+from dataclasses import dataclass
+from typing import Optional
+
+import numpy as np
+from ultralytics import YOLO
+
+from entity.model import AlgoModel
+from services.model_service import ModelService
+from common.global_logger import logger
+
+
+@dataclass
+class AlgoModelExec:
+ algo_model_id: int
+ algo_model_info: AlgoModel
+ algo_model_exec: Optional[object] = None
+ input_size: int = 640
+
+
+class ModelManager:
+ def __init__(self, db, model_warm_up=5):
+ self.db = db
+ self.model_service = ModelService(self.db)
+ self.models = {}
+ self.model_warm_up = model_warm_up
+
+ self.load_models()
+
+ def query_model_inuse(self):
+ algo_model_list = list(self.model_service.get_models_in_use())
+ for algo_model in algo_model_list:
+ self.models[algo_model.id] = AlgoModelExec(
+ algo_model_id=algo_model.id,
+ algo_model_info=algo_model
+ )
+
+ def load_models(self):
+ self.models = {}
+ self.query_model_inuse()
+ for algo_model_id, algo_model_exec in self.models.items():
+ self.load_model(algo_model_exec)
+
+ def load_model(self, algo_model_exec: AlgoModelExec):
+ model_name = algo_model_exec.algo_model_info.name
+ model_path = algo_model_exec.algo_model_info.path
+ logger.info(f'loading model {model_name}: {model_path}')
+
+ algo_model_exec.algo_model_exec = YOLO(model_path, task='detect')
+ if self.model_warm_up > 0:
+ logger.info(f'warming up model {model_name}')
+ imgsz = algo_model_exec.input_size
+ if not isinstance(imgsz, list):
+ imgsz = [imgsz, imgsz]
+ dummy_input = np.zeros((imgsz[0], imgsz[1], 3))
+ for i in range(self.model_warm_up):
+ algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False)
+ logger.info(f'warm up model {model_name} success!')
+ logger.info(f'load model {model_name} success!')
+
+ def reload_model(self, model_id):
+ algo_model_exec = self.models.get(model_id)
+ if algo_model_exec:
+ algo_model_exec.algo_model_exec = None
+ self.load_model(algo_model_exec)
diff --git a/apis/__init__.py b/apis/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/apis/__init__.py
diff --git a/apis/base.py b/apis/base.py
new file mode 100644
index 0000000..a7c4ede
--- /dev/null
+++ b/apis/base.py
@@ -0,0 +1,26 @@
+from pydantic import BaseModel, Field
+from typing import Optional, Generic, TypeVar, Any, List
+
+# 定义一个泛型类型变量 T
+T = TypeVar("T")
+
+
+# 使用泛型 T 定义 data 的类型
+class StandardResponse(BaseModel, Generic[T]):
+ code: int = 200
+ data: Optional[T] = None
+ message: str = "请求成功"
+ success: bool = True
+
+
+class PageResponse(BaseModel, Generic[T]):
+ total: int
+ items: List[T] = Field(default_factory=list)
+
+
+def standard_response(data: Any = None, code: int = 200, message: str = "请求成功", success: bool = True):
+ return StandardResponse(data=data, code=code, message=message, success=success)
+
+
+def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False):
+ return StandardResponse(data=data, code=code, message=message, success=success)
diff --git a/apis/device.py b/apis/device.py
new file mode 100644
index 0000000..7eeaa78
--- /dev/null
+++ b/apis/device.py
@@ -0,0 +1,65 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, Query
+from sqlmodel import Session
+
+from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response
+from db.database import get_db
+from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo
+from services.device_service import DeviceService
+
+router = APIRouter()
+
+
+@router.get("/list", response_model=StandardResponse[List[DeviceInfo]])
+def get_device_list(
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ devices = list(service.get_device_list(name, code, device_type))
+ return standard_response(data=devices)
+
+
+@router.get("/page/", response_model=StandardResponse[PageResponse[Device]])
+def get_device_page(
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ offset: int = Query(0, ge=0), # 从第几条开始
+ limit: int = Query(10, ge=1), # 每页显示多少条记录
+ db: Session = Depends(get_db)):
+ service = DeviceService(db)
+
+ # 获取分页后的设备列表和总数
+ devices, total = service.get_device_page(name, code, device_type, offset, limit)
+
+ return standard_response(
+ data=PageResponse(total=total, items=devices)
+ )
+
+
+@router.post("/add", response_model=StandardResponse[DeviceInfo])
+def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.create_device(device_data)
+ return standard_response(data=device)
+
+
+@router.post("/update", response_model=StandardResponse[DeviceInfo])
+def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.update_device(device_data)
+ if not device:
+ return standard_error_response(data=device_data, message="Device not found")
+ return standard_response(data=device)
+
+
+@router.delete("/delete", response_model=StandardResponse[int])
+def delete_device(device_id: int, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.delete_device(device_id)
+ if not device:
+ return standard_error_response(data=device_id, message="Device not found")
+ return standard_response(data=device_id)
diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py
new file mode 100644
index 0000000..7108571
--- /dev/null
+++ b/apis/device_model_realtion.py
@@ -0,0 +1,45 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, Query
+from sqlmodel import Session
+
+from apis.base import StandardResponse, standard_response
+from db.database import get_db
+from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation
+from services.device_model_relation_service import DeviceModelRelationService
+
+router = APIRouter()
+
+
+@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]])
+def list_by_device(
+ device_id: int,
+ db: Session = Depends(get_db)):
+ service = DeviceModelRelationService(db)
+ models = list(service.get_device_models(device_id))
+ return standard_response(data=models)
+
+
+# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
+# def add_by_device(relation_data: List[DeviceModelRelationCreate],
+# device_id: int = Query(...),
+# db: Session = Depends(get_db)):
+# service = DeviceModelRelationService(db)
+# relations = service.add_relations_by_device(device_id, relation_data)
+# return standard_response(data=relations)
+
+
+@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
+def update_by_device(relation_data: List[DeviceModelRelationCreate],
+ device_id: int = Query(...),
+ db: Session = Depends(get_db)):
+ service = DeviceModelRelationService(db)
+ relations = service.update_relations_by_device(device_id, relation_data)
+ return standard_response(data=relations)
+
+
+# @router.delete("/delete_by_device", response_model=StandardResponse[int])
+# def delete_device(device_id: int, db: Session = Depends(get_db)):
+# service = DeviceModelRelationService(db)
+# count = service.delete_relations_by_device(device_id)
+# return standard_response(data=count)
diff --git a/apis/model.py b/apis/model.py
new file mode 100644
index 0000000..8adcc14
--- /dev/null
+++ b/apis/model.py
@@ -0,0 +1,64 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, HTTPException, Query
+from sqlmodel import Session
+
+from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response
+from common.biz_exception import BizException
+from db.database import get_db
+from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo
+from services.model_service import ModelService
+
+router = APIRouter()
+
+
+@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]])
+def get_model_list(
+ name: Optional[str] = None,
+ remark: Optional[str] = None,
+ db: Session = Depends(get_db)):
+ service = ModelService(db)
+ models = list(service.get_model_list(name, remark))
+ return standard_response(data=models)
+
+
+@router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModel]])
+def get_model_page(
+ name: Optional[str] = None,
+ remark: Optional[str] = None,
+ offset: int = Query(0, ge=0), # 从第几条开始
+ limit: int = Query(10, ge=1), # 每页显示多少条记录
+ db: Session = Depends(get_db)):
+ service = ModelService(db)
+
+ # 获取分页后的设备列表和总数
+ models, total = service.get_model_page(name, remark, offset, limit)
+
+ return standard_response(
+ data=PageResponse(total=total, items=models)
+ )
+
+
+@router.post("/add", response_model=StandardResponse[AlgoModelInfo])
+def create_model(model_data: AlgoModelCreate, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.create_model(model_data)
+ return standard_response(data=model)
+
+
+@router.post("/update", response_model=StandardResponse[AlgoModelInfo])
+def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.update_model(model_data)
+ if not model:
+ return standard_error_response(data=model_data, message="Model not found")
+ return standard_response(data=model)
+
+
+@router.delete("/delete", response_model=StandardResponse[int])
+def delete_model(model_id: int, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.delete_model(model_id)
+ if not model:
+ return standard_error_response(data=model_id, message="Model not found")
+ return standard_response(data=model_id)
diff --git a/apis/router.py b/apis/router.py
new file mode 100644
index 0000000..22b5ed6
--- /dev/null
+++ b/apis/router.py
@@ -0,0 +1,15 @@
+# api/router.py
+from fastapi import APIRouter
+from .device import router as devices_router
+from .model import router as models_router
+from .device_model_realtion import router as device_model_relation_router
+
+
+# 创建一个全局的 router
+router = APIRouter()
+
+# 将各个模块的 router 注册到全局 router 中
+router.include_router(devices_router, prefix="/device", tags=["Devices"])
+router.include_router(models_router, prefix="/model", tags=["Models"])
+router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"])
+
diff --git a/common/__init__.py b/common/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/common/__init__.py
diff --git a/common/biz_exception.py b/common/biz_exception.py
new file mode 100644
index 0000000..946a794
--- /dev/null
+++ b/common/biz_exception.py
@@ -0,0 +1,14 @@
+from fastapi import Request, HTTPException
+
+from apis.base import standard_error_response
+
+
+class BizException(HTTPException):
+ def __init__(self, message: str = "请求异常", status_code: int = 500):
+ super().__init__(status_code=status_code, detail=message)
+
+
+class BizExceptionHandlers:
+ @staticmethod
+ async def biz_exception_handler(request: Request, exc: HTTPException):
+ standard_error_response(code=exc.status_code, message=exc.detail)
diff --git a/common/consts.py b/common/consts.py
new file mode 100644
index 0000000..bba41e4
--- /dev/null
+++ b/common/consts.py
@@ -0,0 +1,16 @@
+class Constants:
+ def __setattr__(self, name, value):
+ raise AttributeError(f"Cannot modify constant {name}")
+
+
+class NotifyChangeType(Constants):
+ # 定义常量成员
+ DEVICE_CREATE = "device_create"
+ DEVICE_UPDATE = "device_update"
+ DEVICE_DELETE = "device_delete"
+ MODEL_CREATE = "model_create" # 模型新增不用通知,与设备绑定时才生效
+ MODEL_UPDATE = "model_update"
+ MODEL_DELETE = "model_delete" # 正在使用的模型不能直接删除
+ DEVICE_MODEL_RELATION_CREATE = "device_model_relation_create"
+ DEVICE_MODEL_RELATION_UPDATE = "device_model_relation_update" # 绑定关系变化 应该只用这个吧??
+ DEVICE_MODEL_RELATION_DELETE = "device_model_relation_delete"
diff --git a/common/global_logger.py b/common/global_logger.py
new file mode 100644
index 0000000..1d65173
--- /dev/null
+++ b/common/global_logger.py
@@ -0,0 +1,30 @@
+# logger.py
+import logging.handlers
+import os
+
+# 确保日志目录存在
+log_dir = '../logs'
+if not os.path.exists(log_dir):
+ os.makedirs(log_dir)
+
+# 实例化并导出全局日志记录器
+logger = logging.getLogger("casic_safe_logger")
+logger.setLevel(logging.DEBUG) # 设置日志级别
+
+# 创建一个TimedRotatingFileHandler
+handler = logging.handlers.TimedRotatingFileHandler(
+ os.path.join(log_dir, 'app.log'), # 日志文件名
+ when='midnight', # 每天午夜滚动
+ interval=1 # 滚动间隔为1天
+)
+formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+handler.setFormatter(formatter)
+
+# 将handler添加到日志器
+logger.addHandler(handler)
+
+# 创建控制台处理器
+console_handler = logging.StreamHandler()
+console_handler.setLevel(logging.DEBUG)
+console_handler.setFormatter(formatter)
+logger.addHandler(console_handler)
diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py
new file mode 100644
index 0000000..21d7fd5
--- /dev/null
+++ b/common/global_thread_pool.py
@@ -0,0 +1,60 @@
+import uuid
+from concurrent.futures import ThreadPoolExecutor
+import threading
+
+
+def generate_thread_id():
+ """生成唯一的线程 ID"""
+ return str(uuid.uuid4())
+
+
+class GlobalThreadPool:
+ _instance = None
+ _lock = threading.Lock()
+
+ def __new__(cls, *args, **kwargs):
+ with cls._lock:
+ if cls._instance is None:
+ # 第一次创建实例时调用父类的 __new__ 来创建实例
+ cls._instance = super(GlobalThreadPool, cls).__new__(cls)
+ # 在此进行一次性的初始化,比如线程池的创建
+ max_workers = kwargs.get('max_workers', 10)
+ cls._instance.executor = ThreadPoolExecutor(max_workers=max_workers)
+ cls._instance.task_map = {} # 初始化任务映射
+ return cls._instance
+
+ def submit_task(self, fn, *args, thread_id=None, **kwargs):
+ """提交任务到线程池,并记录线程 ID"""
+ if thread_id is None:
+ thread_id = generate_thread_id()
+ if self.check_task_is_running(thread_id):
+ raise ValueError(f"线程 ID {thread_id} 已存在")
+ future = self.executor.submit(fn, *args, **kwargs)
+ self.task_map[thread_id] = future # 记录线程 ID 和 Future 对象的映射
+ return thread_id
+
+ def check_task_is_running(self, thread_id):
+ future = self.task_map.get(thread_id)
+ if future:
+ if future.running():
+ return True
+ else:
+ del self.task_map[thread_id]
+ return False
+ else:
+ return False
+
+ def stop_task(self, thread_id):
+ """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务"""
+ future = self.task_map.get(thread_id)
+ if future:
+ future.cancel() # 尝试取消任务
+ print(f"任务 {thread_id} 已取消")
+ del self.task_map[thread_id] # 从任务映射中删除
+ else:
+ print(f"未找到线程 ID {thread_id}")
+
+ def shutdown(self, wait=True):
+ """关闭线程池"""
+ self.executor.shutdown(wait=wait)
+ GlobalThreadPool._instance = None
diff --git a/db/__init__.py b/db/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/db/__init__.py
diff --git a/db/database.py b/db/database.py
new file mode 100644
index 0000000..4f44d2d
--- /dev/null
+++ b/db/database.py
@@ -0,0 +1,22 @@
+from sqlmodel import SQLModel, create_engine, Session
+from contextlib import contextmanager
+
+sqlite_file_name = "./db/safe-algo-pro.db"
+sqlite_url = f"sqlite:///{sqlite_file_name}"
+
+connect_args = {"check_same_thread": False}
+engine = create_engine(sqlite_url, connect_args=connect_args)
+
+
+# 初始化数据库表
+def init_db():
+ SQLModel.metadata.create_all(engine)
+
+
+# 数据库会话管理
+def get_db():
+ session = Session(engine)
+ try:
+ yield session
+ finally:
+ session.close()
diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db
new file mode 100644
index 0000000..86c8d64
--- /dev/null
+++ b/db/safe-algo-pro.db
Binary files differ
diff --git a/entity/__init__.py b/entity/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/entity/__init__.py
diff --git a/entity/base.py b/entity/base.py
new file mode 100644
index 0000000..6aa862c
--- /dev/null
+++ b/entity/base.py
@@ -0,0 +1,13 @@
+from sqlmodel import SQLModel, Field
+from typing import Optional
+from datetime import datetime
+
+
+class TimestampMixin(SQLModel):
+ create_time: Optional[datetime] = Field(default_factory=datetime.now)
+ update_time: Optional[datetime] = Field(default_factory=datetime.now)
+
+ class Config:
+ json_encoders = {
+ datetime: lambda v: v.strftime('%Y-%m-%d %H:%M:%S')
+ }
diff --git a/entity/device.py b/entity/device.py
new file mode 100644
index 0000000..f09b52d
--- /dev/null
+++ b/entity/device.py
@@ -0,0 +1,33 @@
+from typing import Optional
+from sqlmodel import SQLModel, Field
+
+from entity.base import TimestampMixin
+
+
+class DeviceBase(SQLModel):
+ name: str
+ code: str
+ type: Optional[str] = None
+ ip: str
+ input_stream_url: Optional[str] = None
+ output_stream_url: Optional[str] = None
+ image_save_interval: Optional[int] = None
+
+
+class Device(DeviceBase, TimestampMixin, table=True):
+ id: Optional[int] = Field(default=None, primary_key=True)
+
+
+class DeviceCreate(DeviceBase):
+ pass
+
+
+class DeviceUpdate(DeviceBase):
+ id: int
+ name: Optional[str] = None
+ code: Optional[str] = None
+ ip: Optional[str] = None
+
+
+class DeviceInfo(DeviceBase, TimestampMixin):
+ id: int
diff --git a/entity/device_model_relation.py b/entity/device_model_relation.py
new file mode 100644
index 0000000..2e58a39
--- /dev/null
+++ b/entity/device_model_relation.py
@@ -0,0 +1,31 @@
+from typing import Optional
+
+from sqlmodel import SQLModel, Field
+
+from entity.base import TimestampMixin
+
+
+class DeviceModelRelationBase(SQLModel):
+ algo_model_id: int
+ is_use: int
+ threshold: Optional[float] = None
+ alarm_interval: Optional[int] = None
+ alarm_type: Optional[str] = None
+
+
+class DeviceModelRelation(DeviceModelRelationBase, TimestampMixin, table=True):
+ __tablename__ = 'device_model_relation'
+ id: Optional[int] = Field(default=None, primary_key=True)
+ device_id: int
+
+class DeviceModelRelationCreate(DeviceModelRelationBase):
+ pass
+
+
+class DeviceModelRelationInfo(DeviceModelRelationBase, TimestampMixin):
+ id: int
+ device_id: int
+ algo_model_name: str
+ algo_model_version: str
+ algo_model_path: Optional[str] = None # 可选字段
+ algo_model_remark: Optional[str] = None
diff --git a/entity/model.py b/entity/model.py
new file mode 100644
index 0000000..db9133e
--- /dev/null
+++ b/entity/model.py
@@ -0,0 +1,33 @@
+from typing import Optional
+
+from sqlmodel import SQLModel, Field
+
+from entity.base import TimestampMixin
+
+
+class AlgoModelBase(SQLModel):
+ name: str
+ version: str
+ path: str
+ remark: Optional[str] = None
+
+
+class AlgoModel(AlgoModelBase, TimestampMixin, table=True):
+ __tablename__ = "algo_model" # 显式指定表名
+
+ id: Optional[int] = Field(default=None, primary_key=True)
+
+
+class AlgoModelCreate(AlgoModelBase):
+ pass
+
+
+class AlgoModelUpdate(AlgoModelBase):
+ id: int
+ name: Optional[str] = None
+ version: Optional[str] = None
+ path: Optional[str] = None
+
+
+class AlgoModelInfo(AlgoModelBase, TimestampMixin):
+ id: int
diff --git a/main.py b/main.py
new file mode 100644
index 0000000..75c9b52
--- /dev/null
+++ b/main.py
@@ -0,0 +1,15 @@
+from fastapi import FastAPI, HTTPException
+from apis.router import router
+import uvicorn
+
+from common.biz_exception import BizExceptionHandlers
+
+app = FastAPI()
+
+# 包含所有模块的路由
+app.include_router(router, prefix="/api")
+app.add_exception_handler(HTTPException, BizExceptionHandlers.http_exception_handler)
+
+if __name__ == "__main__":
+
+ uvicorn.run(app, host="0.0.0.0", port=9000)
\ No newline at end of file
diff --git a/services/__init__.py b/services/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/services/__init__.py
diff --git a/services/device_model_relation_service.py b/services/device_model_relation_service.py
new file mode 100644
index 0000000..0f8d9a7
--- /dev/null
+++ b/services/device_model_relation_service.py
@@ -0,0 +1,85 @@
+from datetime import datetime
+from typing import List
+
+from sqlmodel import Session, select, delete
+
+from common.consts import NotifyChangeType
+from common.global_thread_pool import GlobalThreadPool
+from entity.device_model_relation import DeviceModelRelation, DeviceModelRelationInfo, DeviceModelRelationCreate
+from entity.model import AlgoModel
+
+
+class DeviceModelRelationService:
+ def __init__(self, db: Session):
+ self.db = db
+ self.relation_change_callbacks = [] # 用于存储回调函数
+ self.thread_pool = GlobalThreadPool()
+
+ def register_change_callback(self, callback):
+ """注册设备变化回调函数"""
+ self.relation_change_callbacks.append(callback)
+
+ def notify_change(self, device_id, change_type):
+ """当设备发生变化时,调用回调通知变化"""
+ for callback in self.relation_change_callbacks:
+ self.thread_pool.executor.submit(callback, device_id, change_type)
+
+ def get_device_models(self, device_id: int) -> List[DeviceModelRelationInfo]:
+ statement = (
+ select(DeviceModelRelation, AlgoModel)
+ .join(AlgoModel, DeviceModelRelation.algo_model_id == AlgoModel.id)
+ .where(DeviceModelRelation.device_id == device_id)
+ )
+
+ # 执行联表查询
+ result = self.db.exec(statement).all()
+
+ models_info = [
+ DeviceModelRelationInfo(
+ id=relation.id,
+ device_id=relation.id,
+ algo_model_id=relation.algo_model_id,
+ is_use=relation.is_use,
+ threshold=relation.threshold,
+ alarm_interval=relation.alarm_interval,
+ alarm_type=relation.alarm_type,
+ algo_model_name=model.name,
+ algo_model_version=model.version,
+ algo_model_path=model.path,
+ algo_model_remark=model.remark,
+ )
+ for relation, model in result
+ ]
+ return models_info
+
+ def add_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]):
+ new_relations = [
+ DeviceModelRelation(
+ algo_model_id=relation.algo_model_id,
+ is_use=relation.is_use,
+ threshold=relation.threshold,
+ alarm_interval=relation.alarm_interval,
+ alarm_type=relation.alarm_type,
+ device_id=device_id, # 统一赋值 device_id
+ createtime=datetime.now(),
+ updatetime=datetime.now(),
+ )
+ for relation in relations
+ ]
+ self.db.add_all(new_relations)
+ self.db.commit()
+ for relation in new_relations:
+ self.db.refresh(relation)
+ return new_relations
+
+ def delete_relations_by_device(self, device_id: int):
+ statement = delete(DeviceModelRelation).where(DeviceModelRelation.device_id == device_id)
+ count = self.db.exec(statement)
+ self.db.commit()
+ return count.rowcount
+
+ def update_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]):
+ self.delete_relations_by_device(device_id)
+ new_relations = self.add_relations_by_device(device_id, relations)
+ self.notify_change(device_id, NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE)
+ return new_relations
diff --git a/services/device_service.py b/services/device_service.py
new file mode 100644
index 0000000..0284c4e
--- /dev/null
+++ b/services/device_service.py
@@ -0,0 +1,108 @@
+from datetime import datetime
+from typing import Sequence, Optional, Tuple
+
+from sqlalchemy import func
+from sqlmodel import Session, select
+
+from common.global_thread_pool import GlobalThreadPool
+from common.consts import NotifyChangeType
+from entity.device import Device, DeviceCreate, DeviceUpdate
+from services.device_model_relation_service import DeviceModelRelationService
+
+
+class DeviceService:
+ def __init__(self, db: Session):
+ self.db = db
+ self.device_change_callbacks = [] # 用于存储回调函数
+ self.thread_pool = GlobalThreadPool()
+
+ def register_change_callback(self, callback):
+ """注册设备变化回调函数"""
+ self.device_change_callbacks.append(callback)
+
+ def notify_change(self, device_id, change_type):
+ """当设备发生变化时,调用回调通知变化"""
+ for callback in self.device_change_callbacks:
+ self.thread_pool.executor.submit(callback, device_id, change_type)
+
+ def get_device_list(self,
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ ) -> Sequence[Device]:
+ statement = self.device_query(code, device_type, name)
+ results = self.db.exec(statement)
+ return results.all()
+
+ def get_device_page(self,
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ offset: int = 0,
+ limit: int = 10
+ ) -> Tuple[Sequence[Device], int]:
+ statement = self.device_query(code, device_type, name)
+
+ # 查询总记录数
+ total_statement = select(func.count()).select_from(statement.subquery())
+ total = self.db.exec(total_statement).one()
+
+ # 添加分页限制
+ statement = statement.offset(offset).limit(limit)
+
+ # 执行查询并返回结果
+ results = self.db.exec(statement)
+ return results.all(), total # 返回分页数据和总数
+
+ def device_query(self, code, device_type, name):
+ # 构建查询语句
+ statement = select(Device)
+ if name:
+ statement = statement.where(Device.name.like(f"%{name}%"))
+ if code:
+ statement = statement.where(Device.code.like(f"%{code}%"))
+ if device_type:
+ statement = statement.where(Device.type == device_type)
+ return statement
+
+ def create_device(self, device_data: DeviceCreate):
+ device = Device.model_validate(device_data)
+ device.create_time = datetime.now()
+ device.update_time = datetime.now()
+ self.db.add(device)
+ self.db.commit()
+ self.db.refresh(device)
+ self.notify_change(device.id, NotifyChangeType.DEVICE_CREATE)
+ return device
+
+ def update_device(self, device_data: DeviceUpdate):
+ device = self.db.get(Device, device_data.id)
+ if not device:
+ return None
+
+ update_data = device_data.dict(exclude_unset=True)
+ for key, value in update_data.items():
+ setattr(device, key, value)
+
+ device.update_time = datetime.now()
+ self.db.add(device)
+ self.db.commit()
+ self.db.refresh(device)
+ self.notify_change(device.id, NotifyChangeType.DEVICE_UPDATE)
+ return device
+
+ def delete_device(self, device_id: int):
+ device = self.db.get(Device, device_id)
+ if not device:
+ return None
+
+ self.db.delete(device)
+ self.db.commit()
+ self.notify_change(device.id, NotifyChangeType.DEVICE_DELETE)
+
+ relation_service = DeviceModelRelationService(self.db)
+ relation_service.delete_relations_by_device(device_id)
+ return device
+
+ def get_device(self, device_id: int):
+ return self.db.get(Device, device_id)
diff --git a/services/model_service.py b/services/model_service.py
new file mode 100644
index 0000000..568895e
--- /dev/null
+++ b/services/model_service.py
@@ -0,0 +1,120 @@
+from datetime import datetime
+from typing import List, Sequence, Optional, Tuple
+
+from sqlalchemy import func
+from sqlmodel import Session, select
+
+from common.biz_exception import BizException
+from entity.device_model_relation import DeviceModelRelation
+from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate
+from common.global_thread_pool import GlobalThreadPool
+from common.consts import NotifyChangeType
+
+
+class ModelService:
+ def __init__(self, db: Session):
+ self.db = db
+ self.model_change_callbacks = [] # 用于存储回调函数
+ self.thread_pool = GlobalThreadPool()
+
+ def register_change_callback(self, callback):
+ """注册设备变化回调函数"""
+ self.model_change_callbacks.append(callback)
+
+ def notify_change(self, algo_model_id, change_type):
+ """当设备发生变化时,调用回调通知变化"""
+ for callback in self.model_change_callbacks:
+ self.thread_pool.executor.submit(callback, algo_model_id, change_type)
+
+ def get_model_list(self,
+ name: Optional[str] = None,
+ remark: Optional[str] = None,
+ ) -> Sequence[AlgoModel]:
+ statement = self.model_query(name, remark)
+ results = self.db.exec(statement)
+ return results.all()
+
+ def get_model_page(self,
+ name: Optional[str] = None,
+ remark: Optional[str] = None,
+ offset: int = 0,
+ limit: int = 10
+ ) -> Tuple[Sequence[AlgoModel], int]:
+ statement = self.model_query(name, remark)
+
+ # 查询总记录数
+ total_statement = select(func.count()).select_from(statement.subquery())
+ total = self.db.exec(total_statement).one()
+
+ # 添加分页限制
+ statement = statement.offset(offset).limit(limit)
+
+ # 执行查询并返回结果
+ results = self.db.exec(statement)
+ return results.all(), total # 返回分页数据和总数
+
+ def model_query(self, name, remark):
+ # 构建查询语句
+ statement = select(AlgoModel)
+ if name:
+ statement = statement.where(AlgoModel.name.like(f"%{name}%"))
+ if remark:
+ statement = statement.where(AlgoModel.remark.like(f"%{remark}%"))
+ return statement
+
+ def create_model(self, model_data: AlgoModelCreate):
+ model = AlgoModel.model_validate(model_data)
+ model.create_time = datetime.now()
+ model.update_time = datetime.now()
+ self.db.add(model)
+ self.db.commit()
+ self.db.refresh(model)
+ return model
+
+ def update_model(self, model_data: AlgoModelUpdate):
+ model = self.db.get(AlgoModel, model_data.id)
+ if not model:
+ return None
+
+ update_data = model_data.dict(exclude_unset=True)
+ for key, value in update_data.items():
+ setattr(model, key, value)
+
+ model.update_time = datetime.now()
+ self.db.add(model)
+ self.db.commit()
+ self.db.refresh(model)
+ self.notify_change(model.id, NotifyChangeType.MODEL_UPDATE)
+
+ return model
+
+ def delete_model(self, model_id: int):
+ model = self.db.get(AlgoModel, model_id)
+ if not model:
+ return None
+ # 查询 device_model_relation 中是否存在启用的绑定关系
+ statement = (
+ select(DeviceModelRelation)
+ .where(DeviceModelRelation.algo_model_id == model_id)
+ .where(DeviceModelRelation.is_use == 1)
+ )
+ relation_in_use = self.db.exec(statement).first()
+
+ # 如果存在启用的绑定关系,提示无法删除
+ if relation_in_use:
+ raise BizException(message=f"模型 {model.name} 正在被设备使用,无法删除")
+
+ self.db.delete(model)
+ self.db.commit()
+ return model
+
+ def get_models_in_use(self) -> Sequence[AlgoModel]:
+ """获取所有在 device_model_relation 表里有启用绑定关系的模型信息"""
+ statement = (
+ select(AlgoModel)
+ .join(DeviceModelRelation, DeviceModelRelation.algo_model_id == AlgoModel.id)
+ .where(DeviceModelRelation.is_use == 1)
+ .group_by(AlgoModel.id)
+ )
+ results = self.db.exec(statement).all()
+ return results
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..e530a86
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+test.py
+/logs/*
\ No newline at end of file
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..901fe74
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,20 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..b68d82f
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..3deddad
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml
new file mode 100644
index 0000000..d0876a7
--- /dev/null
+++ b/.idea/safe-algo-pro.iml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..94a25f7
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/algo/__init__.py b/algo/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/algo/__init__.py
diff --git a/algo/algo_runner.py b/algo/algo_runner.py
new file mode 100644
index 0000000..ee07eaf
--- /dev/null
+++ b/algo/algo_runner.py
@@ -0,0 +1,117 @@
+from algo.model_manager import ModelManager
+from common.consts import NotifyChangeType
+from common.global_thread_pool import GlobalThreadPool
+from db.database import get_db
+from entity.device import Device
+from services.device_model_relation_service import DeviceModelRelationService
+from services.device_service import DeviceService
+from services.model_service import ModelService
+
+
+class AlgoRunner:
+ def __init__(self):
+ self.db = get_db()
+ self.device_service = DeviceService(self.db)
+ self.model_service = ModelService(self.db)
+ self.relation_service = DeviceModelRelationService(self.db)
+ self.model_manager = ModelManager(self.db)
+ self.thread_pool = GlobalThreadPool()
+ self.threads = {} # 用于存储设备对应的线程
+ self.device_model_relations = {}
+
+ # 注册设备和模型的变化回调
+ self.device_service.register_change_callback(self.on_device_change)
+ self.model_service.register_change_callback(self.on_model_change)
+ self.relation_service.relation_change_callbacks(self.on_relation_change)
+
+ def start(self):
+ """在程序启动时调用,读取设备和模型,启动检测线程"""
+ self.model_manager.load_models()
+ self.load_and_start_devices()
+
+ def load_and_start_devices(self):
+ """从数据库读取设备列表并启动线程"""
+ devices = self.device_service.get_device_list()
+ devices = [device for device in devices if device.input_stream_url]
+ for device in devices:
+ self.start_device_thread(device)
+
+ def run_detection(self, device: Device):
+ """todo 设备目标检测主逻辑"""
+ print(f"设备 {device.device_id} 的检测线程启动")
+ video_stream = self.device_service.get_video_stream(device.device_id)
+
+ while True:
+ try:
+ frame = video_stream.read_frame()
+ for model in self.model_manager.models:
+ # 调用目标检测模型
+ result = self.model_service.run_inference(model, frame)
+ # 在此处进行后处理,例如标记视频、生成告警等
+ self.process_result(result, device.device_id)
+ except Exception as e:
+ print(f"设备 {device.device_id} 处理时出错: {e}")
+ break
+ video_stream.close()
+ print(f"设备 {device.device_id} 的检测线程结束")
+
+ def start_device_thread(self, device: Device):
+ """为单个设备启动检测线程"""
+
+ if not device.input_stream_url:
+ print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程')
+ return
+
+ if device.device_id in self.threads:
+ print(f"设备 {device.device_id} 已经在运行中")
+ return
+
+ # 获取设备绑定的模型列表
+ relations = self.relation_service.get_device_models(device.device_id)
+ self.device_model_relations[device.device_id] = relations
+
+ if not relations:
+ print(f"设备 {device.code} 未绑定模型,无法启动检测")
+ return
+
+ future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}')
+
+ def stop_device_thread(self, device_id):
+ """todo 控制标志位 停止指定设备的检测线程"""
+ if device_id in self.threads:
+ self.threads[device_id].cancel() # 尝试取消线程
+ del self.threads[device_id]
+ print(f"设备 {device_id} 的检测线程已停止")
+
+ def restart_device_thread(self, device_id):
+ self.stop_device_thread(device_id)
+ # todo 需要留个关闭时间吗
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+
+ def on_device_change(self, device_id, change_type):
+ """设备变化回调"""
+ if change_type == NotifyChangeType.DEVICE_CREATE:
+ device = self.device_service.get_device(device_id)
+ self.start_device_thread(device)
+ elif change_type == NotifyChangeType.DEVICE_DELETE:
+ self.stop_device_thread(device_id)
+ elif change_type == NotifyChangeType.DEVICE_UPDATE:
+ self.restart_device_thread(device_id)
+
+ def on_model_change(self, model_id, change_type):
+ """模型变化回调"""
+ if change_type == NotifyChangeType.MODEL_UPDATE:
+ self.model_manager.reload_model(model_id)
+
+ devices_to_reload = [
+ device_id for device_id, relation_info in self.device_model_relations.items()
+ if relation_info.model_id == model_id
+ ]
+ for device_id in devices_to_reload:
+ self.restart_device_thread(device_id)
+
+ def on_relation_change(self, device_id, change_type):
+ """设备模型关系变化回调"""
+ if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE:
+ self.restart_device_thread(device_id)
diff --git a/algo/model_manager.py b/algo/model_manager.py
new file mode 100644
index 0000000..3ab0709
--- /dev/null
+++ b/algo/model_manager.py
@@ -0,0 +1,64 @@
+from dataclasses import dataclass
+from typing import Optional
+
+import numpy as np
+from ultralytics import YOLO
+
+from entity.model import AlgoModel
+from services.model_service import ModelService
+from common.global_logger import logger
+
+
+@dataclass
+class AlgoModelExec:
+ algo_model_id: int
+ algo_model_info: AlgoModel
+ algo_model_exec: Optional[object] = None
+ input_size: int = 640
+
+
+class ModelManager:
+ def __init__(self, db, model_warm_up=5):
+ self.db = db
+ self.model_service = ModelService(self.db)
+ self.models = {}
+ self.model_warm_up = model_warm_up
+
+ self.load_models()
+
+ def query_model_inuse(self):
+ algo_model_list = list(self.model_service.get_models_in_use())
+ for algo_model in algo_model_list:
+ self.models[algo_model.id] = AlgoModelExec(
+ algo_model_id=algo_model.id,
+ algo_model_info=algo_model
+ )
+
+ def load_models(self):
+ self.models = {}
+ self.query_model_inuse()
+ for algo_model_id, algo_model_exec in self.models.items():
+ self.load_model(algo_model_exec)
+
+ def load_model(self, algo_model_exec: AlgoModelExec):
+ model_name = algo_model_exec.algo_model_info.name
+ model_path = algo_model_exec.algo_model_info.path
+ logger.info(f'loading model {model_name}: {model_path}')
+
+ algo_model_exec.algo_model_exec = YOLO(model_path, task='detect')
+ if self.model_warm_up > 0:
+ logger.info(f'warming up model {model_name}')
+ imgsz = algo_model_exec.input_size
+ if not isinstance(imgsz, list):
+ imgsz = [imgsz, imgsz]
+ dummy_input = np.zeros((imgsz[0], imgsz[1], 3))
+ for i in range(self.model_warm_up):
+ algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False)
+ logger.info(f'warm up model {model_name} success!')
+ logger.info(f'load model {model_name} success!')
+
+ def reload_model(self, model_id):
+ algo_model_exec = self.models.get(model_id)
+ if algo_model_exec:
+ algo_model_exec.algo_model_exec = None
+ self.load_model(algo_model_exec)
diff --git a/apis/__init__.py b/apis/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/apis/__init__.py
diff --git a/apis/base.py b/apis/base.py
new file mode 100644
index 0000000..a7c4ede
--- /dev/null
+++ b/apis/base.py
@@ -0,0 +1,26 @@
+from pydantic import BaseModel, Field
+from typing import Optional, Generic, TypeVar, Any, List
+
+# 定义一个泛型类型变量 T
+T = TypeVar("T")
+
+
+# 使用泛型 T 定义 data 的类型
+class StandardResponse(BaseModel, Generic[T]):
+ code: int = 200
+ data: Optional[T] = None
+ message: str = "请求成功"
+ success: bool = True
+
+
+class PageResponse(BaseModel, Generic[T]):
+ total: int
+ items: List[T] = Field(default_factory=list)
+
+
+def standard_response(data: Any = None, code: int = 200, message: str = "请求成功", success: bool = True):
+ return StandardResponse(data=data, code=code, message=message, success=success)
+
+
+def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False):
+ return StandardResponse(data=data, code=code, message=message, success=success)
diff --git a/apis/device.py b/apis/device.py
new file mode 100644
index 0000000..7eeaa78
--- /dev/null
+++ b/apis/device.py
@@ -0,0 +1,65 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, Query
+from sqlmodel import Session
+
+from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response
+from db.database import get_db
+from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo
+from services.device_service import DeviceService
+
+router = APIRouter()
+
+
+@router.get("/list", response_model=StandardResponse[List[DeviceInfo]])
+def get_device_list(
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ devices = list(service.get_device_list(name, code, device_type))
+ return standard_response(data=devices)
+
+
+@router.get("/page/", response_model=StandardResponse[PageResponse[Device]])
+def get_device_page(
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ offset: int = Query(0, ge=0), # 从第几条开始
+ limit: int = Query(10, ge=1), # 每页显示多少条记录
+ db: Session = Depends(get_db)):
+ service = DeviceService(db)
+
+ # 获取分页后的设备列表和总数
+ devices, total = service.get_device_page(name, code, device_type, offset, limit)
+
+ return standard_response(
+ data=PageResponse(total=total, items=devices)
+ )
+
+
+@router.post("/add", response_model=StandardResponse[DeviceInfo])
+def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.create_device(device_data)
+ return standard_response(data=device)
+
+
+@router.post("/update", response_model=StandardResponse[DeviceInfo])
+def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.update_device(device_data)
+ if not device:
+ return standard_error_response(data=device_data, message="Device not found")
+ return standard_response(data=device)
+
+
+@router.delete("/delete", response_model=StandardResponse[int])
+def delete_device(device_id: int, db: Session = Depends(get_db)):
+ service = DeviceService(db)
+ device = service.delete_device(device_id)
+ if not device:
+ return standard_error_response(data=device_id, message="Device not found")
+ return standard_response(data=device_id)
diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py
new file mode 100644
index 0000000..7108571
--- /dev/null
+++ b/apis/device_model_realtion.py
@@ -0,0 +1,45 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, Query
+from sqlmodel import Session
+
+from apis.base import StandardResponse, standard_response
+from db.database import get_db
+from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation
+from services.device_model_relation_service import DeviceModelRelationService
+
+router = APIRouter()
+
+
+@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]])
+def list_by_device(
+ device_id: int,
+ db: Session = Depends(get_db)):
+ service = DeviceModelRelationService(db)
+ models = list(service.get_device_models(device_id))
+ return standard_response(data=models)
+
+
+# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
+# def add_by_device(relation_data: List[DeviceModelRelationCreate],
+# device_id: int = Query(...),
+# db: Session = Depends(get_db)):
+# service = DeviceModelRelationService(db)
+# relations = service.add_relations_by_device(device_id, relation_data)
+# return standard_response(data=relations)
+
+
+@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]])
+def update_by_device(relation_data: List[DeviceModelRelationCreate],
+ device_id: int = Query(...),
+ db: Session = Depends(get_db)):
+ service = DeviceModelRelationService(db)
+ relations = service.update_relations_by_device(device_id, relation_data)
+ return standard_response(data=relations)
+
+
+# @router.delete("/delete_by_device", response_model=StandardResponse[int])
+# def delete_device(device_id: int, db: Session = Depends(get_db)):
+# service = DeviceModelRelationService(db)
+# count = service.delete_relations_by_device(device_id)
+# return standard_response(data=count)
diff --git a/apis/model.py b/apis/model.py
new file mode 100644
index 0000000..8adcc14
--- /dev/null
+++ b/apis/model.py
@@ -0,0 +1,64 @@
+from typing import List, Optional
+
+from fastapi import APIRouter, Depends, HTTPException, Query
+from sqlmodel import Session
+
+from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response
+from common.biz_exception import BizException
+from db.database import get_db
+from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo
+from services.model_service import ModelService
+
+router = APIRouter()
+
+
+@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]])
+def get_model_list(
+ name: Optional[str] = None,
+ remark: Optional[str] = None,
+ db: Session = Depends(get_db)):
+ service = ModelService(db)
+ models = list(service.get_model_list(name, remark))
+ return standard_response(data=models)
+
+
+@router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModel]])
+def get_model_page(
+ name: Optional[str] = None,
+ remark: Optional[str] = None,
+ offset: int = Query(0, ge=0), # 从第几条开始
+ limit: int = Query(10, ge=1), # 每页显示多少条记录
+ db: Session = Depends(get_db)):
+ service = ModelService(db)
+
+ # 获取分页后的设备列表和总数
+ models, total = service.get_model_page(name, remark, offset, limit)
+
+ return standard_response(
+ data=PageResponse(total=total, items=models)
+ )
+
+
+@router.post("/add", response_model=StandardResponse[AlgoModelInfo])
+def create_model(model_data: AlgoModelCreate, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.create_model(model_data)
+ return standard_response(data=model)
+
+
+@router.post("/update", response_model=StandardResponse[AlgoModelInfo])
+def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.update_model(model_data)
+ if not model:
+ return standard_error_response(data=model_data, message="Model not found")
+ return standard_response(data=model)
+
+
+@router.delete("/delete", response_model=StandardResponse[int])
+def delete_model(model_id: int, db: Session = Depends(get_db)):
+ service = ModelService(db)
+ model = service.delete_model(model_id)
+ if not model:
+ return standard_error_response(data=model_id, message="Model not found")
+ return standard_response(data=model_id)
diff --git a/apis/router.py b/apis/router.py
new file mode 100644
index 0000000..22b5ed6
--- /dev/null
+++ b/apis/router.py
@@ -0,0 +1,15 @@
+# api/router.py
+from fastapi import APIRouter
+from .device import router as devices_router
+from .model import router as models_router
+from .device_model_realtion import router as device_model_relation_router
+
+
+# 创建一个全局的 router
+router = APIRouter()
+
+# 将各个模块的 router 注册到全局 router 中
+router.include_router(devices_router, prefix="/device", tags=["Devices"])
+router.include_router(models_router, prefix="/model", tags=["Models"])
+router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"])
+
diff --git a/common/__init__.py b/common/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/common/__init__.py
diff --git a/common/biz_exception.py b/common/biz_exception.py
new file mode 100644
index 0000000..946a794
--- /dev/null
+++ b/common/biz_exception.py
@@ -0,0 +1,14 @@
+from fastapi import Request, HTTPException
+
+from apis.base import standard_error_response
+
+
+class BizException(HTTPException):
+ def __init__(self, message: str = "请求异常", status_code: int = 500):
+ super().__init__(status_code=status_code, detail=message)
+
+
+class BizExceptionHandlers:
+ @staticmethod
+ async def biz_exception_handler(request: Request, exc: HTTPException):
+ standard_error_response(code=exc.status_code, message=exc.detail)
diff --git a/common/consts.py b/common/consts.py
new file mode 100644
index 0000000..bba41e4
--- /dev/null
+++ b/common/consts.py
@@ -0,0 +1,16 @@
+class Constants:
+ def __setattr__(self, name, value):
+ raise AttributeError(f"Cannot modify constant {name}")
+
+
+class NotifyChangeType(Constants):
+ # 定义常量成员
+ DEVICE_CREATE = "device_create"
+ DEVICE_UPDATE = "device_update"
+ DEVICE_DELETE = "device_delete"
+ MODEL_CREATE = "model_create" # 模型新增不用通知,与设备绑定时才生效
+ MODEL_UPDATE = "model_update"
+ MODEL_DELETE = "model_delete" # 正在使用的模型不能直接删除
+ DEVICE_MODEL_RELATION_CREATE = "device_model_relation_create"
+ DEVICE_MODEL_RELATION_UPDATE = "device_model_relation_update" # 绑定关系变化 应该只用这个吧??
+ DEVICE_MODEL_RELATION_DELETE = "device_model_relation_delete"
diff --git a/common/global_logger.py b/common/global_logger.py
new file mode 100644
index 0000000..1d65173
--- /dev/null
+++ b/common/global_logger.py
@@ -0,0 +1,30 @@
+# logger.py
+import logging.handlers
+import os
+
+# 确保日志目录存在
+log_dir = '../logs'
+if not os.path.exists(log_dir):
+ os.makedirs(log_dir)
+
+# 实例化并导出全局日志记录器
+logger = logging.getLogger("casic_safe_logger")
+logger.setLevel(logging.DEBUG) # 设置日志级别
+
+# 创建一个TimedRotatingFileHandler
+handler = logging.handlers.TimedRotatingFileHandler(
+ os.path.join(log_dir, 'app.log'), # 日志文件名
+ when='midnight', # 每天午夜滚动
+ interval=1 # 滚动间隔为1天
+)
+formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+handler.setFormatter(formatter)
+
+# 将handler添加到日志器
+logger.addHandler(handler)
+
+# 创建控制台处理器
+console_handler = logging.StreamHandler()
+console_handler.setLevel(logging.DEBUG)
+console_handler.setFormatter(formatter)
+logger.addHandler(console_handler)
diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py
new file mode 100644
index 0000000..21d7fd5
--- /dev/null
+++ b/common/global_thread_pool.py
@@ -0,0 +1,60 @@
+import uuid
+from concurrent.futures import ThreadPoolExecutor
+import threading
+
+
+def generate_thread_id():
+ """生成唯一的线程 ID"""
+ return str(uuid.uuid4())
+
+
+class GlobalThreadPool:
+ _instance = None
+ _lock = threading.Lock()
+
+ def __new__(cls, *args, **kwargs):
+ with cls._lock:
+ if cls._instance is None:
+ # 第一次创建实例时调用父类的 __new__ 来创建实例
+ cls._instance = super(GlobalThreadPool, cls).__new__(cls)
+ # 在此进行一次性的初始化,比如线程池的创建
+ max_workers = kwargs.get('max_workers', 10)
+ cls._instance.executor = ThreadPoolExecutor(max_workers=max_workers)
+ cls._instance.task_map = {} # 初始化任务映射
+ return cls._instance
+
+ def submit_task(self, fn, *args, thread_id=None, **kwargs):
+ """提交任务到线程池,并记录线程 ID"""
+ if thread_id is None:
+ thread_id = generate_thread_id()
+ if self.check_task_is_running(thread_id):
+ raise ValueError(f"线程 ID {thread_id} 已存在")
+ future = self.executor.submit(fn, *args, **kwargs)
+ self.task_map[thread_id] = future # 记录线程 ID 和 Future 对象的映射
+ return thread_id
+
+ def check_task_is_running(self, thread_id):
+ future = self.task_map.get(thread_id)
+ if future:
+ if future.running():
+ return True
+ else:
+ del self.task_map[thread_id]
+ return False
+ else:
+ return False
+
+ def stop_task(self, thread_id):
+ """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务"""
+ future = self.task_map.get(thread_id)
+ if future:
+ future.cancel() # 尝试取消任务
+ print(f"任务 {thread_id} 已取消")
+ del self.task_map[thread_id] # 从任务映射中删除
+ else:
+ print(f"未找到线程 ID {thread_id}")
+
+ def shutdown(self, wait=True):
+ """关闭线程池"""
+ self.executor.shutdown(wait=wait)
+ GlobalThreadPool._instance = None
diff --git a/db/__init__.py b/db/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/db/__init__.py
diff --git a/db/database.py b/db/database.py
new file mode 100644
index 0000000..4f44d2d
--- /dev/null
+++ b/db/database.py
@@ -0,0 +1,22 @@
+from sqlmodel import SQLModel, create_engine, Session
+from contextlib import contextmanager
+
+sqlite_file_name = "./db/safe-algo-pro.db"
+sqlite_url = f"sqlite:///{sqlite_file_name}"
+
+connect_args = {"check_same_thread": False}
+engine = create_engine(sqlite_url, connect_args=connect_args)
+
+
+# 初始化数据库表
+def init_db():
+ SQLModel.metadata.create_all(engine)
+
+
+# 数据库会话管理
+def get_db():
+ session = Session(engine)
+ try:
+ yield session
+ finally:
+ session.close()
diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db
new file mode 100644
index 0000000..86c8d64
--- /dev/null
+++ b/db/safe-algo-pro.db
Binary files differ
diff --git a/entity/__init__.py b/entity/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/entity/__init__.py
diff --git a/entity/base.py b/entity/base.py
new file mode 100644
index 0000000..6aa862c
--- /dev/null
+++ b/entity/base.py
@@ -0,0 +1,13 @@
+from sqlmodel import SQLModel, Field
+from typing import Optional
+from datetime import datetime
+
+
+class TimestampMixin(SQLModel):
+ create_time: Optional[datetime] = Field(default_factory=datetime.now)
+ update_time: Optional[datetime] = Field(default_factory=datetime.now)
+
+ class Config:
+ json_encoders = {
+ datetime: lambda v: v.strftime('%Y-%m-%d %H:%M:%S')
+ }
diff --git a/entity/device.py b/entity/device.py
new file mode 100644
index 0000000..f09b52d
--- /dev/null
+++ b/entity/device.py
@@ -0,0 +1,33 @@
+from typing import Optional
+from sqlmodel import SQLModel, Field
+
+from entity.base import TimestampMixin
+
+
+class DeviceBase(SQLModel):
+ name: str
+ code: str
+ type: Optional[str] = None
+ ip: str
+ input_stream_url: Optional[str] = None
+ output_stream_url: Optional[str] = None
+ image_save_interval: Optional[int] = None
+
+
+class Device(DeviceBase, TimestampMixin, table=True):
+ id: Optional[int] = Field(default=None, primary_key=True)
+
+
+class DeviceCreate(DeviceBase):
+ pass
+
+
+class DeviceUpdate(DeviceBase):
+ id: int
+ name: Optional[str] = None
+ code: Optional[str] = None
+ ip: Optional[str] = None
+
+
+class DeviceInfo(DeviceBase, TimestampMixin):
+ id: int
diff --git a/entity/device_model_relation.py b/entity/device_model_relation.py
new file mode 100644
index 0000000..2e58a39
--- /dev/null
+++ b/entity/device_model_relation.py
@@ -0,0 +1,31 @@
+from typing import Optional
+
+from sqlmodel import SQLModel, Field
+
+from entity.base import TimestampMixin
+
+
+class DeviceModelRelationBase(SQLModel):
+ algo_model_id: int
+ is_use: int
+ threshold: Optional[float] = None
+ alarm_interval: Optional[int] = None
+ alarm_type: Optional[str] = None
+
+
+class DeviceModelRelation(DeviceModelRelationBase, TimestampMixin, table=True):
+ __tablename__ = 'device_model_relation'
+ id: Optional[int] = Field(default=None, primary_key=True)
+ device_id: int
+
+class DeviceModelRelationCreate(DeviceModelRelationBase):
+ pass
+
+
+class DeviceModelRelationInfo(DeviceModelRelationBase, TimestampMixin):
+ id: int
+ device_id: int
+ algo_model_name: str
+ algo_model_version: str
+ algo_model_path: Optional[str] = None # 可选字段
+ algo_model_remark: Optional[str] = None
diff --git a/entity/model.py b/entity/model.py
new file mode 100644
index 0000000..db9133e
--- /dev/null
+++ b/entity/model.py
@@ -0,0 +1,33 @@
+from typing import Optional
+
+from sqlmodel import SQLModel, Field
+
+from entity.base import TimestampMixin
+
+
+class AlgoModelBase(SQLModel):
+ name: str
+ version: str
+ path: str
+ remark: Optional[str] = None
+
+
+class AlgoModel(AlgoModelBase, TimestampMixin, table=True):
+ __tablename__ = "algo_model" # 显式指定表名
+
+ id: Optional[int] = Field(default=None, primary_key=True)
+
+
+class AlgoModelCreate(AlgoModelBase):
+ pass
+
+
+class AlgoModelUpdate(AlgoModelBase):
+ id: int
+ name: Optional[str] = None
+ version: Optional[str] = None
+ path: Optional[str] = None
+
+
+class AlgoModelInfo(AlgoModelBase, TimestampMixin):
+ id: int
diff --git a/main.py b/main.py
new file mode 100644
index 0000000..75c9b52
--- /dev/null
+++ b/main.py
@@ -0,0 +1,15 @@
+from fastapi import FastAPI, HTTPException
+from apis.router import router
+import uvicorn
+
+from common.biz_exception import BizExceptionHandlers
+
+app = FastAPI()
+
+# 包含所有模块的路由
+app.include_router(router, prefix="/api")
+app.add_exception_handler(HTTPException, BizExceptionHandlers.http_exception_handler)
+
+if __name__ == "__main__":
+
+ uvicorn.run(app, host="0.0.0.0", port=9000)
\ No newline at end of file
diff --git a/services/__init__.py b/services/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/services/__init__.py
diff --git a/services/device_model_relation_service.py b/services/device_model_relation_service.py
new file mode 100644
index 0000000..0f8d9a7
--- /dev/null
+++ b/services/device_model_relation_service.py
@@ -0,0 +1,85 @@
+from datetime import datetime
+from typing import List
+
+from sqlmodel import Session, select, delete
+
+from common.consts import NotifyChangeType
+from common.global_thread_pool import GlobalThreadPool
+from entity.device_model_relation import DeviceModelRelation, DeviceModelRelationInfo, DeviceModelRelationCreate
+from entity.model import AlgoModel
+
+
+class DeviceModelRelationService:
+ def __init__(self, db: Session):
+ self.db = db
+ self.relation_change_callbacks = [] # 用于存储回调函数
+ self.thread_pool = GlobalThreadPool()
+
+ def register_change_callback(self, callback):
+ """注册设备变化回调函数"""
+ self.relation_change_callbacks.append(callback)
+
+ def notify_change(self, device_id, change_type):
+ """当设备发生变化时,调用回调通知变化"""
+ for callback in self.relation_change_callbacks:
+ self.thread_pool.executor.submit(callback, device_id, change_type)
+
+ def get_device_models(self, device_id: int) -> List[DeviceModelRelationInfo]:
+ statement = (
+ select(DeviceModelRelation, AlgoModel)
+ .join(AlgoModel, DeviceModelRelation.algo_model_id == AlgoModel.id)
+ .where(DeviceModelRelation.device_id == device_id)
+ )
+
+ # 执行联表查询
+ result = self.db.exec(statement).all()
+
+ models_info = [
+ DeviceModelRelationInfo(
+ id=relation.id,
+ device_id=relation.id,
+ algo_model_id=relation.algo_model_id,
+ is_use=relation.is_use,
+ threshold=relation.threshold,
+ alarm_interval=relation.alarm_interval,
+ alarm_type=relation.alarm_type,
+ algo_model_name=model.name,
+ algo_model_version=model.version,
+ algo_model_path=model.path,
+ algo_model_remark=model.remark,
+ )
+ for relation, model in result
+ ]
+ return models_info
+
+ def add_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]):
+ new_relations = [
+ DeviceModelRelation(
+ algo_model_id=relation.algo_model_id,
+ is_use=relation.is_use,
+ threshold=relation.threshold,
+ alarm_interval=relation.alarm_interval,
+ alarm_type=relation.alarm_type,
+ device_id=device_id, # 统一赋值 device_id
+ createtime=datetime.now(),
+ updatetime=datetime.now(),
+ )
+ for relation in relations
+ ]
+ self.db.add_all(new_relations)
+ self.db.commit()
+ for relation in new_relations:
+ self.db.refresh(relation)
+ return new_relations
+
+ def delete_relations_by_device(self, device_id: int):
+ statement = delete(DeviceModelRelation).where(DeviceModelRelation.device_id == device_id)
+ count = self.db.exec(statement)
+ self.db.commit()
+ return count.rowcount
+
+ def update_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]):
+ self.delete_relations_by_device(device_id)
+ new_relations = self.add_relations_by_device(device_id, relations)
+ self.notify_change(device_id, NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE)
+ return new_relations
diff --git a/services/device_service.py b/services/device_service.py
new file mode 100644
index 0000000..0284c4e
--- /dev/null
+++ b/services/device_service.py
@@ -0,0 +1,108 @@
+from datetime import datetime
+from typing import Sequence, Optional, Tuple
+
+from sqlalchemy import func
+from sqlmodel import Session, select
+
+from common.global_thread_pool import GlobalThreadPool
+from common.consts import NotifyChangeType
+from entity.device import Device, DeviceCreate, DeviceUpdate
+from services.device_model_relation_service import DeviceModelRelationService
+
+
+class DeviceService:
+ def __init__(self, db: Session):
+ self.db = db
+ self.device_change_callbacks = [] # 用于存储回调函数
+ self.thread_pool = GlobalThreadPool()
+
+ def register_change_callback(self, callback):
+ """注册设备变化回调函数"""
+ self.device_change_callbacks.append(callback)
+
+ def notify_change(self, device_id, change_type):
+ """当设备发生变化时,调用回调通知变化"""
+ for callback in self.device_change_callbacks:
+ self.thread_pool.executor.submit(callback, device_id, change_type)
+
+ def get_device_list(self,
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ ) -> Sequence[Device]:
+ statement = self.device_query(code, device_type, name)
+ results = self.db.exec(statement)
+ return results.all()
+
+ def get_device_page(self,
+ name: Optional[str] = None,
+ code: Optional[str] = None,
+ device_type: Optional[str] = None,
+ offset: int = 0,
+ limit: int = 10
+ ) -> Tuple[Sequence[Device], int]:
+ statement = self.device_query(code, device_type, name)
+
+ # 查询总记录数
+ total_statement = select(func.count()).select_from(statement.subquery())
+ total = self.db.exec(total_statement).one()
+
+ # 添加分页限制
+ statement = statement.offset(offset).limit(limit)
+
+ # 执行查询并返回结果
+ results = self.db.exec(statement)
+ return results.all(), total # 返回分页数据和总数
+
+ def device_query(self, code, device_type, name):
+ # 构建查询语句
+ statement = select(Device)
+ if name:
+ statement = statement.where(Device.name.like(f"%{name}%"))
+ if code:
+ statement = statement.where(Device.code.like(f"%{code}%"))
+ if device_type:
+ statement = statement.where(Device.type == device_type)
+ return statement
+
+ def create_device(self, device_data: DeviceCreate):
+ device = Device.model_validate(device_data)
+ device.create_time = datetime.now()
+ device.update_time = datetime.now()
+ self.db.add(device)
+ self.db.commit()
+ self.db.refresh(device)
+ self.notify_change(device.id, NotifyChangeType.DEVICE_CREATE)
+ return device
+
+ def update_device(self, device_data: DeviceUpdate):
+ device = self.db.get(Device, device_data.id)
+ if not device:
+ return None
+
+ update_data = device_data.dict(exclude_unset=True)
+ for key, value in update_data.items():
+ setattr(device, key, value)
+
+ device.update_time = datetime.now()
+ self.db.add(device)
+ self.db.commit()
+ self.db.refresh(device)
+ self.notify_change(device.id, NotifyChangeType.DEVICE_UPDATE)
+ return device
+
+ def delete_device(self, device_id: int):
+ device = self.db.get(Device, device_id)
+ if not device:
+ return None
+
+ self.db.delete(device)
+ self.db.commit()
+ self.notify_change(device.id, NotifyChangeType.DEVICE_DELETE)
+
+ relation_service = DeviceModelRelationService(self.db)
+ relation_service.delete_relations_by_device(device_id)
+ return device
+
+ def get_device(self, device_id: int):
+ return self.db.get(Device, device_id)
diff --git a/services/model_service.py b/services/model_service.py
new file mode 100644
index 0000000..568895e
--- /dev/null
+++ b/services/model_service.py
@@ -0,0 +1,120 @@
+from datetime import datetime
+from typing import List, Sequence, Optional, Tuple
+
+from sqlalchemy import func
+from sqlmodel import Session, select
+
+from common.biz_exception import BizException
+from entity.device_model_relation import DeviceModelRelation
+from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate
+from common.global_thread_pool import GlobalThreadPool
+from common.consts import NotifyChangeType
+
+
+class ModelService:
+ def __init__(self, db: Session):
+ self.db = db
+ self.model_change_callbacks = [] # 用于存储回调函数
+ self.thread_pool = GlobalThreadPool()
+
+ def register_change_callback(self, callback):
+ """注册设备变化回调函数"""
+ self.model_change_callbacks.append(callback)
+
+ def notify_change(self, algo_model_id, change_type):
+ """当设备发生变化时,调用回调通知变化"""
+ for callback in self.model_change_callbacks:
+ self.thread_pool.executor.submit(callback, algo_model_id, change_type)
+
+ def get_model_list(self,
+ name: Optional[str] = None,
+ remark: Optional[str] = None,
+ ) -> Sequence[AlgoModel]:
+ statement = self.model_query(name, remark)
+ results = self.db.exec(statement)
+ return results.all()
+
+ def get_model_page(self,
+ name: Optional[str] = None,
+ remark: Optional[str] = None,
+ offset: int = 0,
+ limit: int = 10
+ ) -> Tuple[Sequence[AlgoModel], int]:
+ statement = self.model_query(name, remark)
+
+ # 查询总记录数
+ total_statement = select(func.count()).select_from(statement.subquery())
+ total = self.db.exec(total_statement).one()
+
+ # 添加分页限制
+ statement = statement.offset(offset).limit(limit)
+
+ # 执行查询并返回结果
+ results = self.db.exec(statement)
+ return results.all(), total # 返回分页数据和总数
+
+ def model_query(self, name, remark):
+ # 构建查询语句
+ statement = select(AlgoModel)
+ if name:
+ statement = statement.where(AlgoModel.name.like(f"%{name}%"))
+ if remark:
+ statement = statement.where(AlgoModel.remark.like(f"%{remark}%"))
+ return statement
+
+ def create_model(self, model_data: AlgoModelCreate):
+ model = AlgoModel.model_validate(model_data)
+ model.create_time = datetime.now()
+ model.update_time = datetime.now()
+ self.db.add(model)
+ self.db.commit()
+ self.db.refresh(model)
+ return model
+
+ def update_model(self, model_data: AlgoModelUpdate):
+ model = self.db.get(AlgoModel, model_data.id)
+ if not model:
+ return None
+
+ update_data = model_data.dict(exclude_unset=True)
+ for key, value in update_data.items():
+ setattr(model, key, value)
+
+ model.update_time = datetime.now()
+ self.db.add(model)
+ self.db.commit()
+ self.db.refresh(model)
+ self.notify_change(model.id, NotifyChangeType.MODEL_UPDATE)
+
+ return model
+
+ def delete_model(self, model_id: int):
+ model = self.db.get(AlgoModel, model_id)
+ if not model:
+ return None
+ # 查询 device_model_relation 中是否存在启用的绑定关系
+ statement = (
+ select(DeviceModelRelation)
+ .where(DeviceModelRelation.algo_model_id == model_id)
+ .where(DeviceModelRelation.is_use == 1)
+ )
+ relation_in_use = self.db.exec(statement).first()
+
+ # 如果存在启用的绑定关系,提示无法删除
+ if relation_in_use:
+ raise BizException(message=f"模型 {model.name} 正在被设备使用,无法删除")
+
+ self.db.delete(model)
+ self.db.commit()
+ return model
+
+ def get_models_in_use(self) -> Sequence[AlgoModel]:
+ """获取所有在 device_model_relation 表里有启用绑定关系的模型信息"""
+ statement = (
+ select(AlgoModel)
+ .join(DeviceModelRelation, DeviceModelRelation.algo_model_id == AlgoModel.id)
+ .where(DeviceModelRelation.is_use == 1)
+ .group_by(AlgoModel.id)
+ )
+ results = self.db.exec(statement).all()
+ return results
diff --git a/tcp/__init__.py b/tcp/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tcp/__init__.py