diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e530a86 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +test.py +/logs/* \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e530a86 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +test.py +/logs/* \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e530a86 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +test.py +/logs/* \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..901fe74 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,20 @@ + + + + \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e530a86 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +test.py +/logs/* \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..901fe74 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,20 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e530a86 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +test.py +/logs/* \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..901fe74 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,20 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..b68d82f --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e530a86 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +test.py +/logs/* \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..901fe74 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,20 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..b68d82f --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..3deddad --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e530a86 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +test.py +/logs/* \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..901fe74 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,20 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..b68d82f --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..3deddad --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml new file mode 100644 index 0000000..d0876a7 --- /dev/null +++ b/.idea/safe-algo-pro.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e530a86 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +test.py +/logs/* \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..901fe74 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,20 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..b68d82f --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..3deddad --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml new file mode 100644 index 0000000..d0876a7 --- /dev/null +++ b/.idea/safe-algo-pro.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e530a86 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +test.py +/logs/* \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..901fe74 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,20 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..b68d82f --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..3deddad --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml new file mode 100644 index 0000000..d0876a7 --- /dev/null +++ b/.idea/safe-algo-pro.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/algo/__init__.py b/algo/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/algo/__init__.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e530a86 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +test.py +/logs/* \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..901fe74 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,20 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..b68d82f --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..3deddad --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml new file mode 100644 index 0000000..d0876a7 --- /dev/null +++ b/.idea/safe-algo-pro.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/algo/__init__.py b/algo/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/algo/__init__.py diff --git a/algo/algo_runner.py b/algo/algo_runner.py new file mode 100644 index 0000000..ee07eaf --- /dev/null +++ b/algo/algo_runner.py @@ -0,0 +1,117 @@ +from algo.model_manager import ModelManager +from common.consts import NotifyChangeType +from common.global_thread_pool import GlobalThreadPool +from db.database import get_db +from entity.device import Device +from services.device_model_relation_service import DeviceModelRelationService +from services.device_service import DeviceService +from services.model_service import ModelService + + +class AlgoRunner: + def __init__(self): + self.db = get_db() + self.device_service = DeviceService(self.db) + self.model_service = ModelService(self.db) + self.relation_service = DeviceModelRelationService(self.db) + self.model_manager = ModelManager(self.db) + self.thread_pool = GlobalThreadPool() + self.threads = {} # 用于存储设备对应的线程 + self.device_model_relations = {} + + # 注册设备和模型的变化回调 + self.device_service.register_change_callback(self.on_device_change) + self.model_service.register_change_callback(self.on_model_change) + self.relation_service.relation_change_callbacks(self.on_relation_change) + + def start(self): + """在程序启动时调用,读取设备和模型,启动检测线程""" + self.model_manager.load_models() + self.load_and_start_devices() + + def load_and_start_devices(self): + """从数据库读取设备列表并启动线程""" + devices = self.device_service.get_device_list() + devices = [device for device in devices if device.input_stream_url] + for device in devices: + self.start_device_thread(device) + + def run_detection(self, device: Device): + """todo 设备目标检测主逻辑""" + print(f"设备 {device.device_id} 的检测线程启动") + video_stream = self.device_service.get_video_stream(device.device_id) + + while True: + try: + frame = video_stream.read_frame() + for model in self.model_manager.models: + # 调用目标检测模型 + result = self.model_service.run_inference(model, frame) + # 在此处进行后处理,例如标记视频、生成告警等 + self.process_result(result, device.device_id) + except Exception as e: + print(f"设备 {device.device_id} 处理时出错: {e}") + break + video_stream.close() + print(f"设备 {device.device_id} 的检测线程结束") + + def start_device_thread(self, device: Device): + """为单个设备启动检测线程""" + + if not device.input_stream_url: + print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程') + return + + if device.device_id in self.threads: + print(f"设备 {device.device_id} 已经在运行中") + return + + # 获取设备绑定的模型列表 + relations = self.relation_service.get_device_models(device.device_id) + self.device_model_relations[device.device_id] = relations + + if not relations: + print(f"设备 {device.code} 未绑定模型,无法启动检测") + return + + future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}') + + def stop_device_thread(self, device_id): + """todo 控制标志位 停止指定设备的检测线程""" + if device_id in self.threads: + self.threads[device_id].cancel() # 尝试取消线程 + del self.threads[device_id] + print(f"设备 {device_id} 的检测线程已停止") + + def restart_device_thread(self, device_id): + self.stop_device_thread(device_id) + # todo 需要留个关闭时间吗 + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + + def on_device_change(self, device_id, change_type): + """设备变化回调""" + if change_type == NotifyChangeType.DEVICE_CREATE: + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + elif change_type == NotifyChangeType.DEVICE_DELETE: + self.stop_device_thread(device_id) + elif change_type == NotifyChangeType.DEVICE_UPDATE: + self.restart_device_thread(device_id) + + def on_model_change(self, model_id, change_type): + """模型变化回调""" + if change_type == NotifyChangeType.MODEL_UPDATE: + self.model_manager.reload_model(model_id) + + devices_to_reload = [ + device_id for device_id, relation_info in self.device_model_relations.items() + if relation_info.model_id == model_id + ] + for device_id in devices_to_reload: + self.restart_device_thread(device_id) + + def on_relation_change(self, device_id, change_type): + """设备模型关系变化回调""" + if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: + self.restart_device_thread(device_id) diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e530a86 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +test.py +/logs/* \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..901fe74 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,20 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..b68d82f --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..3deddad --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml new file mode 100644 index 0000000..d0876a7 --- /dev/null +++ b/.idea/safe-algo-pro.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/algo/__init__.py b/algo/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/algo/__init__.py diff --git a/algo/algo_runner.py b/algo/algo_runner.py new file mode 100644 index 0000000..ee07eaf --- /dev/null +++ b/algo/algo_runner.py @@ -0,0 +1,117 @@ +from algo.model_manager import ModelManager +from common.consts import NotifyChangeType +from common.global_thread_pool import GlobalThreadPool +from db.database import get_db +from entity.device import Device +from services.device_model_relation_service import DeviceModelRelationService +from services.device_service import DeviceService +from services.model_service import ModelService + + +class AlgoRunner: + def __init__(self): + self.db = get_db() + self.device_service = DeviceService(self.db) + self.model_service = ModelService(self.db) + self.relation_service = DeviceModelRelationService(self.db) + self.model_manager = ModelManager(self.db) + self.thread_pool = GlobalThreadPool() + self.threads = {} # 用于存储设备对应的线程 + self.device_model_relations = {} + + # 注册设备和模型的变化回调 + self.device_service.register_change_callback(self.on_device_change) + self.model_service.register_change_callback(self.on_model_change) + self.relation_service.relation_change_callbacks(self.on_relation_change) + + def start(self): + """在程序启动时调用,读取设备和模型,启动检测线程""" + self.model_manager.load_models() + self.load_and_start_devices() + + def load_and_start_devices(self): + """从数据库读取设备列表并启动线程""" + devices = self.device_service.get_device_list() + devices = [device for device in devices if device.input_stream_url] + for device in devices: + self.start_device_thread(device) + + def run_detection(self, device: Device): + """todo 设备目标检测主逻辑""" + print(f"设备 {device.device_id} 的检测线程启动") + video_stream = self.device_service.get_video_stream(device.device_id) + + while True: + try: + frame = video_stream.read_frame() + for model in self.model_manager.models: + # 调用目标检测模型 + result = self.model_service.run_inference(model, frame) + # 在此处进行后处理,例如标记视频、生成告警等 + self.process_result(result, device.device_id) + except Exception as e: + print(f"设备 {device.device_id} 处理时出错: {e}") + break + video_stream.close() + print(f"设备 {device.device_id} 的检测线程结束") + + def start_device_thread(self, device: Device): + """为单个设备启动检测线程""" + + if not device.input_stream_url: + print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程') + return + + if device.device_id in self.threads: + print(f"设备 {device.device_id} 已经在运行中") + return + + # 获取设备绑定的模型列表 + relations = self.relation_service.get_device_models(device.device_id) + self.device_model_relations[device.device_id] = relations + + if not relations: + print(f"设备 {device.code} 未绑定模型,无法启动检测") + return + + future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}') + + def stop_device_thread(self, device_id): + """todo 控制标志位 停止指定设备的检测线程""" + if device_id in self.threads: + self.threads[device_id].cancel() # 尝试取消线程 + del self.threads[device_id] + print(f"设备 {device_id} 的检测线程已停止") + + def restart_device_thread(self, device_id): + self.stop_device_thread(device_id) + # todo 需要留个关闭时间吗 + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + + def on_device_change(self, device_id, change_type): + """设备变化回调""" + if change_type == NotifyChangeType.DEVICE_CREATE: + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + elif change_type == NotifyChangeType.DEVICE_DELETE: + self.stop_device_thread(device_id) + elif change_type == NotifyChangeType.DEVICE_UPDATE: + self.restart_device_thread(device_id) + + def on_model_change(self, model_id, change_type): + """模型变化回调""" + if change_type == NotifyChangeType.MODEL_UPDATE: + self.model_manager.reload_model(model_id) + + devices_to_reload = [ + device_id for device_id, relation_info in self.device_model_relations.items() + if relation_info.model_id == model_id + ] + for device_id in devices_to_reload: + self.restart_device_thread(device_id) + + def on_relation_change(self, device_id, change_type): + """设备模型关系变化回调""" + if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: + self.restart_device_thread(device_id) diff --git a/algo/model_manager.py b/algo/model_manager.py new file mode 100644 index 0000000..3ab0709 --- /dev/null +++ b/algo/model_manager.py @@ -0,0 +1,64 @@ +from dataclasses import dataclass +from typing import Optional + +import numpy as np +from ultralytics import YOLO + +from entity.model import AlgoModel +from services.model_service import ModelService +from common.global_logger import logger + + +@dataclass +class AlgoModelExec: + algo_model_id: int + algo_model_info: AlgoModel + algo_model_exec: Optional[object] = None + input_size: int = 640 + + +class ModelManager: + def __init__(self, db, model_warm_up=5): + self.db = db + self.model_service = ModelService(self.db) + self.models = {} + self.model_warm_up = model_warm_up + + self.load_models() + + def query_model_inuse(self): + algo_model_list = list(self.model_service.get_models_in_use()) + for algo_model in algo_model_list: + self.models[algo_model.id] = AlgoModelExec( + algo_model_id=algo_model.id, + algo_model_info=algo_model + ) + + def load_models(self): + self.models = {} + self.query_model_inuse() + for algo_model_id, algo_model_exec in self.models.items(): + self.load_model(algo_model_exec) + + def load_model(self, algo_model_exec: AlgoModelExec): + model_name = algo_model_exec.algo_model_info.name + model_path = algo_model_exec.algo_model_info.path + logger.info(f'loading model {model_name}: {model_path}') + + algo_model_exec.algo_model_exec = YOLO(model_path, task='detect') + if self.model_warm_up > 0: + logger.info(f'warming up model {model_name}') + imgsz = algo_model_exec.input_size + if not isinstance(imgsz, list): + imgsz = [imgsz, imgsz] + dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + for i in range(self.model_warm_up): + algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + logger.info(f'warm up model {model_name} success!') + logger.info(f'load model {model_name} success!') + + def reload_model(self, model_id): + algo_model_exec = self.models.get(model_id) + if algo_model_exec: + algo_model_exec.algo_model_exec = None + self.load_model(algo_model_exec) diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e530a86 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +test.py +/logs/* \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..901fe74 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,20 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..b68d82f --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..3deddad --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml new file mode 100644 index 0000000..d0876a7 --- /dev/null +++ b/.idea/safe-algo-pro.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/algo/__init__.py b/algo/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/algo/__init__.py diff --git a/algo/algo_runner.py b/algo/algo_runner.py new file mode 100644 index 0000000..ee07eaf --- /dev/null +++ b/algo/algo_runner.py @@ -0,0 +1,117 @@ +from algo.model_manager import ModelManager +from common.consts import NotifyChangeType +from common.global_thread_pool import GlobalThreadPool +from db.database import get_db +from entity.device import Device +from services.device_model_relation_service import DeviceModelRelationService +from services.device_service import DeviceService +from services.model_service import ModelService + + +class AlgoRunner: + def __init__(self): + self.db = get_db() + self.device_service = DeviceService(self.db) + self.model_service = ModelService(self.db) + self.relation_service = DeviceModelRelationService(self.db) + self.model_manager = ModelManager(self.db) + self.thread_pool = GlobalThreadPool() + self.threads = {} # 用于存储设备对应的线程 + self.device_model_relations = {} + + # 注册设备和模型的变化回调 + self.device_service.register_change_callback(self.on_device_change) + self.model_service.register_change_callback(self.on_model_change) + self.relation_service.relation_change_callbacks(self.on_relation_change) + + def start(self): + """在程序启动时调用,读取设备和模型,启动检测线程""" + self.model_manager.load_models() + self.load_and_start_devices() + + def load_and_start_devices(self): + """从数据库读取设备列表并启动线程""" + devices = self.device_service.get_device_list() + devices = [device for device in devices if device.input_stream_url] + for device in devices: + self.start_device_thread(device) + + def run_detection(self, device: Device): + """todo 设备目标检测主逻辑""" + print(f"设备 {device.device_id} 的检测线程启动") + video_stream = self.device_service.get_video_stream(device.device_id) + + while True: + try: + frame = video_stream.read_frame() + for model in self.model_manager.models: + # 调用目标检测模型 + result = self.model_service.run_inference(model, frame) + # 在此处进行后处理,例如标记视频、生成告警等 + self.process_result(result, device.device_id) + except Exception as e: + print(f"设备 {device.device_id} 处理时出错: {e}") + break + video_stream.close() + print(f"设备 {device.device_id} 的检测线程结束") + + def start_device_thread(self, device: Device): + """为单个设备启动检测线程""" + + if not device.input_stream_url: + print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程') + return + + if device.device_id in self.threads: + print(f"设备 {device.device_id} 已经在运行中") + return + + # 获取设备绑定的模型列表 + relations = self.relation_service.get_device_models(device.device_id) + self.device_model_relations[device.device_id] = relations + + if not relations: + print(f"设备 {device.code} 未绑定模型,无法启动检测") + return + + future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}') + + def stop_device_thread(self, device_id): + """todo 控制标志位 停止指定设备的检测线程""" + if device_id in self.threads: + self.threads[device_id].cancel() # 尝试取消线程 + del self.threads[device_id] + print(f"设备 {device_id} 的检测线程已停止") + + def restart_device_thread(self, device_id): + self.stop_device_thread(device_id) + # todo 需要留个关闭时间吗 + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + + def on_device_change(self, device_id, change_type): + """设备变化回调""" + if change_type == NotifyChangeType.DEVICE_CREATE: + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + elif change_type == NotifyChangeType.DEVICE_DELETE: + self.stop_device_thread(device_id) + elif change_type == NotifyChangeType.DEVICE_UPDATE: + self.restart_device_thread(device_id) + + def on_model_change(self, model_id, change_type): + """模型变化回调""" + if change_type == NotifyChangeType.MODEL_UPDATE: + self.model_manager.reload_model(model_id) + + devices_to_reload = [ + device_id for device_id, relation_info in self.device_model_relations.items() + if relation_info.model_id == model_id + ] + for device_id in devices_to_reload: + self.restart_device_thread(device_id) + + def on_relation_change(self, device_id, change_type): + """设备模型关系变化回调""" + if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: + self.restart_device_thread(device_id) diff --git a/algo/model_manager.py b/algo/model_manager.py new file mode 100644 index 0000000..3ab0709 --- /dev/null +++ b/algo/model_manager.py @@ -0,0 +1,64 @@ +from dataclasses import dataclass +from typing import Optional + +import numpy as np +from ultralytics import YOLO + +from entity.model import AlgoModel +from services.model_service import ModelService +from common.global_logger import logger + + +@dataclass +class AlgoModelExec: + algo_model_id: int + algo_model_info: AlgoModel + algo_model_exec: Optional[object] = None + input_size: int = 640 + + +class ModelManager: + def __init__(self, db, model_warm_up=5): + self.db = db + self.model_service = ModelService(self.db) + self.models = {} + self.model_warm_up = model_warm_up + + self.load_models() + + def query_model_inuse(self): + algo_model_list = list(self.model_service.get_models_in_use()) + for algo_model in algo_model_list: + self.models[algo_model.id] = AlgoModelExec( + algo_model_id=algo_model.id, + algo_model_info=algo_model + ) + + def load_models(self): + self.models = {} + self.query_model_inuse() + for algo_model_id, algo_model_exec in self.models.items(): + self.load_model(algo_model_exec) + + def load_model(self, algo_model_exec: AlgoModelExec): + model_name = algo_model_exec.algo_model_info.name + model_path = algo_model_exec.algo_model_info.path + logger.info(f'loading model {model_name}: {model_path}') + + algo_model_exec.algo_model_exec = YOLO(model_path, task='detect') + if self.model_warm_up > 0: + logger.info(f'warming up model {model_name}') + imgsz = algo_model_exec.input_size + if not isinstance(imgsz, list): + imgsz = [imgsz, imgsz] + dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + for i in range(self.model_warm_up): + algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + logger.info(f'warm up model {model_name} success!') + logger.info(f'load model {model_name} success!') + + def reload_model(self, model_id): + algo_model_exec = self.models.get(model_id) + if algo_model_exec: + algo_model_exec.algo_model_exec = None + self.load_model(algo_model_exec) diff --git a/apis/__init__.py b/apis/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/apis/__init__.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e530a86 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +test.py +/logs/* \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..901fe74 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,20 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..b68d82f --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..3deddad --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml new file mode 100644 index 0000000..d0876a7 --- /dev/null +++ b/.idea/safe-algo-pro.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/algo/__init__.py b/algo/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/algo/__init__.py diff --git a/algo/algo_runner.py b/algo/algo_runner.py new file mode 100644 index 0000000..ee07eaf --- /dev/null +++ b/algo/algo_runner.py @@ -0,0 +1,117 @@ +from algo.model_manager import ModelManager +from common.consts import NotifyChangeType +from common.global_thread_pool import GlobalThreadPool +from db.database import get_db +from entity.device import Device +from services.device_model_relation_service import DeviceModelRelationService +from services.device_service import DeviceService +from services.model_service import ModelService + + +class AlgoRunner: + def __init__(self): + self.db = get_db() + self.device_service = DeviceService(self.db) + self.model_service = ModelService(self.db) + self.relation_service = DeviceModelRelationService(self.db) + self.model_manager = ModelManager(self.db) + self.thread_pool = GlobalThreadPool() + self.threads = {} # 用于存储设备对应的线程 + self.device_model_relations = {} + + # 注册设备和模型的变化回调 + self.device_service.register_change_callback(self.on_device_change) + self.model_service.register_change_callback(self.on_model_change) + self.relation_service.relation_change_callbacks(self.on_relation_change) + + def start(self): + """在程序启动时调用,读取设备和模型,启动检测线程""" + self.model_manager.load_models() + self.load_and_start_devices() + + def load_and_start_devices(self): + """从数据库读取设备列表并启动线程""" + devices = self.device_service.get_device_list() + devices = [device for device in devices if device.input_stream_url] + for device in devices: + self.start_device_thread(device) + + def run_detection(self, device: Device): + """todo 设备目标检测主逻辑""" + print(f"设备 {device.device_id} 的检测线程启动") + video_stream = self.device_service.get_video_stream(device.device_id) + + while True: + try: + frame = video_stream.read_frame() + for model in self.model_manager.models: + # 调用目标检测模型 + result = self.model_service.run_inference(model, frame) + # 在此处进行后处理,例如标记视频、生成告警等 + self.process_result(result, device.device_id) + except Exception as e: + print(f"设备 {device.device_id} 处理时出错: {e}") + break + video_stream.close() + print(f"设备 {device.device_id} 的检测线程结束") + + def start_device_thread(self, device: Device): + """为单个设备启动检测线程""" + + if not device.input_stream_url: + print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程') + return + + if device.device_id in self.threads: + print(f"设备 {device.device_id} 已经在运行中") + return + + # 获取设备绑定的模型列表 + relations = self.relation_service.get_device_models(device.device_id) + self.device_model_relations[device.device_id] = relations + + if not relations: + print(f"设备 {device.code} 未绑定模型,无法启动检测") + return + + future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}') + + def stop_device_thread(self, device_id): + """todo 控制标志位 停止指定设备的检测线程""" + if device_id in self.threads: + self.threads[device_id].cancel() # 尝试取消线程 + del self.threads[device_id] + print(f"设备 {device_id} 的检测线程已停止") + + def restart_device_thread(self, device_id): + self.stop_device_thread(device_id) + # todo 需要留个关闭时间吗 + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + + def on_device_change(self, device_id, change_type): + """设备变化回调""" + if change_type == NotifyChangeType.DEVICE_CREATE: + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + elif change_type == NotifyChangeType.DEVICE_DELETE: + self.stop_device_thread(device_id) + elif change_type == NotifyChangeType.DEVICE_UPDATE: + self.restart_device_thread(device_id) + + def on_model_change(self, model_id, change_type): + """模型变化回调""" + if change_type == NotifyChangeType.MODEL_UPDATE: + self.model_manager.reload_model(model_id) + + devices_to_reload = [ + device_id for device_id, relation_info in self.device_model_relations.items() + if relation_info.model_id == model_id + ] + for device_id in devices_to_reload: + self.restart_device_thread(device_id) + + def on_relation_change(self, device_id, change_type): + """设备模型关系变化回调""" + if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: + self.restart_device_thread(device_id) diff --git a/algo/model_manager.py b/algo/model_manager.py new file mode 100644 index 0000000..3ab0709 --- /dev/null +++ b/algo/model_manager.py @@ -0,0 +1,64 @@ +from dataclasses import dataclass +from typing import Optional + +import numpy as np +from ultralytics import YOLO + +from entity.model import AlgoModel +from services.model_service import ModelService +from common.global_logger import logger + + +@dataclass +class AlgoModelExec: + algo_model_id: int + algo_model_info: AlgoModel + algo_model_exec: Optional[object] = None + input_size: int = 640 + + +class ModelManager: + def __init__(self, db, model_warm_up=5): + self.db = db + self.model_service = ModelService(self.db) + self.models = {} + self.model_warm_up = model_warm_up + + self.load_models() + + def query_model_inuse(self): + algo_model_list = list(self.model_service.get_models_in_use()) + for algo_model in algo_model_list: + self.models[algo_model.id] = AlgoModelExec( + algo_model_id=algo_model.id, + algo_model_info=algo_model + ) + + def load_models(self): + self.models = {} + self.query_model_inuse() + for algo_model_id, algo_model_exec in self.models.items(): + self.load_model(algo_model_exec) + + def load_model(self, algo_model_exec: AlgoModelExec): + model_name = algo_model_exec.algo_model_info.name + model_path = algo_model_exec.algo_model_info.path + logger.info(f'loading model {model_name}: {model_path}') + + algo_model_exec.algo_model_exec = YOLO(model_path, task='detect') + if self.model_warm_up > 0: + logger.info(f'warming up model {model_name}') + imgsz = algo_model_exec.input_size + if not isinstance(imgsz, list): + imgsz = [imgsz, imgsz] + dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + for i in range(self.model_warm_up): + algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + logger.info(f'warm up model {model_name} success!') + logger.info(f'load model {model_name} success!') + + def reload_model(self, model_id): + algo_model_exec = self.models.get(model_id) + if algo_model_exec: + algo_model_exec.algo_model_exec = None + self.load_model(algo_model_exec) diff --git a/apis/__init__.py b/apis/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/apis/__init__.py diff --git a/apis/base.py b/apis/base.py new file mode 100644 index 0000000..a7c4ede --- /dev/null +++ b/apis/base.py @@ -0,0 +1,26 @@ +from pydantic import BaseModel, Field +from typing import Optional, Generic, TypeVar, Any, List + +# 定义一个泛型类型变量 T +T = TypeVar("T") + + +# 使用泛型 T 定义 data 的类型 +class StandardResponse(BaseModel, Generic[T]): + code: int = 200 + data: Optional[T] = None + message: str = "请求成功" + success: bool = True + + +class PageResponse(BaseModel, Generic[T]): + total: int + items: List[T] = Field(default_factory=list) + + +def standard_response(data: Any = None, code: int = 200, message: str = "请求成功", success: bool = True): + return StandardResponse(data=data, code=code, message=message, success=success) + + +def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False): + return StandardResponse(data=data, code=code, message=message, success=success) diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e530a86 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +test.py +/logs/* \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..901fe74 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,20 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..b68d82f --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..3deddad --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml new file mode 100644 index 0000000..d0876a7 --- /dev/null +++ b/.idea/safe-algo-pro.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/algo/__init__.py b/algo/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/algo/__init__.py diff --git a/algo/algo_runner.py b/algo/algo_runner.py new file mode 100644 index 0000000..ee07eaf --- /dev/null +++ b/algo/algo_runner.py @@ -0,0 +1,117 @@ +from algo.model_manager import ModelManager +from common.consts import NotifyChangeType +from common.global_thread_pool import GlobalThreadPool +from db.database import get_db +from entity.device import Device +from services.device_model_relation_service import DeviceModelRelationService +from services.device_service import DeviceService +from services.model_service import ModelService + + +class AlgoRunner: + def __init__(self): + self.db = get_db() + self.device_service = DeviceService(self.db) + self.model_service = ModelService(self.db) + self.relation_service = DeviceModelRelationService(self.db) + self.model_manager = ModelManager(self.db) + self.thread_pool = GlobalThreadPool() + self.threads = {} # 用于存储设备对应的线程 + self.device_model_relations = {} + + # 注册设备和模型的变化回调 + self.device_service.register_change_callback(self.on_device_change) + self.model_service.register_change_callback(self.on_model_change) + self.relation_service.relation_change_callbacks(self.on_relation_change) + + def start(self): + """在程序启动时调用,读取设备和模型,启动检测线程""" + self.model_manager.load_models() + self.load_and_start_devices() + + def load_and_start_devices(self): + """从数据库读取设备列表并启动线程""" + devices = self.device_service.get_device_list() + devices = [device for device in devices if device.input_stream_url] + for device in devices: + self.start_device_thread(device) + + def run_detection(self, device: Device): + """todo 设备目标检测主逻辑""" + print(f"设备 {device.device_id} 的检测线程启动") + video_stream = self.device_service.get_video_stream(device.device_id) + + while True: + try: + frame = video_stream.read_frame() + for model in self.model_manager.models: + # 调用目标检测模型 + result = self.model_service.run_inference(model, frame) + # 在此处进行后处理,例如标记视频、生成告警等 + self.process_result(result, device.device_id) + except Exception as e: + print(f"设备 {device.device_id} 处理时出错: {e}") + break + video_stream.close() + print(f"设备 {device.device_id} 的检测线程结束") + + def start_device_thread(self, device: Device): + """为单个设备启动检测线程""" + + if not device.input_stream_url: + print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程') + return + + if device.device_id in self.threads: + print(f"设备 {device.device_id} 已经在运行中") + return + + # 获取设备绑定的模型列表 + relations = self.relation_service.get_device_models(device.device_id) + self.device_model_relations[device.device_id] = relations + + if not relations: + print(f"设备 {device.code} 未绑定模型,无法启动检测") + return + + future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}') + + def stop_device_thread(self, device_id): + """todo 控制标志位 停止指定设备的检测线程""" + if device_id in self.threads: + self.threads[device_id].cancel() # 尝试取消线程 + del self.threads[device_id] + print(f"设备 {device_id} 的检测线程已停止") + + def restart_device_thread(self, device_id): + self.stop_device_thread(device_id) + # todo 需要留个关闭时间吗 + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + + def on_device_change(self, device_id, change_type): + """设备变化回调""" + if change_type == NotifyChangeType.DEVICE_CREATE: + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + elif change_type == NotifyChangeType.DEVICE_DELETE: + self.stop_device_thread(device_id) + elif change_type == NotifyChangeType.DEVICE_UPDATE: + self.restart_device_thread(device_id) + + def on_model_change(self, model_id, change_type): + """模型变化回调""" + if change_type == NotifyChangeType.MODEL_UPDATE: + self.model_manager.reload_model(model_id) + + devices_to_reload = [ + device_id for device_id, relation_info in self.device_model_relations.items() + if relation_info.model_id == model_id + ] + for device_id in devices_to_reload: + self.restart_device_thread(device_id) + + def on_relation_change(self, device_id, change_type): + """设备模型关系变化回调""" + if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: + self.restart_device_thread(device_id) diff --git a/algo/model_manager.py b/algo/model_manager.py new file mode 100644 index 0000000..3ab0709 --- /dev/null +++ b/algo/model_manager.py @@ -0,0 +1,64 @@ +from dataclasses import dataclass +from typing import Optional + +import numpy as np +from ultralytics import YOLO + +from entity.model import AlgoModel +from services.model_service import ModelService +from common.global_logger import logger + + +@dataclass +class AlgoModelExec: + algo_model_id: int + algo_model_info: AlgoModel + algo_model_exec: Optional[object] = None + input_size: int = 640 + + +class ModelManager: + def __init__(self, db, model_warm_up=5): + self.db = db + self.model_service = ModelService(self.db) + self.models = {} + self.model_warm_up = model_warm_up + + self.load_models() + + def query_model_inuse(self): + algo_model_list = list(self.model_service.get_models_in_use()) + for algo_model in algo_model_list: + self.models[algo_model.id] = AlgoModelExec( + algo_model_id=algo_model.id, + algo_model_info=algo_model + ) + + def load_models(self): + self.models = {} + self.query_model_inuse() + for algo_model_id, algo_model_exec in self.models.items(): + self.load_model(algo_model_exec) + + def load_model(self, algo_model_exec: AlgoModelExec): + model_name = algo_model_exec.algo_model_info.name + model_path = algo_model_exec.algo_model_info.path + logger.info(f'loading model {model_name}: {model_path}') + + algo_model_exec.algo_model_exec = YOLO(model_path, task='detect') + if self.model_warm_up > 0: + logger.info(f'warming up model {model_name}') + imgsz = algo_model_exec.input_size + if not isinstance(imgsz, list): + imgsz = [imgsz, imgsz] + dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + for i in range(self.model_warm_up): + algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + logger.info(f'warm up model {model_name} success!') + logger.info(f'load model {model_name} success!') + + def reload_model(self, model_id): + algo_model_exec = self.models.get(model_id) + if algo_model_exec: + algo_model_exec.algo_model_exec = None + self.load_model(algo_model_exec) diff --git a/apis/__init__.py b/apis/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/apis/__init__.py diff --git a/apis/base.py b/apis/base.py new file mode 100644 index 0000000..a7c4ede --- /dev/null +++ b/apis/base.py @@ -0,0 +1,26 @@ +from pydantic import BaseModel, Field +from typing import Optional, Generic, TypeVar, Any, List + +# 定义一个泛型类型变量 T +T = TypeVar("T") + + +# 使用泛型 T 定义 data 的类型 +class StandardResponse(BaseModel, Generic[T]): + code: int = 200 + data: Optional[T] = None + message: str = "请求成功" + success: bool = True + + +class PageResponse(BaseModel, Generic[T]): + total: int + items: List[T] = Field(default_factory=list) + + +def standard_response(data: Any = None, code: int = 200, message: str = "请求成功", success: bool = True): + return StandardResponse(data=data, code=code, message=message, success=success) + + +def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False): + return StandardResponse(data=data, code=code, message=message, success=success) diff --git a/apis/device.py b/apis/device.py new file mode 100644 index 0000000..7eeaa78 --- /dev/null +++ b/apis/device.py @@ -0,0 +1,65 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, Query +from sqlmodel import Session + +from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response +from db.database import get_db +from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo +from services.device_service import DeviceService + +router = APIRouter() + + +@router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) +def get_device_list( + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + db: Session = Depends(get_db)): + service = DeviceService(db) + devices = list(service.get_device_list(name, code, device_type)) + return standard_response(data=devices) + + +@router.get("/page/", response_model=StandardResponse[PageResponse[Device]]) +def get_device_page( + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = DeviceService(db) + + # 获取分页后的设备列表和总数 + devices, total = service.get_device_page(name, code, device_type, offset, limit) + + return standard_response( + data=PageResponse(total=total, items=devices) + ) + + +@router.post("/add", response_model=StandardResponse[DeviceInfo]) +def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.create_device(device_data) + return standard_response(data=device) + + +@router.post("/update", response_model=StandardResponse[DeviceInfo]) +def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.update_device(device_data) + if not device: + return standard_error_response(data=device_data, message="Device not found") + return standard_response(data=device) + + +@router.delete("/delete", response_model=StandardResponse[int]) +def delete_device(device_id: int, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.delete_device(device_id) + if not device: + return standard_error_response(data=device_id, message="Device not found") + return standard_response(data=device_id) diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e530a86 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +test.py +/logs/* \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..901fe74 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,20 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..b68d82f --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..3deddad --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml new file mode 100644 index 0000000..d0876a7 --- /dev/null +++ b/.idea/safe-algo-pro.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/algo/__init__.py b/algo/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/algo/__init__.py diff --git a/algo/algo_runner.py b/algo/algo_runner.py new file mode 100644 index 0000000..ee07eaf --- /dev/null +++ b/algo/algo_runner.py @@ -0,0 +1,117 @@ +from algo.model_manager import ModelManager +from common.consts import NotifyChangeType +from common.global_thread_pool import GlobalThreadPool +from db.database import get_db +from entity.device import Device +from services.device_model_relation_service import DeviceModelRelationService +from services.device_service import DeviceService +from services.model_service import ModelService + + +class AlgoRunner: + def __init__(self): + self.db = get_db() + self.device_service = DeviceService(self.db) + self.model_service = ModelService(self.db) + self.relation_service = DeviceModelRelationService(self.db) + self.model_manager = ModelManager(self.db) + self.thread_pool = GlobalThreadPool() + self.threads = {} # 用于存储设备对应的线程 + self.device_model_relations = {} + + # 注册设备和模型的变化回调 + self.device_service.register_change_callback(self.on_device_change) + self.model_service.register_change_callback(self.on_model_change) + self.relation_service.relation_change_callbacks(self.on_relation_change) + + def start(self): + """在程序启动时调用,读取设备和模型,启动检测线程""" + self.model_manager.load_models() + self.load_and_start_devices() + + def load_and_start_devices(self): + """从数据库读取设备列表并启动线程""" + devices = self.device_service.get_device_list() + devices = [device for device in devices if device.input_stream_url] + for device in devices: + self.start_device_thread(device) + + def run_detection(self, device: Device): + """todo 设备目标检测主逻辑""" + print(f"设备 {device.device_id} 的检测线程启动") + video_stream = self.device_service.get_video_stream(device.device_id) + + while True: + try: + frame = video_stream.read_frame() + for model in self.model_manager.models: + # 调用目标检测模型 + result = self.model_service.run_inference(model, frame) + # 在此处进行后处理,例如标记视频、生成告警等 + self.process_result(result, device.device_id) + except Exception as e: + print(f"设备 {device.device_id} 处理时出错: {e}") + break + video_stream.close() + print(f"设备 {device.device_id} 的检测线程结束") + + def start_device_thread(self, device: Device): + """为单个设备启动检测线程""" + + if not device.input_stream_url: + print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程') + return + + if device.device_id in self.threads: + print(f"设备 {device.device_id} 已经在运行中") + return + + # 获取设备绑定的模型列表 + relations = self.relation_service.get_device_models(device.device_id) + self.device_model_relations[device.device_id] = relations + + if not relations: + print(f"设备 {device.code} 未绑定模型,无法启动检测") + return + + future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}') + + def stop_device_thread(self, device_id): + """todo 控制标志位 停止指定设备的检测线程""" + if device_id in self.threads: + self.threads[device_id].cancel() # 尝试取消线程 + del self.threads[device_id] + print(f"设备 {device_id} 的检测线程已停止") + + def restart_device_thread(self, device_id): + self.stop_device_thread(device_id) + # todo 需要留个关闭时间吗 + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + + def on_device_change(self, device_id, change_type): + """设备变化回调""" + if change_type == NotifyChangeType.DEVICE_CREATE: + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + elif change_type == NotifyChangeType.DEVICE_DELETE: + self.stop_device_thread(device_id) + elif change_type == NotifyChangeType.DEVICE_UPDATE: + self.restart_device_thread(device_id) + + def on_model_change(self, model_id, change_type): + """模型变化回调""" + if change_type == NotifyChangeType.MODEL_UPDATE: + self.model_manager.reload_model(model_id) + + devices_to_reload = [ + device_id for device_id, relation_info in self.device_model_relations.items() + if relation_info.model_id == model_id + ] + for device_id in devices_to_reload: + self.restart_device_thread(device_id) + + def on_relation_change(self, device_id, change_type): + """设备模型关系变化回调""" + if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: + self.restart_device_thread(device_id) diff --git a/algo/model_manager.py b/algo/model_manager.py new file mode 100644 index 0000000..3ab0709 --- /dev/null +++ b/algo/model_manager.py @@ -0,0 +1,64 @@ +from dataclasses import dataclass +from typing import Optional + +import numpy as np +from ultralytics import YOLO + +from entity.model import AlgoModel +from services.model_service import ModelService +from common.global_logger import logger + + +@dataclass +class AlgoModelExec: + algo_model_id: int + algo_model_info: AlgoModel + algo_model_exec: Optional[object] = None + input_size: int = 640 + + +class ModelManager: + def __init__(self, db, model_warm_up=5): + self.db = db + self.model_service = ModelService(self.db) + self.models = {} + self.model_warm_up = model_warm_up + + self.load_models() + + def query_model_inuse(self): + algo_model_list = list(self.model_service.get_models_in_use()) + for algo_model in algo_model_list: + self.models[algo_model.id] = AlgoModelExec( + algo_model_id=algo_model.id, + algo_model_info=algo_model + ) + + def load_models(self): + self.models = {} + self.query_model_inuse() + for algo_model_id, algo_model_exec in self.models.items(): + self.load_model(algo_model_exec) + + def load_model(self, algo_model_exec: AlgoModelExec): + model_name = algo_model_exec.algo_model_info.name + model_path = algo_model_exec.algo_model_info.path + logger.info(f'loading model {model_name}: {model_path}') + + algo_model_exec.algo_model_exec = YOLO(model_path, task='detect') + if self.model_warm_up > 0: + logger.info(f'warming up model {model_name}') + imgsz = algo_model_exec.input_size + if not isinstance(imgsz, list): + imgsz = [imgsz, imgsz] + dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + for i in range(self.model_warm_up): + algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + logger.info(f'warm up model {model_name} success!') + logger.info(f'load model {model_name} success!') + + def reload_model(self, model_id): + algo_model_exec = self.models.get(model_id) + if algo_model_exec: + algo_model_exec.algo_model_exec = None + self.load_model(algo_model_exec) diff --git a/apis/__init__.py b/apis/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/apis/__init__.py diff --git a/apis/base.py b/apis/base.py new file mode 100644 index 0000000..a7c4ede --- /dev/null +++ b/apis/base.py @@ -0,0 +1,26 @@ +from pydantic import BaseModel, Field +from typing import Optional, Generic, TypeVar, Any, List + +# 定义一个泛型类型变量 T +T = TypeVar("T") + + +# 使用泛型 T 定义 data 的类型 +class StandardResponse(BaseModel, Generic[T]): + code: int = 200 + data: Optional[T] = None + message: str = "请求成功" + success: bool = True + + +class PageResponse(BaseModel, Generic[T]): + total: int + items: List[T] = Field(default_factory=list) + + +def standard_response(data: Any = None, code: int = 200, message: str = "请求成功", success: bool = True): + return StandardResponse(data=data, code=code, message=message, success=success) + + +def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False): + return StandardResponse(data=data, code=code, message=message, success=success) diff --git a/apis/device.py b/apis/device.py new file mode 100644 index 0000000..7eeaa78 --- /dev/null +++ b/apis/device.py @@ -0,0 +1,65 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, Query +from sqlmodel import Session + +from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response +from db.database import get_db +from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo +from services.device_service import DeviceService + +router = APIRouter() + + +@router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) +def get_device_list( + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + db: Session = Depends(get_db)): + service = DeviceService(db) + devices = list(service.get_device_list(name, code, device_type)) + return standard_response(data=devices) + + +@router.get("/page/", response_model=StandardResponse[PageResponse[Device]]) +def get_device_page( + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = DeviceService(db) + + # 获取分页后的设备列表和总数 + devices, total = service.get_device_page(name, code, device_type, offset, limit) + + return standard_response( + data=PageResponse(total=total, items=devices) + ) + + +@router.post("/add", response_model=StandardResponse[DeviceInfo]) +def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.create_device(device_data) + return standard_response(data=device) + + +@router.post("/update", response_model=StandardResponse[DeviceInfo]) +def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.update_device(device_data) + if not device: + return standard_error_response(data=device_data, message="Device not found") + return standard_response(data=device) + + +@router.delete("/delete", response_model=StandardResponse[int]) +def delete_device(device_id: int, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.delete_device(device_id) + if not device: + return standard_error_response(data=device_id, message="Device not found") + return standard_response(data=device_id) diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py new file mode 100644 index 0000000..7108571 --- /dev/null +++ b/apis/device_model_realtion.py @@ -0,0 +1,45 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, Query +from sqlmodel import Session + +from apis.base import StandardResponse, standard_response +from db.database import get_db +from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation +from services.device_model_relation_service import DeviceModelRelationService + +router = APIRouter() + + +@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) +def list_by_device( + device_id: int, + db: Session = Depends(get_db)): + service = DeviceModelRelationService(db) + models = list(service.get_device_models(device_id)) + return standard_response(data=models) + + +# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) +# def add_by_device(relation_data: List[DeviceModelRelationCreate], +# device_id: int = Query(...), +# db: Session = Depends(get_db)): +# service = DeviceModelRelationService(db) +# relations = service.add_relations_by_device(device_id, relation_data) +# return standard_response(data=relations) + + +@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) +def update_by_device(relation_data: List[DeviceModelRelationCreate], + device_id: int = Query(...), + db: Session = Depends(get_db)): + service = DeviceModelRelationService(db) + relations = service.update_relations_by_device(device_id, relation_data) + return standard_response(data=relations) + + +# @router.delete("/delete_by_device", response_model=StandardResponse[int]) +# def delete_device(device_id: int, db: Session = Depends(get_db)): +# service = DeviceModelRelationService(db) +# count = service.delete_relations_by_device(device_id) +# return standard_response(data=count) diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e530a86 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +test.py +/logs/* \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..901fe74 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,20 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..b68d82f --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..3deddad --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml new file mode 100644 index 0000000..d0876a7 --- /dev/null +++ b/.idea/safe-algo-pro.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/algo/__init__.py b/algo/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/algo/__init__.py diff --git a/algo/algo_runner.py b/algo/algo_runner.py new file mode 100644 index 0000000..ee07eaf --- /dev/null +++ b/algo/algo_runner.py @@ -0,0 +1,117 @@ +from algo.model_manager import ModelManager +from common.consts import NotifyChangeType +from common.global_thread_pool import GlobalThreadPool +from db.database import get_db +from entity.device import Device +from services.device_model_relation_service import DeviceModelRelationService +from services.device_service import DeviceService +from services.model_service import ModelService + + +class AlgoRunner: + def __init__(self): + self.db = get_db() + self.device_service = DeviceService(self.db) + self.model_service = ModelService(self.db) + self.relation_service = DeviceModelRelationService(self.db) + self.model_manager = ModelManager(self.db) + self.thread_pool = GlobalThreadPool() + self.threads = {} # 用于存储设备对应的线程 + self.device_model_relations = {} + + # 注册设备和模型的变化回调 + self.device_service.register_change_callback(self.on_device_change) + self.model_service.register_change_callback(self.on_model_change) + self.relation_service.relation_change_callbacks(self.on_relation_change) + + def start(self): + """在程序启动时调用,读取设备和模型,启动检测线程""" + self.model_manager.load_models() + self.load_and_start_devices() + + def load_and_start_devices(self): + """从数据库读取设备列表并启动线程""" + devices = self.device_service.get_device_list() + devices = [device for device in devices if device.input_stream_url] + for device in devices: + self.start_device_thread(device) + + def run_detection(self, device: Device): + """todo 设备目标检测主逻辑""" + print(f"设备 {device.device_id} 的检测线程启动") + video_stream = self.device_service.get_video_stream(device.device_id) + + while True: + try: + frame = video_stream.read_frame() + for model in self.model_manager.models: + # 调用目标检测模型 + result = self.model_service.run_inference(model, frame) + # 在此处进行后处理,例如标记视频、生成告警等 + self.process_result(result, device.device_id) + except Exception as e: + print(f"设备 {device.device_id} 处理时出错: {e}") + break + video_stream.close() + print(f"设备 {device.device_id} 的检测线程结束") + + def start_device_thread(self, device: Device): + """为单个设备启动检测线程""" + + if not device.input_stream_url: + print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程') + return + + if device.device_id in self.threads: + print(f"设备 {device.device_id} 已经在运行中") + return + + # 获取设备绑定的模型列表 + relations = self.relation_service.get_device_models(device.device_id) + self.device_model_relations[device.device_id] = relations + + if not relations: + print(f"设备 {device.code} 未绑定模型,无法启动检测") + return + + future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}') + + def stop_device_thread(self, device_id): + """todo 控制标志位 停止指定设备的检测线程""" + if device_id in self.threads: + self.threads[device_id].cancel() # 尝试取消线程 + del self.threads[device_id] + print(f"设备 {device_id} 的检测线程已停止") + + def restart_device_thread(self, device_id): + self.stop_device_thread(device_id) + # todo 需要留个关闭时间吗 + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + + def on_device_change(self, device_id, change_type): + """设备变化回调""" + if change_type == NotifyChangeType.DEVICE_CREATE: + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + elif change_type == NotifyChangeType.DEVICE_DELETE: + self.stop_device_thread(device_id) + elif change_type == NotifyChangeType.DEVICE_UPDATE: + self.restart_device_thread(device_id) + + def on_model_change(self, model_id, change_type): + """模型变化回调""" + if change_type == NotifyChangeType.MODEL_UPDATE: + self.model_manager.reload_model(model_id) + + devices_to_reload = [ + device_id for device_id, relation_info in self.device_model_relations.items() + if relation_info.model_id == model_id + ] + for device_id in devices_to_reload: + self.restart_device_thread(device_id) + + def on_relation_change(self, device_id, change_type): + """设备模型关系变化回调""" + if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: + self.restart_device_thread(device_id) diff --git a/algo/model_manager.py b/algo/model_manager.py new file mode 100644 index 0000000..3ab0709 --- /dev/null +++ b/algo/model_manager.py @@ -0,0 +1,64 @@ +from dataclasses import dataclass +from typing import Optional + +import numpy as np +from ultralytics import YOLO + +from entity.model import AlgoModel +from services.model_service import ModelService +from common.global_logger import logger + + +@dataclass +class AlgoModelExec: + algo_model_id: int + algo_model_info: AlgoModel + algo_model_exec: Optional[object] = None + input_size: int = 640 + + +class ModelManager: + def __init__(self, db, model_warm_up=5): + self.db = db + self.model_service = ModelService(self.db) + self.models = {} + self.model_warm_up = model_warm_up + + self.load_models() + + def query_model_inuse(self): + algo_model_list = list(self.model_service.get_models_in_use()) + for algo_model in algo_model_list: + self.models[algo_model.id] = AlgoModelExec( + algo_model_id=algo_model.id, + algo_model_info=algo_model + ) + + def load_models(self): + self.models = {} + self.query_model_inuse() + for algo_model_id, algo_model_exec in self.models.items(): + self.load_model(algo_model_exec) + + def load_model(self, algo_model_exec: AlgoModelExec): + model_name = algo_model_exec.algo_model_info.name + model_path = algo_model_exec.algo_model_info.path + logger.info(f'loading model {model_name}: {model_path}') + + algo_model_exec.algo_model_exec = YOLO(model_path, task='detect') + if self.model_warm_up > 0: + logger.info(f'warming up model {model_name}') + imgsz = algo_model_exec.input_size + if not isinstance(imgsz, list): + imgsz = [imgsz, imgsz] + dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + for i in range(self.model_warm_up): + algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + logger.info(f'warm up model {model_name} success!') + logger.info(f'load model {model_name} success!') + + def reload_model(self, model_id): + algo_model_exec = self.models.get(model_id) + if algo_model_exec: + algo_model_exec.algo_model_exec = None + self.load_model(algo_model_exec) diff --git a/apis/__init__.py b/apis/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/apis/__init__.py diff --git a/apis/base.py b/apis/base.py new file mode 100644 index 0000000..a7c4ede --- /dev/null +++ b/apis/base.py @@ -0,0 +1,26 @@ +from pydantic import BaseModel, Field +from typing import Optional, Generic, TypeVar, Any, List + +# 定义一个泛型类型变量 T +T = TypeVar("T") + + +# 使用泛型 T 定义 data 的类型 +class StandardResponse(BaseModel, Generic[T]): + code: int = 200 + data: Optional[T] = None + message: str = "请求成功" + success: bool = True + + +class PageResponse(BaseModel, Generic[T]): + total: int + items: List[T] = Field(default_factory=list) + + +def standard_response(data: Any = None, code: int = 200, message: str = "请求成功", success: bool = True): + return StandardResponse(data=data, code=code, message=message, success=success) + + +def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False): + return StandardResponse(data=data, code=code, message=message, success=success) diff --git a/apis/device.py b/apis/device.py new file mode 100644 index 0000000..7eeaa78 --- /dev/null +++ b/apis/device.py @@ -0,0 +1,65 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, Query +from sqlmodel import Session + +from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response +from db.database import get_db +from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo +from services.device_service import DeviceService + +router = APIRouter() + + +@router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) +def get_device_list( + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + db: Session = Depends(get_db)): + service = DeviceService(db) + devices = list(service.get_device_list(name, code, device_type)) + return standard_response(data=devices) + + +@router.get("/page/", response_model=StandardResponse[PageResponse[Device]]) +def get_device_page( + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = DeviceService(db) + + # 获取分页后的设备列表和总数 + devices, total = service.get_device_page(name, code, device_type, offset, limit) + + return standard_response( + data=PageResponse(total=total, items=devices) + ) + + +@router.post("/add", response_model=StandardResponse[DeviceInfo]) +def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.create_device(device_data) + return standard_response(data=device) + + +@router.post("/update", response_model=StandardResponse[DeviceInfo]) +def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.update_device(device_data) + if not device: + return standard_error_response(data=device_data, message="Device not found") + return standard_response(data=device) + + +@router.delete("/delete", response_model=StandardResponse[int]) +def delete_device(device_id: int, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.delete_device(device_id) + if not device: + return standard_error_response(data=device_id, message="Device not found") + return standard_response(data=device_id) diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py new file mode 100644 index 0000000..7108571 --- /dev/null +++ b/apis/device_model_realtion.py @@ -0,0 +1,45 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, Query +from sqlmodel import Session + +from apis.base import StandardResponse, standard_response +from db.database import get_db +from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation +from services.device_model_relation_service import DeviceModelRelationService + +router = APIRouter() + + +@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) +def list_by_device( + device_id: int, + db: Session = Depends(get_db)): + service = DeviceModelRelationService(db) + models = list(service.get_device_models(device_id)) + return standard_response(data=models) + + +# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) +# def add_by_device(relation_data: List[DeviceModelRelationCreate], +# device_id: int = Query(...), +# db: Session = Depends(get_db)): +# service = DeviceModelRelationService(db) +# relations = service.add_relations_by_device(device_id, relation_data) +# return standard_response(data=relations) + + +@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) +def update_by_device(relation_data: List[DeviceModelRelationCreate], + device_id: int = Query(...), + db: Session = Depends(get_db)): + service = DeviceModelRelationService(db) + relations = service.update_relations_by_device(device_id, relation_data) + return standard_response(data=relations) + + +# @router.delete("/delete_by_device", response_model=StandardResponse[int]) +# def delete_device(device_id: int, db: Session = Depends(get_db)): +# service = DeviceModelRelationService(db) +# count = service.delete_relations_by_device(device_id) +# return standard_response(data=count) diff --git a/apis/model.py b/apis/model.py new file mode 100644 index 0000000..8adcc14 --- /dev/null +++ b/apis/model.py @@ -0,0 +1,64 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlmodel import Session + +from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response +from common.biz_exception import BizException +from db.database import get_db +from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo +from services.model_service import ModelService + +router = APIRouter() + + +@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) +def get_model_list( + name: Optional[str] = None, + remark: Optional[str] = None, + db: Session = Depends(get_db)): + service = ModelService(db) + models = list(service.get_model_list(name, remark)) + return standard_response(data=models) + + +@router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModel]]) +def get_model_page( + name: Optional[str] = None, + remark: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = ModelService(db) + + # 获取分页后的设备列表和总数 + models, total = service.get_model_page(name, remark, offset, limit) + + return standard_response( + data=PageResponse(total=total, items=models) + ) + + +@router.post("/add", response_model=StandardResponse[AlgoModelInfo]) +def create_model(model_data: AlgoModelCreate, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.create_model(model_data) + return standard_response(data=model) + + +@router.post("/update", response_model=StandardResponse[AlgoModelInfo]) +def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.update_model(model_data) + if not model: + return standard_error_response(data=model_data, message="Model not found") + return standard_response(data=model) + + +@router.delete("/delete", response_model=StandardResponse[int]) +def delete_model(model_id: int, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.delete_model(model_id) + if not model: + return standard_error_response(data=model_id, message="Model not found") + return standard_response(data=model_id) diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e530a86 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +test.py +/logs/* \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..901fe74 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,20 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..b68d82f --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..3deddad --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml new file mode 100644 index 0000000..d0876a7 --- /dev/null +++ b/.idea/safe-algo-pro.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/algo/__init__.py b/algo/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/algo/__init__.py diff --git a/algo/algo_runner.py b/algo/algo_runner.py new file mode 100644 index 0000000..ee07eaf --- /dev/null +++ b/algo/algo_runner.py @@ -0,0 +1,117 @@ +from algo.model_manager import ModelManager +from common.consts import NotifyChangeType +from common.global_thread_pool import GlobalThreadPool +from db.database import get_db +from entity.device import Device +from services.device_model_relation_service import DeviceModelRelationService +from services.device_service import DeviceService +from services.model_service import ModelService + + +class AlgoRunner: + def __init__(self): + self.db = get_db() + self.device_service = DeviceService(self.db) + self.model_service = ModelService(self.db) + self.relation_service = DeviceModelRelationService(self.db) + self.model_manager = ModelManager(self.db) + self.thread_pool = GlobalThreadPool() + self.threads = {} # 用于存储设备对应的线程 + self.device_model_relations = {} + + # 注册设备和模型的变化回调 + self.device_service.register_change_callback(self.on_device_change) + self.model_service.register_change_callback(self.on_model_change) + self.relation_service.relation_change_callbacks(self.on_relation_change) + + def start(self): + """在程序启动时调用,读取设备和模型,启动检测线程""" + self.model_manager.load_models() + self.load_and_start_devices() + + def load_and_start_devices(self): + """从数据库读取设备列表并启动线程""" + devices = self.device_service.get_device_list() + devices = [device for device in devices if device.input_stream_url] + for device in devices: + self.start_device_thread(device) + + def run_detection(self, device: Device): + """todo 设备目标检测主逻辑""" + print(f"设备 {device.device_id} 的检测线程启动") + video_stream = self.device_service.get_video_stream(device.device_id) + + while True: + try: + frame = video_stream.read_frame() + for model in self.model_manager.models: + # 调用目标检测模型 + result = self.model_service.run_inference(model, frame) + # 在此处进行后处理,例如标记视频、生成告警等 + self.process_result(result, device.device_id) + except Exception as e: + print(f"设备 {device.device_id} 处理时出错: {e}") + break + video_stream.close() + print(f"设备 {device.device_id} 的检测线程结束") + + def start_device_thread(self, device: Device): + """为单个设备启动检测线程""" + + if not device.input_stream_url: + print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程') + return + + if device.device_id in self.threads: + print(f"设备 {device.device_id} 已经在运行中") + return + + # 获取设备绑定的模型列表 + relations = self.relation_service.get_device_models(device.device_id) + self.device_model_relations[device.device_id] = relations + + if not relations: + print(f"设备 {device.code} 未绑定模型,无法启动检测") + return + + future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}') + + def stop_device_thread(self, device_id): + """todo 控制标志位 停止指定设备的检测线程""" + if device_id in self.threads: + self.threads[device_id].cancel() # 尝试取消线程 + del self.threads[device_id] + print(f"设备 {device_id} 的检测线程已停止") + + def restart_device_thread(self, device_id): + self.stop_device_thread(device_id) + # todo 需要留个关闭时间吗 + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + + def on_device_change(self, device_id, change_type): + """设备变化回调""" + if change_type == NotifyChangeType.DEVICE_CREATE: + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + elif change_type == NotifyChangeType.DEVICE_DELETE: + self.stop_device_thread(device_id) + elif change_type == NotifyChangeType.DEVICE_UPDATE: + self.restart_device_thread(device_id) + + def on_model_change(self, model_id, change_type): + """模型变化回调""" + if change_type == NotifyChangeType.MODEL_UPDATE: + self.model_manager.reload_model(model_id) + + devices_to_reload = [ + device_id for device_id, relation_info in self.device_model_relations.items() + if relation_info.model_id == model_id + ] + for device_id in devices_to_reload: + self.restart_device_thread(device_id) + + def on_relation_change(self, device_id, change_type): + """设备模型关系变化回调""" + if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: + self.restart_device_thread(device_id) diff --git a/algo/model_manager.py b/algo/model_manager.py new file mode 100644 index 0000000..3ab0709 --- /dev/null +++ b/algo/model_manager.py @@ -0,0 +1,64 @@ +from dataclasses import dataclass +from typing import Optional + +import numpy as np +from ultralytics import YOLO + +from entity.model import AlgoModel +from services.model_service import ModelService +from common.global_logger import logger + + +@dataclass +class AlgoModelExec: + algo_model_id: int + algo_model_info: AlgoModel + algo_model_exec: Optional[object] = None + input_size: int = 640 + + +class ModelManager: + def __init__(self, db, model_warm_up=5): + self.db = db + self.model_service = ModelService(self.db) + self.models = {} + self.model_warm_up = model_warm_up + + self.load_models() + + def query_model_inuse(self): + algo_model_list = list(self.model_service.get_models_in_use()) + for algo_model in algo_model_list: + self.models[algo_model.id] = AlgoModelExec( + algo_model_id=algo_model.id, + algo_model_info=algo_model + ) + + def load_models(self): + self.models = {} + self.query_model_inuse() + for algo_model_id, algo_model_exec in self.models.items(): + self.load_model(algo_model_exec) + + def load_model(self, algo_model_exec: AlgoModelExec): + model_name = algo_model_exec.algo_model_info.name + model_path = algo_model_exec.algo_model_info.path + logger.info(f'loading model {model_name}: {model_path}') + + algo_model_exec.algo_model_exec = YOLO(model_path, task='detect') + if self.model_warm_up > 0: + logger.info(f'warming up model {model_name}') + imgsz = algo_model_exec.input_size + if not isinstance(imgsz, list): + imgsz = [imgsz, imgsz] + dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + for i in range(self.model_warm_up): + algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + logger.info(f'warm up model {model_name} success!') + logger.info(f'load model {model_name} success!') + + def reload_model(self, model_id): + algo_model_exec = self.models.get(model_id) + if algo_model_exec: + algo_model_exec.algo_model_exec = None + self.load_model(algo_model_exec) diff --git a/apis/__init__.py b/apis/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/apis/__init__.py diff --git a/apis/base.py b/apis/base.py new file mode 100644 index 0000000..a7c4ede --- /dev/null +++ b/apis/base.py @@ -0,0 +1,26 @@ +from pydantic import BaseModel, Field +from typing import Optional, Generic, TypeVar, Any, List + +# 定义一个泛型类型变量 T +T = TypeVar("T") + + +# 使用泛型 T 定义 data 的类型 +class StandardResponse(BaseModel, Generic[T]): + code: int = 200 + data: Optional[T] = None + message: str = "请求成功" + success: bool = True + + +class PageResponse(BaseModel, Generic[T]): + total: int + items: List[T] = Field(default_factory=list) + + +def standard_response(data: Any = None, code: int = 200, message: str = "请求成功", success: bool = True): + return StandardResponse(data=data, code=code, message=message, success=success) + + +def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False): + return StandardResponse(data=data, code=code, message=message, success=success) diff --git a/apis/device.py b/apis/device.py new file mode 100644 index 0000000..7eeaa78 --- /dev/null +++ b/apis/device.py @@ -0,0 +1,65 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, Query +from sqlmodel import Session + +from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response +from db.database import get_db +from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo +from services.device_service import DeviceService + +router = APIRouter() + + +@router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) +def get_device_list( + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + db: Session = Depends(get_db)): + service = DeviceService(db) + devices = list(service.get_device_list(name, code, device_type)) + return standard_response(data=devices) + + +@router.get("/page/", response_model=StandardResponse[PageResponse[Device]]) +def get_device_page( + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = DeviceService(db) + + # 获取分页后的设备列表和总数 + devices, total = service.get_device_page(name, code, device_type, offset, limit) + + return standard_response( + data=PageResponse(total=total, items=devices) + ) + + +@router.post("/add", response_model=StandardResponse[DeviceInfo]) +def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.create_device(device_data) + return standard_response(data=device) + + +@router.post("/update", response_model=StandardResponse[DeviceInfo]) +def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.update_device(device_data) + if not device: + return standard_error_response(data=device_data, message="Device not found") + return standard_response(data=device) + + +@router.delete("/delete", response_model=StandardResponse[int]) +def delete_device(device_id: int, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.delete_device(device_id) + if not device: + return standard_error_response(data=device_id, message="Device not found") + return standard_response(data=device_id) diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py new file mode 100644 index 0000000..7108571 --- /dev/null +++ b/apis/device_model_realtion.py @@ -0,0 +1,45 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, Query +from sqlmodel import Session + +from apis.base import StandardResponse, standard_response +from db.database import get_db +from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation +from services.device_model_relation_service import DeviceModelRelationService + +router = APIRouter() + + +@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) +def list_by_device( + device_id: int, + db: Session = Depends(get_db)): + service = DeviceModelRelationService(db) + models = list(service.get_device_models(device_id)) + return standard_response(data=models) + + +# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) +# def add_by_device(relation_data: List[DeviceModelRelationCreate], +# device_id: int = Query(...), +# db: Session = Depends(get_db)): +# service = DeviceModelRelationService(db) +# relations = service.add_relations_by_device(device_id, relation_data) +# return standard_response(data=relations) + + +@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) +def update_by_device(relation_data: List[DeviceModelRelationCreate], + device_id: int = Query(...), + db: Session = Depends(get_db)): + service = DeviceModelRelationService(db) + relations = service.update_relations_by_device(device_id, relation_data) + return standard_response(data=relations) + + +# @router.delete("/delete_by_device", response_model=StandardResponse[int]) +# def delete_device(device_id: int, db: Session = Depends(get_db)): +# service = DeviceModelRelationService(db) +# count = service.delete_relations_by_device(device_id) +# return standard_response(data=count) diff --git a/apis/model.py b/apis/model.py new file mode 100644 index 0000000..8adcc14 --- /dev/null +++ b/apis/model.py @@ -0,0 +1,64 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlmodel import Session + +from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response +from common.biz_exception import BizException +from db.database import get_db +from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo +from services.model_service import ModelService + +router = APIRouter() + + +@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) +def get_model_list( + name: Optional[str] = None, + remark: Optional[str] = None, + db: Session = Depends(get_db)): + service = ModelService(db) + models = list(service.get_model_list(name, remark)) + return standard_response(data=models) + + +@router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModel]]) +def get_model_page( + name: Optional[str] = None, + remark: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = ModelService(db) + + # 获取分页后的设备列表和总数 + models, total = service.get_model_page(name, remark, offset, limit) + + return standard_response( + data=PageResponse(total=total, items=models) + ) + + +@router.post("/add", response_model=StandardResponse[AlgoModelInfo]) +def create_model(model_data: AlgoModelCreate, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.create_model(model_data) + return standard_response(data=model) + + +@router.post("/update", response_model=StandardResponse[AlgoModelInfo]) +def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.update_model(model_data) + if not model: + return standard_error_response(data=model_data, message="Model not found") + return standard_response(data=model) + + +@router.delete("/delete", response_model=StandardResponse[int]) +def delete_model(model_id: int, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.delete_model(model_id) + if not model: + return standard_error_response(data=model_id, message="Model not found") + return standard_response(data=model_id) diff --git a/apis/router.py b/apis/router.py new file mode 100644 index 0000000..22b5ed6 --- /dev/null +++ b/apis/router.py @@ -0,0 +1,15 @@ +# api/router.py +from fastapi import APIRouter +from .device import router as devices_router +from .model import router as models_router +from .device_model_realtion import router as device_model_relation_router + + +# 创建一个全局的 router +router = APIRouter() + +# 将各个模块的 router 注册到全局 router 中 +router.include_router(devices_router, prefix="/device", tags=["Devices"]) +router.include_router(models_router, prefix="/model", tags=["Models"]) +router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"]) + diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e530a86 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +test.py +/logs/* \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..901fe74 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,20 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..b68d82f --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..3deddad --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml new file mode 100644 index 0000000..d0876a7 --- /dev/null +++ b/.idea/safe-algo-pro.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/algo/__init__.py b/algo/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/algo/__init__.py diff --git a/algo/algo_runner.py b/algo/algo_runner.py new file mode 100644 index 0000000..ee07eaf --- /dev/null +++ b/algo/algo_runner.py @@ -0,0 +1,117 @@ +from algo.model_manager import ModelManager +from common.consts import NotifyChangeType +from common.global_thread_pool import GlobalThreadPool +from db.database import get_db +from entity.device import Device +from services.device_model_relation_service import DeviceModelRelationService +from services.device_service import DeviceService +from services.model_service import ModelService + + +class AlgoRunner: + def __init__(self): + self.db = get_db() + self.device_service = DeviceService(self.db) + self.model_service = ModelService(self.db) + self.relation_service = DeviceModelRelationService(self.db) + self.model_manager = ModelManager(self.db) + self.thread_pool = GlobalThreadPool() + self.threads = {} # 用于存储设备对应的线程 + self.device_model_relations = {} + + # 注册设备和模型的变化回调 + self.device_service.register_change_callback(self.on_device_change) + self.model_service.register_change_callback(self.on_model_change) + self.relation_service.relation_change_callbacks(self.on_relation_change) + + def start(self): + """在程序启动时调用,读取设备和模型,启动检测线程""" + self.model_manager.load_models() + self.load_and_start_devices() + + def load_and_start_devices(self): + """从数据库读取设备列表并启动线程""" + devices = self.device_service.get_device_list() + devices = [device for device in devices if device.input_stream_url] + for device in devices: + self.start_device_thread(device) + + def run_detection(self, device: Device): + """todo 设备目标检测主逻辑""" + print(f"设备 {device.device_id} 的检测线程启动") + video_stream = self.device_service.get_video_stream(device.device_id) + + while True: + try: + frame = video_stream.read_frame() + for model in self.model_manager.models: + # 调用目标检测模型 + result = self.model_service.run_inference(model, frame) + # 在此处进行后处理,例如标记视频、生成告警等 + self.process_result(result, device.device_id) + except Exception as e: + print(f"设备 {device.device_id} 处理时出错: {e}") + break + video_stream.close() + print(f"设备 {device.device_id} 的检测线程结束") + + def start_device_thread(self, device: Device): + """为单个设备启动检测线程""" + + if not device.input_stream_url: + print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程') + return + + if device.device_id in self.threads: + print(f"设备 {device.device_id} 已经在运行中") + return + + # 获取设备绑定的模型列表 + relations = self.relation_service.get_device_models(device.device_id) + self.device_model_relations[device.device_id] = relations + + if not relations: + print(f"设备 {device.code} 未绑定模型,无法启动检测") + return + + future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}') + + def stop_device_thread(self, device_id): + """todo 控制标志位 停止指定设备的检测线程""" + if device_id in self.threads: + self.threads[device_id].cancel() # 尝试取消线程 + del self.threads[device_id] + print(f"设备 {device_id} 的检测线程已停止") + + def restart_device_thread(self, device_id): + self.stop_device_thread(device_id) + # todo 需要留个关闭时间吗 + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + + def on_device_change(self, device_id, change_type): + """设备变化回调""" + if change_type == NotifyChangeType.DEVICE_CREATE: + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + elif change_type == NotifyChangeType.DEVICE_DELETE: + self.stop_device_thread(device_id) + elif change_type == NotifyChangeType.DEVICE_UPDATE: + self.restart_device_thread(device_id) + + def on_model_change(self, model_id, change_type): + """模型变化回调""" + if change_type == NotifyChangeType.MODEL_UPDATE: + self.model_manager.reload_model(model_id) + + devices_to_reload = [ + device_id for device_id, relation_info in self.device_model_relations.items() + if relation_info.model_id == model_id + ] + for device_id in devices_to_reload: + self.restart_device_thread(device_id) + + def on_relation_change(self, device_id, change_type): + """设备模型关系变化回调""" + if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: + self.restart_device_thread(device_id) diff --git a/algo/model_manager.py b/algo/model_manager.py new file mode 100644 index 0000000..3ab0709 --- /dev/null +++ b/algo/model_manager.py @@ -0,0 +1,64 @@ +from dataclasses import dataclass +from typing import Optional + +import numpy as np +from ultralytics import YOLO + +from entity.model import AlgoModel +from services.model_service import ModelService +from common.global_logger import logger + + +@dataclass +class AlgoModelExec: + algo_model_id: int + algo_model_info: AlgoModel + algo_model_exec: Optional[object] = None + input_size: int = 640 + + +class ModelManager: + def __init__(self, db, model_warm_up=5): + self.db = db + self.model_service = ModelService(self.db) + self.models = {} + self.model_warm_up = model_warm_up + + self.load_models() + + def query_model_inuse(self): + algo_model_list = list(self.model_service.get_models_in_use()) + for algo_model in algo_model_list: + self.models[algo_model.id] = AlgoModelExec( + algo_model_id=algo_model.id, + algo_model_info=algo_model + ) + + def load_models(self): + self.models = {} + self.query_model_inuse() + for algo_model_id, algo_model_exec in self.models.items(): + self.load_model(algo_model_exec) + + def load_model(self, algo_model_exec: AlgoModelExec): + model_name = algo_model_exec.algo_model_info.name + model_path = algo_model_exec.algo_model_info.path + logger.info(f'loading model {model_name}: {model_path}') + + algo_model_exec.algo_model_exec = YOLO(model_path, task='detect') + if self.model_warm_up > 0: + logger.info(f'warming up model {model_name}') + imgsz = algo_model_exec.input_size + if not isinstance(imgsz, list): + imgsz = [imgsz, imgsz] + dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + for i in range(self.model_warm_up): + algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + logger.info(f'warm up model {model_name} success!') + logger.info(f'load model {model_name} success!') + + def reload_model(self, model_id): + algo_model_exec = self.models.get(model_id) + if algo_model_exec: + algo_model_exec.algo_model_exec = None + self.load_model(algo_model_exec) diff --git a/apis/__init__.py b/apis/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/apis/__init__.py diff --git a/apis/base.py b/apis/base.py new file mode 100644 index 0000000..a7c4ede --- /dev/null +++ b/apis/base.py @@ -0,0 +1,26 @@ +from pydantic import BaseModel, Field +from typing import Optional, Generic, TypeVar, Any, List + +# 定义一个泛型类型变量 T +T = TypeVar("T") + + +# 使用泛型 T 定义 data 的类型 +class StandardResponse(BaseModel, Generic[T]): + code: int = 200 + data: Optional[T] = None + message: str = "请求成功" + success: bool = True + + +class PageResponse(BaseModel, Generic[T]): + total: int + items: List[T] = Field(default_factory=list) + + +def standard_response(data: Any = None, code: int = 200, message: str = "请求成功", success: bool = True): + return StandardResponse(data=data, code=code, message=message, success=success) + + +def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False): + return StandardResponse(data=data, code=code, message=message, success=success) diff --git a/apis/device.py b/apis/device.py new file mode 100644 index 0000000..7eeaa78 --- /dev/null +++ b/apis/device.py @@ -0,0 +1,65 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, Query +from sqlmodel import Session + +from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response +from db.database import get_db +from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo +from services.device_service import DeviceService + +router = APIRouter() + + +@router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) +def get_device_list( + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + db: Session = Depends(get_db)): + service = DeviceService(db) + devices = list(service.get_device_list(name, code, device_type)) + return standard_response(data=devices) + + +@router.get("/page/", response_model=StandardResponse[PageResponse[Device]]) +def get_device_page( + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = DeviceService(db) + + # 获取分页后的设备列表和总数 + devices, total = service.get_device_page(name, code, device_type, offset, limit) + + return standard_response( + data=PageResponse(total=total, items=devices) + ) + + +@router.post("/add", response_model=StandardResponse[DeviceInfo]) +def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.create_device(device_data) + return standard_response(data=device) + + +@router.post("/update", response_model=StandardResponse[DeviceInfo]) +def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.update_device(device_data) + if not device: + return standard_error_response(data=device_data, message="Device not found") + return standard_response(data=device) + + +@router.delete("/delete", response_model=StandardResponse[int]) +def delete_device(device_id: int, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.delete_device(device_id) + if not device: + return standard_error_response(data=device_id, message="Device not found") + return standard_response(data=device_id) diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py new file mode 100644 index 0000000..7108571 --- /dev/null +++ b/apis/device_model_realtion.py @@ -0,0 +1,45 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, Query +from sqlmodel import Session + +from apis.base import StandardResponse, standard_response +from db.database import get_db +from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation +from services.device_model_relation_service import DeviceModelRelationService + +router = APIRouter() + + +@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) +def list_by_device( + device_id: int, + db: Session = Depends(get_db)): + service = DeviceModelRelationService(db) + models = list(service.get_device_models(device_id)) + return standard_response(data=models) + + +# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) +# def add_by_device(relation_data: List[DeviceModelRelationCreate], +# device_id: int = Query(...), +# db: Session = Depends(get_db)): +# service = DeviceModelRelationService(db) +# relations = service.add_relations_by_device(device_id, relation_data) +# return standard_response(data=relations) + + +@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) +def update_by_device(relation_data: List[DeviceModelRelationCreate], + device_id: int = Query(...), + db: Session = Depends(get_db)): + service = DeviceModelRelationService(db) + relations = service.update_relations_by_device(device_id, relation_data) + return standard_response(data=relations) + + +# @router.delete("/delete_by_device", response_model=StandardResponse[int]) +# def delete_device(device_id: int, db: Session = Depends(get_db)): +# service = DeviceModelRelationService(db) +# count = service.delete_relations_by_device(device_id) +# return standard_response(data=count) diff --git a/apis/model.py b/apis/model.py new file mode 100644 index 0000000..8adcc14 --- /dev/null +++ b/apis/model.py @@ -0,0 +1,64 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlmodel import Session + +from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response +from common.biz_exception import BizException +from db.database import get_db +from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo +from services.model_service import ModelService + +router = APIRouter() + + +@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) +def get_model_list( + name: Optional[str] = None, + remark: Optional[str] = None, + db: Session = Depends(get_db)): + service = ModelService(db) + models = list(service.get_model_list(name, remark)) + return standard_response(data=models) + + +@router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModel]]) +def get_model_page( + name: Optional[str] = None, + remark: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = ModelService(db) + + # 获取分页后的设备列表和总数 + models, total = service.get_model_page(name, remark, offset, limit) + + return standard_response( + data=PageResponse(total=total, items=models) + ) + + +@router.post("/add", response_model=StandardResponse[AlgoModelInfo]) +def create_model(model_data: AlgoModelCreate, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.create_model(model_data) + return standard_response(data=model) + + +@router.post("/update", response_model=StandardResponse[AlgoModelInfo]) +def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.update_model(model_data) + if not model: + return standard_error_response(data=model_data, message="Model not found") + return standard_response(data=model) + + +@router.delete("/delete", response_model=StandardResponse[int]) +def delete_model(model_id: int, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.delete_model(model_id) + if not model: + return standard_error_response(data=model_id, message="Model not found") + return standard_response(data=model_id) diff --git a/apis/router.py b/apis/router.py new file mode 100644 index 0000000..22b5ed6 --- /dev/null +++ b/apis/router.py @@ -0,0 +1,15 @@ +# api/router.py +from fastapi import APIRouter +from .device import router as devices_router +from .model import router as models_router +from .device_model_realtion import router as device_model_relation_router + + +# 创建一个全局的 router +router = APIRouter() + +# 将各个模块的 router 注册到全局 router 中 +router.include_router(devices_router, prefix="/device", tags=["Devices"]) +router.include_router(models_router, prefix="/model", tags=["Models"]) +router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"]) + diff --git a/common/__init__.py b/common/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/common/__init__.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e530a86 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +test.py +/logs/* \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..901fe74 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,20 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..b68d82f --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..3deddad --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml new file mode 100644 index 0000000..d0876a7 --- /dev/null +++ b/.idea/safe-algo-pro.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/algo/__init__.py b/algo/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/algo/__init__.py diff --git a/algo/algo_runner.py b/algo/algo_runner.py new file mode 100644 index 0000000..ee07eaf --- /dev/null +++ b/algo/algo_runner.py @@ -0,0 +1,117 @@ +from algo.model_manager import ModelManager +from common.consts import NotifyChangeType +from common.global_thread_pool import GlobalThreadPool +from db.database import get_db +from entity.device import Device +from services.device_model_relation_service import DeviceModelRelationService +from services.device_service import DeviceService +from services.model_service import ModelService + + +class AlgoRunner: + def __init__(self): + self.db = get_db() + self.device_service = DeviceService(self.db) + self.model_service = ModelService(self.db) + self.relation_service = DeviceModelRelationService(self.db) + self.model_manager = ModelManager(self.db) + self.thread_pool = GlobalThreadPool() + self.threads = {} # 用于存储设备对应的线程 + self.device_model_relations = {} + + # 注册设备和模型的变化回调 + self.device_service.register_change_callback(self.on_device_change) + self.model_service.register_change_callback(self.on_model_change) + self.relation_service.relation_change_callbacks(self.on_relation_change) + + def start(self): + """在程序启动时调用,读取设备和模型,启动检测线程""" + self.model_manager.load_models() + self.load_and_start_devices() + + def load_and_start_devices(self): + """从数据库读取设备列表并启动线程""" + devices = self.device_service.get_device_list() + devices = [device for device in devices if device.input_stream_url] + for device in devices: + self.start_device_thread(device) + + def run_detection(self, device: Device): + """todo 设备目标检测主逻辑""" + print(f"设备 {device.device_id} 的检测线程启动") + video_stream = self.device_service.get_video_stream(device.device_id) + + while True: + try: + frame = video_stream.read_frame() + for model in self.model_manager.models: + # 调用目标检测模型 + result = self.model_service.run_inference(model, frame) + # 在此处进行后处理,例如标记视频、生成告警等 + self.process_result(result, device.device_id) + except Exception as e: + print(f"设备 {device.device_id} 处理时出错: {e}") + break + video_stream.close() + print(f"设备 {device.device_id} 的检测线程结束") + + def start_device_thread(self, device: Device): + """为单个设备启动检测线程""" + + if not device.input_stream_url: + print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程') + return + + if device.device_id in self.threads: + print(f"设备 {device.device_id} 已经在运行中") + return + + # 获取设备绑定的模型列表 + relations = self.relation_service.get_device_models(device.device_id) + self.device_model_relations[device.device_id] = relations + + if not relations: + print(f"设备 {device.code} 未绑定模型,无法启动检测") + return + + future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}') + + def stop_device_thread(self, device_id): + """todo 控制标志位 停止指定设备的检测线程""" + if device_id in self.threads: + self.threads[device_id].cancel() # 尝试取消线程 + del self.threads[device_id] + print(f"设备 {device_id} 的检测线程已停止") + + def restart_device_thread(self, device_id): + self.stop_device_thread(device_id) + # todo 需要留个关闭时间吗 + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + + def on_device_change(self, device_id, change_type): + """设备变化回调""" + if change_type == NotifyChangeType.DEVICE_CREATE: + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + elif change_type == NotifyChangeType.DEVICE_DELETE: + self.stop_device_thread(device_id) + elif change_type == NotifyChangeType.DEVICE_UPDATE: + self.restart_device_thread(device_id) + + def on_model_change(self, model_id, change_type): + """模型变化回调""" + if change_type == NotifyChangeType.MODEL_UPDATE: + self.model_manager.reload_model(model_id) + + devices_to_reload = [ + device_id for device_id, relation_info in self.device_model_relations.items() + if relation_info.model_id == model_id + ] + for device_id in devices_to_reload: + self.restart_device_thread(device_id) + + def on_relation_change(self, device_id, change_type): + """设备模型关系变化回调""" + if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: + self.restart_device_thread(device_id) diff --git a/algo/model_manager.py b/algo/model_manager.py new file mode 100644 index 0000000..3ab0709 --- /dev/null +++ b/algo/model_manager.py @@ -0,0 +1,64 @@ +from dataclasses import dataclass +from typing import Optional + +import numpy as np +from ultralytics import YOLO + +from entity.model import AlgoModel +from services.model_service import ModelService +from common.global_logger import logger + + +@dataclass +class AlgoModelExec: + algo_model_id: int + algo_model_info: AlgoModel + algo_model_exec: Optional[object] = None + input_size: int = 640 + + +class ModelManager: + def __init__(self, db, model_warm_up=5): + self.db = db + self.model_service = ModelService(self.db) + self.models = {} + self.model_warm_up = model_warm_up + + self.load_models() + + def query_model_inuse(self): + algo_model_list = list(self.model_service.get_models_in_use()) + for algo_model in algo_model_list: + self.models[algo_model.id] = AlgoModelExec( + algo_model_id=algo_model.id, + algo_model_info=algo_model + ) + + def load_models(self): + self.models = {} + self.query_model_inuse() + for algo_model_id, algo_model_exec in self.models.items(): + self.load_model(algo_model_exec) + + def load_model(self, algo_model_exec: AlgoModelExec): + model_name = algo_model_exec.algo_model_info.name + model_path = algo_model_exec.algo_model_info.path + logger.info(f'loading model {model_name}: {model_path}') + + algo_model_exec.algo_model_exec = YOLO(model_path, task='detect') + if self.model_warm_up > 0: + logger.info(f'warming up model {model_name}') + imgsz = algo_model_exec.input_size + if not isinstance(imgsz, list): + imgsz = [imgsz, imgsz] + dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + for i in range(self.model_warm_up): + algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + logger.info(f'warm up model {model_name} success!') + logger.info(f'load model {model_name} success!') + + def reload_model(self, model_id): + algo_model_exec = self.models.get(model_id) + if algo_model_exec: + algo_model_exec.algo_model_exec = None + self.load_model(algo_model_exec) diff --git a/apis/__init__.py b/apis/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/apis/__init__.py diff --git a/apis/base.py b/apis/base.py new file mode 100644 index 0000000..a7c4ede --- /dev/null +++ b/apis/base.py @@ -0,0 +1,26 @@ +from pydantic import BaseModel, Field +from typing import Optional, Generic, TypeVar, Any, List + +# 定义一个泛型类型变量 T +T = TypeVar("T") + + +# 使用泛型 T 定义 data 的类型 +class StandardResponse(BaseModel, Generic[T]): + code: int = 200 + data: Optional[T] = None + message: str = "请求成功" + success: bool = True + + +class PageResponse(BaseModel, Generic[T]): + total: int + items: List[T] = Field(default_factory=list) + + +def standard_response(data: Any = None, code: int = 200, message: str = "请求成功", success: bool = True): + return StandardResponse(data=data, code=code, message=message, success=success) + + +def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False): + return StandardResponse(data=data, code=code, message=message, success=success) diff --git a/apis/device.py b/apis/device.py new file mode 100644 index 0000000..7eeaa78 --- /dev/null +++ b/apis/device.py @@ -0,0 +1,65 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, Query +from sqlmodel import Session + +from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response +from db.database import get_db +from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo +from services.device_service import DeviceService + +router = APIRouter() + + +@router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) +def get_device_list( + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + db: Session = Depends(get_db)): + service = DeviceService(db) + devices = list(service.get_device_list(name, code, device_type)) + return standard_response(data=devices) + + +@router.get("/page/", response_model=StandardResponse[PageResponse[Device]]) +def get_device_page( + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = DeviceService(db) + + # 获取分页后的设备列表和总数 + devices, total = service.get_device_page(name, code, device_type, offset, limit) + + return standard_response( + data=PageResponse(total=total, items=devices) + ) + + +@router.post("/add", response_model=StandardResponse[DeviceInfo]) +def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.create_device(device_data) + return standard_response(data=device) + + +@router.post("/update", response_model=StandardResponse[DeviceInfo]) +def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.update_device(device_data) + if not device: + return standard_error_response(data=device_data, message="Device not found") + return standard_response(data=device) + + +@router.delete("/delete", response_model=StandardResponse[int]) +def delete_device(device_id: int, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.delete_device(device_id) + if not device: + return standard_error_response(data=device_id, message="Device not found") + return standard_response(data=device_id) diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py new file mode 100644 index 0000000..7108571 --- /dev/null +++ b/apis/device_model_realtion.py @@ -0,0 +1,45 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, Query +from sqlmodel import Session + +from apis.base import StandardResponse, standard_response +from db.database import get_db +from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation +from services.device_model_relation_service import DeviceModelRelationService + +router = APIRouter() + + +@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) +def list_by_device( + device_id: int, + db: Session = Depends(get_db)): + service = DeviceModelRelationService(db) + models = list(service.get_device_models(device_id)) + return standard_response(data=models) + + +# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) +# def add_by_device(relation_data: List[DeviceModelRelationCreate], +# device_id: int = Query(...), +# db: Session = Depends(get_db)): +# service = DeviceModelRelationService(db) +# relations = service.add_relations_by_device(device_id, relation_data) +# return standard_response(data=relations) + + +@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) +def update_by_device(relation_data: List[DeviceModelRelationCreate], + device_id: int = Query(...), + db: Session = Depends(get_db)): + service = DeviceModelRelationService(db) + relations = service.update_relations_by_device(device_id, relation_data) + return standard_response(data=relations) + + +# @router.delete("/delete_by_device", response_model=StandardResponse[int]) +# def delete_device(device_id: int, db: Session = Depends(get_db)): +# service = DeviceModelRelationService(db) +# count = service.delete_relations_by_device(device_id) +# return standard_response(data=count) diff --git a/apis/model.py b/apis/model.py new file mode 100644 index 0000000..8adcc14 --- /dev/null +++ b/apis/model.py @@ -0,0 +1,64 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlmodel import Session + +from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response +from common.biz_exception import BizException +from db.database import get_db +from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo +from services.model_service import ModelService + +router = APIRouter() + + +@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) +def get_model_list( + name: Optional[str] = None, + remark: Optional[str] = None, + db: Session = Depends(get_db)): + service = ModelService(db) + models = list(service.get_model_list(name, remark)) + return standard_response(data=models) + + +@router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModel]]) +def get_model_page( + name: Optional[str] = None, + remark: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = ModelService(db) + + # 获取分页后的设备列表和总数 + models, total = service.get_model_page(name, remark, offset, limit) + + return standard_response( + data=PageResponse(total=total, items=models) + ) + + +@router.post("/add", response_model=StandardResponse[AlgoModelInfo]) +def create_model(model_data: AlgoModelCreate, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.create_model(model_data) + return standard_response(data=model) + + +@router.post("/update", response_model=StandardResponse[AlgoModelInfo]) +def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.update_model(model_data) + if not model: + return standard_error_response(data=model_data, message="Model not found") + return standard_response(data=model) + + +@router.delete("/delete", response_model=StandardResponse[int]) +def delete_model(model_id: int, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.delete_model(model_id) + if not model: + return standard_error_response(data=model_id, message="Model not found") + return standard_response(data=model_id) diff --git a/apis/router.py b/apis/router.py new file mode 100644 index 0000000..22b5ed6 --- /dev/null +++ b/apis/router.py @@ -0,0 +1,15 @@ +# api/router.py +from fastapi import APIRouter +from .device import router as devices_router +from .model import router as models_router +from .device_model_realtion import router as device_model_relation_router + + +# 创建一个全局的 router +router = APIRouter() + +# 将各个模块的 router 注册到全局 router 中 +router.include_router(devices_router, prefix="/device", tags=["Devices"]) +router.include_router(models_router, prefix="/model", tags=["Models"]) +router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"]) + diff --git a/common/__init__.py b/common/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/common/__init__.py diff --git a/common/biz_exception.py b/common/biz_exception.py new file mode 100644 index 0000000..946a794 --- /dev/null +++ b/common/biz_exception.py @@ -0,0 +1,14 @@ +from fastapi import Request, HTTPException + +from apis.base import standard_error_response + + +class BizException(HTTPException): + def __init__(self, message: str = "请求异常", status_code: int = 500): + super().__init__(status_code=status_code, detail=message) + + +class BizExceptionHandlers: + @staticmethod + async def biz_exception_handler(request: Request, exc: HTTPException): + standard_error_response(code=exc.status_code, message=exc.detail) diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e530a86 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +test.py +/logs/* \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..901fe74 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,20 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..b68d82f --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..3deddad --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml new file mode 100644 index 0000000..d0876a7 --- /dev/null +++ b/.idea/safe-algo-pro.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/algo/__init__.py b/algo/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/algo/__init__.py diff --git a/algo/algo_runner.py b/algo/algo_runner.py new file mode 100644 index 0000000..ee07eaf --- /dev/null +++ b/algo/algo_runner.py @@ -0,0 +1,117 @@ +from algo.model_manager import ModelManager +from common.consts import NotifyChangeType +from common.global_thread_pool import GlobalThreadPool +from db.database import get_db +from entity.device import Device +from services.device_model_relation_service import DeviceModelRelationService +from services.device_service import DeviceService +from services.model_service import ModelService + + +class AlgoRunner: + def __init__(self): + self.db = get_db() + self.device_service = DeviceService(self.db) + self.model_service = ModelService(self.db) + self.relation_service = DeviceModelRelationService(self.db) + self.model_manager = ModelManager(self.db) + self.thread_pool = GlobalThreadPool() + self.threads = {} # 用于存储设备对应的线程 + self.device_model_relations = {} + + # 注册设备和模型的变化回调 + self.device_service.register_change_callback(self.on_device_change) + self.model_service.register_change_callback(self.on_model_change) + self.relation_service.relation_change_callbacks(self.on_relation_change) + + def start(self): + """在程序启动时调用,读取设备和模型,启动检测线程""" + self.model_manager.load_models() + self.load_and_start_devices() + + def load_and_start_devices(self): + """从数据库读取设备列表并启动线程""" + devices = self.device_service.get_device_list() + devices = [device for device in devices if device.input_stream_url] + for device in devices: + self.start_device_thread(device) + + def run_detection(self, device: Device): + """todo 设备目标检测主逻辑""" + print(f"设备 {device.device_id} 的检测线程启动") + video_stream = self.device_service.get_video_stream(device.device_id) + + while True: + try: + frame = video_stream.read_frame() + for model in self.model_manager.models: + # 调用目标检测模型 + result = self.model_service.run_inference(model, frame) + # 在此处进行后处理,例如标记视频、生成告警等 + self.process_result(result, device.device_id) + except Exception as e: + print(f"设备 {device.device_id} 处理时出错: {e}") + break + video_stream.close() + print(f"设备 {device.device_id} 的检测线程结束") + + def start_device_thread(self, device: Device): + """为单个设备启动检测线程""" + + if not device.input_stream_url: + print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程') + return + + if device.device_id in self.threads: + print(f"设备 {device.device_id} 已经在运行中") + return + + # 获取设备绑定的模型列表 + relations = self.relation_service.get_device_models(device.device_id) + self.device_model_relations[device.device_id] = relations + + if not relations: + print(f"设备 {device.code} 未绑定模型,无法启动检测") + return + + future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}') + + def stop_device_thread(self, device_id): + """todo 控制标志位 停止指定设备的检测线程""" + if device_id in self.threads: + self.threads[device_id].cancel() # 尝试取消线程 + del self.threads[device_id] + print(f"设备 {device_id} 的检测线程已停止") + + def restart_device_thread(self, device_id): + self.stop_device_thread(device_id) + # todo 需要留个关闭时间吗 + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + + def on_device_change(self, device_id, change_type): + """设备变化回调""" + if change_type == NotifyChangeType.DEVICE_CREATE: + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + elif change_type == NotifyChangeType.DEVICE_DELETE: + self.stop_device_thread(device_id) + elif change_type == NotifyChangeType.DEVICE_UPDATE: + self.restart_device_thread(device_id) + + def on_model_change(self, model_id, change_type): + """模型变化回调""" + if change_type == NotifyChangeType.MODEL_UPDATE: + self.model_manager.reload_model(model_id) + + devices_to_reload = [ + device_id for device_id, relation_info in self.device_model_relations.items() + if relation_info.model_id == model_id + ] + for device_id in devices_to_reload: + self.restart_device_thread(device_id) + + def on_relation_change(self, device_id, change_type): + """设备模型关系变化回调""" + if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: + self.restart_device_thread(device_id) diff --git a/algo/model_manager.py b/algo/model_manager.py new file mode 100644 index 0000000..3ab0709 --- /dev/null +++ b/algo/model_manager.py @@ -0,0 +1,64 @@ +from dataclasses import dataclass +from typing import Optional + +import numpy as np +from ultralytics import YOLO + +from entity.model import AlgoModel +from services.model_service import ModelService +from common.global_logger import logger + + +@dataclass +class AlgoModelExec: + algo_model_id: int + algo_model_info: AlgoModel + algo_model_exec: Optional[object] = None + input_size: int = 640 + + +class ModelManager: + def __init__(self, db, model_warm_up=5): + self.db = db + self.model_service = ModelService(self.db) + self.models = {} + self.model_warm_up = model_warm_up + + self.load_models() + + def query_model_inuse(self): + algo_model_list = list(self.model_service.get_models_in_use()) + for algo_model in algo_model_list: + self.models[algo_model.id] = AlgoModelExec( + algo_model_id=algo_model.id, + algo_model_info=algo_model + ) + + def load_models(self): + self.models = {} + self.query_model_inuse() + for algo_model_id, algo_model_exec in self.models.items(): + self.load_model(algo_model_exec) + + def load_model(self, algo_model_exec: AlgoModelExec): + model_name = algo_model_exec.algo_model_info.name + model_path = algo_model_exec.algo_model_info.path + logger.info(f'loading model {model_name}: {model_path}') + + algo_model_exec.algo_model_exec = YOLO(model_path, task='detect') + if self.model_warm_up > 0: + logger.info(f'warming up model {model_name}') + imgsz = algo_model_exec.input_size + if not isinstance(imgsz, list): + imgsz = [imgsz, imgsz] + dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + for i in range(self.model_warm_up): + algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + logger.info(f'warm up model {model_name} success!') + logger.info(f'load model {model_name} success!') + + def reload_model(self, model_id): + algo_model_exec = self.models.get(model_id) + if algo_model_exec: + algo_model_exec.algo_model_exec = None + self.load_model(algo_model_exec) diff --git a/apis/__init__.py b/apis/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/apis/__init__.py diff --git a/apis/base.py b/apis/base.py new file mode 100644 index 0000000..a7c4ede --- /dev/null +++ b/apis/base.py @@ -0,0 +1,26 @@ +from pydantic import BaseModel, Field +from typing import Optional, Generic, TypeVar, Any, List + +# 定义一个泛型类型变量 T +T = TypeVar("T") + + +# 使用泛型 T 定义 data 的类型 +class StandardResponse(BaseModel, Generic[T]): + code: int = 200 + data: Optional[T] = None + message: str = "请求成功" + success: bool = True + + +class PageResponse(BaseModel, Generic[T]): + total: int + items: List[T] = Field(default_factory=list) + + +def standard_response(data: Any = None, code: int = 200, message: str = "请求成功", success: bool = True): + return StandardResponse(data=data, code=code, message=message, success=success) + + +def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False): + return StandardResponse(data=data, code=code, message=message, success=success) diff --git a/apis/device.py b/apis/device.py new file mode 100644 index 0000000..7eeaa78 --- /dev/null +++ b/apis/device.py @@ -0,0 +1,65 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, Query +from sqlmodel import Session + +from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response +from db.database import get_db +from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo +from services.device_service import DeviceService + +router = APIRouter() + + +@router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) +def get_device_list( + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + db: Session = Depends(get_db)): + service = DeviceService(db) + devices = list(service.get_device_list(name, code, device_type)) + return standard_response(data=devices) + + +@router.get("/page/", response_model=StandardResponse[PageResponse[Device]]) +def get_device_page( + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = DeviceService(db) + + # 获取分页后的设备列表和总数 + devices, total = service.get_device_page(name, code, device_type, offset, limit) + + return standard_response( + data=PageResponse(total=total, items=devices) + ) + + +@router.post("/add", response_model=StandardResponse[DeviceInfo]) +def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.create_device(device_data) + return standard_response(data=device) + + +@router.post("/update", response_model=StandardResponse[DeviceInfo]) +def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.update_device(device_data) + if not device: + return standard_error_response(data=device_data, message="Device not found") + return standard_response(data=device) + + +@router.delete("/delete", response_model=StandardResponse[int]) +def delete_device(device_id: int, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.delete_device(device_id) + if not device: + return standard_error_response(data=device_id, message="Device not found") + return standard_response(data=device_id) diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py new file mode 100644 index 0000000..7108571 --- /dev/null +++ b/apis/device_model_realtion.py @@ -0,0 +1,45 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, Query +from sqlmodel import Session + +from apis.base import StandardResponse, standard_response +from db.database import get_db +from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation +from services.device_model_relation_service import DeviceModelRelationService + +router = APIRouter() + + +@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) +def list_by_device( + device_id: int, + db: Session = Depends(get_db)): + service = DeviceModelRelationService(db) + models = list(service.get_device_models(device_id)) + return standard_response(data=models) + + +# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) +# def add_by_device(relation_data: List[DeviceModelRelationCreate], +# device_id: int = Query(...), +# db: Session = Depends(get_db)): +# service = DeviceModelRelationService(db) +# relations = service.add_relations_by_device(device_id, relation_data) +# return standard_response(data=relations) + + +@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) +def update_by_device(relation_data: List[DeviceModelRelationCreate], + device_id: int = Query(...), + db: Session = Depends(get_db)): + service = DeviceModelRelationService(db) + relations = service.update_relations_by_device(device_id, relation_data) + return standard_response(data=relations) + + +# @router.delete("/delete_by_device", response_model=StandardResponse[int]) +# def delete_device(device_id: int, db: Session = Depends(get_db)): +# service = DeviceModelRelationService(db) +# count = service.delete_relations_by_device(device_id) +# return standard_response(data=count) diff --git a/apis/model.py b/apis/model.py new file mode 100644 index 0000000..8adcc14 --- /dev/null +++ b/apis/model.py @@ -0,0 +1,64 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlmodel import Session + +from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response +from common.biz_exception import BizException +from db.database import get_db +from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo +from services.model_service import ModelService + +router = APIRouter() + + +@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) +def get_model_list( + name: Optional[str] = None, + remark: Optional[str] = None, + db: Session = Depends(get_db)): + service = ModelService(db) + models = list(service.get_model_list(name, remark)) + return standard_response(data=models) + + +@router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModel]]) +def get_model_page( + name: Optional[str] = None, + remark: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = ModelService(db) + + # 获取分页后的设备列表和总数 + models, total = service.get_model_page(name, remark, offset, limit) + + return standard_response( + data=PageResponse(total=total, items=models) + ) + + +@router.post("/add", response_model=StandardResponse[AlgoModelInfo]) +def create_model(model_data: AlgoModelCreate, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.create_model(model_data) + return standard_response(data=model) + + +@router.post("/update", response_model=StandardResponse[AlgoModelInfo]) +def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.update_model(model_data) + if not model: + return standard_error_response(data=model_data, message="Model not found") + return standard_response(data=model) + + +@router.delete("/delete", response_model=StandardResponse[int]) +def delete_model(model_id: int, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.delete_model(model_id) + if not model: + return standard_error_response(data=model_id, message="Model not found") + return standard_response(data=model_id) diff --git a/apis/router.py b/apis/router.py new file mode 100644 index 0000000..22b5ed6 --- /dev/null +++ b/apis/router.py @@ -0,0 +1,15 @@ +# api/router.py +from fastapi import APIRouter +from .device import router as devices_router +from .model import router as models_router +from .device_model_realtion import router as device_model_relation_router + + +# 创建一个全局的 router +router = APIRouter() + +# 将各个模块的 router 注册到全局 router 中 +router.include_router(devices_router, prefix="/device", tags=["Devices"]) +router.include_router(models_router, prefix="/model", tags=["Models"]) +router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"]) + diff --git a/common/__init__.py b/common/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/common/__init__.py diff --git a/common/biz_exception.py b/common/biz_exception.py new file mode 100644 index 0000000..946a794 --- /dev/null +++ b/common/biz_exception.py @@ -0,0 +1,14 @@ +from fastapi import Request, HTTPException + +from apis.base import standard_error_response + + +class BizException(HTTPException): + def __init__(self, message: str = "请求异常", status_code: int = 500): + super().__init__(status_code=status_code, detail=message) + + +class BizExceptionHandlers: + @staticmethod + async def biz_exception_handler(request: Request, exc: HTTPException): + standard_error_response(code=exc.status_code, message=exc.detail) diff --git a/common/consts.py b/common/consts.py new file mode 100644 index 0000000..bba41e4 --- /dev/null +++ b/common/consts.py @@ -0,0 +1,16 @@ +class Constants: + def __setattr__(self, name, value): + raise AttributeError(f"Cannot modify constant {name}") + + +class NotifyChangeType(Constants): + # 定义常量成员 + DEVICE_CREATE = "device_create" + DEVICE_UPDATE = "device_update" + DEVICE_DELETE = "device_delete" + MODEL_CREATE = "model_create" # 模型新增不用通知,与设备绑定时才生效 + MODEL_UPDATE = "model_update" + MODEL_DELETE = "model_delete" # 正在使用的模型不能直接删除 + DEVICE_MODEL_RELATION_CREATE = "device_model_relation_create" + DEVICE_MODEL_RELATION_UPDATE = "device_model_relation_update" # 绑定关系变化 应该只用这个吧?? + DEVICE_MODEL_RELATION_DELETE = "device_model_relation_delete" diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e530a86 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +test.py +/logs/* \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..901fe74 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,20 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..b68d82f --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..3deddad --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml new file mode 100644 index 0000000..d0876a7 --- /dev/null +++ b/.idea/safe-algo-pro.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/algo/__init__.py b/algo/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/algo/__init__.py diff --git a/algo/algo_runner.py b/algo/algo_runner.py new file mode 100644 index 0000000..ee07eaf --- /dev/null +++ b/algo/algo_runner.py @@ -0,0 +1,117 @@ +from algo.model_manager import ModelManager +from common.consts import NotifyChangeType +from common.global_thread_pool import GlobalThreadPool +from db.database import get_db +from entity.device import Device +from services.device_model_relation_service import DeviceModelRelationService +from services.device_service import DeviceService +from services.model_service import ModelService + + +class AlgoRunner: + def __init__(self): + self.db = get_db() + self.device_service = DeviceService(self.db) + self.model_service = ModelService(self.db) + self.relation_service = DeviceModelRelationService(self.db) + self.model_manager = ModelManager(self.db) + self.thread_pool = GlobalThreadPool() + self.threads = {} # 用于存储设备对应的线程 + self.device_model_relations = {} + + # 注册设备和模型的变化回调 + self.device_service.register_change_callback(self.on_device_change) + self.model_service.register_change_callback(self.on_model_change) + self.relation_service.relation_change_callbacks(self.on_relation_change) + + def start(self): + """在程序启动时调用,读取设备和模型,启动检测线程""" + self.model_manager.load_models() + self.load_and_start_devices() + + def load_and_start_devices(self): + """从数据库读取设备列表并启动线程""" + devices = self.device_service.get_device_list() + devices = [device for device in devices if device.input_stream_url] + for device in devices: + self.start_device_thread(device) + + def run_detection(self, device: Device): + """todo 设备目标检测主逻辑""" + print(f"设备 {device.device_id} 的检测线程启动") + video_stream = self.device_service.get_video_stream(device.device_id) + + while True: + try: + frame = video_stream.read_frame() + for model in self.model_manager.models: + # 调用目标检测模型 + result = self.model_service.run_inference(model, frame) + # 在此处进行后处理,例如标记视频、生成告警等 + self.process_result(result, device.device_id) + except Exception as e: + print(f"设备 {device.device_id} 处理时出错: {e}") + break + video_stream.close() + print(f"设备 {device.device_id} 的检测线程结束") + + def start_device_thread(self, device: Device): + """为单个设备启动检测线程""" + + if not device.input_stream_url: + print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程') + return + + if device.device_id in self.threads: + print(f"设备 {device.device_id} 已经在运行中") + return + + # 获取设备绑定的模型列表 + relations = self.relation_service.get_device_models(device.device_id) + self.device_model_relations[device.device_id] = relations + + if not relations: + print(f"设备 {device.code} 未绑定模型,无法启动检测") + return + + future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}') + + def stop_device_thread(self, device_id): + """todo 控制标志位 停止指定设备的检测线程""" + if device_id in self.threads: + self.threads[device_id].cancel() # 尝试取消线程 + del self.threads[device_id] + print(f"设备 {device_id} 的检测线程已停止") + + def restart_device_thread(self, device_id): + self.stop_device_thread(device_id) + # todo 需要留个关闭时间吗 + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + + def on_device_change(self, device_id, change_type): + """设备变化回调""" + if change_type == NotifyChangeType.DEVICE_CREATE: + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + elif change_type == NotifyChangeType.DEVICE_DELETE: + self.stop_device_thread(device_id) + elif change_type == NotifyChangeType.DEVICE_UPDATE: + self.restart_device_thread(device_id) + + def on_model_change(self, model_id, change_type): + """模型变化回调""" + if change_type == NotifyChangeType.MODEL_UPDATE: + self.model_manager.reload_model(model_id) + + devices_to_reload = [ + device_id for device_id, relation_info in self.device_model_relations.items() + if relation_info.model_id == model_id + ] + for device_id in devices_to_reload: + self.restart_device_thread(device_id) + + def on_relation_change(self, device_id, change_type): + """设备模型关系变化回调""" + if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: + self.restart_device_thread(device_id) diff --git a/algo/model_manager.py b/algo/model_manager.py new file mode 100644 index 0000000..3ab0709 --- /dev/null +++ b/algo/model_manager.py @@ -0,0 +1,64 @@ +from dataclasses import dataclass +from typing import Optional + +import numpy as np +from ultralytics import YOLO + +from entity.model import AlgoModel +from services.model_service import ModelService +from common.global_logger import logger + + +@dataclass +class AlgoModelExec: + algo_model_id: int + algo_model_info: AlgoModel + algo_model_exec: Optional[object] = None + input_size: int = 640 + + +class ModelManager: + def __init__(self, db, model_warm_up=5): + self.db = db + self.model_service = ModelService(self.db) + self.models = {} + self.model_warm_up = model_warm_up + + self.load_models() + + def query_model_inuse(self): + algo_model_list = list(self.model_service.get_models_in_use()) + for algo_model in algo_model_list: + self.models[algo_model.id] = AlgoModelExec( + algo_model_id=algo_model.id, + algo_model_info=algo_model + ) + + def load_models(self): + self.models = {} + self.query_model_inuse() + for algo_model_id, algo_model_exec in self.models.items(): + self.load_model(algo_model_exec) + + def load_model(self, algo_model_exec: AlgoModelExec): + model_name = algo_model_exec.algo_model_info.name + model_path = algo_model_exec.algo_model_info.path + logger.info(f'loading model {model_name}: {model_path}') + + algo_model_exec.algo_model_exec = YOLO(model_path, task='detect') + if self.model_warm_up > 0: + logger.info(f'warming up model {model_name}') + imgsz = algo_model_exec.input_size + if not isinstance(imgsz, list): + imgsz = [imgsz, imgsz] + dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + for i in range(self.model_warm_up): + algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + logger.info(f'warm up model {model_name} success!') + logger.info(f'load model {model_name} success!') + + def reload_model(self, model_id): + algo_model_exec = self.models.get(model_id) + if algo_model_exec: + algo_model_exec.algo_model_exec = None + self.load_model(algo_model_exec) diff --git a/apis/__init__.py b/apis/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/apis/__init__.py diff --git a/apis/base.py b/apis/base.py new file mode 100644 index 0000000..a7c4ede --- /dev/null +++ b/apis/base.py @@ -0,0 +1,26 @@ +from pydantic import BaseModel, Field +from typing import Optional, Generic, TypeVar, Any, List + +# 定义一个泛型类型变量 T +T = TypeVar("T") + + +# 使用泛型 T 定义 data 的类型 +class StandardResponse(BaseModel, Generic[T]): + code: int = 200 + data: Optional[T] = None + message: str = "请求成功" + success: bool = True + + +class PageResponse(BaseModel, Generic[T]): + total: int + items: List[T] = Field(default_factory=list) + + +def standard_response(data: Any = None, code: int = 200, message: str = "请求成功", success: bool = True): + return StandardResponse(data=data, code=code, message=message, success=success) + + +def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False): + return StandardResponse(data=data, code=code, message=message, success=success) diff --git a/apis/device.py b/apis/device.py new file mode 100644 index 0000000..7eeaa78 --- /dev/null +++ b/apis/device.py @@ -0,0 +1,65 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, Query +from sqlmodel import Session + +from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response +from db.database import get_db +from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo +from services.device_service import DeviceService + +router = APIRouter() + + +@router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) +def get_device_list( + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + db: Session = Depends(get_db)): + service = DeviceService(db) + devices = list(service.get_device_list(name, code, device_type)) + return standard_response(data=devices) + + +@router.get("/page/", response_model=StandardResponse[PageResponse[Device]]) +def get_device_page( + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = DeviceService(db) + + # 获取分页后的设备列表和总数 + devices, total = service.get_device_page(name, code, device_type, offset, limit) + + return standard_response( + data=PageResponse(total=total, items=devices) + ) + + +@router.post("/add", response_model=StandardResponse[DeviceInfo]) +def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.create_device(device_data) + return standard_response(data=device) + + +@router.post("/update", response_model=StandardResponse[DeviceInfo]) +def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.update_device(device_data) + if not device: + return standard_error_response(data=device_data, message="Device not found") + return standard_response(data=device) + + +@router.delete("/delete", response_model=StandardResponse[int]) +def delete_device(device_id: int, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.delete_device(device_id) + if not device: + return standard_error_response(data=device_id, message="Device not found") + return standard_response(data=device_id) diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py new file mode 100644 index 0000000..7108571 --- /dev/null +++ b/apis/device_model_realtion.py @@ -0,0 +1,45 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, Query +from sqlmodel import Session + +from apis.base import StandardResponse, standard_response +from db.database import get_db +from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation +from services.device_model_relation_service import DeviceModelRelationService + +router = APIRouter() + + +@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) +def list_by_device( + device_id: int, + db: Session = Depends(get_db)): + service = DeviceModelRelationService(db) + models = list(service.get_device_models(device_id)) + return standard_response(data=models) + + +# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) +# def add_by_device(relation_data: List[DeviceModelRelationCreate], +# device_id: int = Query(...), +# db: Session = Depends(get_db)): +# service = DeviceModelRelationService(db) +# relations = service.add_relations_by_device(device_id, relation_data) +# return standard_response(data=relations) + + +@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) +def update_by_device(relation_data: List[DeviceModelRelationCreate], + device_id: int = Query(...), + db: Session = Depends(get_db)): + service = DeviceModelRelationService(db) + relations = service.update_relations_by_device(device_id, relation_data) + return standard_response(data=relations) + + +# @router.delete("/delete_by_device", response_model=StandardResponse[int]) +# def delete_device(device_id: int, db: Session = Depends(get_db)): +# service = DeviceModelRelationService(db) +# count = service.delete_relations_by_device(device_id) +# return standard_response(data=count) diff --git a/apis/model.py b/apis/model.py new file mode 100644 index 0000000..8adcc14 --- /dev/null +++ b/apis/model.py @@ -0,0 +1,64 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlmodel import Session + +from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response +from common.biz_exception import BizException +from db.database import get_db +from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo +from services.model_service import ModelService + +router = APIRouter() + + +@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) +def get_model_list( + name: Optional[str] = None, + remark: Optional[str] = None, + db: Session = Depends(get_db)): + service = ModelService(db) + models = list(service.get_model_list(name, remark)) + return standard_response(data=models) + + +@router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModel]]) +def get_model_page( + name: Optional[str] = None, + remark: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = ModelService(db) + + # 获取分页后的设备列表和总数 + models, total = service.get_model_page(name, remark, offset, limit) + + return standard_response( + data=PageResponse(total=total, items=models) + ) + + +@router.post("/add", response_model=StandardResponse[AlgoModelInfo]) +def create_model(model_data: AlgoModelCreate, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.create_model(model_data) + return standard_response(data=model) + + +@router.post("/update", response_model=StandardResponse[AlgoModelInfo]) +def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.update_model(model_data) + if not model: + return standard_error_response(data=model_data, message="Model not found") + return standard_response(data=model) + + +@router.delete("/delete", response_model=StandardResponse[int]) +def delete_model(model_id: int, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.delete_model(model_id) + if not model: + return standard_error_response(data=model_id, message="Model not found") + return standard_response(data=model_id) diff --git a/apis/router.py b/apis/router.py new file mode 100644 index 0000000..22b5ed6 --- /dev/null +++ b/apis/router.py @@ -0,0 +1,15 @@ +# api/router.py +from fastapi import APIRouter +from .device import router as devices_router +from .model import router as models_router +from .device_model_realtion import router as device_model_relation_router + + +# 创建一个全局的 router +router = APIRouter() + +# 将各个模块的 router 注册到全局 router 中 +router.include_router(devices_router, prefix="/device", tags=["Devices"]) +router.include_router(models_router, prefix="/model", tags=["Models"]) +router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"]) + diff --git a/common/__init__.py b/common/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/common/__init__.py diff --git a/common/biz_exception.py b/common/biz_exception.py new file mode 100644 index 0000000..946a794 --- /dev/null +++ b/common/biz_exception.py @@ -0,0 +1,14 @@ +from fastapi import Request, HTTPException + +from apis.base import standard_error_response + + +class BizException(HTTPException): + def __init__(self, message: str = "请求异常", status_code: int = 500): + super().__init__(status_code=status_code, detail=message) + + +class BizExceptionHandlers: + @staticmethod + async def biz_exception_handler(request: Request, exc: HTTPException): + standard_error_response(code=exc.status_code, message=exc.detail) diff --git a/common/consts.py b/common/consts.py new file mode 100644 index 0000000..bba41e4 --- /dev/null +++ b/common/consts.py @@ -0,0 +1,16 @@ +class Constants: + def __setattr__(self, name, value): + raise AttributeError(f"Cannot modify constant {name}") + + +class NotifyChangeType(Constants): + # 定义常量成员 + DEVICE_CREATE = "device_create" + DEVICE_UPDATE = "device_update" + DEVICE_DELETE = "device_delete" + MODEL_CREATE = "model_create" # 模型新增不用通知,与设备绑定时才生效 + MODEL_UPDATE = "model_update" + MODEL_DELETE = "model_delete" # 正在使用的模型不能直接删除 + DEVICE_MODEL_RELATION_CREATE = "device_model_relation_create" + DEVICE_MODEL_RELATION_UPDATE = "device_model_relation_update" # 绑定关系变化 应该只用这个吧?? + DEVICE_MODEL_RELATION_DELETE = "device_model_relation_delete" diff --git a/common/global_logger.py b/common/global_logger.py new file mode 100644 index 0000000..1d65173 --- /dev/null +++ b/common/global_logger.py @@ -0,0 +1,30 @@ +# logger.py +import logging.handlers +import os + +# 确保日志目录存在 +log_dir = '../logs' +if not os.path.exists(log_dir): + os.makedirs(log_dir) + +# 实例化并导出全局日志记录器 +logger = logging.getLogger("casic_safe_logger") +logger.setLevel(logging.DEBUG) # 设置日志级别 + +# 创建一个TimedRotatingFileHandler +handler = logging.handlers.TimedRotatingFileHandler( + os.path.join(log_dir, 'app.log'), # 日志文件名 + when='midnight', # 每天午夜滚动 + interval=1 # 滚动间隔为1天 +) +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +handler.setFormatter(formatter) + +# 将handler添加到日志器 +logger.addHandler(handler) + +# 创建控制台处理器 +console_handler = logging.StreamHandler() +console_handler.setLevel(logging.DEBUG) +console_handler.setFormatter(formatter) +logger.addHandler(console_handler) diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e530a86 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +test.py +/logs/* \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..901fe74 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,20 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..b68d82f --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..3deddad --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml new file mode 100644 index 0000000..d0876a7 --- /dev/null +++ b/.idea/safe-algo-pro.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/algo/__init__.py b/algo/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/algo/__init__.py diff --git a/algo/algo_runner.py b/algo/algo_runner.py new file mode 100644 index 0000000..ee07eaf --- /dev/null +++ b/algo/algo_runner.py @@ -0,0 +1,117 @@ +from algo.model_manager import ModelManager +from common.consts import NotifyChangeType +from common.global_thread_pool import GlobalThreadPool +from db.database import get_db +from entity.device import Device +from services.device_model_relation_service import DeviceModelRelationService +from services.device_service import DeviceService +from services.model_service import ModelService + + +class AlgoRunner: + def __init__(self): + self.db = get_db() + self.device_service = DeviceService(self.db) + self.model_service = ModelService(self.db) + self.relation_service = DeviceModelRelationService(self.db) + self.model_manager = ModelManager(self.db) + self.thread_pool = GlobalThreadPool() + self.threads = {} # 用于存储设备对应的线程 + self.device_model_relations = {} + + # 注册设备和模型的变化回调 + self.device_service.register_change_callback(self.on_device_change) + self.model_service.register_change_callback(self.on_model_change) + self.relation_service.relation_change_callbacks(self.on_relation_change) + + def start(self): + """在程序启动时调用,读取设备和模型,启动检测线程""" + self.model_manager.load_models() + self.load_and_start_devices() + + def load_and_start_devices(self): + """从数据库读取设备列表并启动线程""" + devices = self.device_service.get_device_list() + devices = [device for device in devices if device.input_stream_url] + for device in devices: + self.start_device_thread(device) + + def run_detection(self, device: Device): + """todo 设备目标检测主逻辑""" + print(f"设备 {device.device_id} 的检测线程启动") + video_stream = self.device_service.get_video_stream(device.device_id) + + while True: + try: + frame = video_stream.read_frame() + for model in self.model_manager.models: + # 调用目标检测模型 + result = self.model_service.run_inference(model, frame) + # 在此处进行后处理,例如标记视频、生成告警等 + self.process_result(result, device.device_id) + except Exception as e: + print(f"设备 {device.device_id} 处理时出错: {e}") + break + video_stream.close() + print(f"设备 {device.device_id} 的检测线程结束") + + def start_device_thread(self, device: Device): + """为单个设备启动检测线程""" + + if not device.input_stream_url: + print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程') + return + + if device.device_id in self.threads: + print(f"设备 {device.device_id} 已经在运行中") + return + + # 获取设备绑定的模型列表 + relations = self.relation_service.get_device_models(device.device_id) + self.device_model_relations[device.device_id] = relations + + if not relations: + print(f"设备 {device.code} 未绑定模型,无法启动检测") + return + + future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}') + + def stop_device_thread(self, device_id): + """todo 控制标志位 停止指定设备的检测线程""" + if device_id in self.threads: + self.threads[device_id].cancel() # 尝试取消线程 + del self.threads[device_id] + print(f"设备 {device_id} 的检测线程已停止") + + def restart_device_thread(self, device_id): + self.stop_device_thread(device_id) + # todo 需要留个关闭时间吗 + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + + def on_device_change(self, device_id, change_type): + """设备变化回调""" + if change_type == NotifyChangeType.DEVICE_CREATE: + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + elif change_type == NotifyChangeType.DEVICE_DELETE: + self.stop_device_thread(device_id) + elif change_type == NotifyChangeType.DEVICE_UPDATE: + self.restart_device_thread(device_id) + + def on_model_change(self, model_id, change_type): + """模型变化回调""" + if change_type == NotifyChangeType.MODEL_UPDATE: + self.model_manager.reload_model(model_id) + + devices_to_reload = [ + device_id for device_id, relation_info in self.device_model_relations.items() + if relation_info.model_id == model_id + ] + for device_id in devices_to_reload: + self.restart_device_thread(device_id) + + def on_relation_change(self, device_id, change_type): + """设备模型关系变化回调""" + if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: + self.restart_device_thread(device_id) diff --git a/algo/model_manager.py b/algo/model_manager.py new file mode 100644 index 0000000..3ab0709 --- /dev/null +++ b/algo/model_manager.py @@ -0,0 +1,64 @@ +from dataclasses import dataclass +from typing import Optional + +import numpy as np +from ultralytics import YOLO + +from entity.model import AlgoModel +from services.model_service import ModelService +from common.global_logger import logger + + +@dataclass +class AlgoModelExec: + algo_model_id: int + algo_model_info: AlgoModel + algo_model_exec: Optional[object] = None + input_size: int = 640 + + +class ModelManager: + def __init__(self, db, model_warm_up=5): + self.db = db + self.model_service = ModelService(self.db) + self.models = {} + self.model_warm_up = model_warm_up + + self.load_models() + + def query_model_inuse(self): + algo_model_list = list(self.model_service.get_models_in_use()) + for algo_model in algo_model_list: + self.models[algo_model.id] = AlgoModelExec( + algo_model_id=algo_model.id, + algo_model_info=algo_model + ) + + def load_models(self): + self.models = {} + self.query_model_inuse() + for algo_model_id, algo_model_exec in self.models.items(): + self.load_model(algo_model_exec) + + def load_model(self, algo_model_exec: AlgoModelExec): + model_name = algo_model_exec.algo_model_info.name + model_path = algo_model_exec.algo_model_info.path + logger.info(f'loading model {model_name}: {model_path}') + + algo_model_exec.algo_model_exec = YOLO(model_path, task='detect') + if self.model_warm_up > 0: + logger.info(f'warming up model {model_name}') + imgsz = algo_model_exec.input_size + if not isinstance(imgsz, list): + imgsz = [imgsz, imgsz] + dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + for i in range(self.model_warm_up): + algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + logger.info(f'warm up model {model_name} success!') + logger.info(f'load model {model_name} success!') + + def reload_model(self, model_id): + algo_model_exec = self.models.get(model_id) + if algo_model_exec: + algo_model_exec.algo_model_exec = None + self.load_model(algo_model_exec) diff --git a/apis/__init__.py b/apis/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/apis/__init__.py diff --git a/apis/base.py b/apis/base.py new file mode 100644 index 0000000..a7c4ede --- /dev/null +++ b/apis/base.py @@ -0,0 +1,26 @@ +from pydantic import BaseModel, Field +from typing import Optional, Generic, TypeVar, Any, List + +# 定义一个泛型类型变量 T +T = TypeVar("T") + + +# 使用泛型 T 定义 data 的类型 +class StandardResponse(BaseModel, Generic[T]): + code: int = 200 + data: Optional[T] = None + message: str = "请求成功" + success: bool = True + + +class PageResponse(BaseModel, Generic[T]): + total: int + items: List[T] = Field(default_factory=list) + + +def standard_response(data: Any = None, code: int = 200, message: str = "请求成功", success: bool = True): + return StandardResponse(data=data, code=code, message=message, success=success) + + +def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False): + return StandardResponse(data=data, code=code, message=message, success=success) diff --git a/apis/device.py b/apis/device.py new file mode 100644 index 0000000..7eeaa78 --- /dev/null +++ b/apis/device.py @@ -0,0 +1,65 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, Query +from sqlmodel import Session + +from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response +from db.database import get_db +from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo +from services.device_service import DeviceService + +router = APIRouter() + + +@router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) +def get_device_list( + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + db: Session = Depends(get_db)): + service = DeviceService(db) + devices = list(service.get_device_list(name, code, device_type)) + return standard_response(data=devices) + + +@router.get("/page/", response_model=StandardResponse[PageResponse[Device]]) +def get_device_page( + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = DeviceService(db) + + # 获取分页后的设备列表和总数 + devices, total = service.get_device_page(name, code, device_type, offset, limit) + + return standard_response( + data=PageResponse(total=total, items=devices) + ) + + +@router.post("/add", response_model=StandardResponse[DeviceInfo]) +def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.create_device(device_data) + return standard_response(data=device) + + +@router.post("/update", response_model=StandardResponse[DeviceInfo]) +def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.update_device(device_data) + if not device: + return standard_error_response(data=device_data, message="Device not found") + return standard_response(data=device) + + +@router.delete("/delete", response_model=StandardResponse[int]) +def delete_device(device_id: int, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.delete_device(device_id) + if not device: + return standard_error_response(data=device_id, message="Device not found") + return standard_response(data=device_id) diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py new file mode 100644 index 0000000..7108571 --- /dev/null +++ b/apis/device_model_realtion.py @@ -0,0 +1,45 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, Query +from sqlmodel import Session + +from apis.base import StandardResponse, standard_response +from db.database import get_db +from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation +from services.device_model_relation_service import DeviceModelRelationService + +router = APIRouter() + + +@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) +def list_by_device( + device_id: int, + db: Session = Depends(get_db)): + service = DeviceModelRelationService(db) + models = list(service.get_device_models(device_id)) + return standard_response(data=models) + + +# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) +# def add_by_device(relation_data: List[DeviceModelRelationCreate], +# device_id: int = Query(...), +# db: Session = Depends(get_db)): +# service = DeviceModelRelationService(db) +# relations = service.add_relations_by_device(device_id, relation_data) +# return standard_response(data=relations) + + +@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) +def update_by_device(relation_data: List[DeviceModelRelationCreate], + device_id: int = Query(...), + db: Session = Depends(get_db)): + service = DeviceModelRelationService(db) + relations = service.update_relations_by_device(device_id, relation_data) + return standard_response(data=relations) + + +# @router.delete("/delete_by_device", response_model=StandardResponse[int]) +# def delete_device(device_id: int, db: Session = Depends(get_db)): +# service = DeviceModelRelationService(db) +# count = service.delete_relations_by_device(device_id) +# return standard_response(data=count) diff --git a/apis/model.py b/apis/model.py new file mode 100644 index 0000000..8adcc14 --- /dev/null +++ b/apis/model.py @@ -0,0 +1,64 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlmodel import Session + +from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response +from common.biz_exception import BizException +from db.database import get_db +from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo +from services.model_service import ModelService + +router = APIRouter() + + +@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) +def get_model_list( + name: Optional[str] = None, + remark: Optional[str] = None, + db: Session = Depends(get_db)): + service = ModelService(db) + models = list(service.get_model_list(name, remark)) + return standard_response(data=models) + + +@router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModel]]) +def get_model_page( + name: Optional[str] = None, + remark: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = ModelService(db) + + # 获取分页后的设备列表和总数 + models, total = service.get_model_page(name, remark, offset, limit) + + return standard_response( + data=PageResponse(total=total, items=models) + ) + + +@router.post("/add", response_model=StandardResponse[AlgoModelInfo]) +def create_model(model_data: AlgoModelCreate, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.create_model(model_data) + return standard_response(data=model) + + +@router.post("/update", response_model=StandardResponse[AlgoModelInfo]) +def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.update_model(model_data) + if not model: + return standard_error_response(data=model_data, message="Model not found") + return standard_response(data=model) + + +@router.delete("/delete", response_model=StandardResponse[int]) +def delete_model(model_id: int, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.delete_model(model_id) + if not model: + return standard_error_response(data=model_id, message="Model not found") + return standard_response(data=model_id) diff --git a/apis/router.py b/apis/router.py new file mode 100644 index 0000000..22b5ed6 --- /dev/null +++ b/apis/router.py @@ -0,0 +1,15 @@ +# api/router.py +from fastapi import APIRouter +from .device import router as devices_router +from .model import router as models_router +from .device_model_realtion import router as device_model_relation_router + + +# 创建一个全局的 router +router = APIRouter() + +# 将各个模块的 router 注册到全局 router 中 +router.include_router(devices_router, prefix="/device", tags=["Devices"]) +router.include_router(models_router, prefix="/model", tags=["Models"]) +router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"]) + diff --git a/common/__init__.py b/common/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/common/__init__.py diff --git a/common/biz_exception.py b/common/biz_exception.py new file mode 100644 index 0000000..946a794 --- /dev/null +++ b/common/biz_exception.py @@ -0,0 +1,14 @@ +from fastapi import Request, HTTPException + +from apis.base import standard_error_response + + +class BizException(HTTPException): + def __init__(self, message: str = "请求异常", status_code: int = 500): + super().__init__(status_code=status_code, detail=message) + + +class BizExceptionHandlers: + @staticmethod + async def biz_exception_handler(request: Request, exc: HTTPException): + standard_error_response(code=exc.status_code, message=exc.detail) diff --git a/common/consts.py b/common/consts.py new file mode 100644 index 0000000..bba41e4 --- /dev/null +++ b/common/consts.py @@ -0,0 +1,16 @@ +class Constants: + def __setattr__(self, name, value): + raise AttributeError(f"Cannot modify constant {name}") + + +class NotifyChangeType(Constants): + # 定义常量成员 + DEVICE_CREATE = "device_create" + DEVICE_UPDATE = "device_update" + DEVICE_DELETE = "device_delete" + MODEL_CREATE = "model_create" # 模型新增不用通知,与设备绑定时才生效 + MODEL_UPDATE = "model_update" + MODEL_DELETE = "model_delete" # 正在使用的模型不能直接删除 + DEVICE_MODEL_RELATION_CREATE = "device_model_relation_create" + DEVICE_MODEL_RELATION_UPDATE = "device_model_relation_update" # 绑定关系变化 应该只用这个吧?? + DEVICE_MODEL_RELATION_DELETE = "device_model_relation_delete" diff --git a/common/global_logger.py b/common/global_logger.py new file mode 100644 index 0000000..1d65173 --- /dev/null +++ b/common/global_logger.py @@ -0,0 +1,30 @@ +# logger.py +import logging.handlers +import os + +# 确保日志目录存在 +log_dir = '../logs' +if not os.path.exists(log_dir): + os.makedirs(log_dir) + +# 实例化并导出全局日志记录器 +logger = logging.getLogger("casic_safe_logger") +logger.setLevel(logging.DEBUG) # 设置日志级别 + +# 创建一个TimedRotatingFileHandler +handler = logging.handlers.TimedRotatingFileHandler( + os.path.join(log_dir, 'app.log'), # 日志文件名 + when='midnight', # 每天午夜滚动 + interval=1 # 滚动间隔为1天 +) +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +handler.setFormatter(formatter) + +# 将handler添加到日志器 +logger.addHandler(handler) + +# 创建控制台处理器 +console_handler = logging.StreamHandler() +console_handler.setLevel(logging.DEBUG) +console_handler.setFormatter(formatter) +logger.addHandler(console_handler) diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py new file mode 100644 index 0000000..21d7fd5 --- /dev/null +++ b/common/global_thread_pool.py @@ -0,0 +1,60 @@ +import uuid +from concurrent.futures import ThreadPoolExecutor +import threading + + +def generate_thread_id(): + """生成唯一的线程 ID""" + return str(uuid.uuid4()) + + +class GlobalThreadPool: + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + with cls._lock: + if cls._instance is None: + # 第一次创建实例时调用父类的 __new__ 来创建实例 + cls._instance = super(GlobalThreadPool, cls).__new__(cls) + # 在此进行一次性的初始化,比如线程池的创建 + max_workers = kwargs.get('max_workers', 10) + cls._instance.executor = ThreadPoolExecutor(max_workers=max_workers) + cls._instance.task_map = {} # 初始化任务映射 + return cls._instance + + def submit_task(self, fn, *args, thread_id=None, **kwargs): + """提交任务到线程池,并记录线程 ID""" + if thread_id is None: + thread_id = generate_thread_id() + if self.check_task_is_running(thread_id): + raise ValueError(f"线程 ID {thread_id} 已存在") + future = self.executor.submit(fn, *args, **kwargs) + self.task_map[thread_id] = future # 记录线程 ID 和 Future 对象的映射 + return thread_id + + def check_task_is_running(self, thread_id): + future = self.task_map.get(thread_id) + if future: + if future.running(): + return True + else: + del self.task_map[thread_id] + return False + else: + return False + + def stop_task(self, thread_id): + """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务""" + future = self.task_map.get(thread_id) + if future: + future.cancel() # 尝试取消任务 + print(f"任务 {thread_id} 已取消") + del self.task_map[thread_id] # 从任务映射中删除 + else: + print(f"未找到线程 ID {thread_id}") + + def shutdown(self, wait=True): + """关闭线程池""" + self.executor.shutdown(wait=wait) + GlobalThreadPool._instance = None diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e530a86 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +test.py +/logs/* \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..901fe74 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,20 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..b68d82f --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..3deddad --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml new file mode 100644 index 0000000..d0876a7 --- /dev/null +++ b/.idea/safe-algo-pro.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/algo/__init__.py b/algo/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/algo/__init__.py diff --git a/algo/algo_runner.py b/algo/algo_runner.py new file mode 100644 index 0000000..ee07eaf --- /dev/null +++ b/algo/algo_runner.py @@ -0,0 +1,117 @@ +from algo.model_manager import ModelManager +from common.consts import NotifyChangeType +from common.global_thread_pool import GlobalThreadPool +from db.database import get_db +from entity.device import Device +from services.device_model_relation_service import DeviceModelRelationService +from services.device_service import DeviceService +from services.model_service import ModelService + + +class AlgoRunner: + def __init__(self): + self.db = get_db() + self.device_service = DeviceService(self.db) + self.model_service = ModelService(self.db) + self.relation_service = DeviceModelRelationService(self.db) + self.model_manager = ModelManager(self.db) + self.thread_pool = GlobalThreadPool() + self.threads = {} # 用于存储设备对应的线程 + self.device_model_relations = {} + + # 注册设备和模型的变化回调 + self.device_service.register_change_callback(self.on_device_change) + self.model_service.register_change_callback(self.on_model_change) + self.relation_service.relation_change_callbacks(self.on_relation_change) + + def start(self): + """在程序启动时调用,读取设备和模型,启动检测线程""" + self.model_manager.load_models() + self.load_and_start_devices() + + def load_and_start_devices(self): + """从数据库读取设备列表并启动线程""" + devices = self.device_service.get_device_list() + devices = [device for device in devices if device.input_stream_url] + for device in devices: + self.start_device_thread(device) + + def run_detection(self, device: Device): + """todo 设备目标检测主逻辑""" + print(f"设备 {device.device_id} 的检测线程启动") + video_stream = self.device_service.get_video_stream(device.device_id) + + while True: + try: + frame = video_stream.read_frame() + for model in self.model_manager.models: + # 调用目标检测模型 + result = self.model_service.run_inference(model, frame) + # 在此处进行后处理,例如标记视频、生成告警等 + self.process_result(result, device.device_id) + except Exception as e: + print(f"设备 {device.device_id} 处理时出错: {e}") + break + video_stream.close() + print(f"设备 {device.device_id} 的检测线程结束") + + def start_device_thread(self, device: Device): + """为单个设备启动检测线程""" + + if not device.input_stream_url: + print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程') + return + + if device.device_id in self.threads: + print(f"设备 {device.device_id} 已经在运行中") + return + + # 获取设备绑定的模型列表 + relations = self.relation_service.get_device_models(device.device_id) + self.device_model_relations[device.device_id] = relations + + if not relations: + print(f"设备 {device.code} 未绑定模型,无法启动检测") + return + + future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}') + + def stop_device_thread(self, device_id): + """todo 控制标志位 停止指定设备的检测线程""" + if device_id in self.threads: + self.threads[device_id].cancel() # 尝试取消线程 + del self.threads[device_id] + print(f"设备 {device_id} 的检测线程已停止") + + def restart_device_thread(self, device_id): + self.stop_device_thread(device_id) + # todo 需要留个关闭时间吗 + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + + def on_device_change(self, device_id, change_type): + """设备变化回调""" + if change_type == NotifyChangeType.DEVICE_CREATE: + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + elif change_type == NotifyChangeType.DEVICE_DELETE: + self.stop_device_thread(device_id) + elif change_type == NotifyChangeType.DEVICE_UPDATE: + self.restart_device_thread(device_id) + + def on_model_change(self, model_id, change_type): + """模型变化回调""" + if change_type == NotifyChangeType.MODEL_UPDATE: + self.model_manager.reload_model(model_id) + + devices_to_reload = [ + device_id for device_id, relation_info in self.device_model_relations.items() + if relation_info.model_id == model_id + ] + for device_id in devices_to_reload: + self.restart_device_thread(device_id) + + def on_relation_change(self, device_id, change_type): + """设备模型关系变化回调""" + if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: + self.restart_device_thread(device_id) diff --git a/algo/model_manager.py b/algo/model_manager.py new file mode 100644 index 0000000..3ab0709 --- /dev/null +++ b/algo/model_manager.py @@ -0,0 +1,64 @@ +from dataclasses import dataclass +from typing import Optional + +import numpy as np +from ultralytics import YOLO + +from entity.model import AlgoModel +from services.model_service import ModelService +from common.global_logger import logger + + +@dataclass +class AlgoModelExec: + algo_model_id: int + algo_model_info: AlgoModel + algo_model_exec: Optional[object] = None + input_size: int = 640 + + +class ModelManager: + def __init__(self, db, model_warm_up=5): + self.db = db + self.model_service = ModelService(self.db) + self.models = {} + self.model_warm_up = model_warm_up + + self.load_models() + + def query_model_inuse(self): + algo_model_list = list(self.model_service.get_models_in_use()) + for algo_model in algo_model_list: + self.models[algo_model.id] = AlgoModelExec( + algo_model_id=algo_model.id, + algo_model_info=algo_model + ) + + def load_models(self): + self.models = {} + self.query_model_inuse() + for algo_model_id, algo_model_exec in self.models.items(): + self.load_model(algo_model_exec) + + def load_model(self, algo_model_exec: AlgoModelExec): + model_name = algo_model_exec.algo_model_info.name + model_path = algo_model_exec.algo_model_info.path + logger.info(f'loading model {model_name}: {model_path}') + + algo_model_exec.algo_model_exec = YOLO(model_path, task='detect') + if self.model_warm_up > 0: + logger.info(f'warming up model {model_name}') + imgsz = algo_model_exec.input_size + if not isinstance(imgsz, list): + imgsz = [imgsz, imgsz] + dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + for i in range(self.model_warm_up): + algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + logger.info(f'warm up model {model_name} success!') + logger.info(f'load model {model_name} success!') + + def reload_model(self, model_id): + algo_model_exec = self.models.get(model_id) + if algo_model_exec: + algo_model_exec.algo_model_exec = None + self.load_model(algo_model_exec) diff --git a/apis/__init__.py b/apis/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/apis/__init__.py diff --git a/apis/base.py b/apis/base.py new file mode 100644 index 0000000..a7c4ede --- /dev/null +++ b/apis/base.py @@ -0,0 +1,26 @@ +from pydantic import BaseModel, Field +from typing import Optional, Generic, TypeVar, Any, List + +# 定义一个泛型类型变量 T +T = TypeVar("T") + + +# 使用泛型 T 定义 data 的类型 +class StandardResponse(BaseModel, Generic[T]): + code: int = 200 + data: Optional[T] = None + message: str = "请求成功" + success: bool = True + + +class PageResponse(BaseModel, Generic[T]): + total: int + items: List[T] = Field(default_factory=list) + + +def standard_response(data: Any = None, code: int = 200, message: str = "请求成功", success: bool = True): + return StandardResponse(data=data, code=code, message=message, success=success) + + +def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False): + return StandardResponse(data=data, code=code, message=message, success=success) diff --git a/apis/device.py b/apis/device.py new file mode 100644 index 0000000..7eeaa78 --- /dev/null +++ b/apis/device.py @@ -0,0 +1,65 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, Query +from sqlmodel import Session + +from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response +from db.database import get_db +from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo +from services.device_service import DeviceService + +router = APIRouter() + + +@router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) +def get_device_list( + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + db: Session = Depends(get_db)): + service = DeviceService(db) + devices = list(service.get_device_list(name, code, device_type)) + return standard_response(data=devices) + + +@router.get("/page/", response_model=StandardResponse[PageResponse[Device]]) +def get_device_page( + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = DeviceService(db) + + # 获取分页后的设备列表和总数 + devices, total = service.get_device_page(name, code, device_type, offset, limit) + + return standard_response( + data=PageResponse(total=total, items=devices) + ) + + +@router.post("/add", response_model=StandardResponse[DeviceInfo]) +def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.create_device(device_data) + return standard_response(data=device) + + +@router.post("/update", response_model=StandardResponse[DeviceInfo]) +def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.update_device(device_data) + if not device: + return standard_error_response(data=device_data, message="Device not found") + return standard_response(data=device) + + +@router.delete("/delete", response_model=StandardResponse[int]) +def delete_device(device_id: int, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.delete_device(device_id) + if not device: + return standard_error_response(data=device_id, message="Device not found") + return standard_response(data=device_id) diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py new file mode 100644 index 0000000..7108571 --- /dev/null +++ b/apis/device_model_realtion.py @@ -0,0 +1,45 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, Query +from sqlmodel import Session + +from apis.base import StandardResponse, standard_response +from db.database import get_db +from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation +from services.device_model_relation_service import DeviceModelRelationService + +router = APIRouter() + + +@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) +def list_by_device( + device_id: int, + db: Session = Depends(get_db)): + service = DeviceModelRelationService(db) + models = list(service.get_device_models(device_id)) + return standard_response(data=models) + + +# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) +# def add_by_device(relation_data: List[DeviceModelRelationCreate], +# device_id: int = Query(...), +# db: Session = Depends(get_db)): +# service = DeviceModelRelationService(db) +# relations = service.add_relations_by_device(device_id, relation_data) +# return standard_response(data=relations) + + +@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) +def update_by_device(relation_data: List[DeviceModelRelationCreate], + device_id: int = Query(...), + db: Session = Depends(get_db)): + service = DeviceModelRelationService(db) + relations = service.update_relations_by_device(device_id, relation_data) + return standard_response(data=relations) + + +# @router.delete("/delete_by_device", response_model=StandardResponse[int]) +# def delete_device(device_id: int, db: Session = Depends(get_db)): +# service = DeviceModelRelationService(db) +# count = service.delete_relations_by_device(device_id) +# return standard_response(data=count) diff --git a/apis/model.py b/apis/model.py new file mode 100644 index 0000000..8adcc14 --- /dev/null +++ b/apis/model.py @@ -0,0 +1,64 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlmodel import Session + +from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response +from common.biz_exception import BizException +from db.database import get_db +from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo +from services.model_service import ModelService + +router = APIRouter() + + +@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) +def get_model_list( + name: Optional[str] = None, + remark: Optional[str] = None, + db: Session = Depends(get_db)): + service = ModelService(db) + models = list(service.get_model_list(name, remark)) + return standard_response(data=models) + + +@router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModel]]) +def get_model_page( + name: Optional[str] = None, + remark: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = ModelService(db) + + # 获取分页后的设备列表和总数 + models, total = service.get_model_page(name, remark, offset, limit) + + return standard_response( + data=PageResponse(total=total, items=models) + ) + + +@router.post("/add", response_model=StandardResponse[AlgoModelInfo]) +def create_model(model_data: AlgoModelCreate, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.create_model(model_data) + return standard_response(data=model) + + +@router.post("/update", response_model=StandardResponse[AlgoModelInfo]) +def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.update_model(model_data) + if not model: + return standard_error_response(data=model_data, message="Model not found") + return standard_response(data=model) + + +@router.delete("/delete", response_model=StandardResponse[int]) +def delete_model(model_id: int, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.delete_model(model_id) + if not model: + return standard_error_response(data=model_id, message="Model not found") + return standard_response(data=model_id) diff --git a/apis/router.py b/apis/router.py new file mode 100644 index 0000000..22b5ed6 --- /dev/null +++ b/apis/router.py @@ -0,0 +1,15 @@ +# api/router.py +from fastapi import APIRouter +from .device import router as devices_router +from .model import router as models_router +from .device_model_realtion import router as device_model_relation_router + + +# 创建一个全局的 router +router = APIRouter() + +# 将各个模块的 router 注册到全局 router 中 +router.include_router(devices_router, prefix="/device", tags=["Devices"]) +router.include_router(models_router, prefix="/model", tags=["Models"]) +router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"]) + diff --git a/common/__init__.py b/common/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/common/__init__.py diff --git a/common/biz_exception.py b/common/biz_exception.py new file mode 100644 index 0000000..946a794 --- /dev/null +++ b/common/biz_exception.py @@ -0,0 +1,14 @@ +from fastapi import Request, HTTPException + +from apis.base import standard_error_response + + +class BizException(HTTPException): + def __init__(self, message: str = "请求异常", status_code: int = 500): + super().__init__(status_code=status_code, detail=message) + + +class BizExceptionHandlers: + @staticmethod + async def biz_exception_handler(request: Request, exc: HTTPException): + standard_error_response(code=exc.status_code, message=exc.detail) diff --git a/common/consts.py b/common/consts.py new file mode 100644 index 0000000..bba41e4 --- /dev/null +++ b/common/consts.py @@ -0,0 +1,16 @@ +class Constants: + def __setattr__(self, name, value): + raise AttributeError(f"Cannot modify constant {name}") + + +class NotifyChangeType(Constants): + # 定义常量成员 + DEVICE_CREATE = "device_create" + DEVICE_UPDATE = "device_update" + DEVICE_DELETE = "device_delete" + MODEL_CREATE = "model_create" # 模型新增不用通知,与设备绑定时才生效 + MODEL_UPDATE = "model_update" + MODEL_DELETE = "model_delete" # 正在使用的模型不能直接删除 + DEVICE_MODEL_RELATION_CREATE = "device_model_relation_create" + DEVICE_MODEL_RELATION_UPDATE = "device_model_relation_update" # 绑定关系变化 应该只用这个吧?? + DEVICE_MODEL_RELATION_DELETE = "device_model_relation_delete" diff --git a/common/global_logger.py b/common/global_logger.py new file mode 100644 index 0000000..1d65173 --- /dev/null +++ b/common/global_logger.py @@ -0,0 +1,30 @@ +# logger.py +import logging.handlers +import os + +# 确保日志目录存在 +log_dir = '../logs' +if not os.path.exists(log_dir): + os.makedirs(log_dir) + +# 实例化并导出全局日志记录器 +logger = logging.getLogger("casic_safe_logger") +logger.setLevel(logging.DEBUG) # 设置日志级别 + +# 创建一个TimedRotatingFileHandler +handler = logging.handlers.TimedRotatingFileHandler( + os.path.join(log_dir, 'app.log'), # 日志文件名 + when='midnight', # 每天午夜滚动 + interval=1 # 滚动间隔为1天 +) +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +handler.setFormatter(formatter) + +# 将handler添加到日志器 +logger.addHandler(handler) + +# 创建控制台处理器 +console_handler = logging.StreamHandler() +console_handler.setLevel(logging.DEBUG) +console_handler.setFormatter(formatter) +logger.addHandler(console_handler) diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py new file mode 100644 index 0000000..21d7fd5 --- /dev/null +++ b/common/global_thread_pool.py @@ -0,0 +1,60 @@ +import uuid +from concurrent.futures import ThreadPoolExecutor +import threading + + +def generate_thread_id(): + """生成唯一的线程 ID""" + return str(uuid.uuid4()) + + +class GlobalThreadPool: + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + with cls._lock: + if cls._instance is None: + # 第一次创建实例时调用父类的 __new__ 来创建实例 + cls._instance = super(GlobalThreadPool, cls).__new__(cls) + # 在此进行一次性的初始化,比如线程池的创建 + max_workers = kwargs.get('max_workers', 10) + cls._instance.executor = ThreadPoolExecutor(max_workers=max_workers) + cls._instance.task_map = {} # 初始化任务映射 + return cls._instance + + def submit_task(self, fn, *args, thread_id=None, **kwargs): + """提交任务到线程池,并记录线程 ID""" + if thread_id is None: + thread_id = generate_thread_id() + if self.check_task_is_running(thread_id): + raise ValueError(f"线程 ID {thread_id} 已存在") + future = self.executor.submit(fn, *args, **kwargs) + self.task_map[thread_id] = future # 记录线程 ID 和 Future 对象的映射 + return thread_id + + def check_task_is_running(self, thread_id): + future = self.task_map.get(thread_id) + if future: + if future.running(): + return True + else: + del self.task_map[thread_id] + return False + else: + return False + + def stop_task(self, thread_id): + """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务""" + future = self.task_map.get(thread_id) + if future: + future.cancel() # 尝试取消任务 + print(f"任务 {thread_id} 已取消") + del self.task_map[thread_id] # 从任务映射中删除 + else: + print(f"未找到线程 ID {thread_id}") + + def shutdown(self, wait=True): + """关闭线程池""" + self.executor.shutdown(wait=wait) + GlobalThreadPool._instance = None diff --git a/db/__init__.py b/db/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/db/__init__.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e530a86 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +test.py +/logs/* \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..901fe74 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,20 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..b68d82f --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..3deddad --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml new file mode 100644 index 0000000..d0876a7 --- /dev/null +++ b/.idea/safe-algo-pro.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/algo/__init__.py b/algo/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/algo/__init__.py diff --git a/algo/algo_runner.py b/algo/algo_runner.py new file mode 100644 index 0000000..ee07eaf --- /dev/null +++ b/algo/algo_runner.py @@ -0,0 +1,117 @@ +from algo.model_manager import ModelManager +from common.consts import NotifyChangeType +from common.global_thread_pool import GlobalThreadPool +from db.database import get_db +from entity.device import Device +from services.device_model_relation_service import DeviceModelRelationService +from services.device_service import DeviceService +from services.model_service import ModelService + + +class AlgoRunner: + def __init__(self): + self.db = get_db() + self.device_service = DeviceService(self.db) + self.model_service = ModelService(self.db) + self.relation_service = DeviceModelRelationService(self.db) + self.model_manager = ModelManager(self.db) + self.thread_pool = GlobalThreadPool() + self.threads = {} # 用于存储设备对应的线程 + self.device_model_relations = {} + + # 注册设备和模型的变化回调 + self.device_service.register_change_callback(self.on_device_change) + self.model_service.register_change_callback(self.on_model_change) + self.relation_service.relation_change_callbacks(self.on_relation_change) + + def start(self): + """在程序启动时调用,读取设备和模型,启动检测线程""" + self.model_manager.load_models() + self.load_and_start_devices() + + def load_and_start_devices(self): + """从数据库读取设备列表并启动线程""" + devices = self.device_service.get_device_list() + devices = [device for device in devices if device.input_stream_url] + for device in devices: + self.start_device_thread(device) + + def run_detection(self, device: Device): + """todo 设备目标检测主逻辑""" + print(f"设备 {device.device_id} 的检测线程启动") + video_stream = self.device_service.get_video_stream(device.device_id) + + while True: + try: + frame = video_stream.read_frame() + for model in self.model_manager.models: + # 调用目标检测模型 + result = self.model_service.run_inference(model, frame) + # 在此处进行后处理,例如标记视频、生成告警等 + self.process_result(result, device.device_id) + except Exception as e: + print(f"设备 {device.device_id} 处理时出错: {e}") + break + video_stream.close() + print(f"设备 {device.device_id} 的检测线程结束") + + def start_device_thread(self, device: Device): + """为单个设备启动检测线程""" + + if not device.input_stream_url: + print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程') + return + + if device.device_id in self.threads: + print(f"设备 {device.device_id} 已经在运行中") + return + + # 获取设备绑定的模型列表 + relations = self.relation_service.get_device_models(device.device_id) + self.device_model_relations[device.device_id] = relations + + if not relations: + print(f"设备 {device.code} 未绑定模型,无法启动检测") + return + + future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}') + + def stop_device_thread(self, device_id): + """todo 控制标志位 停止指定设备的检测线程""" + if device_id in self.threads: + self.threads[device_id].cancel() # 尝试取消线程 + del self.threads[device_id] + print(f"设备 {device_id} 的检测线程已停止") + + def restart_device_thread(self, device_id): + self.stop_device_thread(device_id) + # todo 需要留个关闭时间吗 + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + + def on_device_change(self, device_id, change_type): + """设备变化回调""" + if change_type == NotifyChangeType.DEVICE_CREATE: + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + elif change_type == NotifyChangeType.DEVICE_DELETE: + self.stop_device_thread(device_id) + elif change_type == NotifyChangeType.DEVICE_UPDATE: + self.restart_device_thread(device_id) + + def on_model_change(self, model_id, change_type): + """模型变化回调""" + if change_type == NotifyChangeType.MODEL_UPDATE: + self.model_manager.reload_model(model_id) + + devices_to_reload = [ + device_id for device_id, relation_info in self.device_model_relations.items() + if relation_info.model_id == model_id + ] + for device_id in devices_to_reload: + self.restart_device_thread(device_id) + + def on_relation_change(self, device_id, change_type): + """设备模型关系变化回调""" + if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: + self.restart_device_thread(device_id) diff --git a/algo/model_manager.py b/algo/model_manager.py new file mode 100644 index 0000000..3ab0709 --- /dev/null +++ b/algo/model_manager.py @@ -0,0 +1,64 @@ +from dataclasses import dataclass +from typing import Optional + +import numpy as np +from ultralytics import YOLO + +from entity.model import AlgoModel +from services.model_service import ModelService +from common.global_logger import logger + + +@dataclass +class AlgoModelExec: + algo_model_id: int + algo_model_info: AlgoModel + algo_model_exec: Optional[object] = None + input_size: int = 640 + + +class ModelManager: + def __init__(self, db, model_warm_up=5): + self.db = db + self.model_service = ModelService(self.db) + self.models = {} + self.model_warm_up = model_warm_up + + self.load_models() + + def query_model_inuse(self): + algo_model_list = list(self.model_service.get_models_in_use()) + for algo_model in algo_model_list: + self.models[algo_model.id] = AlgoModelExec( + algo_model_id=algo_model.id, + algo_model_info=algo_model + ) + + def load_models(self): + self.models = {} + self.query_model_inuse() + for algo_model_id, algo_model_exec in self.models.items(): + self.load_model(algo_model_exec) + + def load_model(self, algo_model_exec: AlgoModelExec): + model_name = algo_model_exec.algo_model_info.name + model_path = algo_model_exec.algo_model_info.path + logger.info(f'loading model {model_name}: {model_path}') + + algo_model_exec.algo_model_exec = YOLO(model_path, task='detect') + if self.model_warm_up > 0: + logger.info(f'warming up model {model_name}') + imgsz = algo_model_exec.input_size + if not isinstance(imgsz, list): + imgsz = [imgsz, imgsz] + dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + for i in range(self.model_warm_up): + algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + logger.info(f'warm up model {model_name} success!') + logger.info(f'load model {model_name} success!') + + def reload_model(self, model_id): + algo_model_exec = self.models.get(model_id) + if algo_model_exec: + algo_model_exec.algo_model_exec = None + self.load_model(algo_model_exec) diff --git a/apis/__init__.py b/apis/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/apis/__init__.py diff --git a/apis/base.py b/apis/base.py new file mode 100644 index 0000000..a7c4ede --- /dev/null +++ b/apis/base.py @@ -0,0 +1,26 @@ +from pydantic import BaseModel, Field +from typing import Optional, Generic, TypeVar, Any, List + +# 定义一个泛型类型变量 T +T = TypeVar("T") + + +# 使用泛型 T 定义 data 的类型 +class StandardResponse(BaseModel, Generic[T]): + code: int = 200 + data: Optional[T] = None + message: str = "请求成功" + success: bool = True + + +class PageResponse(BaseModel, Generic[T]): + total: int + items: List[T] = Field(default_factory=list) + + +def standard_response(data: Any = None, code: int = 200, message: str = "请求成功", success: bool = True): + return StandardResponse(data=data, code=code, message=message, success=success) + + +def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False): + return StandardResponse(data=data, code=code, message=message, success=success) diff --git a/apis/device.py b/apis/device.py new file mode 100644 index 0000000..7eeaa78 --- /dev/null +++ b/apis/device.py @@ -0,0 +1,65 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, Query +from sqlmodel import Session + +from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response +from db.database import get_db +from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo +from services.device_service import DeviceService + +router = APIRouter() + + +@router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) +def get_device_list( + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + db: Session = Depends(get_db)): + service = DeviceService(db) + devices = list(service.get_device_list(name, code, device_type)) + return standard_response(data=devices) + + +@router.get("/page/", response_model=StandardResponse[PageResponse[Device]]) +def get_device_page( + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = DeviceService(db) + + # 获取分页后的设备列表和总数 + devices, total = service.get_device_page(name, code, device_type, offset, limit) + + return standard_response( + data=PageResponse(total=total, items=devices) + ) + + +@router.post("/add", response_model=StandardResponse[DeviceInfo]) +def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.create_device(device_data) + return standard_response(data=device) + + +@router.post("/update", response_model=StandardResponse[DeviceInfo]) +def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.update_device(device_data) + if not device: + return standard_error_response(data=device_data, message="Device not found") + return standard_response(data=device) + + +@router.delete("/delete", response_model=StandardResponse[int]) +def delete_device(device_id: int, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.delete_device(device_id) + if not device: + return standard_error_response(data=device_id, message="Device not found") + return standard_response(data=device_id) diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py new file mode 100644 index 0000000..7108571 --- /dev/null +++ b/apis/device_model_realtion.py @@ -0,0 +1,45 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, Query +from sqlmodel import Session + +from apis.base import StandardResponse, standard_response +from db.database import get_db +from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation +from services.device_model_relation_service import DeviceModelRelationService + +router = APIRouter() + + +@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) +def list_by_device( + device_id: int, + db: Session = Depends(get_db)): + service = DeviceModelRelationService(db) + models = list(service.get_device_models(device_id)) + return standard_response(data=models) + + +# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) +# def add_by_device(relation_data: List[DeviceModelRelationCreate], +# device_id: int = Query(...), +# db: Session = Depends(get_db)): +# service = DeviceModelRelationService(db) +# relations = service.add_relations_by_device(device_id, relation_data) +# return standard_response(data=relations) + + +@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) +def update_by_device(relation_data: List[DeviceModelRelationCreate], + device_id: int = Query(...), + db: Session = Depends(get_db)): + service = DeviceModelRelationService(db) + relations = service.update_relations_by_device(device_id, relation_data) + return standard_response(data=relations) + + +# @router.delete("/delete_by_device", response_model=StandardResponse[int]) +# def delete_device(device_id: int, db: Session = Depends(get_db)): +# service = DeviceModelRelationService(db) +# count = service.delete_relations_by_device(device_id) +# return standard_response(data=count) diff --git a/apis/model.py b/apis/model.py new file mode 100644 index 0000000..8adcc14 --- /dev/null +++ b/apis/model.py @@ -0,0 +1,64 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlmodel import Session + +from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response +from common.biz_exception import BizException +from db.database import get_db +from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo +from services.model_service import ModelService + +router = APIRouter() + + +@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) +def get_model_list( + name: Optional[str] = None, + remark: Optional[str] = None, + db: Session = Depends(get_db)): + service = ModelService(db) + models = list(service.get_model_list(name, remark)) + return standard_response(data=models) + + +@router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModel]]) +def get_model_page( + name: Optional[str] = None, + remark: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = ModelService(db) + + # 获取分页后的设备列表和总数 + models, total = service.get_model_page(name, remark, offset, limit) + + return standard_response( + data=PageResponse(total=total, items=models) + ) + + +@router.post("/add", response_model=StandardResponse[AlgoModelInfo]) +def create_model(model_data: AlgoModelCreate, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.create_model(model_data) + return standard_response(data=model) + + +@router.post("/update", response_model=StandardResponse[AlgoModelInfo]) +def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.update_model(model_data) + if not model: + return standard_error_response(data=model_data, message="Model not found") + return standard_response(data=model) + + +@router.delete("/delete", response_model=StandardResponse[int]) +def delete_model(model_id: int, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.delete_model(model_id) + if not model: + return standard_error_response(data=model_id, message="Model not found") + return standard_response(data=model_id) diff --git a/apis/router.py b/apis/router.py new file mode 100644 index 0000000..22b5ed6 --- /dev/null +++ b/apis/router.py @@ -0,0 +1,15 @@ +# api/router.py +from fastapi import APIRouter +from .device import router as devices_router +from .model import router as models_router +from .device_model_realtion import router as device_model_relation_router + + +# 创建一个全局的 router +router = APIRouter() + +# 将各个模块的 router 注册到全局 router 中 +router.include_router(devices_router, prefix="/device", tags=["Devices"]) +router.include_router(models_router, prefix="/model", tags=["Models"]) +router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"]) + diff --git a/common/__init__.py b/common/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/common/__init__.py diff --git a/common/biz_exception.py b/common/biz_exception.py new file mode 100644 index 0000000..946a794 --- /dev/null +++ b/common/biz_exception.py @@ -0,0 +1,14 @@ +from fastapi import Request, HTTPException + +from apis.base import standard_error_response + + +class BizException(HTTPException): + def __init__(self, message: str = "请求异常", status_code: int = 500): + super().__init__(status_code=status_code, detail=message) + + +class BizExceptionHandlers: + @staticmethod + async def biz_exception_handler(request: Request, exc: HTTPException): + standard_error_response(code=exc.status_code, message=exc.detail) diff --git a/common/consts.py b/common/consts.py new file mode 100644 index 0000000..bba41e4 --- /dev/null +++ b/common/consts.py @@ -0,0 +1,16 @@ +class Constants: + def __setattr__(self, name, value): + raise AttributeError(f"Cannot modify constant {name}") + + +class NotifyChangeType(Constants): + # 定义常量成员 + DEVICE_CREATE = "device_create" + DEVICE_UPDATE = "device_update" + DEVICE_DELETE = "device_delete" + MODEL_CREATE = "model_create" # 模型新增不用通知,与设备绑定时才生效 + MODEL_UPDATE = "model_update" + MODEL_DELETE = "model_delete" # 正在使用的模型不能直接删除 + DEVICE_MODEL_RELATION_CREATE = "device_model_relation_create" + DEVICE_MODEL_RELATION_UPDATE = "device_model_relation_update" # 绑定关系变化 应该只用这个吧?? + DEVICE_MODEL_RELATION_DELETE = "device_model_relation_delete" diff --git a/common/global_logger.py b/common/global_logger.py new file mode 100644 index 0000000..1d65173 --- /dev/null +++ b/common/global_logger.py @@ -0,0 +1,30 @@ +# logger.py +import logging.handlers +import os + +# 确保日志目录存在 +log_dir = '../logs' +if not os.path.exists(log_dir): + os.makedirs(log_dir) + +# 实例化并导出全局日志记录器 +logger = logging.getLogger("casic_safe_logger") +logger.setLevel(logging.DEBUG) # 设置日志级别 + +# 创建一个TimedRotatingFileHandler +handler = logging.handlers.TimedRotatingFileHandler( + os.path.join(log_dir, 'app.log'), # 日志文件名 + when='midnight', # 每天午夜滚动 + interval=1 # 滚动间隔为1天 +) +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +handler.setFormatter(formatter) + +# 将handler添加到日志器 +logger.addHandler(handler) + +# 创建控制台处理器 +console_handler = logging.StreamHandler() +console_handler.setLevel(logging.DEBUG) +console_handler.setFormatter(formatter) +logger.addHandler(console_handler) diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py new file mode 100644 index 0000000..21d7fd5 --- /dev/null +++ b/common/global_thread_pool.py @@ -0,0 +1,60 @@ +import uuid +from concurrent.futures import ThreadPoolExecutor +import threading + + +def generate_thread_id(): + """生成唯一的线程 ID""" + return str(uuid.uuid4()) + + +class GlobalThreadPool: + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + with cls._lock: + if cls._instance is None: + # 第一次创建实例时调用父类的 __new__ 来创建实例 + cls._instance = super(GlobalThreadPool, cls).__new__(cls) + # 在此进行一次性的初始化,比如线程池的创建 + max_workers = kwargs.get('max_workers', 10) + cls._instance.executor = ThreadPoolExecutor(max_workers=max_workers) + cls._instance.task_map = {} # 初始化任务映射 + return cls._instance + + def submit_task(self, fn, *args, thread_id=None, **kwargs): + """提交任务到线程池,并记录线程 ID""" + if thread_id is None: + thread_id = generate_thread_id() + if self.check_task_is_running(thread_id): + raise ValueError(f"线程 ID {thread_id} 已存在") + future = self.executor.submit(fn, *args, **kwargs) + self.task_map[thread_id] = future # 记录线程 ID 和 Future 对象的映射 + return thread_id + + def check_task_is_running(self, thread_id): + future = self.task_map.get(thread_id) + if future: + if future.running(): + return True + else: + del self.task_map[thread_id] + return False + else: + return False + + def stop_task(self, thread_id): + """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务""" + future = self.task_map.get(thread_id) + if future: + future.cancel() # 尝试取消任务 + print(f"任务 {thread_id} 已取消") + del self.task_map[thread_id] # 从任务映射中删除 + else: + print(f"未找到线程 ID {thread_id}") + + def shutdown(self, wait=True): + """关闭线程池""" + self.executor.shutdown(wait=wait) + GlobalThreadPool._instance = None diff --git a/db/__init__.py b/db/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/db/__init__.py diff --git a/db/database.py b/db/database.py new file mode 100644 index 0000000..4f44d2d --- /dev/null +++ b/db/database.py @@ -0,0 +1,22 @@ +from sqlmodel import SQLModel, create_engine, Session +from contextlib import contextmanager + +sqlite_file_name = "./db/safe-algo-pro.db" +sqlite_url = f"sqlite:///{sqlite_file_name}" + +connect_args = {"check_same_thread": False} +engine = create_engine(sqlite_url, connect_args=connect_args) + + +# 初始化数据库表 +def init_db(): + SQLModel.metadata.create_all(engine) + + +# 数据库会话管理 +def get_db(): + session = Session(engine) + try: + yield session + finally: + session.close() diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e530a86 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +test.py +/logs/* \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..901fe74 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,20 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..b68d82f --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..3deddad --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml new file mode 100644 index 0000000..d0876a7 --- /dev/null +++ b/.idea/safe-algo-pro.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/algo/__init__.py b/algo/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/algo/__init__.py diff --git a/algo/algo_runner.py b/algo/algo_runner.py new file mode 100644 index 0000000..ee07eaf --- /dev/null +++ b/algo/algo_runner.py @@ -0,0 +1,117 @@ +from algo.model_manager import ModelManager +from common.consts import NotifyChangeType +from common.global_thread_pool import GlobalThreadPool +from db.database import get_db +from entity.device import Device +from services.device_model_relation_service import DeviceModelRelationService +from services.device_service import DeviceService +from services.model_service import ModelService + + +class AlgoRunner: + def __init__(self): + self.db = get_db() + self.device_service = DeviceService(self.db) + self.model_service = ModelService(self.db) + self.relation_service = DeviceModelRelationService(self.db) + self.model_manager = ModelManager(self.db) + self.thread_pool = GlobalThreadPool() + self.threads = {} # 用于存储设备对应的线程 + self.device_model_relations = {} + + # 注册设备和模型的变化回调 + self.device_service.register_change_callback(self.on_device_change) + self.model_service.register_change_callback(self.on_model_change) + self.relation_service.relation_change_callbacks(self.on_relation_change) + + def start(self): + """在程序启动时调用,读取设备和模型,启动检测线程""" + self.model_manager.load_models() + self.load_and_start_devices() + + def load_and_start_devices(self): + """从数据库读取设备列表并启动线程""" + devices = self.device_service.get_device_list() + devices = [device for device in devices if device.input_stream_url] + for device in devices: + self.start_device_thread(device) + + def run_detection(self, device: Device): + """todo 设备目标检测主逻辑""" + print(f"设备 {device.device_id} 的检测线程启动") + video_stream = self.device_service.get_video_stream(device.device_id) + + while True: + try: + frame = video_stream.read_frame() + for model in self.model_manager.models: + # 调用目标检测模型 + result = self.model_service.run_inference(model, frame) + # 在此处进行后处理,例如标记视频、生成告警等 + self.process_result(result, device.device_id) + except Exception as e: + print(f"设备 {device.device_id} 处理时出错: {e}") + break + video_stream.close() + print(f"设备 {device.device_id} 的检测线程结束") + + def start_device_thread(self, device: Device): + """为单个设备启动检测线程""" + + if not device.input_stream_url: + print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程') + return + + if device.device_id in self.threads: + print(f"设备 {device.device_id} 已经在运行中") + return + + # 获取设备绑定的模型列表 + relations = self.relation_service.get_device_models(device.device_id) + self.device_model_relations[device.device_id] = relations + + if not relations: + print(f"设备 {device.code} 未绑定模型,无法启动检测") + return + + future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}') + + def stop_device_thread(self, device_id): + """todo 控制标志位 停止指定设备的检测线程""" + if device_id in self.threads: + self.threads[device_id].cancel() # 尝试取消线程 + del self.threads[device_id] + print(f"设备 {device_id} 的检测线程已停止") + + def restart_device_thread(self, device_id): + self.stop_device_thread(device_id) + # todo 需要留个关闭时间吗 + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + + def on_device_change(self, device_id, change_type): + """设备变化回调""" + if change_type == NotifyChangeType.DEVICE_CREATE: + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + elif change_type == NotifyChangeType.DEVICE_DELETE: + self.stop_device_thread(device_id) + elif change_type == NotifyChangeType.DEVICE_UPDATE: + self.restart_device_thread(device_id) + + def on_model_change(self, model_id, change_type): + """模型变化回调""" + if change_type == NotifyChangeType.MODEL_UPDATE: + self.model_manager.reload_model(model_id) + + devices_to_reload = [ + device_id for device_id, relation_info in self.device_model_relations.items() + if relation_info.model_id == model_id + ] + for device_id in devices_to_reload: + self.restart_device_thread(device_id) + + def on_relation_change(self, device_id, change_type): + """设备模型关系变化回调""" + if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: + self.restart_device_thread(device_id) diff --git a/algo/model_manager.py b/algo/model_manager.py new file mode 100644 index 0000000..3ab0709 --- /dev/null +++ b/algo/model_manager.py @@ -0,0 +1,64 @@ +from dataclasses import dataclass +from typing import Optional + +import numpy as np +from ultralytics import YOLO + +from entity.model import AlgoModel +from services.model_service import ModelService +from common.global_logger import logger + + +@dataclass +class AlgoModelExec: + algo_model_id: int + algo_model_info: AlgoModel + algo_model_exec: Optional[object] = None + input_size: int = 640 + + +class ModelManager: + def __init__(self, db, model_warm_up=5): + self.db = db + self.model_service = ModelService(self.db) + self.models = {} + self.model_warm_up = model_warm_up + + self.load_models() + + def query_model_inuse(self): + algo_model_list = list(self.model_service.get_models_in_use()) + for algo_model in algo_model_list: + self.models[algo_model.id] = AlgoModelExec( + algo_model_id=algo_model.id, + algo_model_info=algo_model + ) + + def load_models(self): + self.models = {} + self.query_model_inuse() + for algo_model_id, algo_model_exec in self.models.items(): + self.load_model(algo_model_exec) + + def load_model(self, algo_model_exec: AlgoModelExec): + model_name = algo_model_exec.algo_model_info.name + model_path = algo_model_exec.algo_model_info.path + logger.info(f'loading model {model_name}: {model_path}') + + algo_model_exec.algo_model_exec = YOLO(model_path, task='detect') + if self.model_warm_up > 0: + logger.info(f'warming up model {model_name}') + imgsz = algo_model_exec.input_size + if not isinstance(imgsz, list): + imgsz = [imgsz, imgsz] + dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + for i in range(self.model_warm_up): + algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + logger.info(f'warm up model {model_name} success!') + logger.info(f'load model {model_name} success!') + + def reload_model(self, model_id): + algo_model_exec = self.models.get(model_id) + if algo_model_exec: + algo_model_exec.algo_model_exec = None + self.load_model(algo_model_exec) diff --git a/apis/__init__.py b/apis/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/apis/__init__.py diff --git a/apis/base.py b/apis/base.py new file mode 100644 index 0000000..a7c4ede --- /dev/null +++ b/apis/base.py @@ -0,0 +1,26 @@ +from pydantic import BaseModel, Field +from typing import Optional, Generic, TypeVar, Any, List + +# 定义一个泛型类型变量 T +T = TypeVar("T") + + +# 使用泛型 T 定义 data 的类型 +class StandardResponse(BaseModel, Generic[T]): + code: int = 200 + data: Optional[T] = None + message: str = "请求成功" + success: bool = True + + +class PageResponse(BaseModel, Generic[T]): + total: int + items: List[T] = Field(default_factory=list) + + +def standard_response(data: Any = None, code: int = 200, message: str = "请求成功", success: bool = True): + return StandardResponse(data=data, code=code, message=message, success=success) + + +def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False): + return StandardResponse(data=data, code=code, message=message, success=success) diff --git a/apis/device.py b/apis/device.py new file mode 100644 index 0000000..7eeaa78 --- /dev/null +++ b/apis/device.py @@ -0,0 +1,65 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, Query +from sqlmodel import Session + +from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response +from db.database import get_db +from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo +from services.device_service import DeviceService + +router = APIRouter() + + +@router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) +def get_device_list( + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + db: Session = Depends(get_db)): + service = DeviceService(db) + devices = list(service.get_device_list(name, code, device_type)) + return standard_response(data=devices) + + +@router.get("/page/", response_model=StandardResponse[PageResponse[Device]]) +def get_device_page( + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = DeviceService(db) + + # 获取分页后的设备列表和总数 + devices, total = service.get_device_page(name, code, device_type, offset, limit) + + return standard_response( + data=PageResponse(total=total, items=devices) + ) + + +@router.post("/add", response_model=StandardResponse[DeviceInfo]) +def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.create_device(device_data) + return standard_response(data=device) + + +@router.post("/update", response_model=StandardResponse[DeviceInfo]) +def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.update_device(device_data) + if not device: + return standard_error_response(data=device_data, message="Device not found") + return standard_response(data=device) + + +@router.delete("/delete", response_model=StandardResponse[int]) +def delete_device(device_id: int, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.delete_device(device_id) + if not device: + return standard_error_response(data=device_id, message="Device not found") + return standard_response(data=device_id) diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py new file mode 100644 index 0000000..7108571 --- /dev/null +++ b/apis/device_model_realtion.py @@ -0,0 +1,45 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, Query +from sqlmodel import Session + +from apis.base import StandardResponse, standard_response +from db.database import get_db +from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation +from services.device_model_relation_service import DeviceModelRelationService + +router = APIRouter() + + +@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) +def list_by_device( + device_id: int, + db: Session = Depends(get_db)): + service = DeviceModelRelationService(db) + models = list(service.get_device_models(device_id)) + return standard_response(data=models) + + +# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) +# def add_by_device(relation_data: List[DeviceModelRelationCreate], +# device_id: int = Query(...), +# db: Session = Depends(get_db)): +# service = DeviceModelRelationService(db) +# relations = service.add_relations_by_device(device_id, relation_data) +# return standard_response(data=relations) + + +@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) +def update_by_device(relation_data: List[DeviceModelRelationCreate], + device_id: int = Query(...), + db: Session = Depends(get_db)): + service = DeviceModelRelationService(db) + relations = service.update_relations_by_device(device_id, relation_data) + return standard_response(data=relations) + + +# @router.delete("/delete_by_device", response_model=StandardResponse[int]) +# def delete_device(device_id: int, db: Session = Depends(get_db)): +# service = DeviceModelRelationService(db) +# count = service.delete_relations_by_device(device_id) +# return standard_response(data=count) diff --git a/apis/model.py b/apis/model.py new file mode 100644 index 0000000..8adcc14 --- /dev/null +++ b/apis/model.py @@ -0,0 +1,64 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlmodel import Session + +from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response +from common.biz_exception import BizException +from db.database import get_db +from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo +from services.model_service import ModelService + +router = APIRouter() + + +@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) +def get_model_list( + name: Optional[str] = None, + remark: Optional[str] = None, + db: Session = Depends(get_db)): + service = ModelService(db) + models = list(service.get_model_list(name, remark)) + return standard_response(data=models) + + +@router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModel]]) +def get_model_page( + name: Optional[str] = None, + remark: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = ModelService(db) + + # 获取分页后的设备列表和总数 + models, total = service.get_model_page(name, remark, offset, limit) + + return standard_response( + data=PageResponse(total=total, items=models) + ) + + +@router.post("/add", response_model=StandardResponse[AlgoModelInfo]) +def create_model(model_data: AlgoModelCreate, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.create_model(model_data) + return standard_response(data=model) + + +@router.post("/update", response_model=StandardResponse[AlgoModelInfo]) +def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.update_model(model_data) + if not model: + return standard_error_response(data=model_data, message="Model not found") + return standard_response(data=model) + + +@router.delete("/delete", response_model=StandardResponse[int]) +def delete_model(model_id: int, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.delete_model(model_id) + if not model: + return standard_error_response(data=model_id, message="Model not found") + return standard_response(data=model_id) diff --git a/apis/router.py b/apis/router.py new file mode 100644 index 0000000..22b5ed6 --- /dev/null +++ b/apis/router.py @@ -0,0 +1,15 @@ +# api/router.py +from fastapi import APIRouter +from .device import router as devices_router +from .model import router as models_router +from .device_model_realtion import router as device_model_relation_router + + +# 创建一个全局的 router +router = APIRouter() + +# 将各个模块的 router 注册到全局 router 中 +router.include_router(devices_router, prefix="/device", tags=["Devices"]) +router.include_router(models_router, prefix="/model", tags=["Models"]) +router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"]) + diff --git a/common/__init__.py b/common/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/common/__init__.py diff --git a/common/biz_exception.py b/common/biz_exception.py new file mode 100644 index 0000000..946a794 --- /dev/null +++ b/common/biz_exception.py @@ -0,0 +1,14 @@ +from fastapi import Request, HTTPException + +from apis.base import standard_error_response + + +class BizException(HTTPException): + def __init__(self, message: str = "请求异常", status_code: int = 500): + super().__init__(status_code=status_code, detail=message) + + +class BizExceptionHandlers: + @staticmethod + async def biz_exception_handler(request: Request, exc: HTTPException): + standard_error_response(code=exc.status_code, message=exc.detail) diff --git a/common/consts.py b/common/consts.py new file mode 100644 index 0000000..bba41e4 --- /dev/null +++ b/common/consts.py @@ -0,0 +1,16 @@ +class Constants: + def __setattr__(self, name, value): + raise AttributeError(f"Cannot modify constant {name}") + + +class NotifyChangeType(Constants): + # 定义常量成员 + DEVICE_CREATE = "device_create" + DEVICE_UPDATE = "device_update" + DEVICE_DELETE = "device_delete" + MODEL_CREATE = "model_create" # 模型新增不用通知,与设备绑定时才生效 + MODEL_UPDATE = "model_update" + MODEL_DELETE = "model_delete" # 正在使用的模型不能直接删除 + DEVICE_MODEL_RELATION_CREATE = "device_model_relation_create" + DEVICE_MODEL_RELATION_UPDATE = "device_model_relation_update" # 绑定关系变化 应该只用这个吧?? + DEVICE_MODEL_RELATION_DELETE = "device_model_relation_delete" diff --git a/common/global_logger.py b/common/global_logger.py new file mode 100644 index 0000000..1d65173 --- /dev/null +++ b/common/global_logger.py @@ -0,0 +1,30 @@ +# logger.py +import logging.handlers +import os + +# 确保日志目录存在 +log_dir = '../logs' +if not os.path.exists(log_dir): + os.makedirs(log_dir) + +# 实例化并导出全局日志记录器 +logger = logging.getLogger("casic_safe_logger") +logger.setLevel(logging.DEBUG) # 设置日志级别 + +# 创建一个TimedRotatingFileHandler +handler = logging.handlers.TimedRotatingFileHandler( + os.path.join(log_dir, 'app.log'), # 日志文件名 + when='midnight', # 每天午夜滚动 + interval=1 # 滚动间隔为1天 +) +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +handler.setFormatter(formatter) + +# 将handler添加到日志器 +logger.addHandler(handler) + +# 创建控制台处理器 +console_handler = logging.StreamHandler() +console_handler.setLevel(logging.DEBUG) +console_handler.setFormatter(formatter) +logger.addHandler(console_handler) diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py new file mode 100644 index 0000000..21d7fd5 --- /dev/null +++ b/common/global_thread_pool.py @@ -0,0 +1,60 @@ +import uuid +from concurrent.futures import ThreadPoolExecutor +import threading + + +def generate_thread_id(): + """生成唯一的线程 ID""" + return str(uuid.uuid4()) + + +class GlobalThreadPool: + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + with cls._lock: + if cls._instance is None: + # 第一次创建实例时调用父类的 __new__ 来创建实例 + cls._instance = super(GlobalThreadPool, cls).__new__(cls) + # 在此进行一次性的初始化,比如线程池的创建 + max_workers = kwargs.get('max_workers', 10) + cls._instance.executor = ThreadPoolExecutor(max_workers=max_workers) + cls._instance.task_map = {} # 初始化任务映射 + return cls._instance + + def submit_task(self, fn, *args, thread_id=None, **kwargs): + """提交任务到线程池,并记录线程 ID""" + if thread_id is None: + thread_id = generate_thread_id() + if self.check_task_is_running(thread_id): + raise ValueError(f"线程 ID {thread_id} 已存在") + future = self.executor.submit(fn, *args, **kwargs) + self.task_map[thread_id] = future # 记录线程 ID 和 Future 对象的映射 + return thread_id + + def check_task_is_running(self, thread_id): + future = self.task_map.get(thread_id) + if future: + if future.running(): + return True + else: + del self.task_map[thread_id] + return False + else: + return False + + def stop_task(self, thread_id): + """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务""" + future = self.task_map.get(thread_id) + if future: + future.cancel() # 尝试取消任务 + print(f"任务 {thread_id} 已取消") + del self.task_map[thread_id] # 从任务映射中删除 + else: + print(f"未找到线程 ID {thread_id}") + + def shutdown(self, wait=True): + """关闭线程池""" + self.executor.shutdown(wait=wait) + GlobalThreadPool._instance = None diff --git a/db/__init__.py b/db/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/db/__init__.py diff --git a/db/database.py b/db/database.py new file mode 100644 index 0000000..4f44d2d --- /dev/null +++ b/db/database.py @@ -0,0 +1,22 @@ +from sqlmodel import SQLModel, create_engine, Session +from contextlib import contextmanager + +sqlite_file_name = "./db/safe-algo-pro.db" +sqlite_url = f"sqlite:///{sqlite_file_name}" + +connect_args = {"check_same_thread": False} +engine = create_engine(sqlite_url, connect_args=connect_args) + + +# 初始化数据库表 +def init_db(): + SQLModel.metadata.create_all(engine) + + +# 数据库会话管理 +def get_db(): + session = Session(engine) + try: + yield session + finally: + session.close() diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db new file mode 100644 index 0000000..86c8d64 --- /dev/null +++ b/db/safe-algo-pro.db Binary files differ diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e530a86 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +test.py +/logs/* \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..901fe74 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,20 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..b68d82f --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..3deddad --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml new file mode 100644 index 0000000..d0876a7 --- /dev/null +++ b/.idea/safe-algo-pro.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/algo/__init__.py b/algo/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/algo/__init__.py diff --git a/algo/algo_runner.py b/algo/algo_runner.py new file mode 100644 index 0000000..ee07eaf --- /dev/null +++ b/algo/algo_runner.py @@ -0,0 +1,117 @@ +from algo.model_manager import ModelManager +from common.consts import NotifyChangeType +from common.global_thread_pool import GlobalThreadPool +from db.database import get_db +from entity.device import Device +from services.device_model_relation_service import DeviceModelRelationService +from services.device_service import DeviceService +from services.model_service import ModelService + + +class AlgoRunner: + def __init__(self): + self.db = get_db() + self.device_service = DeviceService(self.db) + self.model_service = ModelService(self.db) + self.relation_service = DeviceModelRelationService(self.db) + self.model_manager = ModelManager(self.db) + self.thread_pool = GlobalThreadPool() + self.threads = {} # 用于存储设备对应的线程 + self.device_model_relations = {} + + # 注册设备和模型的变化回调 + self.device_service.register_change_callback(self.on_device_change) + self.model_service.register_change_callback(self.on_model_change) + self.relation_service.relation_change_callbacks(self.on_relation_change) + + def start(self): + """在程序启动时调用,读取设备和模型,启动检测线程""" + self.model_manager.load_models() + self.load_and_start_devices() + + def load_and_start_devices(self): + """从数据库读取设备列表并启动线程""" + devices = self.device_service.get_device_list() + devices = [device for device in devices if device.input_stream_url] + for device in devices: + self.start_device_thread(device) + + def run_detection(self, device: Device): + """todo 设备目标检测主逻辑""" + print(f"设备 {device.device_id} 的检测线程启动") + video_stream = self.device_service.get_video_stream(device.device_id) + + while True: + try: + frame = video_stream.read_frame() + for model in self.model_manager.models: + # 调用目标检测模型 + result = self.model_service.run_inference(model, frame) + # 在此处进行后处理,例如标记视频、生成告警等 + self.process_result(result, device.device_id) + except Exception as e: + print(f"设备 {device.device_id} 处理时出错: {e}") + break + video_stream.close() + print(f"设备 {device.device_id} 的检测线程结束") + + def start_device_thread(self, device: Device): + """为单个设备启动检测线程""" + + if not device.input_stream_url: + print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程') + return + + if device.device_id in self.threads: + print(f"设备 {device.device_id} 已经在运行中") + return + + # 获取设备绑定的模型列表 + relations = self.relation_service.get_device_models(device.device_id) + self.device_model_relations[device.device_id] = relations + + if not relations: + print(f"设备 {device.code} 未绑定模型,无法启动检测") + return + + future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}') + + def stop_device_thread(self, device_id): + """todo 控制标志位 停止指定设备的检测线程""" + if device_id in self.threads: + self.threads[device_id].cancel() # 尝试取消线程 + del self.threads[device_id] + print(f"设备 {device_id} 的检测线程已停止") + + def restart_device_thread(self, device_id): + self.stop_device_thread(device_id) + # todo 需要留个关闭时间吗 + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + + def on_device_change(self, device_id, change_type): + """设备变化回调""" + if change_type == NotifyChangeType.DEVICE_CREATE: + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + elif change_type == NotifyChangeType.DEVICE_DELETE: + self.stop_device_thread(device_id) + elif change_type == NotifyChangeType.DEVICE_UPDATE: + self.restart_device_thread(device_id) + + def on_model_change(self, model_id, change_type): + """模型变化回调""" + if change_type == NotifyChangeType.MODEL_UPDATE: + self.model_manager.reload_model(model_id) + + devices_to_reload = [ + device_id for device_id, relation_info in self.device_model_relations.items() + if relation_info.model_id == model_id + ] + for device_id in devices_to_reload: + self.restart_device_thread(device_id) + + def on_relation_change(self, device_id, change_type): + """设备模型关系变化回调""" + if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: + self.restart_device_thread(device_id) diff --git a/algo/model_manager.py b/algo/model_manager.py new file mode 100644 index 0000000..3ab0709 --- /dev/null +++ b/algo/model_manager.py @@ -0,0 +1,64 @@ +from dataclasses import dataclass +from typing import Optional + +import numpy as np +from ultralytics import YOLO + +from entity.model import AlgoModel +from services.model_service import ModelService +from common.global_logger import logger + + +@dataclass +class AlgoModelExec: + algo_model_id: int + algo_model_info: AlgoModel + algo_model_exec: Optional[object] = None + input_size: int = 640 + + +class ModelManager: + def __init__(self, db, model_warm_up=5): + self.db = db + self.model_service = ModelService(self.db) + self.models = {} + self.model_warm_up = model_warm_up + + self.load_models() + + def query_model_inuse(self): + algo_model_list = list(self.model_service.get_models_in_use()) + for algo_model in algo_model_list: + self.models[algo_model.id] = AlgoModelExec( + algo_model_id=algo_model.id, + algo_model_info=algo_model + ) + + def load_models(self): + self.models = {} + self.query_model_inuse() + for algo_model_id, algo_model_exec in self.models.items(): + self.load_model(algo_model_exec) + + def load_model(self, algo_model_exec: AlgoModelExec): + model_name = algo_model_exec.algo_model_info.name + model_path = algo_model_exec.algo_model_info.path + logger.info(f'loading model {model_name}: {model_path}') + + algo_model_exec.algo_model_exec = YOLO(model_path, task='detect') + if self.model_warm_up > 0: + logger.info(f'warming up model {model_name}') + imgsz = algo_model_exec.input_size + if not isinstance(imgsz, list): + imgsz = [imgsz, imgsz] + dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + for i in range(self.model_warm_up): + algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + logger.info(f'warm up model {model_name} success!') + logger.info(f'load model {model_name} success!') + + def reload_model(self, model_id): + algo_model_exec = self.models.get(model_id) + if algo_model_exec: + algo_model_exec.algo_model_exec = None + self.load_model(algo_model_exec) diff --git a/apis/__init__.py b/apis/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/apis/__init__.py diff --git a/apis/base.py b/apis/base.py new file mode 100644 index 0000000..a7c4ede --- /dev/null +++ b/apis/base.py @@ -0,0 +1,26 @@ +from pydantic import BaseModel, Field +from typing import Optional, Generic, TypeVar, Any, List + +# 定义一个泛型类型变量 T +T = TypeVar("T") + + +# 使用泛型 T 定义 data 的类型 +class StandardResponse(BaseModel, Generic[T]): + code: int = 200 + data: Optional[T] = None + message: str = "请求成功" + success: bool = True + + +class PageResponse(BaseModel, Generic[T]): + total: int + items: List[T] = Field(default_factory=list) + + +def standard_response(data: Any = None, code: int = 200, message: str = "请求成功", success: bool = True): + return StandardResponse(data=data, code=code, message=message, success=success) + + +def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False): + return StandardResponse(data=data, code=code, message=message, success=success) diff --git a/apis/device.py b/apis/device.py new file mode 100644 index 0000000..7eeaa78 --- /dev/null +++ b/apis/device.py @@ -0,0 +1,65 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, Query +from sqlmodel import Session + +from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response +from db.database import get_db +from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo +from services.device_service import DeviceService + +router = APIRouter() + + +@router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) +def get_device_list( + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + db: Session = Depends(get_db)): + service = DeviceService(db) + devices = list(service.get_device_list(name, code, device_type)) + return standard_response(data=devices) + + +@router.get("/page/", response_model=StandardResponse[PageResponse[Device]]) +def get_device_page( + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = DeviceService(db) + + # 获取分页后的设备列表和总数 + devices, total = service.get_device_page(name, code, device_type, offset, limit) + + return standard_response( + data=PageResponse(total=total, items=devices) + ) + + +@router.post("/add", response_model=StandardResponse[DeviceInfo]) +def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.create_device(device_data) + return standard_response(data=device) + + +@router.post("/update", response_model=StandardResponse[DeviceInfo]) +def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.update_device(device_data) + if not device: + return standard_error_response(data=device_data, message="Device not found") + return standard_response(data=device) + + +@router.delete("/delete", response_model=StandardResponse[int]) +def delete_device(device_id: int, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.delete_device(device_id) + if not device: + return standard_error_response(data=device_id, message="Device not found") + return standard_response(data=device_id) diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py new file mode 100644 index 0000000..7108571 --- /dev/null +++ b/apis/device_model_realtion.py @@ -0,0 +1,45 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, Query +from sqlmodel import Session + +from apis.base import StandardResponse, standard_response +from db.database import get_db +from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation +from services.device_model_relation_service import DeviceModelRelationService + +router = APIRouter() + + +@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) +def list_by_device( + device_id: int, + db: Session = Depends(get_db)): + service = DeviceModelRelationService(db) + models = list(service.get_device_models(device_id)) + return standard_response(data=models) + + +# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) +# def add_by_device(relation_data: List[DeviceModelRelationCreate], +# device_id: int = Query(...), +# db: Session = Depends(get_db)): +# service = DeviceModelRelationService(db) +# relations = service.add_relations_by_device(device_id, relation_data) +# return standard_response(data=relations) + + +@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) +def update_by_device(relation_data: List[DeviceModelRelationCreate], + device_id: int = Query(...), + db: Session = Depends(get_db)): + service = DeviceModelRelationService(db) + relations = service.update_relations_by_device(device_id, relation_data) + return standard_response(data=relations) + + +# @router.delete("/delete_by_device", response_model=StandardResponse[int]) +# def delete_device(device_id: int, db: Session = Depends(get_db)): +# service = DeviceModelRelationService(db) +# count = service.delete_relations_by_device(device_id) +# return standard_response(data=count) diff --git a/apis/model.py b/apis/model.py new file mode 100644 index 0000000..8adcc14 --- /dev/null +++ b/apis/model.py @@ -0,0 +1,64 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlmodel import Session + +from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response +from common.biz_exception import BizException +from db.database import get_db +from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo +from services.model_service import ModelService + +router = APIRouter() + + +@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) +def get_model_list( + name: Optional[str] = None, + remark: Optional[str] = None, + db: Session = Depends(get_db)): + service = ModelService(db) + models = list(service.get_model_list(name, remark)) + return standard_response(data=models) + + +@router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModel]]) +def get_model_page( + name: Optional[str] = None, + remark: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = ModelService(db) + + # 获取分页后的设备列表和总数 + models, total = service.get_model_page(name, remark, offset, limit) + + return standard_response( + data=PageResponse(total=total, items=models) + ) + + +@router.post("/add", response_model=StandardResponse[AlgoModelInfo]) +def create_model(model_data: AlgoModelCreate, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.create_model(model_data) + return standard_response(data=model) + + +@router.post("/update", response_model=StandardResponse[AlgoModelInfo]) +def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.update_model(model_data) + if not model: + return standard_error_response(data=model_data, message="Model not found") + return standard_response(data=model) + + +@router.delete("/delete", response_model=StandardResponse[int]) +def delete_model(model_id: int, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.delete_model(model_id) + if not model: + return standard_error_response(data=model_id, message="Model not found") + return standard_response(data=model_id) diff --git a/apis/router.py b/apis/router.py new file mode 100644 index 0000000..22b5ed6 --- /dev/null +++ b/apis/router.py @@ -0,0 +1,15 @@ +# api/router.py +from fastapi import APIRouter +from .device import router as devices_router +from .model import router as models_router +from .device_model_realtion import router as device_model_relation_router + + +# 创建一个全局的 router +router = APIRouter() + +# 将各个模块的 router 注册到全局 router 中 +router.include_router(devices_router, prefix="/device", tags=["Devices"]) +router.include_router(models_router, prefix="/model", tags=["Models"]) +router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"]) + diff --git a/common/__init__.py b/common/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/common/__init__.py diff --git a/common/biz_exception.py b/common/biz_exception.py new file mode 100644 index 0000000..946a794 --- /dev/null +++ b/common/biz_exception.py @@ -0,0 +1,14 @@ +from fastapi import Request, HTTPException + +from apis.base import standard_error_response + + +class BizException(HTTPException): + def __init__(self, message: str = "请求异常", status_code: int = 500): + super().__init__(status_code=status_code, detail=message) + + +class BizExceptionHandlers: + @staticmethod + async def biz_exception_handler(request: Request, exc: HTTPException): + standard_error_response(code=exc.status_code, message=exc.detail) diff --git a/common/consts.py b/common/consts.py new file mode 100644 index 0000000..bba41e4 --- /dev/null +++ b/common/consts.py @@ -0,0 +1,16 @@ +class Constants: + def __setattr__(self, name, value): + raise AttributeError(f"Cannot modify constant {name}") + + +class NotifyChangeType(Constants): + # 定义常量成员 + DEVICE_CREATE = "device_create" + DEVICE_UPDATE = "device_update" + DEVICE_DELETE = "device_delete" + MODEL_CREATE = "model_create" # 模型新增不用通知,与设备绑定时才生效 + MODEL_UPDATE = "model_update" + MODEL_DELETE = "model_delete" # 正在使用的模型不能直接删除 + DEVICE_MODEL_RELATION_CREATE = "device_model_relation_create" + DEVICE_MODEL_RELATION_UPDATE = "device_model_relation_update" # 绑定关系变化 应该只用这个吧?? + DEVICE_MODEL_RELATION_DELETE = "device_model_relation_delete" diff --git a/common/global_logger.py b/common/global_logger.py new file mode 100644 index 0000000..1d65173 --- /dev/null +++ b/common/global_logger.py @@ -0,0 +1,30 @@ +# logger.py +import logging.handlers +import os + +# 确保日志目录存在 +log_dir = '../logs' +if not os.path.exists(log_dir): + os.makedirs(log_dir) + +# 实例化并导出全局日志记录器 +logger = logging.getLogger("casic_safe_logger") +logger.setLevel(logging.DEBUG) # 设置日志级别 + +# 创建一个TimedRotatingFileHandler +handler = logging.handlers.TimedRotatingFileHandler( + os.path.join(log_dir, 'app.log'), # 日志文件名 + when='midnight', # 每天午夜滚动 + interval=1 # 滚动间隔为1天 +) +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +handler.setFormatter(formatter) + +# 将handler添加到日志器 +logger.addHandler(handler) + +# 创建控制台处理器 +console_handler = logging.StreamHandler() +console_handler.setLevel(logging.DEBUG) +console_handler.setFormatter(formatter) +logger.addHandler(console_handler) diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py new file mode 100644 index 0000000..21d7fd5 --- /dev/null +++ b/common/global_thread_pool.py @@ -0,0 +1,60 @@ +import uuid +from concurrent.futures import ThreadPoolExecutor +import threading + + +def generate_thread_id(): + """生成唯一的线程 ID""" + return str(uuid.uuid4()) + + +class GlobalThreadPool: + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + with cls._lock: + if cls._instance is None: + # 第一次创建实例时调用父类的 __new__ 来创建实例 + cls._instance = super(GlobalThreadPool, cls).__new__(cls) + # 在此进行一次性的初始化,比如线程池的创建 + max_workers = kwargs.get('max_workers', 10) + cls._instance.executor = ThreadPoolExecutor(max_workers=max_workers) + cls._instance.task_map = {} # 初始化任务映射 + return cls._instance + + def submit_task(self, fn, *args, thread_id=None, **kwargs): + """提交任务到线程池,并记录线程 ID""" + if thread_id is None: + thread_id = generate_thread_id() + if self.check_task_is_running(thread_id): + raise ValueError(f"线程 ID {thread_id} 已存在") + future = self.executor.submit(fn, *args, **kwargs) + self.task_map[thread_id] = future # 记录线程 ID 和 Future 对象的映射 + return thread_id + + def check_task_is_running(self, thread_id): + future = self.task_map.get(thread_id) + if future: + if future.running(): + return True + else: + del self.task_map[thread_id] + return False + else: + return False + + def stop_task(self, thread_id): + """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务""" + future = self.task_map.get(thread_id) + if future: + future.cancel() # 尝试取消任务 + print(f"任务 {thread_id} 已取消") + del self.task_map[thread_id] # 从任务映射中删除 + else: + print(f"未找到线程 ID {thread_id}") + + def shutdown(self, wait=True): + """关闭线程池""" + self.executor.shutdown(wait=wait) + GlobalThreadPool._instance = None diff --git a/db/__init__.py b/db/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/db/__init__.py diff --git a/db/database.py b/db/database.py new file mode 100644 index 0000000..4f44d2d --- /dev/null +++ b/db/database.py @@ -0,0 +1,22 @@ +from sqlmodel import SQLModel, create_engine, Session +from contextlib import contextmanager + +sqlite_file_name = "./db/safe-algo-pro.db" +sqlite_url = f"sqlite:///{sqlite_file_name}" + +connect_args = {"check_same_thread": False} +engine = create_engine(sqlite_url, connect_args=connect_args) + + +# 初始化数据库表 +def init_db(): + SQLModel.metadata.create_all(engine) + + +# 数据库会话管理 +def get_db(): + session = Session(engine) + try: + yield session + finally: + session.close() diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db new file mode 100644 index 0000000..86c8d64 --- /dev/null +++ b/db/safe-algo-pro.db Binary files differ diff --git a/entity/__init__.py b/entity/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/entity/__init__.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e530a86 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +test.py +/logs/* \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..901fe74 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,20 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..b68d82f --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..3deddad --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml new file mode 100644 index 0000000..d0876a7 --- /dev/null +++ b/.idea/safe-algo-pro.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/algo/__init__.py b/algo/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/algo/__init__.py diff --git a/algo/algo_runner.py b/algo/algo_runner.py new file mode 100644 index 0000000..ee07eaf --- /dev/null +++ b/algo/algo_runner.py @@ -0,0 +1,117 @@ +from algo.model_manager import ModelManager +from common.consts import NotifyChangeType +from common.global_thread_pool import GlobalThreadPool +from db.database import get_db +from entity.device import Device +from services.device_model_relation_service import DeviceModelRelationService +from services.device_service import DeviceService +from services.model_service import ModelService + + +class AlgoRunner: + def __init__(self): + self.db = get_db() + self.device_service = DeviceService(self.db) + self.model_service = ModelService(self.db) + self.relation_service = DeviceModelRelationService(self.db) + self.model_manager = ModelManager(self.db) + self.thread_pool = GlobalThreadPool() + self.threads = {} # 用于存储设备对应的线程 + self.device_model_relations = {} + + # 注册设备和模型的变化回调 + self.device_service.register_change_callback(self.on_device_change) + self.model_service.register_change_callback(self.on_model_change) + self.relation_service.relation_change_callbacks(self.on_relation_change) + + def start(self): + """在程序启动时调用,读取设备和模型,启动检测线程""" + self.model_manager.load_models() + self.load_and_start_devices() + + def load_and_start_devices(self): + """从数据库读取设备列表并启动线程""" + devices = self.device_service.get_device_list() + devices = [device for device in devices if device.input_stream_url] + for device in devices: + self.start_device_thread(device) + + def run_detection(self, device: Device): + """todo 设备目标检测主逻辑""" + print(f"设备 {device.device_id} 的检测线程启动") + video_stream = self.device_service.get_video_stream(device.device_id) + + while True: + try: + frame = video_stream.read_frame() + for model in self.model_manager.models: + # 调用目标检测模型 + result = self.model_service.run_inference(model, frame) + # 在此处进行后处理,例如标记视频、生成告警等 + self.process_result(result, device.device_id) + except Exception as e: + print(f"设备 {device.device_id} 处理时出错: {e}") + break + video_stream.close() + print(f"设备 {device.device_id} 的检测线程结束") + + def start_device_thread(self, device: Device): + """为单个设备启动检测线程""" + + if not device.input_stream_url: + print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程') + return + + if device.device_id in self.threads: + print(f"设备 {device.device_id} 已经在运行中") + return + + # 获取设备绑定的模型列表 + relations = self.relation_service.get_device_models(device.device_id) + self.device_model_relations[device.device_id] = relations + + if not relations: + print(f"设备 {device.code} 未绑定模型,无法启动检测") + return + + future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}') + + def stop_device_thread(self, device_id): + """todo 控制标志位 停止指定设备的检测线程""" + if device_id in self.threads: + self.threads[device_id].cancel() # 尝试取消线程 + del self.threads[device_id] + print(f"设备 {device_id} 的检测线程已停止") + + def restart_device_thread(self, device_id): + self.stop_device_thread(device_id) + # todo 需要留个关闭时间吗 + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + + def on_device_change(self, device_id, change_type): + """设备变化回调""" + if change_type == NotifyChangeType.DEVICE_CREATE: + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + elif change_type == NotifyChangeType.DEVICE_DELETE: + self.stop_device_thread(device_id) + elif change_type == NotifyChangeType.DEVICE_UPDATE: + self.restart_device_thread(device_id) + + def on_model_change(self, model_id, change_type): + """模型变化回调""" + if change_type == NotifyChangeType.MODEL_UPDATE: + self.model_manager.reload_model(model_id) + + devices_to_reload = [ + device_id for device_id, relation_info in self.device_model_relations.items() + if relation_info.model_id == model_id + ] + for device_id in devices_to_reload: + self.restart_device_thread(device_id) + + def on_relation_change(self, device_id, change_type): + """设备模型关系变化回调""" + if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: + self.restart_device_thread(device_id) diff --git a/algo/model_manager.py b/algo/model_manager.py new file mode 100644 index 0000000..3ab0709 --- /dev/null +++ b/algo/model_manager.py @@ -0,0 +1,64 @@ +from dataclasses import dataclass +from typing import Optional + +import numpy as np +from ultralytics import YOLO + +from entity.model import AlgoModel +from services.model_service import ModelService +from common.global_logger import logger + + +@dataclass +class AlgoModelExec: + algo_model_id: int + algo_model_info: AlgoModel + algo_model_exec: Optional[object] = None + input_size: int = 640 + + +class ModelManager: + def __init__(self, db, model_warm_up=5): + self.db = db + self.model_service = ModelService(self.db) + self.models = {} + self.model_warm_up = model_warm_up + + self.load_models() + + def query_model_inuse(self): + algo_model_list = list(self.model_service.get_models_in_use()) + for algo_model in algo_model_list: + self.models[algo_model.id] = AlgoModelExec( + algo_model_id=algo_model.id, + algo_model_info=algo_model + ) + + def load_models(self): + self.models = {} + self.query_model_inuse() + for algo_model_id, algo_model_exec in self.models.items(): + self.load_model(algo_model_exec) + + def load_model(self, algo_model_exec: AlgoModelExec): + model_name = algo_model_exec.algo_model_info.name + model_path = algo_model_exec.algo_model_info.path + logger.info(f'loading model {model_name}: {model_path}') + + algo_model_exec.algo_model_exec = YOLO(model_path, task='detect') + if self.model_warm_up > 0: + logger.info(f'warming up model {model_name}') + imgsz = algo_model_exec.input_size + if not isinstance(imgsz, list): + imgsz = [imgsz, imgsz] + dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + for i in range(self.model_warm_up): + algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + logger.info(f'warm up model {model_name} success!') + logger.info(f'load model {model_name} success!') + + def reload_model(self, model_id): + algo_model_exec = self.models.get(model_id) + if algo_model_exec: + algo_model_exec.algo_model_exec = None + self.load_model(algo_model_exec) diff --git a/apis/__init__.py b/apis/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/apis/__init__.py diff --git a/apis/base.py b/apis/base.py new file mode 100644 index 0000000..a7c4ede --- /dev/null +++ b/apis/base.py @@ -0,0 +1,26 @@ +from pydantic import BaseModel, Field +from typing import Optional, Generic, TypeVar, Any, List + +# 定义一个泛型类型变量 T +T = TypeVar("T") + + +# 使用泛型 T 定义 data 的类型 +class StandardResponse(BaseModel, Generic[T]): + code: int = 200 + data: Optional[T] = None + message: str = "请求成功" + success: bool = True + + +class PageResponse(BaseModel, Generic[T]): + total: int + items: List[T] = Field(default_factory=list) + + +def standard_response(data: Any = None, code: int = 200, message: str = "请求成功", success: bool = True): + return StandardResponse(data=data, code=code, message=message, success=success) + + +def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False): + return StandardResponse(data=data, code=code, message=message, success=success) diff --git a/apis/device.py b/apis/device.py new file mode 100644 index 0000000..7eeaa78 --- /dev/null +++ b/apis/device.py @@ -0,0 +1,65 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, Query +from sqlmodel import Session + +from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response +from db.database import get_db +from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo +from services.device_service import DeviceService + +router = APIRouter() + + +@router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) +def get_device_list( + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + db: Session = Depends(get_db)): + service = DeviceService(db) + devices = list(service.get_device_list(name, code, device_type)) + return standard_response(data=devices) + + +@router.get("/page/", response_model=StandardResponse[PageResponse[Device]]) +def get_device_page( + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = DeviceService(db) + + # 获取分页后的设备列表和总数 + devices, total = service.get_device_page(name, code, device_type, offset, limit) + + return standard_response( + data=PageResponse(total=total, items=devices) + ) + + +@router.post("/add", response_model=StandardResponse[DeviceInfo]) +def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.create_device(device_data) + return standard_response(data=device) + + +@router.post("/update", response_model=StandardResponse[DeviceInfo]) +def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.update_device(device_data) + if not device: + return standard_error_response(data=device_data, message="Device not found") + return standard_response(data=device) + + +@router.delete("/delete", response_model=StandardResponse[int]) +def delete_device(device_id: int, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.delete_device(device_id) + if not device: + return standard_error_response(data=device_id, message="Device not found") + return standard_response(data=device_id) diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py new file mode 100644 index 0000000..7108571 --- /dev/null +++ b/apis/device_model_realtion.py @@ -0,0 +1,45 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, Query +from sqlmodel import Session + +from apis.base import StandardResponse, standard_response +from db.database import get_db +from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation +from services.device_model_relation_service import DeviceModelRelationService + +router = APIRouter() + + +@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) +def list_by_device( + device_id: int, + db: Session = Depends(get_db)): + service = DeviceModelRelationService(db) + models = list(service.get_device_models(device_id)) + return standard_response(data=models) + + +# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) +# def add_by_device(relation_data: List[DeviceModelRelationCreate], +# device_id: int = Query(...), +# db: Session = Depends(get_db)): +# service = DeviceModelRelationService(db) +# relations = service.add_relations_by_device(device_id, relation_data) +# return standard_response(data=relations) + + +@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) +def update_by_device(relation_data: List[DeviceModelRelationCreate], + device_id: int = Query(...), + db: Session = Depends(get_db)): + service = DeviceModelRelationService(db) + relations = service.update_relations_by_device(device_id, relation_data) + return standard_response(data=relations) + + +# @router.delete("/delete_by_device", response_model=StandardResponse[int]) +# def delete_device(device_id: int, db: Session = Depends(get_db)): +# service = DeviceModelRelationService(db) +# count = service.delete_relations_by_device(device_id) +# return standard_response(data=count) diff --git a/apis/model.py b/apis/model.py new file mode 100644 index 0000000..8adcc14 --- /dev/null +++ b/apis/model.py @@ -0,0 +1,64 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlmodel import Session + +from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response +from common.biz_exception import BizException +from db.database import get_db +from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo +from services.model_service import ModelService + +router = APIRouter() + + +@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) +def get_model_list( + name: Optional[str] = None, + remark: Optional[str] = None, + db: Session = Depends(get_db)): + service = ModelService(db) + models = list(service.get_model_list(name, remark)) + return standard_response(data=models) + + +@router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModel]]) +def get_model_page( + name: Optional[str] = None, + remark: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = ModelService(db) + + # 获取分页后的设备列表和总数 + models, total = service.get_model_page(name, remark, offset, limit) + + return standard_response( + data=PageResponse(total=total, items=models) + ) + + +@router.post("/add", response_model=StandardResponse[AlgoModelInfo]) +def create_model(model_data: AlgoModelCreate, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.create_model(model_data) + return standard_response(data=model) + + +@router.post("/update", response_model=StandardResponse[AlgoModelInfo]) +def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.update_model(model_data) + if not model: + return standard_error_response(data=model_data, message="Model not found") + return standard_response(data=model) + + +@router.delete("/delete", response_model=StandardResponse[int]) +def delete_model(model_id: int, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.delete_model(model_id) + if not model: + return standard_error_response(data=model_id, message="Model not found") + return standard_response(data=model_id) diff --git a/apis/router.py b/apis/router.py new file mode 100644 index 0000000..22b5ed6 --- /dev/null +++ b/apis/router.py @@ -0,0 +1,15 @@ +# api/router.py +from fastapi import APIRouter +from .device import router as devices_router +from .model import router as models_router +from .device_model_realtion import router as device_model_relation_router + + +# 创建一个全局的 router +router = APIRouter() + +# 将各个模块的 router 注册到全局 router 中 +router.include_router(devices_router, prefix="/device", tags=["Devices"]) +router.include_router(models_router, prefix="/model", tags=["Models"]) +router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"]) + diff --git a/common/__init__.py b/common/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/common/__init__.py diff --git a/common/biz_exception.py b/common/biz_exception.py new file mode 100644 index 0000000..946a794 --- /dev/null +++ b/common/biz_exception.py @@ -0,0 +1,14 @@ +from fastapi import Request, HTTPException + +from apis.base import standard_error_response + + +class BizException(HTTPException): + def __init__(self, message: str = "请求异常", status_code: int = 500): + super().__init__(status_code=status_code, detail=message) + + +class BizExceptionHandlers: + @staticmethod + async def biz_exception_handler(request: Request, exc: HTTPException): + standard_error_response(code=exc.status_code, message=exc.detail) diff --git a/common/consts.py b/common/consts.py new file mode 100644 index 0000000..bba41e4 --- /dev/null +++ b/common/consts.py @@ -0,0 +1,16 @@ +class Constants: + def __setattr__(self, name, value): + raise AttributeError(f"Cannot modify constant {name}") + + +class NotifyChangeType(Constants): + # 定义常量成员 + DEVICE_CREATE = "device_create" + DEVICE_UPDATE = "device_update" + DEVICE_DELETE = "device_delete" + MODEL_CREATE = "model_create" # 模型新增不用通知,与设备绑定时才生效 + MODEL_UPDATE = "model_update" + MODEL_DELETE = "model_delete" # 正在使用的模型不能直接删除 + DEVICE_MODEL_RELATION_CREATE = "device_model_relation_create" + DEVICE_MODEL_RELATION_UPDATE = "device_model_relation_update" # 绑定关系变化 应该只用这个吧?? + DEVICE_MODEL_RELATION_DELETE = "device_model_relation_delete" diff --git a/common/global_logger.py b/common/global_logger.py new file mode 100644 index 0000000..1d65173 --- /dev/null +++ b/common/global_logger.py @@ -0,0 +1,30 @@ +# logger.py +import logging.handlers +import os + +# 确保日志目录存在 +log_dir = '../logs' +if not os.path.exists(log_dir): + os.makedirs(log_dir) + +# 实例化并导出全局日志记录器 +logger = logging.getLogger("casic_safe_logger") +logger.setLevel(logging.DEBUG) # 设置日志级别 + +# 创建一个TimedRotatingFileHandler +handler = logging.handlers.TimedRotatingFileHandler( + os.path.join(log_dir, 'app.log'), # 日志文件名 + when='midnight', # 每天午夜滚动 + interval=1 # 滚动间隔为1天 +) +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +handler.setFormatter(formatter) + +# 将handler添加到日志器 +logger.addHandler(handler) + +# 创建控制台处理器 +console_handler = logging.StreamHandler() +console_handler.setLevel(logging.DEBUG) +console_handler.setFormatter(formatter) +logger.addHandler(console_handler) diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py new file mode 100644 index 0000000..21d7fd5 --- /dev/null +++ b/common/global_thread_pool.py @@ -0,0 +1,60 @@ +import uuid +from concurrent.futures import ThreadPoolExecutor +import threading + + +def generate_thread_id(): + """生成唯一的线程 ID""" + return str(uuid.uuid4()) + + +class GlobalThreadPool: + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + with cls._lock: + if cls._instance is None: + # 第一次创建实例时调用父类的 __new__ 来创建实例 + cls._instance = super(GlobalThreadPool, cls).__new__(cls) + # 在此进行一次性的初始化,比如线程池的创建 + max_workers = kwargs.get('max_workers', 10) + cls._instance.executor = ThreadPoolExecutor(max_workers=max_workers) + cls._instance.task_map = {} # 初始化任务映射 + return cls._instance + + def submit_task(self, fn, *args, thread_id=None, **kwargs): + """提交任务到线程池,并记录线程 ID""" + if thread_id is None: + thread_id = generate_thread_id() + if self.check_task_is_running(thread_id): + raise ValueError(f"线程 ID {thread_id} 已存在") + future = self.executor.submit(fn, *args, **kwargs) + self.task_map[thread_id] = future # 记录线程 ID 和 Future 对象的映射 + return thread_id + + def check_task_is_running(self, thread_id): + future = self.task_map.get(thread_id) + if future: + if future.running(): + return True + else: + del self.task_map[thread_id] + return False + else: + return False + + def stop_task(self, thread_id): + """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务""" + future = self.task_map.get(thread_id) + if future: + future.cancel() # 尝试取消任务 + print(f"任务 {thread_id} 已取消") + del self.task_map[thread_id] # 从任务映射中删除 + else: + print(f"未找到线程 ID {thread_id}") + + def shutdown(self, wait=True): + """关闭线程池""" + self.executor.shutdown(wait=wait) + GlobalThreadPool._instance = None diff --git a/db/__init__.py b/db/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/db/__init__.py diff --git a/db/database.py b/db/database.py new file mode 100644 index 0000000..4f44d2d --- /dev/null +++ b/db/database.py @@ -0,0 +1,22 @@ +from sqlmodel import SQLModel, create_engine, Session +from contextlib import contextmanager + +sqlite_file_name = "./db/safe-algo-pro.db" +sqlite_url = f"sqlite:///{sqlite_file_name}" + +connect_args = {"check_same_thread": False} +engine = create_engine(sqlite_url, connect_args=connect_args) + + +# 初始化数据库表 +def init_db(): + SQLModel.metadata.create_all(engine) + + +# 数据库会话管理 +def get_db(): + session = Session(engine) + try: + yield session + finally: + session.close() diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db new file mode 100644 index 0000000..86c8d64 --- /dev/null +++ b/db/safe-algo-pro.db Binary files differ diff --git a/entity/__init__.py b/entity/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/entity/__init__.py diff --git a/entity/base.py b/entity/base.py new file mode 100644 index 0000000..6aa862c --- /dev/null +++ b/entity/base.py @@ -0,0 +1,13 @@ +from sqlmodel import SQLModel, Field +from typing import Optional +from datetime import datetime + + +class TimestampMixin(SQLModel): + create_time: Optional[datetime] = Field(default_factory=datetime.now) + update_time: Optional[datetime] = Field(default_factory=datetime.now) + + class Config: + json_encoders = { + datetime: lambda v: v.strftime('%Y-%m-%d %H:%M:%S') + } diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e530a86 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +test.py +/logs/* \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..901fe74 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,20 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..b68d82f --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..3deddad --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml new file mode 100644 index 0000000..d0876a7 --- /dev/null +++ b/.idea/safe-algo-pro.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/algo/__init__.py b/algo/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/algo/__init__.py diff --git a/algo/algo_runner.py b/algo/algo_runner.py new file mode 100644 index 0000000..ee07eaf --- /dev/null +++ b/algo/algo_runner.py @@ -0,0 +1,117 @@ +from algo.model_manager import ModelManager +from common.consts import NotifyChangeType +from common.global_thread_pool import GlobalThreadPool +from db.database import get_db +from entity.device import Device +from services.device_model_relation_service import DeviceModelRelationService +from services.device_service import DeviceService +from services.model_service import ModelService + + +class AlgoRunner: + def __init__(self): + self.db = get_db() + self.device_service = DeviceService(self.db) + self.model_service = ModelService(self.db) + self.relation_service = DeviceModelRelationService(self.db) + self.model_manager = ModelManager(self.db) + self.thread_pool = GlobalThreadPool() + self.threads = {} # 用于存储设备对应的线程 + self.device_model_relations = {} + + # 注册设备和模型的变化回调 + self.device_service.register_change_callback(self.on_device_change) + self.model_service.register_change_callback(self.on_model_change) + self.relation_service.relation_change_callbacks(self.on_relation_change) + + def start(self): + """在程序启动时调用,读取设备和模型,启动检测线程""" + self.model_manager.load_models() + self.load_and_start_devices() + + def load_and_start_devices(self): + """从数据库读取设备列表并启动线程""" + devices = self.device_service.get_device_list() + devices = [device for device in devices if device.input_stream_url] + for device in devices: + self.start_device_thread(device) + + def run_detection(self, device: Device): + """todo 设备目标检测主逻辑""" + print(f"设备 {device.device_id} 的检测线程启动") + video_stream = self.device_service.get_video_stream(device.device_id) + + while True: + try: + frame = video_stream.read_frame() + for model in self.model_manager.models: + # 调用目标检测模型 + result = self.model_service.run_inference(model, frame) + # 在此处进行后处理,例如标记视频、生成告警等 + self.process_result(result, device.device_id) + except Exception as e: + print(f"设备 {device.device_id} 处理时出错: {e}") + break + video_stream.close() + print(f"设备 {device.device_id} 的检测线程结束") + + def start_device_thread(self, device: Device): + """为单个设备启动检测线程""" + + if not device.input_stream_url: + print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程') + return + + if device.device_id in self.threads: + print(f"设备 {device.device_id} 已经在运行中") + return + + # 获取设备绑定的模型列表 + relations = self.relation_service.get_device_models(device.device_id) + self.device_model_relations[device.device_id] = relations + + if not relations: + print(f"设备 {device.code} 未绑定模型,无法启动检测") + return + + future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}') + + def stop_device_thread(self, device_id): + """todo 控制标志位 停止指定设备的检测线程""" + if device_id in self.threads: + self.threads[device_id].cancel() # 尝试取消线程 + del self.threads[device_id] + print(f"设备 {device_id} 的检测线程已停止") + + def restart_device_thread(self, device_id): + self.stop_device_thread(device_id) + # todo 需要留个关闭时间吗 + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + + def on_device_change(self, device_id, change_type): + """设备变化回调""" + if change_type == NotifyChangeType.DEVICE_CREATE: + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + elif change_type == NotifyChangeType.DEVICE_DELETE: + self.stop_device_thread(device_id) + elif change_type == NotifyChangeType.DEVICE_UPDATE: + self.restart_device_thread(device_id) + + def on_model_change(self, model_id, change_type): + """模型变化回调""" + if change_type == NotifyChangeType.MODEL_UPDATE: + self.model_manager.reload_model(model_id) + + devices_to_reload = [ + device_id for device_id, relation_info in self.device_model_relations.items() + if relation_info.model_id == model_id + ] + for device_id in devices_to_reload: + self.restart_device_thread(device_id) + + def on_relation_change(self, device_id, change_type): + """设备模型关系变化回调""" + if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: + self.restart_device_thread(device_id) diff --git a/algo/model_manager.py b/algo/model_manager.py new file mode 100644 index 0000000..3ab0709 --- /dev/null +++ b/algo/model_manager.py @@ -0,0 +1,64 @@ +from dataclasses import dataclass +from typing import Optional + +import numpy as np +from ultralytics import YOLO + +from entity.model import AlgoModel +from services.model_service import ModelService +from common.global_logger import logger + + +@dataclass +class AlgoModelExec: + algo_model_id: int + algo_model_info: AlgoModel + algo_model_exec: Optional[object] = None + input_size: int = 640 + + +class ModelManager: + def __init__(self, db, model_warm_up=5): + self.db = db + self.model_service = ModelService(self.db) + self.models = {} + self.model_warm_up = model_warm_up + + self.load_models() + + def query_model_inuse(self): + algo_model_list = list(self.model_service.get_models_in_use()) + for algo_model in algo_model_list: + self.models[algo_model.id] = AlgoModelExec( + algo_model_id=algo_model.id, + algo_model_info=algo_model + ) + + def load_models(self): + self.models = {} + self.query_model_inuse() + for algo_model_id, algo_model_exec in self.models.items(): + self.load_model(algo_model_exec) + + def load_model(self, algo_model_exec: AlgoModelExec): + model_name = algo_model_exec.algo_model_info.name + model_path = algo_model_exec.algo_model_info.path + logger.info(f'loading model {model_name}: {model_path}') + + algo_model_exec.algo_model_exec = YOLO(model_path, task='detect') + if self.model_warm_up > 0: + logger.info(f'warming up model {model_name}') + imgsz = algo_model_exec.input_size + if not isinstance(imgsz, list): + imgsz = [imgsz, imgsz] + dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + for i in range(self.model_warm_up): + algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + logger.info(f'warm up model {model_name} success!') + logger.info(f'load model {model_name} success!') + + def reload_model(self, model_id): + algo_model_exec = self.models.get(model_id) + if algo_model_exec: + algo_model_exec.algo_model_exec = None + self.load_model(algo_model_exec) diff --git a/apis/__init__.py b/apis/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/apis/__init__.py diff --git a/apis/base.py b/apis/base.py new file mode 100644 index 0000000..a7c4ede --- /dev/null +++ b/apis/base.py @@ -0,0 +1,26 @@ +from pydantic import BaseModel, Field +from typing import Optional, Generic, TypeVar, Any, List + +# 定义一个泛型类型变量 T +T = TypeVar("T") + + +# 使用泛型 T 定义 data 的类型 +class StandardResponse(BaseModel, Generic[T]): + code: int = 200 + data: Optional[T] = None + message: str = "请求成功" + success: bool = True + + +class PageResponse(BaseModel, Generic[T]): + total: int + items: List[T] = Field(default_factory=list) + + +def standard_response(data: Any = None, code: int = 200, message: str = "请求成功", success: bool = True): + return StandardResponse(data=data, code=code, message=message, success=success) + + +def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False): + return StandardResponse(data=data, code=code, message=message, success=success) diff --git a/apis/device.py b/apis/device.py new file mode 100644 index 0000000..7eeaa78 --- /dev/null +++ b/apis/device.py @@ -0,0 +1,65 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, Query +from sqlmodel import Session + +from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response +from db.database import get_db +from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo +from services.device_service import DeviceService + +router = APIRouter() + + +@router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) +def get_device_list( + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + db: Session = Depends(get_db)): + service = DeviceService(db) + devices = list(service.get_device_list(name, code, device_type)) + return standard_response(data=devices) + + +@router.get("/page/", response_model=StandardResponse[PageResponse[Device]]) +def get_device_page( + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = DeviceService(db) + + # 获取分页后的设备列表和总数 + devices, total = service.get_device_page(name, code, device_type, offset, limit) + + return standard_response( + data=PageResponse(total=total, items=devices) + ) + + +@router.post("/add", response_model=StandardResponse[DeviceInfo]) +def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.create_device(device_data) + return standard_response(data=device) + + +@router.post("/update", response_model=StandardResponse[DeviceInfo]) +def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.update_device(device_data) + if not device: + return standard_error_response(data=device_data, message="Device not found") + return standard_response(data=device) + + +@router.delete("/delete", response_model=StandardResponse[int]) +def delete_device(device_id: int, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.delete_device(device_id) + if not device: + return standard_error_response(data=device_id, message="Device not found") + return standard_response(data=device_id) diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py new file mode 100644 index 0000000..7108571 --- /dev/null +++ b/apis/device_model_realtion.py @@ -0,0 +1,45 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, Query +from sqlmodel import Session + +from apis.base import StandardResponse, standard_response +from db.database import get_db +from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation +from services.device_model_relation_service import DeviceModelRelationService + +router = APIRouter() + + +@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) +def list_by_device( + device_id: int, + db: Session = Depends(get_db)): + service = DeviceModelRelationService(db) + models = list(service.get_device_models(device_id)) + return standard_response(data=models) + + +# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) +# def add_by_device(relation_data: List[DeviceModelRelationCreate], +# device_id: int = Query(...), +# db: Session = Depends(get_db)): +# service = DeviceModelRelationService(db) +# relations = service.add_relations_by_device(device_id, relation_data) +# return standard_response(data=relations) + + +@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) +def update_by_device(relation_data: List[DeviceModelRelationCreate], + device_id: int = Query(...), + db: Session = Depends(get_db)): + service = DeviceModelRelationService(db) + relations = service.update_relations_by_device(device_id, relation_data) + return standard_response(data=relations) + + +# @router.delete("/delete_by_device", response_model=StandardResponse[int]) +# def delete_device(device_id: int, db: Session = Depends(get_db)): +# service = DeviceModelRelationService(db) +# count = service.delete_relations_by_device(device_id) +# return standard_response(data=count) diff --git a/apis/model.py b/apis/model.py new file mode 100644 index 0000000..8adcc14 --- /dev/null +++ b/apis/model.py @@ -0,0 +1,64 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlmodel import Session + +from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response +from common.biz_exception import BizException +from db.database import get_db +from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo +from services.model_service import ModelService + +router = APIRouter() + + +@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) +def get_model_list( + name: Optional[str] = None, + remark: Optional[str] = None, + db: Session = Depends(get_db)): + service = ModelService(db) + models = list(service.get_model_list(name, remark)) + return standard_response(data=models) + + +@router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModel]]) +def get_model_page( + name: Optional[str] = None, + remark: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = ModelService(db) + + # 获取分页后的设备列表和总数 + models, total = service.get_model_page(name, remark, offset, limit) + + return standard_response( + data=PageResponse(total=total, items=models) + ) + + +@router.post("/add", response_model=StandardResponse[AlgoModelInfo]) +def create_model(model_data: AlgoModelCreate, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.create_model(model_data) + return standard_response(data=model) + + +@router.post("/update", response_model=StandardResponse[AlgoModelInfo]) +def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.update_model(model_data) + if not model: + return standard_error_response(data=model_data, message="Model not found") + return standard_response(data=model) + + +@router.delete("/delete", response_model=StandardResponse[int]) +def delete_model(model_id: int, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.delete_model(model_id) + if not model: + return standard_error_response(data=model_id, message="Model not found") + return standard_response(data=model_id) diff --git a/apis/router.py b/apis/router.py new file mode 100644 index 0000000..22b5ed6 --- /dev/null +++ b/apis/router.py @@ -0,0 +1,15 @@ +# api/router.py +from fastapi import APIRouter +from .device import router as devices_router +from .model import router as models_router +from .device_model_realtion import router as device_model_relation_router + + +# 创建一个全局的 router +router = APIRouter() + +# 将各个模块的 router 注册到全局 router 中 +router.include_router(devices_router, prefix="/device", tags=["Devices"]) +router.include_router(models_router, prefix="/model", tags=["Models"]) +router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"]) + diff --git a/common/__init__.py b/common/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/common/__init__.py diff --git a/common/biz_exception.py b/common/biz_exception.py new file mode 100644 index 0000000..946a794 --- /dev/null +++ b/common/biz_exception.py @@ -0,0 +1,14 @@ +from fastapi import Request, HTTPException + +from apis.base import standard_error_response + + +class BizException(HTTPException): + def __init__(self, message: str = "请求异常", status_code: int = 500): + super().__init__(status_code=status_code, detail=message) + + +class BizExceptionHandlers: + @staticmethod + async def biz_exception_handler(request: Request, exc: HTTPException): + standard_error_response(code=exc.status_code, message=exc.detail) diff --git a/common/consts.py b/common/consts.py new file mode 100644 index 0000000..bba41e4 --- /dev/null +++ b/common/consts.py @@ -0,0 +1,16 @@ +class Constants: + def __setattr__(self, name, value): + raise AttributeError(f"Cannot modify constant {name}") + + +class NotifyChangeType(Constants): + # 定义常量成员 + DEVICE_CREATE = "device_create" + DEVICE_UPDATE = "device_update" + DEVICE_DELETE = "device_delete" + MODEL_CREATE = "model_create" # 模型新增不用通知,与设备绑定时才生效 + MODEL_UPDATE = "model_update" + MODEL_DELETE = "model_delete" # 正在使用的模型不能直接删除 + DEVICE_MODEL_RELATION_CREATE = "device_model_relation_create" + DEVICE_MODEL_RELATION_UPDATE = "device_model_relation_update" # 绑定关系变化 应该只用这个吧?? + DEVICE_MODEL_RELATION_DELETE = "device_model_relation_delete" diff --git a/common/global_logger.py b/common/global_logger.py new file mode 100644 index 0000000..1d65173 --- /dev/null +++ b/common/global_logger.py @@ -0,0 +1,30 @@ +# logger.py +import logging.handlers +import os + +# 确保日志目录存在 +log_dir = '../logs' +if not os.path.exists(log_dir): + os.makedirs(log_dir) + +# 实例化并导出全局日志记录器 +logger = logging.getLogger("casic_safe_logger") +logger.setLevel(logging.DEBUG) # 设置日志级别 + +# 创建一个TimedRotatingFileHandler +handler = logging.handlers.TimedRotatingFileHandler( + os.path.join(log_dir, 'app.log'), # 日志文件名 + when='midnight', # 每天午夜滚动 + interval=1 # 滚动间隔为1天 +) +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +handler.setFormatter(formatter) + +# 将handler添加到日志器 +logger.addHandler(handler) + +# 创建控制台处理器 +console_handler = logging.StreamHandler() +console_handler.setLevel(logging.DEBUG) +console_handler.setFormatter(formatter) +logger.addHandler(console_handler) diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py new file mode 100644 index 0000000..21d7fd5 --- /dev/null +++ b/common/global_thread_pool.py @@ -0,0 +1,60 @@ +import uuid +from concurrent.futures import ThreadPoolExecutor +import threading + + +def generate_thread_id(): + """生成唯一的线程 ID""" + return str(uuid.uuid4()) + + +class GlobalThreadPool: + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + with cls._lock: + if cls._instance is None: + # 第一次创建实例时调用父类的 __new__ 来创建实例 + cls._instance = super(GlobalThreadPool, cls).__new__(cls) + # 在此进行一次性的初始化,比如线程池的创建 + max_workers = kwargs.get('max_workers', 10) + cls._instance.executor = ThreadPoolExecutor(max_workers=max_workers) + cls._instance.task_map = {} # 初始化任务映射 + return cls._instance + + def submit_task(self, fn, *args, thread_id=None, **kwargs): + """提交任务到线程池,并记录线程 ID""" + if thread_id is None: + thread_id = generate_thread_id() + if self.check_task_is_running(thread_id): + raise ValueError(f"线程 ID {thread_id} 已存在") + future = self.executor.submit(fn, *args, **kwargs) + self.task_map[thread_id] = future # 记录线程 ID 和 Future 对象的映射 + return thread_id + + def check_task_is_running(self, thread_id): + future = self.task_map.get(thread_id) + if future: + if future.running(): + return True + else: + del self.task_map[thread_id] + return False + else: + return False + + def stop_task(self, thread_id): + """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务""" + future = self.task_map.get(thread_id) + if future: + future.cancel() # 尝试取消任务 + print(f"任务 {thread_id} 已取消") + del self.task_map[thread_id] # 从任务映射中删除 + else: + print(f"未找到线程 ID {thread_id}") + + def shutdown(self, wait=True): + """关闭线程池""" + self.executor.shutdown(wait=wait) + GlobalThreadPool._instance = None diff --git a/db/__init__.py b/db/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/db/__init__.py diff --git a/db/database.py b/db/database.py new file mode 100644 index 0000000..4f44d2d --- /dev/null +++ b/db/database.py @@ -0,0 +1,22 @@ +from sqlmodel import SQLModel, create_engine, Session +from contextlib import contextmanager + +sqlite_file_name = "./db/safe-algo-pro.db" +sqlite_url = f"sqlite:///{sqlite_file_name}" + +connect_args = {"check_same_thread": False} +engine = create_engine(sqlite_url, connect_args=connect_args) + + +# 初始化数据库表 +def init_db(): + SQLModel.metadata.create_all(engine) + + +# 数据库会话管理 +def get_db(): + session = Session(engine) + try: + yield session + finally: + session.close() diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db new file mode 100644 index 0000000..86c8d64 --- /dev/null +++ b/db/safe-algo-pro.db Binary files differ diff --git a/entity/__init__.py b/entity/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/entity/__init__.py diff --git a/entity/base.py b/entity/base.py new file mode 100644 index 0000000..6aa862c --- /dev/null +++ b/entity/base.py @@ -0,0 +1,13 @@ +from sqlmodel import SQLModel, Field +from typing import Optional +from datetime import datetime + + +class TimestampMixin(SQLModel): + create_time: Optional[datetime] = Field(default_factory=datetime.now) + update_time: Optional[datetime] = Field(default_factory=datetime.now) + + class Config: + json_encoders = { + datetime: lambda v: v.strftime('%Y-%m-%d %H:%M:%S') + } diff --git a/entity/device.py b/entity/device.py new file mode 100644 index 0000000..f09b52d --- /dev/null +++ b/entity/device.py @@ -0,0 +1,33 @@ +from typing import Optional +from sqlmodel import SQLModel, Field + +from entity.base import TimestampMixin + + +class DeviceBase(SQLModel): + name: str + code: str + type: Optional[str] = None + ip: str + input_stream_url: Optional[str] = None + output_stream_url: Optional[str] = None + image_save_interval: Optional[int] = None + + +class Device(DeviceBase, TimestampMixin, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + + +class DeviceCreate(DeviceBase): + pass + + +class DeviceUpdate(DeviceBase): + id: int + name: Optional[str] = None + code: Optional[str] = None + ip: Optional[str] = None + + +class DeviceInfo(DeviceBase, TimestampMixin): + id: int diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e530a86 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +test.py +/logs/* \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..901fe74 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,20 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..b68d82f --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..3deddad --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml new file mode 100644 index 0000000..d0876a7 --- /dev/null +++ b/.idea/safe-algo-pro.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/algo/__init__.py b/algo/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/algo/__init__.py diff --git a/algo/algo_runner.py b/algo/algo_runner.py new file mode 100644 index 0000000..ee07eaf --- /dev/null +++ b/algo/algo_runner.py @@ -0,0 +1,117 @@ +from algo.model_manager import ModelManager +from common.consts import NotifyChangeType +from common.global_thread_pool import GlobalThreadPool +from db.database import get_db +from entity.device import Device +from services.device_model_relation_service import DeviceModelRelationService +from services.device_service import DeviceService +from services.model_service import ModelService + + +class AlgoRunner: + def __init__(self): + self.db = get_db() + self.device_service = DeviceService(self.db) + self.model_service = ModelService(self.db) + self.relation_service = DeviceModelRelationService(self.db) + self.model_manager = ModelManager(self.db) + self.thread_pool = GlobalThreadPool() + self.threads = {} # 用于存储设备对应的线程 + self.device_model_relations = {} + + # 注册设备和模型的变化回调 + self.device_service.register_change_callback(self.on_device_change) + self.model_service.register_change_callback(self.on_model_change) + self.relation_service.relation_change_callbacks(self.on_relation_change) + + def start(self): + """在程序启动时调用,读取设备和模型,启动检测线程""" + self.model_manager.load_models() + self.load_and_start_devices() + + def load_and_start_devices(self): + """从数据库读取设备列表并启动线程""" + devices = self.device_service.get_device_list() + devices = [device for device in devices if device.input_stream_url] + for device in devices: + self.start_device_thread(device) + + def run_detection(self, device: Device): + """todo 设备目标检测主逻辑""" + print(f"设备 {device.device_id} 的检测线程启动") + video_stream = self.device_service.get_video_stream(device.device_id) + + while True: + try: + frame = video_stream.read_frame() + for model in self.model_manager.models: + # 调用目标检测模型 + result = self.model_service.run_inference(model, frame) + # 在此处进行后处理,例如标记视频、生成告警等 + self.process_result(result, device.device_id) + except Exception as e: + print(f"设备 {device.device_id} 处理时出错: {e}") + break + video_stream.close() + print(f"设备 {device.device_id} 的检测线程结束") + + def start_device_thread(self, device: Device): + """为单个设备启动检测线程""" + + if not device.input_stream_url: + print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程') + return + + if device.device_id in self.threads: + print(f"设备 {device.device_id} 已经在运行中") + return + + # 获取设备绑定的模型列表 + relations = self.relation_service.get_device_models(device.device_id) + self.device_model_relations[device.device_id] = relations + + if not relations: + print(f"设备 {device.code} 未绑定模型,无法启动检测") + return + + future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}') + + def stop_device_thread(self, device_id): + """todo 控制标志位 停止指定设备的检测线程""" + if device_id in self.threads: + self.threads[device_id].cancel() # 尝试取消线程 + del self.threads[device_id] + print(f"设备 {device_id} 的检测线程已停止") + + def restart_device_thread(self, device_id): + self.stop_device_thread(device_id) + # todo 需要留个关闭时间吗 + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + + def on_device_change(self, device_id, change_type): + """设备变化回调""" + if change_type == NotifyChangeType.DEVICE_CREATE: + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + elif change_type == NotifyChangeType.DEVICE_DELETE: + self.stop_device_thread(device_id) + elif change_type == NotifyChangeType.DEVICE_UPDATE: + self.restart_device_thread(device_id) + + def on_model_change(self, model_id, change_type): + """模型变化回调""" + if change_type == NotifyChangeType.MODEL_UPDATE: + self.model_manager.reload_model(model_id) + + devices_to_reload = [ + device_id for device_id, relation_info in self.device_model_relations.items() + if relation_info.model_id == model_id + ] + for device_id in devices_to_reload: + self.restart_device_thread(device_id) + + def on_relation_change(self, device_id, change_type): + """设备模型关系变化回调""" + if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: + self.restart_device_thread(device_id) diff --git a/algo/model_manager.py b/algo/model_manager.py new file mode 100644 index 0000000..3ab0709 --- /dev/null +++ b/algo/model_manager.py @@ -0,0 +1,64 @@ +from dataclasses import dataclass +from typing import Optional + +import numpy as np +from ultralytics import YOLO + +from entity.model import AlgoModel +from services.model_service import ModelService +from common.global_logger import logger + + +@dataclass +class AlgoModelExec: + algo_model_id: int + algo_model_info: AlgoModel + algo_model_exec: Optional[object] = None + input_size: int = 640 + + +class ModelManager: + def __init__(self, db, model_warm_up=5): + self.db = db + self.model_service = ModelService(self.db) + self.models = {} + self.model_warm_up = model_warm_up + + self.load_models() + + def query_model_inuse(self): + algo_model_list = list(self.model_service.get_models_in_use()) + for algo_model in algo_model_list: + self.models[algo_model.id] = AlgoModelExec( + algo_model_id=algo_model.id, + algo_model_info=algo_model + ) + + def load_models(self): + self.models = {} + self.query_model_inuse() + for algo_model_id, algo_model_exec in self.models.items(): + self.load_model(algo_model_exec) + + def load_model(self, algo_model_exec: AlgoModelExec): + model_name = algo_model_exec.algo_model_info.name + model_path = algo_model_exec.algo_model_info.path + logger.info(f'loading model {model_name}: {model_path}') + + algo_model_exec.algo_model_exec = YOLO(model_path, task='detect') + if self.model_warm_up > 0: + logger.info(f'warming up model {model_name}') + imgsz = algo_model_exec.input_size + if not isinstance(imgsz, list): + imgsz = [imgsz, imgsz] + dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + for i in range(self.model_warm_up): + algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + logger.info(f'warm up model {model_name} success!') + logger.info(f'load model {model_name} success!') + + def reload_model(self, model_id): + algo_model_exec = self.models.get(model_id) + if algo_model_exec: + algo_model_exec.algo_model_exec = None + self.load_model(algo_model_exec) diff --git a/apis/__init__.py b/apis/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/apis/__init__.py diff --git a/apis/base.py b/apis/base.py new file mode 100644 index 0000000..a7c4ede --- /dev/null +++ b/apis/base.py @@ -0,0 +1,26 @@ +from pydantic import BaseModel, Field +from typing import Optional, Generic, TypeVar, Any, List + +# 定义一个泛型类型变量 T +T = TypeVar("T") + + +# 使用泛型 T 定义 data 的类型 +class StandardResponse(BaseModel, Generic[T]): + code: int = 200 + data: Optional[T] = None + message: str = "请求成功" + success: bool = True + + +class PageResponse(BaseModel, Generic[T]): + total: int + items: List[T] = Field(default_factory=list) + + +def standard_response(data: Any = None, code: int = 200, message: str = "请求成功", success: bool = True): + return StandardResponse(data=data, code=code, message=message, success=success) + + +def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False): + return StandardResponse(data=data, code=code, message=message, success=success) diff --git a/apis/device.py b/apis/device.py new file mode 100644 index 0000000..7eeaa78 --- /dev/null +++ b/apis/device.py @@ -0,0 +1,65 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, Query +from sqlmodel import Session + +from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response +from db.database import get_db +from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo +from services.device_service import DeviceService + +router = APIRouter() + + +@router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) +def get_device_list( + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + db: Session = Depends(get_db)): + service = DeviceService(db) + devices = list(service.get_device_list(name, code, device_type)) + return standard_response(data=devices) + + +@router.get("/page/", response_model=StandardResponse[PageResponse[Device]]) +def get_device_page( + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = DeviceService(db) + + # 获取分页后的设备列表和总数 + devices, total = service.get_device_page(name, code, device_type, offset, limit) + + return standard_response( + data=PageResponse(total=total, items=devices) + ) + + +@router.post("/add", response_model=StandardResponse[DeviceInfo]) +def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.create_device(device_data) + return standard_response(data=device) + + +@router.post("/update", response_model=StandardResponse[DeviceInfo]) +def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.update_device(device_data) + if not device: + return standard_error_response(data=device_data, message="Device not found") + return standard_response(data=device) + + +@router.delete("/delete", response_model=StandardResponse[int]) +def delete_device(device_id: int, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.delete_device(device_id) + if not device: + return standard_error_response(data=device_id, message="Device not found") + return standard_response(data=device_id) diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py new file mode 100644 index 0000000..7108571 --- /dev/null +++ b/apis/device_model_realtion.py @@ -0,0 +1,45 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, Query +from sqlmodel import Session + +from apis.base import StandardResponse, standard_response +from db.database import get_db +from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation +from services.device_model_relation_service import DeviceModelRelationService + +router = APIRouter() + + +@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) +def list_by_device( + device_id: int, + db: Session = Depends(get_db)): + service = DeviceModelRelationService(db) + models = list(service.get_device_models(device_id)) + return standard_response(data=models) + + +# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) +# def add_by_device(relation_data: List[DeviceModelRelationCreate], +# device_id: int = Query(...), +# db: Session = Depends(get_db)): +# service = DeviceModelRelationService(db) +# relations = service.add_relations_by_device(device_id, relation_data) +# return standard_response(data=relations) + + +@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) +def update_by_device(relation_data: List[DeviceModelRelationCreate], + device_id: int = Query(...), + db: Session = Depends(get_db)): + service = DeviceModelRelationService(db) + relations = service.update_relations_by_device(device_id, relation_data) + return standard_response(data=relations) + + +# @router.delete("/delete_by_device", response_model=StandardResponse[int]) +# def delete_device(device_id: int, db: Session = Depends(get_db)): +# service = DeviceModelRelationService(db) +# count = service.delete_relations_by_device(device_id) +# return standard_response(data=count) diff --git a/apis/model.py b/apis/model.py new file mode 100644 index 0000000..8adcc14 --- /dev/null +++ b/apis/model.py @@ -0,0 +1,64 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlmodel import Session + +from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response +from common.biz_exception import BizException +from db.database import get_db +from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo +from services.model_service import ModelService + +router = APIRouter() + + +@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) +def get_model_list( + name: Optional[str] = None, + remark: Optional[str] = None, + db: Session = Depends(get_db)): + service = ModelService(db) + models = list(service.get_model_list(name, remark)) + return standard_response(data=models) + + +@router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModel]]) +def get_model_page( + name: Optional[str] = None, + remark: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = ModelService(db) + + # 获取分页后的设备列表和总数 + models, total = service.get_model_page(name, remark, offset, limit) + + return standard_response( + data=PageResponse(total=total, items=models) + ) + + +@router.post("/add", response_model=StandardResponse[AlgoModelInfo]) +def create_model(model_data: AlgoModelCreate, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.create_model(model_data) + return standard_response(data=model) + + +@router.post("/update", response_model=StandardResponse[AlgoModelInfo]) +def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.update_model(model_data) + if not model: + return standard_error_response(data=model_data, message="Model not found") + return standard_response(data=model) + + +@router.delete("/delete", response_model=StandardResponse[int]) +def delete_model(model_id: int, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.delete_model(model_id) + if not model: + return standard_error_response(data=model_id, message="Model not found") + return standard_response(data=model_id) diff --git a/apis/router.py b/apis/router.py new file mode 100644 index 0000000..22b5ed6 --- /dev/null +++ b/apis/router.py @@ -0,0 +1,15 @@ +# api/router.py +from fastapi import APIRouter +from .device import router as devices_router +from .model import router as models_router +from .device_model_realtion import router as device_model_relation_router + + +# 创建一个全局的 router +router = APIRouter() + +# 将各个模块的 router 注册到全局 router 中 +router.include_router(devices_router, prefix="/device", tags=["Devices"]) +router.include_router(models_router, prefix="/model", tags=["Models"]) +router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"]) + diff --git a/common/__init__.py b/common/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/common/__init__.py diff --git a/common/biz_exception.py b/common/biz_exception.py new file mode 100644 index 0000000..946a794 --- /dev/null +++ b/common/biz_exception.py @@ -0,0 +1,14 @@ +from fastapi import Request, HTTPException + +from apis.base import standard_error_response + + +class BizException(HTTPException): + def __init__(self, message: str = "请求异常", status_code: int = 500): + super().__init__(status_code=status_code, detail=message) + + +class BizExceptionHandlers: + @staticmethod + async def biz_exception_handler(request: Request, exc: HTTPException): + standard_error_response(code=exc.status_code, message=exc.detail) diff --git a/common/consts.py b/common/consts.py new file mode 100644 index 0000000..bba41e4 --- /dev/null +++ b/common/consts.py @@ -0,0 +1,16 @@ +class Constants: + def __setattr__(self, name, value): + raise AttributeError(f"Cannot modify constant {name}") + + +class NotifyChangeType(Constants): + # 定义常量成员 + DEVICE_CREATE = "device_create" + DEVICE_UPDATE = "device_update" + DEVICE_DELETE = "device_delete" + MODEL_CREATE = "model_create" # 模型新增不用通知,与设备绑定时才生效 + MODEL_UPDATE = "model_update" + MODEL_DELETE = "model_delete" # 正在使用的模型不能直接删除 + DEVICE_MODEL_RELATION_CREATE = "device_model_relation_create" + DEVICE_MODEL_RELATION_UPDATE = "device_model_relation_update" # 绑定关系变化 应该只用这个吧?? + DEVICE_MODEL_RELATION_DELETE = "device_model_relation_delete" diff --git a/common/global_logger.py b/common/global_logger.py new file mode 100644 index 0000000..1d65173 --- /dev/null +++ b/common/global_logger.py @@ -0,0 +1,30 @@ +# logger.py +import logging.handlers +import os + +# 确保日志目录存在 +log_dir = '../logs' +if not os.path.exists(log_dir): + os.makedirs(log_dir) + +# 实例化并导出全局日志记录器 +logger = logging.getLogger("casic_safe_logger") +logger.setLevel(logging.DEBUG) # 设置日志级别 + +# 创建一个TimedRotatingFileHandler +handler = logging.handlers.TimedRotatingFileHandler( + os.path.join(log_dir, 'app.log'), # 日志文件名 + when='midnight', # 每天午夜滚动 + interval=1 # 滚动间隔为1天 +) +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +handler.setFormatter(formatter) + +# 将handler添加到日志器 +logger.addHandler(handler) + +# 创建控制台处理器 +console_handler = logging.StreamHandler() +console_handler.setLevel(logging.DEBUG) +console_handler.setFormatter(formatter) +logger.addHandler(console_handler) diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py new file mode 100644 index 0000000..21d7fd5 --- /dev/null +++ b/common/global_thread_pool.py @@ -0,0 +1,60 @@ +import uuid +from concurrent.futures import ThreadPoolExecutor +import threading + + +def generate_thread_id(): + """生成唯一的线程 ID""" + return str(uuid.uuid4()) + + +class GlobalThreadPool: + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + with cls._lock: + if cls._instance is None: + # 第一次创建实例时调用父类的 __new__ 来创建实例 + cls._instance = super(GlobalThreadPool, cls).__new__(cls) + # 在此进行一次性的初始化,比如线程池的创建 + max_workers = kwargs.get('max_workers', 10) + cls._instance.executor = ThreadPoolExecutor(max_workers=max_workers) + cls._instance.task_map = {} # 初始化任务映射 + return cls._instance + + def submit_task(self, fn, *args, thread_id=None, **kwargs): + """提交任务到线程池,并记录线程 ID""" + if thread_id is None: + thread_id = generate_thread_id() + if self.check_task_is_running(thread_id): + raise ValueError(f"线程 ID {thread_id} 已存在") + future = self.executor.submit(fn, *args, **kwargs) + self.task_map[thread_id] = future # 记录线程 ID 和 Future 对象的映射 + return thread_id + + def check_task_is_running(self, thread_id): + future = self.task_map.get(thread_id) + if future: + if future.running(): + return True + else: + del self.task_map[thread_id] + return False + else: + return False + + def stop_task(self, thread_id): + """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务""" + future = self.task_map.get(thread_id) + if future: + future.cancel() # 尝试取消任务 + print(f"任务 {thread_id} 已取消") + del self.task_map[thread_id] # 从任务映射中删除 + else: + print(f"未找到线程 ID {thread_id}") + + def shutdown(self, wait=True): + """关闭线程池""" + self.executor.shutdown(wait=wait) + GlobalThreadPool._instance = None diff --git a/db/__init__.py b/db/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/db/__init__.py diff --git a/db/database.py b/db/database.py new file mode 100644 index 0000000..4f44d2d --- /dev/null +++ b/db/database.py @@ -0,0 +1,22 @@ +from sqlmodel import SQLModel, create_engine, Session +from contextlib import contextmanager + +sqlite_file_name = "./db/safe-algo-pro.db" +sqlite_url = f"sqlite:///{sqlite_file_name}" + +connect_args = {"check_same_thread": False} +engine = create_engine(sqlite_url, connect_args=connect_args) + + +# 初始化数据库表 +def init_db(): + SQLModel.metadata.create_all(engine) + + +# 数据库会话管理 +def get_db(): + session = Session(engine) + try: + yield session + finally: + session.close() diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db new file mode 100644 index 0000000..86c8d64 --- /dev/null +++ b/db/safe-algo-pro.db Binary files differ diff --git a/entity/__init__.py b/entity/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/entity/__init__.py diff --git a/entity/base.py b/entity/base.py new file mode 100644 index 0000000..6aa862c --- /dev/null +++ b/entity/base.py @@ -0,0 +1,13 @@ +from sqlmodel import SQLModel, Field +from typing import Optional +from datetime import datetime + + +class TimestampMixin(SQLModel): + create_time: Optional[datetime] = Field(default_factory=datetime.now) + update_time: Optional[datetime] = Field(default_factory=datetime.now) + + class Config: + json_encoders = { + datetime: lambda v: v.strftime('%Y-%m-%d %H:%M:%S') + } diff --git a/entity/device.py b/entity/device.py new file mode 100644 index 0000000..f09b52d --- /dev/null +++ b/entity/device.py @@ -0,0 +1,33 @@ +from typing import Optional +from sqlmodel import SQLModel, Field + +from entity.base import TimestampMixin + + +class DeviceBase(SQLModel): + name: str + code: str + type: Optional[str] = None + ip: str + input_stream_url: Optional[str] = None + output_stream_url: Optional[str] = None + image_save_interval: Optional[int] = None + + +class Device(DeviceBase, TimestampMixin, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + + +class DeviceCreate(DeviceBase): + pass + + +class DeviceUpdate(DeviceBase): + id: int + name: Optional[str] = None + code: Optional[str] = None + ip: Optional[str] = None + + +class DeviceInfo(DeviceBase, TimestampMixin): + id: int diff --git a/entity/device_model_relation.py b/entity/device_model_relation.py new file mode 100644 index 0000000..2e58a39 --- /dev/null +++ b/entity/device_model_relation.py @@ -0,0 +1,31 @@ +from typing import Optional + +from sqlmodel import SQLModel, Field + +from entity.base import TimestampMixin + + +class DeviceModelRelationBase(SQLModel): + algo_model_id: int + is_use: int + threshold: Optional[float] = None + alarm_interval: Optional[int] = None + alarm_type: Optional[str] = None + + +class DeviceModelRelation(DeviceModelRelationBase, TimestampMixin, table=True): + __tablename__ = 'device_model_relation' + id: Optional[int] = Field(default=None, primary_key=True) + device_id: int + +class DeviceModelRelationCreate(DeviceModelRelationBase): + pass + + +class DeviceModelRelationInfo(DeviceModelRelationBase, TimestampMixin): + id: int + device_id: int + algo_model_name: str + algo_model_version: str + algo_model_path: Optional[str] = None # 可选字段 + algo_model_remark: Optional[str] = None diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e530a86 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +test.py +/logs/* \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..901fe74 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,20 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..b68d82f --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..3deddad --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml new file mode 100644 index 0000000..d0876a7 --- /dev/null +++ b/.idea/safe-algo-pro.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/algo/__init__.py b/algo/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/algo/__init__.py diff --git a/algo/algo_runner.py b/algo/algo_runner.py new file mode 100644 index 0000000..ee07eaf --- /dev/null +++ b/algo/algo_runner.py @@ -0,0 +1,117 @@ +from algo.model_manager import ModelManager +from common.consts import NotifyChangeType +from common.global_thread_pool import GlobalThreadPool +from db.database import get_db +from entity.device import Device +from services.device_model_relation_service import DeviceModelRelationService +from services.device_service import DeviceService +from services.model_service import ModelService + + +class AlgoRunner: + def __init__(self): + self.db = get_db() + self.device_service = DeviceService(self.db) + self.model_service = ModelService(self.db) + self.relation_service = DeviceModelRelationService(self.db) + self.model_manager = ModelManager(self.db) + self.thread_pool = GlobalThreadPool() + self.threads = {} # 用于存储设备对应的线程 + self.device_model_relations = {} + + # 注册设备和模型的变化回调 + self.device_service.register_change_callback(self.on_device_change) + self.model_service.register_change_callback(self.on_model_change) + self.relation_service.relation_change_callbacks(self.on_relation_change) + + def start(self): + """在程序启动时调用,读取设备和模型,启动检测线程""" + self.model_manager.load_models() + self.load_and_start_devices() + + def load_and_start_devices(self): + """从数据库读取设备列表并启动线程""" + devices = self.device_service.get_device_list() + devices = [device for device in devices if device.input_stream_url] + for device in devices: + self.start_device_thread(device) + + def run_detection(self, device: Device): + """todo 设备目标检测主逻辑""" + print(f"设备 {device.device_id} 的检测线程启动") + video_stream = self.device_service.get_video_stream(device.device_id) + + while True: + try: + frame = video_stream.read_frame() + for model in self.model_manager.models: + # 调用目标检测模型 + result = self.model_service.run_inference(model, frame) + # 在此处进行后处理,例如标记视频、生成告警等 + self.process_result(result, device.device_id) + except Exception as e: + print(f"设备 {device.device_id} 处理时出错: {e}") + break + video_stream.close() + print(f"设备 {device.device_id} 的检测线程结束") + + def start_device_thread(self, device: Device): + """为单个设备启动检测线程""" + + if not device.input_stream_url: + print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程') + return + + if device.device_id in self.threads: + print(f"设备 {device.device_id} 已经在运行中") + return + + # 获取设备绑定的模型列表 + relations = self.relation_service.get_device_models(device.device_id) + self.device_model_relations[device.device_id] = relations + + if not relations: + print(f"设备 {device.code} 未绑定模型,无法启动检测") + return + + future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}') + + def stop_device_thread(self, device_id): + """todo 控制标志位 停止指定设备的检测线程""" + if device_id in self.threads: + self.threads[device_id].cancel() # 尝试取消线程 + del self.threads[device_id] + print(f"设备 {device_id} 的检测线程已停止") + + def restart_device_thread(self, device_id): + self.stop_device_thread(device_id) + # todo 需要留个关闭时间吗 + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + + def on_device_change(self, device_id, change_type): + """设备变化回调""" + if change_type == NotifyChangeType.DEVICE_CREATE: + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + elif change_type == NotifyChangeType.DEVICE_DELETE: + self.stop_device_thread(device_id) + elif change_type == NotifyChangeType.DEVICE_UPDATE: + self.restart_device_thread(device_id) + + def on_model_change(self, model_id, change_type): + """模型变化回调""" + if change_type == NotifyChangeType.MODEL_UPDATE: + self.model_manager.reload_model(model_id) + + devices_to_reload = [ + device_id for device_id, relation_info in self.device_model_relations.items() + if relation_info.model_id == model_id + ] + for device_id in devices_to_reload: + self.restart_device_thread(device_id) + + def on_relation_change(self, device_id, change_type): + """设备模型关系变化回调""" + if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: + self.restart_device_thread(device_id) diff --git a/algo/model_manager.py b/algo/model_manager.py new file mode 100644 index 0000000..3ab0709 --- /dev/null +++ b/algo/model_manager.py @@ -0,0 +1,64 @@ +from dataclasses import dataclass +from typing import Optional + +import numpy as np +from ultralytics import YOLO + +from entity.model import AlgoModel +from services.model_service import ModelService +from common.global_logger import logger + + +@dataclass +class AlgoModelExec: + algo_model_id: int + algo_model_info: AlgoModel + algo_model_exec: Optional[object] = None + input_size: int = 640 + + +class ModelManager: + def __init__(self, db, model_warm_up=5): + self.db = db + self.model_service = ModelService(self.db) + self.models = {} + self.model_warm_up = model_warm_up + + self.load_models() + + def query_model_inuse(self): + algo_model_list = list(self.model_service.get_models_in_use()) + for algo_model in algo_model_list: + self.models[algo_model.id] = AlgoModelExec( + algo_model_id=algo_model.id, + algo_model_info=algo_model + ) + + def load_models(self): + self.models = {} + self.query_model_inuse() + for algo_model_id, algo_model_exec in self.models.items(): + self.load_model(algo_model_exec) + + def load_model(self, algo_model_exec: AlgoModelExec): + model_name = algo_model_exec.algo_model_info.name + model_path = algo_model_exec.algo_model_info.path + logger.info(f'loading model {model_name}: {model_path}') + + algo_model_exec.algo_model_exec = YOLO(model_path, task='detect') + if self.model_warm_up > 0: + logger.info(f'warming up model {model_name}') + imgsz = algo_model_exec.input_size + if not isinstance(imgsz, list): + imgsz = [imgsz, imgsz] + dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + for i in range(self.model_warm_up): + algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + logger.info(f'warm up model {model_name} success!') + logger.info(f'load model {model_name} success!') + + def reload_model(self, model_id): + algo_model_exec = self.models.get(model_id) + if algo_model_exec: + algo_model_exec.algo_model_exec = None + self.load_model(algo_model_exec) diff --git a/apis/__init__.py b/apis/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/apis/__init__.py diff --git a/apis/base.py b/apis/base.py new file mode 100644 index 0000000..a7c4ede --- /dev/null +++ b/apis/base.py @@ -0,0 +1,26 @@ +from pydantic import BaseModel, Field +from typing import Optional, Generic, TypeVar, Any, List + +# 定义一个泛型类型变量 T +T = TypeVar("T") + + +# 使用泛型 T 定义 data 的类型 +class StandardResponse(BaseModel, Generic[T]): + code: int = 200 + data: Optional[T] = None + message: str = "请求成功" + success: bool = True + + +class PageResponse(BaseModel, Generic[T]): + total: int + items: List[T] = Field(default_factory=list) + + +def standard_response(data: Any = None, code: int = 200, message: str = "请求成功", success: bool = True): + return StandardResponse(data=data, code=code, message=message, success=success) + + +def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False): + return StandardResponse(data=data, code=code, message=message, success=success) diff --git a/apis/device.py b/apis/device.py new file mode 100644 index 0000000..7eeaa78 --- /dev/null +++ b/apis/device.py @@ -0,0 +1,65 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, Query +from sqlmodel import Session + +from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response +from db.database import get_db +from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo +from services.device_service import DeviceService + +router = APIRouter() + + +@router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) +def get_device_list( + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + db: Session = Depends(get_db)): + service = DeviceService(db) + devices = list(service.get_device_list(name, code, device_type)) + return standard_response(data=devices) + + +@router.get("/page/", response_model=StandardResponse[PageResponse[Device]]) +def get_device_page( + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = DeviceService(db) + + # 获取分页后的设备列表和总数 + devices, total = service.get_device_page(name, code, device_type, offset, limit) + + return standard_response( + data=PageResponse(total=total, items=devices) + ) + + +@router.post("/add", response_model=StandardResponse[DeviceInfo]) +def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.create_device(device_data) + return standard_response(data=device) + + +@router.post("/update", response_model=StandardResponse[DeviceInfo]) +def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.update_device(device_data) + if not device: + return standard_error_response(data=device_data, message="Device not found") + return standard_response(data=device) + + +@router.delete("/delete", response_model=StandardResponse[int]) +def delete_device(device_id: int, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.delete_device(device_id) + if not device: + return standard_error_response(data=device_id, message="Device not found") + return standard_response(data=device_id) diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py new file mode 100644 index 0000000..7108571 --- /dev/null +++ b/apis/device_model_realtion.py @@ -0,0 +1,45 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, Query +from sqlmodel import Session + +from apis.base import StandardResponse, standard_response +from db.database import get_db +from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation +from services.device_model_relation_service import DeviceModelRelationService + +router = APIRouter() + + +@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) +def list_by_device( + device_id: int, + db: Session = Depends(get_db)): + service = DeviceModelRelationService(db) + models = list(service.get_device_models(device_id)) + return standard_response(data=models) + + +# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) +# def add_by_device(relation_data: List[DeviceModelRelationCreate], +# device_id: int = Query(...), +# db: Session = Depends(get_db)): +# service = DeviceModelRelationService(db) +# relations = service.add_relations_by_device(device_id, relation_data) +# return standard_response(data=relations) + + +@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) +def update_by_device(relation_data: List[DeviceModelRelationCreate], + device_id: int = Query(...), + db: Session = Depends(get_db)): + service = DeviceModelRelationService(db) + relations = service.update_relations_by_device(device_id, relation_data) + return standard_response(data=relations) + + +# @router.delete("/delete_by_device", response_model=StandardResponse[int]) +# def delete_device(device_id: int, db: Session = Depends(get_db)): +# service = DeviceModelRelationService(db) +# count = service.delete_relations_by_device(device_id) +# return standard_response(data=count) diff --git a/apis/model.py b/apis/model.py new file mode 100644 index 0000000..8adcc14 --- /dev/null +++ b/apis/model.py @@ -0,0 +1,64 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlmodel import Session + +from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response +from common.biz_exception import BizException +from db.database import get_db +from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo +from services.model_service import ModelService + +router = APIRouter() + + +@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) +def get_model_list( + name: Optional[str] = None, + remark: Optional[str] = None, + db: Session = Depends(get_db)): + service = ModelService(db) + models = list(service.get_model_list(name, remark)) + return standard_response(data=models) + + +@router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModel]]) +def get_model_page( + name: Optional[str] = None, + remark: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = ModelService(db) + + # 获取分页后的设备列表和总数 + models, total = service.get_model_page(name, remark, offset, limit) + + return standard_response( + data=PageResponse(total=total, items=models) + ) + + +@router.post("/add", response_model=StandardResponse[AlgoModelInfo]) +def create_model(model_data: AlgoModelCreate, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.create_model(model_data) + return standard_response(data=model) + + +@router.post("/update", response_model=StandardResponse[AlgoModelInfo]) +def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.update_model(model_data) + if not model: + return standard_error_response(data=model_data, message="Model not found") + return standard_response(data=model) + + +@router.delete("/delete", response_model=StandardResponse[int]) +def delete_model(model_id: int, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.delete_model(model_id) + if not model: + return standard_error_response(data=model_id, message="Model not found") + return standard_response(data=model_id) diff --git a/apis/router.py b/apis/router.py new file mode 100644 index 0000000..22b5ed6 --- /dev/null +++ b/apis/router.py @@ -0,0 +1,15 @@ +# api/router.py +from fastapi import APIRouter +from .device import router as devices_router +from .model import router as models_router +from .device_model_realtion import router as device_model_relation_router + + +# 创建一个全局的 router +router = APIRouter() + +# 将各个模块的 router 注册到全局 router 中 +router.include_router(devices_router, prefix="/device", tags=["Devices"]) +router.include_router(models_router, prefix="/model", tags=["Models"]) +router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"]) + diff --git a/common/__init__.py b/common/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/common/__init__.py diff --git a/common/biz_exception.py b/common/biz_exception.py new file mode 100644 index 0000000..946a794 --- /dev/null +++ b/common/biz_exception.py @@ -0,0 +1,14 @@ +from fastapi import Request, HTTPException + +from apis.base import standard_error_response + + +class BizException(HTTPException): + def __init__(self, message: str = "请求异常", status_code: int = 500): + super().__init__(status_code=status_code, detail=message) + + +class BizExceptionHandlers: + @staticmethod + async def biz_exception_handler(request: Request, exc: HTTPException): + standard_error_response(code=exc.status_code, message=exc.detail) diff --git a/common/consts.py b/common/consts.py new file mode 100644 index 0000000..bba41e4 --- /dev/null +++ b/common/consts.py @@ -0,0 +1,16 @@ +class Constants: + def __setattr__(self, name, value): + raise AttributeError(f"Cannot modify constant {name}") + + +class NotifyChangeType(Constants): + # 定义常量成员 + DEVICE_CREATE = "device_create" + DEVICE_UPDATE = "device_update" + DEVICE_DELETE = "device_delete" + MODEL_CREATE = "model_create" # 模型新增不用通知,与设备绑定时才生效 + MODEL_UPDATE = "model_update" + MODEL_DELETE = "model_delete" # 正在使用的模型不能直接删除 + DEVICE_MODEL_RELATION_CREATE = "device_model_relation_create" + DEVICE_MODEL_RELATION_UPDATE = "device_model_relation_update" # 绑定关系变化 应该只用这个吧?? + DEVICE_MODEL_RELATION_DELETE = "device_model_relation_delete" diff --git a/common/global_logger.py b/common/global_logger.py new file mode 100644 index 0000000..1d65173 --- /dev/null +++ b/common/global_logger.py @@ -0,0 +1,30 @@ +# logger.py +import logging.handlers +import os + +# 确保日志目录存在 +log_dir = '../logs' +if not os.path.exists(log_dir): + os.makedirs(log_dir) + +# 实例化并导出全局日志记录器 +logger = logging.getLogger("casic_safe_logger") +logger.setLevel(logging.DEBUG) # 设置日志级别 + +# 创建一个TimedRotatingFileHandler +handler = logging.handlers.TimedRotatingFileHandler( + os.path.join(log_dir, 'app.log'), # 日志文件名 + when='midnight', # 每天午夜滚动 + interval=1 # 滚动间隔为1天 +) +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +handler.setFormatter(formatter) + +# 将handler添加到日志器 +logger.addHandler(handler) + +# 创建控制台处理器 +console_handler = logging.StreamHandler() +console_handler.setLevel(logging.DEBUG) +console_handler.setFormatter(formatter) +logger.addHandler(console_handler) diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py new file mode 100644 index 0000000..21d7fd5 --- /dev/null +++ b/common/global_thread_pool.py @@ -0,0 +1,60 @@ +import uuid +from concurrent.futures import ThreadPoolExecutor +import threading + + +def generate_thread_id(): + """生成唯一的线程 ID""" + return str(uuid.uuid4()) + + +class GlobalThreadPool: + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + with cls._lock: + if cls._instance is None: + # 第一次创建实例时调用父类的 __new__ 来创建实例 + cls._instance = super(GlobalThreadPool, cls).__new__(cls) + # 在此进行一次性的初始化,比如线程池的创建 + max_workers = kwargs.get('max_workers', 10) + cls._instance.executor = ThreadPoolExecutor(max_workers=max_workers) + cls._instance.task_map = {} # 初始化任务映射 + return cls._instance + + def submit_task(self, fn, *args, thread_id=None, **kwargs): + """提交任务到线程池,并记录线程 ID""" + if thread_id is None: + thread_id = generate_thread_id() + if self.check_task_is_running(thread_id): + raise ValueError(f"线程 ID {thread_id} 已存在") + future = self.executor.submit(fn, *args, **kwargs) + self.task_map[thread_id] = future # 记录线程 ID 和 Future 对象的映射 + return thread_id + + def check_task_is_running(self, thread_id): + future = self.task_map.get(thread_id) + if future: + if future.running(): + return True + else: + del self.task_map[thread_id] + return False + else: + return False + + def stop_task(self, thread_id): + """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务""" + future = self.task_map.get(thread_id) + if future: + future.cancel() # 尝试取消任务 + print(f"任务 {thread_id} 已取消") + del self.task_map[thread_id] # 从任务映射中删除 + else: + print(f"未找到线程 ID {thread_id}") + + def shutdown(self, wait=True): + """关闭线程池""" + self.executor.shutdown(wait=wait) + GlobalThreadPool._instance = None diff --git a/db/__init__.py b/db/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/db/__init__.py diff --git a/db/database.py b/db/database.py new file mode 100644 index 0000000..4f44d2d --- /dev/null +++ b/db/database.py @@ -0,0 +1,22 @@ +from sqlmodel import SQLModel, create_engine, Session +from contextlib import contextmanager + +sqlite_file_name = "./db/safe-algo-pro.db" +sqlite_url = f"sqlite:///{sqlite_file_name}" + +connect_args = {"check_same_thread": False} +engine = create_engine(sqlite_url, connect_args=connect_args) + + +# 初始化数据库表 +def init_db(): + SQLModel.metadata.create_all(engine) + + +# 数据库会话管理 +def get_db(): + session = Session(engine) + try: + yield session + finally: + session.close() diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db new file mode 100644 index 0000000..86c8d64 --- /dev/null +++ b/db/safe-algo-pro.db Binary files differ diff --git a/entity/__init__.py b/entity/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/entity/__init__.py diff --git a/entity/base.py b/entity/base.py new file mode 100644 index 0000000..6aa862c --- /dev/null +++ b/entity/base.py @@ -0,0 +1,13 @@ +from sqlmodel import SQLModel, Field +from typing import Optional +from datetime import datetime + + +class TimestampMixin(SQLModel): + create_time: Optional[datetime] = Field(default_factory=datetime.now) + update_time: Optional[datetime] = Field(default_factory=datetime.now) + + class Config: + json_encoders = { + datetime: lambda v: v.strftime('%Y-%m-%d %H:%M:%S') + } diff --git a/entity/device.py b/entity/device.py new file mode 100644 index 0000000..f09b52d --- /dev/null +++ b/entity/device.py @@ -0,0 +1,33 @@ +from typing import Optional +from sqlmodel import SQLModel, Field + +from entity.base import TimestampMixin + + +class DeviceBase(SQLModel): + name: str + code: str + type: Optional[str] = None + ip: str + input_stream_url: Optional[str] = None + output_stream_url: Optional[str] = None + image_save_interval: Optional[int] = None + + +class Device(DeviceBase, TimestampMixin, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + + +class DeviceCreate(DeviceBase): + pass + + +class DeviceUpdate(DeviceBase): + id: int + name: Optional[str] = None + code: Optional[str] = None + ip: Optional[str] = None + + +class DeviceInfo(DeviceBase, TimestampMixin): + id: int diff --git a/entity/device_model_relation.py b/entity/device_model_relation.py new file mode 100644 index 0000000..2e58a39 --- /dev/null +++ b/entity/device_model_relation.py @@ -0,0 +1,31 @@ +from typing import Optional + +from sqlmodel import SQLModel, Field + +from entity.base import TimestampMixin + + +class DeviceModelRelationBase(SQLModel): + algo_model_id: int + is_use: int + threshold: Optional[float] = None + alarm_interval: Optional[int] = None + alarm_type: Optional[str] = None + + +class DeviceModelRelation(DeviceModelRelationBase, TimestampMixin, table=True): + __tablename__ = 'device_model_relation' + id: Optional[int] = Field(default=None, primary_key=True) + device_id: int + +class DeviceModelRelationCreate(DeviceModelRelationBase): + pass + + +class DeviceModelRelationInfo(DeviceModelRelationBase, TimestampMixin): + id: int + device_id: int + algo_model_name: str + algo_model_version: str + algo_model_path: Optional[str] = None # 可选字段 + algo_model_remark: Optional[str] = None diff --git a/entity/model.py b/entity/model.py new file mode 100644 index 0000000..db9133e --- /dev/null +++ b/entity/model.py @@ -0,0 +1,33 @@ +from typing import Optional + +from sqlmodel import SQLModel, Field + +from entity.base import TimestampMixin + + +class AlgoModelBase(SQLModel): + name: str + version: str + path: str + remark: Optional[str] = None + + +class AlgoModel(AlgoModelBase, TimestampMixin, table=True): + __tablename__ = "algo_model" # 显式指定表名 + + id: Optional[int] = Field(default=None, primary_key=True) + + +class AlgoModelCreate(AlgoModelBase): + pass + + +class AlgoModelUpdate(AlgoModelBase): + id: int + name: Optional[str] = None + version: Optional[str] = None + path: Optional[str] = None + + +class AlgoModelInfo(AlgoModelBase, TimestampMixin): + id: int diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e530a86 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +test.py +/logs/* \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..901fe74 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,20 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..b68d82f --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..3deddad --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml new file mode 100644 index 0000000..d0876a7 --- /dev/null +++ b/.idea/safe-algo-pro.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/algo/__init__.py b/algo/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/algo/__init__.py diff --git a/algo/algo_runner.py b/algo/algo_runner.py new file mode 100644 index 0000000..ee07eaf --- /dev/null +++ b/algo/algo_runner.py @@ -0,0 +1,117 @@ +from algo.model_manager import ModelManager +from common.consts import NotifyChangeType +from common.global_thread_pool import GlobalThreadPool +from db.database import get_db +from entity.device import Device +from services.device_model_relation_service import DeviceModelRelationService +from services.device_service import DeviceService +from services.model_service import ModelService + + +class AlgoRunner: + def __init__(self): + self.db = get_db() + self.device_service = DeviceService(self.db) + self.model_service = ModelService(self.db) + self.relation_service = DeviceModelRelationService(self.db) + self.model_manager = ModelManager(self.db) + self.thread_pool = GlobalThreadPool() + self.threads = {} # 用于存储设备对应的线程 + self.device_model_relations = {} + + # 注册设备和模型的变化回调 + self.device_service.register_change_callback(self.on_device_change) + self.model_service.register_change_callback(self.on_model_change) + self.relation_service.relation_change_callbacks(self.on_relation_change) + + def start(self): + """在程序启动时调用,读取设备和模型,启动检测线程""" + self.model_manager.load_models() + self.load_and_start_devices() + + def load_and_start_devices(self): + """从数据库读取设备列表并启动线程""" + devices = self.device_service.get_device_list() + devices = [device for device in devices if device.input_stream_url] + for device in devices: + self.start_device_thread(device) + + def run_detection(self, device: Device): + """todo 设备目标检测主逻辑""" + print(f"设备 {device.device_id} 的检测线程启动") + video_stream = self.device_service.get_video_stream(device.device_id) + + while True: + try: + frame = video_stream.read_frame() + for model in self.model_manager.models: + # 调用目标检测模型 + result = self.model_service.run_inference(model, frame) + # 在此处进行后处理,例如标记视频、生成告警等 + self.process_result(result, device.device_id) + except Exception as e: + print(f"设备 {device.device_id} 处理时出错: {e}") + break + video_stream.close() + print(f"设备 {device.device_id} 的检测线程结束") + + def start_device_thread(self, device: Device): + """为单个设备启动检测线程""" + + if not device.input_stream_url: + print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程') + return + + if device.device_id in self.threads: + print(f"设备 {device.device_id} 已经在运行中") + return + + # 获取设备绑定的模型列表 + relations = self.relation_service.get_device_models(device.device_id) + self.device_model_relations[device.device_id] = relations + + if not relations: + print(f"设备 {device.code} 未绑定模型,无法启动检测") + return + + future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}') + + def stop_device_thread(self, device_id): + """todo 控制标志位 停止指定设备的检测线程""" + if device_id in self.threads: + self.threads[device_id].cancel() # 尝试取消线程 + del self.threads[device_id] + print(f"设备 {device_id} 的检测线程已停止") + + def restart_device_thread(self, device_id): + self.stop_device_thread(device_id) + # todo 需要留个关闭时间吗 + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + + def on_device_change(self, device_id, change_type): + """设备变化回调""" + if change_type == NotifyChangeType.DEVICE_CREATE: + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + elif change_type == NotifyChangeType.DEVICE_DELETE: + self.stop_device_thread(device_id) + elif change_type == NotifyChangeType.DEVICE_UPDATE: + self.restart_device_thread(device_id) + + def on_model_change(self, model_id, change_type): + """模型变化回调""" + if change_type == NotifyChangeType.MODEL_UPDATE: + self.model_manager.reload_model(model_id) + + devices_to_reload = [ + device_id for device_id, relation_info in self.device_model_relations.items() + if relation_info.model_id == model_id + ] + for device_id in devices_to_reload: + self.restart_device_thread(device_id) + + def on_relation_change(self, device_id, change_type): + """设备模型关系变化回调""" + if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: + self.restart_device_thread(device_id) diff --git a/algo/model_manager.py b/algo/model_manager.py new file mode 100644 index 0000000..3ab0709 --- /dev/null +++ b/algo/model_manager.py @@ -0,0 +1,64 @@ +from dataclasses import dataclass +from typing import Optional + +import numpy as np +from ultralytics import YOLO + +from entity.model import AlgoModel +from services.model_service import ModelService +from common.global_logger import logger + + +@dataclass +class AlgoModelExec: + algo_model_id: int + algo_model_info: AlgoModel + algo_model_exec: Optional[object] = None + input_size: int = 640 + + +class ModelManager: + def __init__(self, db, model_warm_up=5): + self.db = db + self.model_service = ModelService(self.db) + self.models = {} + self.model_warm_up = model_warm_up + + self.load_models() + + def query_model_inuse(self): + algo_model_list = list(self.model_service.get_models_in_use()) + for algo_model in algo_model_list: + self.models[algo_model.id] = AlgoModelExec( + algo_model_id=algo_model.id, + algo_model_info=algo_model + ) + + def load_models(self): + self.models = {} + self.query_model_inuse() + for algo_model_id, algo_model_exec in self.models.items(): + self.load_model(algo_model_exec) + + def load_model(self, algo_model_exec: AlgoModelExec): + model_name = algo_model_exec.algo_model_info.name + model_path = algo_model_exec.algo_model_info.path + logger.info(f'loading model {model_name}: {model_path}') + + algo_model_exec.algo_model_exec = YOLO(model_path, task='detect') + if self.model_warm_up > 0: + logger.info(f'warming up model {model_name}') + imgsz = algo_model_exec.input_size + if not isinstance(imgsz, list): + imgsz = [imgsz, imgsz] + dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + for i in range(self.model_warm_up): + algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + logger.info(f'warm up model {model_name} success!') + logger.info(f'load model {model_name} success!') + + def reload_model(self, model_id): + algo_model_exec = self.models.get(model_id) + if algo_model_exec: + algo_model_exec.algo_model_exec = None + self.load_model(algo_model_exec) diff --git a/apis/__init__.py b/apis/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/apis/__init__.py diff --git a/apis/base.py b/apis/base.py new file mode 100644 index 0000000..a7c4ede --- /dev/null +++ b/apis/base.py @@ -0,0 +1,26 @@ +from pydantic import BaseModel, Field +from typing import Optional, Generic, TypeVar, Any, List + +# 定义一个泛型类型变量 T +T = TypeVar("T") + + +# 使用泛型 T 定义 data 的类型 +class StandardResponse(BaseModel, Generic[T]): + code: int = 200 + data: Optional[T] = None + message: str = "请求成功" + success: bool = True + + +class PageResponse(BaseModel, Generic[T]): + total: int + items: List[T] = Field(default_factory=list) + + +def standard_response(data: Any = None, code: int = 200, message: str = "请求成功", success: bool = True): + return StandardResponse(data=data, code=code, message=message, success=success) + + +def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False): + return StandardResponse(data=data, code=code, message=message, success=success) diff --git a/apis/device.py b/apis/device.py new file mode 100644 index 0000000..7eeaa78 --- /dev/null +++ b/apis/device.py @@ -0,0 +1,65 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, Query +from sqlmodel import Session + +from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response +from db.database import get_db +from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo +from services.device_service import DeviceService + +router = APIRouter() + + +@router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) +def get_device_list( + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + db: Session = Depends(get_db)): + service = DeviceService(db) + devices = list(service.get_device_list(name, code, device_type)) + return standard_response(data=devices) + + +@router.get("/page/", response_model=StandardResponse[PageResponse[Device]]) +def get_device_page( + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = DeviceService(db) + + # 获取分页后的设备列表和总数 + devices, total = service.get_device_page(name, code, device_type, offset, limit) + + return standard_response( + data=PageResponse(total=total, items=devices) + ) + + +@router.post("/add", response_model=StandardResponse[DeviceInfo]) +def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.create_device(device_data) + return standard_response(data=device) + + +@router.post("/update", response_model=StandardResponse[DeviceInfo]) +def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.update_device(device_data) + if not device: + return standard_error_response(data=device_data, message="Device not found") + return standard_response(data=device) + + +@router.delete("/delete", response_model=StandardResponse[int]) +def delete_device(device_id: int, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.delete_device(device_id) + if not device: + return standard_error_response(data=device_id, message="Device not found") + return standard_response(data=device_id) diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py new file mode 100644 index 0000000..7108571 --- /dev/null +++ b/apis/device_model_realtion.py @@ -0,0 +1,45 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, Query +from sqlmodel import Session + +from apis.base import StandardResponse, standard_response +from db.database import get_db +from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation +from services.device_model_relation_service import DeviceModelRelationService + +router = APIRouter() + + +@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) +def list_by_device( + device_id: int, + db: Session = Depends(get_db)): + service = DeviceModelRelationService(db) + models = list(service.get_device_models(device_id)) + return standard_response(data=models) + + +# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) +# def add_by_device(relation_data: List[DeviceModelRelationCreate], +# device_id: int = Query(...), +# db: Session = Depends(get_db)): +# service = DeviceModelRelationService(db) +# relations = service.add_relations_by_device(device_id, relation_data) +# return standard_response(data=relations) + + +@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) +def update_by_device(relation_data: List[DeviceModelRelationCreate], + device_id: int = Query(...), + db: Session = Depends(get_db)): + service = DeviceModelRelationService(db) + relations = service.update_relations_by_device(device_id, relation_data) + return standard_response(data=relations) + + +# @router.delete("/delete_by_device", response_model=StandardResponse[int]) +# def delete_device(device_id: int, db: Session = Depends(get_db)): +# service = DeviceModelRelationService(db) +# count = service.delete_relations_by_device(device_id) +# return standard_response(data=count) diff --git a/apis/model.py b/apis/model.py new file mode 100644 index 0000000..8adcc14 --- /dev/null +++ b/apis/model.py @@ -0,0 +1,64 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlmodel import Session + +from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response +from common.biz_exception import BizException +from db.database import get_db +from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo +from services.model_service import ModelService + +router = APIRouter() + + +@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) +def get_model_list( + name: Optional[str] = None, + remark: Optional[str] = None, + db: Session = Depends(get_db)): + service = ModelService(db) + models = list(service.get_model_list(name, remark)) + return standard_response(data=models) + + +@router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModel]]) +def get_model_page( + name: Optional[str] = None, + remark: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = ModelService(db) + + # 获取分页后的设备列表和总数 + models, total = service.get_model_page(name, remark, offset, limit) + + return standard_response( + data=PageResponse(total=total, items=models) + ) + + +@router.post("/add", response_model=StandardResponse[AlgoModelInfo]) +def create_model(model_data: AlgoModelCreate, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.create_model(model_data) + return standard_response(data=model) + + +@router.post("/update", response_model=StandardResponse[AlgoModelInfo]) +def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.update_model(model_data) + if not model: + return standard_error_response(data=model_data, message="Model not found") + return standard_response(data=model) + + +@router.delete("/delete", response_model=StandardResponse[int]) +def delete_model(model_id: int, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.delete_model(model_id) + if not model: + return standard_error_response(data=model_id, message="Model not found") + return standard_response(data=model_id) diff --git a/apis/router.py b/apis/router.py new file mode 100644 index 0000000..22b5ed6 --- /dev/null +++ b/apis/router.py @@ -0,0 +1,15 @@ +# api/router.py +from fastapi import APIRouter +from .device import router as devices_router +from .model import router as models_router +from .device_model_realtion import router as device_model_relation_router + + +# 创建一个全局的 router +router = APIRouter() + +# 将各个模块的 router 注册到全局 router 中 +router.include_router(devices_router, prefix="/device", tags=["Devices"]) +router.include_router(models_router, prefix="/model", tags=["Models"]) +router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"]) + diff --git a/common/__init__.py b/common/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/common/__init__.py diff --git a/common/biz_exception.py b/common/biz_exception.py new file mode 100644 index 0000000..946a794 --- /dev/null +++ b/common/biz_exception.py @@ -0,0 +1,14 @@ +from fastapi import Request, HTTPException + +from apis.base import standard_error_response + + +class BizException(HTTPException): + def __init__(self, message: str = "请求异常", status_code: int = 500): + super().__init__(status_code=status_code, detail=message) + + +class BizExceptionHandlers: + @staticmethod + async def biz_exception_handler(request: Request, exc: HTTPException): + standard_error_response(code=exc.status_code, message=exc.detail) diff --git a/common/consts.py b/common/consts.py new file mode 100644 index 0000000..bba41e4 --- /dev/null +++ b/common/consts.py @@ -0,0 +1,16 @@ +class Constants: + def __setattr__(self, name, value): + raise AttributeError(f"Cannot modify constant {name}") + + +class NotifyChangeType(Constants): + # 定义常量成员 + DEVICE_CREATE = "device_create" + DEVICE_UPDATE = "device_update" + DEVICE_DELETE = "device_delete" + MODEL_CREATE = "model_create" # 模型新增不用通知,与设备绑定时才生效 + MODEL_UPDATE = "model_update" + MODEL_DELETE = "model_delete" # 正在使用的模型不能直接删除 + DEVICE_MODEL_RELATION_CREATE = "device_model_relation_create" + DEVICE_MODEL_RELATION_UPDATE = "device_model_relation_update" # 绑定关系变化 应该只用这个吧?? + DEVICE_MODEL_RELATION_DELETE = "device_model_relation_delete" diff --git a/common/global_logger.py b/common/global_logger.py new file mode 100644 index 0000000..1d65173 --- /dev/null +++ b/common/global_logger.py @@ -0,0 +1,30 @@ +# logger.py +import logging.handlers +import os + +# 确保日志目录存在 +log_dir = '../logs' +if not os.path.exists(log_dir): + os.makedirs(log_dir) + +# 实例化并导出全局日志记录器 +logger = logging.getLogger("casic_safe_logger") +logger.setLevel(logging.DEBUG) # 设置日志级别 + +# 创建一个TimedRotatingFileHandler +handler = logging.handlers.TimedRotatingFileHandler( + os.path.join(log_dir, 'app.log'), # 日志文件名 + when='midnight', # 每天午夜滚动 + interval=1 # 滚动间隔为1天 +) +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +handler.setFormatter(formatter) + +# 将handler添加到日志器 +logger.addHandler(handler) + +# 创建控制台处理器 +console_handler = logging.StreamHandler() +console_handler.setLevel(logging.DEBUG) +console_handler.setFormatter(formatter) +logger.addHandler(console_handler) diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py new file mode 100644 index 0000000..21d7fd5 --- /dev/null +++ b/common/global_thread_pool.py @@ -0,0 +1,60 @@ +import uuid +from concurrent.futures import ThreadPoolExecutor +import threading + + +def generate_thread_id(): + """生成唯一的线程 ID""" + return str(uuid.uuid4()) + + +class GlobalThreadPool: + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + with cls._lock: + if cls._instance is None: + # 第一次创建实例时调用父类的 __new__ 来创建实例 + cls._instance = super(GlobalThreadPool, cls).__new__(cls) + # 在此进行一次性的初始化,比如线程池的创建 + max_workers = kwargs.get('max_workers', 10) + cls._instance.executor = ThreadPoolExecutor(max_workers=max_workers) + cls._instance.task_map = {} # 初始化任务映射 + return cls._instance + + def submit_task(self, fn, *args, thread_id=None, **kwargs): + """提交任务到线程池,并记录线程 ID""" + if thread_id is None: + thread_id = generate_thread_id() + if self.check_task_is_running(thread_id): + raise ValueError(f"线程 ID {thread_id} 已存在") + future = self.executor.submit(fn, *args, **kwargs) + self.task_map[thread_id] = future # 记录线程 ID 和 Future 对象的映射 + return thread_id + + def check_task_is_running(self, thread_id): + future = self.task_map.get(thread_id) + if future: + if future.running(): + return True + else: + del self.task_map[thread_id] + return False + else: + return False + + def stop_task(self, thread_id): + """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务""" + future = self.task_map.get(thread_id) + if future: + future.cancel() # 尝试取消任务 + print(f"任务 {thread_id} 已取消") + del self.task_map[thread_id] # 从任务映射中删除 + else: + print(f"未找到线程 ID {thread_id}") + + def shutdown(self, wait=True): + """关闭线程池""" + self.executor.shutdown(wait=wait) + GlobalThreadPool._instance = None diff --git a/db/__init__.py b/db/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/db/__init__.py diff --git a/db/database.py b/db/database.py new file mode 100644 index 0000000..4f44d2d --- /dev/null +++ b/db/database.py @@ -0,0 +1,22 @@ +from sqlmodel import SQLModel, create_engine, Session +from contextlib import contextmanager + +sqlite_file_name = "./db/safe-algo-pro.db" +sqlite_url = f"sqlite:///{sqlite_file_name}" + +connect_args = {"check_same_thread": False} +engine = create_engine(sqlite_url, connect_args=connect_args) + + +# 初始化数据库表 +def init_db(): + SQLModel.metadata.create_all(engine) + + +# 数据库会话管理 +def get_db(): + session = Session(engine) + try: + yield session + finally: + session.close() diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db new file mode 100644 index 0000000..86c8d64 --- /dev/null +++ b/db/safe-algo-pro.db Binary files differ diff --git a/entity/__init__.py b/entity/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/entity/__init__.py diff --git a/entity/base.py b/entity/base.py new file mode 100644 index 0000000..6aa862c --- /dev/null +++ b/entity/base.py @@ -0,0 +1,13 @@ +from sqlmodel import SQLModel, Field +from typing import Optional +from datetime import datetime + + +class TimestampMixin(SQLModel): + create_time: Optional[datetime] = Field(default_factory=datetime.now) + update_time: Optional[datetime] = Field(default_factory=datetime.now) + + class Config: + json_encoders = { + datetime: lambda v: v.strftime('%Y-%m-%d %H:%M:%S') + } diff --git a/entity/device.py b/entity/device.py new file mode 100644 index 0000000..f09b52d --- /dev/null +++ b/entity/device.py @@ -0,0 +1,33 @@ +from typing import Optional +from sqlmodel import SQLModel, Field + +from entity.base import TimestampMixin + + +class DeviceBase(SQLModel): + name: str + code: str + type: Optional[str] = None + ip: str + input_stream_url: Optional[str] = None + output_stream_url: Optional[str] = None + image_save_interval: Optional[int] = None + + +class Device(DeviceBase, TimestampMixin, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + + +class DeviceCreate(DeviceBase): + pass + + +class DeviceUpdate(DeviceBase): + id: int + name: Optional[str] = None + code: Optional[str] = None + ip: Optional[str] = None + + +class DeviceInfo(DeviceBase, TimestampMixin): + id: int diff --git a/entity/device_model_relation.py b/entity/device_model_relation.py new file mode 100644 index 0000000..2e58a39 --- /dev/null +++ b/entity/device_model_relation.py @@ -0,0 +1,31 @@ +from typing import Optional + +from sqlmodel import SQLModel, Field + +from entity.base import TimestampMixin + + +class DeviceModelRelationBase(SQLModel): + algo_model_id: int + is_use: int + threshold: Optional[float] = None + alarm_interval: Optional[int] = None + alarm_type: Optional[str] = None + + +class DeviceModelRelation(DeviceModelRelationBase, TimestampMixin, table=True): + __tablename__ = 'device_model_relation' + id: Optional[int] = Field(default=None, primary_key=True) + device_id: int + +class DeviceModelRelationCreate(DeviceModelRelationBase): + pass + + +class DeviceModelRelationInfo(DeviceModelRelationBase, TimestampMixin): + id: int + device_id: int + algo_model_name: str + algo_model_version: str + algo_model_path: Optional[str] = None # 可选字段 + algo_model_remark: Optional[str] = None diff --git a/entity/model.py b/entity/model.py new file mode 100644 index 0000000..db9133e --- /dev/null +++ b/entity/model.py @@ -0,0 +1,33 @@ +from typing import Optional + +from sqlmodel import SQLModel, Field + +from entity.base import TimestampMixin + + +class AlgoModelBase(SQLModel): + name: str + version: str + path: str + remark: Optional[str] = None + + +class AlgoModel(AlgoModelBase, TimestampMixin, table=True): + __tablename__ = "algo_model" # 显式指定表名 + + id: Optional[int] = Field(default=None, primary_key=True) + + +class AlgoModelCreate(AlgoModelBase): + pass + + +class AlgoModelUpdate(AlgoModelBase): + id: int + name: Optional[str] = None + version: Optional[str] = None + path: Optional[str] = None + + +class AlgoModelInfo(AlgoModelBase, TimestampMixin): + id: int diff --git a/main.py b/main.py new file mode 100644 index 0000000..75c9b52 --- /dev/null +++ b/main.py @@ -0,0 +1,15 @@ +from fastapi import FastAPI, HTTPException +from apis.router import router +import uvicorn + +from common.biz_exception import BizExceptionHandlers + +app = FastAPI() + +# 包含所有模块的路由 +app.include_router(router, prefix="/api") +app.add_exception_handler(HTTPException, BizExceptionHandlers.http_exception_handler) + +if __name__ == "__main__": + + uvicorn.run(app, host="0.0.0.0", port=9000) \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e530a86 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +test.py +/logs/* \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..901fe74 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,20 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..b68d82f --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..3deddad --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml new file mode 100644 index 0000000..d0876a7 --- /dev/null +++ b/.idea/safe-algo-pro.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/algo/__init__.py b/algo/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/algo/__init__.py diff --git a/algo/algo_runner.py b/algo/algo_runner.py new file mode 100644 index 0000000..ee07eaf --- /dev/null +++ b/algo/algo_runner.py @@ -0,0 +1,117 @@ +from algo.model_manager import ModelManager +from common.consts import NotifyChangeType +from common.global_thread_pool import GlobalThreadPool +from db.database import get_db +from entity.device import Device +from services.device_model_relation_service import DeviceModelRelationService +from services.device_service import DeviceService +from services.model_service import ModelService + + +class AlgoRunner: + def __init__(self): + self.db = get_db() + self.device_service = DeviceService(self.db) + self.model_service = ModelService(self.db) + self.relation_service = DeviceModelRelationService(self.db) + self.model_manager = ModelManager(self.db) + self.thread_pool = GlobalThreadPool() + self.threads = {} # 用于存储设备对应的线程 + self.device_model_relations = {} + + # 注册设备和模型的变化回调 + self.device_service.register_change_callback(self.on_device_change) + self.model_service.register_change_callback(self.on_model_change) + self.relation_service.relation_change_callbacks(self.on_relation_change) + + def start(self): + """在程序启动时调用,读取设备和模型,启动检测线程""" + self.model_manager.load_models() + self.load_and_start_devices() + + def load_and_start_devices(self): + """从数据库读取设备列表并启动线程""" + devices = self.device_service.get_device_list() + devices = [device for device in devices if device.input_stream_url] + for device in devices: + self.start_device_thread(device) + + def run_detection(self, device: Device): + """todo 设备目标检测主逻辑""" + print(f"设备 {device.device_id} 的检测线程启动") + video_stream = self.device_service.get_video_stream(device.device_id) + + while True: + try: + frame = video_stream.read_frame() + for model in self.model_manager.models: + # 调用目标检测模型 + result = self.model_service.run_inference(model, frame) + # 在此处进行后处理,例如标记视频、生成告警等 + self.process_result(result, device.device_id) + except Exception as e: + print(f"设备 {device.device_id} 处理时出错: {e}") + break + video_stream.close() + print(f"设备 {device.device_id} 的检测线程结束") + + def start_device_thread(self, device: Device): + """为单个设备启动检测线程""" + + if not device.input_stream_url: + print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程') + return + + if device.device_id in self.threads: + print(f"设备 {device.device_id} 已经在运行中") + return + + # 获取设备绑定的模型列表 + relations = self.relation_service.get_device_models(device.device_id) + self.device_model_relations[device.device_id] = relations + + if not relations: + print(f"设备 {device.code} 未绑定模型,无法启动检测") + return + + future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}') + + def stop_device_thread(self, device_id): + """todo 控制标志位 停止指定设备的检测线程""" + if device_id in self.threads: + self.threads[device_id].cancel() # 尝试取消线程 + del self.threads[device_id] + print(f"设备 {device_id} 的检测线程已停止") + + def restart_device_thread(self, device_id): + self.stop_device_thread(device_id) + # todo 需要留个关闭时间吗 + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + + def on_device_change(self, device_id, change_type): + """设备变化回调""" + if change_type == NotifyChangeType.DEVICE_CREATE: + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + elif change_type == NotifyChangeType.DEVICE_DELETE: + self.stop_device_thread(device_id) + elif change_type == NotifyChangeType.DEVICE_UPDATE: + self.restart_device_thread(device_id) + + def on_model_change(self, model_id, change_type): + """模型变化回调""" + if change_type == NotifyChangeType.MODEL_UPDATE: + self.model_manager.reload_model(model_id) + + devices_to_reload = [ + device_id for device_id, relation_info in self.device_model_relations.items() + if relation_info.model_id == model_id + ] + for device_id in devices_to_reload: + self.restart_device_thread(device_id) + + def on_relation_change(self, device_id, change_type): + """设备模型关系变化回调""" + if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: + self.restart_device_thread(device_id) diff --git a/algo/model_manager.py b/algo/model_manager.py new file mode 100644 index 0000000..3ab0709 --- /dev/null +++ b/algo/model_manager.py @@ -0,0 +1,64 @@ +from dataclasses import dataclass +from typing import Optional + +import numpy as np +from ultralytics import YOLO + +from entity.model import AlgoModel +from services.model_service import ModelService +from common.global_logger import logger + + +@dataclass +class AlgoModelExec: + algo_model_id: int + algo_model_info: AlgoModel + algo_model_exec: Optional[object] = None + input_size: int = 640 + + +class ModelManager: + def __init__(self, db, model_warm_up=5): + self.db = db + self.model_service = ModelService(self.db) + self.models = {} + self.model_warm_up = model_warm_up + + self.load_models() + + def query_model_inuse(self): + algo_model_list = list(self.model_service.get_models_in_use()) + for algo_model in algo_model_list: + self.models[algo_model.id] = AlgoModelExec( + algo_model_id=algo_model.id, + algo_model_info=algo_model + ) + + def load_models(self): + self.models = {} + self.query_model_inuse() + for algo_model_id, algo_model_exec in self.models.items(): + self.load_model(algo_model_exec) + + def load_model(self, algo_model_exec: AlgoModelExec): + model_name = algo_model_exec.algo_model_info.name + model_path = algo_model_exec.algo_model_info.path + logger.info(f'loading model {model_name}: {model_path}') + + algo_model_exec.algo_model_exec = YOLO(model_path, task='detect') + if self.model_warm_up > 0: + logger.info(f'warming up model {model_name}') + imgsz = algo_model_exec.input_size + if not isinstance(imgsz, list): + imgsz = [imgsz, imgsz] + dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + for i in range(self.model_warm_up): + algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + logger.info(f'warm up model {model_name} success!') + logger.info(f'load model {model_name} success!') + + def reload_model(self, model_id): + algo_model_exec = self.models.get(model_id) + if algo_model_exec: + algo_model_exec.algo_model_exec = None + self.load_model(algo_model_exec) diff --git a/apis/__init__.py b/apis/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/apis/__init__.py diff --git a/apis/base.py b/apis/base.py new file mode 100644 index 0000000..a7c4ede --- /dev/null +++ b/apis/base.py @@ -0,0 +1,26 @@ +from pydantic import BaseModel, Field +from typing import Optional, Generic, TypeVar, Any, List + +# 定义一个泛型类型变量 T +T = TypeVar("T") + + +# 使用泛型 T 定义 data 的类型 +class StandardResponse(BaseModel, Generic[T]): + code: int = 200 + data: Optional[T] = None + message: str = "请求成功" + success: bool = True + + +class PageResponse(BaseModel, Generic[T]): + total: int + items: List[T] = Field(default_factory=list) + + +def standard_response(data: Any = None, code: int = 200, message: str = "请求成功", success: bool = True): + return StandardResponse(data=data, code=code, message=message, success=success) + + +def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False): + return StandardResponse(data=data, code=code, message=message, success=success) diff --git a/apis/device.py b/apis/device.py new file mode 100644 index 0000000..7eeaa78 --- /dev/null +++ b/apis/device.py @@ -0,0 +1,65 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, Query +from sqlmodel import Session + +from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response +from db.database import get_db +from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo +from services.device_service import DeviceService + +router = APIRouter() + + +@router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) +def get_device_list( + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + db: Session = Depends(get_db)): + service = DeviceService(db) + devices = list(service.get_device_list(name, code, device_type)) + return standard_response(data=devices) + + +@router.get("/page/", response_model=StandardResponse[PageResponse[Device]]) +def get_device_page( + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = DeviceService(db) + + # 获取分页后的设备列表和总数 + devices, total = service.get_device_page(name, code, device_type, offset, limit) + + return standard_response( + data=PageResponse(total=total, items=devices) + ) + + +@router.post("/add", response_model=StandardResponse[DeviceInfo]) +def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.create_device(device_data) + return standard_response(data=device) + + +@router.post("/update", response_model=StandardResponse[DeviceInfo]) +def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.update_device(device_data) + if not device: + return standard_error_response(data=device_data, message="Device not found") + return standard_response(data=device) + + +@router.delete("/delete", response_model=StandardResponse[int]) +def delete_device(device_id: int, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.delete_device(device_id) + if not device: + return standard_error_response(data=device_id, message="Device not found") + return standard_response(data=device_id) diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py new file mode 100644 index 0000000..7108571 --- /dev/null +++ b/apis/device_model_realtion.py @@ -0,0 +1,45 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, Query +from sqlmodel import Session + +from apis.base import StandardResponse, standard_response +from db.database import get_db +from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation +from services.device_model_relation_service import DeviceModelRelationService + +router = APIRouter() + + +@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) +def list_by_device( + device_id: int, + db: Session = Depends(get_db)): + service = DeviceModelRelationService(db) + models = list(service.get_device_models(device_id)) + return standard_response(data=models) + + +# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) +# def add_by_device(relation_data: List[DeviceModelRelationCreate], +# device_id: int = Query(...), +# db: Session = Depends(get_db)): +# service = DeviceModelRelationService(db) +# relations = service.add_relations_by_device(device_id, relation_data) +# return standard_response(data=relations) + + +@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) +def update_by_device(relation_data: List[DeviceModelRelationCreate], + device_id: int = Query(...), + db: Session = Depends(get_db)): + service = DeviceModelRelationService(db) + relations = service.update_relations_by_device(device_id, relation_data) + return standard_response(data=relations) + + +# @router.delete("/delete_by_device", response_model=StandardResponse[int]) +# def delete_device(device_id: int, db: Session = Depends(get_db)): +# service = DeviceModelRelationService(db) +# count = service.delete_relations_by_device(device_id) +# return standard_response(data=count) diff --git a/apis/model.py b/apis/model.py new file mode 100644 index 0000000..8adcc14 --- /dev/null +++ b/apis/model.py @@ -0,0 +1,64 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlmodel import Session + +from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response +from common.biz_exception import BizException +from db.database import get_db +from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo +from services.model_service import ModelService + +router = APIRouter() + + +@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) +def get_model_list( + name: Optional[str] = None, + remark: Optional[str] = None, + db: Session = Depends(get_db)): + service = ModelService(db) + models = list(service.get_model_list(name, remark)) + return standard_response(data=models) + + +@router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModel]]) +def get_model_page( + name: Optional[str] = None, + remark: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = ModelService(db) + + # 获取分页后的设备列表和总数 + models, total = service.get_model_page(name, remark, offset, limit) + + return standard_response( + data=PageResponse(total=total, items=models) + ) + + +@router.post("/add", response_model=StandardResponse[AlgoModelInfo]) +def create_model(model_data: AlgoModelCreate, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.create_model(model_data) + return standard_response(data=model) + + +@router.post("/update", response_model=StandardResponse[AlgoModelInfo]) +def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.update_model(model_data) + if not model: + return standard_error_response(data=model_data, message="Model not found") + return standard_response(data=model) + + +@router.delete("/delete", response_model=StandardResponse[int]) +def delete_model(model_id: int, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.delete_model(model_id) + if not model: + return standard_error_response(data=model_id, message="Model not found") + return standard_response(data=model_id) diff --git a/apis/router.py b/apis/router.py new file mode 100644 index 0000000..22b5ed6 --- /dev/null +++ b/apis/router.py @@ -0,0 +1,15 @@ +# api/router.py +from fastapi import APIRouter +from .device import router as devices_router +from .model import router as models_router +from .device_model_realtion import router as device_model_relation_router + + +# 创建一个全局的 router +router = APIRouter() + +# 将各个模块的 router 注册到全局 router 中 +router.include_router(devices_router, prefix="/device", tags=["Devices"]) +router.include_router(models_router, prefix="/model", tags=["Models"]) +router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"]) + diff --git a/common/__init__.py b/common/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/common/__init__.py diff --git a/common/biz_exception.py b/common/biz_exception.py new file mode 100644 index 0000000..946a794 --- /dev/null +++ b/common/biz_exception.py @@ -0,0 +1,14 @@ +from fastapi import Request, HTTPException + +from apis.base import standard_error_response + + +class BizException(HTTPException): + def __init__(self, message: str = "请求异常", status_code: int = 500): + super().__init__(status_code=status_code, detail=message) + + +class BizExceptionHandlers: + @staticmethod + async def biz_exception_handler(request: Request, exc: HTTPException): + standard_error_response(code=exc.status_code, message=exc.detail) diff --git a/common/consts.py b/common/consts.py new file mode 100644 index 0000000..bba41e4 --- /dev/null +++ b/common/consts.py @@ -0,0 +1,16 @@ +class Constants: + def __setattr__(self, name, value): + raise AttributeError(f"Cannot modify constant {name}") + + +class NotifyChangeType(Constants): + # 定义常量成员 + DEVICE_CREATE = "device_create" + DEVICE_UPDATE = "device_update" + DEVICE_DELETE = "device_delete" + MODEL_CREATE = "model_create" # 模型新增不用通知,与设备绑定时才生效 + MODEL_UPDATE = "model_update" + MODEL_DELETE = "model_delete" # 正在使用的模型不能直接删除 + DEVICE_MODEL_RELATION_CREATE = "device_model_relation_create" + DEVICE_MODEL_RELATION_UPDATE = "device_model_relation_update" # 绑定关系变化 应该只用这个吧?? + DEVICE_MODEL_RELATION_DELETE = "device_model_relation_delete" diff --git a/common/global_logger.py b/common/global_logger.py new file mode 100644 index 0000000..1d65173 --- /dev/null +++ b/common/global_logger.py @@ -0,0 +1,30 @@ +# logger.py +import logging.handlers +import os + +# 确保日志目录存在 +log_dir = '../logs' +if not os.path.exists(log_dir): + os.makedirs(log_dir) + +# 实例化并导出全局日志记录器 +logger = logging.getLogger("casic_safe_logger") +logger.setLevel(logging.DEBUG) # 设置日志级别 + +# 创建一个TimedRotatingFileHandler +handler = logging.handlers.TimedRotatingFileHandler( + os.path.join(log_dir, 'app.log'), # 日志文件名 + when='midnight', # 每天午夜滚动 + interval=1 # 滚动间隔为1天 +) +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +handler.setFormatter(formatter) + +# 将handler添加到日志器 +logger.addHandler(handler) + +# 创建控制台处理器 +console_handler = logging.StreamHandler() +console_handler.setLevel(logging.DEBUG) +console_handler.setFormatter(formatter) +logger.addHandler(console_handler) diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py new file mode 100644 index 0000000..21d7fd5 --- /dev/null +++ b/common/global_thread_pool.py @@ -0,0 +1,60 @@ +import uuid +from concurrent.futures import ThreadPoolExecutor +import threading + + +def generate_thread_id(): + """生成唯一的线程 ID""" + return str(uuid.uuid4()) + + +class GlobalThreadPool: + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + with cls._lock: + if cls._instance is None: + # 第一次创建实例时调用父类的 __new__ 来创建实例 + cls._instance = super(GlobalThreadPool, cls).__new__(cls) + # 在此进行一次性的初始化,比如线程池的创建 + max_workers = kwargs.get('max_workers', 10) + cls._instance.executor = ThreadPoolExecutor(max_workers=max_workers) + cls._instance.task_map = {} # 初始化任务映射 + return cls._instance + + def submit_task(self, fn, *args, thread_id=None, **kwargs): + """提交任务到线程池,并记录线程 ID""" + if thread_id is None: + thread_id = generate_thread_id() + if self.check_task_is_running(thread_id): + raise ValueError(f"线程 ID {thread_id} 已存在") + future = self.executor.submit(fn, *args, **kwargs) + self.task_map[thread_id] = future # 记录线程 ID 和 Future 对象的映射 + return thread_id + + def check_task_is_running(self, thread_id): + future = self.task_map.get(thread_id) + if future: + if future.running(): + return True + else: + del self.task_map[thread_id] + return False + else: + return False + + def stop_task(self, thread_id): + """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务""" + future = self.task_map.get(thread_id) + if future: + future.cancel() # 尝试取消任务 + print(f"任务 {thread_id} 已取消") + del self.task_map[thread_id] # 从任务映射中删除 + else: + print(f"未找到线程 ID {thread_id}") + + def shutdown(self, wait=True): + """关闭线程池""" + self.executor.shutdown(wait=wait) + GlobalThreadPool._instance = None diff --git a/db/__init__.py b/db/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/db/__init__.py diff --git a/db/database.py b/db/database.py new file mode 100644 index 0000000..4f44d2d --- /dev/null +++ b/db/database.py @@ -0,0 +1,22 @@ +from sqlmodel import SQLModel, create_engine, Session +from contextlib import contextmanager + +sqlite_file_name = "./db/safe-algo-pro.db" +sqlite_url = f"sqlite:///{sqlite_file_name}" + +connect_args = {"check_same_thread": False} +engine = create_engine(sqlite_url, connect_args=connect_args) + + +# 初始化数据库表 +def init_db(): + SQLModel.metadata.create_all(engine) + + +# 数据库会话管理 +def get_db(): + session = Session(engine) + try: + yield session + finally: + session.close() diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db new file mode 100644 index 0000000..86c8d64 --- /dev/null +++ b/db/safe-algo-pro.db Binary files differ diff --git a/entity/__init__.py b/entity/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/entity/__init__.py diff --git a/entity/base.py b/entity/base.py new file mode 100644 index 0000000..6aa862c --- /dev/null +++ b/entity/base.py @@ -0,0 +1,13 @@ +from sqlmodel import SQLModel, Field +from typing import Optional +from datetime import datetime + + +class TimestampMixin(SQLModel): + create_time: Optional[datetime] = Field(default_factory=datetime.now) + update_time: Optional[datetime] = Field(default_factory=datetime.now) + + class Config: + json_encoders = { + datetime: lambda v: v.strftime('%Y-%m-%d %H:%M:%S') + } diff --git a/entity/device.py b/entity/device.py new file mode 100644 index 0000000..f09b52d --- /dev/null +++ b/entity/device.py @@ -0,0 +1,33 @@ +from typing import Optional +from sqlmodel import SQLModel, Field + +from entity.base import TimestampMixin + + +class DeviceBase(SQLModel): + name: str + code: str + type: Optional[str] = None + ip: str + input_stream_url: Optional[str] = None + output_stream_url: Optional[str] = None + image_save_interval: Optional[int] = None + + +class Device(DeviceBase, TimestampMixin, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + + +class DeviceCreate(DeviceBase): + pass + + +class DeviceUpdate(DeviceBase): + id: int + name: Optional[str] = None + code: Optional[str] = None + ip: Optional[str] = None + + +class DeviceInfo(DeviceBase, TimestampMixin): + id: int diff --git a/entity/device_model_relation.py b/entity/device_model_relation.py new file mode 100644 index 0000000..2e58a39 --- /dev/null +++ b/entity/device_model_relation.py @@ -0,0 +1,31 @@ +from typing import Optional + +from sqlmodel import SQLModel, Field + +from entity.base import TimestampMixin + + +class DeviceModelRelationBase(SQLModel): + algo_model_id: int + is_use: int + threshold: Optional[float] = None + alarm_interval: Optional[int] = None + alarm_type: Optional[str] = None + + +class DeviceModelRelation(DeviceModelRelationBase, TimestampMixin, table=True): + __tablename__ = 'device_model_relation' + id: Optional[int] = Field(default=None, primary_key=True) + device_id: int + +class DeviceModelRelationCreate(DeviceModelRelationBase): + pass + + +class DeviceModelRelationInfo(DeviceModelRelationBase, TimestampMixin): + id: int + device_id: int + algo_model_name: str + algo_model_version: str + algo_model_path: Optional[str] = None # 可选字段 + algo_model_remark: Optional[str] = None diff --git a/entity/model.py b/entity/model.py new file mode 100644 index 0000000..db9133e --- /dev/null +++ b/entity/model.py @@ -0,0 +1,33 @@ +from typing import Optional + +from sqlmodel import SQLModel, Field + +from entity.base import TimestampMixin + + +class AlgoModelBase(SQLModel): + name: str + version: str + path: str + remark: Optional[str] = None + + +class AlgoModel(AlgoModelBase, TimestampMixin, table=True): + __tablename__ = "algo_model" # 显式指定表名 + + id: Optional[int] = Field(default=None, primary_key=True) + + +class AlgoModelCreate(AlgoModelBase): + pass + + +class AlgoModelUpdate(AlgoModelBase): + id: int + name: Optional[str] = None + version: Optional[str] = None + path: Optional[str] = None + + +class AlgoModelInfo(AlgoModelBase, TimestampMixin): + id: int diff --git a/main.py b/main.py new file mode 100644 index 0000000..75c9b52 --- /dev/null +++ b/main.py @@ -0,0 +1,15 @@ +from fastapi import FastAPI, HTTPException +from apis.router import router +import uvicorn + +from common.biz_exception import BizExceptionHandlers + +app = FastAPI() + +# 包含所有模块的路由 +app.include_router(router, prefix="/api") +app.add_exception_handler(HTTPException, BizExceptionHandlers.http_exception_handler) + +if __name__ == "__main__": + + uvicorn.run(app, host="0.0.0.0", port=9000) \ No newline at end of file diff --git a/services/__init__.py b/services/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/services/__init__.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e530a86 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +test.py +/logs/* \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..901fe74 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,20 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..b68d82f --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..3deddad --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml new file mode 100644 index 0000000..d0876a7 --- /dev/null +++ b/.idea/safe-algo-pro.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/algo/__init__.py b/algo/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/algo/__init__.py diff --git a/algo/algo_runner.py b/algo/algo_runner.py new file mode 100644 index 0000000..ee07eaf --- /dev/null +++ b/algo/algo_runner.py @@ -0,0 +1,117 @@ +from algo.model_manager import ModelManager +from common.consts import NotifyChangeType +from common.global_thread_pool import GlobalThreadPool +from db.database import get_db +from entity.device import Device +from services.device_model_relation_service import DeviceModelRelationService +from services.device_service import DeviceService +from services.model_service import ModelService + + +class AlgoRunner: + def __init__(self): + self.db = get_db() + self.device_service = DeviceService(self.db) + self.model_service = ModelService(self.db) + self.relation_service = DeviceModelRelationService(self.db) + self.model_manager = ModelManager(self.db) + self.thread_pool = GlobalThreadPool() + self.threads = {} # 用于存储设备对应的线程 + self.device_model_relations = {} + + # 注册设备和模型的变化回调 + self.device_service.register_change_callback(self.on_device_change) + self.model_service.register_change_callback(self.on_model_change) + self.relation_service.relation_change_callbacks(self.on_relation_change) + + def start(self): + """在程序启动时调用,读取设备和模型,启动检测线程""" + self.model_manager.load_models() + self.load_and_start_devices() + + def load_and_start_devices(self): + """从数据库读取设备列表并启动线程""" + devices = self.device_service.get_device_list() + devices = [device for device in devices if device.input_stream_url] + for device in devices: + self.start_device_thread(device) + + def run_detection(self, device: Device): + """todo 设备目标检测主逻辑""" + print(f"设备 {device.device_id} 的检测线程启动") + video_stream = self.device_service.get_video_stream(device.device_id) + + while True: + try: + frame = video_stream.read_frame() + for model in self.model_manager.models: + # 调用目标检测模型 + result = self.model_service.run_inference(model, frame) + # 在此处进行后处理,例如标记视频、生成告警等 + self.process_result(result, device.device_id) + except Exception as e: + print(f"设备 {device.device_id} 处理时出错: {e}") + break + video_stream.close() + print(f"设备 {device.device_id} 的检测线程结束") + + def start_device_thread(self, device: Device): + """为单个设备启动检测线程""" + + if not device.input_stream_url: + print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程') + return + + if device.device_id in self.threads: + print(f"设备 {device.device_id} 已经在运行中") + return + + # 获取设备绑定的模型列表 + relations = self.relation_service.get_device_models(device.device_id) + self.device_model_relations[device.device_id] = relations + + if not relations: + print(f"设备 {device.code} 未绑定模型,无法启动检测") + return + + future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}') + + def stop_device_thread(self, device_id): + """todo 控制标志位 停止指定设备的检测线程""" + if device_id in self.threads: + self.threads[device_id].cancel() # 尝试取消线程 + del self.threads[device_id] + print(f"设备 {device_id} 的检测线程已停止") + + def restart_device_thread(self, device_id): + self.stop_device_thread(device_id) + # todo 需要留个关闭时间吗 + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + + def on_device_change(self, device_id, change_type): + """设备变化回调""" + if change_type == NotifyChangeType.DEVICE_CREATE: + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + elif change_type == NotifyChangeType.DEVICE_DELETE: + self.stop_device_thread(device_id) + elif change_type == NotifyChangeType.DEVICE_UPDATE: + self.restart_device_thread(device_id) + + def on_model_change(self, model_id, change_type): + """模型变化回调""" + if change_type == NotifyChangeType.MODEL_UPDATE: + self.model_manager.reload_model(model_id) + + devices_to_reload = [ + device_id for device_id, relation_info in self.device_model_relations.items() + if relation_info.model_id == model_id + ] + for device_id in devices_to_reload: + self.restart_device_thread(device_id) + + def on_relation_change(self, device_id, change_type): + """设备模型关系变化回调""" + if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: + self.restart_device_thread(device_id) diff --git a/algo/model_manager.py b/algo/model_manager.py new file mode 100644 index 0000000..3ab0709 --- /dev/null +++ b/algo/model_manager.py @@ -0,0 +1,64 @@ +from dataclasses import dataclass +from typing import Optional + +import numpy as np +from ultralytics import YOLO + +from entity.model import AlgoModel +from services.model_service import ModelService +from common.global_logger import logger + + +@dataclass +class AlgoModelExec: + algo_model_id: int + algo_model_info: AlgoModel + algo_model_exec: Optional[object] = None + input_size: int = 640 + + +class ModelManager: + def __init__(self, db, model_warm_up=5): + self.db = db + self.model_service = ModelService(self.db) + self.models = {} + self.model_warm_up = model_warm_up + + self.load_models() + + def query_model_inuse(self): + algo_model_list = list(self.model_service.get_models_in_use()) + for algo_model in algo_model_list: + self.models[algo_model.id] = AlgoModelExec( + algo_model_id=algo_model.id, + algo_model_info=algo_model + ) + + def load_models(self): + self.models = {} + self.query_model_inuse() + for algo_model_id, algo_model_exec in self.models.items(): + self.load_model(algo_model_exec) + + def load_model(self, algo_model_exec: AlgoModelExec): + model_name = algo_model_exec.algo_model_info.name + model_path = algo_model_exec.algo_model_info.path + logger.info(f'loading model {model_name}: {model_path}') + + algo_model_exec.algo_model_exec = YOLO(model_path, task='detect') + if self.model_warm_up > 0: + logger.info(f'warming up model {model_name}') + imgsz = algo_model_exec.input_size + if not isinstance(imgsz, list): + imgsz = [imgsz, imgsz] + dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + for i in range(self.model_warm_up): + algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + logger.info(f'warm up model {model_name} success!') + logger.info(f'load model {model_name} success!') + + def reload_model(self, model_id): + algo_model_exec = self.models.get(model_id) + if algo_model_exec: + algo_model_exec.algo_model_exec = None + self.load_model(algo_model_exec) diff --git a/apis/__init__.py b/apis/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/apis/__init__.py diff --git a/apis/base.py b/apis/base.py new file mode 100644 index 0000000..a7c4ede --- /dev/null +++ b/apis/base.py @@ -0,0 +1,26 @@ +from pydantic import BaseModel, Field +from typing import Optional, Generic, TypeVar, Any, List + +# 定义一个泛型类型变量 T +T = TypeVar("T") + + +# 使用泛型 T 定义 data 的类型 +class StandardResponse(BaseModel, Generic[T]): + code: int = 200 + data: Optional[T] = None + message: str = "请求成功" + success: bool = True + + +class PageResponse(BaseModel, Generic[T]): + total: int + items: List[T] = Field(default_factory=list) + + +def standard_response(data: Any = None, code: int = 200, message: str = "请求成功", success: bool = True): + return StandardResponse(data=data, code=code, message=message, success=success) + + +def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False): + return StandardResponse(data=data, code=code, message=message, success=success) diff --git a/apis/device.py b/apis/device.py new file mode 100644 index 0000000..7eeaa78 --- /dev/null +++ b/apis/device.py @@ -0,0 +1,65 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, Query +from sqlmodel import Session + +from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response +from db.database import get_db +from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo +from services.device_service import DeviceService + +router = APIRouter() + + +@router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) +def get_device_list( + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + db: Session = Depends(get_db)): + service = DeviceService(db) + devices = list(service.get_device_list(name, code, device_type)) + return standard_response(data=devices) + + +@router.get("/page/", response_model=StandardResponse[PageResponse[Device]]) +def get_device_page( + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = DeviceService(db) + + # 获取分页后的设备列表和总数 + devices, total = service.get_device_page(name, code, device_type, offset, limit) + + return standard_response( + data=PageResponse(total=total, items=devices) + ) + + +@router.post("/add", response_model=StandardResponse[DeviceInfo]) +def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.create_device(device_data) + return standard_response(data=device) + + +@router.post("/update", response_model=StandardResponse[DeviceInfo]) +def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.update_device(device_data) + if not device: + return standard_error_response(data=device_data, message="Device not found") + return standard_response(data=device) + + +@router.delete("/delete", response_model=StandardResponse[int]) +def delete_device(device_id: int, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.delete_device(device_id) + if not device: + return standard_error_response(data=device_id, message="Device not found") + return standard_response(data=device_id) diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py new file mode 100644 index 0000000..7108571 --- /dev/null +++ b/apis/device_model_realtion.py @@ -0,0 +1,45 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, Query +from sqlmodel import Session + +from apis.base import StandardResponse, standard_response +from db.database import get_db +from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation +from services.device_model_relation_service import DeviceModelRelationService + +router = APIRouter() + + +@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) +def list_by_device( + device_id: int, + db: Session = Depends(get_db)): + service = DeviceModelRelationService(db) + models = list(service.get_device_models(device_id)) + return standard_response(data=models) + + +# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) +# def add_by_device(relation_data: List[DeviceModelRelationCreate], +# device_id: int = Query(...), +# db: Session = Depends(get_db)): +# service = DeviceModelRelationService(db) +# relations = service.add_relations_by_device(device_id, relation_data) +# return standard_response(data=relations) + + +@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) +def update_by_device(relation_data: List[DeviceModelRelationCreate], + device_id: int = Query(...), + db: Session = Depends(get_db)): + service = DeviceModelRelationService(db) + relations = service.update_relations_by_device(device_id, relation_data) + return standard_response(data=relations) + + +# @router.delete("/delete_by_device", response_model=StandardResponse[int]) +# def delete_device(device_id: int, db: Session = Depends(get_db)): +# service = DeviceModelRelationService(db) +# count = service.delete_relations_by_device(device_id) +# return standard_response(data=count) diff --git a/apis/model.py b/apis/model.py new file mode 100644 index 0000000..8adcc14 --- /dev/null +++ b/apis/model.py @@ -0,0 +1,64 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlmodel import Session + +from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response +from common.biz_exception import BizException +from db.database import get_db +from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo +from services.model_service import ModelService + +router = APIRouter() + + +@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) +def get_model_list( + name: Optional[str] = None, + remark: Optional[str] = None, + db: Session = Depends(get_db)): + service = ModelService(db) + models = list(service.get_model_list(name, remark)) + return standard_response(data=models) + + +@router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModel]]) +def get_model_page( + name: Optional[str] = None, + remark: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = ModelService(db) + + # 获取分页后的设备列表和总数 + models, total = service.get_model_page(name, remark, offset, limit) + + return standard_response( + data=PageResponse(total=total, items=models) + ) + + +@router.post("/add", response_model=StandardResponse[AlgoModelInfo]) +def create_model(model_data: AlgoModelCreate, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.create_model(model_data) + return standard_response(data=model) + + +@router.post("/update", response_model=StandardResponse[AlgoModelInfo]) +def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.update_model(model_data) + if not model: + return standard_error_response(data=model_data, message="Model not found") + return standard_response(data=model) + + +@router.delete("/delete", response_model=StandardResponse[int]) +def delete_model(model_id: int, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.delete_model(model_id) + if not model: + return standard_error_response(data=model_id, message="Model not found") + return standard_response(data=model_id) diff --git a/apis/router.py b/apis/router.py new file mode 100644 index 0000000..22b5ed6 --- /dev/null +++ b/apis/router.py @@ -0,0 +1,15 @@ +# api/router.py +from fastapi import APIRouter +from .device import router as devices_router +from .model import router as models_router +from .device_model_realtion import router as device_model_relation_router + + +# 创建一个全局的 router +router = APIRouter() + +# 将各个模块的 router 注册到全局 router 中 +router.include_router(devices_router, prefix="/device", tags=["Devices"]) +router.include_router(models_router, prefix="/model", tags=["Models"]) +router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"]) + diff --git a/common/__init__.py b/common/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/common/__init__.py diff --git a/common/biz_exception.py b/common/biz_exception.py new file mode 100644 index 0000000..946a794 --- /dev/null +++ b/common/biz_exception.py @@ -0,0 +1,14 @@ +from fastapi import Request, HTTPException + +from apis.base import standard_error_response + + +class BizException(HTTPException): + def __init__(self, message: str = "请求异常", status_code: int = 500): + super().__init__(status_code=status_code, detail=message) + + +class BizExceptionHandlers: + @staticmethod + async def biz_exception_handler(request: Request, exc: HTTPException): + standard_error_response(code=exc.status_code, message=exc.detail) diff --git a/common/consts.py b/common/consts.py new file mode 100644 index 0000000..bba41e4 --- /dev/null +++ b/common/consts.py @@ -0,0 +1,16 @@ +class Constants: + def __setattr__(self, name, value): + raise AttributeError(f"Cannot modify constant {name}") + + +class NotifyChangeType(Constants): + # 定义常量成员 + DEVICE_CREATE = "device_create" + DEVICE_UPDATE = "device_update" + DEVICE_DELETE = "device_delete" + MODEL_CREATE = "model_create" # 模型新增不用通知,与设备绑定时才生效 + MODEL_UPDATE = "model_update" + MODEL_DELETE = "model_delete" # 正在使用的模型不能直接删除 + DEVICE_MODEL_RELATION_CREATE = "device_model_relation_create" + DEVICE_MODEL_RELATION_UPDATE = "device_model_relation_update" # 绑定关系变化 应该只用这个吧?? + DEVICE_MODEL_RELATION_DELETE = "device_model_relation_delete" diff --git a/common/global_logger.py b/common/global_logger.py new file mode 100644 index 0000000..1d65173 --- /dev/null +++ b/common/global_logger.py @@ -0,0 +1,30 @@ +# logger.py +import logging.handlers +import os + +# 确保日志目录存在 +log_dir = '../logs' +if not os.path.exists(log_dir): + os.makedirs(log_dir) + +# 实例化并导出全局日志记录器 +logger = logging.getLogger("casic_safe_logger") +logger.setLevel(logging.DEBUG) # 设置日志级别 + +# 创建一个TimedRotatingFileHandler +handler = logging.handlers.TimedRotatingFileHandler( + os.path.join(log_dir, 'app.log'), # 日志文件名 + when='midnight', # 每天午夜滚动 + interval=1 # 滚动间隔为1天 +) +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +handler.setFormatter(formatter) + +# 将handler添加到日志器 +logger.addHandler(handler) + +# 创建控制台处理器 +console_handler = logging.StreamHandler() +console_handler.setLevel(logging.DEBUG) +console_handler.setFormatter(formatter) +logger.addHandler(console_handler) diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py new file mode 100644 index 0000000..21d7fd5 --- /dev/null +++ b/common/global_thread_pool.py @@ -0,0 +1,60 @@ +import uuid +from concurrent.futures import ThreadPoolExecutor +import threading + + +def generate_thread_id(): + """生成唯一的线程 ID""" + return str(uuid.uuid4()) + + +class GlobalThreadPool: + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + with cls._lock: + if cls._instance is None: + # 第一次创建实例时调用父类的 __new__ 来创建实例 + cls._instance = super(GlobalThreadPool, cls).__new__(cls) + # 在此进行一次性的初始化,比如线程池的创建 + max_workers = kwargs.get('max_workers', 10) + cls._instance.executor = ThreadPoolExecutor(max_workers=max_workers) + cls._instance.task_map = {} # 初始化任务映射 + return cls._instance + + def submit_task(self, fn, *args, thread_id=None, **kwargs): + """提交任务到线程池,并记录线程 ID""" + if thread_id is None: + thread_id = generate_thread_id() + if self.check_task_is_running(thread_id): + raise ValueError(f"线程 ID {thread_id} 已存在") + future = self.executor.submit(fn, *args, **kwargs) + self.task_map[thread_id] = future # 记录线程 ID 和 Future 对象的映射 + return thread_id + + def check_task_is_running(self, thread_id): + future = self.task_map.get(thread_id) + if future: + if future.running(): + return True + else: + del self.task_map[thread_id] + return False + else: + return False + + def stop_task(self, thread_id): + """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务""" + future = self.task_map.get(thread_id) + if future: + future.cancel() # 尝试取消任务 + print(f"任务 {thread_id} 已取消") + del self.task_map[thread_id] # 从任务映射中删除 + else: + print(f"未找到线程 ID {thread_id}") + + def shutdown(self, wait=True): + """关闭线程池""" + self.executor.shutdown(wait=wait) + GlobalThreadPool._instance = None diff --git a/db/__init__.py b/db/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/db/__init__.py diff --git a/db/database.py b/db/database.py new file mode 100644 index 0000000..4f44d2d --- /dev/null +++ b/db/database.py @@ -0,0 +1,22 @@ +from sqlmodel import SQLModel, create_engine, Session +from contextlib import contextmanager + +sqlite_file_name = "./db/safe-algo-pro.db" +sqlite_url = f"sqlite:///{sqlite_file_name}" + +connect_args = {"check_same_thread": False} +engine = create_engine(sqlite_url, connect_args=connect_args) + + +# 初始化数据库表 +def init_db(): + SQLModel.metadata.create_all(engine) + + +# 数据库会话管理 +def get_db(): + session = Session(engine) + try: + yield session + finally: + session.close() diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db new file mode 100644 index 0000000..86c8d64 --- /dev/null +++ b/db/safe-algo-pro.db Binary files differ diff --git a/entity/__init__.py b/entity/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/entity/__init__.py diff --git a/entity/base.py b/entity/base.py new file mode 100644 index 0000000..6aa862c --- /dev/null +++ b/entity/base.py @@ -0,0 +1,13 @@ +from sqlmodel import SQLModel, Field +from typing import Optional +from datetime import datetime + + +class TimestampMixin(SQLModel): + create_time: Optional[datetime] = Field(default_factory=datetime.now) + update_time: Optional[datetime] = Field(default_factory=datetime.now) + + class Config: + json_encoders = { + datetime: lambda v: v.strftime('%Y-%m-%d %H:%M:%S') + } diff --git a/entity/device.py b/entity/device.py new file mode 100644 index 0000000..f09b52d --- /dev/null +++ b/entity/device.py @@ -0,0 +1,33 @@ +from typing import Optional +from sqlmodel import SQLModel, Field + +from entity.base import TimestampMixin + + +class DeviceBase(SQLModel): + name: str + code: str + type: Optional[str] = None + ip: str + input_stream_url: Optional[str] = None + output_stream_url: Optional[str] = None + image_save_interval: Optional[int] = None + + +class Device(DeviceBase, TimestampMixin, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + + +class DeviceCreate(DeviceBase): + pass + + +class DeviceUpdate(DeviceBase): + id: int + name: Optional[str] = None + code: Optional[str] = None + ip: Optional[str] = None + + +class DeviceInfo(DeviceBase, TimestampMixin): + id: int diff --git a/entity/device_model_relation.py b/entity/device_model_relation.py new file mode 100644 index 0000000..2e58a39 --- /dev/null +++ b/entity/device_model_relation.py @@ -0,0 +1,31 @@ +from typing import Optional + +from sqlmodel import SQLModel, Field + +from entity.base import TimestampMixin + + +class DeviceModelRelationBase(SQLModel): + algo_model_id: int + is_use: int + threshold: Optional[float] = None + alarm_interval: Optional[int] = None + alarm_type: Optional[str] = None + + +class DeviceModelRelation(DeviceModelRelationBase, TimestampMixin, table=True): + __tablename__ = 'device_model_relation' + id: Optional[int] = Field(default=None, primary_key=True) + device_id: int + +class DeviceModelRelationCreate(DeviceModelRelationBase): + pass + + +class DeviceModelRelationInfo(DeviceModelRelationBase, TimestampMixin): + id: int + device_id: int + algo_model_name: str + algo_model_version: str + algo_model_path: Optional[str] = None # 可选字段 + algo_model_remark: Optional[str] = None diff --git a/entity/model.py b/entity/model.py new file mode 100644 index 0000000..db9133e --- /dev/null +++ b/entity/model.py @@ -0,0 +1,33 @@ +from typing import Optional + +from sqlmodel import SQLModel, Field + +from entity.base import TimestampMixin + + +class AlgoModelBase(SQLModel): + name: str + version: str + path: str + remark: Optional[str] = None + + +class AlgoModel(AlgoModelBase, TimestampMixin, table=True): + __tablename__ = "algo_model" # 显式指定表名 + + id: Optional[int] = Field(default=None, primary_key=True) + + +class AlgoModelCreate(AlgoModelBase): + pass + + +class AlgoModelUpdate(AlgoModelBase): + id: int + name: Optional[str] = None + version: Optional[str] = None + path: Optional[str] = None + + +class AlgoModelInfo(AlgoModelBase, TimestampMixin): + id: int diff --git a/main.py b/main.py new file mode 100644 index 0000000..75c9b52 --- /dev/null +++ b/main.py @@ -0,0 +1,15 @@ +from fastapi import FastAPI, HTTPException +from apis.router import router +import uvicorn + +from common.biz_exception import BizExceptionHandlers + +app = FastAPI() + +# 包含所有模块的路由 +app.include_router(router, prefix="/api") +app.add_exception_handler(HTTPException, BizExceptionHandlers.http_exception_handler) + +if __name__ == "__main__": + + uvicorn.run(app, host="0.0.0.0", port=9000) \ No newline at end of file diff --git a/services/__init__.py b/services/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/services/__init__.py diff --git a/services/device_model_relation_service.py b/services/device_model_relation_service.py new file mode 100644 index 0000000..0f8d9a7 --- /dev/null +++ b/services/device_model_relation_service.py @@ -0,0 +1,85 @@ +from datetime import datetime +from typing import List + +from sqlmodel import Session, select, delete + +from common.consts import NotifyChangeType +from common.global_thread_pool import GlobalThreadPool +from entity.device_model_relation import DeviceModelRelation, DeviceModelRelationInfo, DeviceModelRelationCreate +from entity.model import AlgoModel + + +class DeviceModelRelationService: + def __init__(self, db: Session): + self.db = db + self.relation_change_callbacks = [] # 用于存储回调函数 + self.thread_pool = GlobalThreadPool() + + def register_change_callback(self, callback): + """注册设备变化回调函数""" + self.relation_change_callbacks.append(callback) + + def notify_change(self, device_id, change_type): + """当设备发生变化时,调用回调通知变化""" + for callback in self.relation_change_callbacks: + self.thread_pool.executor.submit(callback, device_id, change_type) + + def get_device_models(self, device_id: int) -> List[DeviceModelRelationInfo]: + statement = ( + select(DeviceModelRelation, AlgoModel) + .join(AlgoModel, DeviceModelRelation.algo_model_id == AlgoModel.id) + .where(DeviceModelRelation.device_id == device_id) + ) + + # 执行联表查询 + result = self.db.exec(statement).all() + + models_info = [ + DeviceModelRelationInfo( + id=relation.id, + device_id=relation.id, + algo_model_id=relation.algo_model_id, + is_use=relation.is_use, + threshold=relation.threshold, + alarm_interval=relation.alarm_interval, + alarm_type=relation.alarm_type, + algo_model_name=model.name, + algo_model_version=model.version, + algo_model_path=model.path, + algo_model_remark=model.remark, + ) + for relation, model in result + ] + return models_info + + def add_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]): + new_relations = [ + DeviceModelRelation( + algo_model_id=relation.algo_model_id, + is_use=relation.is_use, + threshold=relation.threshold, + alarm_interval=relation.alarm_interval, + alarm_type=relation.alarm_type, + device_id=device_id, # 统一赋值 device_id + createtime=datetime.now(), + updatetime=datetime.now(), + ) + for relation in relations + ] + self.db.add_all(new_relations) + self.db.commit() + for relation in new_relations: + self.db.refresh(relation) + return new_relations + + def delete_relations_by_device(self, device_id: int): + statement = delete(DeviceModelRelation).where(DeviceModelRelation.device_id == device_id) + count = self.db.exec(statement) + self.db.commit() + return count.rowcount + + def update_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]): + self.delete_relations_by_device(device_id) + new_relations = self.add_relations_by_device(device_id, relations) + self.notify_change(device_id, NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE) + return new_relations diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e530a86 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +test.py +/logs/* \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..901fe74 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,20 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..b68d82f --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..3deddad --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml new file mode 100644 index 0000000..d0876a7 --- /dev/null +++ b/.idea/safe-algo-pro.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/algo/__init__.py b/algo/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/algo/__init__.py diff --git a/algo/algo_runner.py b/algo/algo_runner.py new file mode 100644 index 0000000..ee07eaf --- /dev/null +++ b/algo/algo_runner.py @@ -0,0 +1,117 @@ +from algo.model_manager import ModelManager +from common.consts import NotifyChangeType +from common.global_thread_pool import GlobalThreadPool +from db.database import get_db +from entity.device import Device +from services.device_model_relation_service import DeviceModelRelationService +from services.device_service import DeviceService +from services.model_service import ModelService + + +class AlgoRunner: + def __init__(self): + self.db = get_db() + self.device_service = DeviceService(self.db) + self.model_service = ModelService(self.db) + self.relation_service = DeviceModelRelationService(self.db) + self.model_manager = ModelManager(self.db) + self.thread_pool = GlobalThreadPool() + self.threads = {} # 用于存储设备对应的线程 + self.device_model_relations = {} + + # 注册设备和模型的变化回调 + self.device_service.register_change_callback(self.on_device_change) + self.model_service.register_change_callback(self.on_model_change) + self.relation_service.relation_change_callbacks(self.on_relation_change) + + def start(self): + """在程序启动时调用,读取设备和模型,启动检测线程""" + self.model_manager.load_models() + self.load_and_start_devices() + + def load_and_start_devices(self): + """从数据库读取设备列表并启动线程""" + devices = self.device_service.get_device_list() + devices = [device for device in devices if device.input_stream_url] + for device in devices: + self.start_device_thread(device) + + def run_detection(self, device: Device): + """todo 设备目标检测主逻辑""" + print(f"设备 {device.device_id} 的检测线程启动") + video_stream = self.device_service.get_video_stream(device.device_id) + + while True: + try: + frame = video_stream.read_frame() + for model in self.model_manager.models: + # 调用目标检测模型 + result = self.model_service.run_inference(model, frame) + # 在此处进行后处理,例如标记视频、生成告警等 + self.process_result(result, device.device_id) + except Exception as e: + print(f"设备 {device.device_id} 处理时出错: {e}") + break + video_stream.close() + print(f"设备 {device.device_id} 的检测线程结束") + + def start_device_thread(self, device: Device): + """为单个设备启动检测线程""" + + if not device.input_stream_url: + print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程') + return + + if device.device_id in self.threads: + print(f"设备 {device.device_id} 已经在运行中") + return + + # 获取设备绑定的模型列表 + relations = self.relation_service.get_device_models(device.device_id) + self.device_model_relations[device.device_id] = relations + + if not relations: + print(f"设备 {device.code} 未绑定模型,无法启动检测") + return + + future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}') + + def stop_device_thread(self, device_id): + """todo 控制标志位 停止指定设备的检测线程""" + if device_id in self.threads: + self.threads[device_id].cancel() # 尝试取消线程 + del self.threads[device_id] + print(f"设备 {device_id} 的检测线程已停止") + + def restart_device_thread(self, device_id): + self.stop_device_thread(device_id) + # todo 需要留个关闭时间吗 + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + + def on_device_change(self, device_id, change_type): + """设备变化回调""" + if change_type == NotifyChangeType.DEVICE_CREATE: + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + elif change_type == NotifyChangeType.DEVICE_DELETE: + self.stop_device_thread(device_id) + elif change_type == NotifyChangeType.DEVICE_UPDATE: + self.restart_device_thread(device_id) + + def on_model_change(self, model_id, change_type): + """模型变化回调""" + if change_type == NotifyChangeType.MODEL_UPDATE: + self.model_manager.reload_model(model_id) + + devices_to_reload = [ + device_id for device_id, relation_info in self.device_model_relations.items() + if relation_info.model_id == model_id + ] + for device_id in devices_to_reload: + self.restart_device_thread(device_id) + + def on_relation_change(self, device_id, change_type): + """设备模型关系变化回调""" + if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: + self.restart_device_thread(device_id) diff --git a/algo/model_manager.py b/algo/model_manager.py new file mode 100644 index 0000000..3ab0709 --- /dev/null +++ b/algo/model_manager.py @@ -0,0 +1,64 @@ +from dataclasses import dataclass +from typing import Optional + +import numpy as np +from ultralytics import YOLO + +from entity.model import AlgoModel +from services.model_service import ModelService +from common.global_logger import logger + + +@dataclass +class AlgoModelExec: + algo_model_id: int + algo_model_info: AlgoModel + algo_model_exec: Optional[object] = None + input_size: int = 640 + + +class ModelManager: + def __init__(self, db, model_warm_up=5): + self.db = db + self.model_service = ModelService(self.db) + self.models = {} + self.model_warm_up = model_warm_up + + self.load_models() + + def query_model_inuse(self): + algo_model_list = list(self.model_service.get_models_in_use()) + for algo_model in algo_model_list: + self.models[algo_model.id] = AlgoModelExec( + algo_model_id=algo_model.id, + algo_model_info=algo_model + ) + + def load_models(self): + self.models = {} + self.query_model_inuse() + for algo_model_id, algo_model_exec in self.models.items(): + self.load_model(algo_model_exec) + + def load_model(self, algo_model_exec: AlgoModelExec): + model_name = algo_model_exec.algo_model_info.name + model_path = algo_model_exec.algo_model_info.path + logger.info(f'loading model {model_name}: {model_path}') + + algo_model_exec.algo_model_exec = YOLO(model_path, task='detect') + if self.model_warm_up > 0: + logger.info(f'warming up model {model_name}') + imgsz = algo_model_exec.input_size + if not isinstance(imgsz, list): + imgsz = [imgsz, imgsz] + dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + for i in range(self.model_warm_up): + algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + logger.info(f'warm up model {model_name} success!') + logger.info(f'load model {model_name} success!') + + def reload_model(self, model_id): + algo_model_exec = self.models.get(model_id) + if algo_model_exec: + algo_model_exec.algo_model_exec = None + self.load_model(algo_model_exec) diff --git a/apis/__init__.py b/apis/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/apis/__init__.py diff --git a/apis/base.py b/apis/base.py new file mode 100644 index 0000000..a7c4ede --- /dev/null +++ b/apis/base.py @@ -0,0 +1,26 @@ +from pydantic import BaseModel, Field +from typing import Optional, Generic, TypeVar, Any, List + +# 定义一个泛型类型变量 T +T = TypeVar("T") + + +# 使用泛型 T 定义 data 的类型 +class StandardResponse(BaseModel, Generic[T]): + code: int = 200 + data: Optional[T] = None + message: str = "请求成功" + success: bool = True + + +class PageResponse(BaseModel, Generic[T]): + total: int + items: List[T] = Field(default_factory=list) + + +def standard_response(data: Any = None, code: int = 200, message: str = "请求成功", success: bool = True): + return StandardResponse(data=data, code=code, message=message, success=success) + + +def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False): + return StandardResponse(data=data, code=code, message=message, success=success) diff --git a/apis/device.py b/apis/device.py new file mode 100644 index 0000000..7eeaa78 --- /dev/null +++ b/apis/device.py @@ -0,0 +1,65 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, Query +from sqlmodel import Session + +from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response +from db.database import get_db +from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo +from services.device_service import DeviceService + +router = APIRouter() + + +@router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) +def get_device_list( + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + db: Session = Depends(get_db)): + service = DeviceService(db) + devices = list(service.get_device_list(name, code, device_type)) + return standard_response(data=devices) + + +@router.get("/page/", response_model=StandardResponse[PageResponse[Device]]) +def get_device_page( + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = DeviceService(db) + + # 获取分页后的设备列表和总数 + devices, total = service.get_device_page(name, code, device_type, offset, limit) + + return standard_response( + data=PageResponse(total=total, items=devices) + ) + + +@router.post("/add", response_model=StandardResponse[DeviceInfo]) +def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.create_device(device_data) + return standard_response(data=device) + + +@router.post("/update", response_model=StandardResponse[DeviceInfo]) +def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.update_device(device_data) + if not device: + return standard_error_response(data=device_data, message="Device not found") + return standard_response(data=device) + + +@router.delete("/delete", response_model=StandardResponse[int]) +def delete_device(device_id: int, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.delete_device(device_id) + if not device: + return standard_error_response(data=device_id, message="Device not found") + return standard_response(data=device_id) diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py new file mode 100644 index 0000000..7108571 --- /dev/null +++ b/apis/device_model_realtion.py @@ -0,0 +1,45 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, Query +from sqlmodel import Session + +from apis.base import StandardResponse, standard_response +from db.database import get_db +from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation +from services.device_model_relation_service import DeviceModelRelationService + +router = APIRouter() + + +@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) +def list_by_device( + device_id: int, + db: Session = Depends(get_db)): + service = DeviceModelRelationService(db) + models = list(service.get_device_models(device_id)) + return standard_response(data=models) + + +# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) +# def add_by_device(relation_data: List[DeviceModelRelationCreate], +# device_id: int = Query(...), +# db: Session = Depends(get_db)): +# service = DeviceModelRelationService(db) +# relations = service.add_relations_by_device(device_id, relation_data) +# return standard_response(data=relations) + + +@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) +def update_by_device(relation_data: List[DeviceModelRelationCreate], + device_id: int = Query(...), + db: Session = Depends(get_db)): + service = DeviceModelRelationService(db) + relations = service.update_relations_by_device(device_id, relation_data) + return standard_response(data=relations) + + +# @router.delete("/delete_by_device", response_model=StandardResponse[int]) +# def delete_device(device_id: int, db: Session = Depends(get_db)): +# service = DeviceModelRelationService(db) +# count = service.delete_relations_by_device(device_id) +# return standard_response(data=count) diff --git a/apis/model.py b/apis/model.py new file mode 100644 index 0000000..8adcc14 --- /dev/null +++ b/apis/model.py @@ -0,0 +1,64 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlmodel import Session + +from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response +from common.biz_exception import BizException +from db.database import get_db +from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo +from services.model_service import ModelService + +router = APIRouter() + + +@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) +def get_model_list( + name: Optional[str] = None, + remark: Optional[str] = None, + db: Session = Depends(get_db)): + service = ModelService(db) + models = list(service.get_model_list(name, remark)) + return standard_response(data=models) + + +@router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModel]]) +def get_model_page( + name: Optional[str] = None, + remark: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = ModelService(db) + + # 获取分页后的设备列表和总数 + models, total = service.get_model_page(name, remark, offset, limit) + + return standard_response( + data=PageResponse(total=total, items=models) + ) + + +@router.post("/add", response_model=StandardResponse[AlgoModelInfo]) +def create_model(model_data: AlgoModelCreate, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.create_model(model_data) + return standard_response(data=model) + + +@router.post("/update", response_model=StandardResponse[AlgoModelInfo]) +def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.update_model(model_data) + if not model: + return standard_error_response(data=model_data, message="Model not found") + return standard_response(data=model) + + +@router.delete("/delete", response_model=StandardResponse[int]) +def delete_model(model_id: int, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.delete_model(model_id) + if not model: + return standard_error_response(data=model_id, message="Model not found") + return standard_response(data=model_id) diff --git a/apis/router.py b/apis/router.py new file mode 100644 index 0000000..22b5ed6 --- /dev/null +++ b/apis/router.py @@ -0,0 +1,15 @@ +# api/router.py +from fastapi import APIRouter +from .device import router as devices_router +from .model import router as models_router +from .device_model_realtion import router as device_model_relation_router + + +# 创建一个全局的 router +router = APIRouter() + +# 将各个模块的 router 注册到全局 router 中 +router.include_router(devices_router, prefix="/device", tags=["Devices"]) +router.include_router(models_router, prefix="/model", tags=["Models"]) +router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"]) + diff --git a/common/__init__.py b/common/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/common/__init__.py diff --git a/common/biz_exception.py b/common/biz_exception.py new file mode 100644 index 0000000..946a794 --- /dev/null +++ b/common/biz_exception.py @@ -0,0 +1,14 @@ +from fastapi import Request, HTTPException + +from apis.base import standard_error_response + + +class BizException(HTTPException): + def __init__(self, message: str = "请求异常", status_code: int = 500): + super().__init__(status_code=status_code, detail=message) + + +class BizExceptionHandlers: + @staticmethod + async def biz_exception_handler(request: Request, exc: HTTPException): + standard_error_response(code=exc.status_code, message=exc.detail) diff --git a/common/consts.py b/common/consts.py new file mode 100644 index 0000000..bba41e4 --- /dev/null +++ b/common/consts.py @@ -0,0 +1,16 @@ +class Constants: + def __setattr__(self, name, value): + raise AttributeError(f"Cannot modify constant {name}") + + +class NotifyChangeType(Constants): + # 定义常量成员 + DEVICE_CREATE = "device_create" + DEVICE_UPDATE = "device_update" + DEVICE_DELETE = "device_delete" + MODEL_CREATE = "model_create" # 模型新增不用通知,与设备绑定时才生效 + MODEL_UPDATE = "model_update" + MODEL_DELETE = "model_delete" # 正在使用的模型不能直接删除 + DEVICE_MODEL_RELATION_CREATE = "device_model_relation_create" + DEVICE_MODEL_RELATION_UPDATE = "device_model_relation_update" # 绑定关系变化 应该只用这个吧?? + DEVICE_MODEL_RELATION_DELETE = "device_model_relation_delete" diff --git a/common/global_logger.py b/common/global_logger.py new file mode 100644 index 0000000..1d65173 --- /dev/null +++ b/common/global_logger.py @@ -0,0 +1,30 @@ +# logger.py +import logging.handlers +import os + +# 确保日志目录存在 +log_dir = '../logs' +if not os.path.exists(log_dir): + os.makedirs(log_dir) + +# 实例化并导出全局日志记录器 +logger = logging.getLogger("casic_safe_logger") +logger.setLevel(logging.DEBUG) # 设置日志级别 + +# 创建一个TimedRotatingFileHandler +handler = logging.handlers.TimedRotatingFileHandler( + os.path.join(log_dir, 'app.log'), # 日志文件名 + when='midnight', # 每天午夜滚动 + interval=1 # 滚动间隔为1天 +) +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +handler.setFormatter(formatter) + +# 将handler添加到日志器 +logger.addHandler(handler) + +# 创建控制台处理器 +console_handler = logging.StreamHandler() +console_handler.setLevel(logging.DEBUG) +console_handler.setFormatter(formatter) +logger.addHandler(console_handler) diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py new file mode 100644 index 0000000..21d7fd5 --- /dev/null +++ b/common/global_thread_pool.py @@ -0,0 +1,60 @@ +import uuid +from concurrent.futures import ThreadPoolExecutor +import threading + + +def generate_thread_id(): + """生成唯一的线程 ID""" + return str(uuid.uuid4()) + + +class GlobalThreadPool: + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + with cls._lock: + if cls._instance is None: + # 第一次创建实例时调用父类的 __new__ 来创建实例 + cls._instance = super(GlobalThreadPool, cls).__new__(cls) + # 在此进行一次性的初始化,比如线程池的创建 + max_workers = kwargs.get('max_workers', 10) + cls._instance.executor = ThreadPoolExecutor(max_workers=max_workers) + cls._instance.task_map = {} # 初始化任务映射 + return cls._instance + + def submit_task(self, fn, *args, thread_id=None, **kwargs): + """提交任务到线程池,并记录线程 ID""" + if thread_id is None: + thread_id = generate_thread_id() + if self.check_task_is_running(thread_id): + raise ValueError(f"线程 ID {thread_id} 已存在") + future = self.executor.submit(fn, *args, **kwargs) + self.task_map[thread_id] = future # 记录线程 ID 和 Future 对象的映射 + return thread_id + + def check_task_is_running(self, thread_id): + future = self.task_map.get(thread_id) + if future: + if future.running(): + return True + else: + del self.task_map[thread_id] + return False + else: + return False + + def stop_task(self, thread_id): + """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务""" + future = self.task_map.get(thread_id) + if future: + future.cancel() # 尝试取消任务 + print(f"任务 {thread_id} 已取消") + del self.task_map[thread_id] # 从任务映射中删除 + else: + print(f"未找到线程 ID {thread_id}") + + def shutdown(self, wait=True): + """关闭线程池""" + self.executor.shutdown(wait=wait) + GlobalThreadPool._instance = None diff --git a/db/__init__.py b/db/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/db/__init__.py diff --git a/db/database.py b/db/database.py new file mode 100644 index 0000000..4f44d2d --- /dev/null +++ b/db/database.py @@ -0,0 +1,22 @@ +from sqlmodel import SQLModel, create_engine, Session +from contextlib import contextmanager + +sqlite_file_name = "./db/safe-algo-pro.db" +sqlite_url = f"sqlite:///{sqlite_file_name}" + +connect_args = {"check_same_thread": False} +engine = create_engine(sqlite_url, connect_args=connect_args) + + +# 初始化数据库表 +def init_db(): + SQLModel.metadata.create_all(engine) + + +# 数据库会话管理 +def get_db(): + session = Session(engine) + try: + yield session + finally: + session.close() diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db new file mode 100644 index 0000000..86c8d64 --- /dev/null +++ b/db/safe-algo-pro.db Binary files differ diff --git a/entity/__init__.py b/entity/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/entity/__init__.py diff --git a/entity/base.py b/entity/base.py new file mode 100644 index 0000000..6aa862c --- /dev/null +++ b/entity/base.py @@ -0,0 +1,13 @@ +from sqlmodel import SQLModel, Field +from typing import Optional +from datetime import datetime + + +class TimestampMixin(SQLModel): + create_time: Optional[datetime] = Field(default_factory=datetime.now) + update_time: Optional[datetime] = Field(default_factory=datetime.now) + + class Config: + json_encoders = { + datetime: lambda v: v.strftime('%Y-%m-%d %H:%M:%S') + } diff --git a/entity/device.py b/entity/device.py new file mode 100644 index 0000000..f09b52d --- /dev/null +++ b/entity/device.py @@ -0,0 +1,33 @@ +from typing import Optional +from sqlmodel import SQLModel, Field + +from entity.base import TimestampMixin + + +class DeviceBase(SQLModel): + name: str + code: str + type: Optional[str] = None + ip: str + input_stream_url: Optional[str] = None + output_stream_url: Optional[str] = None + image_save_interval: Optional[int] = None + + +class Device(DeviceBase, TimestampMixin, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + + +class DeviceCreate(DeviceBase): + pass + + +class DeviceUpdate(DeviceBase): + id: int + name: Optional[str] = None + code: Optional[str] = None + ip: Optional[str] = None + + +class DeviceInfo(DeviceBase, TimestampMixin): + id: int diff --git a/entity/device_model_relation.py b/entity/device_model_relation.py new file mode 100644 index 0000000..2e58a39 --- /dev/null +++ b/entity/device_model_relation.py @@ -0,0 +1,31 @@ +from typing import Optional + +from sqlmodel import SQLModel, Field + +from entity.base import TimestampMixin + + +class DeviceModelRelationBase(SQLModel): + algo_model_id: int + is_use: int + threshold: Optional[float] = None + alarm_interval: Optional[int] = None + alarm_type: Optional[str] = None + + +class DeviceModelRelation(DeviceModelRelationBase, TimestampMixin, table=True): + __tablename__ = 'device_model_relation' + id: Optional[int] = Field(default=None, primary_key=True) + device_id: int + +class DeviceModelRelationCreate(DeviceModelRelationBase): + pass + + +class DeviceModelRelationInfo(DeviceModelRelationBase, TimestampMixin): + id: int + device_id: int + algo_model_name: str + algo_model_version: str + algo_model_path: Optional[str] = None # 可选字段 + algo_model_remark: Optional[str] = None diff --git a/entity/model.py b/entity/model.py new file mode 100644 index 0000000..db9133e --- /dev/null +++ b/entity/model.py @@ -0,0 +1,33 @@ +from typing import Optional + +from sqlmodel import SQLModel, Field + +from entity.base import TimestampMixin + + +class AlgoModelBase(SQLModel): + name: str + version: str + path: str + remark: Optional[str] = None + + +class AlgoModel(AlgoModelBase, TimestampMixin, table=True): + __tablename__ = "algo_model" # 显式指定表名 + + id: Optional[int] = Field(default=None, primary_key=True) + + +class AlgoModelCreate(AlgoModelBase): + pass + + +class AlgoModelUpdate(AlgoModelBase): + id: int + name: Optional[str] = None + version: Optional[str] = None + path: Optional[str] = None + + +class AlgoModelInfo(AlgoModelBase, TimestampMixin): + id: int diff --git a/main.py b/main.py new file mode 100644 index 0000000..75c9b52 --- /dev/null +++ b/main.py @@ -0,0 +1,15 @@ +from fastapi import FastAPI, HTTPException +from apis.router import router +import uvicorn + +from common.biz_exception import BizExceptionHandlers + +app = FastAPI() + +# 包含所有模块的路由 +app.include_router(router, prefix="/api") +app.add_exception_handler(HTTPException, BizExceptionHandlers.http_exception_handler) + +if __name__ == "__main__": + + uvicorn.run(app, host="0.0.0.0", port=9000) \ No newline at end of file diff --git a/services/__init__.py b/services/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/services/__init__.py diff --git a/services/device_model_relation_service.py b/services/device_model_relation_service.py new file mode 100644 index 0000000..0f8d9a7 --- /dev/null +++ b/services/device_model_relation_service.py @@ -0,0 +1,85 @@ +from datetime import datetime +from typing import List + +from sqlmodel import Session, select, delete + +from common.consts import NotifyChangeType +from common.global_thread_pool import GlobalThreadPool +from entity.device_model_relation import DeviceModelRelation, DeviceModelRelationInfo, DeviceModelRelationCreate +from entity.model import AlgoModel + + +class DeviceModelRelationService: + def __init__(self, db: Session): + self.db = db + self.relation_change_callbacks = [] # 用于存储回调函数 + self.thread_pool = GlobalThreadPool() + + def register_change_callback(self, callback): + """注册设备变化回调函数""" + self.relation_change_callbacks.append(callback) + + def notify_change(self, device_id, change_type): + """当设备发生变化时,调用回调通知变化""" + for callback in self.relation_change_callbacks: + self.thread_pool.executor.submit(callback, device_id, change_type) + + def get_device_models(self, device_id: int) -> List[DeviceModelRelationInfo]: + statement = ( + select(DeviceModelRelation, AlgoModel) + .join(AlgoModel, DeviceModelRelation.algo_model_id == AlgoModel.id) + .where(DeviceModelRelation.device_id == device_id) + ) + + # 执行联表查询 + result = self.db.exec(statement).all() + + models_info = [ + DeviceModelRelationInfo( + id=relation.id, + device_id=relation.id, + algo_model_id=relation.algo_model_id, + is_use=relation.is_use, + threshold=relation.threshold, + alarm_interval=relation.alarm_interval, + alarm_type=relation.alarm_type, + algo_model_name=model.name, + algo_model_version=model.version, + algo_model_path=model.path, + algo_model_remark=model.remark, + ) + for relation, model in result + ] + return models_info + + def add_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]): + new_relations = [ + DeviceModelRelation( + algo_model_id=relation.algo_model_id, + is_use=relation.is_use, + threshold=relation.threshold, + alarm_interval=relation.alarm_interval, + alarm_type=relation.alarm_type, + device_id=device_id, # 统一赋值 device_id + createtime=datetime.now(), + updatetime=datetime.now(), + ) + for relation in relations + ] + self.db.add_all(new_relations) + self.db.commit() + for relation in new_relations: + self.db.refresh(relation) + return new_relations + + def delete_relations_by_device(self, device_id: int): + statement = delete(DeviceModelRelation).where(DeviceModelRelation.device_id == device_id) + count = self.db.exec(statement) + self.db.commit() + return count.rowcount + + def update_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]): + self.delete_relations_by_device(device_id) + new_relations = self.add_relations_by_device(device_id, relations) + self.notify_change(device_id, NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE) + return new_relations diff --git a/services/device_service.py b/services/device_service.py new file mode 100644 index 0000000..0284c4e --- /dev/null +++ b/services/device_service.py @@ -0,0 +1,108 @@ +from datetime import datetime +from typing import Sequence, Optional, Tuple + +from sqlalchemy import func +from sqlmodel import Session, select + +from common.global_thread_pool import GlobalThreadPool +from common.consts import NotifyChangeType +from entity.device import Device, DeviceCreate, DeviceUpdate +from services.device_model_relation_service import DeviceModelRelationService + + +class DeviceService: + def __init__(self, db: Session): + self.db = db + self.device_change_callbacks = [] # 用于存储回调函数 + self.thread_pool = GlobalThreadPool() + + def register_change_callback(self, callback): + """注册设备变化回调函数""" + self.device_change_callbacks.append(callback) + + def notify_change(self, device_id, change_type): + """当设备发生变化时,调用回调通知变化""" + for callback in self.device_change_callbacks: + self.thread_pool.executor.submit(callback, device_id, change_type) + + def get_device_list(self, + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + ) -> Sequence[Device]: + statement = self.device_query(code, device_type, name) + results = self.db.exec(statement) + return results.all() + + def get_device_page(self, + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + offset: int = 0, + limit: int = 10 + ) -> Tuple[Sequence[Device], int]: + statement = self.device_query(code, device_type, name) + + # 查询总记录数 + total_statement = select(func.count()).select_from(statement.subquery()) + total = self.db.exec(total_statement).one() + + # 添加分页限制 + statement = statement.offset(offset).limit(limit) + + # 执行查询并返回结果 + results = self.db.exec(statement) + return results.all(), total # 返回分页数据和总数 + + def device_query(self, code, device_type, name): + # 构建查询语句 + statement = select(Device) + if name: + statement = statement.where(Device.name.like(f"%{name}%")) + if code: + statement = statement.where(Device.code.like(f"%{code}%")) + if device_type: + statement = statement.where(Device.type == device_type) + return statement + + def create_device(self, device_data: DeviceCreate): + device = Device.model_validate(device_data) + device.create_time = datetime.now() + device.update_time = datetime.now() + self.db.add(device) + self.db.commit() + self.db.refresh(device) + self.notify_change(device.id, NotifyChangeType.DEVICE_CREATE) + return device + + def update_device(self, device_data: DeviceUpdate): + device = self.db.get(Device, device_data.id) + if not device: + return None + + update_data = device_data.dict(exclude_unset=True) + for key, value in update_data.items(): + setattr(device, key, value) + + device.update_time = datetime.now() + self.db.add(device) + self.db.commit() + self.db.refresh(device) + self.notify_change(device.id, NotifyChangeType.DEVICE_UPDATE) + return device + + def delete_device(self, device_id: int): + device = self.db.get(Device, device_id) + if not device: + return None + + self.db.delete(device) + self.db.commit() + self.notify_change(device.id, NotifyChangeType.DEVICE_DELETE) + + relation_service = DeviceModelRelationService(self.db) + relation_service.delete_relations_by_device(device_id) + return device + + def get_device(self, device_id: int): + return self.db.get(Device, device_id) diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e530a86 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +test.py +/logs/* \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..901fe74 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,20 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..b68d82f --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..3deddad --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml new file mode 100644 index 0000000..d0876a7 --- /dev/null +++ b/.idea/safe-algo-pro.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/algo/__init__.py b/algo/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/algo/__init__.py diff --git a/algo/algo_runner.py b/algo/algo_runner.py new file mode 100644 index 0000000..ee07eaf --- /dev/null +++ b/algo/algo_runner.py @@ -0,0 +1,117 @@ +from algo.model_manager import ModelManager +from common.consts import NotifyChangeType +from common.global_thread_pool import GlobalThreadPool +from db.database import get_db +from entity.device import Device +from services.device_model_relation_service import DeviceModelRelationService +from services.device_service import DeviceService +from services.model_service import ModelService + + +class AlgoRunner: + def __init__(self): + self.db = get_db() + self.device_service = DeviceService(self.db) + self.model_service = ModelService(self.db) + self.relation_service = DeviceModelRelationService(self.db) + self.model_manager = ModelManager(self.db) + self.thread_pool = GlobalThreadPool() + self.threads = {} # 用于存储设备对应的线程 + self.device_model_relations = {} + + # 注册设备和模型的变化回调 + self.device_service.register_change_callback(self.on_device_change) + self.model_service.register_change_callback(self.on_model_change) + self.relation_service.relation_change_callbacks(self.on_relation_change) + + def start(self): + """在程序启动时调用,读取设备和模型,启动检测线程""" + self.model_manager.load_models() + self.load_and_start_devices() + + def load_and_start_devices(self): + """从数据库读取设备列表并启动线程""" + devices = self.device_service.get_device_list() + devices = [device for device in devices if device.input_stream_url] + for device in devices: + self.start_device_thread(device) + + def run_detection(self, device: Device): + """todo 设备目标检测主逻辑""" + print(f"设备 {device.device_id} 的检测线程启动") + video_stream = self.device_service.get_video_stream(device.device_id) + + while True: + try: + frame = video_stream.read_frame() + for model in self.model_manager.models: + # 调用目标检测模型 + result = self.model_service.run_inference(model, frame) + # 在此处进行后处理,例如标记视频、生成告警等 + self.process_result(result, device.device_id) + except Exception as e: + print(f"设备 {device.device_id} 处理时出错: {e}") + break + video_stream.close() + print(f"设备 {device.device_id} 的检测线程结束") + + def start_device_thread(self, device: Device): + """为单个设备启动检测线程""" + + if not device.input_stream_url: + print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程') + return + + if device.device_id in self.threads: + print(f"设备 {device.device_id} 已经在运行中") + return + + # 获取设备绑定的模型列表 + relations = self.relation_service.get_device_models(device.device_id) + self.device_model_relations[device.device_id] = relations + + if not relations: + print(f"设备 {device.code} 未绑定模型,无法启动检测") + return + + future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}') + + def stop_device_thread(self, device_id): + """todo 控制标志位 停止指定设备的检测线程""" + if device_id in self.threads: + self.threads[device_id].cancel() # 尝试取消线程 + del self.threads[device_id] + print(f"设备 {device_id} 的检测线程已停止") + + def restart_device_thread(self, device_id): + self.stop_device_thread(device_id) + # todo 需要留个关闭时间吗 + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + + def on_device_change(self, device_id, change_type): + """设备变化回调""" + if change_type == NotifyChangeType.DEVICE_CREATE: + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + elif change_type == NotifyChangeType.DEVICE_DELETE: + self.stop_device_thread(device_id) + elif change_type == NotifyChangeType.DEVICE_UPDATE: + self.restart_device_thread(device_id) + + def on_model_change(self, model_id, change_type): + """模型变化回调""" + if change_type == NotifyChangeType.MODEL_UPDATE: + self.model_manager.reload_model(model_id) + + devices_to_reload = [ + device_id for device_id, relation_info in self.device_model_relations.items() + if relation_info.model_id == model_id + ] + for device_id in devices_to_reload: + self.restart_device_thread(device_id) + + def on_relation_change(self, device_id, change_type): + """设备模型关系变化回调""" + if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: + self.restart_device_thread(device_id) diff --git a/algo/model_manager.py b/algo/model_manager.py new file mode 100644 index 0000000..3ab0709 --- /dev/null +++ b/algo/model_manager.py @@ -0,0 +1,64 @@ +from dataclasses import dataclass +from typing import Optional + +import numpy as np +from ultralytics import YOLO + +from entity.model import AlgoModel +from services.model_service import ModelService +from common.global_logger import logger + + +@dataclass +class AlgoModelExec: + algo_model_id: int + algo_model_info: AlgoModel + algo_model_exec: Optional[object] = None + input_size: int = 640 + + +class ModelManager: + def __init__(self, db, model_warm_up=5): + self.db = db + self.model_service = ModelService(self.db) + self.models = {} + self.model_warm_up = model_warm_up + + self.load_models() + + def query_model_inuse(self): + algo_model_list = list(self.model_service.get_models_in_use()) + for algo_model in algo_model_list: + self.models[algo_model.id] = AlgoModelExec( + algo_model_id=algo_model.id, + algo_model_info=algo_model + ) + + def load_models(self): + self.models = {} + self.query_model_inuse() + for algo_model_id, algo_model_exec in self.models.items(): + self.load_model(algo_model_exec) + + def load_model(self, algo_model_exec: AlgoModelExec): + model_name = algo_model_exec.algo_model_info.name + model_path = algo_model_exec.algo_model_info.path + logger.info(f'loading model {model_name}: {model_path}') + + algo_model_exec.algo_model_exec = YOLO(model_path, task='detect') + if self.model_warm_up > 0: + logger.info(f'warming up model {model_name}') + imgsz = algo_model_exec.input_size + if not isinstance(imgsz, list): + imgsz = [imgsz, imgsz] + dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + for i in range(self.model_warm_up): + algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + logger.info(f'warm up model {model_name} success!') + logger.info(f'load model {model_name} success!') + + def reload_model(self, model_id): + algo_model_exec = self.models.get(model_id) + if algo_model_exec: + algo_model_exec.algo_model_exec = None + self.load_model(algo_model_exec) diff --git a/apis/__init__.py b/apis/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/apis/__init__.py diff --git a/apis/base.py b/apis/base.py new file mode 100644 index 0000000..a7c4ede --- /dev/null +++ b/apis/base.py @@ -0,0 +1,26 @@ +from pydantic import BaseModel, Field +from typing import Optional, Generic, TypeVar, Any, List + +# 定义一个泛型类型变量 T +T = TypeVar("T") + + +# 使用泛型 T 定义 data 的类型 +class StandardResponse(BaseModel, Generic[T]): + code: int = 200 + data: Optional[T] = None + message: str = "请求成功" + success: bool = True + + +class PageResponse(BaseModel, Generic[T]): + total: int + items: List[T] = Field(default_factory=list) + + +def standard_response(data: Any = None, code: int = 200, message: str = "请求成功", success: bool = True): + return StandardResponse(data=data, code=code, message=message, success=success) + + +def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False): + return StandardResponse(data=data, code=code, message=message, success=success) diff --git a/apis/device.py b/apis/device.py new file mode 100644 index 0000000..7eeaa78 --- /dev/null +++ b/apis/device.py @@ -0,0 +1,65 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, Query +from sqlmodel import Session + +from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response +from db.database import get_db +from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo +from services.device_service import DeviceService + +router = APIRouter() + + +@router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) +def get_device_list( + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + db: Session = Depends(get_db)): + service = DeviceService(db) + devices = list(service.get_device_list(name, code, device_type)) + return standard_response(data=devices) + + +@router.get("/page/", response_model=StandardResponse[PageResponse[Device]]) +def get_device_page( + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = DeviceService(db) + + # 获取分页后的设备列表和总数 + devices, total = service.get_device_page(name, code, device_type, offset, limit) + + return standard_response( + data=PageResponse(total=total, items=devices) + ) + + +@router.post("/add", response_model=StandardResponse[DeviceInfo]) +def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.create_device(device_data) + return standard_response(data=device) + + +@router.post("/update", response_model=StandardResponse[DeviceInfo]) +def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.update_device(device_data) + if not device: + return standard_error_response(data=device_data, message="Device not found") + return standard_response(data=device) + + +@router.delete("/delete", response_model=StandardResponse[int]) +def delete_device(device_id: int, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.delete_device(device_id) + if not device: + return standard_error_response(data=device_id, message="Device not found") + return standard_response(data=device_id) diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py new file mode 100644 index 0000000..7108571 --- /dev/null +++ b/apis/device_model_realtion.py @@ -0,0 +1,45 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, Query +from sqlmodel import Session + +from apis.base import StandardResponse, standard_response +from db.database import get_db +from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation +from services.device_model_relation_service import DeviceModelRelationService + +router = APIRouter() + + +@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) +def list_by_device( + device_id: int, + db: Session = Depends(get_db)): + service = DeviceModelRelationService(db) + models = list(service.get_device_models(device_id)) + return standard_response(data=models) + + +# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) +# def add_by_device(relation_data: List[DeviceModelRelationCreate], +# device_id: int = Query(...), +# db: Session = Depends(get_db)): +# service = DeviceModelRelationService(db) +# relations = service.add_relations_by_device(device_id, relation_data) +# return standard_response(data=relations) + + +@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) +def update_by_device(relation_data: List[DeviceModelRelationCreate], + device_id: int = Query(...), + db: Session = Depends(get_db)): + service = DeviceModelRelationService(db) + relations = service.update_relations_by_device(device_id, relation_data) + return standard_response(data=relations) + + +# @router.delete("/delete_by_device", response_model=StandardResponse[int]) +# def delete_device(device_id: int, db: Session = Depends(get_db)): +# service = DeviceModelRelationService(db) +# count = service.delete_relations_by_device(device_id) +# return standard_response(data=count) diff --git a/apis/model.py b/apis/model.py new file mode 100644 index 0000000..8adcc14 --- /dev/null +++ b/apis/model.py @@ -0,0 +1,64 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlmodel import Session + +from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response +from common.biz_exception import BizException +from db.database import get_db +from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo +from services.model_service import ModelService + +router = APIRouter() + + +@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) +def get_model_list( + name: Optional[str] = None, + remark: Optional[str] = None, + db: Session = Depends(get_db)): + service = ModelService(db) + models = list(service.get_model_list(name, remark)) + return standard_response(data=models) + + +@router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModel]]) +def get_model_page( + name: Optional[str] = None, + remark: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = ModelService(db) + + # 获取分页后的设备列表和总数 + models, total = service.get_model_page(name, remark, offset, limit) + + return standard_response( + data=PageResponse(total=total, items=models) + ) + + +@router.post("/add", response_model=StandardResponse[AlgoModelInfo]) +def create_model(model_data: AlgoModelCreate, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.create_model(model_data) + return standard_response(data=model) + + +@router.post("/update", response_model=StandardResponse[AlgoModelInfo]) +def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.update_model(model_data) + if not model: + return standard_error_response(data=model_data, message="Model not found") + return standard_response(data=model) + + +@router.delete("/delete", response_model=StandardResponse[int]) +def delete_model(model_id: int, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.delete_model(model_id) + if not model: + return standard_error_response(data=model_id, message="Model not found") + return standard_response(data=model_id) diff --git a/apis/router.py b/apis/router.py new file mode 100644 index 0000000..22b5ed6 --- /dev/null +++ b/apis/router.py @@ -0,0 +1,15 @@ +# api/router.py +from fastapi import APIRouter +from .device import router as devices_router +from .model import router as models_router +from .device_model_realtion import router as device_model_relation_router + + +# 创建一个全局的 router +router = APIRouter() + +# 将各个模块的 router 注册到全局 router 中 +router.include_router(devices_router, prefix="/device", tags=["Devices"]) +router.include_router(models_router, prefix="/model", tags=["Models"]) +router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"]) + diff --git a/common/__init__.py b/common/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/common/__init__.py diff --git a/common/biz_exception.py b/common/biz_exception.py new file mode 100644 index 0000000..946a794 --- /dev/null +++ b/common/biz_exception.py @@ -0,0 +1,14 @@ +from fastapi import Request, HTTPException + +from apis.base import standard_error_response + + +class BizException(HTTPException): + def __init__(self, message: str = "请求异常", status_code: int = 500): + super().__init__(status_code=status_code, detail=message) + + +class BizExceptionHandlers: + @staticmethod + async def biz_exception_handler(request: Request, exc: HTTPException): + standard_error_response(code=exc.status_code, message=exc.detail) diff --git a/common/consts.py b/common/consts.py new file mode 100644 index 0000000..bba41e4 --- /dev/null +++ b/common/consts.py @@ -0,0 +1,16 @@ +class Constants: + def __setattr__(self, name, value): + raise AttributeError(f"Cannot modify constant {name}") + + +class NotifyChangeType(Constants): + # 定义常量成员 + DEVICE_CREATE = "device_create" + DEVICE_UPDATE = "device_update" + DEVICE_DELETE = "device_delete" + MODEL_CREATE = "model_create" # 模型新增不用通知,与设备绑定时才生效 + MODEL_UPDATE = "model_update" + MODEL_DELETE = "model_delete" # 正在使用的模型不能直接删除 + DEVICE_MODEL_RELATION_CREATE = "device_model_relation_create" + DEVICE_MODEL_RELATION_UPDATE = "device_model_relation_update" # 绑定关系变化 应该只用这个吧?? + DEVICE_MODEL_RELATION_DELETE = "device_model_relation_delete" diff --git a/common/global_logger.py b/common/global_logger.py new file mode 100644 index 0000000..1d65173 --- /dev/null +++ b/common/global_logger.py @@ -0,0 +1,30 @@ +# logger.py +import logging.handlers +import os + +# 确保日志目录存在 +log_dir = '../logs' +if not os.path.exists(log_dir): + os.makedirs(log_dir) + +# 实例化并导出全局日志记录器 +logger = logging.getLogger("casic_safe_logger") +logger.setLevel(logging.DEBUG) # 设置日志级别 + +# 创建一个TimedRotatingFileHandler +handler = logging.handlers.TimedRotatingFileHandler( + os.path.join(log_dir, 'app.log'), # 日志文件名 + when='midnight', # 每天午夜滚动 + interval=1 # 滚动间隔为1天 +) +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +handler.setFormatter(formatter) + +# 将handler添加到日志器 +logger.addHandler(handler) + +# 创建控制台处理器 +console_handler = logging.StreamHandler() +console_handler.setLevel(logging.DEBUG) +console_handler.setFormatter(formatter) +logger.addHandler(console_handler) diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py new file mode 100644 index 0000000..21d7fd5 --- /dev/null +++ b/common/global_thread_pool.py @@ -0,0 +1,60 @@ +import uuid +from concurrent.futures import ThreadPoolExecutor +import threading + + +def generate_thread_id(): + """生成唯一的线程 ID""" + return str(uuid.uuid4()) + + +class GlobalThreadPool: + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + with cls._lock: + if cls._instance is None: + # 第一次创建实例时调用父类的 __new__ 来创建实例 + cls._instance = super(GlobalThreadPool, cls).__new__(cls) + # 在此进行一次性的初始化,比如线程池的创建 + max_workers = kwargs.get('max_workers', 10) + cls._instance.executor = ThreadPoolExecutor(max_workers=max_workers) + cls._instance.task_map = {} # 初始化任务映射 + return cls._instance + + def submit_task(self, fn, *args, thread_id=None, **kwargs): + """提交任务到线程池,并记录线程 ID""" + if thread_id is None: + thread_id = generate_thread_id() + if self.check_task_is_running(thread_id): + raise ValueError(f"线程 ID {thread_id} 已存在") + future = self.executor.submit(fn, *args, **kwargs) + self.task_map[thread_id] = future # 记录线程 ID 和 Future 对象的映射 + return thread_id + + def check_task_is_running(self, thread_id): + future = self.task_map.get(thread_id) + if future: + if future.running(): + return True + else: + del self.task_map[thread_id] + return False + else: + return False + + def stop_task(self, thread_id): + """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务""" + future = self.task_map.get(thread_id) + if future: + future.cancel() # 尝试取消任务 + print(f"任务 {thread_id} 已取消") + del self.task_map[thread_id] # 从任务映射中删除 + else: + print(f"未找到线程 ID {thread_id}") + + def shutdown(self, wait=True): + """关闭线程池""" + self.executor.shutdown(wait=wait) + GlobalThreadPool._instance = None diff --git a/db/__init__.py b/db/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/db/__init__.py diff --git a/db/database.py b/db/database.py new file mode 100644 index 0000000..4f44d2d --- /dev/null +++ b/db/database.py @@ -0,0 +1,22 @@ +from sqlmodel import SQLModel, create_engine, Session +from contextlib import contextmanager + +sqlite_file_name = "./db/safe-algo-pro.db" +sqlite_url = f"sqlite:///{sqlite_file_name}" + +connect_args = {"check_same_thread": False} +engine = create_engine(sqlite_url, connect_args=connect_args) + + +# 初始化数据库表 +def init_db(): + SQLModel.metadata.create_all(engine) + + +# 数据库会话管理 +def get_db(): + session = Session(engine) + try: + yield session + finally: + session.close() diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db new file mode 100644 index 0000000..86c8d64 --- /dev/null +++ b/db/safe-algo-pro.db Binary files differ diff --git a/entity/__init__.py b/entity/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/entity/__init__.py diff --git a/entity/base.py b/entity/base.py new file mode 100644 index 0000000..6aa862c --- /dev/null +++ b/entity/base.py @@ -0,0 +1,13 @@ +from sqlmodel import SQLModel, Field +from typing import Optional +from datetime import datetime + + +class TimestampMixin(SQLModel): + create_time: Optional[datetime] = Field(default_factory=datetime.now) + update_time: Optional[datetime] = Field(default_factory=datetime.now) + + class Config: + json_encoders = { + datetime: lambda v: v.strftime('%Y-%m-%d %H:%M:%S') + } diff --git a/entity/device.py b/entity/device.py new file mode 100644 index 0000000..f09b52d --- /dev/null +++ b/entity/device.py @@ -0,0 +1,33 @@ +from typing import Optional +from sqlmodel import SQLModel, Field + +from entity.base import TimestampMixin + + +class DeviceBase(SQLModel): + name: str + code: str + type: Optional[str] = None + ip: str + input_stream_url: Optional[str] = None + output_stream_url: Optional[str] = None + image_save_interval: Optional[int] = None + + +class Device(DeviceBase, TimestampMixin, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + + +class DeviceCreate(DeviceBase): + pass + + +class DeviceUpdate(DeviceBase): + id: int + name: Optional[str] = None + code: Optional[str] = None + ip: Optional[str] = None + + +class DeviceInfo(DeviceBase, TimestampMixin): + id: int diff --git a/entity/device_model_relation.py b/entity/device_model_relation.py new file mode 100644 index 0000000..2e58a39 --- /dev/null +++ b/entity/device_model_relation.py @@ -0,0 +1,31 @@ +from typing import Optional + +from sqlmodel import SQLModel, Field + +from entity.base import TimestampMixin + + +class DeviceModelRelationBase(SQLModel): + algo_model_id: int + is_use: int + threshold: Optional[float] = None + alarm_interval: Optional[int] = None + alarm_type: Optional[str] = None + + +class DeviceModelRelation(DeviceModelRelationBase, TimestampMixin, table=True): + __tablename__ = 'device_model_relation' + id: Optional[int] = Field(default=None, primary_key=True) + device_id: int + +class DeviceModelRelationCreate(DeviceModelRelationBase): + pass + + +class DeviceModelRelationInfo(DeviceModelRelationBase, TimestampMixin): + id: int + device_id: int + algo_model_name: str + algo_model_version: str + algo_model_path: Optional[str] = None # 可选字段 + algo_model_remark: Optional[str] = None diff --git a/entity/model.py b/entity/model.py new file mode 100644 index 0000000..db9133e --- /dev/null +++ b/entity/model.py @@ -0,0 +1,33 @@ +from typing import Optional + +from sqlmodel import SQLModel, Field + +from entity.base import TimestampMixin + + +class AlgoModelBase(SQLModel): + name: str + version: str + path: str + remark: Optional[str] = None + + +class AlgoModel(AlgoModelBase, TimestampMixin, table=True): + __tablename__ = "algo_model" # 显式指定表名 + + id: Optional[int] = Field(default=None, primary_key=True) + + +class AlgoModelCreate(AlgoModelBase): + pass + + +class AlgoModelUpdate(AlgoModelBase): + id: int + name: Optional[str] = None + version: Optional[str] = None + path: Optional[str] = None + + +class AlgoModelInfo(AlgoModelBase, TimestampMixin): + id: int diff --git a/main.py b/main.py new file mode 100644 index 0000000..75c9b52 --- /dev/null +++ b/main.py @@ -0,0 +1,15 @@ +from fastapi import FastAPI, HTTPException +from apis.router import router +import uvicorn + +from common.biz_exception import BizExceptionHandlers + +app = FastAPI() + +# 包含所有模块的路由 +app.include_router(router, prefix="/api") +app.add_exception_handler(HTTPException, BizExceptionHandlers.http_exception_handler) + +if __name__ == "__main__": + + uvicorn.run(app, host="0.0.0.0", port=9000) \ No newline at end of file diff --git a/services/__init__.py b/services/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/services/__init__.py diff --git a/services/device_model_relation_service.py b/services/device_model_relation_service.py new file mode 100644 index 0000000..0f8d9a7 --- /dev/null +++ b/services/device_model_relation_service.py @@ -0,0 +1,85 @@ +from datetime import datetime +from typing import List + +from sqlmodel import Session, select, delete + +from common.consts import NotifyChangeType +from common.global_thread_pool import GlobalThreadPool +from entity.device_model_relation import DeviceModelRelation, DeviceModelRelationInfo, DeviceModelRelationCreate +from entity.model import AlgoModel + + +class DeviceModelRelationService: + def __init__(self, db: Session): + self.db = db + self.relation_change_callbacks = [] # 用于存储回调函数 + self.thread_pool = GlobalThreadPool() + + def register_change_callback(self, callback): + """注册设备变化回调函数""" + self.relation_change_callbacks.append(callback) + + def notify_change(self, device_id, change_type): + """当设备发生变化时,调用回调通知变化""" + for callback in self.relation_change_callbacks: + self.thread_pool.executor.submit(callback, device_id, change_type) + + def get_device_models(self, device_id: int) -> List[DeviceModelRelationInfo]: + statement = ( + select(DeviceModelRelation, AlgoModel) + .join(AlgoModel, DeviceModelRelation.algo_model_id == AlgoModel.id) + .where(DeviceModelRelation.device_id == device_id) + ) + + # 执行联表查询 + result = self.db.exec(statement).all() + + models_info = [ + DeviceModelRelationInfo( + id=relation.id, + device_id=relation.id, + algo_model_id=relation.algo_model_id, + is_use=relation.is_use, + threshold=relation.threshold, + alarm_interval=relation.alarm_interval, + alarm_type=relation.alarm_type, + algo_model_name=model.name, + algo_model_version=model.version, + algo_model_path=model.path, + algo_model_remark=model.remark, + ) + for relation, model in result + ] + return models_info + + def add_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]): + new_relations = [ + DeviceModelRelation( + algo_model_id=relation.algo_model_id, + is_use=relation.is_use, + threshold=relation.threshold, + alarm_interval=relation.alarm_interval, + alarm_type=relation.alarm_type, + device_id=device_id, # 统一赋值 device_id + createtime=datetime.now(), + updatetime=datetime.now(), + ) + for relation in relations + ] + self.db.add_all(new_relations) + self.db.commit() + for relation in new_relations: + self.db.refresh(relation) + return new_relations + + def delete_relations_by_device(self, device_id: int): + statement = delete(DeviceModelRelation).where(DeviceModelRelation.device_id == device_id) + count = self.db.exec(statement) + self.db.commit() + return count.rowcount + + def update_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]): + self.delete_relations_by_device(device_id) + new_relations = self.add_relations_by_device(device_id, relations) + self.notify_change(device_id, NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE) + return new_relations diff --git a/services/device_service.py b/services/device_service.py new file mode 100644 index 0000000..0284c4e --- /dev/null +++ b/services/device_service.py @@ -0,0 +1,108 @@ +from datetime import datetime +from typing import Sequence, Optional, Tuple + +from sqlalchemy import func +from sqlmodel import Session, select + +from common.global_thread_pool import GlobalThreadPool +from common.consts import NotifyChangeType +from entity.device import Device, DeviceCreate, DeviceUpdate +from services.device_model_relation_service import DeviceModelRelationService + + +class DeviceService: + def __init__(self, db: Session): + self.db = db + self.device_change_callbacks = [] # 用于存储回调函数 + self.thread_pool = GlobalThreadPool() + + def register_change_callback(self, callback): + """注册设备变化回调函数""" + self.device_change_callbacks.append(callback) + + def notify_change(self, device_id, change_type): + """当设备发生变化时,调用回调通知变化""" + for callback in self.device_change_callbacks: + self.thread_pool.executor.submit(callback, device_id, change_type) + + def get_device_list(self, + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + ) -> Sequence[Device]: + statement = self.device_query(code, device_type, name) + results = self.db.exec(statement) + return results.all() + + def get_device_page(self, + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + offset: int = 0, + limit: int = 10 + ) -> Tuple[Sequence[Device], int]: + statement = self.device_query(code, device_type, name) + + # 查询总记录数 + total_statement = select(func.count()).select_from(statement.subquery()) + total = self.db.exec(total_statement).one() + + # 添加分页限制 + statement = statement.offset(offset).limit(limit) + + # 执行查询并返回结果 + results = self.db.exec(statement) + return results.all(), total # 返回分页数据和总数 + + def device_query(self, code, device_type, name): + # 构建查询语句 + statement = select(Device) + if name: + statement = statement.where(Device.name.like(f"%{name}%")) + if code: + statement = statement.where(Device.code.like(f"%{code}%")) + if device_type: + statement = statement.where(Device.type == device_type) + return statement + + def create_device(self, device_data: DeviceCreate): + device = Device.model_validate(device_data) + device.create_time = datetime.now() + device.update_time = datetime.now() + self.db.add(device) + self.db.commit() + self.db.refresh(device) + self.notify_change(device.id, NotifyChangeType.DEVICE_CREATE) + return device + + def update_device(self, device_data: DeviceUpdate): + device = self.db.get(Device, device_data.id) + if not device: + return None + + update_data = device_data.dict(exclude_unset=True) + for key, value in update_data.items(): + setattr(device, key, value) + + device.update_time = datetime.now() + self.db.add(device) + self.db.commit() + self.db.refresh(device) + self.notify_change(device.id, NotifyChangeType.DEVICE_UPDATE) + return device + + def delete_device(self, device_id: int): + device = self.db.get(Device, device_id) + if not device: + return None + + self.db.delete(device) + self.db.commit() + self.notify_change(device.id, NotifyChangeType.DEVICE_DELETE) + + relation_service = DeviceModelRelationService(self.db) + relation_service.delete_relations_by_device(device_id) + return device + + def get_device(self, device_id: int): + return self.db.get(Device, device_id) diff --git a/services/model_service.py b/services/model_service.py new file mode 100644 index 0000000..568895e --- /dev/null +++ b/services/model_service.py @@ -0,0 +1,120 @@ +from datetime import datetime +from typing import List, Sequence, Optional, Tuple + +from sqlalchemy import func +from sqlmodel import Session, select + +from common.biz_exception import BizException +from entity.device_model_relation import DeviceModelRelation +from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate +from common.global_thread_pool import GlobalThreadPool +from common.consts import NotifyChangeType + + +class ModelService: + def __init__(self, db: Session): + self.db = db + self.model_change_callbacks = [] # 用于存储回调函数 + self.thread_pool = GlobalThreadPool() + + def register_change_callback(self, callback): + """注册设备变化回调函数""" + self.model_change_callbacks.append(callback) + + def notify_change(self, algo_model_id, change_type): + """当设备发生变化时,调用回调通知变化""" + for callback in self.model_change_callbacks: + self.thread_pool.executor.submit(callback, algo_model_id, change_type) + + def get_model_list(self, + name: Optional[str] = None, + remark: Optional[str] = None, + ) -> Sequence[AlgoModel]: + statement = self.model_query(name, remark) + results = self.db.exec(statement) + return results.all() + + def get_model_page(self, + name: Optional[str] = None, + remark: Optional[str] = None, + offset: int = 0, + limit: int = 10 + ) -> Tuple[Sequence[AlgoModel], int]: + statement = self.model_query(name, remark) + + # 查询总记录数 + total_statement = select(func.count()).select_from(statement.subquery()) + total = self.db.exec(total_statement).one() + + # 添加分页限制 + statement = statement.offset(offset).limit(limit) + + # 执行查询并返回结果 + results = self.db.exec(statement) + return results.all(), total # 返回分页数据和总数 + + def model_query(self, name, remark): + # 构建查询语句 + statement = select(AlgoModel) + if name: + statement = statement.where(AlgoModel.name.like(f"%{name}%")) + if remark: + statement = statement.where(AlgoModel.remark.like(f"%{remark}%")) + return statement + + def create_model(self, model_data: AlgoModelCreate): + model = AlgoModel.model_validate(model_data) + model.create_time = datetime.now() + model.update_time = datetime.now() + self.db.add(model) + self.db.commit() + self.db.refresh(model) + return model + + def update_model(self, model_data: AlgoModelUpdate): + model = self.db.get(AlgoModel, model_data.id) + if not model: + return None + + update_data = model_data.dict(exclude_unset=True) + for key, value in update_data.items(): + setattr(model, key, value) + + model.update_time = datetime.now() + self.db.add(model) + self.db.commit() + self.db.refresh(model) + self.notify_change(model.id, NotifyChangeType.MODEL_UPDATE) + + return model + + def delete_model(self, model_id: int): + model = self.db.get(AlgoModel, model_id) + if not model: + return None + # 查询 device_model_relation 中是否存在启用的绑定关系 + statement = ( + select(DeviceModelRelation) + .where(DeviceModelRelation.algo_model_id == model_id) + .where(DeviceModelRelation.is_use == 1) + ) + relation_in_use = self.db.exec(statement).first() + + # 如果存在启用的绑定关系,提示无法删除 + if relation_in_use: + raise BizException(message=f"模型 {model.name} 正在被设备使用,无法删除") + + self.db.delete(model) + self.db.commit() + return model + + def get_models_in_use(self) -> Sequence[AlgoModel]: + """获取所有在 device_model_relation 表里有启用绑定关系的模型信息""" + statement = ( + select(AlgoModel) + .join(DeviceModelRelation, DeviceModelRelation.algo_model_id == AlgoModel.id) + .where(DeviceModelRelation.is_use == 1) + .group_by(AlgoModel.id) + ) + results = self.db.exec(statement).all() + return results diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e530a86 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +test.py +/logs/* \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..901fe74 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,20 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..b68d82f --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..3deddad --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/safe-algo-pro.iml b/.idea/safe-algo-pro.iml new file mode 100644 index 0000000..d0876a7 --- /dev/null +++ b/.idea/safe-algo-pro.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/algo/__init__.py b/algo/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/algo/__init__.py diff --git a/algo/algo_runner.py b/algo/algo_runner.py new file mode 100644 index 0000000..ee07eaf --- /dev/null +++ b/algo/algo_runner.py @@ -0,0 +1,117 @@ +from algo.model_manager import ModelManager +from common.consts import NotifyChangeType +from common.global_thread_pool import GlobalThreadPool +from db.database import get_db +from entity.device import Device +from services.device_model_relation_service import DeviceModelRelationService +from services.device_service import DeviceService +from services.model_service import ModelService + + +class AlgoRunner: + def __init__(self): + self.db = get_db() + self.device_service = DeviceService(self.db) + self.model_service = ModelService(self.db) + self.relation_service = DeviceModelRelationService(self.db) + self.model_manager = ModelManager(self.db) + self.thread_pool = GlobalThreadPool() + self.threads = {} # 用于存储设备对应的线程 + self.device_model_relations = {} + + # 注册设备和模型的变化回调 + self.device_service.register_change_callback(self.on_device_change) + self.model_service.register_change_callback(self.on_model_change) + self.relation_service.relation_change_callbacks(self.on_relation_change) + + def start(self): + """在程序启动时调用,读取设备和模型,启动检测线程""" + self.model_manager.load_models() + self.load_and_start_devices() + + def load_and_start_devices(self): + """从数据库读取设备列表并启动线程""" + devices = self.device_service.get_device_list() + devices = [device for device in devices if device.input_stream_url] + for device in devices: + self.start_device_thread(device) + + def run_detection(self, device: Device): + """todo 设备目标检测主逻辑""" + print(f"设备 {device.device_id} 的检测线程启动") + video_stream = self.device_service.get_video_stream(device.device_id) + + while True: + try: + frame = video_stream.read_frame() + for model in self.model_manager.models: + # 调用目标检测模型 + result = self.model_service.run_inference(model, frame) + # 在此处进行后处理,例如标记视频、生成告警等 + self.process_result(result, device.device_id) + except Exception as e: + print(f"设备 {device.device_id} 处理时出错: {e}") + break + video_stream.close() + print(f"设备 {device.device_id} 的检测线程结束") + + def start_device_thread(self, device: Device): + """为单个设备启动检测线程""" + + if not device.input_stream_url: + print(f'设备 {device.device_id} 未配置视频流地址,无法启动检测线程') + return + + if device.device_id in self.threads: + print(f"设备 {device.device_id} 已经在运行中") + return + + # 获取设备绑定的模型列表 + relations = self.relation_service.get_device_models(device.device_id) + self.device_model_relations[device.device_id] = relations + + if not relations: + print(f"设备 {device.code} 未绑定模型,无法启动检测") + return + + future = self.thread_pool.submit_task(self.run_detection, device, thread_id=f'device_{device.id}') + + def stop_device_thread(self, device_id): + """todo 控制标志位 停止指定设备的检测线程""" + if device_id in self.threads: + self.threads[device_id].cancel() # 尝试取消线程 + del self.threads[device_id] + print(f"设备 {device_id} 的检测线程已停止") + + def restart_device_thread(self, device_id): + self.stop_device_thread(device_id) + # todo 需要留个关闭时间吗 + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + + def on_device_change(self, device_id, change_type): + """设备变化回调""" + if change_type == NotifyChangeType.DEVICE_CREATE: + device = self.device_service.get_device(device_id) + self.start_device_thread(device) + elif change_type == NotifyChangeType.DEVICE_DELETE: + self.stop_device_thread(device_id) + elif change_type == NotifyChangeType.DEVICE_UPDATE: + self.restart_device_thread(device_id) + + def on_model_change(self, model_id, change_type): + """模型变化回调""" + if change_type == NotifyChangeType.MODEL_UPDATE: + self.model_manager.reload_model(model_id) + + devices_to_reload = [ + device_id for device_id, relation_info in self.device_model_relations.items() + if relation_info.model_id == model_id + ] + for device_id in devices_to_reload: + self.restart_device_thread(device_id) + + def on_relation_change(self, device_id, change_type): + """设备模型关系变化回调""" + if change_type == NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE: + self.restart_device_thread(device_id) diff --git a/algo/model_manager.py b/algo/model_manager.py new file mode 100644 index 0000000..3ab0709 --- /dev/null +++ b/algo/model_manager.py @@ -0,0 +1,64 @@ +from dataclasses import dataclass +from typing import Optional + +import numpy as np +from ultralytics import YOLO + +from entity.model import AlgoModel +from services.model_service import ModelService +from common.global_logger import logger + + +@dataclass +class AlgoModelExec: + algo_model_id: int + algo_model_info: AlgoModel + algo_model_exec: Optional[object] = None + input_size: int = 640 + + +class ModelManager: + def __init__(self, db, model_warm_up=5): + self.db = db + self.model_service = ModelService(self.db) + self.models = {} + self.model_warm_up = model_warm_up + + self.load_models() + + def query_model_inuse(self): + algo_model_list = list(self.model_service.get_models_in_use()) + for algo_model in algo_model_list: + self.models[algo_model.id] = AlgoModelExec( + algo_model_id=algo_model.id, + algo_model_info=algo_model + ) + + def load_models(self): + self.models = {} + self.query_model_inuse() + for algo_model_id, algo_model_exec in self.models.items(): + self.load_model(algo_model_exec) + + def load_model(self, algo_model_exec: AlgoModelExec): + model_name = algo_model_exec.algo_model_info.name + model_path = algo_model_exec.algo_model_info.path + logger.info(f'loading model {model_name}: {model_path}') + + algo_model_exec.algo_model_exec = YOLO(model_path, task='detect') + if self.model_warm_up > 0: + logger.info(f'warming up model {model_name}') + imgsz = algo_model_exec.input_size + if not isinstance(imgsz, list): + imgsz = [imgsz, imgsz] + dummy_input = np.zeros((imgsz[0], imgsz[1], 3)) + for i in range(self.model_warm_up): + algo_model_exec.algo_model_exec.predict(source=dummy_input, imgsz=imgsz, verbose=False) + logger.info(f'warm up model {model_name} success!') + logger.info(f'load model {model_name} success!') + + def reload_model(self, model_id): + algo_model_exec = self.models.get(model_id) + if algo_model_exec: + algo_model_exec.algo_model_exec = None + self.load_model(algo_model_exec) diff --git a/apis/__init__.py b/apis/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/apis/__init__.py diff --git a/apis/base.py b/apis/base.py new file mode 100644 index 0000000..a7c4ede --- /dev/null +++ b/apis/base.py @@ -0,0 +1,26 @@ +from pydantic import BaseModel, Field +from typing import Optional, Generic, TypeVar, Any, List + +# 定义一个泛型类型变量 T +T = TypeVar("T") + + +# 使用泛型 T 定义 data 的类型 +class StandardResponse(BaseModel, Generic[T]): + code: int = 200 + data: Optional[T] = None + message: str = "请求成功" + success: bool = True + + +class PageResponse(BaseModel, Generic[T]): + total: int + items: List[T] = Field(default_factory=list) + + +def standard_response(data: Any = None, code: int = 200, message: str = "请求成功", success: bool = True): + return StandardResponse(data=data, code=code, message=message, success=success) + + +def standard_error_response(data: Any = None, code: int = 500, message: str = "请求异常", success: bool = False): + return StandardResponse(data=data, code=code, message=message, success=success) diff --git a/apis/device.py b/apis/device.py new file mode 100644 index 0000000..7eeaa78 --- /dev/null +++ b/apis/device.py @@ -0,0 +1,65 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, Query +from sqlmodel import Session + +from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response +from db.database import get_db +from entity.device import Device, DeviceCreate, DeviceUpdate, DeviceInfo +from services.device_service import DeviceService + +router = APIRouter() + + +@router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) +def get_device_list( + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + db: Session = Depends(get_db)): + service = DeviceService(db) + devices = list(service.get_device_list(name, code, device_type)) + return standard_response(data=devices) + + +@router.get("/page/", response_model=StandardResponse[PageResponse[Device]]) +def get_device_page( + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = DeviceService(db) + + # 获取分页后的设备列表和总数 + devices, total = service.get_device_page(name, code, device_type, offset, limit) + + return standard_response( + data=PageResponse(total=total, items=devices) + ) + + +@router.post("/add", response_model=StandardResponse[DeviceInfo]) +def create_device(device_data: DeviceCreate, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.create_device(device_data) + return standard_response(data=device) + + +@router.post("/update", response_model=StandardResponse[DeviceInfo]) +def update_device(device_data: DeviceUpdate, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.update_device(device_data) + if not device: + return standard_error_response(data=device_data, message="Device not found") + return standard_response(data=device) + + +@router.delete("/delete", response_model=StandardResponse[int]) +def delete_device(device_id: int, db: Session = Depends(get_db)): + service = DeviceService(db) + device = service.delete_device(device_id) + if not device: + return standard_error_response(data=device_id, message="Device not found") + return standard_response(data=device_id) diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py new file mode 100644 index 0000000..7108571 --- /dev/null +++ b/apis/device_model_realtion.py @@ -0,0 +1,45 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, Query +from sqlmodel import Session + +from apis.base import StandardResponse, standard_response +from db.database import get_db +from entity.device_model_relation import DeviceModelRelationInfo, DeviceModelRelationCreate, DeviceModelRelation +from services.device_model_relation_service import DeviceModelRelationService + +router = APIRouter() + + +@router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) +def list_by_device( + device_id: int, + db: Session = Depends(get_db)): + service = DeviceModelRelationService(db) + models = list(service.get_device_models(device_id)) + return standard_response(data=models) + + +# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) +# def add_by_device(relation_data: List[DeviceModelRelationCreate], +# device_id: int = Query(...), +# db: Session = Depends(get_db)): +# service = DeviceModelRelationService(db) +# relations = service.add_relations_by_device(device_id, relation_data) +# return standard_response(data=relations) + + +@router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) +def update_by_device(relation_data: List[DeviceModelRelationCreate], + device_id: int = Query(...), + db: Session = Depends(get_db)): + service = DeviceModelRelationService(db) + relations = service.update_relations_by_device(device_id, relation_data) + return standard_response(data=relations) + + +# @router.delete("/delete_by_device", response_model=StandardResponse[int]) +# def delete_device(device_id: int, db: Session = Depends(get_db)): +# service = DeviceModelRelationService(db) +# count = service.delete_relations_by_device(device_id) +# return standard_response(data=count) diff --git a/apis/model.py b/apis/model.py new file mode 100644 index 0000000..8adcc14 --- /dev/null +++ b/apis/model.py @@ -0,0 +1,64 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlmodel import Session + +from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response +from common.biz_exception import BizException +from db.database import get_db +from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate, AlgoModelInfo +from services.model_service import ModelService + +router = APIRouter() + + +@router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) +def get_model_list( + name: Optional[str] = None, + remark: Optional[str] = None, + db: Session = Depends(get_db)): + service = ModelService(db) + models = list(service.get_model_list(name, remark)) + return standard_response(data=models) + + +@router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModel]]) +def get_model_page( + name: Optional[str] = None, + remark: Optional[str] = None, + offset: int = Query(0, ge=0), # 从第几条开始 + limit: int = Query(10, ge=1), # 每页显示多少条记录 + db: Session = Depends(get_db)): + service = ModelService(db) + + # 获取分页后的设备列表和总数 + models, total = service.get_model_page(name, remark, offset, limit) + + return standard_response( + data=PageResponse(total=total, items=models) + ) + + +@router.post("/add", response_model=StandardResponse[AlgoModelInfo]) +def create_model(model_data: AlgoModelCreate, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.create_model(model_data) + return standard_response(data=model) + + +@router.post("/update", response_model=StandardResponse[AlgoModelInfo]) +def update_model(model_data: AlgoModelUpdate, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.update_model(model_data) + if not model: + return standard_error_response(data=model_data, message="Model not found") + return standard_response(data=model) + + +@router.delete("/delete", response_model=StandardResponse[int]) +def delete_model(model_id: int, db: Session = Depends(get_db)): + service = ModelService(db) + model = service.delete_model(model_id) + if not model: + return standard_error_response(data=model_id, message="Model not found") + return standard_response(data=model_id) diff --git a/apis/router.py b/apis/router.py new file mode 100644 index 0000000..22b5ed6 --- /dev/null +++ b/apis/router.py @@ -0,0 +1,15 @@ +# api/router.py +from fastapi import APIRouter +from .device import router as devices_router +from .model import router as models_router +from .device_model_realtion import router as device_model_relation_router + + +# 创建一个全局的 router +router = APIRouter() + +# 将各个模块的 router 注册到全局 router 中 +router.include_router(devices_router, prefix="/device", tags=["Devices"]) +router.include_router(models_router, prefix="/model", tags=["Models"]) +router.include_router(device_model_relation_router, prefix="/device_model_relation", tags=["DeviceModelRelations"]) + diff --git a/common/__init__.py b/common/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/common/__init__.py diff --git a/common/biz_exception.py b/common/biz_exception.py new file mode 100644 index 0000000..946a794 --- /dev/null +++ b/common/biz_exception.py @@ -0,0 +1,14 @@ +from fastapi import Request, HTTPException + +from apis.base import standard_error_response + + +class BizException(HTTPException): + def __init__(self, message: str = "请求异常", status_code: int = 500): + super().__init__(status_code=status_code, detail=message) + + +class BizExceptionHandlers: + @staticmethod + async def biz_exception_handler(request: Request, exc: HTTPException): + standard_error_response(code=exc.status_code, message=exc.detail) diff --git a/common/consts.py b/common/consts.py new file mode 100644 index 0000000..bba41e4 --- /dev/null +++ b/common/consts.py @@ -0,0 +1,16 @@ +class Constants: + def __setattr__(self, name, value): + raise AttributeError(f"Cannot modify constant {name}") + + +class NotifyChangeType(Constants): + # 定义常量成员 + DEVICE_CREATE = "device_create" + DEVICE_UPDATE = "device_update" + DEVICE_DELETE = "device_delete" + MODEL_CREATE = "model_create" # 模型新增不用通知,与设备绑定时才生效 + MODEL_UPDATE = "model_update" + MODEL_DELETE = "model_delete" # 正在使用的模型不能直接删除 + DEVICE_MODEL_RELATION_CREATE = "device_model_relation_create" + DEVICE_MODEL_RELATION_UPDATE = "device_model_relation_update" # 绑定关系变化 应该只用这个吧?? + DEVICE_MODEL_RELATION_DELETE = "device_model_relation_delete" diff --git a/common/global_logger.py b/common/global_logger.py new file mode 100644 index 0000000..1d65173 --- /dev/null +++ b/common/global_logger.py @@ -0,0 +1,30 @@ +# logger.py +import logging.handlers +import os + +# 确保日志目录存在 +log_dir = '../logs' +if not os.path.exists(log_dir): + os.makedirs(log_dir) + +# 实例化并导出全局日志记录器 +logger = logging.getLogger("casic_safe_logger") +logger.setLevel(logging.DEBUG) # 设置日志级别 + +# 创建一个TimedRotatingFileHandler +handler = logging.handlers.TimedRotatingFileHandler( + os.path.join(log_dir, 'app.log'), # 日志文件名 + when='midnight', # 每天午夜滚动 + interval=1 # 滚动间隔为1天 +) +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +handler.setFormatter(formatter) + +# 将handler添加到日志器 +logger.addHandler(handler) + +# 创建控制台处理器 +console_handler = logging.StreamHandler() +console_handler.setLevel(logging.DEBUG) +console_handler.setFormatter(formatter) +logger.addHandler(console_handler) diff --git a/common/global_thread_pool.py b/common/global_thread_pool.py new file mode 100644 index 0000000..21d7fd5 --- /dev/null +++ b/common/global_thread_pool.py @@ -0,0 +1,60 @@ +import uuid +from concurrent.futures import ThreadPoolExecutor +import threading + + +def generate_thread_id(): + """生成唯一的线程 ID""" + return str(uuid.uuid4()) + + +class GlobalThreadPool: + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + with cls._lock: + if cls._instance is None: + # 第一次创建实例时调用父类的 __new__ 来创建实例 + cls._instance = super(GlobalThreadPool, cls).__new__(cls) + # 在此进行一次性的初始化,比如线程池的创建 + max_workers = kwargs.get('max_workers', 10) + cls._instance.executor = ThreadPoolExecutor(max_workers=max_workers) + cls._instance.task_map = {} # 初始化任务映射 + return cls._instance + + def submit_task(self, fn, *args, thread_id=None, **kwargs): + """提交任务到线程池,并记录线程 ID""" + if thread_id is None: + thread_id = generate_thread_id() + if self.check_task_is_running(thread_id): + raise ValueError(f"线程 ID {thread_id} 已存在") + future = self.executor.submit(fn, *args, **kwargs) + self.task_map[thread_id] = future # 记录线程 ID 和 Future 对象的映射 + return thread_id + + def check_task_is_running(self, thread_id): + future = self.task_map.get(thread_id) + if future: + if future.running(): + return True + else: + del self.task_map[thread_id] + return False + else: + return False + + def stop_task(self, thread_id): + """todo [可能不生效,需要控制线程任务里的标志位让它停止] 停止指定线程 ID 的任务""" + future = self.task_map.get(thread_id) + if future: + future.cancel() # 尝试取消任务 + print(f"任务 {thread_id} 已取消") + del self.task_map[thread_id] # 从任务映射中删除 + else: + print(f"未找到线程 ID {thread_id}") + + def shutdown(self, wait=True): + """关闭线程池""" + self.executor.shutdown(wait=wait) + GlobalThreadPool._instance = None diff --git a/db/__init__.py b/db/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/db/__init__.py diff --git a/db/database.py b/db/database.py new file mode 100644 index 0000000..4f44d2d --- /dev/null +++ b/db/database.py @@ -0,0 +1,22 @@ +from sqlmodel import SQLModel, create_engine, Session +from contextlib import contextmanager + +sqlite_file_name = "./db/safe-algo-pro.db" +sqlite_url = f"sqlite:///{sqlite_file_name}" + +connect_args = {"check_same_thread": False} +engine = create_engine(sqlite_url, connect_args=connect_args) + + +# 初始化数据库表 +def init_db(): + SQLModel.metadata.create_all(engine) + + +# 数据库会话管理 +def get_db(): + session = Session(engine) + try: + yield session + finally: + session.close() diff --git a/db/safe-algo-pro.db b/db/safe-algo-pro.db new file mode 100644 index 0000000..86c8d64 --- /dev/null +++ b/db/safe-algo-pro.db Binary files differ diff --git a/entity/__init__.py b/entity/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/entity/__init__.py diff --git a/entity/base.py b/entity/base.py new file mode 100644 index 0000000..6aa862c --- /dev/null +++ b/entity/base.py @@ -0,0 +1,13 @@ +from sqlmodel import SQLModel, Field +from typing import Optional +from datetime import datetime + + +class TimestampMixin(SQLModel): + create_time: Optional[datetime] = Field(default_factory=datetime.now) + update_time: Optional[datetime] = Field(default_factory=datetime.now) + + class Config: + json_encoders = { + datetime: lambda v: v.strftime('%Y-%m-%d %H:%M:%S') + } diff --git a/entity/device.py b/entity/device.py new file mode 100644 index 0000000..f09b52d --- /dev/null +++ b/entity/device.py @@ -0,0 +1,33 @@ +from typing import Optional +from sqlmodel import SQLModel, Field + +from entity.base import TimestampMixin + + +class DeviceBase(SQLModel): + name: str + code: str + type: Optional[str] = None + ip: str + input_stream_url: Optional[str] = None + output_stream_url: Optional[str] = None + image_save_interval: Optional[int] = None + + +class Device(DeviceBase, TimestampMixin, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + + +class DeviceCreate(DeviceBase): + pass + + +class DeviceUpdate(DeviceBase): + id: int + name: Optional[str] = None + code: Optional[str] = None + ip: Optional[str] = None + + +class DeviceInfo(DeviceBase, TimestampMixin): + id: int diff --git a/entity/device_model_relation.py b/entity/device_model_relation.py new file mode 100644 index 0000000..2e58a39 --- /dev/null +++ b/entity/device_model_relation.py @@ -0,0 +1,31 @@ +from typing import Optional + +from sqlmodel import SQLModel, Field + +from entity.base import TimestampMixin + + +class DeviceModelRelationBase(SQLModel): + algo_model_id: int + is_use: int + threshold: Optional[float] = None + alarm_interval: Optional[int] = None + alarm_type: Optional[str] = None + + +class DeviceModelRelation(DeviceModelRelationBase, TimestampMixin, table=True): + __tablename__ = 'device_model_relation' + id: Optional[int] = Field(default=None, primary_key=True) + device_id: int + +class DeviceModelRelationCreate(DeviceModelRelationBase): + pass + + +class DeviceModelRelationInfo(DeviceModelRelationBase, TimestampMixin): + id: int + device_id: int + algo_model_name: str + algo_model_version: str + algo_model_path: Optional[str] = None # 可选字段 + algo_model_remark: Optional[str] = None diff --git a/entity/model.py b/entity/model.py new file mode 100644 index 0000000..db9133e --- /dev/null +++ b/entity/model.py @@ -0,0 +1,33 @@ +from typing import Optional + +from sqlmodel import SQLModel, Field + +from entity.base import TimestampMixin + + +class AlgoModelBase(SQLModel): + name: str + version: str + path: str + remark: Optional[str] = None + + +class AlgoModel(AlgoModelBase, TimestampMixin, table=True): + __tablename__ = "algo_model" # 显式指定表名 + + id: Optional[int] = Field(default=None, primary_key=True) + + +class AlgoModelCreate(AlgoModelBase): + pass + + +class AlgoModelUpdate(AlgoModelBase): + id: int + name: Optional[str] = None + version: Optional[str] = None + path: Optional[str] = None + + +class AlgoModelInfo(AlgoModelBase, TimestampMixin): + id: int diff --git a/main.py b/main.py new file mode 100644 index 0000000..75c9b52 --- /dev/null +++ b/main.py @@ -0,0 +1,15 @@ +from fastapi import FastAPI, HTTPException +from apis.router import router +import uvicorn + +from common.biz_exception import BizExceptionHandlers + +app = FastAPI() + +# 包含所有模块的路由 +app.include_router(router, prefix="/api") +app.add_exception_handler(HTTPException, BizExceptionHandlers.http_exception_handler) + +if __name__ == "__main__": + + uvicorn.run(app, host="0.0.0.0", port=9000) \ No newline at end of file diff --git a/services/__init__.py b/services/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/services/__init__.py diff --git a/services/device_model_relation_service.py b/services/device_model_relation_service.py new file mode 100644 index 0000000..0f8d9a7 --- /dev/null +++ b/services/device_model_relation_service.py @@ -0,0 +1,85 @@ +from datetime import datetime +from typing import List + +from sqlmodel import Session, select, delete + +from common.consts import NotifyChangeType +from common.global_thread_pool import GlobalThreadPool +from entity.device_model_relation import DeviceModelRelation, DeviceModelRelationInfo, DeviceModelRelationCreate +from entity.model import AlgoModel + + +class DeviceModelRelationService: + def __init__(self, db: Session): + self.db = db + self.relation_change_callbacks = [] # 用于存储回调函数 + self.thread_pool = GlobalThreadPool() + + def register_change_callback(self, callback): + """注册设备变化回调函数""" + self.relation_change_callbacks.append(callback) + + def notify_change(self, device_id, change_type): + """当设备发生变化时,调用回调通知变化""" + for callback in self.relation_change_callbacks: + self.thread_pool.executor.submit(callback, device_id, change_type) + + def get_device_models(self, device_id: int) -> List[DeviceModelRelationInfo]: + statement = ( + select(DeviceModelRelation, AlgoModel) + .join(AlgoModel, DeviceModelRelation.algo_model_id == AlgoModel.id) + .where(DeviceModelRelation.device_id == device_id) + ) + + # 执行联表查询 + result = self.db.exec(statement).all() + + models_info = [ + DeviceModelRelationInfo( + id=relation.id, + device_id=relation.id, + algo_model_id=relation.algo_model_id, + is_use=relation.is_use, + threshold=relation.threshold, + alarm_interval=relation.alarm_interval, + alarm_type=relation.alarm_type, + algo_model_name=model.name, + algo_model_version=model.version, + algo_model_path=model.path, + algo_model_remark=model.remark, + ) + for relation, model in result + ] + return models_info + + def add_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]): + new_relations = [ + DeviceModelRelation( + algo_model_id=relation.algo_model_id, + is_use=relation.is_use, + threshold=relation.threshold, + alarm_interval=relation.alarm_interval, + alarm_type=relation.alarm_type, + device_id=device_id, # 统一赋值 device_id + createtime=datetime.now(), + updatetime=datetime.now(), + ) + for relation in relations + ] + self.db.add_all(new_relations) + self.db.commit() + for relation in new_relations: + self.db.refresh(relation) + return new_relations + + def delete_relations_by_device(self, device_id: int): + statement = delete(DeviceModelRelation).where(DeviceModelRelation.device_id == device_id) + count = self.db.exec(statement) + self.db.commit() + return count.rowcount + + def update_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]): + self.delete_relations_by_device(device_id) + new_relations = self.add_relations_by_device(device_id, relations) + self.notify_change(device_id, NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE) + return new_relations diff --git a/services/device_service.py b/services/device_service.py new file mode 100644 index 0000000..0284c4e --- /dev/null +++ b/services/device_service.py @@ -0,0 +1,108 @@ +from datetime import datetime +from typing import Sequence, Optional, Tuple + +from sqlalchemy import func +from sqlmodel import Session, select + +from common.global_thread_pool import GlobalThreadPool +from common.consts import NotifyChangeType +from entity.device import Device, DeviceCreate, DeviceUpdate +from services.device_model_relation_service import DeviceModelRelationService + + +class DeviceService: + def __init__(self, db: Session): + self.db = db + self.device_change_callbacks = [] # 用于存储回调函数 + self.thread_pool = GlobalThreadPool() + + def register_change_callback(self, callback): + """注册设备变化回调函数""" + self.device_change_callbacks.append(callback) + + def notify_change(self, device_id, change_type): + """当设备发生变化时,调用回调通知变化""" + for callback in self.device_change_callbacks: + self.thread_pool.executor.submit(callback, device_id, change_type) + + def get_device_list(self, + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + ) -> Sequence[Device]: + statement = self.device_query(code, device_type, name) + results = self.db.exec(statement) + return results.all() + + def get_device_page(self, + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + offset: int = 0, + limit: int = 10 + ) -> Tuple[Sequence[Device], int]: + statement = self.device_query(code, device_type, name) + + # 查询总记录数 + total_statement = select(func.count()).select_from(statement.subquery()) + total = self.db.exec(total_statement).one() + + # 添加分页限制 + statement = statement.offset(offset).limit(limit) + + # 执行查询并返回结果 + results = self.db.exec(statement) + return results.all(), total # 返回分页数据和总数 + + def device_query(self, code, device_type, name): + # 构建查询语句 + statement = select(Device) + if name: + statement = statement.where(Device.name.like(f"%{name}%")) + if code: + statement = statement.where(Device.code.like(f"%{code}%")) + if device_type: + statement = statement.where(Device.type == device_type) + return statement + + def create_device(self, device_data: DeviceCreate): + device = Device.model_validate(device_data) + device.create_time = datetime.now() + device.update_time = datetime.now() + self.db.add(device) + self.db.commit() + self.db.refresh(device) + self.notify_change(device.id, NotifyChangeType.DEVICE_CREATE) + return device + + def update_device(self, device_data: DeviceUpdate): + device = self.db.get(Device, device_data.id) + if not device: + return None + + update_data = device_data.dict(exclude_unset=True) + for key, value in update_data.items(): + setattr(device, key, value) + + device.update_time = datetime.now() + self.db.add(device) + self.db.commit() + self.db.refresh(device) + self.notify_change(device.id, NotifyChangeType.DEVICE_UPDATE) + return device + + def delete_device(self, device_id: int): + device = self.db.get(Device, device_id) + if not device: + return None + + self.db.delete(device) + self.db.commit() + self.notify_change(device.id, NotifyChangeType.DEVICE_DELETE) + + relation_service = DeviceModelRelationService(self.db) + relation_service.delete_relations_by_device(device_id) + return device + + def get_device(self, device_id: int): + return self.db.get(Device, device_id) diff --git a/services/model_service.py b/services/model_service.py new file mode 100644 index 0000000..568895e --- /dev/null +++ b/services/model_service.py @@ -0,0 +1,120 @@ +from datetime import datetime +from typing import List, Sequence, Optional, Tuple + +from sqlalchemy import func +from sqlmodel import Session, select + +from common.biz_exception import BizException +from entity.device_model_relation import DeviceModelRelation +from entity.model import AlgoModel, AlgoModelCreate, AlgoModelUpdate +from common.global_thread_pool import GlobalThreadPool +from common.consts import NotifyChangeType + + +class ModelService: + def __init__(self, db: Session): + self.db = db + self.model_change_callbacks = [] # 用于存储回调函数 + self.thread_pool = GlobalThreadPool() + + def register_change_callback(self, callback): + """注册设备变化回调函数""" + self.model_change_callbacks.append(callback) + + def notify_change(self, algo_model_id, change_type): + """当设备发生变化时,调用回调通知变化""" + for callback in self.model_change_callbacks: + self.thread_pool.executor.submit(callback, algo_model_id, change_type) + + def get_model_list(self, + name: Optional[str] = None, + remark: Optional[str] = None, + ) -> Sequence[AlgoModel]: + statement = self.model_query(name, remark) + results = self.db.exec(statement) + return results.all() + + def get_model_page(self, + name: Optional[str] = None, + remark: Optional[str] = None, + offset: int = 0, + limit: int = 10 + ) -> Tuple[Sequence[AlgoModel], int]: + statement = self.model_query(name, remark) + + # 查询总记录数 + total_statement = select(func.count()).select_from(statement.subquery()) + total = self.db.exec(total_statement).one() + + # 添加分页限制 + statement = statement.offset(offset).limit(limit) + + # 执行查询并返回结果 + results = self.db.exec(statement) + return results.all(), total # 返回分页数据和总数 + + def model_query(self, name, remark): + # 构建查询语句 + statement = select(AlgoModel) + if name: + statement = statement.where(AlgoModel.name.like(f"%{name}%")) + if remark: + statement = statement.where(AlgoModel.remark.like(f"%{remark}%")) + return statement + + def create_model(self, model_data: AlgoModelCreate): + model = AlgoModel.model_validate(model_data) + model.create_time = datetime.now() + model.update_time = datetime.now() + self.db.add(model) + self.db.commit() + self.db.refresh(model) + return model + + def update_model(self, model_data: AlgoModelUpdate): + model = self.db.get(AlgoModel, model_data.id) + if not model: + return None + + update_data = model_data.dict(exclude_unset=True) + for key, value in update_data.items(): + setattr(model, key, value) + + model.update_time = datetime.now() + self.db.add(model) + self.db.commit() + self.db.refresh(model) + self.notify_change(model.id, NotifyChangeType.MODEL_UPDATE) + + return model + + def delete_model(self, model_id: int): + model = self.db.get(AlgoModel, model_id) + if not model: + return None + # 查询 device_model_relation 中是否存在启用的绑定关系 + statement = ( + select(DeviceModelRelation) + .where(DeviceModelRelation.algo_model_id == model_id) + .where(DeviceModelRelation.is_use == 1) + ) + relation_in_use = self.db.exec(statement).first() + + # 如果存在启用的绑定关系,提示无法删除 + if relation_in_use: + raise BizException(message=f"模型 {model.name} 正在被设备使用,无法删除") + + self.db.delete(model) + self.db.commit() + return model + + def get_models_in_use(self) -> Sequence[AlgoModel]: + """获取所有在 device_model_relation 表里有启用绑定关系的模型信息""" + statement = ( + select(AlgoModel) + .join(DeviceModelRelation, DeviceModelRelation.algo_model_id == AlgoModel.id) + .where(DeviceModelRelation.is_use == 1) + .group_by(AlgoModel.id) + ) + results = self.db.exec(statement).all() + return results diff --git a/tcp/__init__.py b/tcp/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tcp/__init__.py