diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index dbc3fee..cdb97bf 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -98,7 +98,7 @@ def push_frame_results(self, frame_results): global_config = GlobalConfig() push_config = global_config.get_algo_result_push_config() - if push_config: + if push_config and push_config.push_url: last_ts = self.push_ts current_time = datetime.now() diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index dbc3fee..cdb97bf 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -98,7 +98,7 @@ def push_frame_results(self, frame_results): global_config = GlobalConfig() push_config = global_config.get_algo_result_push_config() - if push_config: + if push_config and push_config.push_url: last_ts = self.push_ts current_time = datetime.now() diff --git a/apis/control.py b/apis/control.py index 8517389..4c108c5 100644 --- a/apis/control.py +++ b/apis/control.py @@ -48,7 +48,7 @@ @router.get("/restart") -def restart(): +async def restart(): try: # 立即返回响应的函数 def restart_container_async(): diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index dbc3fee..cdb97bf 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -98,7 +98,7 @@ def push_frame_results(self, frame_results): global_config = GlobalConfig() push_config = global_config.get_algo_result_push_config() - if push_config: + if push_config and push_config.push_url: last_ts = self.push_ts current_time = datetime.now() diff --git a/apis/control.py b/apis/control.py index 8517389..4c108c5 100644 --- a/apis/control.py +++ b/apis/control.py @@ -48,7 +48,7 @@ @router.get("/restart") -def restart(): +async def restart(): try: # 立即返回响应的函数 def restart_container_async(): diff --git a/apis/data_gas.py b/apis/data_gas.py index fc8ec54..bcb7cb6 100644 --- a/apis/data_gas.py +++ b/apis/data_gas.py @@ -3,8 +3,8 @@ import pandas as pd from fastapi import APIRouter, Query, Depends -from openpyxl.utils import get_column_letter -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession + from starlette.responses import StreamingResponse from apis.base import StandardResponse, PageResponse, convert_page_param, standard_response @@ -17,19 +17,20 @@ @router.get("/page", response_model=StandardResponse[PageResponse[DataGasInfo]]) -def get_gas_page( +async def get_gas_page( device_code: Optional[str] = None, start_time: Optional[str] = None, end_time: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): + service = DataGasService(db) offset, limit = convert_page_param(offset, limit) - data, total = service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), - offset, limit) + data, total = await service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), + offset, limit) return standard_response( data=PageResponse(total=total, items=data) @@ -37,38 +38,39 @@ @router.get("/export") -def export_data_gas(device_code: Optional[str] = None, - start_time: Optional[str] = None, - end_time: Optional[str] = None, - db: Session = Depends(get_db)): - service = DataGasService(db) - data = service.get_data_gas_list(device_code, start_time, end_time) +async def export_data_gas(device_code: Optional[str] = None, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + db: AsyncSession = Depends(get_db)): + async with db: + service = DataGasService(db) + data = await service.get_data_gas_list(device_code, start_time, end_time) - # 将查询结果转换为 DataFrame - data = [ - {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} - for item in data] - df = pd.DataFrame(data) + # 将查询结果转换为 DataFrame + data = [ + {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} + for item in data] + df = pd.DataFrame(data) - # 使用 BytesIO 生成内存中的 Excel 文件 - output = BytesIO() - with pd.ExcelWriter(output, engine="openpyxl") as writer: - df.to_excel(writer, index=False) + # 使用 BytesIO 生成内存中的 Excel 文件 + output = BytesIO() + with pd.ExcelWriter(output, engine="openpyxl") as writer: + df.to_excel(writer, index=False) - # 获取工作表 - worksheet = writer.sheets["Sheet1"] + # 获取工作表 + worksheet = writer.sheets["Sheet1"] - # 设置固定的列宽,例如宽度为 20 - fixed_width = 30 - for col in worksheet.columns: - col_letter = col[0].column_letter # 获取列字母 - worksheet.column_dimensions[col_letter].width = fixed_width + # 设置固定的列宽,例如宽度为 20 + fixed_width = 30 + for col in worksheet.columns: + col_letter = col[0].column_letter # 获取列字母 + worksheet.column_dimensions[col_letter].width = fixed_width - output.seek(0) + output.seek(0) - # 返回内存中的 Excel 文件作为响应 - return StreamingResponse( - output, - media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", - headers={"Content-Disposition": "attachment; filename=gas.xlsx"} - ) + # 返回内存中的 Excel 文件作为响应 + return StreamingResponse( + output, + media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + headers={"Content-Disposition": "attachment; filename=gas.xlsx"} + ) diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index dbc3fee..cdb97bf 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -98,7 +98,7 @@ def push_frame_results(self, frame_results): global_config = GlobalConfig() push_config = global_config.get_algo_result_push_config() - if push_config: + if push_config and push_config.push_url: last_ts = self.push_ts current_time = datetime.now() diff --git a/apis/control.py b/apis/control.py index 8517389..4c108c5 100644 --- a/apis/control.py +++ b/apis/control.py @@ -48,7 +48,7 @@ @router.get("/restart") -def restart(): +async def restart(): try: # 立即返回响应的函数 def restart_container_async(): diff --git a/apis/data_gas.py b/apis/data_gas.py index fc8ec54..bcb7cb6 100644 --- a/apis/data_gas.py +++ b/apis/data_gas.py @@ -3,8 +3,8 @@ import pandas as pd from fastapi import APIRouter, Query, Depends -from openpyxl.utils import get_column_letter -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession + from starlette.responses import StreamingResponse from apis.base import StandardResponse, PageResponse, convert_page_param, standard_response @@ -17,19 +17,20 @@ @router.get("/page", response_model=StandardResponse[PageResponse[DataGasInfo]]) -def get_gas_page( +async def get_gas_page( device_code: Optional[str] = None, start_time: Optional[str] = None, end_time: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): + service = DataGasService(db) offset, limit = convert_page_param(offset, limit) - data, total = service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), - offset, limit) + data, total = await service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), + offset, limit) return standard_response( data=PageResponse(total=total, items=data) @@ -37,38 +38,39 @@ @router.get("/export") -def export_data_gas(device_code: Optional[str] = None, - start_time: Optional[str] = None, - end_time: Optional[str] = None, - db: Session = Depends(get_db)): - service = DataGasService(db) - data = service.get_data_gas_list(device_code, start_time, end_time) +async def export_data_gas(device_code: Optional[str] = None, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + db: AsyncSession = Depends(get_db)): + async with db: + service = DataGasService(db) + data = await service.get_data_gas_list(device_code, start_time, end_time) - # 将查询结果转换为 DataFrame - data = [ - {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} - for item in data] - df = pd.DataFrame(data) + # 将查询结果转换为 DataFrame + data = [ + {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} + for item in data] + df = pd.DataFrame(data) - # 使用 BytesIO 生成内存中的 Excel 文件 - output = BytesIO() - with pd.ExcelWriter(output, engine="openpyxl") as writer: - df.to_excel(writer, index=False) + # 使用 BytesIO 生成内存中的 Excel 文件 + output = BytesIO() + with pd.ExcelWriter(output, engine="openpyxl") as writer: + df.to_excel(writer, index=False) - # 获取工作表 - worksheet = writer.sheets["Sheet1"] + # 获取工作表 + worksheet = writer.sheets["Sheet1"] - # 设置固定的列宽,例如宽度为 20 - fixed_width = 30 - for col in worksheet.columns: - col_letter = col[0].column_letter # 获取列字母 - worksheet.column_dimensions[col_letter].width = fixed_width + # 设置固定的列宽,例如宽度为 20 + fixed_width = 30 + for col in worksheet.columns: + col_letter = col[0].column_letter # 获取列字母 + worksheet.column_dimensions[col_letter].width = fixed_width - output.seek(0) + output.seek(0) - # 返回内存中的 Excel 文件作为响应 - return StreamingResponse( - output, - media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", - headers={"Content-Disposition": "attachment; filename=gas.xlsx"} - ) + # 返回内存中的 Excel 文件作为响应 + return StreamingResponse( + output, + media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + headers={"Content-Disposition": "attachment; filename=gas.xlsx"} + ) diff --git a/apis/device.py b/apis/device.py index 85fb10c..f170ae3 100644 --- a/apis/device.py +++ b/apis/device.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db @@ -18,29 +18,29 @@ @router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) -def get_device_list( +async def get_device_list( name: Optional[str] = None, code: Optional[str] = None, device_type: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceService(db) - devices = list(service.get_device_list(name, code, device_type)) - return standard_response(data=devices) + devices = await service.get_device_list(name, code, device_type) + return standard_response(data=list(devices)) @router.get("/page/", response_model=StandardResponse[PageResponse[DeviceInfo]]) -def get_device_page( +async def get_device_page( name: Optional[str] = None, code: Optional[str] = None, device_type: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - devices, total = service.get_device_page(name, code, device_type, offset, limit) + devices, total = await service.get_device_page(name, code, device_type, offset, limit) return standard_response( data=PageResponse(total=total, items=devices) @@ -48,22 +48,22 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): - device = service.create_device(device_data) +async def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): + device = await service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): - device = service.update_device(device_data) +async def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): + device = await service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") return standard_response(data=device) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, service: DeviceService = Depends(get_service)): - device = service.delete_device(device_id) +async def delete_device(device_id: int, service: DeviceService = Depends(get_service)): + device = await service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") return standard_response(data=device_id) diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index dbc3fee..cdb97bf 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -98,7 +98,7 @@ def push_frame_results(self, frame_results): global_config = GlobalConfig() push_config = global_config.get_algo_result_push_config() - if push_config: + if push_config and push_config.push_url: last_ts = self.push_ts current_time = datetime.now() diff --git a/apis/control.py b/apis/control.py index 8517389..4c108c5 100644 --- a/apis/control.py +++ b/apis/control.py @@ -48,7 +48,7 @@ @router.get("/restart") -def restart(): +async def restart(): try: # 立即返回响应的函数 def restart_container_async(): diff --git a/apis/data_gas.py b/apis/data_gas.py index fc8ec54..bcb7cb6 100644 --- a/apis/data_gas.py +++ b/apis/data_gas.py @@ -3,8 +3,8 @@ import pandas as pd from fastapi import APIRouter, Query, Depends -from openpyxl.utils import get_column_letter -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession + from starlette.responses import StreamingResponse from apis.base import StandardResponse, PageResponse, convert_page_param, standard_response @@ -17,19 +17,20 @@ @router.get("/page", response_model=StandardResponse[PageResponse[DataGasInfo]]) -def get_gas_page( +async def get_gas_page( device_code: Optional[str] = None, start_time: Optional[str] = None, end_time: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): + service = DataGasService(db) offset, limit = convert_page_param(offset, limit) - data, total = service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), - offset, limit) + data, total = await service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), + offset, limit) return standard_response( data=PageResponse(total=total, items=data) @@ -37,38 +38,39 @@ @router.get("/export") -def export_data_gas(device_code: Optional[str] = None, - start_time: Optional[str] = None, - end_time: Optional[str] = None, - db: Session = Depends(get_db)): - service = DataGasService(db) - data = service.get_data_gas_list(device_code, start_time, end_time) +async def export_data_gas(device_code: Optional[str] = None, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + db: AsyncSession = Depends(get_db)): + async with db: + service = DataGasService(db) + data = await service.get_data_gas_list(device_code, start_time, end_time) - # 将查询结果转换为 DataFrame - data = [ - {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} - for item in data] - df = pd.DataFrame(data) + # 将查询结果转换为 DataFrame + data = [ + {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} + for item in data] + df = pd.DataFrame(data) - # 使用 BytesIO 生成内存中的 Excel 文件 - output = BytesIO() - with pd.ExcelWriter(output, engine="openpyxl") as writer: - df.to_excel(writer, index=False) + # 使用 BytesIO 生成内存中的 Excel 文件 + output = BytesIO() + with pd.ExcelWriter(output, engine="openpyxl") as writer: + df.to_excel(writer, index=False) - # 获取工作表 - worksheet = writer.sheets["Sheet1"] + # 获取工作表 + worksheet = writer.sheets["Sheet1"] - # 设置固定的列宽,例如宽度为 20 - fixed_width = 30 - for col in worksheet.columns: - col_letter = col[0].column_letter # 获取列字母 - worksheet.column_dimensions[col_letter].width = fixed_width + # 设置固定的列宽,例如宽度为 20 + fixed_width = 30 + for col in worksheet.columns: + col_letter = col[0].column_letter # 获取列字母 + worksheet.column_dimensions[col_letter].width = fixed_width - output.seek(0) + output.seek(0) - # 返回内存中的 Excel 文件作为响应 - return StreamingResponse( - output, - media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", - headers={"Content-Disposition": "attachment; filename=gas.xlsx"} - ) + # 返回内存中的 Excel 文件作为响应 + return StreamingResponse( + output, + media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + headers={"Content-Disposition": "attachment; filename=gas.xlsx"} + ) diff --git a/apis/device.py b/apis/device.py index 85fb10c..f170ae3 100644 --- a/apis/device.py +++ b/apis/device.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db @@ -18,29 +18,29 @@ @router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) -def get_device_list( +async def get_device_list( name: Optional[str] = None, code: Optional[str] = None, device_type: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceService(db) - devices = list(service.get_device_list(name, code, device_type)) - return standard_response(data=devices) + devices = await service.get_device_list(name, code, device_type) + return standard_response(data=list(devices)) @router.get("/page/", response_model=StandardResponse[PageResponse[DeviceInfo]]) -def get_device_page( +async def get_device_page( name: Optional[str] = None, code: Optional[str] = None, device_type: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - devices, total = service.get_device_page(name, code, device_type, offset, limit) + devices, total = await service.get_device_page(name, code, device_type, offset, limit) return standard_response( data=PageResponse(total=total, items=devices) @@ -48,22 +48,22 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): - device = service.create_device(device_data) +async def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): + device = await service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): - device = service.update_device(device_data) +async def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): + device = await service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") return standard_response(data=device) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, service: DeviceService = Depends(get_service)): - device = service.delete_device(device_id) +async def delete_device(device_id: int, service: DeviceService = Depends(get_service)): + device = await service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") return standard_response(data=device_id) diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py index e61e1d6..ddb5eb6 100644 --- a/apis/device_model_realtion.py +++ b/apis/device_model_realtion.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import StandardResponse, standard_response from db.database import get_db @@ -13,37 +13,23 @@ router = APIRouter() app = get_app() + def get_service(): return app.state.model_relation_service @router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) -def list_by_device( +async def list_by_device( device_id: int, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceModelRelationService(db) - models = list(service.get_device_models(device_id)) - return standard_response(data=models) - - -# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) -# def add_by_device(relation_data: List[DeviceModelRelationCreate], -# device_id: int = Query(...), -# db: Session = Depends(get_db)): -# service = DeviceModelRelationService(db) -# relations = service.add_relations_by_device(device_id, relation_data) -# return standard_response(data=relations) + models = await service.get_device_models(device_id) + return standard_response(data=list(models)) @router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) -def update_by_device(relation_data: List[DeviceModelRelationCreate], - device_id: int = Query(...), - service: DeviceModelRelationService = Depends(get_service)): - relations = service.update_relations_by_device(device_id, relation_data) +async def update_by_device(relation_data: List[DeviceModelRelationCreate], + device_id: int = Query(...), + service: DeviceModelRelationService = Depends(get_service)): + relations = await service.update_relations_by_device(device_id, relation_data) return standard_response(data=relations) - -# @router.delete("/delete_by_device", response_model=StandardResponse[int]) -# def delete_device(device_id: int, db: Session = Depends(get_db)): -# service = DeviceModelRelationService(db) -# count = service.delete_relations_by_device(device_id) -# return standard_response(data=count) diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index dbc3fee..cdb97bf 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -98,7 +98,7 @@ def push_frame_results(self, frame_results): global_config = GlobalConfig() push_config = global_config.get_algo_result_push_config() - if push_config: + if push_config and push_config.push_url: last_ts = self.push_ts current_time = datetime.now() diff --git a/apis/control.py b/apis/control.py index 8517389..4c108c5 100644 --- a/apis/control.py +++ b/apis/control.py @@ -48,7 +48,7 @@ @router.get("/restart") -def restart(): +async def restart(): try: # 立即返回响应的函数 def restart_container_async(): diff --git a/apis/data_gas.py b/apis/data_gas.py index fc8ec54..bcb7cb6 100644 --- a/apis/data_gas.py +++ b/apis/data_gas.py @@ -3,8 +3,8 @@ import pandas as pd from fastapi import APIRouter, Query, Depends -from openpyxl.utils import get_column_letter -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession + from starlette.responses import StreamingResponse from apis.base import StandardResponse, PageResponse, convert_page_param, standard_response @@ -17,19 +17,20 @@ @router.get("/page", response_model=StandardResponse[PageResponse[DataGasInfo]]) -def get_gas_page( +async def get_gas_page( device_code: Optional[str] = None, start_time: Optional[str] = None, end_time: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): + service = DataGasService(db) offset, limit = convert_page_param(offset, limit) - data, total = service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), - offset, limit) + data, total = await service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), + offset, limit) return standard_response( data=PageResponse(total=total, items=data) @@ -37,38 +38,39 @@ @router.get("/export") -def export_data_gas(device_code: Optional[str] = None, - start_time: Optional[str] = None, - end_time: Optional[str] = None, - db: Session = Depends(get_db)): - service = DataGasService(db) - data = service.get_data_gas_list(device_code, start_time, end_time) +async def export_data_gas(device_code: Optional[str] = None, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + db: AsyncSession = Depends(get_db)): + async with db: + service = DataGasService(db) + data = await service.get_data_gas_list(device_code, start_time, end_time) - # 将查询结果转换为 DataFrame - data = [ - {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} - for item in data] - df = pd.DataFrame(data) + # 将查询结果转换为 DataFrame + data = [ + {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} + for item in data] + df = pd.DataFrame(data) - # 使用 BytesIO 生成内存中的 Excel 文件 - output = BytesIO() - with pd.ExcelWriter(output, engine="openpyxl") as writer: - df.to_excel(writer, index=False) + # 使用 BytesIO 生成内存中的 Excel 文件 + output = BytesIO() + with pd.ExcelWriter(output, engine="openpyxl") as writer: + df.to_excel(writer, index=False) - # 获取工作表 - worksheet = writer.sheets["Sheet1"] + # 获取工作表 + worksheet = writer.sheets["Sheet1"] - # 设置固定的列宽,例如宽度为 20 - fixed_width = 30 - for col in worksheet.columns: - col_letter = col[0].column_letter # 获取列字母 - worksheet.column_dimensions[col_letter].width = fixed_width + # 设置固定的列宽,例如宽度为 20 + fixed_width = 30 + for col in worksheet.columns: + col_letter = col[0].column_letter # 获取列字母 + worksheet.column_dimensions[col_letter].width = fixed_width - output.seek(0) + output.seek(0) - # 返回内存中的 Excel 文件作为响应 - return StreamingResponse( - output, - media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", - headers={"Content-Disposition": "attachment; filename=gas.xlsx"} - ) + # 返回内存中的 Excel 文件作为响应 + return StreamingResponse( + output, + media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + headers={"Content-Disposition": "attachment; filename=gas.xlsx"} + ) diff --git a/apis/device.py b/apis/device.py index 85fb10c..f170ae3 100644 --- a/apis/device.py +++ b/apis/device.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db @@ -18,29 +18,29 @@ @router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) -def get_device_list( +async def get_device_list( name: Optional[str] = None, code: Optional[str] = None, device_type: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceService(db) - devices = list(service.get_device_list(name, code, device_type)) - return standard_response(data=devices) + devices = await service.get_device_list(name, code, device_type) + return standard_response(data=list(devices)) @router.get("/page/", response_model=StandardResponse[PageResponse[DeviceInfo]]) -def get_device_page( +async def get_device_page( name: Optional[str] = None, code: Optional[str] = None, device_type: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - devices, total = service.get_device_page(name, code, device_type, offset, limit) + devices, total = await service.get_device_page(name, code, device_type, offset, limit) return standard_response( data=PageResponse(total=total, items=devices) @@ -48,22 +48,22 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): - device = service.create_device(device_data) +async def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): + device = await service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): - device = service.update_device(device_data) +async def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): + device = await service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") return standard_response(data=device) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, service: DeviceService = Depends(get_service)): - device = service.delete_device(device_id) +async def delete_device(device_id: int, service: DeviceService = Depends(get_service)): + device = await service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") return standard_response(data=device_id) diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py index e61e1d6..ddb5eb6 100644 --- a/apis/device_model_realtion.py +++ b/apis/device_model_realtion.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import StandardResponse, standard_response from db.database import get_db @@ -13,37 +13,23 @@ router = APIRouter() app = get_app() + def get_service(): return app.state.model_relation_service @router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) -def list_by_device( +async def list_by_device( device_id: int, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceModelRelationService(db) - models = list(service.get_device_models(device_id)) - return standard_response(data=models) - - -# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) -# def add_by_device(relation_data: List[DeviceModelRelationCreate], -# device_id: int = Query(...), -# db: Session = Depends(get_db)): -# service = DeviceModelRelationService(db) -# relations = service.add_relations_by_device(device_id, relation_data) -# return standard_response(data=relations) + models = await service.get_device_models(device_id) + return standard_response(data=list(models)) @router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) -def update_by_device(relation_data: List[DeviceModelRelationCreate], - device_id: int = Query(...), - service: DeviceModelRelationService = Depends(get_service)): - relations = service.update_relations_by_device(device_id, relation_data) +async def update_by_device(relation_data: List[DeviceModelRelationCreate], + device_id: int = Query(...), + service: DeviceModelRelationService = Depends(get_service)): + relations = await service.update_relations_by_device(device_id, relation_data) return standard_response(data=relations) - -# @router.delete("/delete_by_device", response_model=StandardResponse[int]) -# def delete_device(device_id: int, db: Session = Depends(get_db)): -# service = DeviceModelRelationService(db) -# count = service.delete_relations_by_device(device_id) -# return standard_response(data=count) diff --git a/apis/device_scene_realtion.py b/apis/device_scene_realtion.py index 9a30a44..d9aa3a2 100644 --- a/apis/device_scene_realtion.py +++ b/apis/device_scene_realtion.py @@ -1,8 +1,8 @@ -from typing import List, Optional + from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import StandardResponse, standard_response from db.database import get_db from entity.device_scene_relation import DeviceSceneRelationInfo, DeviceSceneRelation @@ -13,17 +13,17 @@ @router.get("/get_by_device", response_model=StandardResponse[DeviceSceneRelationInfo]) -def list_by_device( +async def list_by_device( device_id: int, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - scene = service.get_device_scene(device_id) + scene = await service.get_device_scene(device_id) return standard_response(data=scene) @router.post("/update_by_device", response_model=StandardResponse[DeviceSceneRelation]) -def update_by_device(device_id: int, scene_id: int, - db: Session = Depends(get_db)): +async def update_by_device(device_id: int, scene_id: int, + db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - relation = service.update_relation_by_device(device_id, scene_id) + relation = await service.update_relation_by_device(device_id, scene_id) return standard_response(data=relation) diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index dbc3fee..cdb97bf 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -98,7 +98,7 @@ def push_frame_results(self, frame_results): global_config = GlobalConfig() push_config = global_config.get_algo_result_push_config() - if push_config: + if push_config and push_config.push_url: last_ts = self.push_ts current_time = datetime.now() diff --git a/apis/control.py b/apis/control.py index 8517389..4c108c5 100644 --- a/apis/control.py +++ b/apis/control.py @@ -48,7 +48,7 @@ @router.get("/restart") -def restart(): +async def restart(): try: # 立即返回响应的函数 def restart_container_async(): diff --git a/apis/data_gas.py b/apis/data_gas.py index fc8ec54..bcb7cb6 100644 --- a/apis/data_gas.py +++ b/apis/data_gas.py @@ -3,8 +3,8 @@ import pandas as pd from fastapi import APIRouter, Query, Depends -from openpyxl.utils import get_column_letter -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession + from starlette.responses import StreamingResponse from apis.base import StandardResponse, PageResponse, convert_page_param, standard_response @@ -17,19 +17,20 @@ @router.get("/page", response_model=StandardResponse[PageResponse[DataGasInfo]]) -def get_gas_page( +async def get_gas_page( device_code: Optional[str] = None, start_time: Optional[str] = None, end_time: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): + service = DataGasService(db) offset, limit = convert_page_param(offset, limit) - data, total = service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), - offset, limit) + data, total = await service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), + offset, limit) return standard_response( data=PageResponse(total=total, items=data) @@ -37,38 +38,39 @@ @router.get("/export") -def export_data_gas(device_code: Optional[str] = None, - start_time: Optional[str] = None, - end_time: Optional[str] = None, - db: Session = Depends(get_db)): - service = DataGasService(db) - data = service.get_data_gas_list(device_code, start_time, end_time) +async def export_data_gas(device_code: Optional[str] = None, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + db: AsyncSession = Depends(get_db)): + async with db: + service = DataGasService(db) + data = await service.get_data_gas_list(device_code, start_time, end_time) - # 将查询结果转换为 DataFrame - data = [ - {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} - for item in data] - df = pd.DataFrame(data) + # 将查询结果转换为 DataFrame + data = [ + {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} + for item in data] + df = pd.DataFrame(data) - # 使用 BytesIO 生成内存中的 Excel 文件 - output = BytesIO() - with pd.ExcelWriter(output, engine="openpyxl") as writer: - df.to_excel(writer, index=False) + # 使用 BytesIO 生成内存中的 Excel 文件 + output = BytesIO() + with pd.ExcelWriter(output, engine="openpyxl") as writer: + df.to_excel(writer, index=False) - # 获取工作表 - worksheet = writer.sheets["Sheet1"] + # 获取工作表 + worksheet = writer.sheets["Sheet1"] - # 设置固定的列宽,例如宽度为 20 - fixed_width = 30 - for col in worksheet.columns: - col_letter = col[0].column_letter # 获取列字母 - worksheet.column_dimensions[col_letter].width = fixed_width + # 设置固定的列宽,例如宽度为 20 + fixed_width = 30 + for col in worksheet.columns: + col_letter = col[0].column_letter # 获取列字母 + worksheet.column_dimensions[col_letter].width = fixed_width - output.seek(0) + output.seek(0) - # 返回内存中的 Excel 文件作为响应 - return StreamingResponse( - output, - media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", - headers={"Content-Disposition": "attachment; filename=gas.xlsx"} - ) + # 返回内存中的 Excel 文件作为响应 + return StreamingResponse( + output, + media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + headers={"Content-Disposition": "attachment; filename=gas.xlsx"} + ) diff --git a/apis/device.py b/apis/device.py index 85fb10c..f170ae3 100644 --- a/apis/device.py +++ b/apis/device.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db @@ -18,29 +18,29 @@ @router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) -def get_device_list( +async def get_device_list( name: Optional[str] = None, code: Optional[str] = None, device_type: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceService(db) - devices = list(service.get_device_list(name, code, device_type)) - return standard_response(data=devices) + devices = await service.get_device_list(name, code, device_type) + return standard_response(data=list(devices)) @router.get("/page/", response_model=StandardResponse[PageResponse[DeviceInfo]]) -def get_device_page( +async def get_device_page( name: Optional[str] = None, code: Optional[str] = None, device_type: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - devices, total = service.get_device_page(name, code, device_type, offset, limit) + devices, total = await service.get_device_page(name, code, device_type, offset, limit) return standard_response( data=PageResponse(total=total, items=devices) @@ -48,22 +48,22 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): - device = service.create_device(device_data) +async def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): + device = await service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): - device = service.update_device(device_data) +async def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): + device = await service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") return standard_response(data=device) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, service: DeviceService = Depends(get_service)): - device = service.delete_device(device_id) +async def delete_device(device_id: int, service: DeviceService = Depends(get_service)): + device = await service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") return standard_response(data=device_id) diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py index e61e1d6..ddb5eb6 100644 --- a/apis/device_model_realtion.py +++ b/apis/device_model_realtion.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import StandardResponse, standard_response from db.database import get_db @@ -13,37 +13,23 @@ router = APIRouter() app = get_app() + def get_service(): return app.state.model_relation_service @router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) -def list_by_device( +async def list_by_device( device_id: int, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceModelRelationService(db) - models = list(service.get_device_models(device_id)) - return standard_response(data=models) - - -# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) -# def add_by_device(relation_data: List[DeviceModelRelationCreate], -# device_id: int = Query(...), -# db: Session = Depends(get_db)): -# service = DeviceModelRelationService(db) -# relations = service.add_relations_by_device(device_id, relation_data) -# return standard_response(data=relations) + models = await service.get_device_models(device_id) + return standard_response(data=list(models)) @router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) -def update_by_device(relation_data: List[DeviceModelRelationCreate], - device_id: int = Query(...), - service: DeviceModelRelationService = Depends(get_service)): - relations = service.update_relations_by_device(device_id, relation_data) +async def update_by_device(relation_data: List[DeviceModelRelationCreate], + device_id: int = Query(...), + service: DeviceModelRelationService = Depends(get_service)): + relations = await service.update_relations_by_device(device_id, relation_data) return standard_response(data=relations) - -# @router.delete("/delete_by_device", response_model=StandardResponse[int]) -# def delete_device(device_id: int, db: Session = Depends(get_db)): -# service = DeviceModelRelationService(db) -# count = service.delete_relations_by_device(device_id) -# return standard_response(data=count) diff --git a/apis/device_scene_realtion.py b/apis/device_scene_realtion.py index 9a30a44..d9aa3a2 100644 --- a/apis/device_scene_realtion.py +++ b/apis/device_scene_realtion.py @@ -1,8 +1,8 @@ -from typing import List, Optional + from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import StandardResponse, standard_response from db.database import get_db from entity.device_scene_relation import DeviceSceneRelationInfo, DeviceSceneRelation @@ -13,17 +13,17 @@ @router.get("/get_by_device", response_model=StandardResponse[DeviceSceneRelationInfo]) -def list_by_device( +async def list_by_device( device_id: int, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - scene = service.get_device_scene(device_id) + scene = await service.get_device_scene(device_id) return standard_response(data=scene) @router.post("/update_by_device", response_model=StandardResponse[DeviceSceneRelation]) -def update_by_device(device_id: int, scene_id: int, - db: Session = Depends(get_db)): +async def update_by_device(device_id: int, scene_id: int, + db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - relation = service.update_relation_by_device(device_id, scene_id) + relation = await service.update_relation_by_device(device_id, scene_id) return standard_response(data=relation) diff --git a/apis/frame.py b/apis/frame.py index 1cb9294..5146ede 100644 --- a/apis/frame.py +++ b/apis/frame.py @@ -5,7 +5,7 @@ import cv2 from fastapi import APIRouter, Depends, Query, HTTPException, Request -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from fastapi.responses import StreamingResponse, FileResponse from apis.base import StandardResponse, PageResponse, standard_response, standard_error_response, convert_page_param @@ -19,20 +19,20 @@ @router.get("/page/", response_model=StandardResponse[PageResponse[DeviceFrame]]) -def get_frame_page( +async def get_frame_page( device_name: Optional[str] = None, device_code: Optional[str] = None, frame_start_time: Optional[str] = None, frame_end_time: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceFrameService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - frames, total = service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), - parse_datetime(frame_end_time), offset, limit) + frames, total = await service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), + parse_datetime(frame_end_time), offset, limit) return standard_response( data=PageResponse(total=total, items=frames) @@ -41,10 +41,10 @@ # 路由:使用 OpenCV 生成内存图像并返回字节流 @router.get("/frame_image/") -def get_frame_image(frame_id, db: Session = Depends(get_db)): +async def get_frame_image(frame_id, db: AsyncSession = Depends(get_db)): try: service = DeviceFrameService(db) - frame = service.get_frame_annotator(frame_id) + frame = await service.get_frame_annotator(frame_id) if frame is None: return standard_error_response(message="Frame does not exist") @@ -63,14 +63,3 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"Error reading image: {str(e)}") - -@router.get("/frame_test") -def get_frame_test(request: Request): - file_path = "test.jpg" - return FileResponse(file_path) - # def iterfile(): - # with open(file_path, 'rb') as file: - # while chunk := file.read(1024): - # yield chunk - # - # return StreamingResponse(iterfile(), media_type="image/jpeg") diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index dbc3fee..cdb97bf 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -98,7 +98,7 @@ def push_frame_results(self, frame_results): global_config = GlobalConfig() push_config = global_config.get_algo_result_push_config() - if push_config: + if push_config and push_config.push_url: last_ts = self.push_ts current_time = datetime.now() diff --git a/apis/control.py b/apis/control.py index 8517389..4c108c5 100644 --- a/apis/control.py +++ b/apis/control.py @@ -48,7 +48,7 @@ @router.get("/restart") -def restart(): +async def restart(): try: # 立即返回响应的函数 def restart_container_async(): diff --git a/apis/data_gas.py b/apis/data_gas.py index fc8ec54..bcb7cb6 100644 --- a/apis/data_gas.py +++ b/apis/data_gas.py @@ -3,8 +3,8 @@ import pandas as pd from fastapi import APIRouter, Query, Depends -from openpyxl.utils import get_column_letter -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession + from starlette.responses import StreamingResponse from apis.base import StandardResponse, PageResponse, convert_page_param, standard_response @@ -17,19 +17,20 @@ @router.get("/page", response_model=StandardResponse[PageResponse[DataGasInfo]]) -def get_gas_page( +async def get_gas_page( device_code: Optional[str] = None, start_time: Optional[str] = None, end_time: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): + service = DataGasService(db) offset, limit = convert_page_param(offset, limit) - data, total = service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), - offset, limit) + data, total = await service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), + offset, limit) return standard_response( data=PageResponse(total=total, items=data) @@ -37,38 +38,39 @@ @router.get("/export") -def export_data_gas(device_code: Optional[str] = None, - start_time: Optional[str] = None, - end_time: Optional[str] = None, - db: Session = Depends(get_db)): - service = DataGasService(db) - data = service.get_data_gas_list(device_code, start_time, end_time) +async def export_data_gas(device_code: Optional[str] = None, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + db: AsyncSession = Depends(get_db)): + async with db: + service = DataGasService(db) + data = await service.get_data_gas_list(device_code, start_time, end_time) - # 将查询结果转换为 DataFrame - data = [ - {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} - for item in data] - df = pd.DataFrame(data) + # 将查询结果转换为 DataFrame + data = [ + {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} + for item in data] + df = pd.DataFrame(data) - # 使用 BytesIO 生成内存中的 Excel 文件 - output = BytesIO() - with pd.ExcelWriter(output, engine="openpyxl") as writer: - df.to_excel(writer, index=False) + # 使用 BytesIO 生成内存中的 Excel 文件 + output = BytesIO() + with pd.ExcelWriter(output, engine="openpyxl") as writer: + df.to_excel(writer, index=False) - # 获取工作表 - worksheet = writer.sheets["Sheet1"] + # 获取工作表 + worksheet = writer.sheets["Sheet1"] - # 设置固定的列宽,例如宽度为 20 - fixed_width = 30 - for col in worksheet.columns: - col_letter = col[0].column_letter # 获取列字母 - worksheet.column_dimensions[col_letter].width = fixed_width + # 设置固定的列宽,例如宽度为 20 + fixed_width = 30 + for col in worksheet.columns: + col_letter = col[0].column_letter # 获取列字母 + worksheet.column_dimensions[col_letter].width = fixed_width - output.seek(0) + output.seek(0) - # 返回内存中的 Excel 文件作为响应 - return StreamingResponse( - output, - media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", - headers={"Content-Disposition": "attachment; filename=gas.xlsx"} - ) + # 返回内存中的 Excel 文件作为响应 + return StreamingResponse( + output, + media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + headers={"Content-Disposition": "attachment; filename=gas.xlsx"} + ) diff --git a/apis/device.py b/apis/device.py index 85fb10c..f170ae3 100644 --- a/apis/device.py +++ b/apis/device.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db @@ -18,29 +18,29 @@ @router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) -def get_device_list( +async def get_device_list( name: Optional[str] = None, code: Optional[str] = None, device_type: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceService(db) - devices = list(service.get_device_list(name, code, device_type)) - return standard_response(data=devices) + devices = await service.get_device_list(name, code, device_type) + return standard_response(data=list(devices)) @router.get("/page/", response_model=StandardResponse[PageResponse[DeviceInfo]]) -def get_device_page( +async def get_device_page( name: Optional[str] = None, code: Optional[str] = None, device_type: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - devices, total = service.get_device_page(name, code, device_type, offset, limit) + devices, total = await service.get_device_page(name, code, device_type, offset, limit) return standard_response( data=PageResponse(total=total, items=devices) @@ -48,22 +48,22 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): - device = service.create_device(device_data) +async def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): + device = await service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): - device = service.update_device(device_data) +async def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): + device = await service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") return standard_response(data=device) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, service: DeviceService = Depends(get_service)): - device = service.delete_device(device_id) +async def delete_device(device_id: int, service: DeviceService = Depends(get_service)): + device = await service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") return standard_response(data=device_id) diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py index e61e1d6..ddb5eb6 100644 --- a/apis/device_model_realtion.py +++ b/apis/device_model_realtion.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import StandardResponse, standard_response from db.database import get_db @@ -13,37 +13,23 @@ router = APIRouter() app = get_app() + def get_service(): return app.state.model_relation_service @router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) -def list_by_device( +async def list_by_device( device_id: int, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceModelRelationService(db) - models = list(service.get_device_models(device_id)) - return standard_response(data=models) - - -# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) -# def add_by_device(relation_data: List[DeviceModelRelationCreate], -# device_id: int = Query(...), -# db: Session = Depends(get_db)): -# service = DeviceModelRelationService(db) -# relations = service.add_relations_by_device(device_id, relation_data) -# return standard_response(data=relations) + models = await service.get_device_models(device_id) + return standard_response(data=list(models)) @router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) -def update_by_device(relation_data: List[DeviceModelRelationCreate], - device_id: int = Query(...), - service: DeviceModelRelationService = Depends(get_service)): - relations = service.update_relations_by_device(device_id, relation_data) +async def update_by_device(relation_data: List[DeviceModelRelationCreate], + device_id: int = Query(...), + service: DeviceModelRelationService = Depends(get_service)): + relations = await service.update_relations_by_device(device_id, relation_data) return standard_response(data=relations) - -# @router.delete("/delete_by_device", response_model=StandardResponse[int]) -# def delete_device(device_id: int, db: Session = Depends(get_db)): -# service = DeviceModelRelationService(db) -# count = service.delete_relations_by_device(device_id) -# return standard_response(data=count) diff --git a/apis/device_scene_realtion.py b/apis/device_scene_realtion.py index 9a30a44..d9aa3a2 100644 --- a/apis/device_scene_realtion.py +++ b/apis/device_scene_realtion.py @@ -1,8 +1,8 @@ -from typing import List, Optional + from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import StandardResponse, standard_response from db.database import get_db from entity.device_scene_relation import DeviceSceneRelationInfo, DeviceSceneRelation @@ -13,17 +13,17 @@ @router.get("/get_by_device", response_model=StandardResponse[DeviceSceneRelationInfo]) -def list_by_device( +async def list_by_device( device_id: int, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - scene = service.get_device_scene(device_id) + scene = await service.get_device_scene(device_id) return standard_response(data=scene) @router.post("/update_by_device", response_model=StandardResponse[DeviceSceneRelation]) -def update_by_device(device_id: int, scene_id: int, - db: Session = Depends(get_db)): +async def update_by_device(device_id: int, scene_id: int, + db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - relation = service.update_relation_by_device(device_id, scene_id) + relation = await service.update_relation_by_device(device_id, scene_id) return standard_response(data=relation) diff --git a/apis/frame.py b/apis/frame.py index 1cb9294..5146ede 100644 --- a/apis/frame.py +++ b/apis/frame.py @@ -5,7 +5,7 @@ import cv2 from fastapi import APIRouter, Depends, Query, HTTPException, Request -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from fastapi.responses import StreamingResponse, FileResponse from apis.base import StandardResponse, PageResponse, standard_response, standard_error_response, convert_page_param @@ -19,20 +19,20 @@ @router.get("/page/", response_model=StandardResponse[PageResponse[DeviceFrame]]) -def get_frame_page( +async def get_frame_page( device_name: Optional[str] = None, device_code: Optional[str] = None, frame_start_time: Optional[str] = None, frame_end_time: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceFrameService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - frames, total = service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), - parse_datetime(frame_end_time), offset, limit) + frames, total = await service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), + parse_datetime(frame_end_time), offset, limit) return standard_response( data=PageResponse(total=total, items=frames) @@ -41,10 +41,10 @@ # 路由:使用 OpenCV 生成内存图像并返回字节流 @router.get("/frame_image/") -def get_frame_image(frame_id, db: Session = Depends(get_db)): +async def get_frame_image(frame_id, db: AsyncSession = Depends(get_db)): try: service = DeviceFrameService(db) - frame = service.get_frame_annotator(frame_id) + frame = await service.get_frame_annotator(frame_id) if frame is None: return standard_error_response(message="Frame does not exist") @@ -63,14 +63,3 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"Error reading image: {str(e)}") - -@router.get("/frame_test") -def get_frame_test(request: Request): - file_path = "test.jpg" - return FileResponse(file_path) - # def iterfile(): - # with open(file_path, 'rb') as file: - # while chunk := file.read(1024): - # yield chunk - # - # return StreamingResponse(iterfile(), media_type="image/jpeg") diff --git a/apis/model.py b/apis/model.py index a1161eb..fa2b191 100644 --- a/apis/model.py +++ b/apis/model.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query, UploadFile, File, Form -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from app_instance import get_app @@ -18,27 +18,27 @@ @router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) -def get_model_list( +async def get_model_list( name: Optional[str] = None, remark: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = ModelService(db) - models = list(service.get_model_list(name, remark)) - return standard_response(data=models) + models = await service.get_model_list(name, remark) + return standard_response(data=list(models)) @router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModelInfo]]) -def get_model_page( +async def get_model_page( name: Optional[str] = None, remark: Optional[str] = None, offset: int = Query(0, ge=0), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = ModelService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - models, total = service.get_model_page(name, remark, offset, limit) + models, total = await service.get_model_page(name, remark, offset, limit) return standard_response( data=PageResponse(total=total, items=models) @@ -46,9 +46,9 @@ @router.post("/add", response_model=StandardResponse[AlgoModelInfo]) -def create_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelCreate结构"), +async def create_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelCreate结构"), file: UploadFile = File(..., description="模型文件"), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_db), ): # 检查文件类型 if not file.filename.endswith(".zip"): @@ -56,25 +56,25 @@ model_data = AlgoModelCreate.parse_raw(json_data) service = ModelService(db) - model = service.create_model(model_data, file) + model = await service.create_model(model_data, file) return standard_response(data=model) @router.post("/update", response_model=StandardResponse[AlgoModelInfo]) -def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), +async def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), file: UploadFile = File(None, description="模型文件"), service: ModelService = Depends(get_service)): model_data = AlgoModelUpdate.parse_raw(json_data) - model = service.update_model(model_data, file) + model = await service.update_model(model_data, file) if not model: return standard_error_response(data=model_data, message="Model not found") return standard_response(data=model) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_model(model_id: int, db: Session = Depends(get_db)): +async def delete_model(model_id: int, db: AsyncSession = Depends(get_db)): service = ModelService(db) - model = service.delete_model(model_id) + model = await service.delete_model(model_id) if not model: return standard_error_response(data=model_id, message="Model not found") return standard_response(data=model_id) diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index dbc3fee..cdb97bf 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -98,7 +98,7 @@ def push_frame_results(self, frame_results): global_config = GlobalConfig() push_config = global_config.get_algo_result_push_config() - if push_config: + if push_config and push_config.push_url: last_ts = self.push_ts current_time = datetime.now() diff --git a/apis/control.py b/apis/control.py index 8517389..4c108c5 100644 --- a/apis/control.py +++ b/apis/control.py @@ -48,7 +48,7 @@ @router.get("/restart") -def restart(): +async def restart(): try: # 立即返回响应的函数 def restart_container_async(): diff --git a/apis/data_gas.py b/apis/data_gas.py index fc8ec54..bcb7cb6 100644 --- a/apis/data_gas.py +++ b/apis/data_gas.py @@ -3,8 +3,8 @@ import pandas as pd from fastapi import APIRouter, Query, Depends -from openpyxl.utils import get_column_letter -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession + from starlette.responses import StreamingResponse from apis.base import StandardResponse, PageResponse, convert_page_param, standard_response @@ -17,19 +17,20 @@ @router.get("/page", response_model=StandardResponse[PageResponse[DataGasInfo]]) -def get_gas_page( +async def get_gas_page( device_code: Optional[str] = None, start_time: Optional[str] = None, end_time: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): + service = DataGasService(db) offset, limit = convert_page_param(offset, limit) - data, total = service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), - offset, limit) + data, total = await service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), + offset, limit) return standard_response( data=PageResponse(total=total, items=data) @@ -37,38 +38,39 @@ @router.get("/export") -def export_data_gas(device_code: Optional[str] = None, - start_time: Optional[str] = None, - end_time: Optional[str] = None, - db: Session = Depends(get_db)): - service = DataGasService(db) - data = service.get_data_gas_list(device_code, start_time, end_time) +async def export_data_gas(device_code: Optional[str] = None, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + db: AsyncSession = Depends(get_db)): + async with db: + service = DataGasService(db) + data = await service.get_data_gas_list(device_code, start_time, end_time) - # 将查询结果转换为 DataFrame - data = [ - {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} - for item in data] - df = pd.DataFrame(data) + # 将查询结果转换为 DataFrame + data = [ + {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} + for item in data] + df = pd.DataFrame(data) - # 使用 BytesIO 生成内存中的 Excel 文件 - output = BytesIO() - with pd.ExcelWriter(output, engine="openpyxl") as writer: - df.to_excel(writer, index=False) + # 使用 BytesIO 生成内存中的 Excel 文件 + output = BytesIO() + with pd.ExcelWriter(output, engine="openpyxl") as writer: + df.to_excel(writer, index=False) - # 获取工作表 - worksheet = writer.sheets["Sheet1"] + # 获取工作表 + worksheet = writer.sheets["Sheet1"] - # 设置固定的列宽,例如宽度为 20 - fixed_width = 30 - for col in worksheet.columns: - col_letter = col[0].column_letter # 获取列字母 - worksheet.column_dimensions[col_letter].width = fixed_width + # 设置固定的列宽,例如宽度为 20 + fixed_width = 30 + for col in worksheet.columns: + col_letter = col[0].column_letter # 获取列字母 + worksheet.column_dimensions[col_letter].width = fixed_width - output.seek(0) + output.seek(0) - # 返回内存中的 Excel 文件作为响应 - return StreamingResponse( - output, - media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", - headers={"Content-Disposition": "attachment; filename=gas.xlsx"} - ) + # 返回内存中的 Excel 文件作为响应 + return StreamingResponse( + output, + media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + headers={"Content-Disposition": "attachment; filename=gas.xlsx"} + ) diff --git a/apis/device.py b/apis/device.py index 85fb10c..f170ae3 100644 --- a/apis/device.py +++ b/apis/device.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db @@ -18,29 +18,29 @@ @router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) -def get_device_list( +async def get_device_list( name: Optional[str] = None, code: Optional[str] = None, device_type: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceService(db) - devices = list(service.get_device_list(name, code, device_type)) - return standard_response(data=devices) + devices = await service.get_device_list(name, code, device_type) + return standard_response(data=list(devices)) @router.get("/page/", response_model=StandardResponse[PageResponse[DeviceInfo]]) -def get_device_page( +async def get_device_page( name: Optional[str] = None, code: Optional[str] = None, device_type: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - devices, total = service.get_device_page(name, code, device_type, offset, limit) + devices, total = await service.get_device_page(name, code, device_type, offset, limit) return standard_response( data=PageResponse(total=total, items=devices) @@ -48,22 +48,22 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): - device = service.create_device(device_data) +async def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): + device = await service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): - device = service.update_device(device_data) +async def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): + device = await service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") return standard_response(data=device) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, service: DeviceService = Depends(get_service)): - device = service.delete_device(device_id) +async def delete_device(device_id: int, service: DeviceService = Depends(get_service)): + device = await service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") return standard_response(data=device_id) diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py index e61e1d6..ddb5eb6 100644 --- a/apis/device_model_realtion.py +++ b/apis/device_model_realtion.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import StandardResponse, standard_response from db.database import get_db @@ -13,37 +13,23 @@ router = APIRouter() app = get_app() + def get_service(): return app.state.model_relation_service @router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) -def list_by_device( +async def list_by_device( device_id: int, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceModelRelationService(db) - models = list(service.get_device_models(device_id)) - return standard_response(data=models) - - -# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) -# def add_by_device(relation_data: List[DeviceModelRelationCreate], -# device_id: int = Query(...), -# db: Session = Depends(get_db)): -# service = DeviceModelRelationService(db) -# relations = service.add_relations_by_device(device_id, relation_data) -# return standard_response(data=relations) + models = await service.get_device_models(device_id) + return standard_response(data=list(models)) @router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) -def update_by_device(relation_data: List[DeviceModelRelationCreate], - device_id: int = Query(...), - service: DeviceModelRelationService = Depends(get_service)): - relations = service.update_relations_by_device(device_id, relation_data) +async def update_by_device(relation_data: List[DeviceModelRelationCreate], + device_id: int = Query(...), + service: DeviceModelRelationService = Depends(get_service)): + relations = await service.update_relations_by_device(device_id, relation_data) return standard_response(data=relations) - -# @router.delete("/delete_by_device", response_model=StandardResponse[int]) -# def delete_device(device_id: int, db: Session = Depends(get_db)): -# service = DeviceModelRelationService(db) -# count = service.delete_relations_by_device(device_id) -# return standard_response(data=count) diff --git a/apis/device_scene_realtion.py b/apis/device_scene_realtion.py index 9a30a44..d9aa3a2 100644 --- a/apis/device_scene_realtion.py +++ b/apis/device_scene_realtion.py @@ -1,8 +1,8 @@ -from typing import List, Optional + from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import StandardResponse, standard_response from db.database import get_db from entity.device_scene_relation import DeviceSceneRelationInfo, DeviceSceneRelation @@ -13,17 +13,17 @@ @router.get("/get_by_device", response_model=StandardResponse[DeviceSceneRelationInfo]) -def list_by_device( +async def list_by_device( device_id: int, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - scene = service.get_device_scene(device_id) + scene = await service.get_device_scene(device_id) return standard_response(data=scene) @router.post("/update_by_device", response_model=StandardResponse[DeviceSceneRelation]) -def update_by_device(device_id: int, scene_id: int, - db: Session = Depends(get_db)): +async def update_by_device(device_id: int, scene_id: int, + db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - relation = service.update_relation_by_device(device_id, scene_id) + relation = await service.update_relation_by_device(device_id, scene_id) return standard_response(data=relation) diff --git a/apis/frame.py b/apis/frame.py index 1cb9294..5146ede 100644 --- a/apis/frame.py +++ b/apis/frame.py @@ -5,7 +5,7 @@ import cv2 from fastapi import APIRouter, Depends, Query, HTTPException, Request -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from fastapi.responses import StreamingResponse, FileResponse from apis.base import StandardResponse, PageResponse, standard_response, standard_error_response, convert_page_param @@ -19,20 +19,20 @@ @router.get("/page/", response_model=StandardResponse[PageResponse[DeviceFrame]]) -def get_frame_page( +async def get_frame_page( device_name: Optional[str] = None, device_code: Optional[str] = None, frame_start_time: Optional[str] = None, frame_end_time: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceFrameService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - frames, total = service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), - parse_datetime(frame_end_time), offset, limit) + frames, total = await service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), + parse_datetime(frame_end_time), offset, limit) return standard_response( data=PageResponse(total=total, items=frames) @@ -41,10 +41,10 @@ # 路由:使用 OpenCV 生成内存图像并返回字节流 @router.get("/frame_image/") -def get_frame_image(frame_id, db: Session = Depends(get_db)): +async def get_frame_image(frame_id, db: AsyncSession = Depends(get_db)): try: service = DeviceFrameService(db) - frame = service.get_frame_annotator(frame_id) + frame = await service.get_frame_annotator(frame_id) if frame is None: return standard_error_response(message="Frame does not exist") @@ -63,14 +63,3 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"Error reading image: {str(e)}") - -@router.get("/frame_test") -def get_frame_test(request: Request): - file_path = "test.jpg" - return FileResponse(file_path) - # def iterfile(): - # with open(file_path, 'rb') as file: - # while chunk := file.read(1024): - # yield chunk - # - # return StreamingResponse(iterfile(), media_type="image/jpeg") diff --git a/apis/model.py b/apis/model.py index a1161eb..fa2b191 100644 --- a/apis/model.py +++ b/apis/model.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query, UploadFile, File, Form -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from app_instance import get_app @@ -18,27 +18,27 @@ @router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) -def get_model_list( +async def get_model_list( name: Optional[str] = None, remark: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = ModelService(db) - models = list(service.get_model_list(name, remark)) - return standard_response(data=models) + models = await service.get_model_list(name, remark) + return standard_response(data=list(models)) @router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModelInfo]]) -def get_model_page( +async def get_model_page( name: Optional[str] = None, remark: Optional[str] = None, offset: int = Query(0, ge=0), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = ModelService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - models, total = service.get_model_page(name, remark, offset, limit) + models, total = await service.get_model_page(name, remark, offset, limit) return standard_response( data=PageResponse(total=total, items=models) @@ -46,9 +46,9 @@ @router.post("/add", response_model=StandardResponse[AlgoModelInfo]) -def create_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelCreate结构"), +async def create_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelCreate结构"), file: UploadFile = File(..., description="模型文件"), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_db), ): # 检查文件类型 if not file.filename.endswith(".zip"): @@ -56,25 +56,25 @@ model_data = AlgoModelCreate.parse_raw(json_data) service = ModelService(db) - model = service.create_model(model_data, file) + model = await service.create_model(model_data, file) return standard_response(data=model) @router.post("/update", response_model=StandardResponse[AlgoModelInfo]) -def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), +async def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), file: UploadFile = File(None, description="模型文件"), service: ModelService = Depends(get_service)): model_data = AlgoModelUpdate.parse_raw(json_data) - model = service.update_model(model_data, file) + model = await service.update_model(model_data, file) if not model: return standard_error_response(data=model_data, message="Model not found") return standard_response(data=model) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_model(model_id: int, db: Session = Depends(get_db)): +async def delete_model(model_id: int, db: AsyncSession = Depends(get_db)): service = ModelService(db) - model = service.delete_model(model_id) + model = await service.delete_model(model_id) if not model: return standard_error_response(data=model_id, message="Model not found") return standard_response(data=model_id) diff --git a/apis/push_config.py b/apis/push_config.py index 5e25579..e234afb 100644 --- a/apis/push_config.py +++ b/apis/push_config.py @@ -1,7 +1,7 @@ from typing import List from fastapi import APIRouter, Depends -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse from db.database import get_db @@ -12,21 +12,21 @@ @router.get("/list", response_model=StandardResponse[List[PushConfig]]) -def get_push_config_list(db: Session = Depends(get_db)): +async def get_push_config_list(db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_configs = service.get_push_config_list() + push_configs = await service.get_push_config_list() return standard_response(data=push_configs) @router.get("/get_by_type", response_model=StandardResponse[PushConfig]) -def get_by_type(push_type: int, db: Session = Depends(get_db)): +async def get_by_type(push_type: int, db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_config = service.get_push_config(push_type) + push_config = await service.get_push_config(push_type) return standard_response(data=push_config) @router.post("/set_push_config", response_model=StandardResponse[PushConfig]) -def set_push_config(push_config: PushConfigCreate, db: Session = Depends(get_db)): +async def set_push_config(push_config: PushConfigCreate, db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_config = service.set_push_config(push_config) + push_config = await service.set_push_config(push_config) return standard_response(data=push_config) diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index dbc3fee..cdb97bf 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -98,7 +98,7 @@ def push_frame_results(self, frame_results): global_config = GlobalConfig() push_config = global_config.get_algo_result_push_config() - if push_config: + if push_config and push_config.push_url: last_ts = self.push_ts current_time = datetime.now() diff --git a/apis/control.py b/apis/control.py index 8517389..4c108c5 100644 --- a/apis/control.py +++ b/apis/control.py @@ -48,7 +48,7 @@ @router.get("/restart") -def restart(): +async def restart(): try: # 立即返回响应的函数 def restart_container_async(): diff --git a/apis/data_gas.py b/apis/data_gas.py index fc8ec54..bcb7cb6 100644 --- a/apis/data_gas.py +++ b/apis/data_gas.py @@ -3,8 +3,8 @@ import pandas as pd from fastapi import APIRouter, Query, Depends -from openpyxl.utils import get_column_letter -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession + from starlette.responses import StreamingResponse from apis.base import StandardResponse, PageResponse, convert_page_param, standard_response @@ -17,19 +17,20 @@ @router.get("/page", response_model=StandardResponse[PageResponse[DataGasInfo]]) -def get_gas_page( +async def get_gas_page( device_code: Optional[str] = None, start_time: Optional[str] = None, end_time: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): + service = DataGasService(db) offset, limit = convert_page_param(offset, limit) - data, total = service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), - offset, limit) + data, total = await service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), + offset, limit) return standard_response( data=PageResponse(total=total, items=data) @@ -37,38 +38,39 @@ @router.get("/export") -def export_data_gas(device_code: Optional[str] = None, - start_time: Optional[str] = None, - end_time: Optional[str] = None, - db: Session = Depends(get_db)): - service = DataGasService(db) - data = service.get_data_gas_list(device_code, start_time, end_time) +async def export_data_gas(device_code: Optional[str] = None, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + db: AsyncSession = Depends(get_db)): + async with db: + service = DataGasService(db) + data = await service.get_data_gas_list(device_code, start_time, end_time) - # 将查询结果转换为 DataFrame - data = [ - {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} - for item in data] - df = pd.DataFrame(data) + # 将查询结果转换为 DataFrame + data = [ + {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} + for item in data] + df = pd.DataFrame(data) - # 使用 BytesIO 生成内存中的 Excel 文件 - output = BytesIO() - with pd.ExcelWriter(output, engine="openpyxl") as writer: - df.to_excel(writer, index=False) + # 使用 BytesIO 生成内存中的 Excel 文件 + output = BytesIO() + with pd.ExcelWriter(output, engine="openpyxl") as writer: + df.to_excel(writer, index=False) - # 获取工作表 - worksheet = writer.sheets["Sheet1"] + # 获取工作表 + worksheet = writer.sheets["Sheet1"] - # 设置固定的列宽,例如宽度为 20 - fixed_width = 30 - for col in worksheet.columns: - col_letter = col[0].column_letter # 获取列字母 - worksheet.column_dimensions[col_letter].width = fixed_width + # 设置固定的列宽,例如宽度为 20 + fixed_width = 30 + for col in worksheet.columns: + col_letter = col[0].column_letter # 获取列字母 + worksheet.column_dimensions[col_letter].width = fixed_width - output.seek(0) + output.seek(0) - # 返回内存中的 Excel 文件作为响应 - return StreamingResponse( - output, - media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", - headers={"Content-Disposition": "attachment; filename=gas.xlsx"} - ) + # 返回内存中的 Excel 文件作为响应 + return StreamingResponse( + output, + media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + headers={"Content-Disposition": "attachment; filename=gas.xlsx"} + ) diff --git a/apis/device.py b/apis/device.py index 85fb10c..f170ae3 100644 --- a/apis/device.py +++ b/apis/device.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db @@ -18,29 +18,29 @@ @router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) -def get_device_list( +async def get_device_list( name: Optional[str] = None, code: Optional[str] = None, device_type: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceService(db) - devices = list(service.get_device_list(name, code, device_type)) - return standard_response(data=devices) + devices = await service.get_device_list(name, code, device_type) + return standard_response(data=list(devices)) @router.get("/page/", response_model=StandardResponse[PageResponse[DeviceInfo]]) -def get_device_page( +async def get_device_page( name: Optional[str] = None, code: Optional[str] = None, device_type: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - devices, total = service.get_device_page(name, code, device_type, offset, limit) + devices, total = await service.get_device_page(name, code, device_type, offset, limit) return standard_response( data=PageResponse(total=total, items=devices) @@ -48,22 +48,22 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): - device = service.create_device(device_data) +async def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): + device = await service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): - device = service.update_device(device_data) +async def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): + device = await service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") return standard_response(data=device) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, service: DeviceService = Depends(get_service)): - device = service.delete_device(device_id) +async def delete_device(device_id: int, service: DeviceService = Depends(get_service)): + device = await service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") return standard_response(data=device_id) diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py index e61e1d6..ddb5eb6 100644 --- a/apis/device_model_realtion.py +++ b/apis/device_model_realtion.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import StandardResponse, standard_response from db.database import get_db @@ -13,37 +13,23 @@ router = APIRouter() app = get_app() + def get_service(): return app.state.model_relation_service @router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) -def list_by_device( +async def list_by_device( device_id: int, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceModelRelationService(db) - models = list(service.get_device_models(device_id)) - return standard_response(data=models) - - -# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) -# def add_by_device(relation_data: List[DeviceModelRelationCreate], -# device_id: int = Query(...), -# db: Session = Depends(get_db)): -# service = DeviceModelRelationService(db) -# relations = service.add_relations_by_device(device_id, relation_data) -# return standard_response(data=relations) + models = await service.get_device_models(device_id) + return standard_response(data=list(models)) @router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) -def update_by_device(relation_data: List[DeviceModelRelationCreate], - device_id: int = Query(...), - service: DeviceModelRelationService = Depends(get_service)): - relations = service.update_relations_by_device(device_id, relation_data) +async def update_by_device(relation_data: List[DeviceModelRelationCreate], + device_id: int = Query(...), + service: DeviceModelRelationService = Depends(get_service)): + relations = await service.update_relations_by_device(device_id, relation_data) return standard_response(data=relations) - -# @router.delete("/delete_by_device", response_model=StandardResponse[int]) -# def delete_device(device_id: int, db: Session = Depends(get_db)): -# service = DeviceModelRelationService(db) -# count = service.delete_relations_by_device(device_id) -# return standard_response(data=count) diff --git a/apis/device_scene_realtion.py b/apis/device_scene_realtion.py index 9a30a44..d9aa3a2 100644 --- a/apis/device_scene_realtion.py +++ b/apis/device_scene_realtion.py @@ -1,8 +1,8 @@ -from typing import List, Optional + from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import StandardResponse, standard_response from db.database import get_db from entity.device_scene_relation import DeviceSceneRelationInfo, DeviceSceneRelation @@ -13,17 +13,17 @@ @router.get("/get_by_device", response_model=StandardResponse[DeviceSceneRelationInfo]) -def list_by_device( +async def list_by_device( device_id: int, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - scene = service.get_device_scene(device_id) + scene = await service.get_device_scene(device_id) return standard_response(data=scene) @router.post("/update_by_device", response_model=StandardResponse[DeviceSceneRelation]) -def update_by_device(device_id: int, scene_id: int, - db: Session = Depends(get_db)): +async def update_by_device(device_id: int, scene_id: int, + db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - relation = service.update_relation_by_device(device_id, scene_id) + relation = await service.update_relation_by_device(device_id, scene_id) return standard_response(data=relation) diff --git a/apis/frame.py b/apis/frame.py index 1cb9294..5146ede 100644 --- a/apis/frame.py +++ b/apis/frame.py @@ -5,7 +5,7 @@ import cv2 from fastapi import APIRouter, Depends, Query, HTTPException, Request -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from fastapi.responses import StreamingResponse, FileResponse from apis.base import StandardResponse, PageResponse, standard_response, standard_error_response, convert_page_param @@ -19,20 +19,20 @@ @router.get("/page/", response_model=StandardResponse[PageResponse[DeviceFrame]]) -def get_frame_page( +async def get_frame_page( device_name: Optional[str] = None, device_code: Optional[str] = None, frame_start_time: Optional[str] = None, frame_end_time: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceFrameService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - frames, total = service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), - parse_datetime(frame_end_time), offset, limit) + frames, total = await service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), + parse_datetime(frame_end_time), offset, limit) return standard_response( data=PageResponse(total=total, items=frames) @@ -41,10 +41,10 @@ # 路由:使用 OpenCV 生成内存图像并返回字节流 @router.get("/frame_image/") -def get_frame_image(frame_id, db: Session = Depends(get_db)): +async def get_frame_image(frame_id, db: AsyncSession = Depends(get_db)): try: service = DeviceFrameService(db) - frame = service.get_frame_annotator(frame_id) + frame = await service.get_frame_annotator(frame_id) if frame is None: return standard_error_response(message="Frame does not exist") @@ -63,14 +63,3 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"Error reading image: {str(e)}") - -@router.get("/frame_test") -def get_frame_test(request: Request): - file_path = "test.jpg" - return FileResponse(file_path) - # def iterfile(): - # with open(file_path, 'rb') as file: - # while chunk := file.read(1024): - # yield chunk - # - # return StreamingResponse(iterfile(), media_type="image/jpeg") diff --git a/apis/model.py b/apis/model.py index a1161eb..fa2b191 100644 --- a/apis/model.py +++ b/apis/model.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query, UploadFile, File, Form -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from app_instance import get_app @@ -18,27 +18,27 @@ @router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) -def get_model_list( +async def get_model_list( name: Optional[str] = None, remark: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = ModelService(db) - models = list(service.get_model_list(name, remark)) - return standard_response(data=models) + models = await service.get_model_list(name, remark) + return standard_response(data=list(models)) @router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModelInfo]]) -def get_model_page( +async def get_model_page( name: Optional[str] = None, remark: Optional[str] = None, offset: int = Query(0, ge=0), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = ModelService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - models, total = service.get_model_page(name, remark, offset, limit) + models, total = await service.get_model_page(name, remark, offset, limit) return standard_response( data=PageResponse(total=total, items=models) @@ -46,9 +46,9 @@ @router.post("/add", response_model=StandardResponse[AlgoModelInfo]) -def create_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelCreate结构"), +async def create_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelCreate结构"), file: UploadFile = File(..., description="模型文件"), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_db), ): # 检查文件类型 if not file.filename.endswith(".zip"): @@ -56,25 +56,25 @@ model_data = AlgoModelCreate.parse_raw(json_data) service = ModelService(db) - model = service.create_model(model_data, file) + model = await service.create_model(model_data, file) return standard_response(data=model) @router.post("/update", response_model=StandardResponse[AlgoModelInfo]) -def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), +async def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), file: UploadFile = File(None, description="模型文件"), service: ModelService = Depends(get_service)): model_data = AlgoModelUpdate.parse_raw(json_data) - model = service.update_model(model_data, file) + model = await service.update_model(model_data, file) if not model: return standard_error_response(data=model_data, message="Model not found") return standard_response(data=model) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_model(model_id: int, db: Session = Depends(get_db)): +async def delete_model(model_id: int, db: AsyncSession = Depends(get_db)): service = ModelService(db) - model = service.delete_model(model_id) + model = await service.delete_model(model_id) if not model: return standard_error_response(data=model_id, message="Model not found") return standard_response(data=model_id) diff --git a/apis/push_config.py b/apis/push_config.py index 5e25579..e234afb 100644 --- a/apis/push_config.py +++ b/apis/push_config.py @@ -1,7 +1,7 @@ from typing import List from fastapi import APIRouter, Depends -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse from db.database import get_db @@ -12,21 +12,21 @@ @router.get("/list", response_model=StandardResponse[List[PushConfig]]) -def get_push_config_list(db: Session = Depends(get_db)): +async def get_push_config_list(db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_configs = service.get_push_config_list() + push_configs = await service.get_push_config_list() return standard_response(data=push_configs) @router.get("/get_by_type", response_model=StandardResponse[PushConfig]) -def get_by_type(push_type: int, db: Session = Depends(get_db)): +async def get_by_type(push_type: int, db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_config = service.get_push_config(push_type) + push_config = await service.get_push_config(push_type) return standard_response(data=push_config) @router.post("/set_push_config", response_model=StandardResponse[PushConfig]) -def set_push_config(push_config: PushConfigCreate, db: Session = Depends(get_db)): +async def set_push_config(push_config: PushConfigCreate, db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_config = service.set_push_config(push_config) + push_config = await service.set_push_config(push_config) return standard_response(data=push_config) diff --git a/apis/scene.py b/apis/scene.py index f84b54b..1b7f706 100644 --- a/apis/scene.py +++ b/apis/scene.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query, UploadFile, File, Form -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db @@ -14,27 +14,27 @@ @router.get("/list", response_model=StandardResponse[List[Scene]]) -def get_scene_list( +async def get_scene_list( name: Optional[str] = None, remark: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) - scenes = list(service.get_scene_list(name, remark)) - return standard_response(data=scenes) + scenes = await service.get_scene_list(name, remark) + return standard_response(data=list(scenes)) @router.get("/page/", response_model=StandardResponse[PageResponse[SceneInfo]]) -def get_scene_page( +async def get_scene_page( name: Optional[str] = None, remark: Optional[str] = None, offset: int = Query(0, ge=0), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - scenes, total = service.get_scene_page(name, remark, offset, limit) + scenes, total = await service.get_scene_page(name, remark, offset, limit) return standard_response( data=PageResponse(total=total, items=scenes) @@ -42,9 +42,9 @@ @router.post("/add", response_model=StandardResponse[SceneInfo]) -def create_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneCreate结构"), +async def create_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneCreate结构"), file: UploadFile = File(..., description="模型文件"), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_db), ): # 检查文件类型 if not file.filename.endswith(".zip"): @@ -52,26 +52,26 @@ scene_data = SceneCreate.parse_raw(json_data) service = SceneService(db) - scene = service.create_scene(scene_data, file) + scene = await service.create_scene(scene_data, file) return standard_response(data=scene) @router.post("/update", response_model=StandardResponse[SceneInfo]) -def update_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneUpdate结构"), +async def update_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneUpdate结构"), file: UploadFile = File(None, description="模型文件"), - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) scene_data = SceneUpdate.parse_raw(json_data) - scene = service.update_scene(scene_data, file) + scene = await service.update_scene(scene_data, file) if not scene: return standard_error_response(data=scene_data, message="Scene not found") return standard_response(data=scene) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_scene(scene_id: int, db: Session = Depends(get_db)): +async def delete_scene(scene_id: int, db: AsyncSession = Depends(get_db)): service = SceneService(db) - scene = service.delete_scene(scene_id) + scene = await service.delete_scene(scene_id) if not scene: return standard_error_response(data=scene_id, message="Scene not found") return standard_response(data=scene_id) diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index dbc3fee..cdb97bf 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -98,7 +98,7 @@ def push_frame_results(self, frame_results): global_config = GlobalConfig() push_config = global_config.get_algo_result_push_config() - if push_config: + if push_config and push_config.push_url: last_ts = self.push_ts current_time = datetime.now() diff --git a/apis/control.py b/apis/control.py index 8517389..4c108c5 100644 --- a/apis/control.py +++ b/apis/control.py @@ -48,7 +48,7 @@ @router.get("/restart") -def restart(): +async def restart(): try: # 立即返回响应的函数 def restart_container_async(): diff --git a/apis/data_gas.py b/apis/data_gas.py index fc8ec54..bcb7cb6 100644 --- a/apis/data_gas.py +++ b/apis/data_gas.py @@ -3,8 +3,8 @@ import pandas as pd from fastapi import APIRouter, Query, Depends -from openpyxl.utils import get_column_letter -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession + from starlette.responses import StreamingResponse from apis.base import StandardResponse, PageResponse, convert_page_param, standard_response @@ -17,19 +17,20 @@ @router.get("/page", response_model=StandardResponse[PageResponse[DataGasInfo]]) -def get_gas_page( +async def get_gas_page( device_code: Optional[str] = None, start_time: Optional[str] = None, end_time: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): + service = DataGasService(db) offset, limit = convert_page_param(offset, limit) - data, total = service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), - offset, limit) + data, total = await service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), + offset, limit) return standard_response( data=PageResponse(total=total, items=data) @@ -37,38 +38,39 @@ @router.get("/export") -def export_data_gas(device_code: Optional[str] = None, - start_time: Optional[str] = None, - end_time: Optional[str] = None, - db: Session = Depends(get_db)): - service = DataGasService(db) - data = service.get_data_gas_list(device_code, start_time, end_time) +async def export_data_gas(device_code: Optional[str] = None, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + db: AsyncSession = Depends(get_db)): + async with db: + service = DataGasService(db) + data = await service.get_data_gas_list(device_code, start_time, end_time) - # 将查询结果转换为 DataFrame - data = [ - {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} - for item in data] - df = pd.DataFrame(data) + # 将查询结果转换为 DataFrame + data = [ + {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} + for item in data] + df = pd.DataFrame(data) - # 使用 BytesIO 生成内存中的 Excel 文件 - output = BytesIO() - with pd.ExcelWriter(output, engine="openpyxl") as writer: - df.to_excel(writer, index=False) + # 使用 BytesIO 生成内存中的 Excel 文件 + output = BytesIO() + with pd.ExcelWriter(output, engine="openpyxl") as writer: + df.to_excel(writer, index=False) - # 获取工作表 - worksheet = writer.sheets["Sheet1"] + # 获取工作表 + worksheet = writer.sheets["Sheet1"] - # 设置固定的列宽,例如宽度为 20 - fixed_width = 30 - for col in worksheet.columns: - col_letter = col[0].column_letter # 获取列字母 - worksheet.column_dimensions[col_letter].width = fixed_width + # 设置固定的列宽,例如宽度为 20 + fixed_width = 30 + for col in worksheet.columns: + col_letter = col[0].column_letter # 获取列字母 + worksheet.column_dimensions[col_letter].width = fixed_width - output.seek(0) + output.seek(0) - # 返回内存中的 Excel 文件作为响应 - return StreamingResponse( - output, - media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", - headers={"Content-Disposition": "attachment; filename=gas.xlsx"} - ) + # 返回内存中的 Excel 文件作为响应 + return StreamingResponse( + output, + media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + headers={"Content-Disposition": "attachment; filename=gas.xlsx"} + ) diff --git a/apis/device.py b/apis/device.py index 85fb10c..f170ae3 100644 --- a/apis/device.py +++ b/apis/device.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db @@ -18,29 +18,29 @@ @router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) -def get_device_list( +async def get_device_list( name: Optional[str] = None, code: Optional[str] = None, device_type: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceService(db) - devices = list(service.get_device_list(name, code, device_type)) - return standard_response(data=devices) + devices = await service.get_device_list(name, code, device_type) + return standard_response(data=list(devices)) @router.get("/page/", response_model=StandardResponse[PageResponse[DeviceInfo]]) -def get_device_page( +async def get_device_page( name: Optional[str] = None, code: Optional[str] = None, device_type: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - devices, total = service.get_device_page(name, code, device_type, offset, limit) + devices, total = await service.get_device_page(name, code, device_type, offset, limit) return standard_response( data=PageResponse(total=total, items=devices) @@ -48,22 +48,22 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): - device = service.create_device(device_data) +async def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): + device = await service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): - device = service.update_device(device_data) +async def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): + device = await service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") return standard_response(data=device) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, service: DeviceService = Depends(get_service)): - device = service.delete_device(device_id) +async def delete_device(device_id: int, service: DeviceService = Depends(get_service)): + device = await service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") return standard_response(data=device_id) diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py index e61e1d6..ddb5eb6 100644 --- a/apis/device_model_realtion.py +++ b/apis/device_model_realtion.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import StandardResponse, standard_response from db.database import get_db @@ -13,37 +13,23 @@ router = APIRouter() app = get_app() + def get_service(): return app.state.model_relation_service @router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) -def list_by_device( +async def list_by_device( device_id: int, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceModelRelationService(db) - models = list(service.get_device_models(device_id)) - return standard_response(data=models) - - -# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) -# def add_by_device(relation_data: List[DeviceModelRelationCreate], -# device_id: int = Query(...), -# db: Session = Depends(get_db)): -# service = DeviceModelRelationService(db) -# relations = service.add_relations_by_device(device_id, relation_data) -# return standard_response(data=relations) + models = await service.get_device_models(device_id) + return standard_response(data=list(models)) @router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) -def update_by_device(relation_data: List[DeviceModelRelationCreate], - device_id: int = Query(...), - service: DeviceModelRelationService = Depends(get_service)): - relations = service.update_relations_by_device(device_id, relation_data) +async def update_by_device(relation_data: List[DeviceModelRelationCreate], + device_id: int = Query(...), + service: DeviceModelRelationService = Depends(get_service)): + relations = await service.update_relations_by_device(device_id, relation_data) return standard_response(data=relations) - -# @router.delete("/delete_by_device", response_model=StandardResponse[int]) -# def delete_device(device_id: int, db: Session = Depends(get_db)): -# service = DeviceModelRelationService(db) -# count = service.delete_relations_by_device(device_id) -# return standard_response(data=count) diff --git a/apis/device_scene_realtion.py b/apis/device_scene_realtion.py index 9a30a44..d9aa3a2 100644 --- a/apis/device_scene_realtion.py +++ b/apis/device_scene_realtion.py @@ -1,8 +1,8 @@ -from typing import List, Optional + from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import StandardResponse, standard_response from db.database import get_db from entity.device_scene_relation import DeviceSceneRelationInfo, DeviceSceneRelation @@ -13,17 +13,17 @@ @router.get("/get_by_device", response_model=StandardResponse[DeviceSceneRelationInfo]) -def list_by_device( +async def list_by_device( device_id: int, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - scene = service.get_device_scene(device_id) + scene = await service.get_device_scene(device_id) return standard_response(data=scene) @router.post("/update_by_device", response_model=StandardResponse[DeviceSceneRelation]) -def update_by_device(device_id: int, scene_id: int, - db: Session = Depends(get_db)): +async def update_by_device(device_id: int, scene_id: int, + db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - relation = service.update_relation_by_device(device_id, scene_id) + relation = await service.update_relation_by_device(device_id, scene_id) return standard_response(data=relation) diff --git a/apis/frame.py b/apis/frame.py index 1cb9294..5146ede 100644 --- a/apis/frame.py +++ b/apis/frame.py @@ -5,7 +5,7 @@ import cv2 from fastapi import APIRouter, Depends, Query, HTTPException, Request -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from fastapi.responses import StreamingResponse, FileResponse from apis.base import StandardResponse, PageResponse, standard_response, standard_error_response, convert_page_param @@ -19,20 +19,20 @@ @router.get("/page/", response_model=StandardResponse[PageResponse[DeviceFrame]]) -def get_frame_page( +async def get_frame_page( device_name: Optional[str] = None, device_code: Optional[str] = None, frame_start_time: Optional[str] = None, frame_end_time: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceFrameService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - frames, total = service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), - parse_datetime(frame_end_time), offset, limit) + frames, total = await service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), + parse_datetime(frame_end_time), offset, limit) return standard_response( data=PageResponse(total=total, items=frames) @@ -41,10 +41,10 @@ # 路由:使用 OpenCV 生成内存图像并返回字节流 @router.get("/frame_image/") -def get_frame_image(frame_id, db: Session = Depends(get_db)): +async def get_frame_image(frame_id, db: AsyncSession = Depends(get_db)): try: service = DeviceFrameService(db) - frame = service.get_frame_annotator(frame_id) + frame = await service.get_frame_annotator(frame_id) if frame is None: return standard_error_response(message="Frame does not exist") @@ -63,14 +63,3 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"Error reading image: {str(e)}") - -@router.get("/frame_test") -def get_frame_test(request: Request): - file_path = "test.jpg" - return FileResponse(file_path) - # def iterfile(): - # with open(file_path, 'rb') as file: - # while chunk := file.read(1024): - # yield chunk - # - # return StreamingResponse(iterfile(), media_type="image/jpeg") diff --git a/apis/model.py b/apis/model.py index a1161eb..fa2b191 100644 --- a/apis/model.py +++ b/apis/model.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query, UploadFile, File, Form -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from app_instance import get_app @@ -18,27 +18,27 @@ @router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) -def get_model_list( +async def get_model_list( name: Optional[str] = None, remark: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = ModelService(db) - models = list(service.get_model_list(name, remark)) - return standard_response(data=models) + models = await service.get_model_list(name, remark) + return standard_response(data=list(models)) @router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModelInfo]]) -def get_model_page( +async def get_model_page( name: Optional[str] = None, remark: Optional[str] = None, offset: int = Query(0, ge=0), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = ModelService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - models, total = service.get_model_page(name, remark, offset, limit) + models, total = await service.get_model_page(name, remark, offset, limit) return standard_response( data=PageResponse(total=total, items=models) @@ -46,9 +46,9 @@ @router.post("/add", response_model=StandardResponse[AlgoModelInfo]) -def create_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelCreate结构"), +async def create_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelCreate结构"), file: UploadFile = File(..., description="模型文件"), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_db), ): # 检查文件类型 if not file.filename.endswith(".zip"): @@ -56,25 +56,25 @@ model_data = AlgoModelCreate.parse_raw(json_data) service = ModelService(db) - model = service.create_model(model_data, file) + model = await service.create_model(model_data, file) return standard_response(data=model) @router.post("/update", response_model=StandardResponse[AlgoModelInfo]) -def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), +async def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), file: UploadFile = File(None, description="模型文件"), service: ModelService = Depends(get_service)): model_data = AlgoModelUpdate.parse_raw(json_data) - model = service.update_model(model_data, file) + model = await service.update_model(model_data, file) if not model: return standard_error_response(data=model_data, message="Model not found") return standard_response(data=model) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_model(model_id: int, db: Session = Depends(get_db)): +async def delete_model(model_id: int, db: AsyncSession = Depends(get_db)): service = ModelService(db) - model = service.delete_model(model_id) + model = await service.delete_model(model_id) if not model: return standard_error_response(data=model_id, message="Model not found") return standard_response(data=model_id) diff --git a/apis/push_config.py b/apis/push_config.py index 5e25579..e234afb 100644 --- a/apis/push_config.py +++ b/apis/push_config.py @@ -1,7 +1,7 @@ from typing import List from fastapi import APIRouter, Depends -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse from db.database import get_db @@ -12,21 +12,21 @@ @router.get("/list", response_model=StandardResponse[List[PushConfig]]) -def get_push_config_list(db: Session = Depends(get_db)): +async def get_push_config_list(db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_configs = service.get_push_config_list() + push_configs = await service.get_push_config_list() return standard_response(data=push_configs) @router.get("/get_by_type", response_model=StandardResponse[PushConfig]) -def get_by_type(push_type: int, db: Session = Depends(get_db)): +async def get_by_type(push_type: int, db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_config = service.get_push_config(push_type) + push_config = await service.get_push_config(push_type) return standard_response(data=push_config) @router.post("/set_push_config", response_model=StandardResponse[PushConfig]) -def set_push_config(push_config: PushConfigCreate, db: Session = Depends(get_db)): +async def set_push_config(push_config: PushConfigCreate, db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_config = service.set_push_config(push_config) + push_config = await service.set_push_config(push_config) return standard_response(data=push_config) diff --git a/apis/scene.py b/apis/scene.py index f84b54b..1b7f706 100644 --- a/apis/scene.py +++ b/apis/scene.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query, UploadFile, File, Form -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db @@ -14,27 +14,27 @@ @router.get("/list", response_model=StandardResponse[List[Scene]]) -def get_scene_list( +async def get_scene_list( name: Optional[str] = None, remark: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) - scenes = list(service.get_scene_list(name, remark)) - return standard_response(data=scenes) + scenes = await service.get_scene_list(name, remark) + return standard_response(data=list(scenes)) @router.get("/page/", response_model=StandardResponse[PageResponse[SceneInfo]]) -def get_scene_page( +async def get_scene_page( name: Optional[str] = None, remark: Optional[str] = None, offset: int = Query(0, ge=0), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - scenes, total = service.get_scene_page(name, remark, offset, limit) + scenes, total = await service.get_scene_page(name, remark, offset, limit) return standard_response( data=PageResponse(total=total, items=scenes) @@ -42,9 +42,9 @@ @router.post("/add", response_model=StandardResponse[SceneInfo]) -def create_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneCreate结构"), +async def create_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneCreate结构"), file: UploadFile = File(..., description="模型文件"), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_db), ): # 检查文件类型 if not file.filename.endswith(".zip"): @@ -52,26 +52,26 @@ scene_data = SceneCreate.parse_raw(json_data) service = SceneService(db) - scene = service.create_scene(scene_data, file) + scene = await service.create_scene(scene_data, file) return standard_response(data=scene) @router.post("/update", response_model=StandardResponse[SceneInfo]) -def update_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneUpdate结构"), +async def update_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneUpdate结构"), file: UploadFile = File(None, description="模型文件"), - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) scene_data = SceneUpdate.parse_raw(json_data) - scene = service.update_scene(scene_data, file) + scene = await service.update_scene(scene_data, file) if not scene: return standard_error_response(data=scene_data, message="Scene not found") return standard_response(data=scene) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_scene(scene_id: int, db: Session = Depends(get_db)): +async def delete_scene(scene_id: int, db: AsyncSession = Depends(get_db)): service = SceneService(db) - scene = service.delete_scene(scene_id) + scene = await service.delete_scene(scene_id) if not scene: return standard_error_response(data=scene_id, message="Scene not found") return standard_response(data=scene_id) diff --git a/app_instance.py b/app_instance.py index a6e0992..f9d1ad9 100644 --- a/app_instance.py +++ b/app_instance.py @@ -12,6 +12,7 @@ from services.device_model_relation_service import DeviceModelRelationService from services.device_scene_relation_service import DeviceSceneRelationService from services.device_service import DeviceService +from services.global_config import GlobalConfig from services.model_service import ModelService from services.scene_service import SceneService from tcp.tcp_manager import TcpManager @@ -28,7 +29,11 @@ async def lifespan(app: FastAPI): main_loop = asyncio.get_running_loop() - with next(get_db()) as db: + # async with get_db() as db: + async for db in get_db(): + global_config = GlobalConfig() + await global_config.init_config() + device_service = DeviceService(db) model_service = ModelService(db) model_relation_service = DeviceModelRelationService(db) @@ -51,7 +56,7 @@ relation_service=model_relation_service, ) app.state.algo_runner = algo_runner - await algo_runner.start() + # await algo_runner.start() scene_runner = SceneRunner( device_service=device_service, @@ -61,7 +66,9 @@ main_loop=main_loop ) app.state.scene_runner = scene_runner - await scene_runner.start() + # await scene_runner.start() + + yield # 允许请求处理 diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index dbc3fee..cdb97bf 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -98,7 +98,7 @@ def push_frame_results(self, frame_results): global_config = GlobalConfig() push_config = global_config.get_algo_result_push_config() - if push_config: + if push_config and push_config.push_url: last_ts = self.push_ts current_time = datetime.now() diff --git a/apis/control.py b/apis/control.py index 8517389..4c108c5 100644 --- a/apis/control.py +++ b/apis/control.py @@ -48,7 +48,7 @@ @router.get("/restart") -def restart(): +async def restart(): try: # 立即返回响应的函数 def restart_container_async(): diff --git a/apis/data_gas.py b/apis/data_gas.py index fc8ec54..bcb7cb6 100644 --- a/apis/data_gas.py +++ b/apis/data_gas.py @@ -3,8 +3,8 @@ import pandas as pd from fastapi import APIRouter, Query, Depends -from openpyxl.utils import get_column_letter -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession + from starlette.responses import StreamingResponse from apis.base import StandardResponse, PageResponse, convert_page_param, standard_response @@ -17,19 +17,20 @@ @router.get("/page", response_model=StandardResponse[PageResponse[DataGasInfo]]) -def get_gas_page( +async def get_gas_page( device_code: Optional[str] = None, start_time: Optional[str] = None, end_time: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): + service = DataGasService(db) offset, limit = convert_page_param(offset, limit) - data, total = service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), - offset, limit) + data, total = await service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), + offset, limit) return standard_response( data=PageResponse(total=total, items=data) @@ -37,38 +38,39 @@ @router.get("/export") -def export_data_gas(device_code: Optional[str] = None, - start_time: Optional[str] = None, - end_time: Optional[str] = None, - db: Session = Depends(get_db)): - service = DataGasService(db) - data = service.get_data_gas_list(device_code, start_time, end_time) +async def export_data_gas(device_code: Optional[str] = None, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + db: AsyncSession = Depends(get_db)): + async with db: + service = DataGasService(db) + data = await service.get_data_gas_list(device_code, start_time, end_time) - # 将查询结果转换为 DataFrame - data = [ - {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} - for item in data] - df = pd.DataFrame(data) + # 将查询结果转换为 DataFrame + data = [ + {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} + for item in data] + df = pd.DataFrame(data) - # 使用 BytesIO 生成内存中的 Excel 文件 - output = BytesIO() - with pd.ExcelWriter(output, engine="openpyxl") as writer: - df.to_excel(writer, index=False) + # 使用 BytesIO 生成内存中的 Excel 文件 + output = BytesIO() + with pd.ExcelWriter(output, engine="openpyxl") as writer: + df.to_excel(writer, index=False) - # 获取工作表 - worksheet = writer.sheets["Sheet1"] + # 获取工作表 + worksheet = writer.sheets["Sheet1"] - # 设置固定的列宽,例如宽度为 20 - fixed_width = 30 - for col in worksheet.columns: - col_letter = col[0].column_letter # 获取列字母 - worksheet.column_dimensions[col_letter].width = fixed_width + # 设置固定的列宽,例如宽度为 20 + fixed_width = 30 + for col in worksheet.columns: + col_letter = col[0].column_letter # 获取列字母 + worksheet.column_dimensions[col_letter].width = fixed_width - output.seek(0) + output.seek(0) - # 返回内存中的 Excel 文件作为响应 - return StreamingResponse( - output, - media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", - headers={"Content-Disposition": "attachment; filename=gas.xlsx"} - ) + # 返回内存中的 Excel 文件作为响应 + return StreamingResponse( + output, + media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + headers={"Content-Disposition": "attachment; filename=gas.xlsx"} + ) diff --git a/apis/device.py b/apis/device.py index 85fb10c..f170ae3 100644 --- a/apis/device.py +++ b/apis/device.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db @@ -18,29 +18,29 @@ @router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) -def get_device_list( +async def get_device_list( name: Optional[str] = None, code: Optional[str] = None, device_type: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceService(db) - devices = list(service.get_device_list(name, code, device_type)) - return standard_response(data=devices) + devices = await service.get_device_list(name, code, device_type) + return standard_response(data=list(devices)) @router.get("/page/", response_model=StandardResponse[PageResponse[DeviceInfo]]) -def get_device_page( +async def get_device_page( name: Optional[str] = None, code: Optional[str] = None, device_type: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - devices, total = service.get_device_page(name, code, device_type, offset, limit) + devices, total = await service.get_device_page(name, code, device_type, offset, limit) return standard_response( data=PageResponse(total=total, items=devices) @@ -48,22 +48,22 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): - device = service.create_device(device_data) +async def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): + device = await service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): - device = service.update_device(device_data) +async def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): + device = await service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") return standard_response(data=device) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, service: DeviceService = Depends(get_service)): - device = service.delete_device(device_id) +async def delete_device(device_id: int, service: DeviceService = Depends(get_service)): + device = await service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") return standard_response(data=device_id) diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py index e61e1d6..ddb5eb6 100644 --- a/apis/device_model_realtion.py +++ b/apis/device_model_realtion.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import StandardResponse, standard_response from db.database import get_db @@ -13,37 +13,23 @@ router = APIRouter() app = get_app() + def get_service(): return app.state.model_relation_service @router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) -def list_by_device( +async def list_by_device( device_id: int, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceModelRelationService(db) - models = list(service.get_device_models(device_id)) - return standard_response(data=models) - - -# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) -# def add_by_device(relation_data: List[DeviceModelRelationCreate], -# device_id: int = Query(...), -# db: Session = Depends(get_db)): -# service = DeviceModelRelationService(db) -# relations = service.add_relations_by_device(device_id, relation_data) -# return standard_response(data=relations) + models = await service.get_device_models(device_id) + return standard_response(data=list(models)) @router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) -def update_by_device(relation_data: List[DeviceModelRelationCreate], - device_id: int = Query(...), - service: DeviceModelRelationService = Depends(get_service)): - relations = service.update_relations_by_device(device_id, relation_data) +async def update_by_device(relation_data: List[DeviceModelRelationCreate], + device_id: int = Query(...), + service: DeviceModelRelationService = Depends(get_service)): + relations = await service.update_relations_by_device(device_id, relation_data) return standard_response(data=relations) - -# @router.delete("/delete_by_device", response_model=StandardResponse[int]) -# def delete_device(device_id: int, db: Session = Depends(get_db)): -# service = DeviceModelRelationService(db) -# count = service.delete_relations_by_device(device_id) -# return standard_response(data=count) diff --git a/apis/device_scene_realtion.py b/apis/device_scene_realtion.py index 9a30a44..d9aa3a2 100644 --- a/apis/device_scene_realtion.py +++ b/apis/device_scene_realtion.py @@ -1,8 +1,8 @@ -from typing import List, Optional + from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import StandardResponse, standard_response from db.database import get_db from entity.device_scene_relation import DeviceSceneRelationInfo, DeviceSceneRelation @@ -13,17 +13,17 @@ @router.get("/get_by_device", response_model=StandardResponse[DeviceSceneRelationInfo]) -def list_by_device( +async def list_by_device( device_id: int, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - scene = service.get_device_scene(device_id) + scene = await service.get_device_scene(device_id) return standard_response(data=scene) @router.post("/update_by_device", response_model=StandardResponse[DeviceSceneRelation]) -def update_by_device(device_id: int, scene_id: int, - db: Session = Depends(get_db)): +async def update_by_device(device_id: int, scene_id: int, + db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - relation = service.update_relation_by_device(device_id, scene_id) + relation = await service.update_relation_by_device(device_id, scene_id) return standard_response(data=relation) diff --git a/apis/frame.py b/apis/frame.py index 1cb9294..5146ede 100644 --- a/apis/frame.py +++ b/apis/frame.py @@ -5,7 +5,7 @@ import cv2 from fastapi import APIRouter, Depends, Query, HTTPException, Request -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from fastapi.responses import StreamingResponse, FileResponse from apis.base import StandardResponse, PageResponse, standard_response, standard_error_response, convert_page_param @@ -19,20 +19,20 @@ @router.get("/page/", response_model=StandardResponse[PageResponse[DeviceFrame]]) -def get_frame_page( +async def get_frame_page( device_name: Optional[str] = None, device_code: Optional[str] = None, frame_start_time: Optional[str] = None, frame_end_time: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceFrameService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - frames, total = service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), - parse_datetime(frame_end_time), offset, limit) + frames, total = await service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), + parse_datetime(frame_end_time), offset, limit) return standard_response( data=PageResponse(total=total, items=frames) @@ -41,10 +41,10 @@ # 路由:使用 OpenCV 生成内存图像并返回字节流 @router.get("/frame_image/") -def get_frame_image(frame_id, db: Session = Depends(get_db)): +async def get_frame_image(frame_id, db: AsyncSession = Depends(get_db)): try: service = DeviceFrameService(db) - frame = service.get_frame_annotator(frame_id) + frame = await service.get_frame_annotator(frame_id) if frame is None: return standard_error_response(message="Frame does not exist") @@ -63,14 +63,3 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"Error reading image: {str(e)}") - -@router.get("/frame_test") -def get_frame_test(request: Request): - file_path = "test.jpg" - return FileResponse(file_path) - # def iterfile(): - # with open(file_path, 'rb') as file: - # while chunk := file.read(1024): - # yield chunk - # - # return StreamingResponse(iterfile(), media_type="image/jpeg") diff --git a/apis/model.py b/apis/model.py index a1161eb..fa2b191 100644 --- a/apis/model.py +++ b/apis/model.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query, UploadFile, File, Form -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from app_instance import get_app @@ -18,27 +18,27 @@ @router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) -def get_model_list( +async def get_model_list( name: Optional[str] = None, remark: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = ModelService(db) - models = list(service.get_model_list(name, remark)) - return standard_response(data=models) + models = await service.get_model_list(name, remark) + return standard_response(data=list(models)) @router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModelInfo]]) -def get_model_page( +async def get_model_page( name: Optional[str] = None, remark: Optional[str] = None, offset: int = Query(0, ge=0), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = ModelService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - models, total = service.get_model_page(name, remark, offset, limit) + models, total = await service.get_model_page(name, remark, offset, limit) return standard_response( data=PageResponse(total=total, items=models) @@ -46,9 +46,9 @@ @router.post("/add", response_model=StandardResponse[AlgoModelInfo]) -def create_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelCreate结构"), +async def create_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelCreate结构"), file: UploadFile = File(..., description="模型文件"), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_db), ): # 检查文件类型 if not file.filename.endswith(".zip"): @@ -56,25 +56,25 @@ model_data = AlgoModelCreate.parse_raw(json_data) service = ModelService(db) - model = service.create_model(model_data, file) + model = await service.create_model(model_data, file) return standard_response(data=model) @router.post("/update", response_model=StandardResponse[AlgoModelInfo]) -def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), +async def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), file: UploadFile = File(None, description="模型文件"), service: ModelService = Depends(get_service)): model_data = AlgoModelUpdate.parse_raw(json_data) - model = service.update_model(model_data, file) + model = await service.update_model(model_data, file) if not model: return standard_error_response(data=model_data, message="Model not found") return standard_response(data=model) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_model(model_id: int, db: Session = Depends(get_db)): +async def delete_model(model_id: int, db: AsyncSession = Depends(get_db)): service = ModelService(db) - model = service.delete_model(model_id) + model = await service.delete_model(model_id) if not model: return standard_error_response(data=model_id, message="Model not found") return standard_response(data=model_id) diff --git a/apis/push_config.py b/apis/push_config.py index 5e25579..e234afb 100644 --- a/apis/push_config.py +++ b/apis/push_config.py @@ -1,7 +1,7 @@ from typing import List from fastapi import APIRouter, Depends -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse from db.database import get_db @@ -12,21 +12,21 @@ @router.get("/list", response_model=StandardResponse[List[PushConfig]]) -def get_push_config_list(db: Session = Depends(get_db)): +async def get_push_config_list(db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_configs = service.get_push_config_list() + push_configs = await service.get_push_config_list() return standard_response(data=push_configs) @router.get("/get_by_type", response_model=StandardResponse[PushConfig]) -def get_by_type(push_type: int, db: Session = Depends(get_db)): +async def get_by_type(push_type: int, db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_config = service.get_push_config(push_type) + push_config = await service.get_push_config(push_type) return standard_response(data=push_config) @router.post("/set_push_config", response_model=StandardResponse[PushConfig]) -def set_push_config(push_config: PushConfigCreate, db: Session = Depends(get_db)): +async def set_push_config(push_config: PushConfigCreate, db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_config = service.set_push_config(push_config) + push_config = await service.set_push_config(push_config) return standard_response(data=push_config) diff --git a/apis/scene.py b/apis/scene.py index f84b54b..1b7f706 100644 --- a/apis/scene.py +++ b/apis/scene.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query, UploadFile, File, Form -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db @@ -14,27 +14,27 @@ @router.get("/list", response_model=StandardResponse[List[Scene]]) -def get_scene_list( +async def get_scene_list( name: Optional[str] = None, remark: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) - scenes = list(service.get_scene_list(name, remark)) - return standard_response(data=scenes) + scenes = await service.get_scene_list(name, remark) + return standard_response(data=list(scenes)) @router.get("/page/", response_model=StandardResponse[PageResponse[SceneInfo]]) -def get_scene_page( +async def get_scene_page( name: Optional[str] = None, remark: Optional[str] = None, offset: int = Query(0, ge=0), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - scenes, total = service.get_scene_page(name, remark, offset, limit) + scenes, total = await service.get_scene_page(name, remark, offset, limit) return standard_response( data=PageResponse(total=total, items=scenes) @@ -42,9 +42,9 @@ @router.post("/add", response_model=StandardResponse[SceneInfo]) -def create_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneCreate结构"), +async def create_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneCreate结构"), file: UploadFile = File(..., description="模型文件"), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_db), ): # 检查文件类型 if not file.filename.endswith(".zip"): @@ -52,26 +52,26 @@ scene_data = SceneCreate.parse_raw(json_data) service = SceneService(db) - scene = service.create_scene(scene_data, file) + scene = await service.create_scene(scene_data, file) return standard_response(data=scene) @router.post("/update", response_model=StandardResponse[SceneInfo]) -def update_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneUpdate结构"), +async def update_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneUpdate结构"), file: UploadFile = File(None, description="模型文件"), - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) scene_data = SceneUpdate.parse_raw(json_data) - scene = service.update_scene(scene_data, file) + scene = await service.update_scene(scene_data, file) if not scene: return standard_error_response(data=scene_data, message="Scene not found") return standard_response(data=scene) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_scene(scene_id: int, db: Session = Depends(get_db)): +async def delete_scene(scene_id: int, db: AsyncSession = Depends(get_db)): service = SceneService(db) - scene = service.delete_scene(scene_id) + scene = await service.delete_scene(scene_id) if not scene: return standard_error_response(data=scene_id, message="Scene not found") return standard_response(data=scene_id) diff --git a/app_instance.py b/app_instance.py index a6e0992..f9d1ad9 100644 --- a/app_instance.py +++ b/app_instance.py @@ -12,6 +12,7 @@ from services.device_model_relation_service import DeviceModelRelationService from services.device_scene_relation_service import DeviceSceneRelationService from services.device_service import DeviceService +from services.global_config import GlobalConfig from services.model_service import ModelService from services.scene_service import SceneService from tcp.tcp_manager import TcpManager @@ -28,7 +29,11 @@ async def lifespan(app: FastAPI): main_loop = asyncio.get_running_loop() - with next(get_db()) as db: + # async with get_db() as db: + async for db in get_db(): + global_config = GlobalConfig() + await global_config.init_config() + device_service = DeviceService(db) model_service = ModelService(db) model_relation_service = DeviceModelRelationService(db) @@ -51,7 +56,7 @@ relation_service=model_relation_service, ) app.state.algo_runner = algo_runner - await algo_runner.start() + # await algo_runner.start() scene_runner = SceneRunner( device_service=device_service, @@ -61,7 +66,9 @@ main_loop=main_loop ) app.state.scene_runner = scene_runner - await scene_runner.start() + # await scene_runner.start() + + yield # 允许请求处理 diff --git a/common/global_logger.py b/common/global_logger.py index 3e3a6a9..3e672ca 100644 --- a/common/global_logger.py +++ b/common/global_logger.py @@ -12,6 +12,8 @@ logger = logging.getLogger("casic_safe_logger") logger.setLevel(logging.DEBUG) # 设置日志级别 +logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING) + # 创建一个TimedRotatingFileHandler handler = logging.handlers.TimedRotatingFileHandler( os.path.join(log_dir, 'app.log'), # 日志文件名 diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index dbc3fee..cdb97bf 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -98,7 +98,7 @@ def push_frame_results(self, frame_results): global_config = GlobalConfig() push_config = global_config.get_algo_result_push_config() - if push_config: + if push_config and push_config.push_url: last_ts = self.push_ts current_time = datetime.now() diff --git a/apis/control.py b/apis/control.py index 8517389..4c108c5 100644 --- a/apis/control.py +++ b/apis/control.py @@ -48,7 +48,7 @@ @router.get("/restart") -def restart(): +async def restart(): try: # 立即返回响应的函数 def restart_container_async(): diff --git a/apis/data_gas.py b/apis/data_gas.py index fc8ec54..bcb7cb6 100644 --- a/apis/data_gas.py +++ b/apis/data_gas.py @@ -3,8 +3,8 @@ import pandas as pd from fastapi import APIRouter, Query, Depends -from openpyxl.utils import get_column_letter -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession + from starlette.responses import StreamingResponse from apis.base import StandardResponse, PageResponse, convert_page_param, standard_response @@ -17,19 +17,20 @@ @router.get("/page", response_model=StandardResponse[PageResponse[DataGasInfo]]) -def get_gas_page( +async def get_gas_page( device_code: Optional[str] = None, start_time: Optional[str] = None, end_time: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): + service = DataGasService(db) offset, limit = convert_page_param(offset, limit) - data, total = service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), - offset, limit) + data, total = await service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), + offset, limit) return standard_response( data=PageResponse(total=total, items=data) @@ -37,38 +38,39 @@ @router.get("/export") -def export_data_gas(device_code: Optional[str] = None, - start_time: Optional[str] = None, - end_time: Optional[str] = None, - db: Session = Depends(get_db)): - service = DataGasService(db) - data = service.get_data_gas_list(device_code, start_time, end_time) +async def export_data_gas(device_code: Optional[str] = None, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + db: AsyncSession = Depends(get_db)): + async with db: + service = DataGasService(db) + data = await service.get_data_gas_list(device_code, start_time, end_time) - # 将查询结果转换为 DataFrame - data = [ - {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} - for item in data] - df = pd.DataFrame(data) + # 将查询结果转换为 DataFrame + data = [ + {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} + for item in data] + df = pd.DataFrame(data) - # 使用 BytesIO 生成内存中的 Excel 文件 - output = BytesIO() - with pd.ExcelWriter(output, engine="openpyxl") as writer: - df.to_excel(writer, index=False) + # 使用 BytesIO 生成内存中的 Excel 文件 + output = BytesIO() + with pd.ExcelWriter(output, engine="openpyxl") as writer: + df.to_excel(writer, index=False) - # 获取工作表 - worksheet = writer.sheets["Sheet1"] + # 获取工作表 + worksheet = writer.sheets["Sheet1"] - # 设置固定的列宽,例如宽度为 20 - fixed_width = 30 - for col in worksheet.columns: - col_letter = col[0].column_letter # 获取列字母 - worksheet.column_dimensions[col_letter].width = fixed_width + # 设置固定的列宽,例如宽度为 20 + fixed_width = 30 + for col in worksheet.columns: + col_letter = col[0].column_letter # 获取列字母 + worksheet.column_dimensions[col_letter].width = fixed_width - output.seek(0) + output.seek(0) - # 返回内存中的 Excel 文件作为响应 - return StreamingResponse( - output, - media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", - headers={"Content-Disposition": "attachment; filename=gas.xlsx"} - ) + # 返回内存中的 Excel 文件作为响应 + return StreamingResponse( + output, + media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + headers={"Content-Disposition": "attachment; filename=gas.xlsx"} + ) diff --git a/apis/device.py b/apis/device.py index 85fb10c..f170ae3 100644 --- a/apis/device.py +++ b/apis/device.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db @@ -18,29 +18,29 @@ @router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) -def get_device_list( +async def get_device_list( name: Optional[str] = None, code: Optional[str] = None, device_type: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceService(db) - devices = list(service.get_device_list(name, code, device_type)) - return standard_response(data=devices) + devices = await service.get_device_list(name, code, device_type) + return standard_response(data=list(devices)) @router.get("/page/", response_model=StandardResponse[PageResponse[DeviceInfo]]) -def get_device_page( +async def get_device_page( name: Optional[str] = None, code: Optional[str] = None, device_type: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - devices, total = service.get_device_page(name, code, device_type, offset, limit) + devices, total = await service.get_device_page(name, code, device_type, offset, limit) return standard_response( data=PageResponse(total=total, items=devices) @@ -48,22 +48,22 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): - device = service.create_device(device_data) +async def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): + device = await service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): - device = service.update_device(device_data) +async def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): + device = await service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") return standard_response(data=device) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, service: DeviceService = Depends(get_service)): - device = service.delete_device(device_id) +async def delete_device(device_id: int, service: DeviceService = Depends(get_service)): + device = await service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") return standard_response(data=device_id) diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py index e61e1d6..ddb5eb6 100644 --- a/apis/device_model_realtion.py +++ b/apis/device_model_realtion.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import StandardResponse, standard_response from db.database import get_db @@ -13,37 +13,23 @@ router = APIRouter() app = get_app() + def get_service(): return app.state.model_relation_service @router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) -def list_by_device( +async def list_by_device( device_id: int, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceModelRelationService(db) - models = list(service.get_device_models(device_id)) - return standard_response(data=models) - - -# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) -# def add_by_device(relation_data: List[DeviceModelRelationCreate], -# device_id: int = Query(...), -# db: Session = Depends(get_db)): -# service = DeviceModelRelationService(db) -# relations = service.add_relations_by_device(device_id, relation_data) -# return standard_response(data=relations) + models = await service.get_device_models(device_id) + return standard_response(data=list(models)) @router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) -def update_by_device(relation_data: List[DeviceModelRelationCreate], - device_id: int = Query(...), - service: DeviceModelRelationService = Depends(get_service)): - relations = service.update_relations_by_device(device_id, relation_data) +async def update_by_device(relation_data: List[DeviceModelRelationCreate], + device_id: int = Query(...), + service: DeviceModelRelationService = Depends(get_service)): + relations = await service.update_relations_by_device(device_id, relation_data) return standard_response(data=relations) - -# @router.delete("/delete_by_device", response_model=StandardResponse[int]) -# def delete_device(device_id: int, db: Session = Depends(get_db)): -# service = DeviceModelRelationService(db) -# count = service.delete_relations_by_device(device_id) -# return standard_response(data=count) diff --git a/apis/device_scene_realtion.py b/apis/device_scene_realtion.py index 9a30a44..d9aa3a2 100644 --- a/apis/device_scene_realtion.py +++ b/apis/device_scene_realtion.py @@ -1,8 +1,8 @@ -from typing import List, Optional + from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import StandardResponse, standard_response from db.database import get_db from entity.device_scene_relation import DeviceSceneRelationInfo, DeviceSceneRelation @@ -13,17 +13,17 @@ @router.get("/get_by_device", response_model=StandardResponse[DeviceSceneRelationInfo]) -def list_by_device( +async def list_by_device( device_id: int, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - scene = service.get_device_scene(device_id) + scene = await service.get_device_scene(device_id) return standard_response(data=scene) @router.post("/update_by_device", response_model=StandardResponse[DeviceSceneRelation]) -def update_by_device(device_id: int, scene_id: int, - db: Session = Depends(get_db)): +async def update_by_device(device_id: int, scene_id: int, + db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - relation = service.update_relation_by_device(device_id, scene_id) + relation = await service.update_relation_by_device(device_id, scene_id) return standard_response(data=relation) diff --git a/apis/frame.py b/apis/frame.py index 1cb9294..5146ede 100644 --- a/apis/frame.py +++ b/apis/frame.py @@ -5,7 +5,7 @@ import cv2 from fastapi import APIRouter, Depends, Query, HTTPException, Request -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from fastapi.responses import StreamingResponse, FileResponse from apis.base import StandardResponse, PageResponse, standard_response, standard_error_response, convert_page_param @@ -19,20 +19,20 @@ @router.get("/page/", response_model=StandardResponse[PageResponse[DeviceFrame]]) -def get_frame_page( +async def get_frame_page( device_name: Optional[str] = None, device_code: Optional[str] = None, frame_start_time: Optional[str] = None, frame_end_time: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceFrameService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - frames, total = service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), - parse_datetime(frame_end_time), offset, limit) + frames, total = await service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), + parse_datetime(frame_end_time), offset, limit) return standard_response( data=PageResponse(total=total, items=frames) @@ -41,10 +41,10 @@ # 路由:使用 OpenCV 生成内存图像并返回字节流 @router.get("/frame_image/") -def get_frame_image(frame_id, db: Session = Depends(get_db)): +async def get_frame_image(frame_id, db: AsyncSession = Depends(get_db)): try: service = DeviceFrameService(db) - frame = service.get_frame_annotator(frame_id) + frame = await service.get_frame_annotator(frame_id) if frame is None: return standard_error_response(message="Frame does not exist") @@ -63,14 +63,3 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"Error reading image: {str(e)}") - -@router.get("/frame_test") -def get_frame_test(request: Request): - file_path = "test.jpg" - return FileResponse(file_path) - # def iterfile(): - # with open(file_path, 'rb') as file: - # while chunk := file.read(1024): - # yield chunk - # - # return StreamingResponse(iterfile(), media_type="image/jpeg") diff --git a/apis/model.py b/apis/model.py index a1161eb..fa2b191 100644 --- a/apis/model.py +++ b/apis/model.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query, UploadFile, File, Form -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from app_instance import get_app @@ -18,27 +18,27 @@ @router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) -def get_model_list( +async def get_model_list( name: Optional[str] = None, remark: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = ModelService(db) - models = list(service.get_model_list(name, remark)) - return standard_response(data=models) + models = await service.get_model_list(name, remark) + return standard_response(data=list(models)) @router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModelInfo]]) -def get_model_page( +async def get_model_page( name: Optional[str] = None, remark: Optional[str] = None, offset: int = Query(0, ge=0), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = ModelService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - models, total = service.get_model_page(name, remark, offset, limit) + models, total = await service.get_model_page(name, remark, offset, limit) return standard_response( data=PageResponse(total=total, items=models) @@ -46,9 +46,9 @@ @router.post("/add", response_model=StandardResponse[AlgoModelInfo]) -def create_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelCreate结构"), +async def create_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelCreate结构"), file: UploadFile = File(..., description="模型文件"), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_db), ): # 检查文件类型 if not file.filename.endswith(".zip"): @@ -56,25 +56,25 @@ model_data = AlgoModelCreate.parse_raw(json_data) service = ModelService(db) - model = service.create_model(model_data, file) + model = await service.create_model(model_data, file) return standard_response(data=model) @router.post("/update", response_model=StandardResponse[AlgoModelInfo]) -def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), +async def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), file: UploadFile = File(None, description="模型文件"), service: ModelService = Depends(get_service)): model_data = AlgoModelUpdate.parse_raw(json_data) - model = service.update_model(model_data, file) + model = await service.update_model(model_data, file) if not model: return standard_error_response(data=model_data, message="Model not found") return standard_response(data=model) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_model(model_id: int, db: Session = Depends(get_db)): +async def delete_model(model_id: int, db: AsyncSession = Depends(get_db)): service = ModelService(db) - model = service.delete_model(model_id) + model = await service.delete_model(model_id) if not model: return standard_error_response(data=model_id, message="Model not found") return standard_response(data=model_id) diff --git a/apis/push_config.py b/apis/push_config.py index 5e25579..e234afb 100644 --- a/apis/push_config.py +++ b/apis/push_config.py @@ -1,7 +1,7 @@ from typing import List from fastapi import APIRouter, Depends -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse from db.database import get_db @@ -12,21 +12,21 @@ @router.get("/list", response_model=StandardResponse[List[PushConfig]]) -def get_push_config_list(db: Session = Depends(get_db)): +async def get_push_config_list(db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_configs = service.get_push_config_list() + push_configs = await service.get_push_config_list() return standard_response(data=push_configs) @router.get("/get_by_type", response_model=StandardResponse[PushConfig]) -def get_by_type(push_type: int, db: Session = Depends(get_db)): +async def get_by_type(push_type: int, db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_config = service.get_push_config(push_type) + push_config = await service.get_push_config(push_type) return standard_response(data=push_config) @router.post("/set_push_config", response_model=StandardResponse[PushConfig]) -def set_push_config(push_config: PushConfigCreate, db: Session = Depends(get_db)): +async def set_push_config(push_config: PushConfigCreate, db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_config = service.set_push_config(push_config) + push_config = await service.set_push_config(push_config) return standard_response(data=push_config) diff --git a/apis/scene.py b/apis/scene.py index f84b54b..1b7f706 100644 --- a/apis/scene.py +++ b/apis/scene.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query, UploadFile, File, Form -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db @@ -14,27 +14,27 @@ @router.get("/list", response_model=StandardResponse[List[Scene]]) -def get_scene_list( +async def get_scene_list( name: Optional[str] = None, remark: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) - scenes = list(service.get_scene_list(name, remark)) - return standard_response(data=scenes) + scenes = await service.get_scene_list(name, remark) + return standard_response(data=list(scenes)) @router.get("/page/", response_model=StandardResponse[PageResponse[SceneInfo]]) -def get_scene_page( +async def get_scene_page( name: Optional[str] = None, remark: Optional[str] = None, offset: int = Query(0, ge=0), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - scenes, total = service.get_scene_page(name, remark, offset, limit) + scenes, total = await service.get_scene_page(name, remark, offset, limit) return standard_response( data=PageResponse(total=total, items=scenes) @@ -42,9 +42,9 @@ @router.post("/add", response_model=StandardResponse[SceneInfo]) -def create_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneCreate结构"), +async def create_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneCreate结构"), file: UploadFile = File(..., description="模型文件"), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_db), ): # 检查文件类型 if not file.filename.endswith(".zip"): @@ -52,26 +52,26 @@ scene_data = SceneCreate.parse_raw(json_data) service = SceneService(db) - scene = service.create_scene(scene_data, file) + scene = await service.create_scene(scene_data, file) return standard_response(data=scene) @router.post("/update", response_model=StandardResponse[SceneInfo]) -def update_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneUpdate结构"), +async def update_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneUpdate结构"), file: UploadFile = File(None, description="模型文件"), - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) scene_data = SceneUpdate.parse_raw(json_data) - scene = service.update_scene(scene_data, file) + scene = await service.update_scene(scene_data, file) if not scene: return standard_error_response(data=scene_data, message="Scene not found") return standard_response(data=scene) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_scene(scene_id: int, db: Session = Depends(get_db)): +async def delete_scene(scene_id: int, db: AsyncSession = Depends(get_db)): service = SceneService(db) - scene = service.delete_scene(scene_id) + scene = await service.delete_scene(scene_id) if not scene: return standard_error_response(data=scene_id, message="Scene not found") return standard_response(data=scene_id) diff --git a/app_instance.py b/app_instance.py index a6e0992..f9d1ad9 100644 --- a/app_instance.py +++ b/app_instance.py @@ -12,6 +12,7 @@ from services.device_model_relation_service import DeviceModelRelationService from services.device_scene_relation_service import DeviceSceneRelationService from services.device_service import DeviceService +from services.global_config import GlobalConfig from services.model_service import ModelService from services.scene_service import SceneService from tcp.tcp_manager import TcpManager @@ -28,7 +29,11 @@ async def lifespan(app: FastAPI): main_loop = asyncio.get_running_loop() - with next(get_db()) as db: + # async with get_db() as db: + async for db in get_db(): + global_config = GlobalConfig() + await global_config.init_config() + device_service = DeviceService(db) model_service = ModelService(db) model_relation_service = DeviceModelRelationService(db) @@ -51,7 +56,7 @@ relation_service=model_relation_service, ) app.state.algo_runner = algo_runner - await algo_runner.start() + # await algo_runner.start() scene_runner = SceneRunner( device_service=device_service, @@ -61,7 +66,9 @@ main_loop=main_loop ) app.state.scene_runner = scene_runner - await scene_runner.start() + # await scene_runner.start() + + yield # 允许请求处理 diff --git a/common/global_logger.py b/common/global_logger.py index 3e3a6a9..3e672ca 100644 --- a/common/global_logger.py +++ b/common/global_logger.py @@ -12,6 +12,8 @@ logger = logging.getLogger("casic_safe_logger") logger.setLevel(logging.DEBUG) # 设置日志级别 +logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING) + # 创建一个TimedRotatingFileHandler handler = logging.handlers.TimedRotatingFileHandler( os.path.join(log_dir, 'app.log'), # 日志文件名 diff --git a/db/database.py b/db/database.py index 4f44d2d..01ac013 100644 --- a/db/database.py +++ b/db/database.py @@ -1,22 +1,25 @@ +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession +from sqlalchemy.orm import sessionmaker from sqlmodel import SQLModel, create_engine, Session -from contextlib import contextmanager +from contextlib import contextmanager, asynccontextmanager sqlite_file_name = "./db/safe-algo-pro.db" -sqlite_url = f"sqlite:///{sqlite_file_name}" +sqlite_url = f"sqlite+aiosqlite:///{sqlite_file_name}" # 使用异步SQLite驱动 -connect_args = {"check_same_thread": False} -engine = create_engine(sqlite_url, connect_args=connect_args) +engine = create_async_engine(sqlite_url, echo=False, future=True) # 初始化数据库表 -def init_db(): - SQLModel.metadata.create_all(engine) +async def init_db(): + async with engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.create_all) -# 数据库会话管理 -def get_db(): - session = Session(engine) - try: +# 异步数据库会话管理 +# @asynccontextmanager +async def get_db() -> AsyncSession: + async_session = sessionmaker( + engine, class_=AsyncSession, expire_on_commit=False + ) + async with async_session() as session: yield session - finally: - session.close() diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index dbc3fee..cdb97bf 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -98,7 +98,7 @@ def push_frame_results(self, frame_results): global_config = GlobalConfig() push_config = global_config.get_algo_result_push_config() - if push_config: + if push_config and push_config.push_url: last_ts = self.push_ts current_time = datetime.now() diff --git a/apis/control.py b/apis/control.py index 8517389..4c108c5 100644 --- a/apis/control.py +++ b/apis/control.py @@ -48,7 +48,7 @@ @router.get("/restart") -def restart(): +async def restart(): try: # 立即返回响应的函数 def restart_container_async(): diff --git a/apis/data_gas.py b/apis/data_gas.py index fc8ec54..bcb7cb6 100644 --- a/apis/data_gas.py +++ b/apis/data_gas.py @@ -3,8 +3,8 @@ import pandas as pd from fastapi import APIRouter, Query, Depends -from openpyxl.utils import get_column_letter -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession + from starlette.responses import StreamingResponse from apis.base import StandardResponse, PageResponse, convert_page_param, standard_response @@ -17,19 +17,20 @@ @router.get("/page", response_model=StandardResponse[PageResponse[DataGasInfo]]) -def get_gas_page( +async def get_gas_page( device_code: Optional[str] = None, start_time: Optional[str] = None, end_time: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): + service = DataGasService(db) offset, limit = convert_page_param(offset, limit) - data, total = service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), - offset, limit) + data, total = await service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), + offset, limit) return standard_response( data=PageResponse(total=total, items=data) @@ -37,38 +38,39 @@ @router.get("/export") -def export_data_gas(device_code: Optional[str] = None, - start_time: Optional[str] = None, - end_time: Optional[str] = None, - db: Session = Depends(get_db)): - service = DataGasService(db) - data = service.get_data_gas_list(device_code, start_time, end_time) +async def export_data_gas(device_code: Optional[str] = None, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + db: AsyncSession = Depends(get_db)): + async with db: + service = DataGasService(db) + data = await service.get_data_gas_list(device_code, start_time, end_time) - # 将查询结果转换为 DataFrame - data = [ - {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} - for item in data] - df = pd.DataFrame(data) + # 将查询结果转换为 DataFrame + data = [ + {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} + for item in data] + df = pd.DataFrame(data) - # 使用 BytesIO 生成内存中的 Excel 文件 - output = BytesIO() - with pd.ExcelWriter(output, engine="openpyxl") as writer: - df.to_excel(writer, index=False) + # 使用 BytesIO 生成内存中的 Excel 文件 + output = BytesIO() + with pd.ExcelWriter(output, engine="openpyxl") as writer: + df.to_excel(writer, index=False) - # 获取工作表 - worksheet = writer.sheets["Sheet1"] + # 获取工作表 + worksheet = writer.sheets["Sheet1"] - # 设置固定的列宽,例如宽度为 20 - fixed_width = 30 - for col in worksheet.columns: - col_letter = col[0].column_letter # 获取列字母 - worksheet.column_dimensions[col_letter].width = fixed_width + # 设置固定的列宽,例如宽度为 20 + fixed_width = 30 + for col in worksheet.columns: + col_letter = col[0].column_letter # 获取列字母 + worksheet.column_dimensions[col_letter].width = fixed_width - output.seek(0) + output.seek(0) - # 返回内存中的 Excel 文件作为响应 - return StreamingResponse( - output, - media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", - headers={"Content-Disposition": "attachment; filename=gas.xlsx"} - ) + # 返回内存中的 Excel 文件作为响应 + return StreamingResponse( + output, + media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + headers={"Content-Disposition": "attachment; filename=gas.xlsx"} + ) diff --git a/apis/device.py b/apis/device.py index 85fb10c..f170ae3 100644 --- a/apis/device.py +++ b/apis/device.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db @@ -18,29 +18,29 @@ @router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) -def get_device_list( +async def get_device_list( name: Optional[str] = None, code: Optional[str] = None, device_type: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceService(db) - devices = list(service.get_device_list(name, code, device_type)) - return standard_response(data=devices) + devices = await service.get_device_list(name, code, device_type) + return standard_response(data=list(devices)) @router.get("/page/", response_model=StandardResponse[PageResponse[DeviceInfo]]) -def get_device_page( +async def get_device_page( name: Optional[str] = None, code: Optional[str] = None, device_type: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - devices, total = service.get_device_page(name, code, device_type, offset, limit) + devices, total = await service.get_device_page(name, code, device_type, offset, limit) return standard_response( data=PageResponse(total=total, items=devices) @@ -48,22 +48,22 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): - device = service.create_device(device_data) +async def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): + device = await service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): - device = service.update_device(device_data) +async def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): + device = await service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") return standard_response(data=device) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, service: DeviceService = Depends(get_service)): - device = service.delete_device(device_id) +async def delete_device(device_id: int, service: DeviceService = Depends(get_service)): + device = await service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") return standard_response(data=device_id) diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py index e61e1d6..ddb5eb6 100644 --- a/apis/device_model_realtion.py +++ b/apis/device_model_realtion.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import StandardResponse, standard_response from db.database import get_db @@ -13,37 +13,23 @@ router = APIRouter() app = get_app() + def get_service(): return app.state.model_relation_service @router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) -def list_by_device( +async def list_by_device( device_id: int, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceModelRelationService(db) - models = list(service.get_device_models(device_id)) - return standard_response(data=models) - - -# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) -# def add_by_device(relation_data: List[DeviceModelRelationCreate], -# device_id: int = Query(...), -# db: Session = Depends(get_db)): -# service = DeviceModelRelationService(db) -# relations = service.add_relations_by_device(device_id, relation_data) -# return standard_response(data=relations) + models = await service.get_device_models(device_id) + return standard_response(data=list(models)) @router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) -def update_by_device(relation_data: List[DeviceModelRelationCreate], - device_id: int = Query(...), - service: DeviceModelRelationService = Depends(get_service)): - relations = service.update_relations_by_device(device_id, relation_data) +async def update_by_device(relation_data: List[DeviceModelRelationCreate], + device_id: int = Query(...), + service: DeviceModelRelationService = Depends(get_service)): + relations = await service.update_relations_by_device(device_id, relation_data) return standard_response(data=relations) - -# @router.delete("/delete_by_device", response_model=StandardResponse[int]) -# def delete_device(device_id: int, db: Session = Depends(get_db)): -# service = DeviceModelRelationService(db) -# count = service.delete_relations_by_device(device_id) -# return standard_response(data=count) diff --git a/apis/device_scene_realtion.py b/apis/device_scene_realtion.py index 9a30a44..d9aa3a2 100644 --- a/apis/device_scene_realtion.py +++ b/apis/device_scene_realtion.py @@ -1,8 +1,8 @@ -from typing import List, Optional + from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import StandardResponse, standard_response from db.database import get_db from entity.device_scene_relation import DeviceSceneRelationInfo, DeviceSceneRelation @@ -13,17 +13,17 @@ @router.get("/get_by_device", response_model=StandardResponse[DeviceSceneRelationInfo]) -def list_by_device( +async def list_by_device( device_id: int, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - scene = service.get_device_scene(device_id) + scene = await service.get_device_scene(device_id) return standard_response(data=scene) @router.post("/update_by_device", response_model=StandardResponse[DeviceSceneRelation]) -def update_by_device(device_id: int, scene_id: int, - db: Session = Depends(get_db)): +async def update_by_device(device_id: int, scene_id: int, + db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - relation = service.update_relation_by_device(device_id, scene_id) + relation = await service.update_relation_by_device(device_id, scene_id) return standard_response(data=relation) diff --git a/apis/frame.py b/apis/frame.py index 1cb9294..5146ede 100644 --- a/apis/frame.py +++ b/apis/frame.py @@ -5,7 +5,7 @@ import cv2 from fastapi import APIRouter, Depends, Query, HTTPException, Request -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from fastapi.responses import StreamingResponse, FileResponse from apis.base import StandardResponse, PageResponse, standard_response, standard_error_response, convert_page_param @@ -19,20 +19,20 @@ @router.get("/page/", response_model=StandardResponse[PageResponse[DeviceFrame]]) -def get_frame_page( +async def get_frame_page( device_name: Optional[str] = None, device_code: Optional[str] = None, frame_start_time: Optional[str] = None, frame_end_time: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceFrameService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - frames, total = service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), - parse_datetime(frame_end_time), offset, limit) + frames, total = await service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), + parse_datetime(frame_end_time), offset, limit) return standard_response( data=PageResponse(total=total, items=frames) @@ -41,10 +41,10 @@ # 路由:使用 OpenCV 生成内存图像并返回字节流 @router.get("/frame_image/") -def get_frame_image(frame_id, db: Session = Depends(get_db)): +async def get_frame_image(frame_id, db: AsyncSession = Depends(get_db)): try: service = DeviceFrameService(db) - frame = service.get_frame_annotator(frame_id) + frame = await service.get_frame_annotator(frame_id) if frame is None: return standard_error_response(message="Frame does not exist") @@ -63,14 +63,3 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"Error reading image: {str(e)}") - -@router.get("/frame_test") -def get_frame_test(request: Request): - file_path = "test.jpg" - return FileResponse(file_path) - # def iterfile(): - # with open(file_path, 'rb') as file: - # while chunk := file.read(1024): - # yield chunk - # - # return StreamingResponse(iterfile(), media_type="image/jpeg") diff --git a/apis/model.py b/apis/model.py index a1161eb..fa2b191 100644 --- a/apis/model.py +++ b/apis/model.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query, UploadFile, File, Form -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from app_instance import get_app @@ -18,27 +18,27 @@ @router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) -def get_model_list( +async def get_model_list( name: Optional[str] = None, remark: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = ModelService(db) - models = list(service.get_model_list(name, remark)) - return standard_response(data=models) + models = await service.get_model_list(name, remark) + return standard_response(data=list(models)) @router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModelInfo]]) -def get_model_page( +async def get_model_page( name: Optional[str] = None, remark: Optional[str] = None, offset: int = Query(0, ge=0), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = ModelService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - models, total = service.get_model_page(name, remark, offset, limit) + models, total = await service.get_model_page(name, remark, offset, limit) return standard_response( data=PageResponse(total=total, items=models) @@ -46,9 +46,9 @@ @router.post("/add", response_model=StandardResponse[AlgoModelInfo]) -def create_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelCreate结构"), +async def create_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelCreate结构"), file: UploadFile = File(..., description="模型文件"), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_db), ): # 检查文件类型 if not file.filename.endswith(".zip"): @@ -56,25 +56,25 @@ model_data = AlgoModelCreate.parse_raw(json_data) service = ModelService(db) - model = service.create_model(model_data, file) + model = await service.create_model(model_data, file) return standard_response(data=model) @router.post("/update", response_model=StandardResponse[AlgoModelInfo]) -def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), +async def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), file: UploadFile = File(None, description="模型文件"), service: ModelService = Depends(get_service)): model_data = AlgoModelUpdate.parse_raw(json_data) - model = service.update_model(model_data, file) + model = await service.update_model(model_data, file) if not model: return standard_error_response(data=model_data, message="Model not found") return standard_response(data=model) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_model(model_id: int, db: Session = Depends(get_db)): +async def delete_model(model_id: int, db: AsyncSession = Depends(get_db)): service = ModelService(db) - model = service.delete_model(model_id) + model = await service.delete_model(model_id) if not model: return standard_error_response(data=model_id, message="Model not found") return standard_response(data=model_id) diff --git a/apis/push_config.py b/apis/push_config.py index 5e25579..e234afb 100644 --- a/apis/push_config.py +++ b/apis/push_config.py @@ -1,7 +1,7 @@ from typing import List from fastapi import APIRouter, Depends -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse from db.database import get_db @@ -12,21 +12,21 @@ @router.get("/list", response_model=StandardResponse[List[PushConfig]]) -def get_push_config_list(db: Session = Depends(get_db)): +async def get_push_config_list(db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_configs = service.get_push_config_list() + push_configs = await service.get_push_config_list() return standard_response(data=push_configs) @router.get("/get_by_type", response_model=StandardResponse[PushConfig]) -def get_by_type(push_type: int, db: Session = Depends(get_db)): +async def get_by_type(push_type: int, db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_config = service.get_push_config(push_type) + push_config = await service.get_push_config(push_type) return standard_response(data=push_config) @router.post("/set_push_config", response_model=StandardResponse[PushConfig]) -def set_push_config(push_config: PushConfigCreate, db: Session = Depends(get_db)): +async def set_push_config(push_config: PushConfigCreate, db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_config = service.set_push_config(push_config) + push_config = await service.set_push_config(push_config) return standard_response(data=push_config) diff --git a/apis/scene.py b/apis/scene.py index f84b54b..1b7f706 100644 --- a/apis/scene.py +++ b/apis/scene.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query, UploadFile, File, Form -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db @@ -14,27 +14,27 @@ @router.get("/list", response_model=StandardResponse[List[Scene]]) -def get_scene_list( +async def get_scene_list( name: Optional[str] = None, remark: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) - scenes = list(service.get_scene_list(name, remark)) - return standard_response(data=scenes) + scenes = await service.get_scene_list(name, remark) + return standard_response(data=list(scenes)) @router.get("/page/", response_model=StandardResponse[PageResponse[SceneInfo]]) -def get_scene_page( +async def get_scene_page( name: Optional[str] = None, remark: Optional[str] = None, offset: int = Query(0, ge=0), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - scenes, total = service.get_scene_page(name, remark, offset, limit) + scenes, total = await service.get_scene_page(name, remark, offset, limit) return standard_response( data=PageResponse(total=total, items=scenes) @@ -42,9 +42,9 @@ @router.post("/add", response_model=StandardResponse[SceneInfo]) -def create_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneCreate结构"), +async def create_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneCreate结构"), file: UploadFile = File(..., description="模型文件"), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_db), ): # 检查文件类型 if not file.filename.endswith(".zip"): @@ -52,26 +52,26 @@ scene_data = SceneCreate.parse_raw(json_data) service = SceneService(db) - scene = service.create_scene(scene_data, file) + scene = await service.create_scene(scene_data, file) return standard_response(data=scene) @router.post("/update", response_model=StandardResponse[SceneInfo]) -def update_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneUpdate结构"), +async def update_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneUpdate结构"), file: UploadFile = File(None, description="模型文件"), - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) scene_data = SceneUpdate.parse_raw(json_data) - scene = service.update_scene(scene_data, file) + scene = await service.update_scene(scene_data, file) if not scene: return standard_error_response(data=scene_data, message="Scene not found") return standard_response(data=scene) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_scene(scene_id: int, db: Session = Depends(get_db)): +async def delete_scene(scene_id: int, db: AsyncSession = Depends(get_db)): service = SceneService(db) - scene = service.delete_scene(scene_id) + scene = await service.delete_scene(scene_id) if not scene: return standard_error_response(data=scene_id, message="Scene not found") return standard_response(data=scene_id) diff --git a/app_instance.py b/app_instance.py index a6e0992..f9d1ad9 100644 --- a/app_instance.py +++ b/app_instance.py @@ -12,6 +12,7 @@ from services.device_model_relation_service import DeviceModelRelationService from services.device_scene_relation_service import DeviceSceneRelationService from services.device_service import DeviceService +from services.global_config import GlobalConfig from services.model_service import ModelService from services.scene_service import SceneService from tcp.tcp_manager import TcpManager @@ -28,7 +29,11 @@ async def lifespan(app: FastAPI): main_loop = asyncio.get_running_loop() - with next(get_db()) as db: + # async with get_db() as db: + async for db in get_db(): + global_config = GlobalConfig() + await global_config.init_config() + device_service = DeviceService(db) model_service = ModelService(db) model_relation_service = DeviceModelRelationService(db) @@ -51,7 +56,7 @@ relation_service=model_relation_service, ) app.state.algo_runner = algo_runner - await algo_runner.start() + # await algo_runner.start() scene_runner = SceneRunner( device_service=device_service, @@ -61,7 +66,9 @@ main_loop=main_loop ) app.state.scene_runner = scene_runner - await scene_runner.start() + # await scene_runner.start() + + yield # 允许请求处理 diff --git a/common/global_logger.py b/common/global_logger.py index 3e3a6a9..3e672ca 100644 --- a/common/global_logger.py +++ b/common/global_logger.py @@ -12,6 +12,8 @@ logger = logging.getLogger("casic_safe_logger") logger.setLevel(logging.DEBUG) # 设置日志级别 +logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING) + # 创建一个TimedRotatingFileHandler handler = logging.handlers.TimedRotatingFileHandler( os.path.join(log_dir, 'app.log'), # 日志文件名 diff --git a/db/database.py b/db/database.py index 4f44d2d..01ac013 100644 --- a/db/database.py +++ b/db/database.py @@ -1,22 +1,25 @@ +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession +from sqlalchemy.orm import sessionmaker from sqlmodel import SQLModel, create_engine, Session -from contextlib import contextmanager +from contextlib import contextmanager, asynccontextmanager sqlite_file_name = "./db/safe-algo-pro.db" -sqlite_url = f"sqlite:///{sqlite_file_name}" +sqlite_url = f"sqlite+aiosqlite:///{sqlite_file_name}" # 使用异步SQLite驱动 -connect_args = {"check_same_thread": False} -engine = create_engine(sqlite_url, connect_args=connect_args) +engine = create_async_engine(sqlite_url, echo=False, future=True) # 初始化数据库表 -def init_db(): - SQLModel.metadata.create_all(engine) +async def init_db(): + async with engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.create_all) -# 数据库会话管理 -def get_db(): - session = Session(engine) - try: +# 异步数据库会话管理 +# @asynccontextmanager +async def get_db() -> AsyncSession: + async_session = sessionmaker( + engine, class_=AsyncSession, expire_on_commit=False + ) + async with async_session() as session: yield session - finally: - session.close() diff --git a/main.py b/main.py index e35e458..e830e86 100644 --- a/main.py +++ b/main.py @@ -3,7 +3,9 @@ import logging from fastapi.openapi.docs import get_swagger_ui_html -from starlette.staticfiles import StaticFiles +from fastapi.middleware.cors import CORSMiddleware +from fastapi.staticfiles import StaticFiles + from app_instance import get_app from common.global_logger import logger @@ -28,6 +30,14 @@ from apis.router import router app.include_router(router, prefix="/api") +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + if __name__ == "__main__": # 重定向 uvicorn 的日志 uvicorn_logger = logging.getLogger("uvicorn") diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index dbc3fee..cdb97bf 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -98,7 +98,7 @@ def push_frame_results(self, frame_results): global_config = GlobalConfig() push_config = global_config.get_algo_result_push_config() - if push_config: + if push_config and push_config.push_url: last_ts = self.push_ts current_time = datetime.now() diff --git a/apis/control.py b/apis/control.py index 8517389..4c108c5 100644 --- a/apis/control.py +++ b/apis/control.py @@ -48,7 +48,7 @@ @router.get("/restart") -def restart(): +async def restart(): try: # 立即返回响应的函数 def restart_container_async(): diff --git a/apis/data_gas.py b/apis/data_gas.py index fc8ec54..bcb7cb6 100644 --- a/apis/data_gas.py +++ b/apis/data_gas.py @@ -3,8 +3,8 @@ import pandas as pd from fastapi import APIRouter, Query, Depends -from openpyxl.utils import get_column_letter -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession + from starlette.responses import StreamingResponse from apis.base import StandardResponse, PageResponse, convert_page_param, standard_response @@ -17,19 +17,20 @@ @router.get("/page", response_model=StandardResponse[PageResponse[DataGasInfo]]) -def get_gas_page( +async def get_gas_page( device_code: Optional[str] = None, start_time: Optional[str] = None, end_time: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): + service = DataGasService(db) offset, limit = convert_page_param(offset, limit) - data, total = service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), - offset, limit) + data, total = await service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), + offset, limit) return standard_response( data=PageResponse(total=total, items=data) @@ -37,38 +38,39 @@ @router.get("/export") -def export_data_gas(device_code: Optional[str] = None, - start_time: Optional[str] = None, - end_time: Optional[str] = None, - db: Session = Depends(get_db)): - service = DataGasService(db) - data = service.get_data_gas_list(device_code, start_time, end_time) +async def export_data_gas(device_code: Optional[str] = None, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + db: AsyncSession = Depends(get_db)): + async with db: + service = DataGasService(db) + data = await service.get_data_gas_list(device_code, start_time, end_time) - # 将查询结果转换为 DataFrame - data = [ - {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} - for item in data] - df = pd.DataFrame(data) + # 将查询结果转换为 DataFrame + data = [ + {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} + for item in data] + df = pd.DataFrame(data) - # 使用 BytesIO 生成内存中的 Excel 文件 - output = BytesIO() - with pd.ExcelWriter(output, engine="openpyxl") as writer: - df.to_excel(writer, index=False) + # 使用 BytesIO 生成内存中的 Excel 文件 + output = BytesIO() + with pd.ExcelWriter(output, engine="openpyxl") as writer: + df.to_excel(writer, index=False) - # 获取工作表 - worksheet = writer.sheets["Sheet1"] + # 获取工作表 + worksheet = writer.sheets["Sheet1"] - # 设置固定的列宽,例如宽度为 20 - fixed_width = 30 - for col in worksheet.columns: - col_letter = col[0].column_letter # 获取列字母 - worksheet.column_dimensions[col_letter].width = fixed_width + # 设置固定的列宽,例如宽度为 20 + fixed_width = 30 + for col in worksheet.columns: + col_letter = col[0].column_letter # 获取列字母 + worksheet.column_dimensions[col_letter].width = fixed_width - output.seek(0) + output.seek(0) - # 返回内存中的 Excel 文件作为响应 - return StreamingResponse( - output, - media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", - headers={"Content-Disposition": "attachment; filename=gas.xlsx"} - ) + # 返回内存中的 Excel 文件作为响应 + return StreamingResponse( + output, + media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + headers={"Content-Disposition": "attachment; filename=gas.xlsx"} + ) diff --git a/apis/device.py b/apis/device.py index 85fb10c..f170ae3 100644 --- a/apis/device.py +++ b/apis/device.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db @@ -18,29 +18,29 @@ @router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) -def get_device_list( +async def get_device_list( name: Optional[str] = None, code: Optional[str] = None, device_type: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceService(db) - devices = list(service.get_device_list(name, code, device_type)) - return standard_response(data=devices) + devices = await service.get_device_list(name, code, device_type) + return standard_response(data=list(devices)) @router.get("/page/", response_model=StandardResponse[PageResponse[DeviceInfo]]) -def get_device_page( +async def get_device_page( name: Optional[str] = None, code: Optional[str] = None, device_type: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - devices, total = service.get_device_page(name, code, device_type, offset, limit) + devices, total = await service.get_device_page(name, code, device_type, offset, limit) return standard_response( data=PageResponse(total=total, items=devices) @@ -48,22 +48,22 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): - device = service.create_device(device_data) +async def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): + device = await service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): - device = service.update_device(device_data) +async def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): + device = await service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") return standard_response(data=device) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, service: DeviceService = Depends(get_service)): - device = service.delete_device(device_id) +async def delete_device(device_id: int, service: DeviceService = Depends(get_service)): + device = await service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") return standard_response(data=device_id) diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py index e61e1d6..ddb5eb6 100644 --- a/apis/device_model_realtion.py +++ b/apis/device_model_realtion.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import StandardResponse, standard_response from db.database import get_db @@ -13,37 +13,23 @@ router = APIRouter() app = get_app() + def get_service(): return app.state.model_relation_service @router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) -def list_by_device( +async def list_by_device( device_id: int, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceModelRelationService(db) - models = list(service.get_device_models(device_id)) - return standard_response(data=models) - - -# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) -# def add_by_device(relation_data: List[DeviceModelRelationCreate], -# device_id: int = Query(...), -# db: Session = Depends(get_db)): -# service = DeviceModelRelationService(db) -# relations = service.add_relations_by_device(device_id, relation_data) -# return standard_response(data=relations) + models = await service.get_device_models(device_id) + return standard_response(data=list(models)) @router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) -def update_by_device(relation_data: List[DeviceModelRelationCreate], - device_id: int = Query(...), - service: DeviceModelRelationService = Depends(get_service)): - relations = service.update_relations_by_device(device_id, relation_data) +async def update_by_device(relation_data: List[DeviceModelRelationCreate], + device_id: int = Query(...), + service: DeviceModelRelationService = Depends(get_service)): + relations = await service.update_relations_by_device(device_id, relation_data) return standard_response(data=relations) - -# @router.delete("/delete_by_device", response_model=StandardResponse[int]) -# def delete_device(device_id: int, db: Session = Depends(get_db)): -# service = DeviceModelRelationService(db) -# count = service.delete_relations_by_device(device_id) -# return standard_response(data=count) diff --git a/apis/device_scene_realtion.py b/apis/device_scene_realtion.py index 9a30a44..d9aa3a2 100644 --- a/apis/device_scene_realtion.py +++ b/apis/device_scene_realtion.py @@ -1,8 +1,8 @@ -from typing import List, Optional + from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import StandardResponse, standard_response from db.database import get_db from entity.device_scene_relation import DeviceSceneRelationInfo, DeviceSceneRelation @@ -13,17 +13,17 @@ @router.get("/get_by_device", response_model=StandardResponse[DeviceSceneRelationInfo]) -def list_by_device( +async def list_by_device( device_id: int, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - scene = service.get_device_scene(device_id) + scene = await service.get_device_scene(device_id) return standard_response(data=scene) @router.post("/update_by_device", response_model=StandardResponse[DeviceSceneRelation]) -def update_by_device(device_id: int, scene_id: int, - db: Session = Depends(get_db)): +async def update_by_device(device_id: int, scene_id: int, + db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - relation = service.update_relation_by_device(device_id, scene_id) + relation = await service.update_relation_by_device(device_id, scene_id) return standard_response(data=relation) diff --git a/apis/frame.py b/apis/frame.py index 1cb9294..5146ede 100644 --- a/apis/frame.py +++ b/apis/frame.py @@ -5,7 +5,7 @@ import cv2 from fastapi import APIRouter, Depends, Query, HTTPException, Request -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from fastapi.responses import StreamingResponse, FileResponse from apis.base import StandardResponse, PageResponse, standard_response, standard_error_response, convert_page_param @@ -19,20 +19,20 @@ @router.get("/page/", response_model=StandardResponse[PageResponse[DeviceFrame]]) -def get_frame_page( +async def get_frame_page( device_name: Optional[str] = None, device_code: Optional[str] = None, frame_start_time: Optional[str] = None, frame_end_time: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceFrameService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - frames, total = service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), - parse_datetime(frame_end_time), offset, limit) + frames, total = await service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), + parse_datetime(frame_end_time), offset, limit) return standard_response( data=PageResponse(total=total, items=frames) @@ -41,10 +41,10 @@ # 路由:使用 OpenCV 生成内存图像并返回字节流 @router.get("/frame_image/") -def get_frame_image(frame_id, db: Session = Depends(get_db)): +async def get_frame_image(frame_id, db: AsyncSession = Depends(get_db)): try: service = DeviceFrameService(db) - frame = service.get_frame_annotator(frame_id) + frame = await service.get_frame_annotator(frame_id) if frame is None: return standard_error_response(message="Frame does not exist") @@ -63,14 +63,3 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"Error reading image: {str(e)}") - -@router.get("/frame_test") -def get_frame_test(request: Request): - file_path = "test.jpg" - return FileResponse(file_path) - # def iterfile(): - # with open(file_path, 'rb') as file: - # while chunk := file.read(1024): - # yield chunk - # - # return StreamingResponse(iterfile(), media_type="image/jpeg") diff --git a/apis/model.py b/apis/model.py index a1161eb..fa2b191 100644 --- a/apis/model.py +++ b/apis/model.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query, UploadFile, File, Form -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from app_instance import get_app @@ -18,27 +18,27 @@ @router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) -def get_model_list( +async def get_model_list( name: Optional[str] = None, remark: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = ModelService(db) - models = list(service.get_model_list(name, remark)) - return standard_response(data=models) + models = await service.get_model_list(name, remark) + return standard_response(data=list(models)) @router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModelInfo]]) -def get_model_page( +async def get_model_page( name: Optional[str] = None, remark: Optional[str] = None, offset: int = Query(0, ge=0), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = ModelService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - models, total = service.get_model_page(name, remark, offset, limit) + models, total = await service.get_model_page(name, remark, offset, limit) return standard_response( data=PageResponse(total=total, items=models) @@ -46,9 +46,9 @@ @router.post("/add", response_model=StandardResponse[AlgoModelInfo]) -def create_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelCreate结构"), +async def create_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelCreate结构"), file: UploadFile = File(..., description="模型文件"), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_db), ): # 检查文件类型 if not file.filename.endswith(".zip"): @@ -56,25 +56,25 @@ model_data = AlgoModelCreate.parse_raw(json_data) service = ModelService(db) - model = service.create_model(model_data, file) + model = await service.create_model(model_data, file) return standard_response(data=model) @router.post("/update", response_model=StandardResponse[AlgoModelInfo]) -def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), +async def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), file: UploadFile = File(None, description="模型文件"), service: ModelService = Depends(get_service)): model_data = AlgoModelUpdate.parse_raw(json_data) - model = service.update_model(model_data, file) + model = await service.update_model(model_data, file) if not model: return standard_error_response(data=model_data, message="Model not found") return standard_response(data=model) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_model(model_id: int, db: Session = Depends(get_db)): +async def delete_model(model_id: int, db: AsyncSession = Depends(get_db)): service = ModelService(db) - model = service.delete_model(model_id) + model = await service.delete_model(model_id) if not model: return standard_error_response(data=model_id, message="Model not found") return standard_response(data=model_id) diff --git a/apis/push_config.py b/apis/push_config.py index 5e25579..e234afb 100644 --- a/apis/push_config.py +++ b/apis/push_config.py @@ -1,7 +1,7 @@ from typing import List from fastapi import APIRouter, Depends -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse from db.database import get_db @@ -12,21 +12,21 @@ @router.get("/list", response_model=StandardResponse[List[PushConfig]]) -def get_push_config_list(db: Session = Depends(get_db)): +async def get_push_config_list(db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_configs = service.get_push_config_list() + push_configs = await service.get_push_config_list() return standard_response(data=push_configs) @router.get("/get_by_type", response_model=StandardResponse[PushConfig]) -def get_by_type(push_type: int, db: Session = Depends(get_db)): +async def get_by_type(push_type: int, db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_config = service.get_push_config(push_type) + push_config = await service.get_push_config(push_type) return standard_response(data=push_config) @router.post("/set_push_config", response_model=StandardResponse[PushConfig]) -def set_push_config(push_config: PushConfigCreate, db: Session = Depends(get_db)): +async def set_push_config(push_config: PushConfigCreate, db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_config = service.set_push_config(push_config) + push_config = await service.set_push_config(push_config) return standard_response(data=push_config) diff --git a/apis/scene.py b/apis/scene.py index f84b54b..1b7f706 100644 --- a/apis/scene.py +++ b/apis/scene.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query, UploadFile, File, Form -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db @@ -14,27 +14,27 @@ @router.get("/list", response_model=StandardResponse[List[Scene]]) -def get_scene_list( +async def get_scene_list( name: Optional[str] = None, remark: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) - scenes = list(service.get_scene_list(name, remark)) - return standard_response(data=scenes) + scenes = await service.get_scene_list(name, remark) + return standard_response(data=list(scenes)) @router.get("/page/", response_model=StandardResponse[PageResponse[SceneInfo]]) -def get_scene_page( +async def get_scene_page( name: Optional[str] = None, remark: Optional[str] = None, offset: int = Query(0, ge=0), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - scenes, total = service.get_scene_page(name, remark, offset, limit) + scenes, total = await service.get_scene_page(name, remark, offset, limit) return standard_response( data=PageResponse(total=total, items=scenes) @@ -42,9 +42,9 @@ @router.post("/add", response_model=StandardResponse[SceneInfo]) -def create_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneCreate结构"), +async def create_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneCreate结构"), file: UploadFile = File(..., description="模型文件"), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_db), ): # 检查文件类型 if not file.filename.endswith(".zip"): @@ -52,26 +52,26 @@ scene_data = SceneCreate.parse_raw(json_data) service = SceneService(db) - scene = service.create_scene(scene_data, file) + scene = await service.create_scene(scene_data, file) return standard_response(data=scene) @router.post("/update", response_model=StandardResponse[SceneInfo]) -def update_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneUpdate结构"), +async def update_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneUpdate结构"), file: UploadFile = File(None, description="模型文件"), - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) scene_data = SceneUpdate.parse_raw(json_data) - scene = service.update_scene(scene_data, file) + scene = await service.update_scene(scene_data, file) if not scene: return standard_error_response(data=scene_data, message="Scene not found") return standard_response(data=scene) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_scene(scene_id: int, db: Session = Depends(get_db)): +async def delete_scene(scene_id: int, db: AsyncSession = Depends(get_db)): service = SceneService(db) - scene = service.delete_scene(scene_id) + scene = await service.delete_scene(scene_id) if not scene: return standard_error_response(data=scene_id, message="Scene not found") return standard_response(data=scene_id) diff --git a/app_instance.py b/app_instance.py index a6e0992..f9d1ad9 100644 --- a/app_instance.py +++ b/app_instance.py @@ -12,6 +12,7 @@ from services.device_model_relation_service import DeviceModelRelationService from services.device_scene_relation_service import DeviceSceneRelationService from services.device_service import DeviceService +from services.global_config import GlobalConfig from services.model_service import ModelService from services.scene_service import SceneService from tcp.tcp_manager import TcpManager @@ -28,7 +29,11 @@ async def lifespan(app: FastAPI): main_loop = asyncio.get_running_loop() - with next(get_db()) as db: + # async with get_db() as db: + async for db in get_db(): + global_config = GlobalConfig() + await global_config.init_config() + device_service = DeviceService(db) model_service = ModelService(db) model_relation_service = DeviceModelRelationService(db) @@ -51,7 +56,7 @@ relation_service=model_relation_service, ) app.state.algo_runner = algo_runner - await algo_runner.start() + # await algo_runner.start() scene_runner = SceneRunner( device_service=device_service, @@ -61,7 +66,9 @@ main_loop=main_loop ) app.state.scene_runner = scene_runner - await scene_runner.start() + # await scene_runner.start() + + yield # 允许请求处理 diff --git a/common/global_logger.py b/common/global_logger.py index 3e3a6a9..3e672ca 100644 --- a/common/global_logger.py +++ b/common/global_logger.py @@ -12,6 +12,8 @@ logger = logging.getLogger("casic_safe_logger") logger.setLevel(logging.DEBUG) # 设置日志级别 +logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING) + # 创建一个TimedRotatingFileHandler handler = logging.handlers.TimedRotatingFileHandler( os.path.join(log_dir, 'app.log'), # 日志文件名 diff --git a/db/database.py b/db/database.py index 4f44d2d..01ac013 100644 --- a/db/database.py +++ b/db/database.py @@ -1,22 +1,25 @@ +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession +from sqlalchemy.orm import sessionmaker from sqlmodel import SQLModel, create_engine, Session -from contextlib import contextmanager +from contextlib import contextmanager, asynccontextmanager sqlite_file_name = "./db/safe-algo-pro.db" -sqlite_url = f"sqlite:///{sqlite_file_name}" +sqlite_url = f"sqlite+aiosqlite:///{sqlite_file_name}" # 使用异步SQLite驱动 -connect_args = {"check_same_thread": False} -engine = create_engine(sqlite_url, connect_args=connect_args) +engine = create_async_engine(sqlite_url, echo=False, future=True) # 初始化数据库表 -def init_db(): - SQLModel.metadata.create_all(engine) +async def init_db(): + async with engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.create_all) -# 数据库会话管理 -def get_db(): - session = Session(engine) - try: +# 异步数据库会话管理 +# @asynccontextmanager +async def get_db() -> AsyncSession: + async_session = sessionmaker( + engine, class_=AsyncSession, expire_on_commit=False + ) + async with async_session() as session: yield session - finally: - session.close() diff --git a/main.py b/main.py index e35e458..e830e86 100644 --- a/main.py +++ b/main.py @@ -3,7 +3,9 @@ import logging from fastapi.openapi.docs import get_swagger_ui_html -from starlette.staticfiles import StaticFiles +from fastapi.middleware.cors import CORSMiddleware +from fastapi.staticfiles import StaticFiles + from app_instance import get_app from common.global_logger import logger @@ -28,6 +30,14 @@ from apis.router import router app.include_router(router, prefix="/api") +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + if __name__ == "__main__": # 重定向 uvicorn 的日志 uvicorn_logger = logging.getLogger("uvicorn") diff --git a/requirements.txt b/requirements.txt index 161192a..aa9e1e6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,6 @@ starlette uvicorn sqlalchemy -aiohttp \ No newline at end of file +aiohttp +aiosqlite +aiofiles \ No newline at end of file diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index dbc3fee..cdb97bf 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -98,7 +98,7 @@ def push_frame_results(self, frame_results): global_config = GlobalConfig() push_config = global_config.get_algo_result_push_config() - if push_config: + if push_config and push_config.push_url: last_ts = self.push_ts current_time = datetime.now() diff --git a/apis/control.py b/apis/control.py index 8517389..4c108c5 100644 --- a/apis/control.py +++ b/apis/control.py @@ -48,7 +48,7 @@ @router.get("/restart") -def restart(): +async def restart(): try: # 立即返回响应的函数 def restart_container_async(): diff --git a/apis/data_gas.py b/apis/data_gas.py index fc8ec54..bcb7cb6 100644 --- a/apis/data_gas.py +++ b/apis/data_gas.py @@ -3,8 +3,8 @@ import pandas as pd from fastapi import APIRouter, Query, Depends -from openpyxl.utils import get_column_letter -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession + from starlette.responses import StreamingResponse from apis.base import StandardResponse, PageResponse, convert_page_param, standard_response @@ -17,19 +17,20 @@ @router.get("/page", response_model=StandardResponse[PageResponse[DataGasInfo]]) -def get_gas_page( +async def get_gas_page( device_code: Optional[str] = None, start_time: Optional[str] = None, end_time: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): + service = DataGasService(db) offset, limit = convert_page_param(offset, limit) - data, total = service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), - offset, limit) + data, total = await service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), + offset, limit) return standard_response( data=PageResponse(total=total, items=data) @@ -37,38 +38,39 @@ @router.get("/export") -def export_data_gas(device_code: Optional[str] = None, - start_time: Optional[str] = None, - end_time: Optional[str] = None, - db: Session = Depends(get_db)): - service = DataGasService(db) - data = service.get_data_gas_list(device_code, start_time, end_time) +async def export_data_gas(device_code: Optional[str] = None, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + db: AsyncSession = Depends(get_db)): + async with db: + service = DataGasService(db) + data = await service.get_data_gas_list(device_code, start_time, end_time) - # 将查询结果转换为 DataFrame - data = [ - {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} - for item in data] - df = pd.DataFrame(data) + # 将查询结果转换为 DataFrame + data = [ + {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} + for item in data] + df = pd.DataFrame(data) - # 使用 BytesIO 生成内存中的 Excel 文件 - output = BytesIO() - with pd.ExcelWriter(output, engine="openpyxl") as writer: - df.to_excel(writer, index=False) + # 使用 BytesIO 生成内存中的 Excel 文件 + output = BytesIO() + with pd.ExcelWriter(output, engine="openpyxl") as writer: + df.to_excel(writer, index=False) - # 获取工作表 - worksheet = writer.sheets["Sheet1"] + # 获取工作表 + worksheet = writer.sheets["Sheet1"] - # 设置固定的列宽,例如宽度为 20 - fixed_width = 30 - for col in worksheet.columns: - col_letter = col[0].column_letter # 获取列字母 - worksheet.column_dimensions[col_letter].width = fixed_width + # 设置固定的列宽,例如宽度为 20 + fixed_width = 30 + for col in worksheet.columns: + col_letter = col[0].column_letter # 获取列字母 + worksheet.column_dimensions[col_letter].width = fixed_width - output.seek(0) + output.seek(0) - # 返回内存中的 Excel 文件作为响应 - return StreamingResponse( - output, - media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", - headers={"Content-Disposition": "attachment; filename=gas.xlsx"} - ) + # 返回内存中的 Excel 文件作为响应 + return StreamingResponse( + output, + media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + headers={"Content-Disposition": "attachment; filename=gas.xlsx"} + ) diff --git a/apis/device.py b/apis/device.py index 85fb10c..f170ae3 100644 --- a/apis/device.py +++ b/apis/device.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db @@ -18,29 +18,29 @@ @router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) -def get_device_list( +async def get_device_list( name: Optional[str] = None, code: Optional[str] = None, device_type: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceService(db) - devices = list(service.get_device_list(name, code, device_type)) - return standard_response(data=devices) + devices = await service.get_device_list(name, code, device_type) + return standard_response(data=list(devices)) @router.get("/page/", response_model=StandardResponse[PageResponse[DeviceInfo]]) -def get_device_page( +async def get_device_page( name: Optional[str] = None, code: Optional[str] = None, device_type: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - devices, total = service.get_device_page(name, code, device_type, offset, limit) + devices, total = await service.get_device_page(name, code, device_type, offset, limit) return standard_response( data=PageResponse(total=total, items=devices) @@ -48,22 +48,22 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): - device = service.create_device(device_data) +async def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): + device = await service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): - device = service.update_device(device_data) +async def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): + device = await service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") return standard_response(data=device) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, service: DeviceService = Depends(get_service)): - device = service.delete_device(device_id) +async def delete_device(device_id: int, service: DeviceService = Depends(get_service)): + device = await service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") return standard_response(data=device_id) diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py index e61e1d6..ddb5eb6 100644 --- a/apis/device_model_realtion.py +++ b/apis/device_model_realtion.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import StandardResponse, standard_response from db.database import get_db @@ -13,37 +13,23 @@ router = APIRouter() app = get_app() + def get_service(): return app.state.model_relation_service @router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) -def list_by_device( +async def list_by_device( device_id: int, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceModelRelationService(db) - models = list(service.get_device_models(device_id)) - return standard_response(data=models) - - -# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) -# def add_by_device(relation_data: List[DeviceModelRelationCreate], -# device_id: int = Query(...), -# db: Session = Depends(get_db)): -# service = DeviceModelRelationService(db) -# relations = service.add_relations_by_device(device_id, relation_data) -# return standard_response(data=relations) + models = await service.get_device_models(device_id) + return standard_response(data=list(models)) @router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) -def update_by_device(relation_data: List[DeviceModelRelationCreate], - device_id: int = Query(...), - service: DeviceModelRelationService = Depends(get_service)): - relations = service.update_relations_by_device(device_id, relation_data) +async def update_by_device(relation_data: List[DeviceModelRelationCreate], + device_id: int = Query(...), + service: DeviceModelRelationService = Depends(get_service)): + relations = await service.update_relations_by_device(device_id, relation_data) return standard_response(data=relations) - -# @router.delete("/delete_by_device", response_model=StandardResponse[int]) -# def delete_device(device_id: int, db: Session = Depends(get_db)): -# service = DeviceModelRelationService(db) -# count = service.delete_relations_by_device(device_id) -# return standard_response(data=count) diff --git a/apis/device_scene_realtion.py b/apis/device_scene_realtion.py index 9a30a44..d9aa3a2 100644 --- a/apis/device_scene_realtion.py +++ b/apis/device_scene_realtion.py @@ -1,8 +1,8 @@ -from typing import List, Optional + from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import StandardResponse, standard_response from db.database import get_db from entity.device_scene_relation import DeviceSceneRelationInfo, DeviceSceneRelation @@ -13,17 +13,17 @@ @router.get("/get_by_device", response_model=StandardResponse[DeviceSceneRelationInfo]) -def list_by_device( +async def list_by_device( device_id: int, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - scene = service.get_device_scene(device_id) + scene = await service.get_device_scene(device_id) return standard_response(data=scene) @router.post("/update_by_device", response_model=StandardResponse[DeviceSceneRelation]) -def update_by_device(device_id: int, scene_id: int, - db: Session = Depends(get_db)): +async def update_by_device(device_id: int, scene_id: int, + db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - relation = service.update_relation_by_device(device_id, scene_id) + relation = await service.update_relation_by_device(device_id, scene_id) return standard_response(data=relation) diff --git a/apis/frame.py b/apis/frame.py index 1cb9294..5146ede 100644 --- a/apis/frame.py +++ b/apis/frame.py @@ -5,7 +5,7 @@ import cv2 from fastapi import APIRouter, Depends, Query, HTTPException, Request -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from fastapi.responses import StreamingResponse, FileResponse from apis.base import StandardResponse, PageResponse, standard_response, standard_error_response, convert_page_param @@ -19,20 +19,20 @@ @router.get("/page/", response_model=StandardResponse[PageResponse[DeviceFrame]]) -def get_frame_page( +async def get_frame_page( device_name: Optional[str] = None, device_code: Optional[str] = None, frame_start_time: Optional[str] = None, frame_end_time: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceFrameService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - frames, total = service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), - parse_datetime(frame_end_time), offset, limit) + frames, total = await service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), + parse_datetime(frame_end_time), offset, limit) return standard_response( data=PageResponse(total=total, items=frames) @@ -41,10 +41,10 @@ # 路由:使用 OpenCV 生成内存图像并返回字节流 @router.get("/frame_image/") -def get_frame_image(frame_id, db: Session = Depends(get_db)): +async def get_frame_image(frame_id, db: AsyncSession = Depends(get_db)): try: service = DeviceFrameService(db) - frame = service.get_frame_annotator(frame_id) + frame = await service.get_frame_annotator(frame_id) if frame is None: return standard_error_response(message="Frame does not exist") @@ -63,14 +63,3 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"Error reading image: {str(e)}") - -@router.get("/frame_test") -def get_frame_test(request: Request): - file_path = "test.jpg" - return FileResponse(file_path) - # def iterfile(): - # with open(file_path, 'rb') as file: - # while chunk := file.read(1024): - # yield chunk - # - # return StreamingResponse(iterfile(), media_type="image/jpeg") diff --git a/apis/model.py b/apis/model.py index a1161eb..fa2b191 100644 --- a/apis/model.py +++ b/apis/model.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query, UploadFile, File, Form -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from app_instance import get_app @@ -18,27 +18,27 @@ @router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) -def get_model_list( +async def get_model_list( name: Optional[str] = None, remark: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = ModelService(db) - models = list(service.get_model_list(name, remark)) - return standard_response(data=models) + models = await service.get_model_list(name, remark) + return standard_response(data=list(models)) @router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModelInfo]]) -def get_model_page( +async def get_model_page( name: Optional[str] = None, remark: Optional[str] = None, offset: int = Query(0, ge=0), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = ModelService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - models, total = service.get_model_page(name, remark, offset, limit) + models, total = await service.get_model_page(name, remark, offset, limit) return standard_response( data=PageResponse(total=total, items=models) @@ -46,9 +46,9 @@ @router.post("/add", response_model=StandardResponse[AlgoModelInfo]) -def create_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelCreate结构"), +async def create_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelCreate结构"), file: UploadFile = File(..., description="模型文件"), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_db), ): # 检查文件类型 if not file.filename.endswith(".zip"): @@ -56,25 +56,25 @@ model_data = AlgoModelCreate.parse_raw(json_data) service = ModelService(db) - model = service.create_model(model_data, file) + model = await service.create_model(model_data, file) return standard_response(data=model) @router.post("/update", response_model=StandardResponse[AlgoModelInfo]) -def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), +async def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), file: UploadFile = File(None, description="模型文件"), service: ModelService = Depends(get_service)): model_data = AlgoModelUpdate.parse_raw(json_data) - model = service.update_model(model_data, file) + model = await service.update_model(model_data, file) if not model: return standard_error_response(data=model_data, message="Model not found") return standard_response(data=model) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_model(model_id: int, db: Session = Depends(get_db)): +async def delete_model(model_id: int, db: AsyncSession = Depends(get_db)): service = ModelService(db) - model = service.delete_model(model_id) + model = await service.delete_model(model_id) if not model: return standard_error_response(data=model_id, message="Model not found") return standard_response(data=model_id) diff --git a/apis/push_config.py b/apis/push_config.py index 5e25579..e234afb 100644 --- a/apis/push_config.py +++ b/apis/push_config.py @@ -1,7 +1,7 @@ from typing import List from fastapi import APIRouter, Depends -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse from db.database import get_db @@ -12,21 +12,21 @@ @router.get("/list", response_model=StandardResponse[List[PushConfig]]) -def get_push_config_list(db: Session = Depends(get_db)): +async def get_push_config_list(db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_configs = service.get_push_config_list() + push_configs = await service.get_push_config_list() return standard_response(data=push_configs) @router.get("/get_by_type", response_model=StandardResponse[PushConfig]) -def get_by_type(push_type: int, db: Session = Depends(get_db)): +async def get_by_type(push_type: int, db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_config = service.get_push_config(push_type) + push_config = await service.get_push_config(push_type) return standard_response(data=push_config) @router.post("/set_push_config", response_model=StandardResponse[PushConfig]) -def set_push_config(push_config: PushConfigCreate, db: Session = Depends(get_db)): +async def set_push_config(push_config: PushConfigCreate, db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_config = service.set_push_config(push_config) + push_config = await service.set_push_config(push_config) return standard_response(data=push_config) diff --git a/apis/scene.py b/apis/scene.py index f84b54b..1b7f706 100644 --- a/apis/scene.py +++ b/apis/scene.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query, UploadFile, File, Form -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db @@ -14,27 +14,27 @@ @router.get("/list", response_model=StandardResponse[List[Scene]]) -def get_scene_list( +async def get_scene_list( name: Optional[str] = None, remark: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) - scenes = list(service.get_scene_list(name, remark)) - return standard_response(data=scenes) + scenes = await service.get_scene_list(name, remark) + return standard_response(data=list(scenes)) @router.get("/page/", response_model=StandardResponse[PageResponse[SceneInfo]]) -def get_scene_page( +async def get_scene_page( name: Optional[str] = None, remark: Optional[str] = None, offset: int = Query(0, ge=0), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - scenes, total = service.get_scene_page(name, remark, offset, limit) + scenes, total = await service.get_scene_page(name, remark, offset, limit) return standard_response( data=PageResponse(total=total, items=scenes) @@ -42,9 +42,9 @@ @router.post("/add", response_model=StandardResponse[SceneInfo]) -def create_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneCreate结构"), +async def create_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneCreate结构"), file: UploadFile = File(..., description="模型文件"), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_db), ): # 检查文件类型 if not file.filename.endswith(".zip"): @@ -52,26 +52,26 @@ scene_data = SceneCreate.parse_raw(json_data) service = SceneService(db) - scene = service.create_scene(scene_data, file) + scene = await service.create_scene(scene_data, file) return standard_response(data=scene) @router.post("/update", response_model=StandardResponse[SceneInfo]) -def update_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneUpdate结构"), +async def update_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneUpdate结构"), file: UploadFile = File(None, description="模型文件"), - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) scene_data = SceneUpdate.parse_raw(json_data) - scene = service.update_scene(scene_data, file) + scene = await service.update_scene(scene_data, file) if not scene: return standard_error_response(data=scene_data, message="Scene not found") return standard_response(data=scene) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_scene(scene_id: int, db: Session = Depends(get_db)): +async def delete_scene(scene_id: int, db: AsyncSession = Depends(get_db)): service = SceneService(db) - scene = service.delete_scene(scene_id) + scene = await service.delete_scene(scene_id) if not scene: return standard_error_response(data=scene_id, message="Scene not found") return standard_response(data=scene_id) diff --git a/app_instance.py b/app_instance.py index a6e0992..f9d1ad9 100644 --- a/app_instance.py +++ b/app_instance.py @@ -12,6 +12,7 @@ from services.device_model_relation_service import DeviceModelRelationService from services.device_scene_relation_service import DeviceSceneRelationService from services.device_service import DeviceService +from services.global_config import GlobalConfig from services.model_service import ModelService from services.scene_service import SceneService from tcp.tcp_manager import TcpManager @@ -28,7 +29,11 @@ async def lifespan(app: FastAPI): main_loop = asyncio.get_running_loop() - with next(get_db()) as db: + # async with get_db() as db: + async for db in get_db(): + global_config = GlobalConfig() + await global_config.init_config() + device_service = DeviceService(db) model_service = ModelService(db) model_relation_service = DeviceModelRelationService(db) @@ -51,7 +56,7 @@ relation_service=model_relation_service, ) app.state.algo_runner = algo_runner - await algo_runner.start() + # await algo_runner.start() scene_runner = SceneRunner( device_service=device_service, @@ -61,7 +66,9 @@ main_loop=main_loop ) app.state.scene_runner = scene_runner - await scene_runner.start() + # await scene_runner.start() + + yield # 允许请求处理 diff --git a/common/global_logger.py b/common/global_logger.py index 3e3a6a9..3e672ca 100644 --- a/common/global_logger.py +++ b/common/global_logger.py @@ -12,6 +12,8 @@ logger = logging.getLogger("casic_safe_logger") logger.setLevel(logging.DEBUG) # 设置日志级别 +logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING) + # 创建一个TimedRotatingFileHandler handler = logging.handlers.TimedRotatingFileHandler( os.path.join(log_dir, 'app.log'), # 日志文件名 diff --git a/db/database.py b/db/database.py index 4f44d2d..01ac013 100644 --- a/db/database.py +++ b/db/database.py @@ -1,22 +1,25 @@ +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession +from sqlalchemy.orm import sessionmaker from sqlmodel import SQLModel, create_engine, Session -from contextlib import contextmanager +from contextlib import contextmanager, asynccontextmanager sqlite_file_name = "./db/safe-algo-pro.db" -sqlite_url = f"sqlite:///{sqlite_file_name}" +sqlite_url = f"sqlite+aiosqlite:///{sqlite_file_name}" # 使用异步SQLite驱动 -connect_args = {"check_same_thread": False} -engine = create_engine(sqlite_url, connect_args=connect_args) +engine = create_async_engine(sqlite_url, echo=False, future=True) # 初始化数据库表 -def init_db(): - SQLModel.metadata.create_all(engine) +async def init_db(): + async with engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.create_all) -# 数据库会话管理 -def get_db(): - session = Session(engine) - try: +# 异步数据库会话管理 +# @asynccontextmanager +async def get_db() -> AsyncSession: + async_session = sessionmaker( + engine, class_=AsyncSession, expire_on_commit=False + ) + async with async_session() as session: yield session - finally: - session.close() diff --git a/main.py b/main.py index e35e458..e830e86 100644 --- a/main.py +++ b/main.py @@ -3,7 +3,9 @@ import logging from fastapi.openapi.docs import get_swagger_ui_html -from starlette.staticfiles import StaticFiles +from fastapi.middleware.cors import CORSMiddleware +from fastapi.staticfiles import StaticFiles + from app_instance import get_app from common.global_logger import logger @@ -28,6 +30,14 @@ from apis.router import router app.include_router(router, prefix="/api") +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + if __name__ == "__main__": # 重定向 uvicorn 的日志 uvicorn_logger = logging.getLogger("uvicorn") diff --git a/requirements.txt b/requirements.txt index 161192a..aa9e1e6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,6 @@ starlette uvicorn sqlalchemy -aiohttp \ No newline at end of file +aiohttp +aiosqlite +aiofiles \ No newline at end of file diff --git a/services/data_gas_service.py b/services/data_gas_service.py index fe6ce53..6ac5665 100644 --- a/services/data_gas_service.py +++ b/services/data_gas_service.py @@ -2,23 +2,24 @@ from typing import Optional, Tuple, Sequence from sqlalchemy import func -from sqlmodel import Session, select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlmodel import select from entity.data_gas import DataGas, DataGasInfo from entity.device import Device class DataGasService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db - def add_data_gas(self, data_gas: DataGas): + async def add_data_gas(self, data_gas: DataGas): self.db.add(data_gas) - self.db.commit() - self.db.refresh(data_gas) + await self.db.commit() + await self.db.refresh(data_gas) return data_gas - def get_data_gas_page(self, + async def get_data_gas_page(self, device_code: Optional[str] = None, start_time: Optional[datetime] = None, end_time: Optional[datetime] = None, @@ -29,13 +30,17 @@ # 查询总记录数 total_statement = select(func.count()).select_from(statement.subquery()) - total = self.db.exec(total_statement).one() + total_result = await self.db.execute(total_statement) + total = total_result.scalar_one() # 添加分页限制 statement = statement.offset(offset).limit(limit) # 执行查询并返回结果 - results = self.db.exec(statement).all() + results = await self.db.execute(statement) + rows = results.all() + + data_gas_info_list = [ DataGasInfo( id=data_gas.id, @@ -44,18 +49,19 @@ ts=data_gas.ts, device_name=device_name ) - for data_gas, device_name in results + for data_gas, device_name in rows ] return data_gas_info_list, total # 返回分页数据和总数 - def get_data_gas_list(self, + async def get_data_gas_list(self, device_code: Optional[str] = None, start_time: Optional[datetime] = None, end_time: Optional[datetime] = None, ) -> Sequence[DataGasInfo]: statement = self.gas_query(device_code, end_time, start_time) - results = self.db.exec(statement).all() + results = await self.db.execute(statement) + rows = results.all() data_gas_info_list = [ DataGasInfo( id=data_gas.id, @@ -64,7 +70,7 @@ ts=data_gas.ts, device_name=device_name ) - for data_gas, device_name in results + for data_gas, device_name in rows ] return data_gas_info_list diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index dbc3fee..cdb97bf 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -98,7 +98,7 @@ def push_frame_results(self, frame_results): global_config = GlobalConfig() push_config = global_config.get_algo_result_push_config() - if push_config: + if push_config and push_config.push_url: last_ts = self.push_ts current_time = datetime.now() diff --git a/apis/control.py b/apis/control.py index 8517389..4c108c5 100644 --- a/apis/control.py +++ b/apis/control.py @@ -48,7 +48,7 @@ @router.get("/restart") -def restart(): +async def restart(): try: # 立即返回响应的函数 def restart_container_async(): diff --git a/apis/data_gas.py b/apis/data_gas.py index fc8ec54..bcb7cb6 100644 --- a/apis/data_gas.py +++ b/apis/data_gas.py @@ -3,8 +3,8 @@ import pandas as pd from fastapi import APIRouter, Query, Depends -from openpyxl.utils import get_column_letter -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession + from starlette.responses import StreamingResponse from apis.base import StandardResponse, PageResponse, convert_page_param, standard_response @@ -17,19 +17,20 @@ @router.get("/page", response_model=StandardResponse[PageResponse[DataGasInfo]]) -def get_gas_page( +async def get_gas_page( device_code: Optional[str] = None, start_time: Optional[str] = None, end_time: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): + service = DataGasService(db) offset, limit = convert_page_param(offset, limit) - data, total = service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), - offset, limit) + data, total = await service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), + offset, limit) return standard_response( data=PageResponse(total=total, items=data) @@ -37,38 +38,39 @@ @router.get("/export") -def export_data_gas(device_code: Optional[str] = None, - start_time: Optional[str] = None, - end_time: Optional[str] = None, - db: Session = Depends(get_db)): - service = DataGasService(db) - data = service.get_data_gas_list(device_code, start_time, end_time) +async def export_data_gas(device_code: Optional[str] = None, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + db: AsyncSession = Depends(get_db)): + async with db: + service = DataGasService(db) + data = await service.get_data_gas_list(device_code, start_time, end_time) - # 将查询结果转换为 DataFrame - data = [ - {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} - for item in data] - df = pd.DataFrame(data) + # 将查询结果转换为 DataFrame + data = [ + {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} + for item in data] + df = pd.DataFrame(data) - # 使用 BytesIO 生成内存中的 Excel 文件 - output = BytesIO() - with pd.ExcelWriter(output, engine="openpyxl") as writer: - df.to_excel(writer, index=False) + # 使用 BytesIO 生成内存中的 Excel 文件 + output = BytesIO() + with pd.ExcelWriter(output, engine="openpyxl") as writer: + df.to_excel(writer, index=False) - # 获取工作表 - worksheet = writer.sheets["Sheet1"] + # 获取工作表 + worksheet = writer.sheets["Sheet1"] - # 设置固定的列宽,例如宽度为 20 - fixed_width = 30 - for col in worksheet.columns: - col_letter = col[0].column_letter # 获取列字母 - worksheet.column_dimensions[col_letter].width = fixed_width + # 设置固定的列宽,例如宽度为 20 + fixed_width = 30 + for col in worksheet.columns: + col_letter = col[0].column_letter # 获取列字母 + worksheet.column_dimensions[col_letter].width = fixed_width - output.seek(0) + output.seek(0) - # 返回内存中的 Excel 文件作为响应 - return StreamingResponse( - output, - media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", - headers={"Content-Disposition": "attachment; filename=gas.xlsx"} - ) + # 返回内存中的 Excel 文件作为响应 + return StreamingResponse( + output, + media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + headers={"Content-Disposition": "attachment; filename=gas.xlsx"} + ) diff --git a/apis/device.py b/apis/device.py index 85fb10c..f170ae3 100644 --- a/apis/device.py +++ b/apis/device.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db @@ -18,29 +18,29 @@ @router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) -def get_device_list( +async def get_device_list( name: Optional[str] = None, code: Optional[str] = None, device_type: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceService(db) - devices = list(service.get_device_list(name, code, device_type)) - return standard_response(data=devices) + devices = await service.get_device_list(name, code, device_type) + return standard_response(data=list(devices)) @router.get("/page/", response_model=StandardResponse[PageResponse[DeviceInfo]]) -def get_device_page( +async def get_device_page( name: Optional[str] = None, code: Optional[str] = None, device_type: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - devices, total = service.get_device_page(name, code, device_type, offset, limit) + devices, total = await service.get_device_page(name, code, device_type, offset, limit) return standard_response( data=PageResponse(total=total, items=devices) @@ -48,22 +48,22 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): - device = service.create_device(device_data) +async def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): + device = await service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): - device = service.update_device(device_data) +async def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): + device = await service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") return standard_response(data=device) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, service: DeviceService = Depends(get_service)): - device = service.delete_device(device_id) +async def delete_device(device_id: int, service: DeviceService = Depends(get_service)): + device = await service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") return standard_response(data=device_id) diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py index e61e1d6..ddb5eb6 100644 --- a/apis/device_model_realtion.py +++ b/apis/device_model_realtion.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import StandardResponse, standard_response from db.database import get_db @@ -13,37 +13,23 @@ router = APIRouter() app = get_app() + def get_service(): return app.state.model_relation_service @router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) -def list_by_device( +async def list_by_device( device_id: int, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceModelRelationService(db) - models = list(service.get_device_models(device_id)) - return standard_response(data=models) - - -# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) -# def add_by_device(relation_data: List[DeviceModelRelationCreate], -# device_id: int = Query(...), -# db: Session = Depends(get_db)): -# service = DeviceModelRelationService(db) -# relations = service.add_relations_by_device(device_id, relation_data) -# return standard_response(data=relations) + models = await service.get_device_models(device_id) + return standard_response(data=list(models)) @router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) -def update_by_device(relation_data: List[DeviceModelRelationCreate], - device_id: int = Query(...), - service: DeviceModelRelationService = Depends(get_service)): - relations = service.update_relations_by_device(device_id, relation_data) +async def update_by_device(relation_data: List[DeviceModelRelationCreate], + device_id: int = Query(...), + service: DeviceModelRelationService = Depends(get_service)): + relations = await service.update_relations_by_device(device_id, relation_data) return standard_response(data=relations) - -# @router.delete("/delete_by_device", response_model=StandardResponse[int]) -# def delete_device(device_id: int, db: Session = Depends(get_db)): -# service = DeviceModelRelationService(db) -# count = service.delete_relations_by_device(device_id) -# return standard_response(data=count) diff --git a/apis/device_scene_realtion.py b/apis/device_scene_realtion.py index 9a30a44..d9aa3a2 100644 --- a/apis/device_scene_realtion.py +++ b/apis/device_scene_realtion.py @@ -1,8 +1,8 @@ -from typing import List, Optional + from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import StandardResponse, standard_response from db.database import get_db from entity.device_scene_relation import DeviceSceneRelationInfo, DeviceSceneRelation @@ -13,17 +13,17 @@ @router.get("/get_by_device", response_model=StandardResponse[DeviceSceneRelationInfo]) -def list_by_device( +async def list_by_device( device_id: int, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - scene = service.get_device_scene(device_id) + scene = await service.get_device_scene(device_id) return standard_response(data=scene) @router.post("/update_by_device", response_model=StandardResponse[DeviceSceneRelation]) -def update_by_device(device_id: int, scene_id: int, - db: Session = Depends(get_db)): +async def update_by_device(device_id: int, scene_id: int, + db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - relation = service.update_relation_by_device(device_id, scene_id) + relation = await service.update_relation_by_device(device_id, scene_id) return standard_response(data=relation) diff --git a/apis/frame.py b/apis/frame.py index 1cb9294..5146ede 100644 --- a/apis/frame.py +++ b/apis/frame.py @@ -5,7 +5,7 @@ import cv2 from fastapi import APIRouter, Depends, Query, HTTPException, Request -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from fastapi.responses import StreamingResponse, FileResponse from apis.base import StandardResponse, PageResponse, standard_response, standard_error_response, convert_page_param @@ -19,20 +19,20 @@ @router.get("/page/", response_model=StandardResponse[PageResponse[DeviceFrame]]) -def get_frame_page( +async def get_frame_page( device_name: Optional[str] = None, device_code: Optional[str] = None, frame_start_time: Optional[str] = None, frame_end_time: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceFrameService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - frames, total = service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), - parse_datetime(frame_end_time), offset, limit) + frames, total = await service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), + parse_datetime(frame_end_time), offset, limit) return standard_response( data=PageResponse(total=total, items=frames) @@ -41,10 +41,10 @@ # 路由:使用 OpenCV 生成内存图像并返回字节流 @router.get("/frame_image/") -def get_frame_image(frame_id, db: Session = Depends(get_db)): +async def get_frame_image(frame_id, db: AsyncSession = Depends(get_db)): try: service = DeviceFrameService(db) - frame = service.get_frame_annotator(frame_id) + frame = await service.get_frame_annotator(frame_id) if frame is None: return standard_error_response(message="Frame does not exist") @@ -63,14 +63,3 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"Error reading image: {str(e)}") - -@router.get("/frame_test") -def get_frame_test(request: Request): - file_path = "test.jpg" - return FileResponse(file_path) - # def iterfile(): - # with open(file_path, 'rb') as file: - # while chunk := file.read(1024): - # yield chunk - # - # return StreamingResponse(iterfile(), media_type="image/jpeg") diff --git a/apis/model.py b/apis/model.py index a1161eb..fa2b191 100644 --- a/apis/model.py +++ b/apis/model.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query, UploadFile, File, Form -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from app_instance import get_app @@ -18,27 +18,27 @@ @router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) -def get_model_list( +async def get_model_list( name: Optional[str] = None, remark: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = ModelService(db) - models = list(service.get_model_list(name, remark)) - return standard_response(data=models) + models = await service.get_model_list(name, remark) + return standard_response(data=list(models)) @router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModelInfo]]) -def get_model_page( +async def get_model_page( name: Optional[str] = None, remark: Optional[str] = None, offset: int = Query(0, ge=0), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = ModelService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - models, total = service.get_model_page(name, remark, offset, limit) + models, total = await service.get_model_page(name, remark, offset, limit) return standard_response( data=PageResponse(total=total, items=models) @@ -46,9 +46,9 @@ @router.post("/add", response_model=StandardResponse[AlgoModelInfo]) -def create_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelCreate结构"), +async def create_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelCreate结构"), file: UploadFile = File(..., description="模型文件"), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_db), ): # 检查文件类型 if not file.filename.endswith(".zip"): @@ -56,25 +56,25 @@ model_data = AlgoModelCreate.parse_raw(json_data) service = ModelService(db) - model = service.create_model(model_data, file) + model = await service.create_model(model_data, file) return standard_response(data=model) @router.post("/update", response_model=StandardResponse[AlgoModelInfo]) -def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), +async def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), file: UploadFile = File(None, description="模型文件"), service: ModelService = Depends(get_service)): model_data = AlgoModelUpdate.parse_raw(json_data) - model = service.update_model(model_data, file) + model = await service.update_model(model_data, file) if not model: return standard_error_response(data=model_data, message="Model not found") return standard_response(data=model) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_model(model_id: int, db: Session = Depends(get_db)): +async def delete_model(model_id: int, db: AsyncSession = Depends(get_db)): service = ModelService(db) - model = service.delete_model(model_id) + model = await service.delete_model(model_id) if not model: return standard_error_response(data=model_id, message="Model not found") return standard_response(data=model_id) diff --git a/apis/push_config.py b/apis/push_config.py index 5e25579..e234afb 100644 --- a/apis/push_config.py +++ b/apis/push_config.py @@ -1,7 +1,7 @@ from typing import List from fastapi import APIRouter, Depends -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse from db.database import get_db @@ -12,21 +12,21 @@ @router.get("/list", response_model=StandardResponse[List[PushConfig]]) -def get_push_config_list(db: Session = Depends(get_db)): +async def get_push_config_list(db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_configs = service.get_push_config_list() + push_configs = await service.get_push_config_list() return standard_response(data=push_configs) @router.get("/get_by_type", response_model=StandardResponse[PushConfig]) -def get_by_type(push_type: int, db: Session = Depends(get_db)): +async def get_by_type(push_type: int, db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_config = service.get_push_config(push_type) + push_config = await service.get_push_config(push_type) return standard_response(data=push_config) @router.post("/set_push_config", response_model=StandardResponse[PushConfig]) -def set_push_config(push_config: PushConfigCreate, db: Session = Depends(get_db)): +async def set_push_config(push_config: PushConfigCreate, db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_config = service.set_push_config(push_config) + push_config = await service.set_push_config(push_config) return standard_response(data=push_config) diff --git a/apis/scene.py b/apis/scene.py index f84b54b..1b7f706 100644 --- a/apis/scene.py +++ b/apis/scene.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query, UploadFile, File, Form -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db @@ -14,27 +14,27 @@ @router.get("/list", response_model=StandardResponse[List[Scene]]) -def get_scene_list( +async def get_scene_list( name: Optional[str] = None, remark: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) - scenes = list(service.get_scene_list(name, remark)) - return standard_response(data=scenes) + scenes = await service.get_scene_list(name, remark) + return standard_response(data=list(scenes)) @router.get("/page/", response_model=StandardResponse[PageResponse[SceneInfo]]) -def get_scene_page( +async def get_scene_page( name: Optional[str] = None, remark: Optional[str] = None, offset: int = Query(0, ge=0), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - scenes, total = service.get_scene_page(name, remark, offset, limit) + scenes, total = await service.get_scene_page(name, remark, offset, limit) return standard_response( data=PageResponse(total=total, items=scenes) @@ -42,9 +42,9 @@ @router.post("/add", response_model=StandardResponse[SceneInfo]) -def create_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneCreate结构"), +async def create_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneCreate结构"), file: UploadFile = File(..., description="模型文件"), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_db), ): # 检查文件类型 if not file.filename.endswith(".zip"): @@ -52,26 +52,26 @@ scene_data = SceneCreate.parse_raw(json_data) service = SceneService(db) - scene = service.create_scene(scene_data, file) + scene = await service.create_scene(scene_data, file) return standard_response(data=scene) @router.post("/update", response_model=StandardResponse[SceneInfo]) -def update_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneUpdate结构"), +async def update_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneUpdate结构"), file: UploadFile = File(None, description="模型文件"), - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) scene_data = SceneUpdate.parse_raw(json_data) - scene = service.update_scene(scene_data, file) + scene = await service.update_scene(scene_data, file) if not scene: return standard_error_response(data=scene_data, message="Scene not found") return standard_response(data=scene) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_scene(scene_id: int, db: Session = Depends(get_db)): +async def delete_scene(scene_id: int, db: AsyncSession = Depends(get_db)): service = SceneService(db) - scene = service.delete_scene(scene_id) + scene = await service.delete_scene(scene_id) if not scene: return standard_error_response(data=scene_id, message="Scene not found") return standard_response(data=scene_id) diff --git a/app_instance.py b/app_instance.py index a6e0992..f9d1ad9 100644 --- a/app_instance.py +++ b/app_instance.py @@ -12,6 +12,7 @@ from services.device_model_relation_service import DeviceModelRelationService from services.device_scene_relation_service import DeviceSceneRelationService from services.device_service import DeviceService +from services.global_config import GlobalConfig from services.model_service import ModelService from services.scene_service import SceneService from tcp.tcp_manager import TcpManager @@ -28,7 +29,11 @@ async def lifespan(app: FastAPI): main_loop = asyncio.get_running_loop() - with next(get_db()) as db: + # async with get_db() as db: + async for db in get_db(): + global_config = GlobalConfig() + await global_config.init_config() + device_service = DeviceService(db) model_service = ModelService(db) model_relation_service = DeviceModelRelationService(db) @@ -51,7 +56,7 @@ relation_service=model_relation_service, ) app.state.algo_runner = algo_runner - await algo_runner.start() + # await algo_runner.start() scene_runner = SceneRunner( device_service=device_service, @@ -61,7 +66,9 @@ main_loop=main_loop ) app.state.scene_runner = scene_runner - await scene_runner.start() + # await scene_runner.start() + + yield # 允许请求处理 diff --git a/common/global_logger.py b/common/global_logger.py index 3e3a6a9..3e672ca 100644 --- a/common/global_logger.py +++ b/common/global_logger.py @@ -12,6 +12,8 @@ logger = logging.getLogger("casic_safe_logger") logger.setLevel(logging.DEBUG) # 设置日志级别 +logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING) + # 创建一个TimedRotatingFileHandler handler = logging.handlers.TimedRotatingFileHandler( os.path.join(log_dir, 'app.log'), # 日志文件名 diff --git a/db/database.py b/db/database.py index 4f44d2d..01ac013 100644 --- a/db/database.py +++ b/db/database.py @@ -1,22 +1,25 @@ +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession +from sqlalchemy.orm import sessionmaker from sqlmodel import SQLModel, create_engine, Session -from contextlib import contextmanager +from contextlib import contextmanager, asynccontextmanager sqlite_file_name = "./db/safe-algo-pro.db" -sqlite_url = f"sqlite:///{sqlite_file_name}" +sqlite_url = f"sqlite+aiosqlite:///{sqlite_file_name}" # 使用异步SQLite驱动 -connect_args = {"check_same_thread": False} -engine = create_engine(sqlite_url, connect_args=connect_args) +engine = create_async_engine(sqlite_url, echo=False, future=True) # 初始化数据库表 -def init_db(): - SQLModel.metadata.create_all(engine) +async def init_db(): + async with engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.create_all) -# 数据库会话管理 -def get_db(): - session = Session(engine) - try: +# 异步数据库会话管理 +# @asynccontextmanager +async def get_db() -> AsyncSession: + async_session = sessionmaker( + engine, class_=AsyncSession, expire_on_commit=False + ) + async with async_session() as session: yield session - finally: - session.close() diff --git a/main.py b/main.py index e35e458..e830e86 100644 --- a/main.py +++ b/main.py @@ -3,7 +3,9 @@ import logging from fastapi.openapi.docs import get_swagger_ui_html -from starlette.staticfiles import StaticFiles +from fastapi.middleware.cors import CORSMiddleware +from fastapi.staticfiles import StaticFiles + from app_instance import get_app from common.global_logger import logger @@ -28,6 +30,14 @@ from apis.router import router app.include_router(router, prefix="/api") +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + if __name__ == "__main__": # 重定向 uvicorn 的日志 uvicorn_logger = logging.getLogger("uvicorn") diff --git a/requirements.txt b/requirements.txt index 161192a..aa9e1e6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,6 @@ starlette uvicorn sqlalchemy -aiohttp \ No newline at end of file +aiohttp +aiosqlite +aiofiles \ No newline at end of file diff --git a/services/data_gas_service.py b/services/data_gas_service.py index fe6ce53..6ac5665 100644 --- a/services/data_gas_service.py +++ b/services/data_gas_service.py @@ -2,23 +2,24 @@ from typing import Optional, Tuple, Sequence from sqlalchemy import func -from sqlmodel import Session, select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlmodel import select from entity.data_gas import DataGas, DataGasInfo from entity.device import Device class DataGasService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db - def add_data_gas(self, data_gas: DataGas): + async def add_data_gas(self, data_gas: DataGas): self.db.add(data_gas) - self.db.commit() - self.db.refresh(data_gas) + await self.db.commit() + await self.db.refresh(data_gas) return data_gas - def get_data_gas_page(self, + async def get_data_gas_page(self, device_code: Optional[str] = None, start_time: Optional[datetime] = None, end_time: Optional[datetime] = None, @@ -29,13 +30,17 @@ # 查询总记录数 total_statement = select(func.count()).select_from(statement.subquery()) - total = self.db.exec(total_statement).one() + total_result = await self.db.execute(total_statement) + total = total_result.scalar_one() # 添加分页限制 statement = statement.offset(offset).limit(limit) # 执行查询并返回结果 - results = self.db.exec(statement).all() + results = await self.db.execute(statement) + rows = results.all() + + data_gas_info_list = [ DataGasInfo( id=data_gas.id, @@ -44,18 +49,19 @@ ts=data_gas.ts, device_name=device_name ) - for data_gas, device_name in results + for data_gas, device_name in rows ] return data_gas_info_list, total # 返回分页数据和总数 - def get_data_gas_list(self, + async def get_data_gas_list(self, device_code: Optional[str] = None, start_time: Optional[datetime] = None, end_time: Optional[datetime] = None, ) -> Sequence[DataGasInfo]: statement = self.gas_query(device_code, end_time, start_time) - results = self.db.exec(statement).all() + results = await self.db.execute(statement) + rows = results.all() data_gas_info_list = [ DataGasInfo( id=data_gas.id, @@ -64,7 +70,7 @@ ts=data_gas.ts, device_name=device_name ) - for data_gas, device_name in results + for data_gas, device_name in rows ] return data_gas_info_list diff --git a/services/device_frame_service.py b/services/device_frame_service.py index 5dead7d..1a468ae 100644 --- a/services/device_frame_service.py +++ b/services/device_frame_service.py @@ -3,7 +3,10 @@ from copy import deepcopy from datetime import datetime from typing import Sequence, Optional, Tuple -from sqlmodel import Session + +import aiofiles +import numpy as np +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import func @@ -13,46 +16,56 @@ import cv2 -from sqlmodel import Session, select, delete +from sqlmodel import select, delete from services.frame_analysis_result_service import FrameAnalysisResultService class DeviceFrameService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db - def add_frame(self, device_id, frame_data) -> DeviceFrame: - def save_frame_file(): - # 生成当前年月日作为目录路径 - current_date = datetime.now().strftime('%Y-%m-%d') - # 生成随机 UUID 作为文件名 - file_name = str(uuid.uuid4()) + ".jpeg" - # 创建保存图片的完整路径 - save_path = os.path.join('./storage/frames', current_date, file_name) - # 创建目录(如果不存在) - os.makedirs(os.path.dirname(save_path), exist_ok=True) - # 保存图片 - cv2.imwrite(save_path, frame_data) - return save_path + async def add_frame(self, device_id, frame_data) -> DeviceFrame: + async def add_frame(self, device_id, frame_data) -> 'DeviceFrame': + async def save_frame_file(): + # 生成当前年月日作为目录路径 + current_date = datetime.now().strftime('%Y-%m-%d') + # 生成随机 UUID 作为文件名 + file_name = f"{uuid.uuid4()}.jpeg" + # 创建保存图片的完整路径 + save_path = os.path.join('./storage/frames', current_date, file_name) + # 创建目录(如果不存在) + os.makedirs(os.path.dirname(save_path), exist_ok=True) - # 保存图片文件 - file_path = save_frame_file() - device_frame = DeviceFrame(device_id=device_id, frame_path=file_path) - self.db.add(device_frame) - self.db.commit() - self.db.refresh(device_frame) - return device_frame + # 将 frame_data 转换为二进制数据(假设 frame_data 是一个 numpy 数组) + _, encoded_image = cv2.imencode('.jpeg', frame_data) + image_data = encoded_image.tobytes() - def get_frame_page(self, - device_name: Optional[str] = None, - device_code: Optional[str] = None, - frame_start_time: Optional[datetime] = None, - frame_end_time: Optional[datetime] = None, - offset: int = 0, - limit: int = 10 - ) -> Tuple[Sequence[DeviceFrame], int]: + # 使用 aiofiles 进行异步写入 + async with aiofiles.open(save_path, 'wb') as f: + await f.write(image_data) + + return save_path + + # 异步保存图片文件 + file_path = await save_frame_file() + + # 创建并保存到数据库中 + device_frame = DeviceFrame(device_id=device_id, frame_path=file_path) + self.db.add(device_frame) + await self.db.commit() + await self.db.refresh(device_frame) + return device_frame + + async def get_frame_page(self, + device_name: Optional[str] = None, + device_code: Optional[str] = None, + frame_start_time: Optional[datetime] = None, + frame_end_time: Optional[datetime] = None, + offset: int = 0, + limit: int = 10 + ) -> Tuple[Sequence[DeviceFrame], int]: statement = ( select(DeviceFrame, Device) .join(Device, DeviceFrame.device_id == Device.id) @@ -69,26 +82,37 @@ # 查询总记录数 total_statement = select(func.count()).select_from(statement.subquery()) - total = self.db.exec(total_statement).one() + total_result = await self.db.execute(total_statement) + total = total_result.scalar_one() # 添加分页限制 statement = statement.offset(offset).limit(limit) # 执行查询并返回结果 - results = self.db.exec(statement) - results = results.all() - frames = [frame for frame, device in results] + results = await self.db.execute(statement) + rows = results.all() + + frames = [frame for frame, device in rows] return frames, total # 返回分页数据和总数 - def get_frame(self, frame_id: int): - return self.db.get(DeviceFrame, frame_id) + async def get_frame(self, frame_id: int): + result = await self.db.execute(select(DeviceFrame).where(DeviceFrame.id == frame_id)) + frame = result.scalar_one_or_none() + return frame - def get_frame_annotator(self, frame_id: int): - device_frame = self.get_frame(frame_id) + async def get_frame_annotator(self, frame_id: int): + device_frame = await self.get_frame(frame_id) if device_frame: - frame_image = cv2.imread(device_frame.frame_path) + # 异步读取图像文件 + async with aiofiles.open(device_frame.frame_path, mode='rb') as f: + file_content = await f.read() + + # 将读取的字节内容转换为 OpenCV 图像 + np_array = np.frombuffer(file_content, dtype=np.uint8) + frame_image = cv2.imdecode(np_array, cv2.IMREAD_COLOR) + frame_analysis_result_service = FrameAnalysisResultService(self.db) - results = frame_analysis_result_service.get_results_by_frame(device_frame.id) + results = await frame_analysis_result_service.get_results_by_frame(device_frame.id) if results: annotator = Annotator(deepcopy(frame_image)) height, width = frame_image.shape[:2] diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index dbc3fee..cdb97bf 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -98,7 +98,7 @@ def push_frame_results(self, frame_results): global_config = GlobalConfig() push_config = global_config.get_algo_result_push_config() - if push_config: + if push_config and push_config.push_url: last_ts = self.push_ts current_time = datetime.now() diff --git a/apis/control.py b/apis/control.py index 8517389..4c108c5 100644 --- a/apis/control.py +++ b/apis/control.py @@ -48,7 +48,7 @@ @router.get("/restart") -def restart(): +async def restart(): try: # 立即返回响应的函数 def restart_container_async(): diff --git a/apis/data_gas.py b/apis/data_gas.py index fc8ec54..bcb7cb6 100644 --- a/apis/data_gas.py +++ b/apis/data_gas.py @@ -3,8 +3,8 @@ import pandas as pd from fastapi import APIRouter, Query, Depends -from openpyxl.utils import get_column_letter -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession + from starlette.responses import StreamingResponse from apis.base import StandardResponse, PageResponse, convert_page_param, standard_response @@ -17,19 +17,20 @@ @router.get("/page", response_model=StandardResponse[PageResponse[DataGasInfo]]) -def get_gas_page( +async def get_gas_page( device_code: Optional[str] = None, start_time: Optional[str] = None, end_time: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): + service = DataGasService(db) offset, limit = convert_page_param(offset, limit) - data, total = service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), - offset, limit) + data, total = await service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), + offset, limit) return standard_response( data=PageResponse(total=total, items=data) @@ -37,38 +38,39 @@ @router.get("/export") -def export_data_gas(device_code: Optional[str] = None, - start_time: Optional[str] = None, - end_time: Optional[str] = None, - db: Session = Depends(get_db)): - service = DataGasService(db) - data = service.get_data_gas_list(device_code, start_time, end_time) +async def export_data_gas(device_code: Optional[str] = None, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + db: AsyncSession = Depends(get_db)): + async with db: + service = DataGasService(db) + data = await service.get_data_gas_list(device_code, start_time, end_time) - # 将查询结果转换为 DataFrame - data = [ - {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} - for item in data] - df = pd.DataFrame(data) + # 将查询结果转换为 DataFrame + data = [ + {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} + for item in data] + df = pd.DataFrame(data) - # 使用 BytesIO 生成内存中的 Excel 文件 - output = BytesIO() - with pd.ExcelWriter(output, engine="openpyxl") as writer: - df.to_excel(writer, index=False) + # 使用 BytesIO 生成内存中的 Excel 文件 + output = BytesIO() + with pd.ExcelWriter(output, engine="openpyxl") as writer: + df.to_excel(writer, index=False) - # 获取工作表 - worksheet = writer.sheets["Sheet1"] + # 获取工作表 + worksheet = writer.sheets["Sheet1"] - # 设置固定的列宽,例如宽度为 20 - fixed_width = 30 - for col in worksheet.columns: - col_letter = col[0].column_letter # 获取列字母 - worksheet.column_dimensions[col_letter].width = fixed_width + # 设置固定的列宽,例如宽度为 20 + fixed_width = 30 + for col in worksheet.columns: + col_letter = col[0].column_letter # 获取列字母 + worksheet.column_dimensions[col_letter].width = fixed_width - output.seek(0) + output.seek(0) - # 返回内存中的 Excel 文件作为响应 - return StreamingResponse( - output, - media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", - headers={"Content-Disposition": "attachment; filename=gas.xlsx"} - ) + # 返回内存中的 Excel 文件作为响应 + return StreamingResponse( + output, + media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + headers={"Content-Disposition": "attachment; filename=gas.xlsx"} + ) diff --git a/apis/device.py b/apis/device.py index 85fb10c..f170ae3 100644 --- a/apis/device.py +++ b/apis/device.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db @@ -18,29 +18,29 @@ @router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) -def get_device_list( +async def get_device_list( name: Optional[str] = None, code: Optional[str] = None, device_type: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceService(db) - devices = list(service.get_device_list(name, code, device_type)) - return standard_response(data=devices) + devices = await service.get_device_list(name, code, device_type) + return standard_response(data=list(devices)) @router.get("/page/", response_model=StandardResponse[PageResponse[DeviceInfo]]) -def get_device_page( +async def get_device_page( name: Optional[str] = None, code: Optional[str] = None, device_type: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - devices, total = service.get_device_page(name, code, device_type, offset, limit) + devices, total = await service.get_device_page(name, code, device_type, offset, limit) return standard_response( data=PageResponse(total=total, items=devices) @@ -48,22 +48,22 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): - device = service.create_device(device_data) +async def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): + device = await service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): - device = service.update_device(device_data) +async def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): + device = await service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") return standard_response(data=device) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, service: DeviceService = Depends(get_service)): - device = service.delete_device(device_id) +async def delete_device(device_id: int, service: DeviceService = Depends(get_service)): + device = await service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") return standard_response(data=device_id) diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py index e61e1d6..ddb5eb6 100644 --- a/apis/device_model_realtion.py +++ b/apis/device_model_realtion.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import StandardResponse, standard_response from db.database import get_db @@ -13,37 +13,23 @@ router = APIRouter() app = get_app() + def get_service(): return app.state.model_relation_service @router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) -def list_by_device( +async def list_by_device( device_id: int, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceModelRelationService(db) - models = list(service.get_device_models(device_id)) - return standard_response(data=models) - - -# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) -# def add_by_device(relation_data: List[DeviceModelRelationCreate], -# device_id: int = Query(...), -# db: Session = Depends(get_db)): -# service = DeviceModelRelationService(db) -# relations = service.add_relations_by_device(device_id, relation_data) -# return standard_response(data=relations) + models = await service.get_device_models(device_id) + return standard_response(data=list(models)) @router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) -def update_by_device(relation_data: List[DeviceModelRelationCreate], - device_id: int = Query(...), - service: DeviceModelRelationService = Depends(get_service)): - relations = service.update_relations_by_device(device_id, relation_data) +async def update_by_device(relation_data: List[DeviceModelRelationCreate], + device_id: int = Query(...), + service: DeviceModelRelationService = Depends(get_service)): + relations = await service.update_relations_by_device(device_id, relation_data) return standard_response(data=relations) - -# @router.delete("/delete_by_device", response_model=StandardResponse[int]) -# def delete_device(device_id: int, db: Session = Depends(get_db)): -# service = DeviceModelRelationService(db) -# count = service.delete_relations_by_device(device_id) -# return standard_response(data=count) diff --git a/apis/device_scene_realtion.py b/apis/device_scene_realtion.py index 9a30a44..d9aa3a2 100644 --- a/apis/device_scene_realtion.py +++ b/apis/device_scene_realtion.py @@ -1,8 +1,8 @@ -from typing import List, Optional + from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import StandardResponse, standard_response from db.database import get_db from entity.device_scene_relation import DeviceSceneRelationInfo, DeviceSceneRelation @@ -13,17 +13,17 @@ @router.get("/get_by_device", response_model=StandardResponse[DeviceSceneRelationInfo]) -def list_by_device( +async def list_by_device( device_id: int, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - scene = service.get_device_scene(device_id) + scene = await service.get_device_scene(device_id) return standard_response(data=scene) @router.post("/update_by_device", response_model=StandardResponse[DeviceSceneRelation]) -def update_by_device(device_id: int, scene_id: int, - db: Session = Depends(get_db)): +async def update_by_device(device_id: int, scene_id: int, + db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - relation = service.update_relation_by_device(device_id, scene_id) + relation = await service.update_relation_by_device(device_id, scene_id) return standard_response(data=relation) diff --git a/apis/frame.py b/apis/frame.py index 1cb9294..5146ede 100644 --- a/apis/frame.py +++ b/apis/frame.py @@ -5,7 +5,7 @@ import cv2 from fastapi import APIRouter, Depends, Query, HTTPException, Request -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from fastapi.responses import StreamingResponse, FileResponse from apis.base import StandardResponse, PageResponse, standard_response, standard_error_response, convert_page_param @@ -19,20 +19,20 @@ @router.get("/page/", response_model=StandardResponse[PageResponse[DeviceFrame]]) -def get_frame_page( +async def get_frame_page( device_name: Optional[str] = None, device_code: Optional[str] = None, frame_start_time: Optional[str] = None, frame_end_time: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceFrameService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - frames, total = service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), - parse_datetime(frame_end_time), offset, limit) + frames, total = await service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), + parse_datetime(frame_end_time), offset, limit) return standard_response( data=PageResponse(total=total, items=frames) @@ -41,10 +41,10 @@ # 路由:使用 OpenCV 生成内存图像并返回字节流 @router.get("/frame_image/") -def get_frame_image(frame_id, db: Session = Depends(get_db)): +async def get_frame_image(frame_id, db: AsyncSession = Depends(get_db)): try: service = DeviceFrameService(db) - frame = service.get_frame_annotator(frame_id) + frame = await service.get_frame_annotator(frame_id) if frame is None: return standard_error_response(message="Frame does not exist") @@ -63,14 +63,3 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"Error reading image: {str(e)}") - -@router.get("/frame_test") -def get_frame_test(request: Request): - file_path = "test.jpg" - return FileResponse(file_path) - # def iterfile(): - # with open(file_path, 'rb') as file: - # while chunk := file.read(1024): - # yield chunk - # - # return StreamingResponse(iterfile(), media_type="image/jpeg") diff --git a/apis/model.py b/apis/model.py index a1161eb..fa2b191 100644 --- a/apis/model.py +++ b/apis/model.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query, UploadFile, File, Form -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from app_instance import get_app @@ -18,27 +18,27 @@ @router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) -def get_model_list( +async def get_model_list( name: Optional[str] = None, remark: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = ModelService(db) - models = list(service.get_model_list(name, remark)) - return standard_response(data=models) + models = await service.get_model_list(name, remark) + return standard_response(data=list(models)) @router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModelInfo]]) -def get_model_page( +async def get_model_page( name: Optional[str] = None, remark: Optional[str] = None, offset: int = Query(0, ge=0), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = ModelService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - models, total = service.get_model_page(name, remark, offset, limit) + models, total = await service.get_model_page(name, remark, offset, limit) return standard_response( data=PageResponse(total=total, items=models) @@ -46,9 +46,9 @@ @router.post("/add", response_model=StandardResponse[AlgoModelInfo]) -def create_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelCreate结构"), +async def create_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelCreate结构"), file: UploadFile = File(..., description="模型文件"), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_db), ): # 检查文件类型 if not file.filename.endswith(".zip"): @@ -56,25 +56,25 @@ model_data = AlgoModelCreate.parse_raw(json_data) service = ModelService(db) - model = service.create_model(model_data, file) + model = await service.create_model(model_data, file) return standard_response(data=model) @router.post("/update", response_model=StandardResponse[AlgoModelInfo]) -def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), +async def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), file: UploadFile = File(None, description="模型文件"), service: ModelService = Depends(get_service)): model_data = AlgoModelUpdate.parse_raw(json_data) - model = service.update_model(model_data, file) + model = await service.update_model(model_data, file) if not model: return standard_error_response(data=model_data, message="Model not found") return standard_response(data=model) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_model(model_id: int, db: Session = Depends(get_db)): +async def delete_model(model_id: int, db: AsyncSession = Depends(get_db)): service = ModelService(db) - model = service.delete_model(model_id) + model = await service.delete_model(model_id) if not model: return standard_error_response(data=model_id, message="Model not found") return standard_response(data=model_id) diff --git a/apis/push_config.py b/apis/push_config.py index 5e25579..e234afb 100644 --- a/apis/push_config.py +++ b/apis/push_config.py @@ -1,7 +1,7 @@ from typing import List from fastapi import APIRouter, Depends -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse from db.database import get_db @@ -12,21 +12,21 @@ @router.get("/list", response_model=StandardResponse[List[PushConfig]]) -def get_push_config_list(db: Session = Depends(get_db)): +async def get_push_config_list(db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_configs = service.get_push_config_list() + push_configs = await service.get_push_config_list() return standard_response(data=push_configs) @router.get("/get_by_type", response_model=StandardResponse[PushConfig]) -def get_by_type(push_type: int, db: Session = Depends(get_db)): +async def get_by_type(push_type: int, db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_config = service.get_push_config(push_type) + push_config = await service.get_push_config(push_type) return standard_response(data=push_config) @router.post("/set_push_config", response_model=StandardResponse[PushConfig]) -def set_push_config(push_config: PushConfigCreate, db: Session = Depends(get_db)): +async def set_push_config(push_config: PushConfigCreate, db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_config = service.set_push_config(push_config) + push_config = await service.set_push_config(push_config) return standard_response(data=push_config) diff --git a/apis/scene.py b/apis/scene.py index f84b54b..1b7f706 100644 --- a/apis/scene.py +++ b/apis/scene.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query, UploadFile, File, Form -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db @@ -14,27 +14,27 @@ @router.get("/list", response_model=StandardResponse[List[Scene]]) -def get_scene_list( +async def get_scene_list( name: Optional[str] = None, remark: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) - scenes = list(service.get_scene_list(name, remark)) - return standard_response(data=scenes) + scenes = await service.get_scene_list(name, remark) + return standard_response(data=list(scenes)) @router.get("/page/", response_model=StandardResponse[PageResponse[SceneInfo]]) -def get_scene_page( +async def get_scene_page( name: Optional[str] = None, remark: Optional[str] = None, offset: int = Query(0, ge=0), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - scenes, total = service.get_scene_page(name, remark, offset, limit) + scenes, total = await service.get_scene_page(name, remark, offset, limit) return standard_response( data=PageResponse(total=total, items=scenes) @@ -42,9 +42,9 @@ @router.post("/add", response_model=StandardResponse[SceneInfo]) -def create_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneCreate结构"), +async def create_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneCreate结构"), file: UploadFile = File(..., description="模型文件"), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_db), ): # 检查文件类型 if not file.filename.endswith(".zip"): @@ -52,26 +52,26 @@ scene_data = SceneCreate.parse_raw(json_data) service = SceneService(db) - scene = service.create_scene(scene_data, file) + scene = await service.create_scene(scene_data, file) return standard_response(data=scene) @router.post("/update", response_model=StandardResponse[SceneInfo]) -def update_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneUpdate结构"), +async def update_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneUpdate结构"), file: UploadFile = File(None, description="模型文件"), - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) scene_data = SceneUpdate.parse_raw(json_data) - scene = service.update_scene(scene_data, file) + scene = await service.update_scene(scene_data, file) if not scene: return standard_error_response(data=scene_data, message="Scene not found") return standard_response(data=scene) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_scene(scene_id: int, db: Session = Depends(get_db)): +async def delete_scene(scene_id: int, db: AsyncSession = Depends(get_db)): service = SceneService(db) - scene = service.delete_scene(scene_id) + scene = await service.delete_scene(scene_id) if not scene: return standard_error_response(data=scene_id, message="Scene not found") return standard_response(data=scene_id) diff --git a/app_instance.py b/app_instance.py index a6e0992..f9d1ad9 100644 --- a/app_instance.py +++ b/app_instance.py @@ -12,6 +12,7 @@ from services.device_model_relation_service import DeviceModelRelationService from services.device_scene_relation_service import DeviceSceneRelationService from services.device_service import DeviceService +from services.global_config import GlobalConfig from services.model_service import ModelService from services.scene_service import SceneService from tcp.tcp_manager import TcpManager @@ -28,7 +29,11 @@ async def lifespan(app: FastAPI): main_loop = asyncio.get_running_loop() - with next(get_db()) as db: + # async with get_db() as db: + async for db in get_db(): + global_config = GlobalConfig() + await global_config.init_config() + device_service = DeviceService(db) model_service = ModelService(db) model_relation_service = DeviceModelRelationService(db) @@ -51,7 +56,7 @@ relation_service=model_relation_service, ) app.state.algo_runner = algo_runner - await algo_runner.start() + # await algo_runner.start() scene_runner = SceneRunner( device_service=device_service, @@ -61,7 +66,9 @@ main_loop=main_loop ) app.state.scene_runner = scene_runner - await scene_runner.start() + # await scene_runner.start() + + yield # 允许请求处理 diff --git a/common/global_logger.py b/common/global_logger.py index 3e3a6a9..3e672ca 100644 --- a/common/global_logger.py +++ b/common/global_logger.py @@ -12,6 +12,8 @@ logger = logging.getLogger("casic_safe_logger") logger.setLevel(logging.DEBUG) # 设置日志级别 +logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING) + # 创建一个TimedRotatingFileHandler handler = logging.handlers.TimedRotatingFileHandler( os.path.join(log_dir, 'app.log'), # 日志文件名 diff --git a/db/database.py b/db/database.py index 4f44d2d..01ac013 100644 --- a/db/database.py +++ b/db/database.py @@ -1,22 +1,25 @@ +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession +from sqlalchemy.orm import sessionmaker from sqlmodel import SQLModel, create_engine, Session -from contextlib import contextmanager +from contextlib import contextmanager, asynccontextmanager sqlite_file_name = "./db/safe-algo-pro.db" -sqlite_url = f"sqlite:///{sqlite_file_name}" +sqlite_url = f"sqlite+aiosqlite:///{sqlite_file_name}" # 使用异步SQLite驱动 -connect_args = {"check_same_thread": False} -engine = create_engine(sqlite_url, connect_args=connect_args) +engine = create_async_engine(sqlite_url, echo=False, future=True) # 初始化数据库表 -def init_db(): - SQLModel.metadata.create_all(engine) +async def init_db(): + async with engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.create_all) -# 数据库会话管理 -def get_db(): - session = Session(engine) - try: +# 异步数据库会话管理 +# @asynccontextmanager +async def get_db() -> AsyncSession: + async_session = sessionmaker( + engine, class_=AsyncSession, expire_on_commit=False + ) + async with async_session() as session: yield session - finally: - session.close() diff --git a/main.py b/main.py index e35e458..e830e86 100644 --- a/main.py +++ b/main.py @@ -3,7 +3,9 @@ import logging from fastapi.openapi.docs import get_swagger_ui_html -from starlette.staticfiles import StaticFiles +from fastapi.middleware.cors import CORSMiddleware +from fastapi.staticfiles import StaticFiles + from app_instance import get_app from common.global_logger import logger @@ -28,6 +30,14 @@ from apis.router import router app.include_router(router, prefix="/api") +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + if __name__ == "__main__": # 重定向 uvicorn 的日志 uvicorn_logger = logging.getLogger("uvicorn") diff --git a/requirements.txt b/requirements.txt index 161192a..aa9e1e6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,6 @@ starlette uvicorn sqlalchemy -aiohttp \ No newline at end of file +aiohttp +aiosqlite +aiofiles \ No newline at end of file diff --git a/services/data_gas_service.py b/services/data_gas_service.py index fe6ce53..6ac5665 100644 --- a/services/data_gas_service.py +++ b/services/data_gas_service.py @@ -2,23 +2,24 @@ from typing import Optional, Tuple, Sequence from sqlalchemy import func -from sqlmodel import Session, select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlmodel import select from entity.data_gas import DataGas, DataGasInfo from entity.device import Device class DataGasService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db - def add_data_gas(self, data_gas: DataGas): + async def add_data_gas(self, data_gas: DataGas): self.db.add(data_gas) - self.db.commit() - self.db.refresh(data_gas) + await self.db.commit() + await self.db.refresh(data_gas) return data_gas - def get_data_gas_page(self, + async def get_data_gas_page(self, device_code: Optional[str] = None, start_time: Optional[datetime] = None, end_time: Optional[datetime] = None, @@ -29,13 +30,17 @@ # 查询总记录数 total_statement = select(func.count()).select_from(statement.subquery()) - total = self.db.exec(total_statement).one() + total_result = await self.db.execute(total_statement) + total = total_result.scalar_one() # 添加分页限制 statement = statement.offset(offset).limit(limit) # 执行查询并返回结果 - results = self.db.exec(statement).all() + results = await self.db.execute(statement) + rows = results.all() + + data_gas_info_list = [ DataGasInfo( id=data_gas.id, @@ -44,18 +49,19 @@ ts=data_gas.ts, device_name=device_name ) - for data_gas, device_name in results + for data_gas, device_name in rows ] return data_gas_info_list, total # 返回分页数据和总数 - def get_data_gas_list(self, + async def get_data_gas_list(self, device_code: Optional[str] = None, start_time: Optional[datetime] = None, end_time: Optional[datetime] = None, ) -> Sequence[DataGasInfo]: statement = self.gas_query(device_code, end_time, start_time) - results = self.db.exec(statement).all() + results = await self.db.execute(statement) + rows = results.all() data_gas_info_list = [ DataGasInfo( id=data_gas.id, @@ -64,7 +70,7 @@ ts=data_gas.ts, device_name=device_name ) - for data_gas, device_name in results + for data_gas, device_name in rows ] return data_gas_info_list diff --git a/services/device_frame_service.py b/services/device_frame_service.py index 5dead7d..1a468ae 100644 --- a/services/device_frame_service.py +++ b/services/device_frame_service.py @@ -3,7 +3,10 @@ from copy import deepcopy from datetime import datetime from typing import Sequence, Optional, Tuple -from sqlmodel import Session + +import aiofiles +import numpy as np +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import func @@ -13,46 +16,56 @@ import cv2 -from sqlmodel import Session, select, delete +from sqlmodel import select, delete from services.frame_analysis_result_service import FrameAnalysisResultService class DeviceFrameService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db - def add_frame(self, device_id, frame_data) -> DeviceFrame: - def save_frame_file(): - # 生成当前年月日作为目录路径 - current_date = datetime.now().strftime('%Y-%m-%d') - # 生成随机 UUID 作为文件名 - file_name = str(uuid.uuid4()) + ".jpeg" - # 创建保存图片的完整路径 - save_path = os.path.join('./storage/frames', current_date, file_name) - # 创建目录(如果不存在) - os.makedirs(os.path.dirname(save_path), exist_ok=True) - # 保存图片 - cv2.imwrite(save_path, frame_data) - return save_path + async def add_frame(self, device_id, frame_data) -> DeviceFrame: + async def add_frame(self, device_id, frame_data) -> 'DeviceFrame': + async def save_frame_file(): + # 生成当前年月日作为目录路径 + current_date = datetime.now().strftime('%Y-%m-%d') + # 生成随机 UUID 作为文件名 + file_name = f"{uuid.uuid4()}.jpeg" + # 创建保存图片的完整路径 + save_path = os.path.join('./storage/frames', current_date, file_name) + # 创建目录(如果不存在) + os.makedirs(os.path.dirname(save_path), exist_ok=True) - # 保存图片文件 - file_path = save_frame_file() - device_frame = DeviceFrame(device_id=device_id, frame_path=file_path) - self.db.add(device_frame) - self.db.commit() - self.db.refresh(device_frame) - return device_frame + # 将 frame_data 转换为二进制数据(假设 frame_data 是一个 numpy 数组) + _, encoded_image = cv2.imencode('.jpeg', frame_data) + image_data = encoded_image.tobytes() - def get_frame_page(self, - device_name: Optional[str] = None, - device_code: Optional[str] = None, - frame_start_time: Optional[datetime] = None, - frame_end_time: Optional[datetime] = None, - offset: int = 0, - limit: int = 10 - ) -> Tuple[Sequence[DeviceFrame], int]: + # 使用 aiofiles 进行异步写入 + async with aiofiles.open(save_path, 'wb') as f: + await f.write(image_data) + + return save_path + + # 异步保存图片文件 + file_path = await save_frame_file() + + # 创建并保存到数据库中 + device_frame = DeviceFrame(device_id=device_id, frame_path=file_path) + self.db.add(device_frame) + await self.db.commit() + await self.db.refresh(device_frame) + return device_frame + + async def get_frame_page(self, + device_name: Optional[str] = None, + device_code: Optional[str] = None, + frame_start_time: Optional[datetime] = None, + frame_end_time: Optional[datetime] = None, + offset: int = 0, + limit: int = 10 + ) -> Tuple[Sequence[DeviceFrame], int]: statement = ( select(DeviceFrame, Device) .join(Device, DeviceFrame.device_id == Device.id) @@ -69,26 +82,37 @@ # 查询总记录数 total_statement = select(func.count()).select_from(statement.subquery()) - total = self.db.exec(total_statement).one() + total_result = await self.db.execute(total_statement) + total = total_result.scalar_one() # 添加分页限制 statement = statement.offset(offset).limit(limit) # 执行查询并返回结果 - results = self.db.exec(statement) - results = results.all() - frames = [frame for frame, device in results] + results = await self.db.execute(statement) + rows = results.all() + + frames = [frame for frame, device in rows] return frames, total # 返回分页数据和总数 - def get_frame(self, frame_id: int): - return self.db.get(DeviceFrame, frame_id) + async def get_frame(self, frame_id: int): + result = await self.db.execute(select(DeviceFrame).where(DeviceFrame.id == frame_id)) + frame = result.scalar_one_or_none() + return frame - def get_frame_annotator(self, frame_id: int): - device_frame = self.get_frame(frame_id) + async def get_frame_annotator(self, frame_id: int): + device_frame = await self.get_frame(frame_id) if device_frame: - frame_image = cv2.imread(device_frame.frame_path) + # 异步读取图像文件 + async with aiofiles.open(device_frame.frame_path, mode='rb') as f: + file_content = await f.read() + + # 将读取的字节内容转换为 OpenCV 图像 + np_array = np.frombuffer(file_content, dtype=np.uint8) + frame_image = cv2.imdecode(np_array, cv2.IMREAD_COLOR) + frame_analysis_result_service = FrameAnalysisResultService(self.db) - results = frame_analysis_result_service.get_results_by_frame(device_frame.id) + results = await frame_analysis_result_service.get_results_by_frame(device_frame.id) if results: annotator = Annotator(deepcopy(frame_image)) height, width = frame_image.shape[:2] diff --git a/services/device_model_relation_service.py b/services/device_model_relation_service.py index c32b5a5..5f1aa9d 100644 --- a/services/device_model_relation_service.py +++ b/services/device_model_relation_service.py @@ -1,8 +1,8 @@ from datetime import datetime from typing import List -from sqlmodel import Session, select, delete - +from sqlmodel import select, delete +from sqlalchemy.ext.asyncio import AsyncSession from common.consts import NotifyChangeType from common.global_thread_pool import GlobalThreadPool from entity.device_model_relation import DeviceModelRelation, DeviceModelRelationInfo, DeviceModelRelationCreate @@ -10,7 +10,7 @@ class DeviceModelRelationService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db self.__relation_change_callbacks = [] # 用于存储回调函数 self.thread_pool = GlobalThreadPool() @@ -24,7 +24,7 @@ for callback in self.__relation_change_callbacks: self.thread_pool.executor.submit(callback, device_id, change_type) - def get_device_models(self, device_id: int) -> List[DeviceModelRelationInfo]: + async def get_device_models(self, device_id: int) -> List[DeviceModelRelationInfo]: statement = ( select(DeviceModelRelation, AlgoModel) .join(AlgoModel, DeviceModelRelation.algo_model_id == AlgoModel.id) @@ -32,7 +32,8 @@ ) # 执行联表查询 - result = self.db.exec(statement).all() + results = await self.db.execute(statement) + rows = results.all() models_info = [ DeviceModelRelationInfo( @@ -46,13 +47,11 @@ algo_model_path=model.path, algo_model_remark=model.remark, ) - for relation, model in result + for relation, model in rows ] return models_info - - - def add_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]): + async def add_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]): new_relations = [ DeviceModelRelation( algo_model_id=relation.algo_model_id, @@ -64,20 +63,21 @@ ) for relation in relations ] - self.db.add_all(new_relations) - self.db.commit() for relation in new_relations: - self.db.refresh(relation) + self.db.add(relation) + await self.db.commit() + for relation in new_relations: + await self.db.refresh(relation) return new_relations - def delete_relations_by_device(self, device_id: int): + async def delete_relations_by_device(self, device_id: int): statement = delete(DeviceModelRelation).where(DeviceModelRelation.device_id == device_id) - count = self.db.exec(statement) - self.db.commit() - return count.rowcount + result = await self.db.execute(statement) + await self.db.commit() + return result.rowcount - def update_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]): - self.delete_relations_by_device(device_id) - new_relations = self.add_relations_by_device(device_id, relations) + async def update_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]): + await self.delete_relations_by_device(device_id) + new_relations = await self.add_relations_by_device(device_id, relations) self.notify_change(device_id, NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE) return new_relations diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index dbc3fee..cdb97bf 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -98,7 +98,7 @@ def push_frame_results(self, frame_results): global_config = GlobalConfig() push_config = global_config.get_algo_result_push_config() - if push_config: + if push_config and push_config.push_url: last_ts = self.push_ts current_time = datetime.now() diff --git a/apis/control.py b/apis/control.py index 8517389..4c108c5 100644 --- a/apis/control.py +++ b/apis/control.py @@ -48,7 +48,7 @@ @router.get("/restart") -def restart(): +async def restart(): try: # 立即返回响应的函数 def restart_container_async(): diff --git a/apis/data_gas.py b/apis/data_gas.py index fc8ec54..bcb7cb6 100644 --- a/apis/data_gas.py +++ b/apis/data_gas.py @@ -3,8 +3,8 @@ import pandas as pd from fastapi import APIRouter, Query, Depends -from openpyxl.utils import get_column_letter -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession + from starlette.responses import StreamingResponse from apis.base import StandardResponse, PageResponse, convert_page_param, standard_response @@ -17,19 +17,20 @@ @router.get("/page", response_model=StandardResponse[PageResponse[DataGasInfo]]) -def get_gas_page( +async def get_gas_page( device_code: Optional[str] = None, start_time: Optional[str] = None, end_time: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): + service = DataGasService(db) offset, limit = convert_page_param(offset, limit) - data, total = service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), - offset, limit) + data, total = await service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), + offset, limit) return standard_response( data=PageResponse(total=total, items=data) @@ -37,38 +38,39 @@ @router.get("/export") -def export_data_gas(device_code: Optional[str] = None, - start_time: Optional[str] = None, - end_time: Optional[str] = None, - db: Session = Depends(get_db)): - service = DataGasService(db) - data = service.get_data_gas_list(device_code, start_time, end_time) +async def export_data_gas(device_code: Optional[str] = None, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + db: AsyncSession = Depends(get_db)): + async with db: + service = DataGasService(db) + data = await service.get_data_gas_list(device_code, start_time, end_time) - # 将查询结果转换为 DataFrame - data = [ - {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} - for item in data] - df = pd.DataFrame(data) + # 将查询结果转换为 DataFrame + data = [ + {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} + for item in data] + df = pd.DataFrame(data) - # 使用 BytesIO 生成内存中的 Excel 文件 - output = BytesIO() - with pd.ExcelWriter(output, engine="openpyxl") as writer: - df.to_excel(writer, index=False) + # 使用 BytesIO 生成内存中的 Excel 文件 + output = BytesIO() + with pd.ExcelWriter(output, engine="openpyxl") as writer: + df.to_excel(writer, index=False) - # 获取工作表 - worksheet = writer.sheets["Sheet1"] + # 获取工作表 + worksheet = writer.sheets["Sheet1"] - # 设置固定的列宽,例如宽度为 20 - fixed_width = 30 - for col in worksheet.columns: - col_letter = col[0].column_letter # 获取列字母 - worksheet.column_dimensions[col_letter].width = fixed_width + # 设置固定的列宽,例如宽度为 20 + fixed_width = 30 + for col in worksheet.columns: + col_letter = col[0].column_letter # 获取列字母 + worksheet.column_dimensions[col_letter].width = fixed_width - output.seek(0) + output.seek(0) - # 返回内存中的 Excel 文件作为响应 - return StreamingResponse( - output, - media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", - headers={"Content-Disposition": "attachment; filename=gas.xlsx"} - ) + # 返回内存中的 Excel 文件作为响应 + return StreamingResponse( + output, + media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + headers={"Content-Disposition": "attachment; filename=gas.xlsx"} + ) diff --git a/apis/device.py b/apis/device.py index 85fb10c..f170ae3 100644 --- a/apis/device.py +++ b/apis/device.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db @@ -18,29 +18,29 @@ @router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) -def get_device_list( +async def get_device_list( name: Optional[str] = None, code: Optional[str] = None, device_type: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceService(db) - devices = list(service.get_device_list(name, code, device_type)) - return standard_response(data=devices) + devices = await service.get_device_list(name, code, device_type) + return standard_response(data=list(devices)) @router.get("/page/", response_model=StandardResponse[PageResponse[DeviceInfo]]) -def get_device_page( +async def get_device_page( name: Optional[str] = None, code: Optional[str] = None, device_type: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - devices, total = service.get_device_page(name, code, device_type, offset, limit) + devices, total = await service.get_device_page(name, code, device_type, offset, limit) return standard_response( data=PageResponse(total=total, items=devices) @@ -48,22 +48,22 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): - device = service.create_device(device_data) +async def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): + device = await service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): - device = service.update_device(device_data) +async def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): + device = await service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") return standard_response(data=device) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, service: DeviceService = Depends(get_service)): - device = service.delete_device(device_id) +async def delete_device(device_id: int, service: DeviceService = Depends(get_service)): + device = await service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") return standard_response(data=device_id) diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py index e61e1d6..ddb5eb6 100644 --- a/apis/device_model_realtion.py +++ b/apis/device_model_realtion.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import StandardResponse, standard_response from db.database import get_db @@ -13,37 +13,23 @@ router = APIRouter() app = get_app() + def get_service(): return app.state.model_relation_service @router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) -def list_by_device( +async def list_by_device( device_id: int, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceModelRelationService(db) - models = list(service.get_device_models(device_id)) - return standard_response(data=models) - - -# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) -# def add_by_device(relation_data: List[DeviceModelRelationCreate], -# device_id: int = Query(...), -# db: Session = Depends(get_db)): -# service = DeviceModelRelationService(db) -# relations = service.add_relations_by_device(device_id, relation_data) -# return standard_response(data=relations) + models = await service.get_device_models(device_id) + return standard_response(data=list(models)) @router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) -def update_by_device(relation_data: List[DeviceModelRelationCreate], - device_id: int = Query(...), - service: DeviceModelRelationService = Depends(get_service)): - relations = service.update_relations_by_device(device_id, relation_data) +async def update_by_device(relation_data: List[DeviceModelRelationCreate], + device_id: int = Query(...), + service: DeviceModelRelationService = Depends(get_service)): + relations = await service.update_relations_by_device(device_id, relation_data) return standard_response(data=relations) - -# @router.delete("/delete_by_device", response_model=StandardResponse[int]) -# def delete_device(device_id: int, db: Session = Depends(get_db)): -# service = DeviceModelRelationService(db) -# count = service.delete_relations_by_device(device_id) -# return standard_response(data=count) diff --git a/apis/device_scene_realtion.py b/apis/device_scene_realtion.py index 9a30a44..d9aa3a2 100644 --- a/apis/device_scene_realtion.py +++ b/apis/device_scene_realtion.py @@ -1,8 +1,8 @@ -from typing import List, Optional + from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import StandardResponse, standard_response from db.database import get_db from entity.device_scene_relation import DeviceSceneRelationInfo, DeviceSceneRelation @@ -13,17 +13,17 @@ @router.get("/get_by_device", response_model=StandardResponse[DeviceSceneRelationInfo]) -def list_by_device( +async def list_by_device( device_id: int, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - scene = service.get_device_scene(device_id) + scene = await service.get_device_scene(device_id) return standard_response(data=scene) @router.post("/update_by_device", response_model=StandardResponse[DeviceSceneRelation]) -def update_by_device(device_id: int, scene_id: int, - db: Session = Depends(get_db)): +async def update_by_device(device_id: int, scene_id: int, + db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - relation = service.update_relation_by_device(device_id, scene_id) + relation = await service.update_relation_by_device(device_id, scene_id) return standard_response(data=relation) diff --git a/apis/frame.py b/apis/frame.py index 1cb9294..5146ede 100644 --- a/apis/frame.py +++ b/apis/frame.py @@ -5,7 +5,7 @@ import cv2 from fastapi import APIRouter, Depends, Query, HTTPException, Request -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from fastapi.responses import StreamingResponse, FileResponse from apis.base import StandardResponse, PageResponse, standard_response, standard_error_response, convert_page_param @@ -19,20 +19,20 @@ @router.get("/page/", response_model=StandardResponse[PageResponse[DeviceFrame]]) -def get_frame_page( +async def get_frame_page( device_name: Optional[str] = None, device_code: Optional[str] = None, frame_start_time: Optional[str] = None, frame_end_time: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceFrameService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - frames, total = service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), - parse_datetime(frame_end_time), offset, limit) + frames, total = await service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), + parse_datetime(frame_end_time), offset, limit) return standard_response( data=PageResponse(total=total, items=frames) @@ -41,10 +41,10 @@ # 路由:使用 OpenCV 生成内存图像并返回字节流 @router.get("/frame_image/") -def get_frame_image(frame_id, db: Session = Depends(get_db)): +async def get_frame_image(frame_id, db: AsyncSession = Depends(get_db)): try: service = DeviceFrameService(db) - frame = service.get_frame_annotator(frame_id) + frame = await service.get_frame_annotator(frame_id) if frame is None: return standard_error_response(message="Frame does not exist") @@ -63,14 +63,3 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"Error reading image: {str(e)}") - -@router.get("/frame_test") -def get_frame_test(request: Request): - file_path = "test.jpg" - return FileResponse(file_path) - # def iterfile(): - # with open(file_path, 'rb') as file: - # while chunk := file.read(1024): - # yield chunk - # - # return StreamingResponse(iterfile(), media_type="image/jpeg") diff --git a/apis/model.py b/apis/model.py index a1161eb..fa2b191 100644 --- a/apis/model.py +++ b/apis/model.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query, UploadFile, File, Form -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from app_instance import get_app @@ -18,27 +18,27 @@ @router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) -def get_model_list( +async def get_model_list( name: Optional[str] = None, remark: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = ModelService(db) - models = list(service.get_model_list(name, remark)) - return standard_response(data=models) + models = await service.get_model_list(name, remark) + return standard_response(data=list(models)) @router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModelInfo]]) -def get_model_page( +async def get_model_page( name: Optional[str] = None, remark: Optional[str] = None, offset: int = Query(0, ge=0), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = ModelService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - models, total = service.get_model_page(name, remark, offset, limit) + models, total = await service.get_model_page(name, remark, offset, limit) return standard_response( data=PageResponse(total=total, items=models) @@ -46,9 +46,9 @@ @router.post("/add", response_model=StandardResponse[AlgoModelInfo]) -def create_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelCreate结构"), +async def create_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelCreate结构"), file: UploadFile = File(..., description="模型文件"), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_db), ): # 检查文件类型 if not file.filename.endswith(".zip"): @@ -56,25 +56,25 @@ model_data = AlgoModelCreate.parse_raw(json_data) service = ModelService(db) - model = service.create_model(model_data, file) + model = await service.create_model(model_data, file) return standard_response(data=model) @router.post("/update", response_model=StandardResponse[AlgoModelInfo]) -def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), +async def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), file: UploadFile = File(None, description="模型文件"), service: ModelService = Depends(get_service)): model_data = AlgoModelUpdate.parse_raw(json_data) - model = service.update_model(model_data, file) + model = await service.update_model(model_data, file) if not model: return standard_error_response(data=model_data, message="Model not found") return standard_response(data=model) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_model(model_id: int, db: Session = Depends(get_db)): +async def delete_model(model_id: int, db: AsyncSession = Depends(get_db)): service = ModelService(db) - model = service.delete_model(model_id) + model = await service.delete_model(model_id) if not model: return standard_error_response(data=model_id, message="Model not found") return standard_response(data=model_id) diff --git a/apis/push_config.py b/apis/push_config.py index 5e25579..e234afb 100644 --- a/apis/push_config.py +++ b/apis/push_config.py @@ -1,7 +1,7 @@ from typing import List from fastapi import APIRouter, Depends -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse from db.database import get_db @@ -12,21 +12,21 @@ @router.get("/list", response_model=StandardResponse[List[PushConfig]]) -def get_push_config_list(db: Session = Depends(get_db)): +async def get_push_config_list(db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_configs = service.get_push_config_list() + push_configs = await service.get_push_config_list() return standard_response(data=push_configs) @router.get("/get_by_type", response_model=StandardResponse[PushConfig]) -def get_by_type(push_type: int, db: Session = Depends(get_db)): +async def get_by_type(push_type: int, db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_config = service.get_push_config(push_type) + push_config = await service.get_push_config(push_type) return standard_response(data=push_config) @router.post("/set_push_config", response_model=StandardResponse[PushConfig]) -def set_push_config(push_config: PushConfigCreate, db: Session = Depends(get_db)): +async def set_push_config(push_config: PushConfigCreate, db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_config = service.set_push_config(push_config) + push_config = await service.set_push_config(push_config) return standard_response(data=push_config) diff --git a/apis/scene.py b/apis/scene.py index f84b54b..1b7f706 100644 --- a/apis/scene.py +++ b/apis/scene.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query, UploadFile, File, Form -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db @@ -14,27 +14,27 @@ @router.get("/list", response_model=StandardResponse[List[Scene]]) -def get_scene_list( +async def get_scene_list( name: Optional[str] = None, remark: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) - scenes = list(service.get_scene_list(name, remark)) - return standard_response(data=scenes) + scenes = await service.get_scene_list(name, remark) + return standard_response(data=list(scenes)) @router.get("/page/", response_model=StandardResponse[PageResponse[SceneInfo]]) -def get_scene_page( +async def get_scene_page( name: Optional[str] = None, remark: Optional[str] = None, offset: int = Query(0, ge=0), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - scenes, total = service.get_scene_page(name, remark, offset, limit) + scenes, total = await service.get_scene_page(name, remark, offset, limit) return standard_response( data=PageResponse(total=total, items=scenes) @@ -42,9 +42,9 @@ @router.post("/add", response_model=StandardResponse[SceneInfo]) -def create_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneCreate结构"), +async def create_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneCreate结构"), file: UploadFile = File(..., description="模型文件"), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_db), ): # 检查文件类型 if not file.filename.endswith(".zip"): @@ -52,26 +52,26 @@ scene_data = SceneCreate.parse_raw(json_data) service = SceneService(db) - scene = service.create_scene(scene_data, file) + scene = await service.create_scene(scene_data, file) return standard_response(data=scene) @router.post("/update", response_model=StandardResponse[SceneInfo]) -def update_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneUpdate结构"), +async def update_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneUpdate结构"), file: UploadFile = File(None, description="模型文件"), - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) scene_data = SceneUpdate.parse_raw(json_data) - scene = service.update_scene(scene_data, file) + scene = await service.update_scene(scene_data, file) if not scene: return standard_error_response(data=scene_data, message="Scene not found") return standard_response(data=scene) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_scene(scene_id: int, db: Session = Depends(get_db)): +async def delete_scene(scene_id: int, db: AsyncSession = Depends(get_db)): service = SceneService(db) - scene = service.delete_scene(scene_id) + scene = await service.delete_scene(scene_id) if not scene: return standard_error_response(data=scene_id, message="Scene not found") return standard_response(data=scene_id) diff --git a/app_instance.py b/app_instance.py index a6e0992..f9d1ad9 100644 --- a/app_instance.py +++ b/app_instance.py @@ -12,6 +12,7 @@ from services.device_model_relation_service import DeviceModelRelationService from services.device_scene_relation_service import DeviceSceneRelationService from services.device_service import DeviceService +from services.global_config import GlobalConfig from services.model_service import ModelService from services.scene_service import SceneService from tcp.tcp_manager import TcpManager @@ -28,7 +29,11 @@ async def lifespan(app: FastAPI): main_loop = asyncio.get_running_loop() - with next(get_db()) as db: + # async with get_db() as db: + async for db in get_db(): + global_config = GlobalConfig() + await global_config.init_config() + device_service = DeviceService(db) model_service = ModelService(db) model_relation_service = DeviceModelRelationService(db) @@ -51,7 +56,7 @@ relation_service=model_relation_service, ) app.state.algo_runner = algo_runner - await algo_runner.start() + # await algo_runner.start() scene_runner = SceneRunner( device_service=device_service, @@ -61,7 +66,9 @@ main_loop=main_loop ) app.state.scene_runner = scene_runner - await scene_runner.start() + # await scene_runner.start() + + yield # 允许请求处理 diff --git a/common/global_logger.py b/common/global_logger.py index 3e3a6a9..3e672ca 100644 --- a/common/global_logger.py +++ b/common/global_logger.py @@ -12,6 +12,8 @@ logger = logging.getLogger("casic_safe_logger") logger.setLevel(logging.DEBUG) # 设置日志级别 +logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING) + # 创建一个TimedRotatingFileHandler handler = logging.handlers.TimedRotatingFileHandler( os.path.join(log_dir, 'app.log'), # 日志文件名 diff --git a/db/database.py b/db/database.py index 4f44d2d..01ac013 100644 --- a/db/database.py +++ b/db/database.py @@ -1,22 +1,25 @@ +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession +from sqlalchemy.orm import sessionmaker from sqlmodel import SQLModel, create_engine, Session -from contextlib import contextmanager +from contextlib import contextmanager, asynccontextmanager sqlite_file_name = "./db/safe-algo-pro.db" -sqlite_url = f"sqlite:///{sqlite_file_name}" +sqlite_url = f"sqlite+aiosqlite:///{sqlite_file_name}" # 使用异步SQLite驱动 -connect_args = {"check_same_thread": False} -engine = create_engine(sqlite_url, connect_args=connect_args) +engine = create_async_engine(sqlite_url, echo=False, future=True) # 初始化数据库表 -def init_db(): - SQLModel.metadata.create_all(engine) +async def init_db(): + async with engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.create_all) -# 数据库会话管理 -def get_db(): - session = Session(engine) - try: +# 异步数据库会话管理 +# @asynccontextmanager +async def get_db() -> AsyncSession: + async_session = sessionmaker( + engine, class_=AsyncSession, expire_on_commit=False + ) + async with async_session() as session: yield session - finally: - session.close() diff --git a/main.py b/main.py index e35e458..e830e86 100644 --- a/main.py +++ b/main.py @@ -3,7 +3,9 @@ import logging from fastapi.openapi.docs import get_swagger_ui_html -from starlette.staticfiles import StaticFiles +from fastapi.middleware.cors import CORSMiddleware +from fastapi.staticfiles import StaticFiles + from app_instance import get_app from common.global_logger import logger @@ -28,6 +30,14 @@ from apis.router import router app.include_router(router, prefix="/api") +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + if __name__ == "__main__": # 重定向 uvicorn 的日志 uvicorn_logger = logging.getLogger("uvicorn") diff --git a/requirements.txt b/requirements.txt index 161192a..aa9e1e6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,6 @@ starlette uvicorn sqlalchemy -aiohttp \ No newline at end of file +aiohttp +aiosqlite +aiofiles \ No newline at end of file diff --git a/services/data_gas_service.py b/services/data_gas_service.py index fe6ce53..6ac5665 100644 --- a/services/data_gas_service.py +++ b/services/data_gas_service.py @@ -2,23 +2,24 @@ from typing import Optional, Tuple, Sequence from sqlalchemy import func -from sqlmodel import Session, select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlmodel import select from entity.data_gas import DataGas, DataGasInfo from entity.device import Device class DataGasService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db - def add_data_gas(self, data_gas: DataGas): + async def add_data_gas(self, data_gas: DataGas): self.db.add(data_gas) - self.db.commit() - self.db.refresh(data_gas) + await self.db.commit() + await self.db.refresh(data_gas) return data_gas - def get_data_gas_page(self, + async def get_data_gas_page(self, device_code: Optional[str] = None, start_time: Optional[datetime] = None, end_time: Optional[datetime] = None, @@ -29,13 +30,17 @@ # 查询总记录数 total_statement = select(func.count()).select_from(statement.subquery()) - total = self.db.exec(total_statement).one() + total_result = await self.db.execute(total_statement) + total = total_result.scalar_one() # 添加分页限制 statement = statement.offset(offset).limit(limit) # 执行查询并返回结果 - results = self.db.exec(statement).all() + results = await self.db.execute(statement) + rows = results.all() + + data_gas_info_list = [ DataGasInfo( id=data_gas.id, @@ -44,18 +49,19 @@ ts=data_gas.ts, device_name=device_name ) - for data_gas, device_name in results + for data_gas, device_name in rows ] return data_gas_info_list, total # 返回分页数据和总数 - def get_data_gas_list(self, + async def get_data_gas_list(self, device_code: Optional[str] = None, start_time: Optional[datetime] = None, end_time: Optional[datetime] = None, ) -> Sequence[DataGasInfo]: statement = self.gas_query(device_code, end_time, start_time) - results = self.db.exec(statement).all() + results = await self.db.execute(statement) + rows = results.all() data_gas_info_list = [ DataGasInfo( id=data_gas.id, @@ -64,7 +70,7 @@ ts=data_gas.ts, device_name=device_name ) - for data_gas, device_name in results + for data_gas, device_name in rows ] return data_gas_info_list diff --git a/services/device_frame_service.py b/services/device_frame_service.py index 5dead7d..1a468ae 100644 --- a/services/device_frame_service.py +++ b/services/device_frame_service.py @@ -3,7 +3,10 @@ from copy import deepcopy from datetime import datetime from typing import Sequence, Optional, Tuple -from sqlmodel import Session + +import aiofiles +import numpy as np +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import func @@ -13,46 +16,56 @@ import cv2 -from sqlmodel import Session, select, delete +from sqlmodel import select, delete from services.frame_analysis_result_service import FrameAnalysisResultService class DeviceFrameService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db - def add_frame(self, device_id, frame_data) -> DeviceFrame: - def save_frame_file(): - # 生成当前年月日作为目录路径 - current_date = datetime.now().strftime('%Y-%m-%d') - # 生成随机 UUID 作为文件名 - file_name = str(uuid.uuid4()) + ".jpeg" - # 创建保存图片的完整路径 - save_path = os.path.join('./storage/frames', current_date, file_name) - # 创建目录(如果不存在) - os.makedirs(os.path.dirname(save_path), exist_ok=True) - # 保存图片 - cv2.imwrite(save_path, frame_data) - return save_path + async def add_frame(self, device_id, frame_data) -> DeviceFrame: + async def add_frame(self, device_id, frame_data) -> 'DeviceFrame': + async def save_frame_file(): + # 生成当前年月日作为目录路径 + current_date = datetime.now().strftime('%Y-%m-%d') + # 生成随机 UUID 作为文件名 + file_name = f"{uuid.uuid4()}.jpeg" + # 创建保存图片的完整路径 + save_path = os.path.join('./storage/frames', current_date, file_name) + # 创建目录(如果不存在) + os.makedirs(os.path.dirname(save_path), exist_ok=True) - # 保存图片文件 - file_path = save_frame_file() - device_frame = DeviceFrame(device_id=device_id, frame_path=file_path) - self.db.add(device_frame) - self.db.commit() - self.db.refresh(device_frame) - return device_frame + # 将 frame_data 转换为二进制数据(假设 frame_data 是一个 numpy 数组) + _, encoded_image = cv2.imencode('.jpeg', frame_data) + image_data = encoded_image.tobytes() - def get_frame_page(self, - device_name: Optional[str] = None, - device_code: Optional[str] = None, - frame_start_time: Optional[datetime] = None, - frame_end_time: Optional[datetime] = None, - offset: int = 0, - limit: int = 10 - ) -> Tuple[Sequence[DeviceFrame], int]: + # 使用 aiofiles 进行异步写入 + async with aiofiles.open(save_path, 'wb') as f: + await f.write(image_data) + + return save_path + + # 异步保存图片文件 + file_path = await save_frame_file() + + # 创建并保存到数据库中 + device_frame = DeviceFrame(device_id=device_id, frame_path=file_path) + self.db.add(device_frame) + await self.db.commit() + await self.db.refresh(device_frame) + return device_frame + + async def get_frame_page(self, + device_name: Optional[str] = None, + device_code: Optional[str] = None, + frame_start_time: Optional[datetime] = None, + frame_end_time: Optional[datetime] = None, + offset: int = 0, + limit: int = 10 + ) -> Tuple[Sequence[DeviceFrame], int]: statement = ( select(DeviceFrame, Device) .join(Device, DeviceFrame.device_id == Device.id) @@ -69,26 +82,37 @@ # 查询总记录数 total_statement = select(func.count()).select_from(statement.subquery()) - total = self.db.exec(total_statement).one() + total_result = await self.db.execute(total_statement) + total = total_result.scalar_one() # 添加分页限制 statement = statement.offset(offset).limit(limit) # 执行查询并返回结果 - results = self.db.exec(statement) - results = results.all() - frames = [frame for frame, device in results] + results = await self.db.execute(statement) + rows = results.all() + + frames = [frame for frame, device in rows] return frames, total # 返回分页数据和总数 - def get_frame(self, frame_id: int): - return self.db.get(DeviceFrame, frame_id) + async def get_frame(self, frame_id: int): + result = await self.db.execute(select(DeviceFrame).where(DeviceFrame.id == frame_id)) + frame = result.scalar_one_or_none() + return frame - def get_frame_annotator(self, frame_id: int): - device_frame = self.get_frame(frame_id) + async def get_frame_annotator(self, frame_id: int): + device_frame = await self.get_frame(frame_id) if device_frame: - frame_image = cv2.imread(device_frame.frame_path) + # 异步读取图像文件 + async with aiofiles.open(device_frame.frame_path, mode='rb') as f: + file_content = await f.read() + + # 将读取的字节内容转换为 OpenCV 图像 + np_array = np.frombuffer(file_content, dtype=np.uint8) + frame_image = cv2.imdecode(np_array, cv2.IMREAD_COLOR) + frame_analysis_result_service = FrameAnalysisResultService(self.db) - results = frame_analysis_result_service.get_results_by_frame(device_frame.id) + results = await frame_analysis_result_service.get_results_by_frame(device_frame.id) if results: annotator = Annotator(deepcopy(frame_image)) height, width = frame_image.shape[:2] diff --git a/services/device_model_relation_service.py b/services/device_model_relation_service.py index c32b5a5..5f1aa9d 100644 --- a/services/device_model_relation_service.py +++ b/services/device_model_relation_service.py @@ -1,8 +1,8 @@ from datetime import datetime from typing import List -from sqlmodel import Session, select, delete - +from sqlmodel import select, delete +from sqlalchemy.ext.asyncio import AsyncSession from common.consts import NotifyChangeType from common.global_thread_pool import GlobalThreadPool from entity.device_model_relation import DeviceModelRelation, DeviceModelRelationInfo, DeviceModelRelationCreate @@ -10,7 +10,7 @@ class DeviceModelRelationService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db self.__relation_change_callbacks = [] # 用于存储回调函数 self.thread_pool = GlobalThreadPool() @@ -24,7 +24,7 @@ for callback in self.__relation_change_callbacks: self.thread_pool.executor.submit(callback, device_id, change_type) - def get_device_models(self, device_id: int) -> List[DeviceModelRelationInfo]: + async def get_device_models(self, device_id: int) -> List[DeviceModelRelationInfo]: statement = ( select(DeviceModelRelation, AlgoModel) .join(AlgoModel, DeviceModelRelation.algo_model_id == AlgoModel.id) @@ -32,7 +32,8 @@ ) # 执行联表查询 - result = self.db.exec(statement).all() + results = await self.db.execute(statement) + rows = results.all() models_info = [ DeviceModelRelationInfo( @@ -46,13 +47,11 @@ algo_model_path=model.path, algo_model_remark=model.remark, ) - for relation, model in result + for relation, model in rows ] return models_info - - - def add_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]): + async def add_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]): new_relations = [ DeviceModelRelation( algo_model_id=relation.algo_model_id, @@ -64,20 +63,21 @@ ) for relation in relations ] - self.db.add_all(new_relations) - self.db.commit() for relation in new_relations: - self.db.refresh(relation) + self.db.add(relation) + await self.db.commit() + for relation in new_relations: + await self.db.refresh(relation) return new_relations - def delete_relations_by_device(self, device_id: int): + async def delete_relations_by_device(self, device_id: int): statement = delete(DeviceModelRelation).where(DeviceModelRelation.device_id == device_id) - count = self.db.exec(statement) - self.db.commit() - return count.rowcount + result = await self.db.execute(statement) + await self.db.commit() + return result.rowcount - def update_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]): - self.delete_relations_by_device(device_id) - new_relations = self.add_relations_by_device(device_id, relations) + async def update_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]): + await self.delete_relations_by_device(device_id) + new_relations = await self.add_relations_by_device(device_id, relations) self.notify_change(device_id, NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE) return new_relations diff --git a/services/device_scene_relation_service.py b/services/device_scene_relation_service.py index 592ee98..3416831 100644 --- a/services/device_scene_relation_service.py +++ b/services/device_scene_relation_service.py @@ -1,8 +1,8 @@ from datetime import datetime from typing import List, Optional -from sqlmodel import Session, select, delete - +from sqlmodel import select, delete +from sqlalchemy.ext.asyncio import AsyncSession from common.consts import NotifyChangeType from common.global_thread_pool import GlobalThreadPool from entity.device_scene_relation import DeviceSceneRelationInfo, DeviceSceneRelation @@ -10,7 +10,7 @@ class DeviceSceneRelationService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db self.__relation_change_callbacks = [] # 用于存储回调函数 self.thread_pool = GlobalThreadPool() @@ -24,7 +24,7 @@ for callback in self.__relation_change_callbacks: self.thread_pool.executor.submit(callback, device_id, change_type) - def get_device_scene(self, device_id: int) -> Optional[DeviceSceneRelationInfo]: + async def get_device_scene(self, device_id: int) -> Optional[DeviceSceneRelationInfo]: statement = ( select(DeviceSceneRelation, Scene) .join(Scene, DeviceSceneRelation.scene_id == Scene.id) @@ -32,11 +32,12 @@ ) # 执行联表查询 - result = self.db.exec(statement).first() + result = await self.db.execute(statement) + result_row = result.first() scene_info = None - if result: - relation, scene = result[0], result[1] + if result_row: + relation, scene = result_row scene_info = DeviceSceneRelationInfo( id=relation.id, device_id=relation.device_id, @@ -48,23 +49,23 @@ ) return scene_info - def add_relation_by_device(self, device_id: int, scene_id: int): + async def add_relation_by_device(self, device_id: int, scene_id: int): new_relation = DeviceSceneRelation(device_id=device_id, scene_id=scene_id) new_relation.create_time = datetime.now() new_relation.update_time = datetime.now() self.db.add(new_relation) - self.db.commit() - self.db.refresh(new_relation) + await self.db.commit() + await self.db.refresh(new_relation) return new_relation - def delete_relation_by_device(self, device_id: int): + async def delete_relation_by_device(self, device_id: int): statement = delete(DeviceSceneRelation).where(DeviceSceneRelation.device_id == device_id) - count = self.db.exec(statement) - self.db.commit() - return count.rowcount + result = await self.db.execute(statement) + await self.db.commit() + return result.rowcount - def update_relation_by_device(self, device_id: int, scene_id: int): - self.delete_relation_by_device(device_id) - new_relation = self.add_relation_by_device(device_id, scene_id) + async def update_relation_by_device(self, device_id: int, scene_id: int): + await self.delete_relation_by_device(device_id) + new_relation = await self.add_relation_by_device(device_id, scene_id) self.notify_change(device_id, NotifyChangeType.DEVICE_SCENE_RELATION_UPDATE) return new_relation diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index dbc3fee..cdb97bf 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -98,7 +98,7 @@ def push_frame_results(self, frame_results): global_config = GlobalConfig() push_config = global_config.get_algo_result_push_config() - if push_config: + if push_config and push_config.push_url: last_ts = self.push_ts current_time = datetime.now() diff --git a/apis/control.py b/apis/control.py index 8517389..4c108c5 100644 --- a/apis/control.py +++ b/apis/control.py @@ -48,7 +48,7 @@ @router.get("/restart") -def restart(): +async def restart(): try: # 立即返回响应的函数 def restart_container_async(): diff --git a/apis/data_gas.py b/apis/data_gas.py index fc8ec54..bcb7cb6 100644 --- a/apis/data_gas.py +++ b/apis/data_gas.py @@ -3,8 +3,8 @@ import pandas as pd from fastapi import APIRouter, Query, Depends -from openpyxl.utils import get_column_letter -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession + from starlette.responses import StreamingResponse from apis.base import StandardResponse, PageResponse, convert_page_param, standard_response @@ -17,19 +17,20 @@ @router.get("/page", response_model=StandardResponse[PageResponse[DataGasInfo]]) -def get_gas_page( +async def get_gas_page( device_code: Optional[str] = None, start_time: Optional[str] = None, end_time: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): + service = DataGasService(db) offset, limit = convert_page_param(offset, limit) - data, total = service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), - offset, limit) + data, total = await service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), + offset, limit) return standard_response( data=PageResponse(total=total, items=data) @@ -37,38 +38,39 @@ @router.get("/export") -def export_data_gas(device_code: Optional[str] = None, - start_time: Optional[str] = None, - end_time: Optional[str] = None, - db: Session = Depends(get_db)): - service = DataGasService(db) - data = service.get_data_gas_list(device_code, start_time, end_time) +async def export_data_gas(device_code: Optional[str] = None, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + db: AsyncSession = Depends(get_db)): + async with db: + service = DataGasService(db) + data = await service.get_data_gas_list(device_code, start_time, end_time) - # 将查询结果转换为 DataFrame - data = [ - {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} - for item in data] - df = pd.DataFrame(data) + # 将查询结果转换为 DataFrame + data = [ + {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} + for item in data] + df = pd.DataFrame(data) - # 使用 BytesIO 生成内存中的 Excel 文件 - output = BytesIO() - with pd.ExcelWriter(output, engine="openpyxl") as writer: - df.to_excel(writer, index=False) + # 使用 BytesIO 生成内存中的 Excel 文件 + output = BytesIO() + with pd.ExcelWriter(output, engine="openpyxl") as writer: + df.to_excel(writer, index=False) - # 获取工作表 - worksheet = writer.sheets["Sheet1"] + # 获取工作表 + worksheet = writer.sheets["Sheet1"] - # 设置固定的列宽,例如宽度为 20 - fixed_width = 30 - for col in worksheet.columns: - col_letter = col[0].column_letter # 获取列字母 - worksheet.column_dimensions[col_letter].width = fixed_width + # 设置固定的列宽,例如宽度为 20 + fixed_width = 30 + for col in worksheet.columns: + col_letter = col[0].column_letter # 获取列字母 + worksheet.column_dimensions[col_letter].width = fixed_width - output.seek(0) + output.seek(0) - # 返回内存中的 Excel 文件作为响应 - return StreamingResponse( - output, - media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", - headers={"Content-Disposition": "attachment; filename=gas.xlsx"} - ) + # 返回内存中的 Excel 文件作为响应 + return StreamingResponse( + output, + media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + headers={"Content-Disposition": "attachment; filename=gas.xlsx"} + ) diff --git a/apis/device.py b/apis/device.py index 85fb10c..f170ae3 100644 --- a/apis/device.py +++ b/apis/device.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db @@ -18,29 +18,29 @@ @router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) -def get_device_list( +async def get_device_list( name: Optional[str] = None, code: Optional[str] = None, device_type: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceService(db) - devices = list(service.get_device_list(name, code, device_type)) - return standard_response(data=devices) + devices = await service.get_device_list(name, code, device_type) + return standard_response(data=list(devices)) @router.get("/page/", response_model=StandardResponse[PageResponse[DeviceInfo]]) -def get_device_page( +async def get_device_page( name: Optional[str] = None, code: Optional[str] = None, device_type: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - devices, total = service.get_device_page(name, code, device_type, offset, limit) + devices, total = await service.get_device_page(name, code, device_type, offset, limit) return standard_response( data=PageResponse(total=total, items=devices) @@ -48,22 +48,22 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): - device = service.create_device(device_data) +async def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): + device = await service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): - device = service.update_device(device_data) +async def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): + device = await service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") return standard_response(data=device) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, service: DeviceService = Depends(get_service)): - device = service.delete_device(device_id) +async def delete_device(device_id: int, service: DeviceService = Depends(get_service)): + device = await service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") return standard_response(data=device_id) diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py index e61e1d6..ddb5eb6 100644 --- a/apis/device_model_realtion.py +++ b/apis/device_model_realtion.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import StandardResponse, standard_response from db.database import get_db @@ -13,37 +13,23 @@ router = APIRouter() app = get_app() + def get_service(): return app.state.model_relation_service @router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) -def list_by_device( +async def list_by_device( device_id: int, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceModelRelationService(db) - models = list(service.get_device_models(device_id)) - return standard_response(data=models) - - -# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) -# def add_by_device(relation_data: List[DeviceModelRelationCreate], -# device_id: int = Query(...), -# db: Session = Depends(get_db)): -# service = DeviceModelRelationService(db) -# relations = service.add_relations_by_device(device_id, relation_data) -# return standard_response(data=relations) + models = await service.get_device_models(device_id) + return standard_response(data=list(models)) @router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) -def update_by_device(relation_data: List[DeviceModelRelationCreate], - device_id: int = Query(...), - service: DeviceModelRelationService = Depends(get_service)): - relations = service.update_relations_by_device(device_id, relation_data) +async def update_by_device(relation_data: List[DeviceModelRelationCreate], + device_id: int = Query(...), + service: DeviceModelRelationService = Depends(get_service)): + relations = await service.update_relations_by_device(device_id, relation_data) return standard_response(data=relations) - -# @router.delete("/delete_by_device", response_model=StandardResponse[int]) -# def delete_device(device_id: int, db: Session = Depends(get_db)): -# service = DeviceModelRelationService(db) -# count = service.delete_relations_by_device(device_id) -# return standard_response(data=count) diff --git a/apis/device_scene_realtion.py b/apis/device_scene_realtion.py index 9a30a44..d9aa3a2 100644 --- a/apis/device_scene_realtion.py +++ b/apis/device_scene_realtion.py @@ -1,8 +1,8 @@ -from typing import List, Optional + from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import StandardResponse, standard_response from db.database import get_db from entity.device_scene_relation import DeviceSceneRelationInfo, DeviceSceneRelation @@ -13,17 +13,17 @@ @router.get("/get_by_device", response_model=StandardResponse[DeviceSceneRelationInfo]) -def list_by_device( +async def list_by_device( device_id: int, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - scene = service.get_device_scene(device_id) + scene = await service.get_device_scene(device_id) return standard_response(data=scene) @router.post("/update_by_device", response_model=StandardResponse[DeviceSceneRelation]) -def update_by_device(device_id: int, scene_id: int, - db: Session = Depends(get_db)): +async def update_by_device(device_id: int, scene_id: int, + db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - relation = service.update_relation_by_device(device_id, scene_id) + relation = await service.update_relation_by_device(device_id, scene_id) return standard_response(data=relation) diff --git a/apis/frame.py b/apis/frame.py index 1cb9294..5146ede 100644 --- a/apis/frame.py +++ b/apis/frame.py @@ -5,7 +5,7 @@ import cv2 from fastapi import APIRouter, Depends, Query, HTTPException, Request -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from fastapi.responses import StreamingResponse, FileResponse from apis.base import StandardResponse, PageResponse, standard_response, standard_error_response, convert_page_param @@ -19,20 +19,20 @@ @router.get("/page/", response_model=StandardResponse[PageResponse[DeviceFrame]]) -def get_frame_page( +async def get_frame_page( device_name: Optional[str] = None, device_code: Optional[str] = None, frame_start_time: Optional[str] = None, frame_end_time: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceFrameService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - frames, total = service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), - parse_datetime(frame_end_time), offset, limit) + frames, total = await service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), + parse_datetime(frame_end_time), offset, limit) return standard_response( data=PageResponse(total=total, items=frames) @@ -41,10 +41,10 @@ # 路由:使用 OpenCV 生成内存图像并返回字节流 @router.get("/frame_image/") -def get_frame_image(frame_id, db: Session = Depends(get_db)): +async def get_frame_image(frame_id, db: AsyncSession = Depends(get_db)): try: service = DeviceFrameService(db) - frame = service.get_frame_annotator(frame_id) + frame = await service.get_frame_annotator(frame_id) if frame is None: return standard_error_response(message="Frame does not exist") @@ -63,14 +63,3 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"Error reading image: {str(e)}") - -@router.get("/frame_test") -def get_frame_test(request: Request): - file_path = "test.jpg" - return FileResponse(file_path) - # def iterfile(): - # with open(file_path, 'rb') as file: - # while chunk := file.read(1024): - # yield chunk - # - # return StreamingResponse(iterfile(), media_type="image/jpeg") diff --git a/apis/model.py b/apis/model.py index a1161eb..fa2b191 100644 --- a/apis/model.py +++ b/apis/model.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query, UploadFile, File, Form -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from app_instance import get_app @@ -18,27 +18,27 @@ @router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) -def get_model_list( +async def get_model_list( name: Optional[str] = None, remark: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = ModelService(db) - models = list(service.get_model_list(name, remark)) - return standard_response(data=models) + models = await service.get_model_list(name, remark) + return standard_response(data=list(models)) @router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModelInfo]]) -def get_model_page( +async def get_model_page( name: Optional[str] = None, remark: Optional[str] = None, offset: int = Query(0, ge=0), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = ModelService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - models, total = service.get_model_page(name, remark, offset, limit) + models, total = await service.get_model_page(name, remark, offset, limit) return standard_response( data=PageResponse(total=total, items=models) @@ -46,9 +46,9 @@ @router.post("/add", response_model=StandardResponse[AlgoModelInfo]) -def create_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelCreate结构"), +async def create_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelCreate结构"), file: UploadFile = File(..., description="模型文件"), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_db), ): # 检查文件类型 if not file.filename.endswith(".zip"): @@ -56,25 +56,25 @@ model_data = AlgoModelCreate.parse_raw(json_data) service = ModelService(db) - model = service.create_model(model_data, file) + model = await service.create_model(model_data, file) return standard_response(data=model) @router.post("/update", response_model=StandardResponse[AlgoModelInfo]) -def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), +async def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), file: UploadFile = File(None, description="模型文件"), service: ModelService = Depends(get_service)): model_data = AlgoModelUpdate.parse_raw(json_data) - model = service.update_model(model_data, file) + model = await service.update_model(model_data, file) if not model: return standard_error_response(data=model_data, message="Model not found") return standard_response(data=model) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_model(model_id: int, db: Session = Depends(get_db)): +async def delete_model(model_id: int, db: AsyncSession = Depends(get_db)): service = ModelService(db) - model = service.delete_model(model_id) + model = await service.delete_model(model_id) if not model: return standard_error_response(data=model_id, message="Model not found") return standard_response(data=model_id) diff --git a/apis/push_config.py b/apis/push_config.py index 5e25579..e234afb 100644 --- a/apis/push_config.py +++ b/apis/push_config.py @@ -1,7 +1,7 @@ from typing import List from fastapi import APIRouter, Depends -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse from db.database import get_db @@ -12,21 +12,21 @@ @router.get("/list", response_model=StandardResponse[List[PushConfig]]) -def get_push_config_list(db: Session = Depends(get_db)): +async def get_push_config_list(db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_configs = service.get_push_config_list() + push_configs = await service.get_push_config_list() return standard_response(data=push_configs) @router.get("/get_by_type", response_model=StandardResponse[PushConfig]) -def get_by_type(push_type: int, db: Session = Depends(get_db)): +async def get_by_type(push_type: int, db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_config = service.get_push_config(push_type) + push_config = await service.get_push_config(push_type) return standard_response(data=push_config) @router.post("/set_push_config", response_model=StandardResponse[PushConfig]) -def set_push_config(push_config: PushConfigCreate, db: Session = Depends(get_db)): +async def set_push_config(push_config: PushConfigCreate, db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_config = service.set_push_config(push_config) + push_config = await service.set_push_config(push_config) return standard_response(data=push_config) diff --git a/apis/scene.py b/apis/scene.py index f84b54b..1b7f706 100644 --- a/apis/scene.py +++ b/apis/scene.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query, UploadFile, File, Form -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db @@ -14,27 +14,27 @@ @router.get("/list", response_model=StandardResponse[List[Scene]]) -def get_scene_list( +async def get_scene_list( name: Optional[str] = None, remark: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) - scenes = list(service.get_scene_list(name, remark)) - return standard_response(data=scenes) + scenes = await service.get_scene_list(name, remark) + return standard_response(data=list(scenes)) @router.get("/page/", response_model=StandardResponse[PageResponse[SceneInfo]]) -def get_scene_page( +async def get_scene_page( name: Optional[str] = None, remark: Optional[str] = None, offset: int = Query(0, ge=0), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - scenes, total = service.get_scene_page(name, remark, offset, limit) + scenes, total = await service.get_scene_page(name, remark, offset, limit) return standard_response( data=PageResponse(total=total, items=scenes) @@ -42,9 +42,9 @@ @router.post("/add", response_model=StandardResponse[SceneInfo]) -def create_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneCreate结构"), +async def create_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneCreate结构"), file: UploadFile = File(..., description="模型文件"), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_db), ): # 检查文件类型 if not file.filename.endswith(".zip"): @@ -52,26 +52,26 @@ scene_data = SceneCreate.parse_raw(json_data) service = SceneService(db) - scene = service.create_scene(scene_data, file) + scene = await service.create_scene(scene_data, file) return standard_response(data=scene) @router.post("/update", response_model=StandardResponse[SceneInfo]) -def update_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneUpdate结构"), +async def update_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneUpdate结构"), file: UploadFile = File(None, description="模型文件"), - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) scene_data = SceneUpdate.parse_raw(json_data) - scene = service.update_scene(scene_data, file) + scene = await service.update_scene(scene_data, file) if not scene: return standard_error_response(data=scene_data, message="Scene not found") return standard_response(data=scene) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_scene(scene_id: int, db: Session = Depends(get_db)): +async def delete_scene(scene_id: int, db: AsyncSession = Depends(get_db)): service = SceneService(db) - scene = service.delete_scene(scene_id) + scene = await service.delete_scene(scene_id) if not scene: return standard_error_response(data=scene_id, message="Scene not found") return standard_response(data=scene_id) diff --git a/app_instance.py b/app_instance.py index a6e0992..f9d1ad9 100644 --- a/app_instance.py +++ b/app_instance.py @@ -12,6 +12,7 @@ from services.device_model_relation_service import DeviceModelRelationService from services.device_scene_relation_service import DeviceSceneRelationService from services.device_service import DeviceService +from services.global_config import GlobalConfig from services.model_service import ModelService from services.scene_service import SceneService from tcp.tcp_manager import TcpManager @@ -28,7 +29,11 @@ async def lifespan(app: FastAPI): main_loop = asyncio.get_running_loop() - with next(get_db()) as db: + # async with get_db() as db: + async for db in get_db(): + global_config = GlobalConfig() + await global_config.init_config() + device_service = DeviceService(db) model_service = ModelService(db) model_relation_service = DeviceModelRelationService(db) @@ -51,7 +56,7 @@ relation_service=model_relation_service, ) app.state.algo_runner = algo_runner - await algo_runner.start() + # await algo_runner.start() scene_runner = SceneRunner( device_service=device_service, @@ -61,7 +66,9 @@ main_loop=main_loop ) app.state.scene_runner = scene_runner - await scene_runner.start() + # await scene_runner.start() + + yield # 允许请求处理 diff --git a/common/global_logger.py b/common/global_logger.py index 3e3a6a9..3e672ca 100644 --- a/common/global_logger.py +++ b/common/global_logger.py @@ -12,6 +12,8 @@ logger = logging.getLogger("casic_safe_logger") logger.setLevel(logging.DEBUG) # 设置日志级别 +logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING) + # 创建一个TimedRotatingFileHandler handler = logging.handlers.TimedRotatingFileHandler( os.path.join(log_dir, 'app.log'), # 日志文件名 diff --git a/db/database.py b/db/database.py index 4f44d2d..01ac013 100644 --- a/db/database.py +++ b/db/database.py @@ -1,22 +1,25 @@ +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession +from sqlalchemy.orm import sessionmaker from sqlmodel import SQLModel, create_engine, Session -from contextlib import contextmanager +from contextlib import contextmanager, asynccontextmanager sqlite_file_name = "./db/safe-algo-pro.db" -sqlite_url = f"sqlite:///{sqlite_file_name}" +sqlite_url = f"sqlite+aiosqlite:///{sqlite_file_name}" # 使用异步SQLite驱动 -connect_args = {"check_same_thread": False} -engine = create_engine(sqlite_url, connect_args=connect_args) +engine = create_async_engine(sqlite_url, echo=False, future=True) # 初始化数据库表 -def init_db(): - SQLModel.metadata.create_all(engine) +async def init_db(): + async with engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.create_all) -# 数据库会话管理 -def get_db(): - session = Session(engine) - try: +# 异步数据库会话管理 +# @asynccontextmanager +async def get_db() -> AsyncSession: + async_session = sessionmaker( + engine, class_=AsyncSession, expire_on_commit=False + ) + async with async_session() as session: yield session - finally: - session.close() diff --git a/main.py b/main.py index e35e458..e830e86 100644 --- a/main.py +++ b/main.py @@ -3,7 +3,9 @@ import logging from fastapi.openapi.docs import get_swagger_ui_html -from starlette.staticfiles import StaticFiles +from fastapi.middleware.cors import CORSMiddleware +from fastapi.staticfiles import StaticFiles + from app_instance import get_app from common.global_logger import logger @@ -28,6 +30,14 @@ from apis.router import router app.include_router(router, prefix="/api") +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + if __name__ == "__main__": # 重定向 uvicorn 的日志 uvicorn_logger = logging.getLogger("uvicorn") diff --git a/requirements.txt b/requirements.txt index 161192a..aa9e1e6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,6 @@ starlette uvicorn sqlalchemy -aiohttp \ No newline at end of file +aiohttp +aiosqlite +aiofiles \ No newline at end of file diff --git a/services/data_gas_service.py b/services/data_gas_service.py index fe6ce53..6ac5665 100644 --- a/services/data_gas_service.py +++ b/services/data_gas_service.py @@ -2,23 +2,24 @@ from typing import Optional, Tuple, Sequence from sqlalchemy import func -from sqlmodel import Session, select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlmodel import select from entity.data_gas import DataGas, DataGasInfo from entity.device import Device class DataGasService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db - def add_data_gas(self, data_gas: DataGas): + async def add_data_gas(self, data_gas: DataGas): self.db.add(data_gas) - self.db.commit() - self.db.refresh(data_gas) + await self.db.commit() + await self.db.refresh(data_gas) return data_gas - def get_data_gas_page(self, + async def get_data_gas_page(self, device_code: Optional[str] = None, start_time: Optional[datetime] = None, end_time: Optional[datetime] = None, @@ -29,13 +30,17 @@ # 查询总记录数 total_statement = select(func.count()).select_from(statement.subquery()) - total = self.db.exec(total_statement).one() + total_result = await self.db.execute(total_statement) + total = total_result.scalar_one() # 添加分页限制 statement = statement.offset(offset).limit(limit) # 执行查询并返回结果 - results = self.db.exec(statement).all() + results = await self.db.execute(statement) + rows = results.all() + + data_gas_info_list = [ DataGasInfo( id=data_gas.id, @@ -44,18 +49,19 @@ ts=data_gas.ts, device_name=device_name ) - for data_gas, device_name in results + for data_gas, device_name in rows ] return data_gas_info_list, total # 返回分页数据和总数 - def get_data_gas_list(self, + async def get_data_gas_list(self, device_code: Optional[str] = None, start_time: Optional[datetime] = None, end_time: Optional[datetime] = None, ) -> Sequence[DataGasInfo]: statement = self.gas_query(device_code, end_time, start_time) - results = self.db.exec(statement).all() + results = await self.db.execute(statement) + rows = results.all() data_gas_info_list = [ DataGasInfo( id=data_gas.id, @@ -64,7 +70,7 @@ ts=data_gas.ts, device_name=device_name ) - for data_gas, device_name in results + for data_gas, device_name in rows ] return data_gas_info_list diff --git a/services/device_frame_service.py b/services/device_frame_service.py index 5dead7d..1a468ae 100644 --- a/services/device_frame_service.py +++ b/services/device_frame_service.py @@ -3,7 +3,10 @@ from copy import deepcopy from datetime import datetime from typing import Sequence, Optional, Tuple -from sqlmodel import Session + +import aiofiles +import numpy as np +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import func @@ -13,46 +16,56 @@ import cv2 -from sqlmodel import Session, select, delete +from sqlmodel import select, delete from services.frame_analysis_result_service import FrameAnalysisResultService class DeviceFrameService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db - def add_frame(self, device_id, frame_data) -> DeviceFrame: - def save_frame_file(): - # 生成当前年月日作为目录路径 - current_date = datetime.now().strftime('%Y-%m-%d') - # 生成随机 UUID 作为文件名 - file_name = str(uuid.uuid4()) + ".jpeg" - # 创建保存图片的完整路径 - save_path = os.path.join('./storage/frames', current_date, file_name) - # 创建目录(如果不存在) - os.makedirs(os.path.dirname(save_path), exist_ok=True) - # 保存图片 - cv2.imwrite(save_path, frame_data) - return save_path + async def add_frame(self, device_id, frame_data) -> DeviceFrame: + async def add_frame(self, device_id, frame_data) -> 'DeviceFrame': + async def save_frame_file(): + # 生成当前年月日作为目录路径 + current_date = datetime.now().strftime('%Y-%m-%d') + # 生成随机 UUID 作为文件名 + file_name = f"{uuid.uuid4()}.jpeg" + # 创建保存图片的完整路径 + save_path = os.path.join('./storage/frames', current_date, file_name) + # 创建目录(如果不存在) + os.makedirs(os.path.dirname(save_path), exist_ok=True) - # 保存图片文件 - file_path = save_frame_file() - device_frame = DeviceFrame(device_id=device_id, frame_path=file_path) - self.db.add(device_frame) - self.db.commit() - self.db.refresh(device_frame) - return device_frame + # 将 frame_data 转换为二进制数据(假设 frame_data 是一个 numpy 数组) + _, encoded_image = cv2.imencode('.jpeg', frame_data) + image_data = encoded_image.tobytes() - def get_frame_page(self, - device_name: Optional[str] = None, - device_code: Optional[str] = None, - frame_start_time: Optional[datetime] = None, - frame_end_time: Optional[datetime] = None, - offset: int = 0, - limit: int = 10 - ) -> Tuple[Sequence[DeviceFrame], int]: + # 使用 aiofiles 进行异步写入 + async with aiofiles.open(save_path, 'wb') as f: + await f.write(image_data) + + return save_path + + # 异步保存图片文件 + file_path = await save_frame_file() + + # 创建并保存到数据库中 + device_frame = DeviceFrame(device_id=device_id, frame_path=file_path) + self.db.add(device_frame) + await self.db.commit() + await self.db.refresh(device_frame) + return device_frame + + async def get_frame_page(self, + device_name: Optional[str] = None, + device_code: Optional[str] = None, + frame_start_time: Optional[datetime] = None, + frame_end_time: Optional[datetime] = None, + offset: int = 0, + limit: int = 10 + ) -> Tuple[Sequence[DeviceFrame], int]: statement = ( select(DeviceFrame, Device) .join(Device, DeviceFrame.device_id == Device.id) @@ -69,26 +82,37 @@ # 查询总记录数 total_statement = select(func.count()).select_from(statement.subquery()) - total = self.db.exec(total_statement).one() + total_result = await self.db.execute(total_statement) + total = total_result.scalar_one() # 添加分页限制 statement = statement.offset(offset).limit(limit) # 执行查询并返回结果 - results = self.db.exec(statement) - results = results.all() - frames = [frame for frame, device in results] + results = await self.db.execute(statement) + rows = results.all() + + frames = [frame for frame, device in rows] return frames, total # 返回分页数据和总数 - def get_frame(self, frame_id: int): - return self.db.get(DeviceFrame, frame_id) + async def get_frame(self, frame_id: int): + result = await self.db.execute(select(DeviceFrame).where(DeviceFrame.id == frame_id)) + frame = result.scalar_one_or_none() + return frame - def get_frame_annotator(self, frame_id: int): - device_frame = self.get_frame(frame_id) + async def get_frame_annotator(self, frame_id: int): + device_frame = await self.get_frame(frame_id) if device_frame: - frame_image = cv2.imread(device_frame.frame_path) + # 异步读取图像文件 + async with aiofiles.open(device_frame.frame_path, mode='rb') as f: + file_content = await f.read() + + # 将读取的字节内容转换为 OpenCV 图像 + np_array = np.frombuffer(file_content, dtype=np.uint8) + frame_image = cv2.imdecode(np_array, cv2.IMREAD_COLOR) + frame_analysis_result_service = FrameAnalysisResultService(self.db) - results = frame_analysis_result_service.get_results_by_frame(device_frame.id) + results = await frame_analysis_result_service.get_results_by_frame(device_frame.id) if results: annotator = Annotator(deepcopy(frame_image)) height, width = frame_image.shape[:2] diff --git a/services/device_model_relation_service.py b/services/device_model_relation_service.py index c32b5a5..5f1aa9d 100644 --- a/services/device_model_relation_service.py +++ b/services/device_model_relation_service.py @@ -1,8 +1,8 @@ from datetime import datetime from typing import List -from sqlmodel import Session, select, delete - +from sqlmodel import select, delete +from sqlalchemy.ext.asyncio import AsyncSession from common.consts import NotifyChangeType from common.global_thread_pool import GlobalThreadPool from entity.device_model_relation import DeviceModelRelation, DeviceModelRelationInfo, DeviceModelRelationCreate @@ -10,7 +10,7 @@ class DeviceModelRelationService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db self.__relation_change_callbacks = [] # 用于存储回调函数 self.thread_pool = GlobalThreadPool() @@ -24,7 +24,7 @@ for callback in self.__relation_change_callbacks: self.thread_pool.executor.submit(callback, device_id, change_type) - def get_device_models(self, device_id: int) -> List[DeviceModelRelationInfo]: + async def get_device_models(self, device_id: int) -> List[DeviceModelRelationInfo]: statement = ( select(DeviceModelRelation, AlgoModel) .join(AlgoModel, DeviceModelRelation.algo_model_id == AlgoModel.id) @@ -32,7 +32,8 @@ ) # 执行联表查询 - result = self.db.exec(statement).all() + results = await self.db.execute(statement) + rows = results.all() models_info = [ DeviceModelRelationInfo( @@ -46,13 +47,11 @@ algo_model_path=model.path, algo_model_remark=model.remark, ) - for relation, model in result + for relation, model in rows ] return models_info - - - def add_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]): + async def add_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]): new_relations = [ DeviceModelRelation( algo_model_id=relation.algo_model_id, @@ -64,20 +63,21 @@ ) for relation in relations ] - self.db.add_all(new_relations) - self.db.commit() for relation in new_relations: - self.db.refresh(relation) + self.db.add(relation) + await self.db.commit() + for relation in new_relations: + await self.db.refresh(relation) return new_relations - def delete_relations_by_device(self, device_id: int): + async def delete_relations_by_device(self, device_id: int): statement = delete(DeviceModelRelation).where(DeviceModelRelation.device_id == device_id) - count = self.db.exec(statement) - self.db.commit() - return count.rowcount + result = await self.db.execute(statement) + await self.db.commit() + return result.rowcount - def update_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]): - self.delete_relations_by_device(device_id) - new_relations = self.add_relations_by_device(device_id, relations) + async def update_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]): + await self.delete_relations_by_device(device_id) + new_relations = await self.add_relations_by_device(device_id, relations) self.notify_change(device_id, NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE) return new_relations diff --git a/services/device_scene_relation_service.py b/services/device_scene_relation_service.py index 592ee98..3416831 100644 --- a/services/device_scene_relation_service.py +++ b/services/device_scene_relation_service.py @@ -1,8 +1,8 @@ from datetime import datetime from typing import List, Optional -from sqlmodel import Session, select, delete - +from sqlmodel import select, delete +from sqlalchemy.ext.asyncio import AsyncSession from common.consts import NotifyChangeType from common.global_thread_pool import GlobalThreadPool from entity.device_scene_relation import DeviceSceneRelationInfo, DeviceSceneRelation @@ -10,7 +10,7 @@ class DeviceSceneRelationService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db self.__relation_change_callbacks = [] # 用于存储回调函数 self.thread_pool = GlobalThreadPool() @@ -24,7 +24,7 @@ for callback in self.__relation_change_callbacks: self.thread_pool.executor.submit(callback, device_id, change_type) - def get_device_scene(self, device_id: int) -> Optional[DeviceSceneRelationInfo]: + async def get_device_scene(self, device_id: int) -> Optional[DeviceSceneRelationInfo]: statement = ( select(DeviceSceneRelation, Scene) .join(Scene, DeviceSceneRelation.scene_id == Scene.id) @@ -32,11 +32,12 @@ ) # 执行联表查询 - result = self.db.exec(statement).first() + result = await self.db.execute(statement) + result_row = result.first() scene_info = None - if result: - relation, scene = result[0], result[1] + if result_row: + relation, scene = result_row scene_info = DeviceSceneRelationInfo( id=relation.id, device_id=relation.device_id, @@ -48,23 +49,23 @@ ) return scene_info - def add_relation_by_device(self, device_id: int, scene_id: int): + async def add_relation_by_device(self, device_id: int, scene_id: int): new_relation = DeviceSceneRelation(device_id=device_id, scene_id=scene_id) new_relation.create_time = datetime.now() new_relation.update_time = datetime.now() self.db.add(new_relation) - self.db.commit() - self.db.refresh(new_relation) + await self.db.commit() + await self.db.refresh(new_relation) return new_relation - def delete_relation_by_device(self, device_id: int): + async def delete_relation_by_device(self, device_id: int): statement = delete(DeviceSceneRelation).where(DeviceSceneRelation.device_id == device_id) - count = self.db.exec(statement) - self.db.commit() - return count.rowcount + result = await self.db.execute(statement) + await self.db.commit() + return result.rowcount - def update_relation_by_device(self, device_id: int, scene_id: int): - self.delete_relation_by_device(device_id) - new_relation = self.add_relation_by_device(device_id, scene_id) + async def update_relation_by_device(self, device_id: int, scene_id: int): + await self.delete_relation_by_device(device_id) + new_relation = await self.add_relation_by_device(device_id, scene_id) self.notify_change(device_id, NotifyChangeType.DEVICE_SCENE_RELATION_UPDATE) return new_relation diff --git a/services/device_service.py b/services/device_service.py index cad46d3..2488fd5 100644 --- a/services/device_service.py +++ b/services/device_service.py @@ -4,8 +4,8 @@ from typing import Sequence, Optional, Tuple from sqlalchemy import func -from sqlmodel import Session, select - +from sqlmodel import select, delete +from sqlalchemy.ext.asyncio import AsyncSession from common.device_status_manager import DeviceStatusManager from common.global_thread_pool import GlobalThreadPool from common.consts import NotifyChangeType, DEVICE_MODE @@ -15,25 +15,25 @@ class DeviceService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db self.__device_change_callbacks = [] # 用于存储回调函数 self.thread_pool = GlobalThreadPool() # 创建一个独立的事件循环并启动线程 - self.loop = asyncio.new_event_loop() - self.loop_thread = threading.Thread(target=self._start_loop, daemon=True) - self.loop_thread.start() + # self.loop = asyncio.new_event_loop() + # self.loop_thread = threading.Thread(target=self._start_loop, daemon=True) + # self.loop_thread.start() - def _start_loop(self): - """后台线程运行事件循环""" - asyncio.set_event_loop(self.loop) - self.loop.run_forever() - - def shutdown(self): - """清理事件循环和线程""" - self.loop.call_soon_threadsafe(self.loop.stop) - self.loop_thread.join() + # def _start_loop(self): + # """后台线程运行事件循环""" + # asyncio.set_event_loop(self.loop) + # self.loop.run_forever() + # + # def shutdown(self): + # """清理事件循环和线程""" + # self.loop.call_soon_threadsafe(self.loop.stop) + # self.loop_thread.join() def register_change_callback(self, callback): """注册设备变化回调函数""" @@ -42,52 +42,53 @@ def notify_change(self, device_id, change_type): """当设备发生变化时,调用回调通知变化""" # loop = asyncio.get_event_loop() # 获取当前的事件循环 - for callback in self.__device_change_callbacks: - if asyncio.iscoroutinefunction(callback): - # 如果是协程函数,使用事件循环运行它 - # loop = asyncio.new_event_loop() - # asyncio.set_event_loop(loop) - asyncio.run_coroutine_threadsafe(callback(device_id, change_type), self.loop) - else: - # 如果是普通函数,直接提交到线程池 - self.thread_pool.executor.submit(callback, device_id, change_type) + # for callback in self.__device_change_callbacks: + # if asyncio.iscoroutinefunction(callback): + # # 如果是协程函数,使用事件循环运行它 + # # loop = asyncio.new_event_loop() + # # asyncio.set_event_loop(loop) + # asyncio.run_coroutine_threadsafe(callback(device_id, change_type), self.loop) + # else: + # # 如果是普通函数,直接提交到线程池 + # self.thread_pool.executor.submit(callback, device_id, change_type) - def get_device_list(self, - name: Optional[str] = None, - code: Optional[str] = None, - device_type: Optional[str] = None, - ) -> Sequence[Device]: + async def get_device_list(self, + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + ) -> Sequence[Device]: statement = self.device_query(code, device_type, name) - results = self.db.exec(statement) - return results.all() + results = await self.db.execute(statement) + return results.scalars().all() - def get_device_page(self, - name: Optional[str] = None, - code: Optional[str] = None, - device_type: Optional[str] = None, - offset: int = 0, - limit: int = 10 - ) -> Tuple[Sequence[DeviceInfo], int]: + async def get_device_page(self, + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + offset: int = 0, + limit: int = 10 + ) -> Tuple[Sequence[DeviceInfo], int]: statement = self.device_query(code, device_type, name) # 查询总记录数 total_statement = select(func.count()).select_from(statement.subquery()) - total = self.db.exec(total_statement).one() + total_result = await self.db.execute(total_statement) + total = total_result.scalar_one() # 添加分页限制 statement = statement.offset(offset).limit(limit) # 执行查询并返回结果 - results = self.db.exec(statement) - device_list = results.all() + results = await self.db.execute(statement) + device_list = results.scalars().all() device_info_list = [] if device_list: device_model_relation_service = DeviceModelRelationService(self.db) device_scene_relation_service = DeviceSceneRelationService(self.db) device_status_manager = DeviceStatusManager() for device in device_list: - model_relations = device_model_relation_service.get_device_models(device.id) - scene_relation = device_scene_relation_service.get_device_scene(device.id) + model_relations = await device_model_relation_service.get_device_models(device.id) + scene_relation = await device_scene_relation_service.get_device_scene(device.id) device_info_list.append(DeviceInfo( id=device.id, @@ -121,20 +122,20 @@ statement = statement.where(Device.type == device_type) return statement - def create_device(self, device_data: DeviceCreate): + async def create_device(self, device_data: DeviceCreate): device = Device.model_validate(device_data) device.create_time = datetime.now() device.update_time = datetime.now() - self.handle_device_mode(device) + await self.handle_device_mode(device) self.db.add(device) - self.db.commit() - self.db.refresh(device) + await self.db.commit() + await self.db.refresh(device) self.notify_change(device.id, NotifyChangeType.DEVICE_CREATE) return device - def update_device(self, device_data: DeviceUpdate): - device_old = self.db.get(Device, device_data.id) - device = self.db.get(Device, device_data.id) + async def update_device(self, device_data: DeviceUpdate): + device_old = await self.get_device(device_data.id) + device = await self.get_device(device_data.id) if not device: return None @@ -142,41 +143,45 @@ for key, value in update_data.items(): setattr(device, key, value) - self.handle_device_mode(device) + await self.handle_device_mode(device) device.update_time = datetime.now() self.db.add(device) - self.db.commit() - self.db.refresh(device) + await self.db.commit() + await self.db.refresh(device) self.notify_change(device.id, NotifyChangeType.DEVICE_UPDATE) return device - def delete_device(self, device_id: int): - device = self.db.get(Device, device_id) + async def delete_device(self, device_id: int): + device = await self.get_device(device_id) if not device: return None - self.db.delete(device) - self.db.commit() + statement = delete(Device).where(Device.id == device_id) + await self.db.execute(statement) + await self.db.commit() + self.notify_change(device.id, NotifyChangeType.DEVICE_DELETE) model_relation_service = DeviceModelRelationService(self.db) - model_relation_service.delete_relations_by_device(device_id) + await model_relation_service.delete_relations_by_device(device_id) scene_relation_service = DeviceSceneRelationService(self.db) - scene_relation_service.delete_relation_by_device(device.id) + await scene_relation_service.delete_relation_by_device(device.id) return device - def handle_device_mode(self, device): + async def handle_device_mode(self, device): if device.mode == DEVICE_MODE.ALGO: scene_relation_service = DeviceSceneRelationService(self.db) - scene_relation_service.delete_relation_by_device(device.id) + await scene_relation_service.delete_relation_by_device(device.id) elif device.mode == DEVICE_MODE.SCENE: model_relation_service = DeviceModelRelationService(self.db) - model_relation_service.delete_relations_by_device(device.id) + await model_relation_service.delete_relations_by_device(device.id) else: scene_relation_service = DeviceSceneRelationService(self.db) - scene_relation_service.delete_relation_by_device(device.id) + await scene_relation_service.delete_relation_by_device(device.id) model_relation_service = DeviceModelRelationService(self.db) - model_relation_service.delete_relations_by_device(device.id) + await model_relation_service.delete_relations_by_device(device.id) - def get_device(self, device_id: int): - return self.db.get(Device, device_id) + async def get_device(self, device_id: int): + result = await self.db.execute(select(Device).where(Device.id == device_id)) + frame = result.scalar_one_or_none() + return frame diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index dbc3fee..cdb97bf 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -98,7 +98,7 @@ def push_frame_results(self, frame_results): global_config = GlobalConfig() push_config = global_config.get_algo_result_push_config() - if push_config: + if push_config and push_config.push_url: last_ts = self.push_ts current_time = datetime.now() diff --git a/apis/control.py b/apis/control.py index 8517389..4c108c5 100644 --- a/apis/control.py +++ b/apis/control.py @@ -48,7 +48,7 @@ @router.get("/restart") -def restart(): +async def restart(): try: # 立即返回响应的函数 def restart_container_async(): diff --git a/apis/data_gas.py b/apis/data_gas.py index fc8ec54..bcb7cb6 100644 --- a/apis/data_gas.py +++ b/apis/data_gas.py @@ -3,8 +3,8 @@ import pandas as pd from fastapi import APIRouter, Query, Depends -from openpyxl.utils import get_column_letter -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession + from starlette.responses import StreamingResponse from apis.base import StandardResponse, PageResponse, convert_page_param, standard_response @@ -17,19 +17,20 @@ @router.get("/page", response_model=StandardResponse[PageResponse[DataGasInfo]]) -def get_gas_page( +async def get_gas_page( device_code: Optional[str] = None, start_time: Optional[str] = None, end_time: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): + service = DataGasService(db) offset, limit = convert_page_param(offset, limit) - data, total = service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), - offset, limit) + data, total = await service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), + offset, limit) return standard_response( data=PageResponse(total=total, items=data) @@ -37,38 +38,39 @@ @router.get("/export") -def export_data_gas(device_code: Optional[str] = None, - start_time: Optional[str] = None, - end_time: Optional[str] = None, - db: Session = Depends(get_db)): - service = DataGasService(db) - data = service.get_data_gas_list(device_code, start_time, end_time) +async def export_data_gas(device_code: Optional[str] = None, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + db: AsyncSession = Depends(get_db)): + async with db: + service = DataGasService(db) + data = await service.get_data_gas_list(device_code, start_time, end_time) - # 将查询结果转换为 DataFrame - data = [ - {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} - for item in data] - df = pd.DataFrame(data) + # 将查询结果转换为 DataFrame + data = [ + {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} + for item in data] + df = pd.DataFrame(data) - # 使用 BytesIO 生成内存中的 Excel 文件 - output = BytesIO() - with pd.ExcelWriter(output, engine="openpyxl") as writer: - df.to_excel(writer, index=False) + # 使用 BytesIO 生成内存中的 Excel 文件 + output = BytesIO() + with pd.ExcelWriter(output, engine="openpyxl") as writer: + df.to_excel(writer, index=False) - # 获取工作表 - worksheet = writer.sheets["Sheet1"] + # 获取工作表 + worksheet = writer.sheets["Sheet1"] - # 设置固定的列宽,例如宽度为 20 - fixed_width = 30 - for col in worksheet.columns: - col_letter = col[0].column_letter # 获取列字母 - worksheet.column_dimensions[col_letter].width = fixed_width + # 设置固定的列宽,例如宽度为 20 + fixed_width = 30 + for col in worksheet.columns: + col_letter = col[0].column_letter # 获取列字母 + worksheet.column_dimensions[col_letter].width = fixed_width - output.seek(0) + output.seek(0) - # 返回内存中的 Excel 文件作为响应 - return StreamingResponse( - output, - media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", - headers={"Content-Disposition": "attachment; filename=gas.xlsx"} - ) + # 返回内存中的 Excel 文件作为响应 + return StreamingResponse( + output, + media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + headers={"Content-Disposition": "attachment; filename=gas.xlsx"} + ) diff --git a/apis/device.py b/apis/device.py index 85fb10c..f170ae3 100644 --- a/apis/device.py +++ b/apis/device.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db @@ -18,29 +18,29 @@ @router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) -def get_device_list( +async def get_device_list( name: Optional[str] = None, code: Optional[str] = None, device_type: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceService(db) - devices = list(service.get_device_list(name, code, device_type)) - return standard_response(data=devices) + devices = await service.get_device_list(name, code, device_type) + return standard_response(data=list(devices)) @router.get("/page/", response_model=StandardResponse[PageResponse[DeviceInfo]]) -def get_device_page( +async def get_device_page( name: Optional[str] = None, code: Optional[str] = None, device_type: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - devices, total = service.get_device_page(name, code, device_type, offset, limit) + devices, total = await service.get_device_page(name, code, device_type, offset, limit) return standard_response( data=PageResponse(total=total, items=devices) @@ -48,22 +48,22 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): - device = service.create_device(device_data) +async def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): + device = await service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): - device = service.update_device(device_data) +async def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): + device = await service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") return standard_response(data=device) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, service: DeviceService = Depends(get_service)): - device = service.delete_device(device_id) +async def delete_device(device_id: int, service: DeviceService = Depends(get_service)): + device = await service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") return standard_response(data=device_id) diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py index e61e1d6..ddb5eb6 100644 --- a/apis/device_model_realtion.py +++ b/apis/device_model_realtion.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import StandardResponse, standard_response from db.database import get_db @@ -13,37 +13,23 @@ router = APIRouter() app = get_app() + def get_service(): return app.state.model_relation_service @router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) -def list_by_device( +async def list_by_device( device_id: int, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceModelRelationService(db) - models = list(service.get_device_models(device_id)) - return standard_response(data=models) - - -# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) -# def add_by_device(relation_data: List[DeviceModelRelationCreate], -# device_id: int = Query(...), -# db: Session = Depends(get_db)): -# service = DeviceModelRelationService(db) -# relations = service.add_relations_by_device(device_id, relation_data) -# return standard_response(data=relations) + models = await service.get_device_models(device_id) + return standard_response(data=list(models)) @router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) -def update_by_device(relation_data: List[DeviceModelRelationCreate], - device_id: int = Query(...), - service: DeviceModelRelationService = Depends(get_service)): - relations = service.update_relations_by_device(device_id, relation_data) +async def update_by_device(relation_data: List[DeviceModelRelationCreate], + device_id: int = Query(...), + service: DeviceModelRelationService = Depends(get_service)): + relations = await service.update_relations_by_device(device_id, relation_data) return standard_response(data=relations) - -# @router.delete("/delete_by_device", response_model=StandardResponse[int]) -# def delete_device(device_id: int, db: Session = Depends(get_db)): -# service = DeviceModelRelationService(db) -# count = service.delete_relations_by_device(device_id) -# return standard_response(data=count) diff --git a/apis/device_scene_realtion.py b/apis/device_scene_realtion.py index 9a30a44..d9aa3a2 100644 --- a/apis/device_scene_realtion.py +++ b/apis/device_scene_realtion.py @@ -1,8 +1,8 @@ -from typing import List, Optional + from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import StandardResponse, standard_response from db.database import get_db from entity.device_scene_relation import DeviceSceneRelationInfo, DeviceSceneRelation @@ -13,17 +13,17 @@ @router.get("/get_by_device", response_model=StandardResponse[DeviceSceneRelationInfo]) -def list_by_device( +async def list_by_device( device_id: int, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - scene = service.get_device_scene(device_id) + scene = await service.get_device_scene(device_id) return standard_response(data=scene) @router.post("/update_by_device", response_model=StandardResponse[DeviceSceneRelation]) -def update_by_device(device_id: int, scene_id: int, - db: Session = Depends(get_db)): +async def update_by_device(device_id: int, scene_id: int, + db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - relation = service.update_relation_by_device(device_id, scene_id) + relation = await service.update_relation_by_device(device_id, scene_id) return standard_response(data=relation) diff --git a/apis/frame.py b/apis/frame.py index 1cb9294..5146ede 100644 --- a/apis/frame.py +++ b/apis/frame.py @@ -5,7 +5,7 @@ import cv2 from fastapi import APIRouter, Depends, Query, HTTPException, Request -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from fastapi.responses import StreamingResponse, FileResponse from apis.base import StandardResponse, PageResponse, standard_response, standard_error_response, convert_page_param @@ -19,20 +19,20 @@ @router.get("/page/", response_model=StandardResponse[PageResponse[DeviceFrame]]) -def get_frame_page( +async def get_frame_page( device_name: Optional[str] = None, device_code: Optional[str] = None, frame_start_time: Optional[str] = None, frame_end_time: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceFrameService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - frames, total = service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), - parse_datetime(frame_end_time), offset, limit) + frames, total = await service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), + parse_datetime(frame_end_time), offset, limit) return standard_response( data=PageResponse(total=total, items=frames) @@ -41,10 +41,10 @@ # 路由:使用 OpenCV 生成内存图像并返回字节流 @router.get("/frame_image/") -def get_frame_image(frame_id, db: Session = Depends(get_db)): +async def get_frame_image(frame_id, db: AsyncSession = Depends(get_db)): try: service = DeviceFrameService(db) - frame = service.get_frame_annotator(frame_id) + frame = await service.get_frame_annotator(frame_id) if frame is None: return standard_error_response(message="Frame does not exist") @@ -63,14 +63,3 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"Error reading image: {str(e)}") - -@router.get("/frame_test") -def get_frame_test(request: Request): - file_path = "test.jpg" - return FileResponse(file_path) - # def iterfile(): - # with open(file_path, 'rb') as file: - # while chunk := file.read(1024): - # yield chunk - # - # return StreamingResponse(iterfile(), media_type="image/jpeg") diff --git a/apis/model.py b/apis/model.py index a1161eb..fa2b191 100644 --- a/apis/model.py +++ b/apis/model.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query, UploadFile, File, Form -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from app_instance import get_app @@ -18,27 +18,27 @@ @router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) -def get_model_list( +async def get_model_list( name: Optional[str] = None, remark: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = ModelService(db) - models = list(service.get_model_list(name, remark)) - return standard_response(data=models) + models = await service.get_model_list(name, remark) + return standard_response(data=list(models)) @router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModelInfo]]) -def get_model_page( +async def get_model_page( name: Optional[str] = None, remark: Optional[str] = None, offset: int = Query(0, ge=0), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = ModelService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - models, total = service.get_model_page(name, remark, offset, limit) + models, total = await service.get_model_page(name, remark, offset, limit) return standard_response( data=PageResponse(total=total, items=models) @@ -46,9 +46,9 @@ @router.post("/add", response_model=StandardResponse[AlgoModelInfo]) -def create_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelCreate结构"), +async def create_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelCreate结构"), file: UploadFile = File(..., description="模型文件"), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_db), ): # 检查文件类型 if not file.filename.endswith(".zip"): @@ -56,25 +56,25 @@ model_data = AlgoModelCreate.parse_raw(json_data) service = ModelService(db) - model = service.create_model(model_data, file) + model = await service.create_model(model_data, file) return standard_response(data=model) @router.post("/update", response_model=StandardResponse[AlgoModelInfo]) -def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), +async def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), file: UploadFile = File(None, description="模型文件"), service: ModelService = Depends(get_service)): model_data = AlgoModelUpdate.parse_raw(json_data) - model = service.update_model(model_data, file) + model = await service.update_model(model_data, file) if not model: return standard_error_response(data=model_data, message="Model not found") return standard_response(data=model) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_model(model_id: int, db: Session = Depends(get_db)): +async def delete_model(model_id: int, db: AsyncSession = Depends(get_db)): service = ModelService(db) - model = service.delete_model(model_id) + model = await service.delete_model(model_id) if not model: return standard_error_response(data=model_id, message="Model not found") return standard_response(data=model_id) diff --git a/apis/push_config.py b/apis/push_config.py index 5e25579..e234afb 100644 --- a/apis/push_config.py +++ b/apis/push_config.py @@ -1,7 +1,7 @@ from typing import List from fastapi import APIRouter, Depends -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse from db.database import get_db @@ -12,21 +12,21 @@ @router.get("/list", response_model=StandardResponse[List[PushConfig]]) -def get_push_config_list(db: Session = Depends(get_db)): +async def get_push_config_list(db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_configs = service.get_push_config_list() + push_configs = await service.get_push_config_list() return standard_response(data=push_configs) @router.get("/get_by_type", response_model=StandardResponse[PushConfig]) -def get_by_type(push_type: int, db: Session = Depends(get_db)): +async def get_by_type(push_type: int, db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_config = service.get_push_config(push_type) + push_config = await service.get_push_config(push_type) return standard_response(data=push_config) @router.post("/set_push_config", response_model=StandardResponse[PushConfig]) -def set_push_config(push_config: PushConfigCreate, db: Session = Depends(get_db)): +async def set_push_config(push_config: PushConfigCreate, db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_config = service.set_push_config(push_config) + push_config = await service.set_push_config(push_config) return standard_response(data=push_config) diff --git a/apis/scene.py b/apis/scene.py index f84b54b..1b7f706 100644 --- a/apis/scene.py +++ b/apis/scene.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query, UploadFile, File, Form -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db @@ -14,27 +14,27 @@ @router.get("/list", response_model=StandardResponse[List[Scene]]) -def get_scene_list( +async def get_scene_list( name: Optional[str] = None, remark: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) - scenes = list(service.get_scene_list(name, remark)) - return standard_response(data=scenes) + scenes = await service.get_scene_list(name, remark) + return standard_response(data=list(scenes)) @router.get("/page/", response_model=StandardResponse[PageResponse[SceneInfo]]) -def get_scene_page( +async def get_scene_page( name: Optional[str] = None, remark: Optional[str] = None, offset: int = Query(0, ge=0), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - scenes, total = service.get_scene_page(name, remark, offset, limit) + scenes, total = await service.get_scene_page(name, remark, offset, limit) return standard_response( data=PageResponse(total=total, items=scenes) @@ -42,9 +42,9 @@ @router.post("/add", response_model=StandardResponse[SceneInfo]) -def create_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneCreate结构"), +async def create_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneCreate结构"), file: UploadFile = File(..., description="模型文件"), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_db), ): # 检查文件类型 if not file.filename.endswith(".zip"): @@ -52,26 +52,26 @@ scene_data = SceneCreate.parse_raw(json_data) service = SceneService(db) - scene = service.create_scene(scene_data, file) + scene = await service.create_scene(scene_data, file) return standard_response(data=scene) @router.post("/update", response_model=StandardResponse[SceneInfo]) -def update_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneUpdate结构"), +async def update_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneUpdate结构"), file: UploadFile = File(None, description="模型文件"), - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) scene_data = SceneUpdate.parse_raw(json_data) - scene = service.update_scene(scene_data, file) + scene = await service.update_scene(scene_data, file) if not scene: return standard_error_response(data=scene_data, message="Scene not found") return standard_response(data=scene) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_scene(scene_id: int, db: Session = Depends(get_db)): +async def delete_scene(scene_id: int, db: AsyncSession = Depends(get_db)): service = SceneService(db) - scene = service.delete_scene(scene_id) + scene = await service.delete_scene(scene_id) if not scene: return standard_error_response(data=scene_id, message="Scene not found") return standard_response(data=scene_id) diff --git a/app_instance.py b/app_instance.py index a6e0992..f9d1ad9 100644 --- a/app_instance.py +++ b/app_instance.py @@ -12,6 +12,7 @@ from services.device_model_relation_service import DeviceModelRelationService from services.device_scene_relation_service import DeviceSceneRelationService from services.device_service import DeviceService +from services.global_config import GlobalConfig from services.model_service import ModelService from services.scene_service import SceneService from tcp.tcp_manager import TcpManager @@ -28,7 +29,11 @@ async def lifespan(app: FastAPI): main_loop = asyncio.get_running_loop() - with next(get_db()) as db: + # async with get_db() as db: + async for db in get_db(): + global_config = GlobalConfig() + await global_config.init_config() + device_service = DeviceService(db) model_service = ModelService(db) model_relation_service = DeviceModelRelationService(db) @@ -51,7 +56,7 @@ relation_service=model_relation_service, ) app.state.algo_runner = algo_runner - await algo_runner.start() + # await algo_runner.start() scene_runner = SceneRunner( device_service=device_service, @@ -61,7 +66,9 @@ main_loop=main_loop ) app.state.scene_runner = scene_runner - await scene_runner.start() + # await scene_runner.start() + + yield # 允许请求处理 diff --git a/common/global_logger.py b/common/global_logger.py index 3e3a6a9..3e672ca 100644 --- a/common/global_logger.py +++ b/common/global_logger.py @@ -12,6 +12,8 @@ logger = logging.getLogger("casic_safe_logger") logger.setLevel(logging.DEBUG) # 设置日志级别 +logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING) + # 创建一个TimedRotatingFileHandler handler = logging.handlers.TimedRotatingFileHandler( os.path.join(log_dir, 'app.log'), # 日志文件名 diff --git a/db/database.py b/db/database.py index 4f44d2d..01ac013 100644 --- a/db/database.py +++ b/db/database.py @@ -1,22 +1,25 @@ +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession +from sqlalchemy.orm import sessionmaker from sqlmodel import SQLModel, create_engine, Session -from contextlib import contextmanager +from contextlib import contextmanager, asynccontextmanager sqlite_file_name = "./db/safe-algo-pro.db" -sqlite_url = f"sqlite:///{sqlite_file_name}" +sqlite_url = f"sqlite+aiosqlite:///{sqlite_file_name}" # 使用异步SQLite驱动 -connect_args = {"check_same_thread": False} -engine = create_engine(sqlite_url, connect_args=connect_args) +engine = create_async_engine(sqlite_url, echo=False, future=True) # 初始化数据库表 -def init_db(): - SQLModel.metadata.create_all(engine) +async def init_db(): + async with engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.create_all) -# 数据库会话管理 -def get_db(): - session = Session(engine) - try: +# 异步数据库会话管理 +# @asynccontextmanager +async def get_db() -> AsyncSession: + async_session = sessionmaker( + engine, class_=AsyncSession, expire_on_commit=False + ) + async with async_session() as session: yield session - finally: - session.close() diff --git a/main.py b/main.py index e35e458..e830e86 100644 --- a/main.py +++ b/main.py @@ -3,7 +3,9 @@ import logging from fastapi.openapi.docs import get_swagger_ui_html -from starlette.staticfiles import StaticFiles +from fastapi.middleware.cors import CORSMiddleware +from fastapi.staticfiles import StaticFiles + from app_instance import get_app from common.global_logger import logger @@ -28,6 +30,14 @@ from apis.router import router app.include_router(router, prefix="/api") +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + if __name__ == "__main__": # 重定向 uvicorn 的日志 uvicorn_logger = logging.getLogger("uvicorn") diff --git a/requirements.txt b/requirements.txt index 161192a..aa9e1e6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,6 @@ starlette uvicorn sqlalchemy -aiohttp \ No newline at end of file +aiohttp +aiosqlite +aiofiles \ No newline at end of file diff --git a/services/data_gas_service.py b/services/data_gas_service.py index fe6ce53..6ac5665 100644 --- a/services/data_gas_service.py +++ b/services/data_gas_service.py @@ -2,23 +2,24 @@ from typing import Optional, Tuple, Sequence from sqlalchemy import func -from sqlmodel import Session, select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlmodel import select from entity.data_gas import DataGas, DataGasInfo from entity.device import Device class DataGasService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db - def add_data_gas(self, data_gas: DataGas): + async def add_data_gas(self, data_gas: DataGas): self.db.add(data_gas) - self.db.commit() - self.db.refresh(data_gas) + await self.db.commit() + await self.db.refresh(data_gas) return data_gas - def get_data_gas_page(self, + async def get_data_gas_page(self, device_code: Optional[str] = None, start_time: Optional[datetime] = None, end_time: Optional[datetime] = None, @@ -29,13 +30,17 @@ # 查询总记录数 total_statement = select(func.count()).select_from(statement.subquery()) - total = self.db.exec(total_statement).one() + total_result = await self.db.execute(total_statement) + total = total_result.scalar_one() # 添加分页限制 statement = statement.offset(offset).limit(limit) # 执行查询并返回结果 - results = self.db.exec(statement).all() + results = await self.db.execute(statement) + rows = results.all() + + data_gas_info_list = [ DataGasInfo( id=data_gas.id, @@ -44,18 +49,19 @@ ts=data_gas.ts, device_name=device_name ) - for data_gas, device_name in results + for data_gas, device_name in rows ] return data_gas_info_list, total # 返回分页数据和总数 - def get_data_gas_list(self, + async def get_data_gas_list(self, device_code: Optional[str] = None, start_time: Optional[datetime] = None, end_time: Optional[datetime] = None, ) -> Sequence[DataGasInfo]: statement = self.gas_query(device_code, end_time, start_time) - results = self.db.exec(statement).all() + results = await self.db.execute(statement) + rows = results.all() data_gas_info_list = [ DataGasInfo( id=data_gas.id, @@ -64,7 +70,7 @@ ts=data_gas.ts, device_name=device_name ) - for data_gas, device_name in results + for data_gas, device_name in rows ] return data_gas_info_list diff --git a/services/device_frame_service.py b/services/device_frame_service.py index 5dead7d..1a468ae 100644 --- a/services/device_frame_service.py +++ b/services/device_frame_service.py @@ -3,7 +3,10 @@ from copy import deepcopy from datetime import datetime from typing import Sequence, Optional, Tuple -from sqlmodel import Session + +import aiofiles +import numpy as np +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import func @@ -13,46 +16,56 @@ import cv2 -from sqlmodel import Session, select, delete +from sqlmodel import select, delete from services.frame_analysis_result_service import FrameAnalysisResultService class DeviceFrameService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db - def add_frame(self, device_id, frame_data) -> DeviceFrame: - def save_frame_file(): - # 生成当前年月日作为目录路径 - current_date = datetime.now().strftime('%Y-%m-%d') - # 生成随机 UUID 作为文件名 - file_name = str(uuid.uuid4()) + ".jpeg" - # 创建保存图片的完整路径 - save_path = os.path.join('./storage/frames', current_date, file_name) - # 创建目录(如果不存在) - os.makedirs(os.path.dirname(save_path), exist_ok=True) - # 保存图片 - cv2.imwrite(save_path, frame_data) - return save_path + async def add_frame(self, device_id, frame_data) -> DeviceFrame: + async def add_frame(self, device_id, frame_data) -> 'DeviceFrame': + async def save_frame_file(): + # 生成当前年月日作为目录路径 + current_date = datetime.now().strftime('%Y-%m-%d') + # 生成随机 UUID 作为文件名 + file_name = f"{uuid.uuid4()}.jpeg" + # 创建保存图片的完整路径 + save_path = os.path.join('./storage/frames', current_date, file_name) + # 创建目录(如果不存在) + os.makedirs(os.path.dirname(save_path), exist_ok=True) - # 保存图片文件 - file_path = save_frame_file() - device_frame = DeviceFrame(device_id=device_id, frame_path=file_path) - self.db.add(device_frame) - self.db.commit() - self.db.refresh(device_frame) - return device_frame + # 将 frame_data 转换为二进制数据(假设 frame_data 是一个 numpy 数组) + _, encoded_image = cv2.imencode('.jpeg', frame_data) + image_data = encoded_image.tobytes() - def get_frame_page(self, - device_name: Optional[str] = None, - device_code: Optional[str] = None, - frame_start_time: Optional[datetime] = None, - frame_end_time: Optional[datetime] = None, - offset: int = 0, - limit: int = 10 - ) -> Tuple[Sequence[DeviceFrame], int]: + # 使用 aiofiles 进行异步写入 + async with aiofiles.open(save_path, 'wb') as f: + await f.write(image_data) + + return save_path + + # 异步保存图片文件 + file_path = await save_frame_file() + + # 创建并保存到数据库中 + device_frame = DeviceFrame(device_id=device_id, frame_path=file_path) + self.db.add(device_frame) + await self.db.commit() + await self.db.refresh(device_frame) + return device_frame + + async def get_frame_page(self, + device_name: Optional[str] = None, + device_code: Optional[str] = None, + frame_start_time: Optional[datetime] = None, + frame_end_time: Optional[datetime] = None, + offset: int = 0, + limit: int = 10 + ) -> Tuple[Sequence[DeviceFrame], int]: statement = ( select(DeviceFrame, Device) .join(Device, DeviceFrame.device_id == Device.id) @@ -69,26 +82,37 @@ # 查询总记录数 total_statement = select(func.count()).select_from(statement.subquery()) - total = self.db.exec(total_statement).one() + total_result = await self.db.execute(total_statement) + total = total_result.scalar_one() # 添加分页限制 statement = statement.offset(offset).limit(limit) # 执行查询并返回结果 - results = self.db.exec(statement) - results = results.all() - frames = [frame for frame, device in results] + results = await self.db.execute(statement) + rows = results.all() + + frames = [frame for frame, device in rows] return frames, total # 返回分页数据和总数 - def get_frame(self, frame_id: int): - return self.db.get(DeviceFrame, frame_id) + async def get_frame(self, frame_id: int): + result = await self.db.execute(select(DeviceFrame).where(DeviceFrame.id == frame_id)) + frame = result.scalar_one_or_none() + return frame - def get_frame_annotator(self, frame_id: int): - device_frame = self.get_frame(frame_id) + async def get_frame_annotator(self, frame_id: int): + device_frame = await self.get_frame(frame_id) if device_frame: - frame_image = cv2.imread(device_frame.frame_path) + # 异步读取图像文件 + async with aiofiles.open(device_frame.frame_path, mode='rb') as f: + file_content = await f.read() + + # 将读取的字节内容转换为 OpenCV 图像 + np_array = np.frombuffer(file_content, dtype=np.uint8) + frame_image = cv2.imdecode(np_array, cv2.IMREAD_COLOR) + frame_analysis_result_service = FrameAnalysisResultService(self.db) - results = frame_analysis_result_service.get_results_by_frame(device_frame.id) + results = await frame_analysis_result_service.get_results_by_frame(device_frame.id) if results: annotator = Annotator(deepcopy(frame_image)) height, width = frame_image.shape[:2] diff --git a/services/device_model_relation_service.py b/services/device_model_relation_service.py index c32b5a5..5f1aa9d 100644 --- a/services/device_model_relation_service.py +++ b/services/device_model_relation_service.py @@ -1,8 +1,8 @@ from datetime import datetime from typing import List -from sqlmodel import Session, select, delete - +from sqlmodel import select, delete +from sqlalchemy.ext.asyncio import AsyncSession from common.consts import NotifyChangeType from common.global_thread_pool import GlobalThreadPool from entity.device_model_relation import DeviceModelRelation, DeviceModelRelationInfo, DeviceModelRelationCreate @@ -10,7 +10,7 @@ class DeviceModelRelationService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db self.__relation_change_callbacks = [] # 用于存储回调函数 self.thread_pool = GlobalThreadPool() @@ -24,7 +24,7 @@ for callback in self.__relation_change_callbacks: self.thread_pool.executor.submit(callback, device_id, change_type) - def get_device_models(self, device_id: int) -> List[DeviceModelRelationInfo]: + async def get_device_models(self, device_id: int) -> List[DeviceModelRelationInfo]: statement = ( select(DeviceModelRelation, AlgoModel) .join(AlgoModel, DeviceModelRelation.algo_model_id == AlgoModel.id) @@ -32,7 +32,8 @@ ) # 执行联表查询 - result = self.db.exec(statement).all() + results = await self.db.execute(statement) + rows = results.all() models_info = [ DeviceModelRelationInfo( @@ -46,13 +47,11 @@ algo_model_path=model.path, algo_model_remark=model.remark, ) - for relation, model in result + for relation, model in rows ] return models_info - - - def add_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]): + async def add_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]): new_relations = [ DeviceModelRelation( algo_model_id=relation.algo_model_id, @@ -64,20 +63,21 @@ ) for relation in relations ] - self.db.add_all(new_relations) - self.db.commit() for relation in new_relations: - self.db.refresh(relation) + self.db.add(relation) + await self.db.commit() + for relation in new_relations: + await self.db.refresh(relation) return new_relations - def delete_relations_by_device(self, device_id: int): + async def delete_relations_by_device(self, device_id: int): statement = delete(DeviceModelRelation).where(DeviceModelRelation.device_id == device_id) - count = self.db.exec(statement) - self.db.commit() - return count.rowcount + result = await self.db.execute(statement) + await self.db.commit() + return result.rowcount - def update_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]): - self.delete_relations_by_device(device_id) - new_relations = self.add_relations_by_device(device_id, relations) + async def update_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]): + await self.delete_relations_by_device(device_id) + new_relations = await self.add_relations_by_device(device_id, relations) self.notify_change(device_id, NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE) return new_relations diff --git a/services/device_scene_relation_service.py b/services/device_scene_relation_service.py index 592ee98..3416831 100644 --- a/services/device_scene_relation_service.py +++ b/services/device_scene_relation_service.py @@ -1,8 +1,8 @@ from datetime import datetime from typing import List, Optional -from sqlmodel import Session, select, delete - +from sqlmodel import select, delete +from sqlalchemy.ext.asyncio import AsyncSession from common.consts import NotifyChangeType from common.global_thread_pool import GlobalThreadPool from entity.device_scene_relation import DeviceSceneRelationInfo, DeviceSceneRelation @@ -10,7 +10,7 @@ class DeviceSceneRelationService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db self.__relation_change_callbacks = [] # 用于存储回调函数 self.thread_pool = GlobalThreadPool() @@ -24,7 +24,7 @@ for callback in self.__relation_change_callbacks: self.thread_pool.executor.submit(callback, device_id, change_type) - def get_device_scene(self, device_id: int) -> Optional[DeviceSceneRelationInfo]: + async def get_device_scene(self, device_id: int) -> Optional[DeviceSceneRelationInfo]: statement = ( select(DeviceSceneRelation, Scene) .join(Scene, DeviceSceneRelation.scene_id == Scene.id) @@ -32,11 +32,12 @@ ) # 执行联表查询 - result = self.db.exec(statement).first() + result = await self.db.execute(statement) + result_row = result.first() scene_info = None - if result: - relation, scene = result[0], result[1] + if result_row: + relation, scene = result_row scene_info = DeviceSceneRelationInfo( id=relation.id, device_id=relation.device_id, @@ -48,23 +49,23 @@ ) return scene_info - def add_relation_by_device(self, device_id: int, scene_id: int): + async def add_relation_by_device(self, device_id: int, scene_id: int): new_relation = DeviceSceneRelation(device_id=device_id, scene_id=scene_id) new_relation.create_time = datetime.now() new_relation.update_time = datetime.now() self.db.add(new_relation) - self.db.commit() - self.db.refresh(new_relation) + await self.db.commit() + await self.db.refresh(new_relation) return new_relation - def delete_relation_by_device(self, device_id: int): + async def delete_relation_by_device(self, device_id: int): statement = delete(DeviceSceneRelation).where(DeviceSceneRelation.device_id == device_id) - count = self.db.exec(statement) - self.db.commit() - return count.rowcount + result = await self.db.execute(statement) + await self.db.commit() + return result.rowcount - def update_relation_by_device(self, device_id: int, scene_id: int): - self.delete_relation_by_device(device_id) - new_relation = self.add_relation_by_device(device_id, scene_id) + async def update_relation_by_device(self, device_id: int, scene_id: int): + await self.delete_relation_by_device(device_id) + new_relation = await self.add_relation_by_device(device_id, scene_id) self.notify_change(device_id, NotifyChangeType.DEVICE_SCENE_RELATION_UPDATE) return new_relation diff --git a/services/device_service.py b/services/device_service.py index cad46d3..2488fd5 100644 --- a/services/device_service.py +++ b/services/device_service.py @@ -4,8 +4,8 @@ from typing import Sequence, Optional, Tuple from sqlalchemy import func -from sqlmodel import Session, select - +from sqlmodel import select, delete +from sqlalchemy.ext.asyncio import AsyncSession from common.device_status_manager import DeviceStatusManager from common.global_thread_pool import GlobalThreadPool from common.consts import NotifyChangeType, DEVICE_MODE @@ -15,25 +15,25 @@ class DeviceService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db self.__device_change_callbacks = [] # 用于存储回调函数 self.thread_pool = GlobalThreadPool() # 创建一个独立的事件循环并启动线程 - self.loop = asyncio.new_event_loop() - self.loop_thread = threading.Thread(target=self._start_loop, daemon=True) - self.loop_thread.start() + # self.loop = asyncio.new_event_loop() + # self.loop_thread = threading.Thread(target=self._start_loop, daemon=True) + # self.loop_thread.start() - def _start_loop(self): - """后台线程运行事件循环""" - asyncio.set_event_loop(self.loop) - self.loop.run_forever() - - def shutdown(self): - """清理事件循环和线程""" - self.loop.call_soon_threadsafe(self.loop.stop) - self.loop_thread.join() + # def _start_loop(self): + # """后台线程运行事件循环""" + # asyncio.set_event_loop(self.loop) + # self.loop.run_forever() + # + # def shutdown(self): + # """清理事件循环和线程""" + # self.loop.call_soon_threadsafe(self.loop.stop) + # self.loop_thread.join() def register_change_callback(self, callback): """注册设备变化回调函数""" @@ -42,52 +42,53 @@ def notify_change(self, device_id, change_type): """当设备发生变化时,调用回调通知变化""" # loop = asyncio.get_event_loop() # 获取当前的事件循环 - for callback in self.__device_change_callbacks: - if asyncio.iscoroutinefunction(callback): - # 如果是协程函数,使用事件循环运行它 - # loop = asyncio.new_event_loop() - # asyncio.set_event_loop(loop) - asyncio.run_coroutine_threadsafe(callback(device_id, change_type), self.loop) - else: - # 如果是普通函数,直接提交到线程池 - self.thread_pool.executor.submit(callback, device_id, change_type) + # for callback in self.__device_change_callbacks: + # if asyncio.iscoroutinefunction(callback): + # # 如果是协程函数,使用事件循环运行它 + # # loop = asyncio.new_event_loop() + # # asyncio.set_event_loop(loop) + # asyncio.run_coroutine_threadsafe(callback(device_id, change_type), self.loop) + # else: + # # 如果是普通函数,直接提交到线程池 + # self.thread_pool.executor.submit(callback, device_id, change_type) - def get_device_list(self, - name: Optional[str] = None, - code: Optional[str] = None, - device_type: Optional[str] = None, - ) -> Sequence[Device]: + async def get_device_list(self, + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + ) -> Sequence[Device]: statement = self.device_query(code, device_type, name) - results = self.db.exec(statement) - return results.all() + results = await self.db.execute(statement) + return results.scalars().all() - def get_device_page(self, - name: Optional[str] = None, - code: Optional[str] = None, - device_type: Optional[str] = None, - offset: int = 0, - limit: int = 10 - ) -> Tuple[Sequence[DeviceInfo], int]: + async def get_device_page(self, + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + offset: int = 0, + limit: int = 10 + ) -> Tuple[Sequence[DeviceInfo], int]: statement = self.device_query(code, device_type, name) # 查询总记录数 total_statement = select(func.count()).select_from(statement.subquery()) - total = self.db.exec(total_statement).one() + total_result = await self.db.execute(total_statement) + total = total_result.scalar_one() # 添加分页限制 statement = statement.offset(offset).limit(limit) # 执行查询并返回结果 - results = self.db.exec(statement) - device_list = results.all() + results = await self.db.execute(statement) + device_list = results.scalars().all() device_info_list = [] if device_list: device_model_relation_service = DeviceModelRelationService(self.db) device_scene_relation_service = DeviceSceneRelationService(self.db) device_status_manager = DeviceStatusManager() for device in device_list: - model_relations = device_model_relation_service.get_device_models(device.id) - scene_relation = device_scene_relation_service.get_device_scene(device.id) + model_relations = await device_model_relation_service.get_device_models(device.id) + scene_relation = await device_scene_relation_service.get_device_scene(device.id) device_info_list.append(DeviceInfo( id=device.id, @@ -121,20 +122,20 @@ statement = statement.where(Device.type == device_type) return statement - def create_device(self, device_data: DeviceCreate): + async def create_device(self, device_data: DeviceCreate): device = Device.model_validate(device_data) device.create_time = datetime.now() device.update_time = datetime.now() - self.handle_device_mode(device) + await self.handle_device_mode(device) self.db.add(device) - self.db.commit() - self.db.refresh(device) + await self.db.commit() + await self.db.refresh(device) self.notify_change(device.id, NotifyChangeType.DEVICE_CREATE) return device - def update_device(self, device_data: DeviceUpdate): - device_old = self.db.get(Device, device_data.id) - device = self.db.get(Device, device_data.id) + async def update_device(self, device_data: DeviceUpdate): + device_old = await self.get_device(device_data.id) + device = await self.get_device(device_data.id) if not device: return None @@ -142,41 +143,45 @@ for key, value in update_data.items(): setattr(device, key, value) - self.handle_device_mode(device) + await self.handle_device_mode(device) device.update_time = datetime.now() self.db.add(device) - self.db.commit() - self.db.refresh(device) + await self.db.commit() + await self.db.refresh(device) self.notify_change(device.id, NotifyChangeType.DEVICE_UPDATE) return device - def delete_device(self, device_id: int): - device = self.db.get(Device, device_id) + async def delete_device(self, device_id: int): + device = await self.get_device(device_id) if not device: return None - self.db.delete(device) - self.db.commit() + statement = delete(Device).where(Device.id == device_id) + await self.db.execute(statement) + await self.db.commit() + self.notify_change(device.id, NotifyChangeType.DEVICE_DELETE) model_relation_service = DeviceModelRelationService(self.db) - model_relation_service.delete_relations_by_device(device_id) + await model_relation_service.delete_relations_by_device(device_id) scene_relation_service = DeviceSceneRelationService(self.db) - scene_relation_service.delete_relation_by_device(device.id) + await scene_relation_service.delete_relation_by_device(device.id) return device - def handle_device_mode(self, device): + async def handle_device_mode(self, device): if device.mode == DEVICE_MODE.ALGO: scene_relation_service = DeviceSceneRelationService(self.db) - scene_relation_service.delete_relation_by_device(device.id) + await scene_relation_service.delete_relation_by_device(device.id) elif device.mode == DEVICE_MODE.SCENE: model_relation_service = DeviceModelRelationService(self.db) - model_relation_service.delete_relations_by_device(device.id) + await model_relation_service.delete_relations_by_device(device.id) else: scene_relation_service = DeviceSceneRelationService(self.db) - scene_relation_service.delete_relation_by_device(device.id) + await scene_relation_service.delete_relation_by_device(device.id) model_relation_service = DeviceModelRelationService(self.db) - model_relation_service.delete_relations_by_device(device.id) + await model_relation_service.delete_relations_by_device(device.id) - def get_device(self, device_id: int): - return self.db.get(Device, device_id) + async def get_device(self, device_id: int): + result = await self.db.execute(select(Device).where(Device.id == device_id)) + frame = result.scalar_one_or_none() + return frame diff --git a/services/frame_analysis_result_service.py b/services/frame_analysis_result_service.py index dbd272d..4e6c145 100644 --- a/services/frame_analysis_result_service.py +++ b/services/frame_analysis_result_service.py @@ -1,23 +1,25 @@ from typing import List -from sqlmodel import Session, select +from sqlmodel import select +from sqlalchemy.ext.asyncio import AsyncSession from entity.frame_analysis_result import FrameAnalysisResultCreate, FrameAnalysisResult class FrameAnalysisResultService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db - def add_frame_analysis_results(self, results: List[FrameAnalysisResultCreate]): + async def add_frame_analysis_results(self, results: List[FrameAnalysisResultCreate]): new_results = [FrameAnalysisResult.model_validate(result) for result in results] - self.db.add_all(new_results) - self.db.commit() for result in new_results: - self.db.refresh(result) + self.db.add(result) + await self.db.commit() + for result in new_results: + await self.db.refresh(result) return new_results - def get_results_by_frame(self, frame_id): + async def get_results_by_frame(self, frame_id): statement = select(FrameAnalysisResult).where(FrameAnalysisResult.frame_id == frame_id) - results = self.db.exec(statement) - return results.all() + results = await self.db.execute(statement) + return results.scalars().all() diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index dbc3fee..cdb97bf 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -98,7 +98,7 @@ def push_frame_results(self, frame_results): global_config = GlobalConfig() push_config = global_config.get_algo_result_push_config() - if push_config: + if push_config and push_config.push_url: last_ts = self.push_ts current_time = datetime.now() diff --git a/apis/control.py b/apis/control.py index 8517389..4c108c5 100644 --- a/apis/control.py +++ b/apis/control.py @@ -48,7 +48,7 @@ @router.get("/restart") -def restart(): +async def restart(): try: # 立即返回响应的函数 def restart_container_async(): diff --git a/apis/data_gas.py b/apis/data_gas.py index fc8ec54..bcb7cb6 100644 --- a/apis/data_gas.py +++ b/apis/data_gas.py @@ -3,8 +3,8 @@ import pandas as pd from fastapi import APIRouter, Query, Depends -from openpyxl.utils import get_column_letter -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession + from starlette.responses import StreamingResponse from apis.base import StandardResponse, PageResponse, convert_page_param, standard_response @@ -17,19 +17,20 @@ @router.get("/page", response_model=StandardResponse[PageResponse[DataGasInfo]]) -def get_gas_page( +async def get_gas_page( device_code: Optional[str] = None, start_time: Optional[str] = None, end_time: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): + service = DataGasService(db) offset, limit = convert_page_param(offset, limit) - data, total = service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), - offset, limit) + data, total = await service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), + offset, limit) return standard_response( data=PageResponse(total=total, items=data) @@ -37,38 +38,39 @@ @router.get("/export") -def export_data_gas(device_code: Optional[str] = None, - start_time: Optional[str] = None, - end_time: Optional[str] = None, - db: Session = Depends(get_db)): - service = DataGasService(db) - data = service.get_data_gas_list(device_code, start_time, end_time) +async def export_data_gas(device_code: Optional[str] = None, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + db: AsyncSession = Depends(get_db)): + async with db: + service = DataGasService(db) + data = await service.get_data_gas_list(device_code, start_time, end_time) - # 将查询结果转换为 DataFrame - data = [ - {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} - for item in data] - df = pd.DataFrame(data) + # 将查询结果转换为 DataFrame + data = [ + {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} + for item in data] + df = pd.DataFrame(data) - # 使用 BytesIO 生成内存中的 Excel 文件 - output = BytesIO() - with pd.ExcelWriter(output, engine="openpyxl") as writer: - df.to_excel(writer, index=False) + # 使用 BytesIO 生成内存中的 Excel 文件 + output = BytesIO() + with pd.ExcelWriter(output, engine="openpyxl") as writer: + df.to_excel(writer, index=False) - # 获取工作表 - worksheet = writer.sheets["Sheet1"] + # 获取工作表 + worksheet = writer.sheets["Sheet1"] - # 设置固定的列宽,例如宽度为 20 - fixed_width = 30 - for col in worksheet.columns: - col_letter = col[0].column_letter # 获取列字母 - worksheet.column_dimensions[col_letter].width = fixed_width + # 设置固定的列宽,例如宽度为 20 + fixed_width = 30 + for col in worksheet.columns: + col_letter = col[0].column_letter # 获取列字母 + worksheet.column_dimensions[col_letter].width = fixed_width - output.seek(0) + output.seek(0) - # 返回内存中的 Excel 文件作为响应 - return StreamingResponse( - output, - media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", - headers={"Content-Disposition": "attachment; filename=gas.xlsx"} - ) + # 返回内存中的 Excel 文件作为响应 + return StreamingResponse( + output, + media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + headers={"Content-Disposition": "attachment; filename=gas.xlsx"} + ) diff --git a/apis/device.py b/apis/device.py index 85fb10c..f170ae3 100644 --- a/apis/device.py +++ b/apis/device.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db @@ -18,29 +18,29 @@ @router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) -def get_device_list( +async def get_device_list( name: Optional[str] = None, code: Optional[str] = None, device_type: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceService(db) - devices = list(service.get_device_list(name, code, device_type)) - return standard_response(data=devices) + devices = await service.get_device_list(name, code, device_type) + return standard_response(data=list(devices)) @router.get("/page/", response_model=StandardResponse[PageResponse[DeviceInfo]]) -def get_device_page( +async def get_device_page( name: Optional[str] = None, code: Optional[str] = None, device_type: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - devices, total = service.get_device_page(name, code, device_type, offset, limit) + devices, total = await service.get_device_page(name, code, device_type, offset, limit) return standard_response( data=PageResponse(total=total, items=devices) @@ -48,22 +48,22 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): - device = service.create_device(device_data) +async def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): + device = await service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): - device = service.update_device(device_data) +async def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): + device = await service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") return standard_response(data=device) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, service: DeviceService = Depends(get_service)): - device = service.delete_device(device_id) +async def delete_device(device_id: int, service: DeviceService = Depends(get_service)): + device = await service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") return standard_response(data=device_id) diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py index e61e1d6..ddb5eb6 100644 --- a/apis/device_model_realtion.py +++ b/apis/device_model_realtion.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import StandardResponse, standard_response from db.database import get_db @@ -13,37 +13,23 @@ router = APIRouter() app = get_app() + def get_service(): return app.state.model_relation_service @router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) -def list_by_device( +async def list_by_device( device_id: int, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceModelRelationService(db) - models = list(service.get_device_models(device_id)) - return standard_response(data=models) - - -# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) -# def add_by_device(relation_data: List[DeviceModelRelationCreate], -# device_id: int = Query(...), -# db: Session = Depends(get_db)): -# service = DeviceModelRelationService(db) -# relations = service.add_relations_by_device(device_id, relation_data) -# return standard_response(data=relations) + models = await service.get_device_models(device_id) + return standard_response(data=list(models)) @router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) -def update_by_device(relation_data: List[DeviceModelRelationCreate], - device_id: int = Query(...), - service: DeviceModelRelationService = Depends(get_service)): - relations = service.update_relations_by_device(device_id, relation_data) +async def update_by_device(relation_data: List[DeviceModelRelationCreate], + device_id: int = Query(...), + service: DeviceModelRelationService = Depends(get_service)): + relations = await service.update_relations_by_device(device_id, relation_data) return standard_response(data=relations) - -# @router.delete("/delete_by_device", response_model=StandardResponse[int]) -# def delete_device(device_id: int, db: Session = Depends(get_db)): -# service = DeviceModelRelationService(db) -# count = service.delete_relations_by_device(device_id) -# return standard_response(data=count) diff --git a/apis/device_scene_realtion.py b/apis/device_scene_realtion.py index 9a30a44..d9aa3a2 100644 --- a/apis/device_scene_realtion.py +++ b/apis/device_scene_realtion.py @@ -1,8 +1,8 @@ -from typing import List, Optional + from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import StandardResponse, standard_response from db.database import get_db from entity.device_scene_relation import DeviceSceneRelationInfo, DeviceSceneRelation @@ -13,17 +13,17 @@ @router.get("/get_by_device", response_model=StandardResponse[DeviceSceneRelationInfo]) -def list_by_device( +async def list_by_device( device_id: int, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - scene = service.get_device_scene(device_id) + scene = await service.get_device_scene(device_id) return standard_response(data=scene) @router.post("/update_by_device", response_model=StandardResponse[DeviceSceneRelation]) -def update_by_device(device_id: int, scene_id: int, - db: Session = Depends(get_db)): +async def update_by_device(device_id: int, scene_id: int, + db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - relation = service.update_relation_by_device(device_id, scene_id) + relation = await service.update_relation_by_device(device_id, scene_id) return standard_response(data=relation) diff --git a/apis/frame.py b/apis/frame.py index 1cb9294..5146ede 100644 --- a/apis/frame.py +++ b/apis/frame.py @@ -5,7 +5,7 @@ import cv2 from fastapi import APIRouter, Depends, Query, HTTPException, Request -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from fastapi.responses import StreamingResponse, FileResponse from apis.base import StandardResponse, PageResponse, standard_response, standard_error_response, convert_page_param @@ -19,20 +19,20 @@ @router.get("/page/", response_model=StandardResponse[PageResponse[DeviceFrame]]) -def get_frame_page( +async def get_frame_page( device_name: Optional[str] = None, device_code: Optional[str] = None, frame_start_time: Optional[str] = None, frame_end_time: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceFrameService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - frames, total = service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), - parse_datetime(frame_end_time), offset, limit) + frames, total = await service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), + parse_datetime(frame_end_time), offset, limit) return standard_response( data=PageResponse(total=total, items=frames) @@ -41,10 +41,10 @@ # 路由:使用 OpenCV 生成内存图像并返回字节流 @router.get("/frame_image/") -def get_frame_image(frame_id, db: Session = Depends(get_db)): +async def get_frame_image(frame_id, db: AsyncSession = Depends(get_db)): try: service = DeviceFrameService(db) - frame = service.get_frame_annotator(frame_id) + frame = await service.get_frame_annotator(frame_id) if frame is None: return standard_error_response(message="Frame does not exist") @@ -63,14 +63,3 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"Error reading image: {str(e)}") - -@router.get("/frame_test") -def get_frame_test(request: Request): - file_path = "test.jpg" - return FileResponse(file_path) - # def iterfile(): - # with open(file_path, 'rb') as file: - # while chunk := file.read(1024): - # yield chunk - # - # return StreamingResponse(iterfile(), media_type="image/jpeg") diff --git a/apis/model.py b/apis/model.py index a1161eb..fa2b191 100644 --- a/apis/model.py +++ b/apis/model.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query, UploadFile, File, Form -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from app_instance import get_app @@ -18,27 +18,27 @@ @router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) -def get_model_list( +async def get_model_list( name: Optional[str] = None, remark: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = ModelService(db) - models = list(service.get_model_list(name, remark)) - return standard_response(data=models) + models = await service.get_model_list(name, remark) + return standard_response(data=list(models)) @router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModelInfo]]) -def get_model_page( +async def get_model_page( name: Optional[str] = None, remark: Optional[str] = None, offset: int = Query(0, ge=0), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = ModelService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - models, total = service.get_model_page(name, remark, offset, limit) + models, total = await service.get_model_page(name, remark, offset, limit) return standard_response( data=PageResponse(total=total, items=models) @@ -46,9 +46,9 @@ @router.post("/add", response_model=StandardResponse[AlgoModelInfo]) -def create_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelCreate结构"), +async def create_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelCreate结构"), file: UploadFile = File(..., description="模型文件"), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_db), ): # 检查文件类型 if not file.filename.endswith(".zip"): @@ -56,25 +56,25 @@ model_data = AlgoModelCreate.parse_raw(json_data) service = ModelService(db) - model = service.create_model(model_data, file) + model = await service.create_model(model_data, file) return standard_response(data=model) @router.post("/update", response_model=StandardResponse[AlgoModelInfo]) -def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), +async def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), file: UploadFile = File(None, description="模型文件"), service: ModelService = Depends(get_service)): model_data = AlgoModelUpdate.parse_raw(json_data) - model = service.update_model(model_data, file) + model = await service.update_model(model_data, file) if not model: return standard_error_response(data=model_data, message="Model not found") return standard_response(data=model) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_model(model_id: int, db: Session = Depends(get_db)): +async def delete_model(model_id: int, db: AsyncSession = Depends(get_db)): service = ModelService(db) - model = service.delete_model(model_id) + model = await service.delete_model(model_id) if not model: return standard_error_response(data=model_id, message="Model not found") return standard_response(data=model_id) diff --git a/apis/push_config.py b/apis/push_config.py index 5e25579..e234afb 100644 --- a/apis/push_config.py +++ b/apis/push_config.py @@ -1,7 +1,7 @@ from typing import List from fastapi import APIRouter, Depends -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse from db.database import get_db @@ -12,21 +12,21 @@ @router.get("/list", response_model=StandardResponse[List[PushConfig]]) -def get_push_config_list(db: Session = Depends(get_db)): +async def get_push_config_list(db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_configs = service.get_push_config_list() + push_configs = await service.get_push_config_list() return standard_response(data=push_configs) @router.get("/get_by_type", response_model=StandardResponse[PushConfig]) -def get_by_type(push_type: int, db: Session = Depends(get_db)): +async def get_by_type(push_type: int, db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_config = service.get_push_config(push_type) + push_config = await service.get_push_config(push_type) return standard_response(data=push_config) @router.post("/set_push_config", response_model=StandardResponse[PushConfig]) -def set_push_config(push_config: PushConfigCreate, db: Session = Depends(get_db)): +async def set_push_config(push_config: PushConfigCreate, db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_config = service.set_push_config(push_config) + push_config = await service.set_push_config(push_config) return standard_response(data=push_config) diff --git a/apis/scene.py b/apis/scene.py index f84b54b..1b7f706 100644 --- a/apis/scene.py +++ b/apis/scene.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query, UploadFile, File, Form -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db @@ -14,27 +14,27 @@ @router.get("/list", response_model=StandardResponse[List[Scene]]) -def get_scene_list( +async def get_scene_list( name: Optional[str] = None, remark: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) - scenes = list(service.get_scene_list(name, remark)) - return standard_response(data=scenes) + scenes = await service.get_scene_list(name, remark) + return standard_response(data=list(scenes)) @router.get("/page/", response_model=StandardResponse[PageResponse[SceneInfo]]) -def get_scene_page( +async def get_scene_page( name: Optional[str] = None, remark: Optional[str] = None, offset: int = Query(0, ge=0), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - scenes, total = service.get_scene_page(name, remark, offset, limit) + scenes, total = await service.get_scene_page(name, remark, offset, limit) return standard_response( data=PageResponse(total=total, items=scenes) @@ -42,9 +42,9 @@ @router.post("/add", response_model=StandardResponse[SceneInfo]) -def create_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneCreate结构"), +async def create_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneCreate结构"), file: UploadFile = File(..., description="模型文件"), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_db), ): # 检查文件类型 if not file.filename.endswith(".zip"): @@ -52,26 +52,26 @@ scene_data = SceneCreate.parse_raw(json_data) service = SceneService(db) - scene = service.create_scene(scene_data, file) + scene = await service.create_scene(scene_data, file) return standard_response(data=scene) @router.post("/update", response_model=StandardResponse[SceneInfo]) -def update_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneUpdate结构"), +async def update_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneUpdate结构"), file: UploadFile = File(None, description="模型文件"), - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) scene_data = SceneUpdate.parse_raw(json_data) - scene = service.update_scene(scene_data, file) + scene = await service.update_scene(scene_data, file) if not scene: return standard_error_response(data=scene_data, message="Scene not found") return standard_response(data=scene) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_scene(scene_id: int, db: Session = Depends(get_db)): +async def delete_scene(scene_id: int, db: AsyncSession = Depends(get_db)): service = SceneService(db) - scene = service.delete_scene(scene_id) + scene = await service.delete_scene(scene_id) if not scene: return standard_error_response(data=scene_id, message="Scene not found") return standard_response(data=scene_id) diff --git a/app_instance.py b/app_instance.py index a6e0992..f9d1ad9 100644 --- a/app_instance.py +++ b/app_instance.py @@ -12,6 +12,7 @@ from services.device_model_relation_service import DeviceModelRelationService from services.device_scene_relation_service import DeviceSceneRelationService from services.device_service import DeviceService +from services.global_config import GlobalConfig from services.model_service import ModelService from services.scene_service import SceneService from tcp.tcp_manager import TcpManager @@ -28,7 +29,11 @@ async def lifespan(app: FastAPI): main_loop = asyncio.get_running_loop() - with next(get_db()) as db: + # async with get_db() as db: + async for db in get_db(): + global_config = GlobalConfig() + await global_config.init_config() + device_service = DeviceService(db) model_service = ModelService(db) model_relation_service = DeviceModelRelationService(db) @@ -51,7 +56,7 @@ relation_service=model_relation_service, ) app.state.algo_runner = algo_runner - await algo_runner.start() + # await algo_runner.start() scene_runner = SceneRunner( device_service=device_service, @@ -61,7 +66,9 @@ main_loop=main_loop ) app.state.scene_runner = scene_runner - await scene_runner.start() + # await scene_runner.start() + + yield # 允许请求处理 diff --git a/common/global_logger.py b/common/global_logger.py index 3e3a6a9..3e672ca 100644 --- a/common/global_logger.py +++ b/common/global_logger.py @@ -12,6 +12,8 @@ logger = logging.getLogger("casic_safe_logger") logger.setLevel(logging.DEBUG) # 设置日志级别 +logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING) + # 创建一个TimedRotatingFileHandler handler = logging.handlers.TimedRotatingFileHandler( os.path.join(log_dir, 'app.log'), # 日志文件名 diff --git a/db/database.py b/db/database.py index 4f44d2d..01ac013 100644 --- a/db/database.py +++ b/db/database.py @@ -1,22 +1,25 @@ +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession +from sqlalchemy.orm import sessionmaker from sqlmodel import SQLModel, create_engine, Session -from contextlib import contextmanager +from contextlib import contextmanager, asynccontextmanager sqlite_file_name = "./db/safe-algo-pro.db" -sqlite_url = f"sqlite:///{sqlite_file_name}" +sqlite_url = f"sqlite+aiosqlite:///{sqlite_file_name}" # 使用异步SQLite驱动 -connect_args = {"check_same_thread": False} -engine = create_engine(sqlite_url, connect_args=connect_args) +engine = create_async_engine(sqlite_url, echo=False, future=True) # 初始化数据库表 -def init_db(): - SQLModel.metadata.create_all(engine) +async def init_db(): + async with engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.create_all) -# 数据库会话管理 -def get_db(): - session = Session(engine) - try: +# 异步数据库会话管理 +# @asynccontextmanager +async def get_db() -> AsyncSession: + async_session = sessionmaker( + engine, class_=AsyncSession, expire_on_commit=False + ) + async with async_session() as session: yield session - finally: - session.close() diff --git a/main.py b/main.py index e35e458..e830e86 100644 --- a/main.py +++ b/main.py @@ -3,7 +3,9 @@ import logging from fastapi.openapi.docs import get_swagger_ui_html -from starlette.staticfiles import StaticFiles +from fastapi.middleware.cors import CORSMiddleware +from fastapi.staticfiles import StaticFiles + from app_instance import get_app from common.global_logger import logger @@ -28,6 +30,14 @@ from apis.router import router app.include_router(router, prefix="/api") +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + if __name__ == "__main__": # 重定向 uvicorn 的日志 uvicorn_logger = logging.getLogger("uvicorn") diff --git a/requirements.txt b/requirements.txt index 161192a..aa9e1e6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,6 @@ starlette uvicorn sqlalchemy -aiohttp \ No newline at end of file +aiohttp +aiosqlite +aiofiles \ No newline at end of file diff --git a/services/data_gas_service.py b/services/data_gas_service.py index fe6ce53..6ac5665 100644 --- a/services/data_gas_service.py +++ b/services/data_gas_service.py @@ -2,23 +2,24 @@ from typing import Optional, Tuple, Sequence from sqlalchemy import func -from sqlmodel import Session, select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlmodel import select from entity.data_gas import DataGas, DataGasInfo from entity.device import Device class DataGasService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db - def add_data_gas(self, data_gas: DataGas): + async def add_data_gas(self, data_gas: DataGas): self.db.add(data_gas) - self.db.commit() - self.db.refresh(data_gas) + await self.db.commit() + await self.db.refresh(data_gas) return data_gas - def get_data_gas_page(self, + async def get_data_gas_page(self, device_code: Optional[str] = None, start_time: Optional[datetime] = None, end_time: Optional[datetime] = None, @@ -29,13 +30,17 @@ # 查询总记录数 total_statement = select(func.count()).select_from(statement.subquery()) - total = self.db.exec(total_statement).one() + total_result = await self.db.execute(total_statement) + total = total_result.scalar_one() # 添加分页限制 statement = statement.offset(offset).limit(limit) # 执行查询并返回结果 - results = self.db.exec(statement).all() + results = await self.db.execute(statement) + rows = results.all() + + data_gas_info_list = [ DataGasInfo( id=data_gas.id, @@ -44,18 +49,19 @@ ts=data_gas.ts, device_name=device_name ) - for data_gas, device_name in results + for data_gas, device_name in rows ] return data_gas_info_list, total # 返回分页数据和总数 - def get_data_gas_list(self, + async def get_data_gas_list(self, device_code: Optional[str] = None, start_time: Optional[datetime] = None, end_time: Optional[datetime] = None, ) -> Sequence[DataGasInfo]: statement = self.gas_query(device_code, end_time, start_time) - results = self.db.exec(statement).all() + results = await self.db.execute(statement) + rows = results.all() data_gas_info_list = [ DataGasInfo( id=data_gas.id, @@ -64,7 +70,7 @@ ts=data_gas.ts, device_name=device_name ) - for data_gas, device_name in results + for data_gas, device_name in rows ] return data_gas_info_list diff --git a/services/device_frame_service.py b/services/device_frame_service.py index 5dead7d..1a468ae 100644 --- a/services/device_frame_service.py +++ b/services/device_frame_service.py @@ -3,7 +3,10 @@ from copy import deepcopy from datetime import datetime from typing import Sequence, Optional, Tuple -from sqlmodel import Session + +import aiofiles +import numpy as np +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import func @@ -13,46 +16,56 @@ import cv2 -from sqlmodel import Session, select, delete +from sqlmodel import select, delete from services.frame_analysis_result_service import FrameAnalysisResultService class DeviceFrameService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db - def add_frame(self, device_id, frame_data) -> DeviceFrame: - def save_frame_file(): - # 生成当前年月日作为目录路径 - current_date = datetime.now().strftime('%Y-%m-%d') - # 生成随机 UUID 作为文件名 - file_name = str(uuid.uuid4()) + ".jpeg" - # 创建保存图片的完整路径 - save_path = os.path.join('./storage/frames', current_date, file_name) - # 创建目录(如果不存在) - os.makedirs(os.path.dirname(save_path), exist_ok=True) - # 保存图片 - cv2.imwrite(save_path, frame_data) - return save_path + async def add_frame(self, device_id, frame_data) -> DeviceFrame: + async def add_frame(self, device_id, frame_data) -> 'DeviceFrame': + async def save_frame_file(): + # 生成当前年月日作为目录路径 + current_date = datetime.now().strftime('%Y-%m-%d') + # 生成随机 UUID 作为文件名 + file_name = f"{uuid.uuid4()}.jpeg" + # 创建保存图片的完整路径 + save_path = os.path.join('./storage/frames', current_date, file_name) + # 创建目录(如果不存在) + os.makedirs(os.path.dirname(save_path), exist_ok=True) - # 保存图片文件 - file_path = save_frame_file() - device_frame = DeviceFrame(device_id=device_id, frame_path=file_path) - self.db.add(device_frame) - self.db.commit() - self.db.refresh(device_frame) - return device_frame + # 将 frame_data 转换为二进制数据(假设 frame_data 是一个 numpy 数组) + _, encoded_image = cv2.imencode('.jpeg', frame_data) + image_data = encoded_image.tobytes() - def get_frame_page(self, - device_name: Optional[str] = None, - device_code: Optional[str] = None, - frame_start_time: Optional[datetime] = None, - frame_end_time: Optional[datetime] = None, - offset: int = 0, - limit: int = 10 - ) -> Tuple[Sequence[DeviceFrame], int]: + # 使用 aiofiles 进行异步写入 + async with aiofiles.open(save_path, 'wb') as f: + await f.write(image_data) + + return save_path + + # 异步保存图片文件 + file_path = await save_frame_file() + + # 创建并保存到数据库中 + device_frame = DeviceFrame(device_id=device_id, frame_path=file_path) + self.db.add(device_frame) + await self.db.commit() + await self.db.refresh(device_frame) + return device_frame + + async def get_frame_page(self, + device_name: Optional[str] = None, + device_code: Optional[str] = None, + frame_start_time: Optional[datetime] = None, + frame_end_time: Optional[datetime] = None, + offset: int = 0, + limit: int = 10 + ) -> Tuple[Sequence[DeviceFrame], int]: statement = ( select(DeviceFrame, Device) .join(Device, DeviceFrame.device_id == Device.id) @@ -69,26 +82,37 @@ # 查询总记录数 total_statement = select(func.count()).select_from(statement.subquery()) - total = self.db.exec(total_statement).one() + total_result = await self.db.execute(total_statement) + total = total_result.scalar_one() # 添加分页限制 statement = statement.offset(offset).limit(limit) # 执行查询并返回结果 - results = self.db.exec(statement) - results = results.all() - frames = [frame for frame, device in results] + results = await self.db.execute(statement) + rows = results.all() + + frames = [frame for frame, device in rows] return frames, total # 返回分页数据和总数 - def get_frame(self, frame_id: int): - return self.db.get(DeviceFrame, frame_id) + async def get_frame(self, frame_id: int): + result = await self.db.execute(select(DeviceFrame).where(DeviceFrame.id == frame_id)) + frame = result.scalar_one_or_none() + return frame - def get_frame_annotator(self, frame_id: int): - device_frame = self.get_frame(frame_id) + async def get_frame_annotator(self, frame_id: int): + device_frame = await self.get_frame(frame_id) if device_frame: - frame_image = cv2.imread(device_frame.frame_path) + # 异步读取图像文件 + async with aiofiles.open(device_frame.frame_path, mode='rb') as f: + file_content = await f.read() + + # 将读取的字节内容转换为 OpenCV 图像 + np_array = np.frombuffer(file_content, dtype=np.uint8) + frame_image = cv2.imdecode(np_array, cv2.IMREAD_COLOR) + frame_analysis_result_service = FrameAnalysisResultService(self.db) - results = frame_analysis_result_service.get_results_by_frame(device_frame.id) + results = await frame_analysis_result_service.get_results_by_frame(device_frame.id) if results: annotator = Annotator(deepcopy(frame_image)) height, width = frame_image.shape[:2] diff --git a/services/device_model_relation_service.py b/services/device_model_relation_service.py index c32b5a5..5f1aa9d 100644 --- a/services/device_model_relation_service.py +++ b/services/device_model_relation_service.py @@ -1,8 +1,8 @@ from datetime import datetime from typing import List -from sqlmodel import Session, select, delete - +from sqlmodel import select, delete +from sqlalchemy.ext.asyncio import AsyncSession from common.consts import NotifyChangeType from common.global_thread_pool import GlobalThreadPool from entity.device_model_relation import DeviceModelRelation, DeviceModelRelationInfo, DeviceModelRelationCreate @@ -10,7 +10,7 @@ class DeviceModelRelationService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db self.__relation_change_callbacks = [] # 用于存储回调函数 self.thread_pool = GlobalThreadPool() @@ -24,7 +24,7 @@ for callback in self.__relation_change_callbacks: self.thread_pool.executor.submit(callback, device_id, change_type) - def get_device_models(self, device_id: int) -> List[DeviceModelRelationInfo]: + async def get_device_models(self, device_id: int) -> List[DeviceModelRelationInfo]: statement = ( select(DeviceModelRelation, AlgoModel) .join(AlgoModel, DeviceModelRelation.algo_model_id == AlgoModel.id) @@ -32,7 +32,8 @@ ) # 执行联表查询 - result = self.db.exec(statement).all() + results = await self.db.execute(statement) + rows = results.all() models_info = [ DeviceModelRelationInfo( @@ -46,13 +47,11 @@ algo_model_path=model.path, algo_model_remark=model.remark, ) - for relation, model in result + for relation, model in rows ] return models_info - - - def add_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]): + async def add_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]): new_relations = [ DeviceModelRelation( algo_model_id=relation.algo_model_id, @@ -64,20 +63,21 @@ ) for relation in relations ] - self.db.add_all(new_relations) - self.db.commit() for relation in new_relations: - self.db.refresh(relation) + self.db.add(relation) + await self.db.commit() + for relation in new_relations: + await self.db.refresh(relation) return new_relations - def delete_relations_by_device(self, device_id: int): + async def delete_relations_by_device(self, device_id: int): statement = delete(DeviceModelRelation).where(DeviceModelRelation.device_id == device_id) - count = self.db.exec(statement) - self.db.commit() - return count.rowcount + result = await self.db.execute(statement) + await self.db.commit() + return result.rowcount - def update_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]): - self.delete_relations_by_device(device_id) - new_relations = self.add_relations_by_device(device_id, relations) + async def update_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]): + await self.delete_relations_by_device(device_id) + new_relations = await self.add_relations_by_device(device_id, relations) self.notify_change(device_id, NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE) return new_relations diff --git a/services/device_scene_relation_service.py b/services/device_scene_relation_service.py index 592ee98..3416831 100644 --- a/services/device_scene_relation_service.py +++ b/services/device_scene_relation_service.py @@ -1,8 +1,8 @@ from datetime import datetime from typing import List, Optional -from sqlmodel import Session, select, delete - +from sqlmodel import select, delete +from sqlalchemy.ext.asyncio import AsyncSession from common.consts import NotifyChangeType from common.global_thread_pool import GlobalThreadPool from entity.device_scene_relation import DeviceSceneRelationInfo, DeviceSceneRelation @@ -10,7 +10,7 @@ class DeviceSceneRelationService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db self.__relation_change_callbacks = [] # 用于存储回调函数 self.thread_pool = GlobalThreadPool() @@ -24,7 +24,7 @@ for callback in self.__relation_change_callbacks: self.thread_pool.executor.submit(callback, device_id, change_type) - def get_device_scene(self, device_id: int) -> Optional[DeviceSceneRelationInfo]: + async def get_device_scene(self, device_id: int) -> Optional[DeviceSceneRelationInfo]: statement = ( select(DeviceSceneRelation, Scene) .join(Scene, DeviceSceneRelation.scene_id == Scene.id) @@ -32,11 +32,12 @@ ) # 执行联表查询 - result = self.db.exec(statement).first() + result = await self.db.execute(statement) + result_row = result.first() scene_info = None - if result: - relation, scene = result[0], result[1] + if result_row: + relation, scene = result_row scene_info = DeviceSceneRelationInfo( id=relation.id, device_id=relation.device_id, @@ -48,23 +49,23 @@ ) return scene_info - def add_relation_by_device(self, device_id: int, scene_id: int): + async def add_relation_by_device(self, device_id: int, scene_id: int): new_relation = DeviceSceneRelation(device_id=device_id, scene_id=scene_id) new_relation.create_time = datetime.now() new_relation.update_time = datetime.now() self.db.add(new_relation) - self.db.commit() - self.db.refresh(new_relation) + await self.db.commit() + await self.db.refresh(new_relation) return new_relation - def delete_relation_by_device(self, device_id: int): + async def delete_relation_by_device(self, device_id: int): statement = delete(DeviceSceneRelation).where(DeviceSceneRelation.device_id == device_id) - count = self.db.exec(statement) - self.db.commit() - return count.rowcount + result = await self.db.execute(statement) + await self.db.commit() + return result.rowcount - def update_relation_by_device(self, device_id: int, scene_id: int): - self.delete_relation_by_device(device_id) - new_relation = self.add_relation_by_device(device_id, scene_id) + async def update_relation_by_device(self, device_id: int, scene_id: int): + await self.delete_relation_by_device(device_id) + new_relation = await self.add_relation_by_device(device_id, scene_id) self.notify_change(device_id, NotifyChangeType.DEVICE_SCENE_RELATION_UPDATE) return new_relation diff --git a/services/device_service.py b/services/device_service.py index cad46d3..2488fd5 100644 --- a/services/device_service.py +++ b/services/device_service.py @@ -4,8 +4,8 @@ from typing import Sequence, Optional, Tuple from sqlalchemy import func -from sqlmodel import Session, select - +from sqlmodel import select, delete +from sqlalchemy.ext.asyncio import AsyncSession from common.device_status_manager import DeviceStatusManager from common.global_thread_pool import GlobalThreadPool from common.consts import NotifyChangeType, DEVICE_MODE @@ -15,25 +15,25 @@ class DeviceService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db self.__device_change_callbacks = [] # 用于存储回调函数 self.thread_pool = GlobalThreadPool() # 创建一个独立的事件循环并启动线程 - self.loop = asyncio.new_event_loop() - self.loop_thread = threading.Thread(target=self._start_loop, daemon=True) - self.loop_thread.start() + # self.loop = asyncio.new_event_loop() + # self.loop_thread = threading.Thread(target=self._start_loop, daemon=True) + # self.loop_thread.start() - def _start_loop(self): - """后台线程运行事件循环""" - asyncio.set_event_loop(self.loop) - self.loop.run_forever() - - def shutdown(self): - """清理事件循环和线程""" - self.loop.call_soon_threadsafe(self.loop.stop) - self.loop_thread.join() + # def _start_loop(self): + # """后台线程运行事件循环""" + # asyncio.set_event_loop(self.loop) + # self.loop.run_forever() + # + # def shutdown(self): + # """清理事件循环和线程""" + # self.loop.call_soon_threadsafe(self.loop.stop) + # self.loop_thread.join() def register_change_callback(self, callback): """注册设备变化回调函数""" @@ -42,52 +42,53 @@ def notify_change(self, device_id, change_type): """当设备发生变化时,调用回调通知变化""" # loop = asyncio.get_event_loop() # 获取当前的事件循环 - for callback in self.__device_change_callbacks: - if asyncio.iscoroutinefunction(callback): - # 如果是协程函数,使用事件循环运行它 - # loop = asyncio.new_event_loop() - # asyncio.set_event_loop(loop) - asyncio.run_coroutine_threadsafe(callback(device_id, change_type), self.loop) - else: - # 如果是普通函数,直接提交到线程池 - self.thread_pool.executor.submit(callback, device_id, change_type) + # for callback in self.__device_change_callbacks: + # if asyncio.iscoroutinefunction(callback): + # # 如果是协程函数,使用事件循环运行它 + # # loop = asyncio.new_event_loop() + # # asyncio.set_event_loop(loop) + # asyncio.run_coroutine_threadsafe(callback(device_id, change_type), self.loop) + # else: + # # 如果是普通函数,直接提交到线程池 + # self.thread_pool.executor.submit(callback, device_id, change_type) - def get_device_list(self, - name: Optional[str] = None, - code: Optional[str] = None, - device_type: Optional[str] = None, - ) -> Sequence[Device]: + async def get_device_list(self, + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + ) -> Sequence[Device]: statement = self.device_query(code, device_type, name) - results = self.db.exec(statement) - return results.all() + results = await self.db.execute(statement) + return results.scalars().all() - def get_device_page(self, - name: Optional[str] = None, - code: Optional[str] = None, - device_type: Optional[str] = None, - offset: int = 0, - limit: int = 10 - ) -> Tuple[Sequence[DeviceInfo], int]: + async def get_device_page(self, + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + offset: int = 0, + limit: int = 10 + ) -> Tuple[Sequence[DeviceInfo], int]: statement = self.device_query(code, device_type, name) # 查询总记录数 total_statement = select(func.count()).select_from(statement.subquery()) - total = self.db.exec(total_statement).one() + total_result = await self.db.execute(total_statement) + total = total_result.scalar_one() # 添加分页限制 statement = statement.offset(offset).limit(limit) # 执行查询并返回结果 - results = self.db.exec(statement) - device_list = results.all() + results = await self.db.execute(statement) + device_list = results.scalars().all() device_info_list = [] if device_list: device_model_relation_service = DeviceModelRelationService(self.db) device_scene_relation_service = DeviceSceneRelationService(self.db) device_status_manager = DeviceStatusManager() for device in device_list: - model_relations = device_model_relation_service.get_device_models(device.id) - scene_relation = device_scene_relation_service.get_device_scene(device.id) + model_relations = await device_model_relation_service.get_device_models(device.id) + scene_relation = await device_scene_relation_service.get_device_scene(device.id) device_info_list.append(DeviceInfo( id=device.id, @@ -121,20 +122,20 @@ statement = statement.where(Device.type == device_type) return statement - def create_device(self, device_data: DeviceCreate): + async def create_device(self, device_data: DeviceCreate): device = Device.model_validate(device_data) device.create_time = datetime.now() device.update_time = datetime.now() - self.handle_device_mode(device) + await self.handle_device_mode(device) self.db.add(device) - self.db.commit() - self.db.refresh(device) + await self.db.commit() + await self.db.refresh(device) self.notify_change(device.id, NotifyChangeType.DEVICE_CREATE) return device - def update_device(self, device_data: DeviceUpdate): - device_old = self.db.get(Device, device_data.id) - device = self.db.get(Device, device_data.id) + async def update_device(self, device_data: DeviceUpdate): + device_old = await self.get_device(device_data.id) + device = await self.get_device(device_data.id) if not device: return None @@ -142,41 +143,45 @@ for key, value in update_data.items(): setattr(device, key, value) - self.handle_device_mode(device) + await self.handle_device_mode(device) device.update_time = datetime.now() self.db.add(device) - self.db.commit() - self.db.refresh(device) + await self.db.commit() + await self.db.refresh(device) self.notify_change(device.id, NotifyChangeType.DEVICE_UPDATE) return device - def delete_device(self, device_id: int): - device = self.db.get(Device, device_id) + async def delete_device(self, device_id: int): + device = await self.get_device(device_id) if not device: return None - self.db.delete(device) - self.db.commit() + statement = delete(Device).where(Device.id == device_id) + await self.db.execute(statement) + await self.db.commit() + self.notify_change(device.id, NotifyChangeType.DEVICE_DELETE) model_relation_service = DeviceModelRelationService(self.db) - model_relation_service.delete_relations_by_device(device_id) + await model_relation_service.delete_relations_by_device(device_id) scene_relation_service = DeviceSceneRelationService(self.db) - scene_relation_service.delete_relation_by_device(device.id) + await scene_relation_service.delete_relation_by_device(device.id) return device - def handle_device_mode(self, device): + async def handle_device_mode(self, device): if device.mode == DEVICE_MODE.ALGO: scene_relation_service = DeviceSceneRelationService(self.db) - scene_relation_service.delete_relation_by_device(device.id) + await scene_relation_service.delete_relation_by_device(device.id) elif device.mode == DEVICE_MODE.SCENE: model_relation_service = DeviceModelRelationService(self.db) - model_relation_service.delete_relations_by_device(device.id) + await model_relation_service.delete_relations_by_device(device.id) else: scene_relation_service = DeviceSceneRelationService(self.db) - scene_relation_service.delete_relation_by_device(device.id) + await scene_relation_service.delete_relation_by_device(device.id) model_relation_service = DeviceModelRelationService(self.db) - model_relation_service.delete_relations_by_device(device.id) + await model_relation_service.delete_relations_by_device(device.id) - def get_device(self, device_id: int): - return self.db.get(Device, device_id) + async def get_device(self, device_id: int): + result = await self.db.execute(select(Device).where(Device.id == device_id)) + frame = result.scalar_one_or_none() + return frame diff --git a/services/frame_analysis_result_service.py b/services/frame_analysis_result_service.py index dbd272d..4e6c145 100644 --- a/services/frame_analysis_result_service.py +++ b/services/frame_analysis_result_service.py @@ -1,23 +1,25 @@ from typing import List -from sqlmodel import Session, select +from sqlmodel import select +from sqlalchemy.ext.asyncio import AsyncSession from entity.frame_analysis_result import FrameAnalysisResultCreate, FrameAnalysisResult class FrameAnalysisResultService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db - def add_frame_analysis_results(self, results: List[FrameAnalysisResultCreate]): + async def add_frame_analysis_results(self, results: List[FrameAnalysisResultCreate]): new_results = [FrameAnalysisResult.model_validate(result) for result in results] - self.db.add_all(new_results) - self.db.commit() for result in new_results: - self.db.refresh(result) + self.db.add(result) + await self.db.commit() + for result in new_results: + await self.db.refresh(result) return new_results - def get_results_by_frame(self, frame_id): + async def get_results_by_frame(self, frame_id): statement = select(FrameAnalysisResult).where(FrameAnalysisResult.frame_id == frame_id) - results = self.db.exec(statement) - return results.all() + results = await self.db.execute(statement) + return results.scalars().all() diff --git a/services/global_config.py b/services/global_config.py index 7dae73b..fa53b2b 100644 --- a/services/global_config.py +++ b/services/global_config.py @@ -1,3 +1,5 @@ +import asyncio + from common.consts import PUSH_TYPE from db.database import get_db from entity.push_config import PushConfig @@ -6,10 +8,12 @@ class GlobalConfig: _instance = None + _lock = asyncio.Lock() def __new__(cls): if cls._instance is None: cls._instance = super().__new__(cls) + return cls._instance def __init__(self): @@ -19,49 +23,58 @@ self.gas_push_config = None self.algo_result_push_config = None self.alarm_push_config = None - self.init_config() # 进行初始化 + self._init_done = False - def init_config(self): - # 初始化配置逻辑 - with next(get_db()) as db: - self.config_service = PushConfigService(db) - self.set_gas_push_config(self.config_service.get_push_config(PUSH_TYPE.GAS)) - self.set_algo_result_push_config(self.config_service.get_push_config(PUSH_TYPE.ALGO_RESULT)) - self.set_alarm_push_config(self.config_service.get_push_config(PUSH_TYPE.ALARM)) + async def _initialize(self): + await self.init_config() # 调用异步初始化 - self.config_service.register_change_callback(self.on_config_change) + async def init_config(self): + # 确保只初始化一次 + if not self._init_done: + async with self._lock: + if not self._init_done: # 双重检查锁,避免多次初始化 + async for db in get_db(): + self.config_service = PushConfigService(db) + self.config_service.register_change_callback(self.on_config_change) + self.gas_push_config = await self.config_service.get_push_config(PUSH_TYPE.GAS) + self.algo_result_push_config = await self.config_service.get_push_config(PUSH_TYPE.ALGO_RESULT) + self.alarm_push_config = await self.config_service.get_push_config(PUSH_TYPE.ALARM) + self._init_done = True - def on_config_change(self, config: PushConfig): + async def on_config_change(self, config: PushConfig): if config.push_type == PUSH_TYPE.GAS: - self.set_gas_push_config(config) + await self.set_gas_push_config(config) elif config.push_type == PUSH_TYPE.ALGO_RESULT: - self.set_algo_result_push_config(config) + await self.set_algo_result_push_config(config) elif config.push_type == PUSH_TYPE.ALARM: - self.set_alarm_push_config(config) + await self.set_alarm_push_config(config) def get_gas_push_config(self): """获取 gas_push_config 配置""" return self.gas_push_config - def set_gas_push_config(self, config): + async def set_gas_push_config(self, config): """设置 gas_push_config 配置""" if config: - self.gas_push_config = config + async with self._lock: + self.gas_push_config = config def get_algo_result_push_config(self): """获取 algo_result_push_config 配置""" return self.algo_result_push_config - def set_algo_result_push_config(self, config): + async def set_algo_result_push_config(self, config): """设置 algo_result_push_config 配置""" if config: - self.algo_result_push_config = config + async with self._lock: + self.algo_result_push_config = config def get_alarm_push_config(self): """获取 algo_result_push_config 配置""" return self.alarm_push_config - def set_alarm_push_config(self, config): + async def set_alarm_push_config(self, config): """设置 algo_result_push_config 配置""" if config: - self.alarm_push_config = config + async with self._lock: + self.alarm_push_config = config diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index dbc3fee..cdb97bf 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -98,7 +98,7 @@ def push_frame_results(self, frame_results): global_config = GlobalConfig() push_config = global_config.get_algo_result_push_config() - if push_config: + if push_config and push_config.push_url: last_ts = self.push_ts current_time = datetime.now() diff --git a/apis/control.py b/apis/control.py index 8517389..4c108c5 100644 --- a/apis/control.py +++ b/apis/control.py @@ -48,7 +48,7 @@ @router.get("/restart") -def restart(): +async def restart(): try: # 立即返回响应的函数 def restart_container_async(): diff --git a/apis/data_gas.py b/apis/data_gas.py index fc8ec54..bcb7cb6 100644 --- a/apis/data_gas.py +++ b/apis/data_gas.py @@ -3,8 +3,8 @@ import pandas as pd from fastapi import APIRouter, Query, Depends -from openpyxl.utils import get_column_letter -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession + from starlette.responses import StreamingResponse from apis.base import StandardResponse, PageResponse, convert_page_param, standard_response @@ -17,19 +17,20 @@ @router.get("/page", response_model=StandardResponse[PageResponse[DataGasInfo]]) -def get_gas_page( +async def get_gas_page( device_code: Optional[str] = None, start_time: Optional[str] = None, end_time: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): + service = DataGasService(db) offset, limit = convert_page_param(offset, limit) - data, total = service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), - offset, limit) + data, total = await service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), + offset, limit) return standard_response( data=PageResponse(total=total, items=data) @@ -37,38 +38,39 @@ @router.get("/export") -def export_data_gas(device_code: Optional[str] = None, - start_time: Optional[str] = None, - end_time: Optional[str] = None, - db: Session = Depends(get_db)): - service = DataGasService(db) - data = service.get_data_gas_list(device_code, start_time, end_time) +async def export_data_gas(device_code: Optional[str] = None, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + db: AsyncSession = Depends(get_db)): + async with db: + service = DataGasService(db) + data = await service.get_data_gas_list(device_code, start_time, end_time) - # 将查询结果转换为 DataFrame - data = [ - {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} - for item in data] - df = pd.DataFrame(data) + # 将查询结果转换为 DataFrame + data = [ + {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} + for item in data] + df = pd.DataFrame(data) - # 使用 BytesIO 生成内存中的 Excel 文件 - output = BytesIO() - with pd.ExcelWriter(output, engine="openpyxl") as writer: - df.to_excel(writer, index=False) + # 使用 BytesIO 生成内存中的 Excel 文件 + output = BytesIO() + with pd.ExcelWriter(output, engine="openpyxl") as writer: + df.to_excel(writer, index=False) - # 获取工作表 - worksheet = writer.sheets["Sheet1"] + # 获取工作表 + worksheet = writer.sheets["Sheet1"] - # 设置固定的列宽,例如宽度为 20 - fixed_width = 30 - for col in worksheet.columns: - col_letter = col[0].column_letter # 获取列字母 - worksheet.column_dimensions[col_letter].width = fixed_width + # 设置固定的列宽,例如宽度为 20 + fixed_width = 30 + for col in worksheet.columns: + col_letter = col[0].column_letter # 获取列字母 + worksheet.column_dimensions[col_letter].width = fixed_width - output.seek(0) + output.seek(0) - # 返回内存中的 Excel 文件作为响应 - return StreamingResponse( - output, - media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", - headers={"Content-Disposition": "attachment; filename=gas.xlsx"} - ) + # 返回内存中的 Excel 文件作为响应 + return StreamingResponse( + output, + media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + headers={"Content-Disposition": "attachment; filename=gas.xlsx"} + ) diff --git a/apis/device.py b/apis/device.py index 85fb10c..f170ae3 100644 --- a/apis/device.py +++ b/apis/device.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db @@ -18,29 +18,29 @@ @router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) -def get_device_list( +async def get_device_list( name: Optional[str] = None, code: Optional[str] = None, device_type: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceService(db) - devices = list(service.get_device_list(name, code, device_type)) - return standard_response(data=devices) + devices = await service.get_device_list(name, code, device_type) + return standard_response(data=list(devices)) @router.get("/page/", response_model=StandardResponse[PageResponse[DeviceInfo]]) -def get_device_page( +async def get_device_page( name: Optional[str] = None, code: Optional[str] = None, device_type: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - devices, total = service.get_device_page(name, code, device_type, offset, limit) + devices, total = await service.get_device_page(name, code, device_type, offset, limit) return standard_response( data=PageResponse(total=total, items=devices) @@ -48,22 +48,22 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): - device = service.create_device(device_data) +async def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): + device = await service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): - device = service.update_device(device_data) +async def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): + device = await service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") return standard_response(data=device) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, service: DeviceService = Depends(get_service)): - device = service.delete_device(device_id) +async def delete_device(device_id: int, service: DeviceService = Depends(get_service)): + device = await service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") return standard_response(data=device_id) diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py index e61e1d6..ddb5eb6 100644 --- a/apis/device_model_realtion.py +++ b/apis/device_model_realtion.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import StandardResponse, standard_response from db.database import get_db @@ -13,37 +13,23 @@ router = APIRouter() app = get_app() + def get_service(): return app.state.model_relation_service @router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) -def list_by_device( +async def list_by_device( device_id: int, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceModelRelationService(db) - models = list(service.get_device_models(device_id)) - return standard_response(data=models) - - -# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) -# def add_by_device(relation_data: List[DeviceModelRelationCreate], -# device_id: int = Query(...), -# db: Session = Depends(get_db)): -# service = DeviceModelRelationService(db) -# relations = service.add_relations_by_device(device_id, relation_data) -# return standard_response(data=relations) + models = await service.get_device_models(device_id) + return standard_response(data=list(models)) @router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) -def update_by_device(relation_data: List[DeviceModelRelationCreate], - device_id: int = Query(...), - service: DeviceModelRelationService = Depends(get_service)): - relations = service.update_relations_by_device(device_id, relation_data) +async def update_by_device(relation_data: List[DeviceModelRelationCreate], + device_id: int = Query(...), + service: DeviceModelRelationService = Depends(get_service)): + relations = await service.update_relations_by_device(device_id, relation_data) return standard_response(data=relations) - -# @router.delete("/delete_by_device", response_model=StandardResponse[int]) -# def delete_device(device_id: int, db: Session = Depends(get_db)): -# service = DeviceModelRelationService(db) -# count = service.delete_relations_by_device(device_id) -# return standard_response(data=count) diff --git a/apis/device_scene_realtion.py b/apis/device_scene_realtion.py index 9a30a44..d9aa3a2 100644 --- a/apis/device_scene_realtion.py +++ b/apis/device_scene_realtion.py @@ -1,8 +1,8 @@ -from typing import List, Optional + from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import StandardResponse, standard_response from db.database import get_db from entity.device_scene_relation import DeviceSceneRelationInfo, DeviceSceneRelation @@ -13,17 +13,17 @@ @router.get("/get_by_device", response_model=StandardResponse[DeviceSceneRelationInfo]) -def list_by_device( +async def list_by_device( device_id: int, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - scene = service.get_device_scene(device_id) + scene = await service.get_device_scene(device_id) return standard_response(data=scene) @router.post("/update_by_device", response_model=StandardResponse[DeviceSceneRelation]) -def update_by_device(device_id: int, scene_id: int, - db: Session = Depends(get_db)): +async def update_by_device(device_id: int, scene_id: int, + db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - relation = service.update_relation_by_device(device_id, scene_id) + relation = await service.update_relation_by_device(device_id, scene_id) return standard_response(data=relation) diff --git a/apis/frame.py b/apis/frame.py index 1cb9294..5146ede 100644 --- a/apis/frame.py +++ b/apis/frame.py @@ -5,7 +5,7 @@ import cv2 from fastapi import APIRouter, Depends, Query, HTTPException, Request -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from fastapi.responses import StreamingResponse, FileResponse from apis.base import StandardResponse, PageResponse, standard_response, standard_error_response, convert_page_param @@ -19,20 +19,20 @@ @router.get("/page/", response_model=StandardResponse[PageResponse[DeviceFrame]]) -def get_frame_page( +async def get_frame_page( device_name: Optional[str] = None, device_code: Optional[str] = None, frame_start_time: Optional[str] = None, frame_end_time: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceFrameService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - frames, total = service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), - parse_datetime(frame_end_time), offset, limit) + frames, total = await service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), + parse_datetime(frame_end_time), offset, limit) return standard_response( data=PageResponse(total=total, items=frames) @@ -41,10 +41,10 @@ # 路由:使用 OpenCV 生成内存图像并返回字节流 @router.get("/frame_image/") -def get_frame_image(frame_id, db: Session = Depends(get_db)): +async def get_frame_image(frame_id, db: AsyncSession = Depends(get_db)): try: service = DeviceFrameService(db) - frame = service.get_frame_annotator(frame_id) + frame = await service.get_frame_annotator(frame_id) if frame is None: return standard_error_response(message="Frame does not exist") @@ -63,14 +63,3 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"Error reading image: {str(e)}") - -@router.get("/frame_test") -def get_frame_test(request: Request): - file_path = "test.jpg" - return FileResponse(file_path) - # def iterfile(): - # with open(file_path, 'rb') as file: - # while chunk := file.read(1024): - # yield chunk - # - # return StreamingResponse(iterfile(), media_type="image/jpeg") diff --git a/apis/model.py b/apis/model.py index a1161eb..fa2b191 100644 --- a/apis/model.py +++ b/apis/model.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query, UploadFile, File, Form -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from app_instance import get_app @@ -18,27 +18,27 @@ @router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) -def get_model_list( +async def get_model_list( name: Optional[str] = None, remark: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = ModelService(db) - models = list(service.get_model_list(name, remark)) - return standard_response(data=models) + models = await service.get_model_list(name, remark) + return standard_response(data=list(models)) @router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModelInfo]]) -def get_model_page( +async def get_model_page( name: Optional[str] = None, remark: Optional[str] = None, offset: int = Query(0, ge=0), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = ModelService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - models, total = service.get_model_page(name, remark, offset, limit) + models, total = await service.get_model_page(name, remark, offset, limit) return standard_response( data=PageResponse(total=total, items=models) @@ -46,9 +46,9 @@ @router.post("/add", response_model=StandardResponse[AlgoModelInfo]) -def create_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelCreate结构"), +async def create_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelCreate结构"), file: UploadFile = File(..., description="模型文件"), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_db), ): # 检查文件类型 if not file.filename.endswith(".zip"): @@ -56,25 +56,25 @@ model_data = AlgoModelCreate.parse_raw(json_data) service = ModelService(db) - model = service.create_model(model_data, file) + model = await service.create_model(model_data, file) return standard_response(data=model) @router.post("/update", response_model=StandardResponse[AlgoModelInfo]) -def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), +async def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), file: UploadFile = File(None, description="模型文件"), service: ModelService = Depends(get_service)): model_data = AlgoModelUpdate.parse_raw(json_data) - model = service.update_model(model_data, file) + model = await service.update_model(model_data, file) if not model: return standard_error_response(data=model_data, message="Model not found") return standard_response(data=model) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_model(model_id: int, db: Session = Depends(get_db)): +async def delete_model(model_id: int, db: AsyncSession = Depends(get_db)): service = ModelService(db) - model = service.delete_model(model_id) + model = await service.delete_model(model_id) if not model: return standard_error_response(data=model_id, message="Model not found") return standard_response(data=model_id) diff --git a/apis/push_config.py b/apis/push_config.py index 5e25579..e234afb 100644 --- a/apis/push_config.py +++ b/apis/push_config.py @@ -1,7 +1,7 @@ from typing import List from fastapi import APIRouter, Depends -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse from db.database import get_db @@ -12,21 +12,21 @@ @router.get("/list", response_model=StandardResponse[List[PushConfig]]) -def get_push_config_list(db: Session = Depends(get_db)): +async def get_push_config_list(db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_configs = service.get_push_config_list() + push_configs = await service.get_push_config_list() return standard_response(data=push_configs) @router.get("/get_by_type", response_model=StandardResponse[PushConfig]) -def get_by_type(push_type: int, db: Session = Depends(get_db)): +async def get_by_type(push_type: int, db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_config = service.get_push_config(push_type) + push_config = await service.get_push_config(push_type) return standard_response(data=push_config) @router.post("/set_push_config", response_model=StandardResponse[PushConfig]) -def set_push_config(push_config: PushConfigCreate, db: Session = Depends(get_db)): +async def set_push_config(push_config: PushConfigCreate, db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_config = service.set_push_config(push_config) + push_config = await service.set_push_config(push_config) return standard_response(data=push_config) diff --git a/apis/scene.py b/apis/scene.py index f84b54b..1b7f706 100644 --- a/apis/scene.py +++ b/apis/scene.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query, UploadFile, File, Form -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db @@ -14,27 +14,27 @@ @router.get("/list", response_model=StandardResponse[List[Scene]]) -def get_scene_list( +async def get_scene_list( name: Optional[str] = None, remark: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) - scenes = list(service.get_scene_list(name, remark)) - return standard_response(data=scenes) + scenes = await service.get_scene_list(name, remark) + return standard_response(data=list(scenes)) @router.get("/page/", response_model=StandardResponse[PageResponse[SceneInfo]]) -def get_scene_page( +async def get_scene_page( name: Optional[str] = None, remark: Optional[str] = None, offset: int = Query(0, ge=0), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - scenes, total = service.get_scene_page(name, remark, offset, limit) + scenes, total = await service.get_scene_page(name, remark, offset, limit) return standard_response( data=PageResponse(total=total, items=scenes) @@ -42,9 +42,9 @@ @router.post("/add", response_model=StandardResponse[SceneInfo]) -def create_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneCreate结构"), +async def create_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneCreate结构"), file: UploadFile = File(..., description="模型文件"), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_db), ): # 检查文件类型 if not file.filename.endswith(".zip"): @@ -52,26 +52,26 @@ scene_data = SceneCreate.parse_raw(json_data) service = SceneService(db) - scene = service.create_scene(scene_data, file) + scene = await service.create_scene(scene_data, file) return standard_response(data=scene) @router.post("/update", response_model=StandardResponse[SceneInfo]) -def update_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneUpdate结构"), +async def update_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneUpdate结构"), file: UploadFile = File(None, description="模型文件"), - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) scene_data = SceneUpdate.parse_raw(json_data) - scene = service.update_scene(scene_data, file) + scene = await service.update_scene(scene_data, file) if not scene: return standard_error_response(data=scene_data, message="Scene not found") return standard_response(data=scene) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_scene(scene_id: int, db: Session = Depends(get_db)): +async def delete_scene(scene_id: int, db: AsyncSession = Depends(get_db)): service = SceneService(db) - scene = service.delete_scene(scene_id) + scene = await service.delete_scene(scene_id) if not scene: return standard_error_response(data=scene_id, message="Scene not found") return standard_response(data=scene_id) diff --git a/app_instance.py b/app_instance.py index a6e0992..f9d1ad9 100644 --- a/app_instance.py +++ b/app_instance.py @@ -12,6 +12,7 @@ from services.device_model_relation_service import DeviceModelRelationService from services.device_scene_relation_service import DeviceSceneRelationService from services.device_service import DeviceService +from services.global_config import GlobalConfig from services.model_service import ModelService from services.scene_service import SceneService from tcp.tcp_manager import TcpManager @@ -28,7 +29,11 @@ async def lifespan(app: FastAPI): main_loop = asyncio.get_running_loop() - with next(get_db()) as db: + # async with get_db() as db: + async for db in get_db(): + global_config = GlobalConfig() + await global_config.init_config() + device_service = DeviceService(db) model_service = ModelService(db) model_relation_service = DeviceModelRelationService(db) @@ -51,7 +56,7 @@ relation_service=model_relation_service, ) app.state.algo_runner = algo_runner - await algo_runner.start() + # await algo_runner.start() scene_runner = SceneRunner( device_service=device_service, @@ -61,7 +66,9 @@ main_loop=main_loop ) app.state.scene_runner = scene_runner - await scene_runner.start() + # await scene_runner.start() + + yield # 允许请求处理 diff --git a/common/global_logger.py b/common/global_logger.py index 3e3a6a9..3e672ca 100644 --- a/common/global_logger.py +++ b/common/global_logger.py @@ -12,6 +12,8 @@ logger = logging.getLogger("casic_safe_logger") logger.setLevel(logging.DEBUG) # 设置日志级别 +logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING) + # 创建一个TimedRotatingFileHandler handler = logging.handlers.TimedRotatingFileHandler( os.path.join(log_dir, 'app.log'), # 日志文件名 diff --git a/db/database.py b/db/database.py index 4f44d2d..01ac013 100644 --- a/db/database.py +++ b/db/database.py @@ -1,22 +1,25 @@ +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession +from sqlalchemy.orm import sessionmaker from sqlmodel import SQLModel, create_engine, Session -from contextlib import contextmanager +from contextlib import contextmanager, asynccontextmanager sqlite_file_name = "./db/safe-algo-pro.db" -sqlite_url = f"sqlite:///{sqlite_file_name}" +sqlite_url = f"sqlite+aiosqlite:///{sqlite_file_name}" # 使用异步SQLite驱动 -connect_args = {"check_same_thread": False} -engine = create_engine(sqlite_url, connect_args=connect_args) +engine = create_async_engine(sqlite_url, echo=False, future=True) # 初始化数据库表 -def init_db(): - SQLModel.metadata.create_all(engine) +async def init_db(): + async with engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.create_all) -# 数据库会话管理 -def get_db(): - session = Session(engine) - try: +# 异步数据库会话管理 +# @asynccontextmanager +async def get_db() -> AsyncSession: + async_session = sessionmaker( + engine, class_=AsyncSession, expire_on_commit=False + ) + async with async_session() as session: yield session - finally: - session.close() diff --git a/main.py b/main.py index e35e458..e830e86 100644 --- a/main.py +++ b/main.py @@ -3,7 +3,9 @@ import logging from fastapi.openapi.docs import get_swagger_ui_html -from starlette.staticfiles import StaticFiles +from fastapi.middleware.cors import CORSMiddleware +from fastapi.staticfiles import StaticFiles + from app_instance import get_app from common.global_logger import logger @@ -28,6 +30,14 @@ from apis.router import router app.include_router(router, prefix="/api") +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + if __name__ == "__main__": # 重定向 uvicorn 的日志 uvicorn_logger = logging.getLogger("uvicorn") diff --git a/requirements.txt b/requirements.txt index 161192a..aa9e1e6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,6 @@ starlette uvicorn sqlalchemy -aiohttp \ No newline at end of file +aiohttp +aiosqlite +aiofiles \ No newline at end of file diff --git a/services/data_gas_service.py b/services/data_gas_service.py index fe6ce53..6ac5665 100644 --- a/services/data_gas_service.py +++ b/services/data_gas_service.py @@ -2,23 +2,24 @@ from typing import Optional, Tuple, Sequence from sqlalchemy import func -from sqlmodel import Session, select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlmodel import select from entity.data_gas import DataGas, DataGasInfo from entity.device import Device class DataGasService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db - def add_data_gas(self, data_gas: DataGas): + async def add_data_gas(self, data_gas: DataGas): self.db.add(data_gas) - self.db.commit() - self.db.refresh(data_gas) + await self.db.commit() + await self.db.refresh(data_gas) return data_gas - def get_data_gas_page(self, + async def get_data_gas_page(self, device_code: Optional[str] = None, start_time: Optional[datetime] = None, end_time: Optional[datetime] = None, @@ -29,13 +30,17 @@ # 查询总记录数 total_statement = select(func.count()).select_from(statement.subquery()) - total = self.db.exec(total_statement).one() + total_result = await self.db.execute(total_statement) + total = total_result.scalar_one() # 添加分页限制 statement = statement.offset(offset).limit(limit) # 执行查询并返回结果 - results = self.db.exec(statement).all() + results = await self.db.execute(statement) + rows = results.all() + + data_gas_info_list = [ DataGasInfo( id=data_gas.id, @@ -44,18 +49,19 @@ ts=data_gas.ts, device_name=device_name ) - for data_gas, device_name in results + for data_gas, device_name in rows ] return data_gas_info_list, total # 返回分页数据和总数 - def get_data_gas_list(self, + async def get_data_gas_list(self, device_code: Optional[str] = None, start_time: Optional[datetime] = None, end_time: Optional[datetime] = None, ) -> Sequence[DataGasInfo]: statement = self.gas_query(device_code, end_time, start_time) - results = self.db.exec(statement).all() + results = await self.db.execute(statement) + rows = results.all() data_gas_info_list = [ DataGasInfo( id=data_gas.id, @@ -64,7 +70,7 @@ ts=data_gas.ts, device_name=device_name ) - for data_gas, device_name in results + for data_gas, device_name in rows ] return data_gas_info_list diff --git a/services/device_frame_service.py b/services/device_frame_service.py index 5dead7d..1a468ae 100644 --- a/services/device_frame_service.py +++ b/services/device_frame_service.py @@ -3,7 +3,10 @@ from copy import deepcopy from datetime import datetime from typing import Sequence, Optional, Tuple -from sqlmodel import Session + +import aiofiles +import numpy as np +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import func @@ -13,46 +16,56 @@ import cv2 -from sqlmodel import Session, select, delete +from sqlmodel import select, delete from services.frame_analysis_result_service import FrameAnalysisResultService class DeviceFrameService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db - def add_frame(self, device_id, frame_data) -> DeviceFrame: - def save_frame_file(): - # 生成当前年月日作为目录路径 - current_date = datetime.now().strftime('%Y-%m-%d') - # 生成随机 UUID 作为文件名 - file_name = str(uuid.uuid4()) + ".jpeg" - # 创建保存图片的完整路径 - save_path = os.path.join('./storage/frames', current_date, file_name) - # 创建目录(如果不存在) - os.makedirs(os.path.dirname(save_path), exist_ok=True) - # 保存图片 - cv2.imwrite(save_path, frame_data) - return save_path + async def add_frame(self, device_id, frame_data) -> DeviceFrame: + async def add_frame(self, device_id, frame_data) -> 'DeviceFrame': + async def save_frame_file(): + # 生成当前年月日作为目录路径 + current_date = datetime.now().strftime('%Y-%m-%d') + # 生成随机 UUID 作为文件名 + file_name = f"{uuid.uuid4()}.jpeg" + # 创建保存图片的完整路径 + save_path = os.path.join('./storage/frames', current_date, file_name) + # 创建目录(如果不存在) + os.makedirs(os.path.dirname(save_path), exist_ok=True) - # 保存图片文件 - file_path = save_frame_file() - device_frame = DeviceFrame(device_id=device_id, frame_path=file_path) - self.db.add(device_frame) - self.db.commit() - self.db.refresh(device_frame) - return device_frame + # 将 frame_data 转换为二进制数据(假设 frame_data 是一个 numpy 数组) + _, encoded_image = cv2.imencode('.jpeg', frame_data) + image_data = encoded_image.tobytes() - def get_frame_page(self, - device_name: Optional[str] = None, - device_code: Optional[str] = None, - frame_start_time: Optional[datetime] = None, - frame_end_time: Optional[datetime] = None, - offset: int = 0, - limit: int = 10 - ) -> Tuple[Sequence[DeviceFrame], int]: + # 使用 aiofiles 进行异步写入 + async with aiofiles.open(save_path, 'wb') as f: + await f.write(image_data) + + return save_path + + # 异步保存图片文件 + file_path = await save_frame_file() + + # 创建并保存到数据库中 + device_frame = DeviceFrame(device_id=device_id, frame_path=file_path) + self.db.add(device_frame) + await self.db.commit() + await self.db.refresh(device_frame) + return device_frame + + async def get_frame_page(self, + device_name: Optional[str] = None, + device_code: Optional[str] = None, + frame_start_time: Optional[datetime] = None, + frame_end_time: Optional[datetime] = None, + offset: int = 0, + limit: int = 10 + ) -> Tuple[Sequence[DeviceFrame], int]: statement = ( select(DeviceFrame, Device) .join(Device, DeviceFrame.device_id == Device.id) @@ -69,26 +82,37 @@ # 查询总记录数 total_statement = select(func.count()).select_from(statement.subquery()) - total = self.db.exec(total_statement).one() + total_result = await self.db.execute(total_statement) + total = total_result.scalar_one() # 添加分页限制 statement = statement.offset(offset).limit(limit) # 执行查询并返回结果 - results = self.db.exec(statement) - results = results.all() - frames = [frame for frame, device in results] + results = await self.db.execute(statement) + rows = results.all() + + frames = [frame for frame, device in rows] return frames, total # 返回分页数据和总数 - def get_frame(self, frame_id: int): - return self.db.get(DeviceFrame, frame_id) + async def get_frame(self, frame_id: int): + result = await self.db.execute(select(DeviceFrame).where(DeviceFrame.id == frame_id)) + frame = result.scalar_one_or_none() + return frame - def get_frame_annotator(self, frame_id: int): - device_frame = self.get_frame(frame_id) + async def get_frame_annotator(self, frame_id: int): + device_frame = await self.get_frame(frame_id) if device_frame: - frame_image = cv2.imread(device_frame.frame_path) + # 异步读取图像文件 + async with aiofiles.open(device_frame.frame_path, mode='rb') as f: + file_content = await f.read() + + # 将读取的字节内容转换为 OpenCV 图像 + np_array = np.frombuffer(file_content, dtype=np.uint8) + frame_image = cv2.imdecode(np_array, cv2.IMREAD_COLOR) + frame_analysis_result_service = FrameAnalysisResultService(self.db) - results = frame_analysis_result_service.get_results_by_frame(device_frame.id) + results = await frame_analysis_result_service.get_results_by_frame(device_frame.id) if results: annotator = Annotator(deepcopy(frame_image)) height, width = frame_image.shape[:2] diff --git a/services/device_model_relation_service.py b/services/device_model_relation_service.py index c32b5a5..5f1aa9d 100644 --- a/services/device_model_relation_service.py +++ b/services/device_model_relation_service.py @@ -1,8 +1,8 @@ from datetime import datetime from typing import List -from sqlmodel import Session, select, delete - +from sqlmodel import select, delete +from sqlalchemy.ext.asyncio import AsyncSession from common.consts import NotifyChangeType from common.global_thread_pool import GlobalThreadPool from entity.device_model_relation import DeviceModelRelation, DeviceModelRelationInfo, DeviceModelRelationCreate @@ -10,7 +10,7 @@ class DeviceModelRelationService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db self.__relation_change_callbacks = [] # 用于存储回调函数 self.thread_pool = GlobalThreadPool() @@ -24,7 +24,7 @@ for callback in self.__relation_change_callbacks: self.thread_pool.executor.submit(callback, device_id, change_type) - def get_device_models(self, device_id: int) -> List[DeviceModelRelationInfo]: + async def get_device_models(self, device_id: int) -> List[DeviceModelRelationInfo]: statement = ( select(DeviceModelRelation, AlgoModel) .join(AlgoModel, DeviceModelRelation.algo_model_id == AlgoModel.id) @@ -32,7 +32,8 @@ ) # 执行联表查询 - result = self.db.exec(statement).all() + results = await self.db.execute(statement) + rows = results.all() models_info = [ DeviceModelRelationInfo( @@ -46,13 +47,11 @@ algo_model_path=model.path, algo_model_remark=model.remark, ) - for relation, model in result + for relation, model in rows ] return models_info - - - def add_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]): + async def add_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]): new_relations = [ DeviceModelRelation( algo_model_id=relation.algo_model_id, @@ -64,20 +63,21 @@ ) for relation in relations ] - self.db.add_all(new_relations) - self.db.commit() for relation in new_relations: - self.db.refresh(relation) + self.db.add(relation) + await self.db.commit() + for relation in new_relations: + await self.db.refresh(relation) return new_relations - def delete_relations_by_device(self, device_id: int): + async def delete_relations_by_device(self, device_id: int): statement = delete(DeviceModelRelation).where(DeviceModelRelation.device_id == device_id) - count = self.db.exec(statement) - self.db.commit() - return count.rowcount + result = await self.db.execute(statement) + await self.db.commit() + return result.rowcount - def update_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]): - self.delete_relations_by_device(device_id) - new_relations = self.add_relations_by_device(device_id, relations) + async def update_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]): + await self.delete_relations_by_device(device_id) + new_relations = await self.add_relations_by_device(device_id, relations) self.notify_change(device_id, NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE) return new_relations diff --git a/services/device_scene_relation_service.py b/services/device_scene_relation_service.py index 592ee98..3416831 100644 --- a/services/device_scene_relation_service.py +++ b/services/device_scene_relation_service.py @@ -1,8 +1,8 @@ from datetime import datetime from typing import List, Optional -from sqlmodel import Session, select, delete - +from sqlmodel import select, delete +from sqlalchemy.ext.asyncio import AsyncSession from common.consts import NotifyChangeType from common.global_thread_pool import GlobalThreadPool from entity.device_scene_relation import DeviceSceneRelationInfo, DeviceSceneRelation @@ -10,7 +10,7 @@ class DeviceSceneRelationService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db self.__relation_change_callbacks = [] # 用于存储回调函数 self.thread_pool = GlobalThreadPool() @@ -24,7 +24,7 @@ for callback in self.__relation_change_callbacks: self.thread_pool.executor.submit(callback, device_id, change_type) - def get_device_scene(self, device_id: int) -> Optional[DeviceSceneRelationInfo]: + async def get_device_scene(self, device_id: int) -> Optional[DeviceSceneRelationInfo]: statement = ( select(DeviceSceneRelation, Scene) .join(Scene, DeviceSceneRelation.scene_id == Scene.id) @@ -32,11 +32,12 @@ ) # 执行联表查询 - result = self.db.exec(statement).first() + result = await self.db.execute(statement) + result_row = result.first() scene_info = None - if result: - relation, scene = result[0], result[1] + if result_row: + relation, scene = result_row scene_info = DeviceSceneRelationInfo( id=relation.id, device_id=relation.device_id, @@ -48,23 +49,23 @@ ) return scene_info - def add_relation_by_device(self, device_id: int, scene_id: int): + async def add_relation_by_device(self, device_id: int, scene_id: int): new_relation = DeviceSceneRelation(device_id=device_id, scene_id=scene_id) new_relation.create_time = datetime.now() new_relation.update_time = datetime.now() self.db.add(new_relation) - self.db.commit() - self.db.refresh(new_relation) + await self.db.commit() + await self.db.refresh(new_relation) return new_relation - def delete_relation_by_device(self, device_id: int): + async def delete_relation_by_device(self, device_id: int): statement = delete(DeviceSceneRelation).where(DeviceSceneRelation.device_id == device_id) - count = self.db.exec(statement) - self.db.commit() - return count.rowcount + result = await self.db.execute(statement) + await self.db.commit() + return result.rowcount - def update_relation_by_device(self, device_id: int, scene_id: int): - self.delete_relation_by_device(device_id) - new_relation = self.add_relation_by_device(device_id, scene_id) + async def update_relation_by_device(self, device_id: int, scene_id: int): + await self.delete_relation_by_device(device_id) + new_relation = await self.add_relation_by_device(device_id, scene_id) self.notify_change(device_id, NotifyChangeType.DEVICE_SCENE_RELATION_UPDATE) return new_relation diff --git a/services/device_service.py b/services/device_service.py index cad46d3..2488fd5 100644 --- a/services/device_service.py +++ b/services/device_service.py @@ -4,8 +4,8 @@ from typing import Sequence, Optional, Tuple from sqlalchemy import func -from sqlmodel import Session, select - +from sqlmodel import select, delete +from sqlalchemy.ext.asyncio import AsyncSession from common.device_status_manager import DeviceStatusManager from common.global_thread_pool import GlobalThreadPool from common.consts import NotifyChangeType, DEVICE_MODE @@ -15,25 +15,25 @@ class DeviceService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db self.__device_change_callbacks = [] # 用于存储回调函数 self.thread_pool = GlobalThreadPool() # 创建一个独立的事件循环并启动线程 - self.loop = asyncio.new_event_loop() - self.loop_thread = threading.Thread(target=self._start_loop, daemon=True) - self.loop_thread.start() + # self.loop = asyncio.new_event_loop() + # self.loop_thread = threading.Thread(target=self._start_loop, daemon=True) + # self.loop_thread.start() - def _start_loop(self): - """后台线程运行事件循环""" - asyncio.set_event_loop(self.loop) - self.loop.run_forever() - - def shutdown(self): - """清理事件循环和线程""" - self.loop.call_soon_threadsafe(self.loop.stop) - self.loop_thread.join() + # def _start_loop(self): + # """后台线程运行事件循环""" + # asyncio.set_event_loop(self.loop) + # self.loop.run_forever() + # + # def shutdown(self): + # """清理事件循环和线程""" + # self.loop.call_soon_threadsafe(self.loop.stop) + # self.loop_thread.join() def register_change_callback(self, callback): """注册设备变化回调函数""" @@ -42,52 +42,53 @@ def notify_change(self, device_id, change_type): """当设备发生变化时,调用回调通知变化""" # loop = asyncio.get_event_loop() # 获取当前的事件循环 - for callback in self.__device_change_callbacks: - if asyncio.iscoroutinefunction(callback): - # 如果是协程函数,使用事件循环运行它 - # loop = asyncio.new_event_loop() - # asyncio.set_event_loop(loop) - asyncio.run_coroutine_threadsafe(callback(device_id, change_type), self.loop) - else: - # 如果是普通函数,直接提交到线程池 - self.thread_pool.executor.submit(callback, device_id, change_type) + # for callback in self.__device_change_callbacks: + # if asyncio.iscoroutinefunction(callback): + # # 如果是协程函数,使用事件循环运行它 + # # loop = asyncio.new_event_loop() + # # asyncio.set_event_loop(loop) + # asyncio.run_coroutine_threadsafe(callback(device_id, change_type), self.loop) + # else: + # # 如果是普通函数,直接提交到线程池 + # self.thread_pool.executor.submit(callback, device_id, change_type) - def get_device_list(self, - name: Optional[str] = None, - code: Optional[str] = None, - device_type: Optional[str] = None, - ) -> Sequence[Device]: + async def get_device_list(self, + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + ) -> Sequence[Device]: statement = self.device_query(code, device_type, name) - results = self.db.exec(statement) - return results.all() + results = await self.db.execute(statement) + return results.scalars().all() - def get_device_page(self, - name: Optional[str] = None, - code: Optional[str] = None, - device_type: Optional[str] = None, - offset: int = 0, - limit: int = 10 - ) -> Tuple[Sequence[DeviceInfo], int]: + async def get_device_page(self, + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + offset: int = 0, + limit: int = 10 + ) -> Tuple[Sequence[DeviceInfo], int]: statement = self.device_query(code, device_type, name) # 查询总记录数 total_statement = select(func.count()).select_from(statement.subquery()) - total = self.db.exec(total_statement).one() + total_result = await self.db.execute(total_statement) + total = total_result.scalar_one() # 添加分页限制 statement = statement.offset(offset).limit(limit) # 执行查询并返回结果 - results = self.db.exec(statement) - device_list = results.all() + results = await self.db.execute(statement) + device_list = results.scalars().all() device_info_list = [] if device_list: device_model_relation_service = DeviceModelRelationService(self.db) device_scene_relation_service = DeviceSceneRelationService(self.db) device_status_manager = DeviceStatusManager() for device in device_list: - model_relations = device_model_relation_service.get_device_models(device.id) - scene_relation = device_scene_relation_service.get_device_scene(device.id) + model_relations = await device_model_relation_service.get_device_models(device.id) + scene_relation = await device_scene_relation_service.get_device_scene(device.id) device_info_list.append(DeviceInfo( id=device.id, @@ -121,20 +122,20 @@ statement = statement.where(Device.type == device_type) return statement - def create_device(self, device_data: DeviceCreate): + async def create_device(self, device_data: DeviceCreate): device = Device.model_validate(device_data) device.create_time = datetime.now() device.update_time = datetime.now() - self.handle_device_mode(device) + await self.handle_device_mode(device) self.db.add(device) - self.db.commit() - self.db.refresh(device) + await self.db.commit() + await self.db.refresh(device) self.notify_change(device.id, NotifyChangeType.DEVICE_CREATE) return device - def update_device(self, device_data: DeviceUpdate): - device_old = self.db.get(Device, device_data.id) - device = self.db.get(Device, device_data.id) + async def update_device(self, device_data: DeviceUpdate): + device_old = await self.get_device(device_data.id) + device = await self.get_device(device_data.id) if not device: return None @@ -142,41 +143,45 @@ for key, value in update_data.items(): setattr(device, key, value) - self.handle_device_mode(device) + await self.handle_device_mode(device) device.update_time = datetime.now() self.db.add(device) - self.db.commit() - self.db.refresh(device) + await self.db.commit() + await self.db.refresh(device) self.notify_change(device.id, NotifyChangeType.DEVICE_UPDATE) return device - def delete_device(self, device_id: int): - device = self.db.get(Device, device_id) + async def delete_device(self, device_id: int): + device = await self.get_device(device_id) if not device: return None - self.db.delete(device) - self.db.commit() + statement = delete(Device).where(Device.id == device_id) + await self.db.execute(statement) + await self.db.commit() + self.notify_change(device.id, NotifyChangeType.DEVICE_DELETE) model_relation_service = DeviceModelRelationService(self.db) - model_relation_service.delete_relations_by_device(device_id) + await model_relation_service.delete_relations_by_device(device_id) scene_relation_service = DeviceSceneRelationService(self.db) - scene_relation_service.delete_relation_by_device(device.id) + await scene_relation_service.delete_relation_by_device(device.id) return device - def handle_device_mode(self, device): + async def handle_device_mode(self, device): if device.mode == DEVICE_MODE.ALGO: scene_relation_service = DeviceSceneRelationService(self.db) - scene_relation_service.delete_relation_by_device(device.id) + await scene_relation_service.delete_relation_by_device(device.id) elif device.mode == DEVICE_MODE.SCENE: model_relation_service = DeviceModelRelationService(self.db) - model_relation_service.delete_relations_by_device(device.id) + await model_relation_service.delete_relations_by_device(device.id) else: scene_relation_service = DeviceSceneRelationService(self.db) - scene_relation_service.delete_relation_by_device(device.id) + await scene_relation_service.delete_relation_by_device(device.id) model_relation_service = DeviceModelRelationService(self.db) - model_relation_service.delete_relations_by_device(device.id) + await model_relation_service.delete_relations_by_device(device.id) - def get_device(self, device_id: int): - return self.db.get(Device, device_id) + async def get_device(self, device_id: int): + result = await self.db.execute(select(Device).where(Device.id == device_id)) + frame = result.scalar_one_or_none() + return frame diff --git a/services/frame_analysis_result_service.py b/services/frame_analysis_result_service.py index dbd272d..4e6c145 100644 --- a/services/frame_analysis_result_service.py +++ b/services/frame_analysis_result_service.py @@ -1,23 +1,25 @@ from typing import List -from sqlmodel import Session, select +from sqlmodel import select +from sqlalchemy.ext.asyncio import AsyncSession from entity.frame_analysis_result import FrameAnalysisResultCreate, FrameAnalysisResult class FrameAnalysisResultService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db - def add_frame_analysis_results(self, results: List[FrameAnalysisResultCreate]): + async def add_frame_analysis_results(self, results: List[FrameAnalysisResultCreate]): new_results = [FrameAnalysisResult.model_validate(result) for result in results] - self.db.add_all(new_results) - self.db.commit() for result in new_results: - self.db.refresh(result) + self.db.add(result) + await self.db.commit() + for result in new_results: + await self.db.refresh(result) return new_results - def get_results_by_frame(self, frame_id): + async def get_results_by_frame(self, frame_id): statement = select(FrameAnalysisResult).where(FrameAnalysisResult.frame_id == frame_id) - results = self.db.exec(statement) - return results.all() + results = await self.db.execute(statement) + return results.scalars().all() diff --git a/services/global_config.py b/services/global_config.py index 7dae73b..fa53b2b 100644 --- a/services/global_config.py +++ b/services/global_config.py @@ -1,3 +1,5 @@ +import asyncio + from common.consts import PUSH_TYPE from db.database import get_db from entity.push_config import PushConfig @@ -6,10 +8,12 @@ class GlobalConfig: _instance = None + _lock = asyncio.Lock() def __new__(cls): if cls._instance is None: cls._instance = super().__new__(cls) + return cls._instance def __init__(self): @@ -19,49 +23,58 @@ self.gas_push_config = None self.algo_result_push_config = None self.alarm_push_config = None - self.init_config() # 进行初始化 + self._init_done = False - def init_config(self): - # 初始化配置逻辑 - with next(get_db()) as db: - self.config_service = PushConfigService(db) - self.set_gas_push_config(self.config_service.get_push_config(PUSH_TYPE.GAS)) - self.set_algo_result_push_config(self.config_service.get_push_config(PUSH_TYPE.ALGO_RESULT)) - self.set_alarm_push_config(self.config_service.get_push_config(PUSH_TYPE.ALARM)) + async def _initialize(self): + await self.init_config() # 调用异步初始化 - self.config_service.register_change_callback(self.on_config_change) + async def init_config(self): + # 确保只初始化一次 + if not self._init_done: + async with self._lock: + if not self._init_done: # 双重检查锁,避免多次初始化 + async for db in get_db(): + self.config_service = PushConfigService(db) + self.config_service.register_change_callback(self.on_config_change) + self.gas_push_config = await self.config_service.get_push_config(PUSH_TYPE.GAS) + self.algo_result_push_config = await self.config_service.get_push_config(PUSH_TYPE.ALGO_RESULT) + self.alarm_push_config = await self.config_service.get_push_config(PUSH_TYPE.ALARM) + self._init_done = True - def on_config_change(self, config: PushConfig): + async def on_config_change(self, config: PushConfig): if config.push_type == PUSH_TYPE.GAS: - self.set_gas_push_config(config) + await self.set_gas_push_config(config) elif config.push_type == PUSH_TYPE.ALGO_RESULT: - self.set_algo_result_push_config(config) + await self.set_algo_result_push_config(config) elif config.push_type == PUSH_TYPE.ALARM: - self.set_alarm_push_config(config) + await self.set_alarm_push_config(config) def get_gas_push_config(self): """获取 gas_push_config 配置""" return self.gas_push_config - def set_gas_push_config(self, config): + async def set_gas_push_config(self, config): """设置 gas_push_config 配置""" if config: - self.gas_push_config = config + async with self._lock: + self.gas_push_config = config def get_algo_result_push_config(self): """获取 algo_result_push_config 配置""" return self.algo_result_push_config - def set_algo_result_push_config(self, config): + async def set_algo_result_push_config(self, config): """设置 algo_result_push_config 配置""" if config: - self.algo_result_push_config = config + async with self._lock: + self.algo_result_push_config = config def get_alarm_push_config(self): """获取 algo_result_push_config 配置""" return self.alarm_push_config - def set_alarm_push_config(self, config): + async def set_alarm_push_config(self, config): """设置 algo_result_push_config 配置""" if config: - self.alarm_push_config = config + async with self._lock: + self.alarm_push_config = config diff --git a/services/model_service.py b/services/model_service.py index f9531a5..afc61d9 100644 --- a/services/model_service.py +++ b/services/model_service.py @@ -1,3 +1,4 @@ +import asyncio import os import uuid import zipfile @@ -5,10 +6,11 @@ from pathlib import Path from typing import List, Sequence, Optional, Tuple, Type +import aiofiles from fastapi import UploadFile from sqlalchemy import func -from sqlmodel import Session, select - +from sqlmodel import select, delete +from sqlalchemy.ext.asyncio import AsyncSession from common.biz_exception import BizException from common.string_utils import snake_to_camel from entity.device_model_relation import DeviceModelRelation @@ -18,7 +20,7 @@ class ModelService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db self.__model_change_callbacks = [] # 用于存储回调函数 self.thread_pool = GlobalThreadPool() @@ -32,15 +34,15 @@ for callback in self.__model_change_callbacks: self.thread_pool.executor.submit(callback, algo_model_id, change_type) - def get_model_list(self, - name: Optional[str] = None, - remark: Optional[str] = None, - ) -> Sequence[AlgoModel]: + async def get_model_list(self, + name: Optional[str] = None, + remark: Optional[str] = None, + ) -> Sequence[AlgoModel]: statement = self.model_query(name, remark) - results = self.db.exec(statement) - return results.all() + results = await self.db.execute(statement) + return results.scalars().all() - def get_model_page(self, + async def get_model_page(self, name: Optional[str] = None, remark: Optional[str] = None, offset: int = 0, @@ -50,19 +52,21 @@ # 查询总记录数 total_statement = select(func.count()).select_from(statement.subquery()) - total = self.db.exec(total_statement).one() + total_result = await self.db.execute(total_statement) + total = total_result.scalar_one() # 添加分页限制 statement = statement.offset(offset).limit(limit) # 执行查询并返回结果 - model_list = self.db.exec(statement) + model_list = await self.db.execute(statement) + rows = model_list.scalars().all() model_info_list: List[AlgoModelInfo] = [] - if model_list: - for model in model_list: + if rows: + for model in rows: model_info_list.append(AlgoModelInfo( **model.dict(), - usage_status="使用中" if self.get_model_usage(model.id) else "未使用" + usage_status="使用中" if await self.get_model_usage(model.id) else "未使用" )) return model_info_list, total # 返回分页数据和总数 @@ -76,7 +80,7 @@ statement = statement.where(AlgoModel.remark.like(f"%{remark}%")) return statement - def process_zip(self, file: UploadFile): + async def process_zip(self, file: UploadFile): model_dir = Path('weights/') model_handle_dir = Path('model_handler/') model_dir.mkdir(parents=True, exist_ok=True) @@ -86,13 +90,16 @@ # 临时保存上传文件 temp_path = Path(f"temp_upload_{uuid.uuid4()}.zip") - with open(temp_path, "wb") as temp_file: - temp_file.write(file.file.read()) + async with aiofiles.open(temp_path, "wb") as temp_file: + content = await file.read() # 异步读取上传文件内容 + await temp_file.write(content) model_file_path = None handle_file_path = None try: + # 使用异步方法读取压缩文件的内容 + loop = asyncio.get_event_loop() with zipfile.ZipFile(temp_path, 'r') as zip_ref: # 获取压缩包文件列表 file_list = zip_ref.namelist() @@ -105,47 +112,47 @@ message=f"Model weight file ({', '.join(SUPPORTED_MODEL_EXTENSIONS)}) is required in the zip." ) - # 解压模型文件到模型目录 - zip_ref.extract(model_file, model_dir) + # 异步解压模型文件到模型目录 model_file_path = model_dir / model_file + await loop.run_in_executor(None, zip_ref.extract, model_file, model_dir) # 检查是否有可选的 Python 脚本 handle_file = next((f for f in file_list if f.endswith(".py")), None) if handle_file: - zip_ref.extract(handle_file, model_handle_dir) handle_file_path = model_handle_dir / handle_file - + await loop.run_in_executor(None, zip_ref.extract, handle_file, model_handle_dir) except zipfile.BadZipFile: raise BizException(status_code=400, message="Invalid zip file.") finally: # 删除临时文件 - temp_path.unlink() + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, temp_path.unlink) return str(model_file_path), str(handle_file_path) if handle_file_path else None - def create_model(self, model_data: AlgoModelCreate, file: UploadFile): - self.process_model_file(file, model_data) + async def create_model(self, model_data: AlgoModelCreate, file: UploadFile): + await self.process_model_file(file, model_data) model = AlgoModel.model_validate(model_data) model.create_time = datetime.now() model.update_time = datetime.now() self.db.add(model) - self.db.commit() - self.db.refresh(model) + await self.db.commit() + await self.db.refresh(model) return model - def process_model_file(self, file, model): - model_file_path, handle_file_path = self.process_zip(file) + async def process_model_file(self, file, model): + model_file_path, handle_file_path = await self.process_zip(file) model.path = model_file_path if handle_file_path: model.handle_task = snake_to_camel(os.path.splitext(os.path.basename(handle_file_path))[0]) else: model.handle_task = 'BaseModelHandler' - def update_model(self, model_data: AlgoModelUpdate, file: UploadFile): - model = self.db.get(AlgoModel, model_data.id) + async def update_model(self, model_data: AlgoModelUpdate, file: UploadFile): + model = await self.get_model_by_id(model_data.id) if not model: return None @@ -155,16 +162,16 @@ model.update_time = datetime.now() if file: - self.process_model_file(file, model) + await self.process_model_file(file, model) self.db.add(model) - self.db.commit() - self.db.refresh(model) + await self.db.commit() + await self.db.refresh(model) self.notify_change(model.id, NotifyChangeType.MODEL_UPDATE) return model - def delete_model(self, model_id: int): - model = self.db.get(AlgoModel, model_id) + async def delete_model(self, model_id: int): + model = await self.get_model_by_id(model_id) if not model: return None # 查询 device_model_relation 中是否存在启用的绑定关系 @@ -173,17 +180,20 @@ .where(DeviceModelRelation.algo_model_id == model_id) .where(DeviceModelRelation.is_use == 1) ) - relation_in_use = self.db.exec(statement).first() + relation_in_use_exec = await self.db.execute(statement) + relation_in_use = relation_in_use_exec.scalars().first() # 如果存在启用的绑定关系,提示无法删除 if relation_in_use: raise BizException(message=f"模型 {model.name} 正在被设备使用,无法删除") - self.db.delete(model) - self.db.commit() + statement = delete(AlgoModel).where(AlgoModel.id == model_id) + await self.db.execute(statement) + await self.db.commit() + return model - def get_models_in_use(self) -> Sequence[AlgoModel]: + async def get_models_in_use(self) -> Sequence[AlgoModel]: """获取所有在 device_model_relation 表里有启用绑定关系的模型信息""" statement = ( select(AlgoModel) @@ -191,10 +201,10 @@ .where(DeviceModelRelation.is_use == 1) .group_by(AlgoModel.id) ) - results = self.db.exec(statement).all() - return results + results = await self.db.execute(statement) + return results.scalars().all() - def get_model_usage(self, algo_model_id) -> bool: + async def get_model_usage(self, algo_model_id) -> bool: statement = ( select(DeviceModelRelation) .where( @@ -202,8 +212,11 @@ DeviceModelRelation.algo_model_id == algo_model_id, ) ) - result = self.db.exec(statement).all() - return len(result) > 0 + result = await self.db.execute(statement) + rows = result.all() + return len(rows) > 0 - def get_model_by_id(self, model_id): - return self.db.get(AlgoModel, model_id) + async def get_model_by_id(self, model_id): + result = await self.db.execute(select(AlgoModel).where(AlgoModel.id == model_id)) + model = result.scalar_one_or_none() + return model diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index dbc3fee..cdb97bf 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -98,7 +98,7 @@ def push_frame_results(self, frame_results): global_config = GlobalConfig() push_config = global_config.get_algo_result_push_config() - if push_config: + if push_config and push_config.push_url: last_ts = self.push_ts current_time = datetime.now() diff --git a/apis/control.py b/apis/control.py index 8517389..4c108c5 100644 --- a/apis/control.py +++ b/apis/control.py @@ -48,7 +48,7 @@ @router.get("/restart") -def restart(): +async def restart(): try: # 立即返回响应的函数 def restart_container_async(): diff --git a/apis/data_gas.py b/apis/data_gas.py index fc8ec54..bcb7cb6 100644 --- a/apis/data_gas.py +++ b/apis/data_gas.py @@ -3,8 +3,8 @@ import pandas as pd from fastapi import APIRouter, Query, Depends -from openpyxl.utils import get_column_letter -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession + from starlette.responses import StreamingResponse from apis.base import StandardResponse, PageResponse, convert_page_param, standard_response @@ -17,19 +17,20 @@ @router.get("/page", response_model=StandardResponse[PageResponse[DataGasInfo]]) -def get_gas_page( +async def get_gas_page( device_code: Optional[str] = None, start_time: Optional[str] = None, end_time: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): + service = DataGasService(db) offset, limit = convert_page_param(offset, limit) - data, total = service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), - offset, limit) + data, total = await service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), + offset, limit) return standard_response( data=PageResponse(total=total, items=data) @@ -37,38 +38,39 @@ @router.get("/export") -def export_data_gas(device_code: Optional[str] = None, - start_time: Optional[str] = None, - end_time: Optional[str] = None, - db: Session = Depends(get_db)): - service = DataGasService(db) - data = service.get_data_gas_list(device_code, start_time, end_time) +async def export_data_gas(device_code: Optional[str] = None, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + db: AsyncSession = Depends(get_db)): + async with db: + service = DataGasService(db) + data = await service.get_data_gas_list(device_code, start_time, end_time) - # 将查询结果转换为 DataFrame - data = [ - {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} - for item in data] - df = pd.DataFrame(data) + # 将查询结果转换为 DataFrame + data = [ + {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} + for item in data] + df = pd.DataFrame(data) - # 使用 BytesIO 生成内存中的 Excel 文件 - output = BytesIO() - with pd.ExcelWriter(output, engine="openpyxl") as writer: - df.to_excel(writer, index=False) + # 使用 BytesIO 生成内存中的 Excel 文件 + output = BytesIO() + with pd.ExcelWriter(output, engine="openpyxl") as writer: + df.to_excel(writer, index=False) - # 获取工作表 - worksheet = writer.sheets["Sheet1"] + # 获取工作表 + worksheet = writer.sheets["Sheet1"] - # 设置固定的列宽,例如宽度为 20 - fixed_width = 30 - for col in worksheet.columns: - col_letter = col[0].column_letter # 获取列字母 - worksheet.column_dimensions[col_letter].width = fixed_width + # 设置固定的列宽,例如宽度为 20 + fixed_width = 30 + for col in worksheet.columns: + col_letter = col[0].column_letter # 获取列字母 + worksheet.column_dimensions[col_letter].width = fixed_width - output.seek(0) + output.seek(0) - # 返回内存中的 Excel 文件作为响应 - return StreamingResponse( - output, - media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", - headers={"Content-Disposition": "attachment; filename=gas.xlsx"} - ) + # 返回内存中的 Excel 文件作为响应 + return StreamingResponse( + output, + media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + headers={"Content-Disposition": "attachment; filename=gas.xlsx"} + ) diff --git a/apis/device.py b/apis/device.py index 85fb10c..f170ae3 100644 --- a/apis/device.py +++ b/apis/device.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db @@ -18,29 +18,29 @@ @router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) -def get_device_list( +async def get_device_list( name: Optional[str] = None, code: Optional[str] = None, device_type: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceService(db) - devices = list(service.get_device_list(name, code, device_type)) - return standard_response(data=devices) + devices = await service.get_device_list(name, code, device_type) + return standard_response(data=list(devices)) @router.get("/page/", response_model=StandardResponse[PageResponse[DeviceInfo]]) -def get_device_page( +async def get_device_page( name: Optional[str] = None, code: Optional[str] = None, device_type: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - devices, total = service.get_device_page(name, code, device_type, offset, limit) + devices, total = await service.get_device_page(name, code, device_type, offset, limit) return standard_response( data=PageResponse(total=total, items=devices) @@ -48,22 +48,22 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): - device = service.create_device(device_data) +async def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): + device = await service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): - device = service.update_device(device_data) +async def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): + device = await service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") return standard_response(data=device) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, service: DeviceService = Depends(get_service)): - device = service.delete_device(device_id) +async def delete_device(device_id: int, service: DeviceService = Depends(get_service)): + device = await service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") return standard_response(data=device_id) diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py index e61e1d6..ddb5eb6 100644 --- a/apis/device_model_realtion.py +++ b/apis/device_model_realtion.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import StandardResponse, standard_response from db.database import get_db @@ -13,37 +13,23 @@ router = APIRouter() app = get_app() + def get_service(): return app.state.model_relation_service @router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) -def list_by_device( +async def list_by_device( device_id: int, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceModelRelationService(db) - models = list(service.get_device_models(device_id)) - return standard_response(data=models) - - -# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) -# def add_by_device(relation_data: List[DeviceModelRelationCreate], -# device_id: int = Query(...), -# db: Session = Depends(get_db)): -# service = DeviceModelRelationService(db) -# relations = service.add_relations_by_device(device_id, relation_data) -# return standard_response(data=relations) + models = await service.get_device_models(device_id) + return standard_response(data=list(models)) @router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) -def update_by_device(relation_data: List[DeviceModelRelationCreate], - device_id: int = Query(...), - service: DeviceModelRelationService = Depends(get_service)): - relations = service.update_relations_by_device(device_id, relation_data) +async def update_by_device(relation_data: List[DeviceModelRelationCreate], + device_id: int = Query(...), + service: DeviceModelRelationService = Depends(get_service)): + relations = await service.update_relations_by_device(device_id, relation_data) return standard_response(data=relations) - -# @router.delete("/delete_by_device", response_model=StandardResponse[int]) -# def delete_device(device_id: int, db: Session = Depends(get_db)): -# service = DeviceModelRelationService(db) -# count = service.delete_relations_by_device(device_id) -# return standard_response(data=count) diff --git a/apis/device_scene_realtion.py b/apis/device_scene_realtion.py index 9a30a44..d9aa3a2 100644 --- a/apis/device_scene_realtion.py +++ b/apis/device_scene_realtion.py @@ -1,8 +1,8 @@ -from typing import List, Optional + from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import StandardResponse, standard_response from db.database import get_db from entity.device_scene_relation import DeviceSceneRelationInfo, DeviceSceneRelation @@ -13,17 +13,17 @@ @router.get("/get_by_device", response_model=StandardResponse[DeviceSceneRelationInfo]) -def list_by_device( +async def list_by_device( device_id: int, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - scene = service.get_device_scene(device_id) + scene = await service.get_device_scene(device_id) return standard_response(data=scene) @router.post("/update_by_device", response_model=StandardResponse[DeviceSceneRelation]) -def update_by_device(device_id: int, scene_id: int, - db: Session = Depends(get_db)): +async def update_by_device(device_id: int, scene_id: int, + db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - relation = service.update_relation_by_device(device_id, scene_id) + relation = await service.update_relation_by_device(device_id, scene_id) return standard_response(data=relation) diff --git a/apis/frame.py b/apis/frame.py index 1cb9294..5146ede 100644 --- a/apis/frame.py +++ b/apis/frame.py @@ -5,7 +5,7 @@ import cv2 from fastapi import APIRouter, Depends, Query, HTTPException, Request -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from fastapi.responses import StreamingResponse, FileResponse from apis.base import StandardResponse, PageResponse, standard_response, standard_error_response, convert_page_param @@ -19,20 +19,20 @@ @router.get("/page/", response_model=StandardResponse[PageResponse[DeviceFrame]]) -def get_frame_page( +async def get_frame_page( device_name: Optional[str] = None, device_code: Optional[str] = None, frame_start_time: Optional[str] = None, frame_end_time: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceFrameService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - frames, total = service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), - parse_datetime(frame_end_time), offset, limit) + frames, total = await service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), + parse_datetime(frame_end_time), offset, limit) return standard_response( data=PageResponse(total=total, items=frames) @@ -41,10 +41,10 @@ # 路由:使用 OpenCV 生成内存图像并返回字节流 @router.get("/frame_image/") -def get_frame_image(frame_id, db: Session = Depends(get_db)): +async def get_frame_image(frame_id, db: AsyncSession = Depends(get_db)): try: service = DeviceFrameService(db) - frame = service.get_frame_annotator(frame_id) + frame = await service.get_frame_annotator(frame_id) if frame is None: return standard_error_response(message="Frame does not exist") @@ -63,14 +63,3 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"Error reading image: {str(e)}") - -@router.get("/frame_test") -def get_frame_test(request: Request): - file_path = "test.jpg" - return FileResponse(file_path) - # def iterfile(): - # with open(file_path, 'rb') as file: - # while chunk := file.read(1024): - # yield chunk - # - # return StreamingResponse(iterfile(), media_type="image/jpeg") diff --git a/apis/model.py b/apis/model.py index a1161eb..fa2b191 100644 --- a/apis/model.py +++ b/apis/model.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query, UploadFile, File, Form -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from app_instance import get_app @@ -18,27 +18,27 @@ @router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) -def get_model_list( +async def get_model_list( name: Optional[str] = None, remark: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = ModelService(db) - models = list(service.get_model_list(name, remark)) - return standard_response(data=models) + models = await service.get_model_list(name, remark) + return standard_response(data=list(models)) @router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModelInfo]]) -def get_model_page( +async def get_model_page( name: Optional[str] = None, remark: Optional[str] = None, offset: int = Query(0, ge=0), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = ModelService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - models, total = service.get_model_page(name, remark, offset, limit) + models, total = await service.get_model_page(name, remark, offset, limit) return standard_response( data=PageResponse(total=total, items=models) @@ -46,9 +46,9 @@ @router.post("/add", response_model=StandardResponse[AlgoModelInfo]) -def create_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelCreate结构"), +async def create_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelCreate结构"), file: UploadFile = File(..., description="模型文件"), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_db), ): # 检查文件类型 if not file.filename.endswith(".zip"): @@ -56,25 +56,25 @@ model_data = AlgoModelCreate.parse_raw(json_data) service = ModelService(db) - model = service.create_model(model_data, file) + model = await service.create_model(model_data, file) return standard_response(data=model) @router.post("/update", response_model=StandardResponse[AlgoModelInfo]) -def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), +async def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), file: UploadFile = File(None, description="模型文件"), service: ModelService = Depends(get_service)): model_data = AlgoModelUpdate.parse_raw(json_data) - model = service.update_model(model_data, file) + model = await service.update_model(model_data, file) if not model: return standard_error_response(data=model_data, message="Model not found") return standard_response(data=model) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_model(model_id: int, db: Session = Depends(get_db)): +async def delete_model(model_id: int, db: AsyncSession = Depends(get_db)): service = ModelService(db) - model = service.delete_model(model_id) + model = await service.delete_model(model_id) if not model: return standard_error_response(data=model_id, message="Model not found") return standard_response(data=model_id) diff --git a/apis/push_config.py b/apis/push_config.py index 5e25579..e234afb 100644 --- a/apis/push_config.py +++ b/apis/push_config.py @@ -1,7 +1,7 @@ from typing import List from fastapi import APIRouter, Depends -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse from db.database import get_db @@ -12,21 +12,21 @@ @router.get("/list", response_model=StandardResponse[List[PushConfig]]) -def get_push_config_list(db: Session = Depends(get_db)): +async def get_push_config_list(db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_configs = service.get_push_config_list() + push_configs = await service.get_push_config_list() return standard_response(data=push_configs) @router.get("/get_by_type", response_model=StandardResponse[PushConfig]) -def get_by_type(push_type: int, db: Session = Depends(get_db)): +async def get_by_type(push_type: int, db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_config = service.get_push_config(push_type) + push_config = await service.get_push_config(push_type) return standard_response(data=push_config) @router.post("/set_push_config", response_model=StandardResponse[PushConfig]) -def set_push_config(push_config: PushConfigCreate, db: Session = Depends(get_db)): +async def set_push_config(push_config: PushConfigCreate, db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_config = service.set_push_config(push_config) + push_config = await service.set_push_config(push_config) return standard_response(data=push_config) diff --git a/apis/scene.py b/apis/scene.py index f84b54b..1b7f706 100644 --- a/apis/scene.py +++ b/apis/scene.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query, UploadFile, File, Form -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db @@ -14,27 +14,27 @@ @router.get("/list", response_model=StandardResponse[List[Scene]]) -def get_scene_list( +async def get_scene_list( name: Optional[str] = None, remark: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) - scenes = list(service.get_scene_list(name, remark)) - return standard_response(data=scenes) + scenes = await service.get_scene_list(name, remark) + return standard_response(data=list(scenes)) @router.get("/page/", response_model=StandardResponse[PageResponse[SceneInfo]]) -def get_scene_page( +async def get_scene_page( name: Optional[str] = None, remark: Optional[str] = None, offset: int = Query(0, ge=0), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - scenes, total = service.get_scene_page(name, remark, offset, limit) + scenes, total = await service.get_scene_page(name, remark, offset, limit) return standard_response( data=PageResponse(total=total, items=scenes) @@ -42,9 +42,9 @@ @router.post("/add", response_model=StandardResponse[SceneInfo]) -def create_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneCreate结构"), +async def create_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneCreate结构"), file: UploadFile = File(..., description="模型文件"), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_db), ): # 检查文件类型 if not file.filename.endswith(".zip"): @@ -52,26 +52,26 @@ scene_data = SceneCreate.parse_raw(json_data) service = SceneService(db) - scene = service.create_scene(scene_data, file) + scene = await service.create_scene(scene_data, file) return standard_response(data=scene) @router.post("/update", response_model=StandardResponse[SceneInfo]) -def update_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneUpdate结构"), +async def update_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneUpdate结构"), file: UploadFile = File(None, description="模型文件"), - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) scene_data = SceneUpdate.parse_raw(json_data) - scene = service.update_scene(scene_data, file) + scene = await service.update_scene(scene_data, file) if not scene: return standard_error_response(data=scene_data, message="Scene not found") return standard_response(data=scene) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_scene(scene_id: int, db: Session = Depends(get_db)): +async def delete_scene(scene_id: int, db: AsyncSession = Depends(get_db)): service = SceneService(db) - scene = service.delete_scene(scene_id) + scene = await service.delete_scene(scene_id) if not scene: return standard_error_response(data=scene_id, message="Scene not found") return standard_response(data=scene_id) diff --git a/app_instance.py b/app_instance.py index a6e0992..f9d1ad9 100644 --- a/app_instance.py +++ b/app_instance.py @@ -12,6 +12,7 @@ from services.device_model_relation_service import DeviceModelRelationService from services.device_scene_relation_service import DeviceSceneRelationService from services.device_service import DeviceService +from services.global_config import GlobalConfig from services.model_service import ModelService from services.scene_service import SceneService from tcp.tcp_manager import TcpManager @@ -28,7 +29,11 @@ async def lifespan(app: FastAPI): main_loop = asyncio.get_running_loop() - with next(get_db()) as db: + # async with get_db() as db: + async for db in get_db(): + global_config = GlobalConfig() + await global_config.init_config() + device_service = DeviceService(db) model_service = ModelService(db) model_relation_service = DeviceModelRelationService(db) @@ -51,7 +56,7 @@ relation_service=model_relation_service, ) app.state.algo_runner = algo_runner - await algo_runner.start() + # await algo_runner.start() scene_runner = SceneRunner( device_service=device_service, @@ -61,7 +66,9 @@ main_loop=main_loop ) app.state.scene_runner = scene_runner - await scene_runner.start() + # await scene_runner.start() + + yield # 允许请求处理 diff --git a/common/global_logger.py b/common/global_logger.py index 3e3a6a9..3e672ca 100644 --- a/common/global_logger.py +++ b/common/global_logger.py @@ -12,6 +12,8 @@ logger = logging.getLogger("casic_safe_logger") logger.setLevel(logging.DEBUG) # 设置日志级别 +logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING) + # 创建一个TimedRotatingFileHandler handler = logging.handlers.TimedRotatingFileHandler( os.path.join(log_dir, 'app.log'), # 日志文件名 diff --git a/db/database.py b/db/database.py index 4f44d2d..01ac013 100644 --- a/db/database.py +++ b/db/database.py @@ -1,22 +1,25 @@ +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession +from sqlalchemy.orm import sessionmaker from sqlmodel import SQLModel, create_engine, Session -from contextlib import contextmanager +from contextlib import contextmanager, asynccontextmanager sqlite_file_name = "./db/safe-algo-pro.db" -sqlite_url = f"sqlite:///{sqlite_file_name}" +sqlite_url = f"sqlite+aiosqlite:///{sqlite_file_name}" # 使用异步SQLite驱动 -connect_args = {"check_same_thread": False} -engine = create_engine(sqlite_url, connect_args=connect_args) +engine = create_async_engine(sqlite_url, echo=False, future=True) # 初始化数据库表 -def init_db(): - SQLModel.metadata.create_all(engine) +async def init_db(): + async with engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.create_all) -# 数据库会话管理 -def get_db(): - session = Session(engine) - try: +# 异步数据库会话管理 +# @asynccontextmanager +async def get_db() -> AsyncSession: + async_session = sessionmaker( + engine, class_=AsyncSession, expire_on_commit=False + ) + async with async_session() as session: yield session - finally: - session.close() diff --git a/main.py b/main.py index e35e458..e830e86 100644 --- a/main.py +++ b/main.py @@ -3,7 +3,9 @@ import logging from fastapi.openapi.docs import get_swagger_ui_html -from starlette.staticfiles import StaticFiles +from fastapi.middleware.cors import CORSMiddleware +from fastapi.staticfiles import StaticFiles + from app_instance import get_app from common.global_logger import logger @@ -28,6 +30,14 @@ from apis.router import router app.include_router(router, prefix="/api") +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + if __name__ == "__main__": # 重定向 uvicorn 的日志 uvicorn_logger = logging.getLogger("uvicorn") diff --git a/requirements.txt b/requirements.txt index 161192a..aa9e1e6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,6 @@ starlette uvicorn sqlalchemy -aiohttp \ No newline at end of file +aiohttp +aiosqlite +aiofiles \ No newline at end of file diff --git a/services/data_gas_service.py b/services/data_gas_service.py index fe6ce53..6ac5665 100644 --- a/services/data_gas_service.py +++ b/services/data_gas_service.py @@ -2,23 +2,24 @@ from typing import Optional, Tuple, Sequence from sqlalchemy import func -from sqlmodel import Session, select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlmodel import select from entity.data_gas import DataGas, DataGasInfo from entity.device import Device class DataGasService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db - def add_data_gas(self, data_gas: DataGas): + async def add_data_gas(self, data_gas: DataGas): self.db.add(data_gas) - self.db.commit() - self.db.refresh(data_gas) + await self.db.commit() + await self.db.refresh(data_gas) return data_gas - def get_data_gas_page(self, + async def get_data_gas_page(self, device_code: Optional[str] = None, start_time: Optional[datetime] = None, end_time: Optional[datetime] = None, @@ -29,13 +30,17 @@ # 查询总记录数 total_statement = select(func.count()).select_from(statement.subquery()) - total = self.db.exec(total_statement).one() + total_result = await self.db.execute(total_statement) + total = total_result.scalar_one() # 添加分页限制 statement = statement.offset(offset).limit(limit) # 执行查询并返回结果 - results = self.db.exec(statement).all() + results = await self.db.execute(statement) + rows = results.all() + + data_gas_info_list = [ DataGasInfo( id=data_gas.id, @@ -44,18 +49,19 @@ ts=data_gas.ts, device_name=device_name ) - for data_gas, device_name in results + for data_gas, device_name in rows ] return data_gas_info_list, total # 返回分页数据和总数 - def get_data_gas_list(self, + async def get_data_gas_list(self, device_code: Optional[str] = None, start_time: Optional[datetime] = None, end_time: Optional[datetime] = None, ) -> Sequence[DataGasInfo]: statement = self.gas_query(device_code, end_time, start_time) - results = self.db.exec(statement).all() + results = await self.db.execute(statement) + rows = results.all() data_gas_info_list = [ DataGasInfo( id=data_gas.id, @@ -64,7 +70,7 @@ ts=data_gas.ts, device_name=device_name ) - for data_gas, device_name in results + for data_gas, device_name in rows ] return data_gas_info_list diff --git a/services/device_frame_service.py b/services/device_frame_service.py index 5dead7d..1a468ae 100644 --- a/services/device_frame_service.py +++ b/services/device_frame_service.py @@ -3,7 +3,10 @@ from copy import deepcopy from datetime import datetime from typing import Sequence, Optional, Tuple -from sqlmodel import Session + +import aiofiles +import numpy as np +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import func @@ -13,46 +16,56 @@ import cv2 -from sqlmodel import Session, select, delete +from sqlmodel import select, delete from services.frame_analysis_result_service import FrameAnalysisResultService class DeviceFrameService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db - def add_frame(self, device_id, frame_data) -> DeviceFrame: - def save_frame_file(): - # 生成当前年月日作为目录路径 - current_date = datetime.now().strftime('%Y-%m-%d') - # 生成随机 UUID 作为文件名 - file_name = str(uuid.uuid4()) + ".jpeg" - # 创建保存图片的完整路径 - save_path = os.path.join('./storage/frames', current_date, file_name) - # 创建目录(如果不存在) - os.makedirs(os.path.dirname(save_path), exist_ok=True) - # 保存图片 - cv2.imwrite(save_path, frame_data) - return save_path + async def add_frame(self, device_id, frame_data) -> DeviceFrame: + async def add_frame(self, device_id, frame_data) -> 'DeviceFrame': + async def save_frame_file(): + # 生成当前年月日作为目录路径 + current_date = datetime.now().strftime('%Y-%m-%d') + # 生成随机 UUID 作为文件名 + file_name = f"{uuid.uuid4()}.jpeg" + # 创建保存图片的完整路径 + save_path = os.path.join('./storage/frames', current_date, file_name) + # 创建目录(如果不存在) + os.makedirs(os.path.dirname(save_path), exist_ok=True) - # 保存图片文件 - file_path = save_frame_file() - device_frame = DeviceFrame(device_id=device_id, frame_path=file_path) - self.db.add(device_frame) - self.db.commit() - self.db.refresh(device_frame) - return device_frame + # 将 frame_data 转换为二进制数据(假设 frame_data 是一个 numpy 数组) + _, encoded_image = cv2.imencode('.jpeg', frame_data) + image_data = encoded_image.tobytes() - def get_frame_page(self, - device_name: Optional[str] = None, - device_code: Optional[str] = None, - frame_start_time: Optional[datetime] = None, - frame_end_time: Optional[datetime] = None, - offset: int = 0, - limit: int = 10 - ) -> Tuple[Sequence[DeviceFrame], int]: + # 使用 aiofiles 进行异步写入 + async with aiofiles.open(save_path, 'wb') as f: + await f.write(image_data) + + return save_path + + # 异步保存图片文件 + file_path = await save_frame_file() + + # 创建并保存到数据库中 + device_frame = DeviceFrame(device_id=device_id, frame_path=file_path) + self.db.add(device_frame) + await self.db.commit() + await self.db.refresh(device_frame) + return device_frame + + async def get_frame_page(self, + device_name: Optional[str] = None, + device_code: Optional[str] = None, + frame_start_time: Optional[datetime] = None, + frame_end_time: Optional[datetime] = None, + offset: int = 0, + limit: int = 10 + ) -> Tuple[Sequence[DeviceFrame], int]: statement = ( select(DeviceFrame, Device) .join(Device, DeviceFrame.device_id == Device.id) @@ -69,26 +82,37 @@ # 查询总记录数 total_statement = select(func.count()).select_from(statement.subquery()) - total = self.db.exec(total_statement).one() + total_result = await self.db.execute(total_statement) + total = total_result.scalar_one() # 添加分页限制 statement = statement.offset(offset).limit(limit) # 执行查询并返回结果 - results = self.db.exec(statement) - results = results.all() - frames = [frame for frame, device in results] + results = await self.db.execute(statement) + rows = results.all() + + frames = [frame for frame, device in rows] return frames, total # 返回分页数据和总数 - def get_frame(self, frame_id: int): - return self.db.get(DeviceFrame, frame_id) + async def get_frame(self, frame_id: int): + result = await self.db.execute(select(DeviceFrame).where(DeviceFrame.id == frame_id)) + frame = result.scalar_one_or_none() + return frame - def get_frame_annotator(self, frame_id: int): - device_frame = self.get_frame(frame_id) + async def get_frame_annotator(self, frame_id: int): + device_frame = await self.get_frame(frame_id) if device_frame: - frame_image = cv2.imread(device_frame.frame_path) + # 异步读取图像文件 + async with aiofiles.open(device_frame.frame_path, mode='rb') as f: + file_content = await f.read() + + # 将读取的字节内容转换为 OpenCV 图像 + np_array = np.frombuffer(file_content, dtype=np.uint8) + frame_image = cv2.imdecode(np_array, cv2.IMREAD_COLOR) + frame_analysis_result_service = FrameAnalysisResultService(self.db) - results = frame_analysis_result_service.get_results_by_frame(device_frame.id) + results = await frame_analysis_result_service.get_results_by_frame(device_frame.id) if results: annotator = Annotator(deepcopy(frame_image)) height, width = frame_image.shape[:2] diff --git a/services/device_model_relation_service.py b/services/device_model_relation_service.py index c32b5a5..5f1aa9d 100644 --- a/services/device_model_relation_service.py +++ b/services/device_model_relation_service.py @@ -1,8 +1,8 @@ from datetime import datetime from typing import List -from sqlmodel import Session, select, delete - +from sqlmodel import select, delete +from sqlalchemy.ext.asyncio import AsyncSession from common.consts import NotifyChangeType from common.global_thread_pool import GlobalThreadPool from entity.device_model_relation import DeviceModelRelation, DeviceModelRelationInfo, DeviceModelRelationCreate @@ -10,7 +10,7 @@ class DeviceModelRelationService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db self.__relation_change_callbacks = [] # 用于存储回调函数 self.thread_pool = GlobalThreadPool() @@ -24,7 +24,7 @@ for callback in self.__relation_change_callbacks: self.thread_pool.executor.submit(callback, device_id, change_type) - def get_device_models(self, device_id: int) -> List[DeviceModelRelationInfo]: + async def get_device_models(self, device_id: int) -> List[DeviceModelRelationInfo]: statement = ( select(DeviceModelRelation, AlgoModel) .join(AlgoModel, DeviceModelRelation.algo_model_id == AlgoModel.id) @@ -32,7 +32,8 @@ ) # 执行联表查询 - result = self.db.exec(statement).all() + results = await self.db.execute(statement) + rows = results.all() models_info = [ DeviceModelRelationInfo( @@ -46,13 +47,11 @@ algo_model_path=model.path, algo_model_remark=model.remark, ) - for relation, model in result + for relation, model in rows ] return models_info - - - def add_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]): + async def add_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]): new_relations = [ DeviceModelRelation( algo_model_id=relation.algo_model_id, @@ -64,20 +63,21 @@ ) for relation in relations ] - self.db.add_all(new_relations) - self.db.commit() for relation in new_relations: - self.db.refresh(relation) + self.db.add(relation) + await self.db.commit() + for relation in new_relations: + await self.db.refresh(relation) return new_relations - def delete_relations_by_device(self, device_id: int): + async def delete_relations_by_device(self, device_id: int): statement = delete(DeviceModelRelation).where(DeviceModelRelation.device_id == device_id) - count = self.db.exec(statement) - self.db.commit() - return count.rowcount + result = await self.db.execute(statement) + await self.db.commit() + return result.rowcount - def update_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]): - self.delete_relations_by_device(device_id) - new_relations = self.add_relations_by_device(device_id, relations) + async def update_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]): + await self.delete_relations_by_device(device_id) + new_relations = await self.add_relations_by_device(device_id, relations) self.notify_change(device_id, NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE) return new_relations diff --git a/services/device_scene_relation_service.py b/services/device_scene_relation_service.py index 592ee98..3416831 100644 --- a/services/device_scene_relation_service.py +++ b/services/device_scene_relation_service.py @@ -1,8 +1,8 @@ from datetime import datetime from typing import List, Optional -from sqlmodel import Session, select, delete - +from sqlmodel import select, delete +from sqlalchemy.ext.asyncio import AsyncSession from common.consts import NotifyChangeType from common.global_thread_pool import GlobalThreadPool from entity.device_scene_relation import DeviceSceneRelationInfo, DeviceSceneRelation @@ -10,7 +10,7 @@ class DeviceSceneRelationService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db self.__relation_change_callbacks = [] # 用于存储回调函数 self.thread_pool = GlobalThreadPool() @@ -24,7 +24,7 @@ for callback in self.__relation_change_callbacks: self.thread_pool.executor.submit(callback, device_id, change_type) - def get_device_scene(self, device_id: int) -> Optional[DeviceSceneRelationInfo]: + async def get_device_scene(self, device_id: int) -> Optional[DeviceSceneRelationInfo]: statement = ( select(DeviceSceneRelation, Scene) .join(Scene, DeviceSceneRelation.scene_id == Scene.id) @@ -32,11 +32,12 @@ ) # 执行联表查询 - result = self.db.exec(statement).first() + result = await self.db.execute(statement) + result_row = result.first() scene_info = None - if result: - relation, scene = result[0], result[1] + if result_row: + relation, scene = result_row scene_info = DeviceSceneRelationInfo( id=relation.id, device_id=relation.device_id, @@ -48,23 +49,23 @@ ) return scene_info - def add_relation_by_device(self, device_id: int, scene_id: int): + async def add_relation_by_device(self, device_id: int, scene_id: int): new_relation = DeviceSceneRelation(device_id=device_id, scene_id=scene_id) new_relation.create_time = datetime.now() new_relation.update_time = datetime.now() self.db.add(new_relation) - self.db.commit() - self.db.refresh(new_relation) + await self.db.commit() + await self.db.refresh(new_relation) return new_relation - def delete_relation_by_device(self, device_id: int): + async def delete_relation_by_device(self, device_id: int): statement = delete(DeviceSceneRelation).where(DeviceSceneRelation.device_id == device_id) - count = self.db.exec(statement) - self.db.commit() - return count.rowcount + result = await self.db.execute(statement) + await self.db.commit() + return result.rowcount - def update_relation_by_device(self, device_id: int, scene_id: int): - self.delete_relation_by_device(device_id) - new_relation = self.add_relation_by_device(device_id, scene_id) + async def update_relation_by_device(self, device_id: int, scene_id: int): + await self.delete_relation_by_device(device_id) + new_relation = await self.add_relation_by_device(device_id, scene_id) self.notify_change(device_id, NotifyChangeType.DEVICE_SCENE_RELATION_UPDATE) return new_relation diff --git a/services/device_service.py b/services/device_service.py index cad46d3..2488fd5 100644 --- a/services/device_service.py +++ b/services/device_service.py @@ -4,8 +4,8 @@ from typing import Sequence, Optional, Tuple from sqlalchemy import func -from sqlmodel import Session, select - +from sqlmodel import select, delete +from sqlalchemy.ext.asyncio import AsyncSession from common.device_status_manager import DeviceStatusManager from common.global_thread_pool import GlobalThreadPool from common.consts import NotifyChangeType, DEVICE_MODE @@ -15,25 +15,25 @@ class DeviceService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db self.__device_change_callbacks = [] # 用于存储回调函数 self.thread_pool = GlobalThreadPool() # 创建一个独立的事件循环并启动线程 - self.loop = asyncio.new_event_loop() - self.loop_thread = threading.Thread(target=self._start_loop, daemon=True) - self.loop_thread.start() + # self.loop = asyncio.new_event_loop() + # self.loop_thread = threading.Thread(target=self._start_loop, daemon=True) + # self.loop_thread.start() - def _start_loop(self): - """后台线程运行事件循环""" - asyncio.set_event_loop(self.loop) - self.loop.run_forever() - - def shutdown(self): - """清理事件循环和线程""" - self.loop.call_soon_threadsafe(self.loop.stop) - self.loop_thread.join() + # def _start_loop(self): + # """后台线程运行事件循环""" + # asyncio.set_event_loop(self.loop) + # self.loop.run_forever() + # + # def shutdown(self): + # """清理事件循环和线程""" + # self.loop.call_soon_threadsafe(self.loop.stop) + # self.loop_thread.join() def register_change_callback(self, callback): """注册设备变化回调函数""" @@ -42,52 +42,53 @@ def notify_change(self, device_id, change_type): """当设备发生变化时,调用回调通知变化""" # loop = asyncio.get_event_loop() # 获取当前的事件循环 - for callback in self.__device_change_callbacks: - if asyncio.iscoroutinefunction(callback): - # 如果是协程函数,使用事件循环运行它 - # loop = asyncio.new_event_loop() - # asyncio.set_event_loop(loop) - asyncio.run_coroutine_threadsafe(callback(device_id, change_type), self.loop) - else: - # 如果是普通函数,直接提交到线程池 - self.thread_pool.executor.submit(callback, device_id, change_type) + # for callback in self.__device_change_callbacks: + # if asyncio.iscoroutinefunction(callback): + # # 如果是协程函数,使用事件循环运行它 + # # loop = asyncio.new_event_loop() + # # asyncio.set_event_loop(loop) + # asyncio.run_coroutine_threadsafe(callback(device_id, change_type), self.loop) + # else: + # # 如果是普通函数,直接提交到线程池 + # self.thread_pool.executor.submit(callback, device_id, change_type) - def get_device_list(self, - name: Optional[str] = None, - code: Optional[str] = None, - device_type: Optional[str] = None, - ) -> Sequence[Device]: + async def get_device_list(self, + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + ) -> Sequence[Device]: statement = self.device_query(code, device_type, name) - results = self.db.exec(statement) - return results.all() + results = await self.db.execute(statement) + return results.scalars().all() - def get_device_page(self, - name: Optional[str] = None, - code: Optional[str] = None, - device_type: Optional[str] = None, - offset: int = 0, - limit: int = 10 - ) -> Tuple[Sequence[DeviceInfo], int]: + async def get_device_page(self, + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + offset: int = 0, + limit: int = 10 + ) -> Tuple[Sequence[DeviceInfo], int]: statement = self.device_query(code, device_type, name) # 查询总记录数 total_statement = select(func.count()).select_from(statement.subquery()) - total = self.db.exec(total_statement).one() + total_result = await self.db.execute(total_statement) + total = total_result.scalar_one() # 添加分页限制 statement = statement.offset(offset).limit(limit) # 执行查询并返回结果 - results = self.db.exec(statement) - device_list = results.all() + results = await self.db.execute(statement) + device_list = results.scalars().all() device_info_list = [] if device_list: device_model_relation_service = DeviceModelRelationService(self.db) device_scene_relation_service = DeviceSceneRelationService(self.db) device_status_manager = DeviceStatusManager() for device in device_list: - model_relations = device_model_relation_service.get_device_models(device.id) - scene_relation = device_scene_relation_service.get_device_scene(device.id) + model_relations = await device_model_relation_service.get_device_models(device.id) + scene_relation = await device_scene_relation_service.get_device_scene(device.id) device_info_list.append(DeviceInfo( id=device.id, @@ -121,20 +122,20 @@ statement = statement.where(Device.type == device_type) return statement - def create_device(self, device_data: DeviceCreate): + async def create_device(self, device_data: DeviceCreate): device = Device.model_validate(device_data) device.create_time = datetime.now() device.update_time = datetime.now() - self.handle_device_mode(device) + await self.handle_device_mode(device) self.db.add(device) - self.db.commit() - self.db.refresh(device) + await self.db.commit() + await self.db.refresh(device) self.notify_change(device.id, NotifyChangeType.DEVICE_CREATE) return device - def update_device(self, device_data: DeviceUpdate): - device_old = self.db.get(Device, device_data.id) - device = self.db.get(Device, device_data.id) + async def update_device(self, device_data: DeviceUpdate): + device_old = await self.get_device(device_data.id) + device = await self.get_device(device_data.id) if not device: return None @@ -142,41 +143,45 @@ for key, value in update_data.items(): setattr(device, key, value) - self.handle_device_mode(device) + await self.handle_device_mode(device) device.update_time = datetime.now() self.db.add(device) - self.db.commit() - self.db.refresh(device) + await self.db.commit() + await self.db.refresh(device) self.notify_change(device.id, NotifyChangeType.DEVICE_UPDATE) return device - def delete_device(self, device_id: int): - device = self.db.get(Device, device_id) + async def delete_device(self, device_id: int): + device = await self.get_device(device_id) if not device: return None - self.db.delete(device) - self.db.commit() + statement = delete(Device).where(Device.id == device_id) + await self.db.execute(statement) + await self.db.commit() + self.notify_change(device.id, NotifyChangeType.DEVICE_DELETE) model_relation_service = DeviceModelRelationService(self.db) - model_relation_service.delete_relations_by_device(device_id) + await model_relation_service.delete_relations_by_device(device_id) scene_relation_service = DeviceSceneRelationService(self.db) - scene_relation_service.delete_relation_by_device(device.id) + await scene_relation_service.delete_relation_by_device(device.id) return device - def handle_device_mode(self, device): + async def handle_device_mode(self, device): if device.mode == DEVICE_MODE.ALGO: scene_relation_service = DeviceSceneRelationService(self.db) - scene_relation_service.delete_relation_by_device(device.id) + await scene_relation_service.delete_relation_by_device(device.id) elif device.mode == DEVICE_MODE.SCENE: model_relation_service = DeviceModelRelationService(self.db) - model_relation_service.delete_relations_by_device(device.id) + await model_relation_service.delete_relations_by_device(device.id) else: scene_relation_service = DeviceSceneRelationService(self.db) - scene_relation_service.delete_relation_by_device(device.id) + await scene_relation_service.delete_relation_by_device(device.id) model_relation_service = DeviceModelRelationService(self.db) - model_relation_service.delete_relations_by_device(device.id) + await model_relation_service.delete_relations_by_device(device.id) - def get_device(self, device_id: int): - return self.db.get(Device, device_id) + async def get_device(self, device_id: int): + result = await self.db.execute(select(Device).where(Device.id == device_id)) + frame = result.scalar_one_or_none() + return frame diff --git a/services/frame_analysis_result_service.py b/services/frame_analysis_result_service.py index dbd272d..4e6c145 100644 --- a/services/frame_analysis_result_service.py +++ b/services/frame_analysis_result_service.py @@ -1,23 +1,25 @@ from typing import List -from sqlmodel import Session, select +from sqlmodel import select +from sqlalchemy.ext.asyncio import AsyncSession from entity.frame_analysis_result import FrameAnalysisResultCreate, FrameAnalysisResult class FrameAnalysisResultService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db - def add_frame_analysis_results(self, results: List[FrameAnalysisResultCreate]): + async def add_frame_analysis_results(self, results: List[FrameAnalysisResultCreate]): new_results = [FrameAnalysisResult.model_validate(result) for result in results] - self.db.add_all(new_results) - self.db.commit() for result in new_results: - self.db.refresh(result) + self.db.add(result) + await self.db.commit() + for result in new_results: + await self.db.refresh(result) return new_results - def get_results_by_frame(self, frame_id): + async def get_results_by_frame(self, frame_id): statement = select(FrameAnalysisResult).where(FrameAnalysisResult.frame_id == frame_id) - results = self.db.exec(statement) - return results.all() + results = await self.db.execute(statement) + return results.scalars().all() diff --git a/services/global_config.py b/services/global_config.py index 7dae73b..fa53b2b 100644 --- a/services/global_config.py +++ b/services/global_config.py @@ -1,3 +1,5 @@ +import asyncio + from common.consts import PUSH_TYPE from db.database import get_db from entity.push_config import PushConfig @@ -6,10 +8,12 @@ class GlobalConfig: _instance = None + _lock = asyncio.Lock() def __new__(cls): if cls._instance is None: cls._instance = super().__new__(cls) + return cls._instance def __init__(self): @@ -19,49 +23,58 @@ self.gas_push_config = None self.algo_result_push_config = None self.alarm_push_config = None - self.init_config() # 进行初始化 + self._init_done = False - def init_config(self): - # 初始化配置逻辑 - with next(get_db()) as db: - self.config_service = PushConfigService(db) - self.set_gas_push_config(self.config_service.get_push_config(PUSH_TYPE.GAS)) - self.set_algo_result_push_config(self.config_service.get_push_config(PUSH_TYPE.ALGO_RESULT)) - self.set_alarm_push_config(self.config_service.get_push_config(PUSH_TYPE.ALARM)) + async def _initialize(self): + await self.init_config() # 调用异步初始化 - self.config_service.register_change_callback(self.on_config_change) + async def init_config(self): + # 确保只初始化一次 + if not self._init_done: + async with self._lock: + if not self._init_done: # 双重检查锁,避免多次初始化 + async for db in get_db(): + self.config_service = PushConfigService(db) + self.config_service.register_change_callback(self.on_config_change) + self.gas_push_config = await self.config_service.get_push_config(PUSH_TYPE.GAS) + self.algo_result_push_config = await self.config_service.get_push_config(PUSH_TYPE.ALGO_RESULT) + self.alarm_push_config = await self.config_service.get_push_config(PUSH_TYPE.ALARM) + self._init_done = True - def on_config_change(self, config: PushConfig): + async def on_config_change(self, config: PushConfig): if config.push_type == PUSH_TYPE.GAS: - self.set_gas_push_config(config) + await self.set_gas_push_config(config) elif config.push_type == PUSH_TYPE.ALGO_RESULT: - self.set_algo_result_push_config(config) + await self.set_algo_result_push_config(config) elif config.push_type == PUSH_TYPE.ALARM: - self.set_alarm_push_config(config) + await self.set_alarm_push_config(config) def get_gas_push_config(self): """获取 gas_push_config 配置""" return self.gas_push_config - def set_gas_push_config(self, config): + async def set_gas_push_config(self, config): """设置 gas_push_config 配置""" if config: - self.gas_push_config = config + async with self._lock: + self.gas_push_config = config def get_algo_result_push_config(self): """获取 algo_result_push_config 配置""" return self.algo_result_push_config - def set_algo_result_push_config(self, config): + async def set_algo_result_push_config(self, config): """设置 algo_result_push_config 配置""" if config: - self.algo_result_push_config = config + async with self._lock: + self.algo_result_push_config = config def get_alarm_push_config(self): """获取 algo_result_push_config 配置""" return self.alarm_push_config - def set_alarm_push_config(self, config): + async def set_alarm_push_config(self, config): """设置 algo_result_push_config 配置""" if config: - self.alarm_push_config = config + async with self._lock: + self.alarm_push_config = config diff --git a/services/model_service.py b/services/model_service.py index f9531a5..afc61d9 100644 --- a/services/model_service.py +++ b/services/model_service.py @@ -1,3 +1,4 @@ +import asyncio import os import uuid import zipfile @@ -5,10 +6,11 @@ from pathlib import Path from typing import List, Sequence, Optional, Tuple, Type +import aiofiles from fastapi import UploadFile from sqlalchemy import func -from sqlmodel import Session, select - +from sqlmodel import select, delete +from sqlalchemy.ext.asyncio import AsyncSession from common.biz_exception import BizException from common.string_utils import snake_to_camel from entity.device_model_relation import DeviceModelRelation @@ -18,7 +20,7 @@ class ModelService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db self.__model_change_callbacks = [] # 用于存储回调函数 self.thread_pool = GlobalThreadPool() @@ -32,15 +34,15 @@ for callback in self.__model_change_callbacks: self.thread_pool.executor.submit(callback, algo_model_id, change_type) - def get_model_list(self, - name: Optional[str] = None, - remark: Optional[str] = None, - ) -> Sequence[AlgoModel]: + async def get_model_list(self, + name: Optional[str] = None, + remark: Optional[str] = None, + ) -> Sequence[AlgoModel]: statement = self.model_query(name, remark) - results = self.db.exec(statement) - return results.all() + results = await self.db.execute(statement) + return results.scalars().all() - def get_model_page(self, + async def get_model_page(self, name: Optional[str] = None, remark: Optional[str] = None, offset: int = 0, @@ -50,19 +52,21 @@ # 查询总记录数 total_statement = select(func.count()).select_from(statement.subquery()) - total = self.db.exec(total_statement).one() + total_result = await self.db.execute(total_statement) + total = total_result.scalar_one() # 添加分页限制 statement = statement.offset(offset).limit(limit) # 执行查询并返回结果 - model_list = self.db.exec(statement) + model_list = await self.db.execute(statement) + rows = model_list.scalars().all() model_info_list: List[AlgoModelInfo] = [] - if model_list: - for model in model_list: + if rows: + for model in rows: model_info_list.append(AlgoModelInfo( **model.dict(), - usage_status="使用中" if self.get_model_usage(model.id) else "未使用" + usage_status="使用中" if await self.get_model_usage(model.id) else "未使用" )) return model_info_list, total # 返回分页数据和总数 @@ -76,7 +80,7 @@ statement = statement.where(AlgoModel.remark.like(f"%{remark}%")) return statement - def process_zip(self, file: UploadFile): + async def process_zip(self, file: UploadFile): model_dir = Path('weights/') model_handle_dir = Path('model_handler/') model_dir.mkdir(parents=True, exist_ok=True) @@ -86,13 +90,16 @@ # 临时保存上传文件 temp_path = Path(f"temp_upload_{uuid.uuid4()}.zip") - with open(temp_path, "wb") as temp_file: - temp_file.write(file.file.read()) + async with aiofiles.open(temp_path, "wb") as temp_file: + content = await file.read() # 异步读取上传文件内容 + await temp_file.write(content) model_file_path = None handle_file_path = None try: + # 使用异步方法读取压缩文件的内容 + loop = asyncio.get_event_loop() with zipfile.ZipFile(temp_path, 'r') as zip_ref: # 获取压缩包文件列表 file_list = zip_ref.namelist() @@ -105,47 +112,47 @@ message=f"Model weight file ({', '.join(SUPPORTED_MODEL_EXTENSIONS)}) is required in the zip." ) - # 解压模型文件到模型目录 - zip_ref.extract(model_file, model_dir) + # 异步解压模型文件到模型目录 model_file_path = model_dir / model_file + await loop.run_in_executor(None, zip_ref.extract, model_file, model_dir) # 检查是否有可选的 Python 脚本 handle_file = next((f for f in file_list if f.endswith(".py")), None) if handle_file: - zip_ref.extract(handle_file, model_handle_dir) handle_file_path = model_handle_dir / handle_file - + await loop.run_in_executor(None, zip_ref.extract, handle_file, model_handle_dir) except zipfile.BadZipFile: raise BizException(status_code=400, message="Invalid zip file.") finally: # 删除临时文件 - temp_path.unlink() + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, temp_path.unlink) return str(model_file_path), str(handle_file_path) if handle_file_path else None - def create_model(self, model_data: AlgoModelCreate, file: UploadFile): - self.process_model_file(file, model_data) + async def create_model(self, model_data: AlgoModelCreate, file: UploadFile): + await self.process_model_file(file, model_data) model = AlgoModel.model_validate(model_data) model.create_time = datetime.now() model.update_time = datetime.now() self.db.add(model) - self.db.commit() - self.db.refresh(model) + await self.db.commit() + await self.db.refresh(model) return model - def process_model_file(self, file, model): - model_file_path, handle_file_path = self.process_zip(file) + async def process_model_file(self, file, model): + model_file_path, handle_file_path = await self.process_zip(file) model.path = model_file_path if handle_file_path: model.handle_task = snake_to_camel(os.path.splitext(os.path.basename(handle_file_path))[0]) else: model.handle_task = 'BaseModelHandler' - def update_model(self, model_data: AlgoModelUpdate, file: UploadFile): - model = self.db.get(AlgoModel, model_data.id) + async def update_model(self, model_data: AlgoModelUpdate, file: UploadFile): + model = await self.get_model_by_id(model_data.id) if not model: return None @@ -155,16 +162,16 @@ model.update_time = datetime.now() if file: - self.process_model_file(file, model) + await self.process_model_file(file, model) self.db.add(model) - self.db.commit() - self.db.refresh(model) + await self.db.commit() + await self.db.refresh(model) self.notify_change(model.id, NotifyChangeType.MODEL_UPDATE) return model - def delete_model(self, model_id: int): - model = self.db.get(AlgoModel, model_id) + async def delete_model(self, model_id: int): + model = await self.get_model_by_id(model_id) if not model: return None # 查询 device_model_relation 中是否存在启用的绑定关系 @@ -173,17 +180,20 @@ .where(DeviceModelRelation.algo_model_id == model_id) .where(DeviceModelRelation.is_use == 1) ) - relation_in_use = self.db.exec(statement).first() + relation_in_use_exec = await self.db.execute(statement) + relation_in_use = relation_in_use_exec.scalars().first() # 如果存在启用的绑定关系,提示无法删除 if relation_in_use: raise BizException(message=f"模型 {model.name} 正在被设备使用,无法删除") - self.db.delete(model) - self.db.commit() + statement = delete(AlgoModel).where(AlgoModel.id == model_id) + await self.db.execute(statement) + await self.db.commit() + return model - def get_models_in_use(self) -> Sequence[AlgoModel]: + async def get_models_in_use(self) -> Sequence[AlgoModel]: """获取所有在 device_model_relation 表里有启用绑定关系的模型信息""" statement = ( select(AlgoModel) @@ -191,10 +201,10 @@ .where(DeviceModelRelation.is_use == 1) .group_by(AlgoModel.id) ) - results = self.db.exec(statement).all() - return results + results = await self.db.execute(statement) + return results.scalars().all() - def get_model_usage(self, algo_model_id) -> bool: + async def get_model_usage(self, algo_model_id) -> bool: statement = ( select(DeviceModelRelation) .where( @@ -202,8 +212,11 @@ DeviceModelRelation.algo_model_id == algo_model_id, ) ) - result = self.db.exec(statement).all() - return len(result) > 0 + result = await self.db.execute(statement) + rows = result.all() + return len(rows) > 0 - def get_model_by_id(self, model_id): - return self.db.get(AlgoModel, model_id) + async def get_model_by_id(self, model_id): + result = await self.db.execute(select(AlgoModel).where(AlgoModel.id == model_id)) + model = result.scalar_one_or_none() + return model diff --git a/services/push_config_service.py b/services/push_config_service.py index 8a2c298..14862df 100644 --- a/services/push_config_service.py +++ b/services/push_config_service.py @@ -1,6 +1,8 @@ +import asyncio from datetime import datetime -from sqlmodel import Session, select +from sqlmodel import select +from sqlalchemy.ext.asyncio import AsyncSession from entity.push_config import PushConfigCreate, PushConfig @@ -13,7 +15,7 @@ cls._instance = super().__new__(cls) return cls._instance - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): if not hasattr(self, 'initialized'): self.db = db self.__push_change_callbacks = [] # 用于存储回调函数 @@ -23,12 +25,22 @@ """注册设备变化回调函数""" self.__push_change_callbacks.append(callback) - def notify_change(self, push_config): - for callback in self.__push_change_callbacks: - callback(push_config) + # def notify_change(self, push_config): + # for callback in self.__push_change_callbacks: + # callback(push_config) - def set_push_config(self, push_config_create: PushConfigCreate): - push_config = self.get_push_config(push_config_create.push_type) + def notify_change(self, config): + """通知所有回调函数""" + for callback in self.__push_change_callbacks: + if asyncio.iscoroutinefunction(callback): + # 如果是异步函数,使用 asyncio.create_task() 调度 + asyncio.create_task(callback(config)) + else: + # 如果是同步函数,直接调用 + callback(config) + + async def set_push_config(self, push_config_create: PushConfigCreate): + push_config = await self.get_push_config(push_config_create.push_type) if push_config: update_data = push_config_create.dict(exclude_unset=True) for key, value in update_data.items(): @@ -40,18 +52,18 @@ push_config.update_time = datetime.now() self.db.add(push_config) - self.db.commit() - self.db.refresh(push_config) + await self.db.commit() + await self.db.refresh(push_config) self.notify_change(push_config) return push_config - def get_push_config(self, push_type): + async def get_push_config(self, push_type): statement = select(PushConfig).where(PushConfig.push_type == push_type) - results = self.db.exec(statement) - return results.first() + results = await self.db.execute(statement) + return results.scalars().first() - def get_push_config_list(self): + async def get_push_config_list(self): statement = select(PushConfig) - results = self.db.exec(statement) - return results.all() \ No newline at end of file + results = await self.db.execute(statement) + return results.scalars().all() diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index dbc3fee..cdb97bf 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -98,7 +98,7 @@ def push_frame_results(self, frame_results): global_config = GlobalConfig() push_config = global_config.get_algo_result_push_config() - if push_config: + if push_config and push_config.push_url: last_ts = self.push_ts current_time = datetime.now() diff --git a/apis/control.py b/apis/control.py index 8517389..4c108c5 100644 --- a/apis/control.py +++ b/apis/control.py @@ -48,7 +48,7 @@ @router.get("/restart") -def restart(): +async def restart(): try: # 立即返回响应的函数 def restart_container_async(): diff --git a/apis/data_gas.py b/apis/data_gas.py index fc8ec54..bcb7cb6 100644 --- a/apis/data_gas.py +++ b/apis/data_gas.py @@ -3,8 +3,8 @@ import pandas as pd from fastapi import APIRouter, Query, Depends -from openpyxl.utils import get_column_letter -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession + from starlette.responses import StreamingResponse from apis.base import StandardResponse, PageResponse, convert_page_param, standard_response @@ -17,19 +17,20 @@ @router.get("/page", response_model=StandardResponse[PageResponse[DataGasInfo]]) -def get_gas_page( +async def get_gas_page( device_code: Optional[str] = None, start_time: Optional[str] = None, end_time: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): + service = DataGasService(db) offset, limit = convert_page_param(offset, limit) - data, total = service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), - offset, limit) + data, total = await service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), + offset, limit) return standard_response( data=PageResponse(total=total, items=data) @@ -37,38 +38,39 @@ @router.get("/export") -def export_data_gas(device_code: Optional[str] = None, - start_time: Optional[str] = None, - end_time: Optional[str] = None, - db: Session = Depends(get_db)): - service = DataGasService(db) - data = service.get_data_gas_list(device_code, start_time, end_time) +async def export_data_gas(device_code: Optional[str] = None, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + db: AsyncSession = Depends(get_db)): + async with db: + service = DataGasService(db) + data = await service.get_data_gas_list(device_code, start_time, end_time) - # 将查询结果转换为 DataFrame - data = [ - {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} - for item in data] - df = pd.DataFrame(data) + # 将查询结果转换为 DataFrame + data = [ + {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} + for item in data] + df = pd.DataFrame(data) - # 使用 BytesIO 生成内存中的 Excel 文件 - output = BytesIO() - with pd.ExcelWriter(output, engine="openpyxl") as writer: - df.to_excel(writer, index=False) + # 使用 BytesIO 生成内存中的 Excel 文件 + output = BytesIO() + with pd.ExcelWriter(output, engine="openpyxl") as writer: + df.to_excel(writer, index=False) - # 获取工作表 - worksheet = writer.sheets["Sheet1"] + # 获取工作表 + worksheet = writer.sheets["Sheet1"] - # 设置固定的列宽,例如宽度为 20 - fixed_width = 30 - for col in worksheet.columns: - col_letter = col[0].column_letter # 获取列字母 - worksheet.column_dimensions[col_letter].width = fixed_width + # 设置固定的列宽,例如宽度为 20 + fixed_width = 30 + for col in worksheet.columns: + col_letter = col[0].column_letter # 获取列字母 + worksheet.column_dimensions[col_letter].width = fixed_width - output.seek(0) + output.seek(0) - # 返回内存中的 Excel 文件作为响应 - return StreamingResponse( - output, - media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", - headers={"Content-Disposition": "attachment; filename=gas.xlsx"} - ) + # 返回内存中的 Excel 文件作为响应 + return StreamingResponse( + output, + media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + headers={"Content-Disposition": "attachment; filename=gas.xlsx"} + ) diff --git a/apis/device.py b/apis/device.py index 85fb10c..f170ae3 100644 --- a/apis/device.py +++ b/apis/device.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db @@ -18,29 +18,29 @@ @router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) -def get_device_list( +async def get_device_list( name: Optional[str] = None, code: Optional[str] = None, device_type: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceService(db) - devices = list(service.get_device_list(name, code, device_type)) - return standard_response(data=devices) + devices = await service.get_device_list(name, code, device_type) + return standard_response(data=list(devices)) @router.get("/page/", response_model=StandardResponse[PageResponse[DeviceInfo]]) -def get_device_page( +async def get_device_page( name: Optional[str] = None, code: Optional[str] = None, device_type: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - devices, total = service.get_device_page(name, code, device_type, offset, limit) + devices, total = await service.get_device_page(name, code, device_type, offset, limit) return standard_response( data=PageResponse(total=total, items=devices) @@ -48,22 +48,22 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): - device = service.create_device(device_data) +async def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): + device = await service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): - device = service.update_device(device_data) +async def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): + device = await service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") return standard_response(data=device) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, service: DeviceService = Depends(get_service)): - device = service.delete_device(device_id) +async def delete_device(device_id: int, service: DeviceService = Depends(get_service)): + device = await service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") return standard_response(data=device_id) diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py index e61e1d6..ddb5eb6 100644 --- a/apis/device_model_realtion.py +++ b/apis/device_model_realtion.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import StandardResponse, standard_response from db.database import get_db @@ -13,37 +13,23 @@ router = APIRouter() app = get_app() + def get_service(): return app.state.model_relation_service @router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) -def list_by_device( +async def list_by_device( device_id: int, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceModelRelationService(db) - models = list(service.get_device_models(device_id)) - return standard_response(data=models) - - -# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) -# def add_by_device(relation_data: List[DeviceModelRelationCreate], -# device_id: int = Query(...), -# db: Session = Depends(get_db)): -# service = DeviceModelRelationService(db) -# relations = service.add_relations_by_device(device_id, relation_data) -# return standard_response(data=relations) + models = await service.get_device_models(device_id) + return standard_response(data=list(models)) @router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) -def update_by_device(relation_data: List[DeviceModelRelationCreate], - device_id: int = Query(...), - service: DeviceModelRelationService = Depends(get_service)): - relations = service.update_relations_by_device(device_id, relation_data) +async def update_by_device(relation_data: List[DeviceModelRelationCreate], + device_id: int = Query(...), + service: DeviceModelRelationService = Depends(get_service)): + relations = await service.update_relations_by_device(device_id, relation_data) return standard_response(data=relations) - -# @router.delete("/delete_by_device", response_model=StandardResponse[int]) -# def delete_device(device_id: int, db: Session = Depends(get_db)): -# service = DeviceModelRelationService(db) -# count = service.delete_relations_by_device(device_id) -# return standard_response(data=count) diff --git a/apis/device_scene_realtion.py b/apis/device_scene_realtion.py index 9a30a44..d9aa3a2 100644 --- a/apis/device_scene_realtion.py +++ b/apis/device_scene_realtion.py @@ -1,8 +1,8 @@ -from typing import List, Optional + from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import StandardResponse, standard_response from db.database import get_db from entity.device_scene_relation import DeviceSceneRelationInfo, DeviceSceneRelation @@ -13,17 +13,17 @@ @router.get("/get_by_device", response_model=StandardResponse[DeviceSceneRelationInfo]) -def list_by_device( +async def list_by_device( device_id: int, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - scene = service.get_device_scene(device_id) + scene = await service.get_device_scene(device_id) return standard_response(data=scene) @router.post("/update_by_device", response_model=StandardResponse[DeviceSceneRelation]) -def update_by_device(device_id: int, scene_id: int, - db: Session = Depends(get_db)): +async def update_by_device(device_id: int, scene_id: int, + db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - relation = service.update_relation_by_device(device_id, scene_id) + relation = await service.update_relation_by_device(device_id, scene_id) return standard_response(data=relation) diff --git a/apis/frame.py b/apis/frame.py index 1cb9294..5146ede 100644 --- a/apis/frame.py +++ b/apis/frame.py @@ -5,7 +5,7 @@ import cv2 from fastapi import APIRouter, Depends, Query, HTTPException, Request -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from fastapi.responses import StreamingResponse, FileResponse from apis.base import StandardResponse, PageResponse, standard_response, standard_error_response, convert_page_param @@ -19,20 +19,20 @@ @router.get("/page/", response_model=StandardResponse[PageResponse[DeviceFrame]]) -def get_frame_page( +async def get_frame_page( device_name: Optional[str] = None, device_code: Optional[str] = None, frame_start_time: Optional[str] = None, frame_end_time: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceFrameService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - frames, total = service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), - parse_datetime(frame_end_time), offset, limit) + frames, total = await service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), + parse_datetime(frame_end_time), offset, limit) return standard_response( data=PageResponse(total=total, items=frames) @@ -41,10 +41,10 @@ # 路由:使用 OpenCV 生成内存图像并返回字节流 @router.get("/frame_image/") -def get_frame_image(frame_id, db: Session = Depends(get_db)): +async def get_frame_image(frame_id, db: AsyncSession = Depends(get_db)): try: service = DeviceFrameService(db) - frame = service.get_frame_annotator(frame_id) + frame = await service.get_frame_annotator(frame_id) if frame is None: return standard_error_response(message="Frame does not exist") @@ -63,14 +63,3 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"Error reading image: {str(e)}") - -@router.get("/frame_test") -def get_frame_test(request: Request): - file_path = "test.jpg" - return FileResponse(file_path) - # def iterfile(): - # with open(file_path, 'rb') as file: - # while chunk := file.read(1024): - # yield chunk - # - # return StreamingResponse(iterfile(), media_type="image/jpeg") diff --git a/apis/model.py b/apis/model.py index a1161eb..fa2b191 100644 --- a/apis/model.py +++ b/apis/model.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query, UploadFile, File, Form -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from app_instance import get_app @@ -18,27 +18,27 @@ @router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) -def get_model_list( +async def get_model_list( name: Optional[str] = None, remark: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = ModelService(db) - models = list(service.get_model_list(name, remark)) - return standard_response(data=models) + models = await service.get_model_list(name, remark) + return standard_response(data=list(models)) @router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModelInfo]]) -def get_model_page( +async def get_model_page( name: Optional[str] = None, remark: Optional[str] = None, offset: int = Query(0, ge=0), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = ModelService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - models, total = service.get_model_page(name, remark, offset, limit) + models, total = await service.get_model_page(name, remark, offset, limit) return standard_response( data=PageResponse(total=total, items=models) @@ -46,9 +46,9 @@ @router.post("/add", response_model=StandardResponse[AlgoModelInfo]) -def create_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelCreate结构"), +async def create_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelCreate结构"), file: UploadFile = File(..., description="模型文件"), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_db), ): # 检查文件类型 if not file.filename.endswith(".zip"): @@ -56,25 +56,25 @@ model_data = AlgoModelCreate.parse_raw(json_data) service = ModelService(db) - model = service.create_model(model_data, file) + model = await service.create_model(model_data, file) return standard_response(data=model) @router.post("/update", response_model=StandardResponse[AlgoModelInfo]) -def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), +async def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), file: UploadFile = File(None, description="模型文件"), service: ModelService = Depends(get_service)): model_data = AlgoModelUpdate.parse_raw(json_data) - model = service.update_model(model_data, file) + model = await service.update_model(model_data, file) if not model: return standard_error_response(data=model_data, message="Model not found") return standard_response(data=model) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_model(model_id: int, db: Session = Depends(get_db)): +async def delete_model(model_id: int, db: AsyncSession = Depends(get_db)): service = ModelService(db) - model = service.delete_model(model_id) + model = await service.delete_model(model_id) if not model: return standard_error_response(data=model_id, message="Model not found") return standard_response(data=model_id) diff --git a/apis/push_config.py b/apis/push_config.py index 5e25579..e234afb 100644 --- a/apis/push_config.py +++ b/apis/push_config.py @@ -1,7 +1,7 @@ from typing import List from fastapi import APIRouter, Depends -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse from db.database import get_db @@ -12,21 +12,21 @@ @router.get("/list", response_model=StandardResponse[List[PushConfig]]) -def get_push_config_list(db: Session = Depends(get_db)): +async def get_push_config_list(db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_configs = service.get_push_config_list() + push_configs = await service.get_push_config_list() return standard_response(data=push_configs) @router.get("/get_by_type", response_model=StandardResponse[PushConfig]) -def get_by_type(push_type: int, db: Session = Depends(get_db)): +async def get_by_type(push_type: int, db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_config = service.get_push_config(push_type) + push_config = await service.get_push_config(push_type) return standard_response(data=push_config) @router.post("/set_push_config", response_model=StandardResponse[PushConfig]) -def set_push_config(push_config: PushConfigCreate, db: Session = Depends(get_db)): +async def set_push_config(push_config: PushConfigCreate, db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_config = service.set_push_config(push_config) + push_config = await service.set_push_config(push_config) return standard_response(data=push_config) diff --git a/apis/scene.py b/apis/scene.py index f84b54b..1b7f706 100644 --- a/apis/scene.py +++ b/apis/scene.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query, UploadFile, File, Form -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db @@ -14,27 +14,27 @@ @router.get("/list", response_model=StandardResponse[List[Scene]]) -def get_scene_list( +async def get_scene_list( name: Optional[str] = None, remark: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) - scenes = list(service.get_scene_list(name, remark)) - return standard_response(data=scenes) + scenes = await service.get_scene_list(name, remark) + return standard_response(data=list(scenes)) @router.get("/page/", response_model=StandardResponse[PageResponse[SceneInfo]]) -def get_scene_page( +async def get_scene_page( name: Optional[str] = None, remark: Optional[str] = None, offset: int = Query(0, ge=0), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - scenes, total = service.get_scene_page(name, remark, offset, limit) + scenes, total = await service.get_scene_page(name, remark, offset, limit) return standard_response( data=PageResponse(total=total, items=scenes) @@ -42,9 +42,9 @@ @router.post("/add", response_model=StandardResponse[SceneInfo]) -def create_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneCreate结构"), +async def create_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneCreate结构"), file: UploadFile = File(..., description="模型文件"), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_db), ): # 检查文件类型 if not file.filename.endswith(".zip"): @@ -52,26 +52,26 @@ scene_data = SceneCreate.parse_raw(json_data) service = SceneService(db) - scene = service.create_scene(scene_data, file) + scene = await service.create_scene(scene_data, file) return standard_response(data=scene) @router.post("/update", response_model=StandardResponse[SceneInfo]) -def update_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneUpdate结构"), +async def update_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneUpdate结构"), file: UploadFile = File(None, description="模型文件"), - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) scene_data = SceneUpdate.parse_raw(json_data) - scene = service.update_scene(scene_data, file) + scene = await service.update_scene(scene_data, file) if not scene: return standard_error_response(data=scene_data, message="Scene not found") return standard_response(data=scene) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_scene(scene_id: int, db: Session = Depends(get_db)): +async def delete_scene(scene_id: int, db: AsyncSession = Depends(get_db)): service = SceneService(db) - scene = service.delete_scene(scene_id) + scene = await service.delete_scene(scene_id) if not scene: return standard_error_response(data=scene_id, message="Scene not found") return standard_response(data=scene_id) diff --git a/app_instance.py b/app_instance.py index a6e0992..f9d1ad9 100644 --- a/app_instance.py +++ b/app_instance.py @@ -12,6 +12,7 @@ from services.device_model_relation_service import DeviceModelRelationService from services.device_scene_relation_service import DeviceSceneRelationService from services.device_service import DeviceService +from services.global_config import GlobalConfig from services.model_service import ModelService from services.scene_service import SceneService from tcp.tcp_manager import TcpManager @@ -28,7 +29,11 @@ async def lifespan(app: FastAPI): main_loop = asyncio.get_running_loop() - with next(get_db()) as db: + # async with get_db() as db: + async for db in get_db(): + global_config = GlobalConfig() + await global_config.init_config() + device_service = DeviceService(db) model_service = ModelService(db) model_relation_service = DeviceModelRelationService(db) @@ -51,7 +56,7 @@ relation_service=model_relation_service, ) app.state.algo_runner = algo_runner - await algo_runner.start() + # await algo_runner.start() scene_runner = SceneRunner( device_service=device_service, @@ -61,7 +66,9 @@ main_loop=main_loop ) app.state.scene_runner = scene_runner - await scene_runner.start() + # await scene_runner.start() + + yield # 允许请求处理 diff --git a/common/global_logger.py b/common/global_logger.py index 3e3a6a9..3e672ca 100644 --- a/common/global_logger.py +++ b/common/global_logger.py @@ -12,6 +12,8 @@ logger = logging.getLogger("casic_safe_logger") logger.setLevel(logging.DEBUG) # 设置日志级别 +logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING) + # 创建一个TimedRotatingFileHandler handler = logging.handlers.TimedRotatingFileHandler( os.path.join(log_dir, 'app.log'), # 日志文件名 diff --git a/db/database.py b/db/database.py index 4f44d2d..01ac013 100644 --- a/db/database.py +++ b/db/database.py @@ -1,22 +1,25 @@ +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession +from sqlalchemy.orm import sessionmaker from sqlmodel import SQLModel, create_engine, Session -from contextlib import contextmanager +from contextlib import contextmanager, asynccontextmanager sqlite_file_name = "./db/safe-algo-pro.db" -sqlite_url = f"sqlite:///{sqlite_file_name}" +sqlite_url = f"sqlite+aiosqlite:///{sqlite_file_name}" # 使用异步SQLite驱动 -connect_args = {"check_same_thread": False} -engine = create_engine(sqlite_url, connect_args=connect_args) +engine = create_async_engine(sqlite_url, echo=False, future=True) # 初始化数据库表 -def init_db(): - SQLModel.metadata.create_all(engine) +async def init_db(): + async with engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.create_all) -# 数据库会话管理 -def get_db(): - session = Session(engine) - try: +# 异步数据库会话管理 +# @asynccontextmanager +async def get_db() -> AsyncSession: + async_session = sessionmaker( + engine, class_=AsyncSession, expire_on_commit=False + ) + async with async_session() as session: yield session - finally: - session.close() diff --git a/main.py b/main.py index e35e458..e830e86 100644 --- a/main.py +++ b/main.py @@ -3,7 +3,9 @@ import logging from fastapi.openapi.docs import get_swagger_ui_html -from starlette.staticfiles import StaticFiles +from fastapi.middleware.cors import CORSMiddleware +from fastapi.staticfiles import StaticFiles + from app_instance import get_app from common.global_logger import logger @@ -28,6 +30,14 @@ from apis.router import router app.include_router(router, prefix="/api") +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + if __name__ == "__main__": # 重定向 uvicorn 的日志 uvicorn_logger = logging.getLogger("uvicorn") diff --git a/requirements.txt b/requirements.txt index 161192a..aa9e1e6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,6 @@ starlette uvicorn sqlalchemy -aiohttp \ No newline at end of file +aiohttp +aiosqlite +aiofiles \ No newline at end of file diff --git a/services/data_gas_service.py b/services/data_gas_service.py index fe6ce53..6ac5665 100644 --- a/services/data_gas_service.py +++ b/services/data_gas_service.py @@ -2,23 +2,24 @@ from typing import Optional, Tuple, Sequence from sqlalchemy import func -from sqlmodel import Session, select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlmodel import select from entity.data_gas import DataGas, DataGasInfo from entity.device import Device class DataGasService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db - def add_data_gas(self, data_gas: DataGas): + async def add_data_gas(self, data_gas: DataGas): self.db.add(data_gas) - self.db.commit() - self.db.refresh(data_gas) + await self.db.commit() + await self.db.refresh(data_gas) return data_gas - def get_data_gas_page(self, + async def get_data_gas_page(self, device_code: Optional[str] = None, start_time: Optional[datetime] = None, end_time: Optional[datetime] = None, @@ -29,13 +30,17 @@ # 查询总记录数 total_statement = select(func.count()).select_from(statement.subquery()) - total = self.db.exec(total_statement).one() + total_result = await self.db.execute(total_statement) + total = total_result.scalar_one() # 添加分页限制 statement = statement.offset(offset).limit(limit) # 执行查询并返回结果 - results = self.db.exec(statement).all() + results = await self.db.execute(statement) + rows = results.all() + + data_gas_info_list = [ DataGasInfo( id=data_gas.id, @@ -44,18 +49,19 @@ ts=data_gas.ts, device_name=device_name ) - for data_gas, device_name in results + for data_gas, device_name in rows ] return data_gas_info_list, total # 返回分页数据和总数 - def get_data_gas_list(self, + async def get_data_gas_list(self, device_code: Optional[str] = None, start_time: Optional[datetime] = None, end_time: Optional[datetime] = None, ) -> Sequence[DataGasInfo]: statement = self.gas_query(device_code, end_time, start_time) - results = self.db.exec(statement).all() + results = await self.db.execute(statement) + rows = results.all() data_gas_info_list = [ DataGasInfo( id=data_gas.id, @@ -64,7 +70,7 @@ ts=data_gas.ts, device_name=device_name ) - for data_gas, device_name in results + for data_gas, device_name in rows ] return data_gas_info_list diff --git a/services/device_frame_service.py b/services/device_frame_service.py index 5dead7d..1a468ae 100644 --- a/services/device_frame_service.py +++ b/services/device_frame_service.py @@ -3,7 +3,10 @@ from copy import deepcopy from datetime import datetime from typing import Sequence, Optional, Tuple -from sqlmodel import Session + +import aiofiles +import numpy as np +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import func @@ -13,46 +16,56 @@ import cv2 -from sqlmodel import Session, select, delete +from sqlmodel import select, delete from services.frame_analysis_result_service import FrameAnalysisResultService class DeviceFrameService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db - def add_frame(self, device_id, frame_data) -> DeviceFrame: - def save_frame_file(): - # 生成当前年月日作为目录路径 - current_date = datetime.now().strftime('%Y-%m-%d') - # 生成随机 UUID 作为文件名 - file_name = str(uuid.uuid4()) + ".jpeg" - # 创建保存图片的完整路径 - save_path = os.path.join('./storage/frames', current_date, file_name) - # 创建目录(如果不存在) - os.makedirs(os.path.dirname(save_path), exist_ok=True) - # 保存图片 - cv2.imwrite(save_path, frame_data) - return save_path + async def add_frame(self, device_id, frame_data) -> DeviceFrame: + async def add_frame(self, device_id, frame_data) -> 'DeviceFrame': + async def save_frame_file(): + # 生成当前年月日作为目录路径 + current_date = datetime.now().strftime('%Y-%m-%d') + # 生成随机 UUID 作为文件名 + file_name = f"{uuid.uuid4()}.jpeg" + # 创建保存图片的完整路径 + save_path = os.path.join('./storage/frames', current_date, file_name) + # 创建目录(如果不存在) + os.makedirs(os.path.dirname(save_path), exist_ok=True) - # 保存图片文件 - file_path = save_frame_file() - device_frame = DeviceFrame(device_id=device_id, frame_path=file_path) - self.db.add(device_frame) - self.db.commit() - self.db.refresh(device_frame) - return device_frame + # 将 frame_data 转换为二进制数据(假设 frame_data 是一个 numpy 数组) + _, encoded_image = cv2.imencode('.jpeg', frame_data) + image_data = encoded_image.tobytes() - def get_frame_page(self, - device_name: Optional[str] = None, - device_code: Optional[str] = None, - frame_start_time: Optional[datetime] = None, - frame_end_time: Optional[datetime] = None, - offset: int = 0, - limit: int = 10 - ) -> Tuple[Sequence[DeviceFrame], int]: + # 使用 aiofiles 进行异步写入 + async with aiofiles.open(save_path, 'wb') as f: + await f.write(image_data) + + return save_path + + # 异步保存图片文件 + file_path = await save_frame_file() + + # 创建并保存到数据库中 + device_frame = DeviceFrame(device_id=device_id, frame_path=file_path) + self.db.add(device_frame) + await self.db.commit() + await self.db.refresh(device_frame) + return device_frame + + async def get_frame_page(self, + device_name: Optional[str] = None, + device_code: Optional[str] = None, + frame_start_time: Optional[datetime] = None, + frame_end_time: Optional[datetime] = None, + offset: int = 0, + limit: int = 10 + ) -> Tuple[Sequence[DeviceFrame], int]: statement = ( select(DeviceFrame, Device) .join(Device, DeviceFrame.device_id == Device.id) @@ -69,26 +82,37 @@ # 查询总记录数 total_statement = select(func.count()).select_from(statement.subquery()) - total = self.db.exec(total_statement).one() + total_result = await self.db.execute(total_statement) + total = total_result.scalar_one() # 添加分页限制 statement = statement.offset(offset).limit(limit) # 执行查询并返回结果 - results = self.db.exec(statement) - results = results.all() - frames = [frame for frame, device in results] + results = await self.db.execute(statement) + rows = results.all() + + frames = [frame for frame, device in rows] return frames, total # 返回分页数据和总数 - def get_frame(self, frame_id: int): - return self.db.get(DeviceFrame, frame_id) + async def get_frame(self, frame_id: int): + result = await self.db.execute(select(DeviceFrame).where(DeviceFrame.id == frame_id)) + frame = result.scalar_one_or_none() + return frame - def get_frame_annotator(self, frame_id: int): - device_frame = self.get_frame(frame_id) + async def get_frame_annotator(self, frame_id: int): + device_frame = await self.get_frame(frame_id) if device_frame: - frame_image = cv2.imread(device_frame.frame_path) + # 异步读取图像文件 + async with aiofiles.open(device_frame.frame_path, mode='rb') as f: + file_content = await f.read() + + # 将读取的字节内容转换为 OpenCV 图像 + np_array = np.frombuffer(file_content, dtype=np.uint8) + frame_image = cv2.imdecode(np_array, cv2.IMREAD_COLOR) + frame_analysis_result_service = FrameAnalysisResultService(self.db) - results = frame_analysis_result_service.get_results_by_frame(device_frame.id) + results = await frame_analysis_result_service.get_results_by_frame(device_frame.id) if results: annotator = Annotator(deepcopy(frame_image)) height, width = frame_image.shape[:2] diff --git a/services/device_model_relation_service.py b/services/device_model_relation_service.py index c32b5a5..5f1aa9d 100644 --- a/services/device_model_relation_service.py +++ b/services/device_model_relation_service.py @@ -1,8 +1,8 @@ from datetime import datetime from typing import List -from sqlmodel import Session, select, delete - +from sqlmodel import select, delete +from sqlalchemy.ext.asyncio import AsyncSession from common.consts import NotifyChangeType from common.global_thread_pool import GlobalThreadPool from entity.device_model_relation import DeviceModelRelation, DeviceModelRelationInfo, DeviceModelRelationCreate @@ -10,7 +10,7 @@ class DeviceModelRelationService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db self.__relation_change_callbacks = [] # 用于存储回调函数 self.thread_pool = GlobalThreadPool() @@ -24,7 +24,7 @@ for callback in self.__relation_change_callbacks: self.thread_pool.executor.submit(callback, device_id, change_type) - def get_device_models(self, device_id: int) -> List[DeviceModelRelationInfo]: + async def get_device_models(self, device_id: int) -> List[DeviceModelRelationInfo]: statement = ( select(DeviceModelRelation, AlgoModel) .join(AlgoModel, DeviceModelRelation.algo_model_id == AlgoModel.id) @@ -32,7 +32,8 @@ ) # 执行联表查询 - result = self.db.exec(statement).all() + results = await self.db.execute(statement) + rows = results.all() models_info = [ DeviceModelRelationInfo( @@ -46,13 +47,11 @@ algo_model_path=model.path, algo_model_remark=model.remark, ) - for relation, model in result + for relation, model in rows ] return models_info - - - def add_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]): + async def add_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]): new_relations = [ DeviceModelRelation( algo_model_id=relation.algo_model_id, @@ -64,20 +63,21 @@ ) for relation in relations ] - self.db.add_all(new_relations) - self.db.commit() for relation in new_relations: - self.db.refresh(relation) + self.db.add(relation) + await self.db.commit() + for relation in new_relations: + await self.db.refresh(relation) return new_relations - def delete_relations_by_device(self, device_id: int): + async def delete_relations_by_device(self, device_id: int): statement = delete(DeviceModelRelation).where(DeviceModelRelation.device_id == device_id) - count = self.db.exec(statement) - self.db.commit() - return count.rowcount + result = await self.db.execute(statement) + await self.db.commit() + return result.rowcount - def update_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]): - self.delete_relations_by_device(device_id) - new_relations = self.add_relations_by_device(device_id, relations) + async def update_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]): + await self.delete_relations_by_device(device_id) + new_relations = await self.add_relations_by_device(device_id, relations) self.notify_change(device_id, NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE) return new_relations diff --git a/services/device_scene_relation_service.py b/services/device_scene_relation_service.py index 592ee98..3416831 100644 --- a/services/device_scene_relation_service.py +++ b/services/device_scene_relation_service.py @@ -1,8 +1,8 @@ from datetime import datetime from typing import List, Optional -from sqlmodel import Session, select, delete - +from sqlmodel import select, delete +from sqlalchemy.ext.asyncio import AsyncSession from common.consts import NotifyChangeType from common.global_thread_pool import GlobalThreadPool from entity.device_scene_relation import DeviceSceneRelationInfo, DeviceSceneRelation @@ -10,7 +10,7 @@ class DeviceSceneRelationService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db self.__relation_change_callbacks = [] # 用于存储回调函数 self.thread_pool = GlobalThreadPool() @@ -24,7 +24,7 @@ for callback in self.__relation_change_callbacks: self.thread_pool.executor.submit(callback, device_id, change_type) - def get_device_scene(self, device_id: int) -> Optional[DeviceSceneRelationInfo]: + async def get_device_scene(self, device_id: int) -> Optional[DeviceSceneRelationInfo]: statement = ( select(DeviceSceneRelation, Scene) .join(Scene, DeviceSceneRelation.scene_id == Scene.id) @@ -32,11 +32,12 @@ ) # 执行联表查询 - result = self.db.exec(statement).first() + result = await self.db.execute(statement) + result_row = result.first() scene_info = None - if result: - relation, scene = result[0], result[1] + if result_row: + relation, scene = result_row scene_info = DeviceSceneRelationInfo( id=relation.id, device_id=relation.device_id, @@ -48,23 +49,23 @@ ) return scene_info - def add_relation_by_device(self, device_id: int, scene_id: int): + async def add_relation_by_device(self, device_id: int, scene_id: int): new_relation = DeviceSceneRelation(device_id=device_id, scene_id=scene_id) new_relation.create_time = datetime.now() new_relation.update_time = datetime.now() self.db.add(new_relation) - self.db.commit() - self.db.refresh(new_relation) + await self.db.commit() + await self.db.refresh(new_relation) return new_relation - def delete_relation_by_device(self, device_id: int): + async def delete_relation_by_device(self, device_id: int): statement = delete(DeviceSceneRelation).where(DeviceSceneRelation.device_id == device_id) - count = self.db.exec(statement) - self.db.commit() - return count.rowcount + result = await self.db.execute(statement) + await self.db.commit() + return result.rowcount - def update_relation_by_device(self, device_id: int, scene_id: int): - self.delete_relation_by_device(device_id) - new_relation = self.add_relation_by_device(device_id, scene_id) + async def update_relation_by_device(self, device_id: int, scene_id: int): + await self.delete_relation_by_device(device_id) + new_relation = await self.add_relation_by_device(device_id, scene_id) self.notify_change(device_id, NotifyChangeType.DEVICE_SCENE_RELATION_UPDATE) return new_relation diff --git a/services/device_service.py b/services/device_service.py index cad46d3..2488fd5 100644 --- a/services/device_service.py +++ b/services/device_service.py @@ -4,8 +4,8 @@ from typing import Sequence, Optional, Tuple from sqlalchemy import func -from sqlmodel import Session, select - +from sqlmodel import select, delete +from sqlalchemy.ext.asyncio import AsyncSession from common.device_status_manager import DeviceStatusManager from common.global_thread_pool import GlobalThreadPool from common.consts import NotifyChangeType, DEVICE_MODE @@ -15,25 +15,25 @@ class DeviceService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db self.__device_change_callbacks = [] # 用于存储回调函数 self.thread_pool = GlobalThreadPool() # 创建一个独立的事件循环并启动线程 - self.loop = asyncio.new_event_loop() - self.loop_thread = threading.Thread(target=self._start_loop, daemon=True) - self.loop_thread.start() + # self.loop = asyncio.new_event_loop() + # self.loop_thread = threading.Thread(target=self._start_loop, daemon=True) + # self.loop_thread.start() - def _start_loop(self): - """后台线程运行事件循环""" - asyncio.set_event_loop(self.loop) - self.loop.run_forever() - - def shutdown(self): - """清理事件循环和线程""" - self.loop.call_soon_threadsafe(self.loop.stop) - self.loop_thread.join() + # def _start_loop(self): + # """后台线程运行事件循环""" + # asyncio.set_event_loop(self.loop) + # self.loop.run_forever() + # + # def shutdown(self): + # """清理事件循环和线程""" + # self.loop.call_soon_threadsafe(self.loop.stop) + # self.loop_thread.join() def register_change_callback(self, callback): """注册设备变化回调函数""" @@ -42,52 +42,53 @@ def notify_change(self, device_id, change_type): """当设备发生变化时,调用回调通知变化""" # loop = asyncio.get_event_loop() # 获取当前的事件循环 - for callback in self.__device_change_callbacks: - if asyncio.iscoroutinefunction(callback): - # 如果是协程函数,使用事件循环运行它 - # loop = asyncio.new_event_loop() - # asyncio.set_event_loop(loop) - asyncio.run_coroutine_threadsafe(callback(device_id, change_type), self.loop) - else: - # 如果是普通函数,直接提交到线程池 - self.thread_pool.executor.submit(callback, device_id, change_type) + # for callback in self.__device_change_callbacks: + # if asyncio.iscoroutinefunction(callback): + # # 如果是协程函数,使用事件循环运行它 + # # loop = asyncio.new_event_loop() + # # asyncio.set_event_loop(loop) + # asyncio.run_coroutine_threadsafe(callback(device_id, change_type), self.loop) + # else: + # # 如果是普通函数,直接提交到线程池 + # self.thread_pool.executor.submit(callback, device_id, change_type) - def get_device_list(self, - name: Optional[str] = None, - code: Optional[str] = None, - device_type: Optional[str] = None, - ) -> Sequence[Device]: + async def get_device_list(self, + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + ) -> Sequence[Device]: statement = self.device_query(code, device_type, name) - results = self.db.exec(statement) - return results.all() + results = await self.db.execute(statement) + return results.scalars().all() - def get_device_page(self, - name: Optional[str] = None, - code: Optional[str] = None, - device_type: Optional[str] = None, - offset: int = 0, - limit: int = 10 - ) -> Tuple[Sequence[DeviceInfo], int]: + async def get_device_page(self, + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + offset: int = 0, + limit: int = 10 + ) -> Tuple[Sequence[DeviceInfo], int]: statement = self.device_query(code, device_type, name) # 查询总记录数 total_statement = select(func.count()).select_from(statement.subquery()) - total = self.db.exec(total_statement).one() + total_result = await self.db.execute(total_statement) + total = total_result.scalar_one() # 添加分页限制 statement = statement.offset(offset).limit(limit) # 执行查询并返回结果 - results = self.db.exec(statement) - device_list = results.all() + results = await self.db.execute(statement) + device_list = results.scalars().all() device_info_list = [] if device_list: device_model_relation_service = DeviceModelRelationService(self.db) device_scene_relation_service = DeviceSceneRelationService(self.db) device_status_manager = DeviceStatusManager() for device in device_list: - model_relations = device_model_relation_service.get_device_models(device.id) - scene_relation = device_scene_relation_service.get_device_scene(device.id) + model_relations = await device_model_relation_service.get_device_models(device.id) + scene_relation = await device_scene_relation_service.get_device_scene(device.id) device_info_list.append(DeviceInfo( id=device.id, @@ -121,20 +122,20 @@ statement = statement.where(Device.type == device_type) return statement - def create_device(self, device_data: DeviceCreate): + async def create_device(self, device_data: DeviceCreate): device = Device.model_validate(device_data) device.create_time = datetime.now() device.update_time = datetime.now() - self.handle_device_mode(device) + await self.handle_device_mode(device) self.db.add(device) - self.db.commit() - self.db.refresh(device) + await self.db.commit() + await self.db.refresh(device) self.notify_change(device.id, NotifyChangeType.DEVICE_CREATE) return device - def update_device(self, device_data: DeviceUpdate): - device_old = self.db.get(Device, device_data.id) - device = self.db.get(Device, device_data.id) + async def update_device(self, device_data: DeviceUpdate): + device_old = await self.get_device(device_data.id) + device = await self.get_device(device_data.id) if not device: return None @@ -142,41 +143,45 @@ for key, value in update_data.items(): setattr(device, key, value) - self.handle_device_mode(device) + await self.handle_device_mode(device) device.update_time = datetime.now() self.db.add(device) - self.db.commit() - self.db.refresh(device) + await self.db.commit() + await self.db.refresh(device) self.notify_change(device.id, NotifyChangeType.DEVICE_UPDATE) return device - def delete_device(self, device_id: int): - device = self.db.get(Device, device_id) + async def delete_device(self, device_id: int): + device = await self.get_device(device_id) if not device: return None - self.db.delete(device) - self.db.commit() + statement = delete(Device).where(Device.id == device_id) + await self.db.execute(statement) + await self.db.commit() + self.notify_change(device.id, NotifyChangeType.DEVICE_DELETE) model_relation_service = DeviceModelRelationService(self.db) - model_relation_service.delete_relations_by_device(device_id) + await model_relation_service.delete_relations_by_device(device_id) scene_relation_service = DeviceSceneRelationService(self.db) - scene_relation_service.delete_relation_by_device(device.id) + await scene_relation_service.delete_relation_by_device(device.id) return device - def handle_device_mode(self, device): + async def handle_device_mode(self, device): if device.mode == DEVICE_MODE.ALGO: scene_relation_service = DeviceSceneRelationService(self.db) - scene_relation_service.delete_relation_by_device(device.id) + await scene_relation_service.delete_relation_by_device(device.id) elif device.mode == DEVICE_MODE.SCENE: model_relation_service = DeviceModelRelationService(self.db) - model_relation_service.delete_relations_by_device(device.id) + await model_relation_service.delete_relations_by_device(device.id) else: scene_relation_service = DeviceSceneRelationService(self.db) - scene_relation_service.delete_relation_by_device(device.id) + await scene_relation_service.delete_relation_by_device(device.id) model_relation_service = DeviceModelRelationService(self.db) - model_relation_service.delete_relations_by_device(device.id) + await model_relation_service.delete_relations_by_device(device.id) - def get_device(self, device_id: int): - return self.db.get(Device, device_id) + async def get_device(self, device_id: int): + result = await self.db.execute(select(Device).where(Device.id == device_id)) + frame = result.scalar_one_or_none() + return frame diff --git a/services/frame_analysis_result_service.py b/services/frame_analysis_result_service.py index dbd272d..4e6c145 100644 --- a/services/frame_analysis_result_service.py +++ b/services/frame_analysis_result_service.py @@ -1,23 +1,25 @@ from typing import List -from sqlmodel import Session, select +from sqlmodel import select +from sqlalchemy.ext.asyncio import AsyncSession from entity.frame_analysis_result import FrameAnalysisResultCreate, FrameAnalysisResult class FrameAnalysisResultService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db - def add_frame_analysis_results(self, results: List[FrameAnalysisResultCreate]): + async def add_frame_analysis_results(self, results: List[FrameAnalysisResultCreate]): new_results = [FrameAnalysisResult.model_validate(result) for result in results] - self.db.add_all(new_results) - self.db.commit() for result in new_results: - self.db.refresh(result) + self.db.add(result) + await self.db.commit() + for result in new_results: + await self.db.refresh(result) return new_results - def get_results_by_frame(self, frame_id): + async def get_results_by_frame(self, frame_id): statement = select(FrameAnalysisResult).where(FrameAnalysisResult.frame_id == frame_id) - results = self.db.exec(statement) - return results.all() + results = await self.db.execute(statement) + return results.scalars().all() diff --git a/services/global_config.py b/services/global_config.py index 7dae73b..fa53b2b 100644 --- a/services/global_config.py +++ b/services/global_config.py @@ -1,3 +1,5 @@ +import asyncio + from common.consts import PUSH_TYPE from db.database import get_db from entity.push_config import PushConfig @@ -6,10 +8,12 @@ class GlobalConfig: _instance = None + _lock = asyncio.Lock() def __new__(cls): if cls._instance is None: cls._instance = super().__new__(cls) + return cls._instance def __init__(self): @@ -19,49 +23,58 @@ self.gas_push_config = None self.algo_result_push_config = None self.alarm_push_config = None - self.init_config() # 进行初始化 + self._init_done = False - def init_config(self): - # 初始化配置逻辑 - with next(get_db()) as db: - self.config_service = PushConfigService(db) - self.set_gas_push_config(self.config_service.get_push_config(PUSH_TYPE.GAS)) - self.set_algo_result_push_config(self.config_service.get_push_config(PUSH_TYPE.ALGO_RESULT)) - self.set_alarm_push_config(self.config_service.get_push_config(PUSH_TYPE.ALARM)) + async def _initialize(self): + await self.init_config() # 调用异步初始化 - self.config_service.register_change_callback(self.on_config_change) + async def init_config(self): + # 确保只初始化一次 + if not self._init_done: + async with self._lock: + if not self._init_done: # 双重检查锁,避免多次初始化 + async for db in get_db(): + self.config_service = PushConfigService(db) + self.config_service.register_change_callback(self.on_config_change) + self.gas_push_config = await self.config_service.get_push_config(PUSH_TYPE.GAS) + self.algo_result_push_config = await self.config_service.get_push_config(PUSH_TYPE.ALGO_RESULT) + self.alarm_push_config = await self.config_service.get_push_config(PUSH_TYPE.ALARM) + self._init_done = True - def on_config_change(self, config: PushConfig): + async def on_config_change(self, config: PushConfig): if config.push_type == PUSH_TYPE.GAS: - self.set_gas_push_config(config) + await self.set_gas_push_config(config) elif config.push_type == PUSH_TYPE.ALGO_RESULT: - self.set_algo_result_push_config(config) + await self.set_algo_result_push_config(config) elif config.push_type == PUSH_TYPE.ALARM: - self.set_alarm_push_config(config) + await self.set_alarm_push_config(config) def get_gas_push_config(self): """获取 gas_push_config 配置""" return self.gas_push_config - def set_gas_push_config(self, config): + async def set_gas_push_config(self, config): """设置 gas_push_config 配置""" if config: - self.gas_push_config = config + async with self._lock: + self.gas_push_config = config def get_algo_result_push_config(self): """获取 algo_result_push_config 配置""" return self.algo_result_push_config - def set_algo_result_push_config(self, config): + async def set_algo_result_push_config(self, config): """设置 algo_result_push_config 配置""" if config: - self.algo_result_push_config = config + async with self._lock: + self.algo_result_push_config = config def get_alarm_push_config(self): """获取 algo_result_push_config 配置""" return self.alarm_push_config - def set_alarm_push_config(self, config): + async def set_alarm_push_config(self, config): """设置 algo_result_push_config 配置""" if config: - self.alarm_push_config = config + async with self._lock: + self.alarm_push_config = config diff --git a/services/model_service.py b/services/model_service.py index f9531a5..afc61d9 100644 --- a/services/model_service.py +++ b/services/model_service.py @@ -1,3 +1,4 @@ +import asyncio import os import uuid import zipfile @@ -5,10 +6,11 @@ from pathlib import Path from typing import List, Sequence, Optional, Tuple, Type +import aiofiles from fastapi import UploadFile from sqlalchemy import func -from sqlmodel import Session, select - +from sqlmodel import select, delete +from sqlalchemy.ext.asyncio import AsyncSession from common.biz_exception import BizException from common.string_utils import snake_to_camel from entity.device_model_relation import DeviceModelRelation @@ -18,7 +20,7 @@ class ModelService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db self.__model_change_callbacks = [] # 用于存储回调函数 self.thread_pool = GlobalThreadPool() @@ -32,15 +34,15 @@ for callback in self.__model_change_callbacks: self.thread_pool.executor.submit(callback, algo_model_id, change_type) - def get_model_list(self, - name: Optional[str] = None, - remark: Optional[str] = None, - ) -> Sequence[AlgoModel]: + async def get_model_list(self, + name: Optional[str] = None, + remark: Optional[str] = None, + ) -> Sequence[AlgoModel]: statement = self.model_query(name, remark) - results = self.db.exec(statement) - return results.all() + results = await self.db.execute(statement) + return results.scalars().all() - def get_model_page(self, + async def get_model_page(self, name: Optional[str] = None, remark: Optional[str] = None, offset: int = 0, @@ -50,19 +52,21 @@ # 查询总记录数 total_statement = select(func.count()).select_from(statement.subquery()) - total = self.db.exec(total_statement).one() + total_result = await self.db.execute(total_statement) + total = total_result.scalar_one() # 添加分页限制 statement = statement.offset(offset).limit(limit) # 执行查询并返回结果 - model_list = self.db.exec(statement) + model_list = await self.db.execute(statement) + rows = model_list.scalars().all() model_info_list: List[AlgoModelInfo] = [] - if model_list: - for model in model_list: + if rows: + for model in rows: model_info_list.append(AlgoModelInfo( **model.dict(), - usage_status="使用中" if self.get_model_usage(model.id) else "未使用" + usage_status="使用中" if await self.get_model_usage(model.id) else "未使用" )) return model_info_list, total # 返回分页数据和总数 @@ -76,7 +80,7 @@ statement = statement.where(AlgoModel.remark.like(f"%{remark}%")) return statement - def process_zip(self, file: UploadFile): + async def process_zip(self, file: UploadFile): model_dir = Path('weights/') model_handle_dir = Path('model_handler/') model_dir.mkdir(parents=True, exist_ok=True) @@ -86,13 +90,16 @@ # 临时保存上传文件 temp_path = Path(f"temp_upload_{uuid.uuid4()}.zip") - with open(temp_path, "wb") as temp_file: - temp_file.write(file.file.read()) + async with aiofiles.open(temp_path, "wb") as temp_file: + content = await file.read() # 异步读取上传文件内容 + await temp_file.write(content) model_file_path = None handle_file_path = None try: + # 使用异步方法读取压缩文件的内容 + loop = asyncio.get_event_loop() with zipfile.ZipFile(temp_path, 'r') as zip_ref: # 获取压缩包文件列表 file_list = zip_ref.namelist() @@ -105,47 +112,47 @@ message=f"Model weight file ({', '.join(SUPPORTED_MODEL_EXTENSIONS)}) is required in the zip." ) - # 解压模型文件到模型目录 - zip_ref.extract(model_file, model_dir) + # 异步解压模型文件到模型目录 model_file_path = model_dir / model_file + await loop.run_in_executor(None, zip_ref.extract, model_file, model_dir) # 检查是否有可选的 Python 脚本 handle_file = next((f for f in file_list if f.endswith(".py")), None) if handle_file: - zip_ref.extract(handle_file, model_handle_dir) handle_file_path = model_handle_dir / handle_file - + await loop.run_in_executor(None, zip_ref.extract, handle_file, model_handle_dir) except zipfile.BadZipFile: raise BizException(status_code=400, message="Invalid zip file.") finally: # 删除临时文件 - temp_path.unlink() + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, temp_path.unlink) return str(model_file_path), str(handle_file_path) if handle_file_path else None - def create_model(self, model_data: AlgoModelCreate, file: UploadFile): - self.process_model_file(file, model_data) + async def create_model(self, model_data: AlgoModelCreate, file: UploadFile): + await self.process_model_file(file, model_data) model = AlgoModel.model_validate(model_data) model.create_time = datetime.now() model.update_time = datetime.now() self.db.add(model) - self.db.commit() - self.db.refresh(model) + await self.db.commit() + await self.db.refresh(model) return model - def process_model_file(self, file, model): - model_file_path, handle_file_path = self.process_zip(file) + async def process_model_file(self, file, model): + model_file_path, handle_file_path = await self.process_zip(file) model.path = model_file_path if handle_file_path: model.handle_task = snake_to_camel(os.path.splitext(os.path.basename(handle_file_path))[0]) else: model.handle_task = 'BaseModelHandler' - def update_model(self, model_data: AlgoModelUpdate, file: UploadFile): - model = self.db.get(AlgoModel, model_data.id) + async def update_model(self, model_data: AlgoModelUpdate, file: UploadFile): + model = await self.get_model_by_id(model_data.id) if not model: return None @@ -155,16 +162,16 @@ model.update_time = datetime.now() if file: - self.process_model_file(file, model) + await self.process_model_file(file, model) self.db.add(model) - self.db.commit() - self.db.refresh(model) + await self.db.commit() + await self.db.refresh(model) self.notify_change(model.id, NotifyChangeType.MODEL_UPDATE) return model - def delete_model(self, model_id: int): - model = self.db.get(AlgoModel, model_id) + async def delete_model(self, model_id: int): + model = await self.get_model_by_id(model_id) if not model: return None # 查询 device_model_relation 中是否存在启用的绑定关系 @@ -173,17 +180,20 @@ .where(DeviceModelRelation.algo_model_id == model_id) .where(DeviceModelRelation.is_use == 1) ) - relation_in_use = self.db.exec(statement).first() + relation_in_use_exec = await self.db.execute(statement) + relation_in_use = relation_in_use_exec.scalars().first() # 如果存在启用的绑定关系,提示无法删除 if relation_in_use: raise BizException(message=f"模型 {model.name} 正在被设备使用,无法删除") - self.db.delete(model) - self.db.commit() + statement = delete(AlgoModel).where(AlgoModel.id == model_id) + await self.db.execute(statement) + await self.db.commit() + return model - def get_models_in_use(self) -> Sequence[AlgoModel]: + async def get_models_in_use(self) -> Sequence[AlgoModel]: """获取所有在 device_model_relation 表里有启用绑定关系的模型信息""" statement = ( select(AlgoModel) @@ -191,10 +201,10 @@ .where(DeviceModelRelation.is_use == 1) .group_by(AlgoModel.id) ) - results = self.db.exec(statement).all() - return results + results = await self.db.execute(statement) + return results.scalars().all() - def get_model_usage(self, algo_model_id) -> bool: + async def get_model_usage(self, algo_model_id) -> bool: statement = ( select(DeviceModelRelation) .where( @@ -202,8 +212,11 @@ DeviceModelRelation.algo_model_id == algo_model_id, ) ) - result = self.db.exec(statement).all() - return len(result) > 0 + result = await self.db.execute(statement) + rows = result.all() + return len(rows) > 0 - def get_model_by_id(self, model_id): - return self.db.get(AlgoModel, model_id) + async def get_model_by_id(self, model_id): + result = await self.db.execute(select(AlgoModel).where(AlgoModel.id == model_id)) + model = result.scalar_one_or_none() + return model diff --git a/services/push_config_service.py b/services/push_config_service.py index 8a2c298..14862df 100644 --- a/services/push_config_service.py +++ b/services/push_config_service.py @@ -1,6 +1,8 @@ +import asyncio from datetime import datetime -from sqlmodel import Session, select +from sqlmodel import select +from sqlalchemy.ext.asyncio import AsyncSession from entity.push_config import PushConfigCreate, PushConfig @@ -13,7 +15,7 @@ cls._instance = super().__new__(cls) return cls._instance - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): if not hasattr(self, 'initialized'): self.db = db self.__push_change_callbacks = [] # 用于存储回调函数 @@ -23,12 +25,22 @@ """注册设备变化回调函数""" self.__push_change_callbacks.append(callback) - def notify_change(self, push_config): - for callback in self.__push_change_callbacks: - callback(push_config) + # def notify_change(self, push_config): + # for callback in self.__push_change_callbacks: + # callback(push_config) - def set_push_config(self, push_config_create: PushConfigCreate): - push_config = self.get_push_config(push_config_create.push_type) + def notify_change(self, config): + """通知所有回调函数""" + for callback in self.__push_change_callbacks: + if asyncio.iscoroutinefunction(callback): + # 如果是异步函数,使用 asyncio.create_task() 调度 + asyncio.create_task(callback(config)) + else: + # 如果是同步函数,直接调用 + callback(config) + + async def set_push_config(self, push_config_create: PushConfigCreate): + push_config = await self.get_push_config(push_config_create.push_type) if push_config: update_data = push_config_create.dict(exclude_unset=True) for key, value in update_data.items(): @@ -40,18 +52,18 @@ push_config.update_time = datetime.now() self.db.add(push_config) - self.db.commit() - self.db.refresh(push_config) + await self.db.commit() + await self.db.refresh(push_config) self.notify_change(push_config) return push_config - def get_push_config(self, push_type): + async def get_push_config(self, push_type): statement = select(PushConfig).where(PushConfig.push_type == push_type) - results = self.db.exec(statement) - return results.first() + results = await self.db.execute(statement) + return results.scalars().first() - def get_push_config_list(self): + async def get_push_config_list(self): statement = select(PushConfig) - results = self.db.exec(statement) - return results.all() \ No newline at end of file + results = await self.db.execute(statement) + return results.scalars().all() diff --git a/services/scene_service.py b/services/scene_service.py index 17c8f52..dc4c67b 100644 --- a/services/scene_service.py +++ b/services/scene_service.py @@ -1,3 +1,4 @@ +import asyncio import os import uuid import zipfile @@ -5,10 +6,11 @@ from pathlib import Path from typing import Optional, Sequence, Tuple, List +import aiofiles from fastapi import UploadFile from sqlalchemy import func -from sqlmodel import Session, select - +from sqlmodel import select, delete +from sqlalchemy.ext.asyncio import AsyncSession from common.biz_exception import BizException from common.consts import NotifyChangeType from common.global_thread_pool import GlobalThreadPool @@ -18,7 +20,7 @@ class SceneService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db self.__scene_change_callbacks = [] # 用于存储回调函数 self.thread_pool = GlobalThreadPool() @@ -32,37 +34,39 @@ for callback in self.__scene_change_callbacks: self.thread_pool.executor.submit(callback, scene_id, change_type) - def get_scene_list(self, - name: Optional[str] = None, - remark: Optional[str] = None, - ) -> Sequence[Scene]: + async def get_scene_list(self, + name: Optional[str] = None, + remark: Optional[str] = None, + ) -> Sequence[Scene]: statement = self.scene_query(name, remark) - results = self.db.exec(statement) - return results.all() + results = await self.db.execute(statement) + return results.scalars().all() - def get_scene_page(self, - name: Optional[str] = None, - remark: Optional[str] = None, - offset: int = 0, - limit: int = 10 - ) -> Tuple[Sequence[SceneInfo], int]: + async def get_scene_page(self, + name: Optional[str] = None, + remark: Optional[str] = None, + offset: int = 0, + limit: int = 10 + ) -> Tuple[Sequence[SceneInfo], int]: statement = self.scene_query(name, remark) # 查询总记录数 total_statement = select(func.count()).select_from(statement.subquery()) - total = self.db.exec(total_statement).one() + total_result = await self.db.execute(total_statement) + total = total_result.scalar_one() # 添加分页限制 statement = statement.offset(offset).limit(limit) # 执行查询并返回结果 - scene_list = self.db.exec(statement) + scene_list = await self.db.execute(statement) + rows = scene_list.scalars().all() scene_info_list: List[SceneInfo] = [] if scene_list: - for scene in scene_list: + for scene in rows: scene_info_list.append(SceneInfo( **scene.dict(), - usage_status="使用中" if self.get_scene_usage(scene.id) else "未使用" + usage_status="使用中" if await self.get_scene_usage(scene.id) else "未使用" )) return scene_info_list, total # 返回分页数据和总数 @@ -76,7 +80,7 @@ statement = statement.where(Scene.remark.like(f"%{remark}%")) return statement - def process_zip(self, file: UploadFile): + async def process_zip(self, file: UploadFile): model_dir = Path('weights/') scene_handle_dir = Path('scene_handler/') model_dir.mkdir(parents=True, exist_ok=True) @@ -86,61 +90,65 @@ # 临时保存上传文件 temp_path = Path(f"temp_upload_{uuid.uuid4()}.zip") - with open(temp_path, "wb") as temp_file: - temp_file.write(file.file.read()) + async with aiofiles.open(temp_path, "wb") as temp_file: + content = await file.read() # 异步读取上传文件内容 + await temp_file.write(content) model_file_paths = [] handle_file_path = None try: + # 使用异步方法读取压缩文件的内容 + loop = asyncio.get_event_loop() with zipfile.ZipFile(temp_path, 'r') as zip_ref: # 获取压缩包文件列表 file_list = zip_ref.namelist() model_files = [f for f in file_list if Path(f).suffix in SUPPORTED_MODEL_EXTENSIONS] - # 解压所有模型文件到模型目录 + # 异步解压所有模型文件到模型目录 for model_file in model_files: - zip_ref.extract(model_file, model_dir) + await loop.run_in_executor(None, zip_ref.extract, model_file, model_dir) model_file_paths.append(model_dir / model_file) # 检查是否有可选的 Python 脚本 handle_file = next((f for f in file_list if f.endswith(".py")), None) if handle_file: - zip_ref.extract(handle_file, scene_handle_dir) + await loop.run_in_executor(None, zip_ref.extract, handle_file, scene_handle_dir) handle_file_path = scene_handle_dir / handle_file else: raise BizException( status_code=400, - message=f"handle file (.py) is required in the zip." + message=f"Handle file (.py) is required in the zip." ) except zipfile.BadZipFile: raise BizException(status_code=400, message="Invalid zip file.") finally: # 删除临时文件 - temp_path.unlink() + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, temp_path.unlink) return [str(path) for path in model_file_paths], str(handle_file_path) - def process_scene_file(self, file, scene): - model_file_paths, handle_file_path = self.process_zip(file) + async def process_scene_file(self, file, scene): + model_file_paths, handle_file_path = await self.process_zip(file) scene.handle_task = snake_to_camel(os.path.splitext(os.path.basename(handle_file_path))[0]) - def create_scene(self, scene_data: SceneCreate, file: UploadFile): - self.process_scene_file(file, scene_data) + async def create_scene(self, scene_data: SceneCreate, file: UploadFile): + await self.process_scene_file(file, scene_data) scene = Scene.model_validate(scene_data) scene.create_time = datetime.now() scene.update_time = datetime.now() self.db.add(scene) - self.db.commit() - self.db.refresh(scene) + await self.db.commit() + await self.db.refresh(scene) return scene - def update_scene(self, scene_data: SceneUpdate, file: UploadFile): - scene = self.db.get(Scene, scene_data.id) + async def update_scene(self, scene_data: SceneUpdate, file: UploadFile): + scene = await self.get_scene_by_id(scene_data.id) if not scene: return None @@ -150,16 +158,16 @@ scene.update_time = datetime.now() if file: - self.process_scene_file(file, scene) + await self.process_scene_file(file, scene) self.db.add(scene) - self.db.commit() - self.db.refresh(scene) + await self.db.commit() + await self.db.refresh(scene) self.notify_change(scene.id, NotifyChangeType.SCENE_UPDATE) return scene - def delete_scene(self, scene_id: int): - scene = self.db.get(Scene, scene_id) + async def delete_scene(self, scene_id: int): + scene = await self.get_scene_by_id(scene_id) if not scene: return None # 查询 device_scene_relation 中是否存在启用的绑定关系 @@ -167,35 +175,40 @@ select(DeviceSceneRelation) .where(DeviceSceneRelation.scene_id == scene_id) ) - relation_in_use = self.db.exec(statement).first() + relation_in_use_exec = self.db.execute(statement) + relation_in_use = relation_in_use_exec.scalars().first() # 如果存在启用的绑定关系,提示无法删除 if relation_in_use: raise BizException(message=f"场景 {scene.name} 正在被设备使用,无法删除") - self.db.delete(scene) - self.db.commit() + statement = delete(Scene).where(Scene.id == scene_id) + await self.db.execute(statement) + await self.db.commit() return scene - def get_scenes_in_use(self) -> Sequence[Scene]: + async def get_scenes_in_use(self) -> Sequence[Scene]: """获取所有在 device_scene_relation 表里有绑定关系的模型信息""" statement = ( select(Scene) .join(DeviceSceneRelation, DeviceSceneRelation.scene_id == Scene.id) .group_by(Scene.id) ) - results = self.db.exec(statement).all() - return results + results = await self.db.execute(statement) + return results.scalars().all() - def get_scene_usage(self, scene_id) -> bool: + async def get_scene_usage(self, scene_id) -> bool: statement = ( select(DeviceSceneRelation) .where( DeviceSceneRelation.scene_id == scene_id, ) ) - result = self.db.exec(statement).all() - return len(result) > 0 + result = await self.db.exec(statement) + rows = result.all() + return len(rows) > 0 - def get_scene_by_id(self, scene_id): - return self.db.get(Scene, scene_id) \ No newline at end of file + async def get_scene_by_id(self, scene_id): + result = await self.db.execute(select(Scene).where(Scene.id == scene_id)) + scene = result.scalar_one_or_none() + return scene diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index dbc3fee..cdb97bf 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -98,7 +98,7 @@ def push_frame_results(self, frame_results): global_config = GlobalConfig() push_config = global_config.get_algo_result_push_config() - if push_config: + if push_config and push_config.push_url: last_ts = self.push_ts current_time = datetime.now() diff --git a/apis/control.py b/apis/control.py index 8517389..4c108c5 100644 --- a/apis/control.py +++ b/apis/control.py @@ -48,7 +48,7 @@ @router.get("/restart") -def restart(): +async def restart(): try: # 立即返回响应的函数 def restart_container_async(): diff --git a/apis/data_gas.py b/apis/data_gas.py index fc8ec54..bcb7cb6 100644 --- a/apis/data_gas.py +++ b/apis/data_gas.py @@ -3,8 +3,8 @@ import pandas as pd from fastapi import APIRouter, Query, Depends -from openpyxl.utils import get_column_letter -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession + from starlette.responses import StreamingResponse from apis.base import StandardResponse, PageResponse, convert_page_param, standard_response @@ -17,19 +17,20 @@ @router.get("/page", response_model=StandardResponse[PageResponse[DataGasInfo]]) -def get_gas_page( +async def get_gas_page( device_code: Optional[str] = None, start_time: Optional[str] = None, end_time: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): + service = DataGasService(db) offset, limit = convert_page_param(offset, limit) - data, total = service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), - offset, limit) + data, total = await service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), + offset, limit) return standard_response( data=PageResponse(total=total, items=data) @@ -37,38 +38,39 @@ @router.get("/export") -def export_data_gas(device_code: Optional[str] = None, - start_time: Optional[str] = None, - end_time: Optional[str] = None, - db: Session = Depends(get_db)): - service = DataGasService(db) - data = service.get_data_gas_list(device_code, start_time, end_time) +async def export_data_gas(device_code: Optional[str] = None, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + db: AsyncSession = Depends(get_db)): + async with db: + service = DataGasService(db) + data = await service.get_data_gas_list(device_code, start_time, end_time) - # 将查询结果转换为 DataFrame - data = [ - {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} - for item in data] - df = pd.DataFrame(data) + # 将查询结果转换为 DataFrame + data = [ + {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} + for item in data] + df = pd.DataFrame(data) - # 使用 BytesIO 生成内存中的 Excel 文件 - output = BytesIO() - with pd.ExcelWriter(output, engine="openpyxl") as writer: - df.to_excel(writer, index=False) + # 使用 BytesIO 生成内存中的 Excel 文件 + output = BytesIO() + with pd.ExcelWriter(output, engine="openpyxl") as writer: + df.to_excel(writer, index=False) - # 获取工作表 - worksheet = writer.sheets["Sheet1"] + # 获取工作表 + worksheet = writer.sheets["Sheet1"] - # 设置固定的列宽,例如宽度为 20 - fixed_width = 30 - for col in worksheet.columns: - col_letter = col[0].column_letter # 获取列字母 - worksheet.column_dimensions[col_letter].width = fixed_width + # 设置固定的列宽,例如宽度为 20 + fixed_width = 30 + for col in worksheet.columns: + col_letter = col[0].column_letter # 获取列字母 + worksheet.column_dimensions[col_letter].width = fixed_width - output.seek(0) + output.seek(0) - # 返回内存中的 Excel 文件作为响应 - return StreamingResponse( - output, - media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", - headers={"Content-Disposition": "attachment; filename=gas.xlsx"} - ) + # 返回内存中的 Excel 文件作为响应 + return StreamingResponse( + output, + media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + headers={"Content-Disposition": "attachment; filename=gas.xlsx"} + ) diff --git a/apis/device.py b/apis/device.py index 85fb10c..f170ae3 100644 --- a/apis/device.py +++ b/apis/device.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db @@ -18,29 +18,29 @@ @router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) -def get_device_list( +async def get_device_list( name: Optional[str] = None, code: Optional[str] = None, device_type: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceService(db) - devices = list(service.get_device_list(name, code, device_type)) - return standard_response(data=devices) + devices = await service.get_device_list(name, code, device_type) + return standard_response(data=list(devices)) @router.get("/page/", response_model=StandardResponse[PageResponse[DeviceInfo]]) -def get_device_page( +async def get_device_page( name: Optional[str] = None, code: Optional[str] = None, device_type: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - devices, total = service.get_device_page(name, code, device_type, offset, limit) + devices, total = await service.get_device_page(name, code, device_type, offset, limit) return standard_response( data=PageResponse(total=total, items=devices) @@ -48,22 +48,22 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): - device = service.create_device(device_data) +async def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): + device = await service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): - device = service.update_device(device_data) +async def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): + device = await service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") return standard_response(data=device) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, service: DeviceService = Depends(get_service)): - device = service.delete_device(device_id) +async def delete_device(device_id: int, service: DeviceService = Depends(get_service)): + device = await service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") return standard_response(data=device_id) diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py index e61e1d6..ddb5eb6 100644 --- a/apis/device_model_realtion.py +++ b/apis/device_model_realtion.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import StandardResponse, standard_response from db.database import get_db @@ -13,37 +13,23 @@ router = APIRouter() app = get_app() + def get_service(): return app.state.model_relation_service @router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) -def list_by_device( +async def list_by_device( device_id: int, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceModelRelationService(db) - models = list(service.get_device_models(device_id)) - return standard_response(data=models) - - -# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) -# def add_by_device(relation_data: List[DeviceModelRelationCreate], -# device_id: int = Query(...), -# db: Session = Depends(get_db)): -# service = DeviceModelRelationService(db) -# relations = service.add_relations_by_device(device_id, relation_data) -# return standard_response(data=relations) + models = await service.get_device_models(device_id) + return standard_response(data=list(models)) @router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) -def update_by_device(relation_data: List[DeviceModelRelationCreate], - device_id: int = Query(...), - service: DeviceModelRelationService = Depends(get_service)): - relations = service.update_relations_by_device(device_id, relation_data) +async def update_by_device(relation_data: List[DeviceModelRelationCreate], + device_id: int = Query(...), + service: DeviceModelRelationService = Depends(get_service)): + relations = await service.update_relations_by_device(device_id, relation_data) return standard_response(data=relations) - -# @router.delete("/delete_by_device", response_model=StandardResponse[int]) -# def delete_device(device_id: int, db: Session = Depends(get_db)): -# service = DeviceModelRelationService(db) -# count = service.delete_relations_by_device(device_id) -# return standard_response(data=count) diff --git a/apis/device_scene_realtion.py b/apis/device_scene_realtion.py index 9a30a44..d9aa3a2 100644 --- a/apis/device_scene_realtion.py +++ b/apis/device_scene_realtion.py @@ -1,8 +1,8 @@ -from typing import List, Optional + from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import StandardResponse, standard_response from db.database import get_db from entity.device_scene_relation import DeviceSceneRelationInfo, DeviceSceneRelation @@ -13,17 +13,17 @@ @router.get("/get_by_device", response_model=StandardResponse[DeviceSceneRelationInfo]) -def list_by_device( +async def list_by_device( device_id: int, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - scene = service.get_device_scene(device_id) + scene = await service.get_device_scene(device_id) return standard_response(data=scene) @router.post("/update_by_device", response_model=StandardResponse[DeviceSceneRelation]) -def update_by_device(device_id: int, scene_id: int, - db: Session = Depends(get_db)): +async def update_by_device(device_id: int, scene_id: int, + db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - relation = service.update_relation_by_device(device_id, scene_id) + relation = await service.update_relation_by_device(device_id, scene_id) return standard_response(data=relation) diff --git a/apis/frame.py b/apis/frame.py index 1cb9294..5146ede 100644 --- a/apis/frame.py +++ b/apis/frame.py @@ -5,7 +5,7 @@ import cv2 from fastapi import APIRouter, Depends, Query, HTTPException, Request -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from fastapi.responses import StreamingResponse, FileResponse from apis.base import StandardResponse, PageResponse, standard_response, standard_error_response, convert_page_param @@ -19,20 +19,20 @@ @router.get("/page/", response_model=StandardResponse[PageResponse[DeviceFrame]]) -def get_frame_page( +async def get_frame_page( device_name: Optional[str] = None, device_code: Optional[str] = None, frame_start_time: Optional[str] = None, frame_end_time: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceFrameService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - frames, total = service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), - parse_datetime(frame_end_time), offset, limit) + frames, total = await service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), + parse_datetime(frame_end_time), offset, limit) return standard_response( data=PageResponse(total=total, items=frames) @@ -41,10 +41,10 @@ # 路由:使用 OpenCV 生成内存图像并返回字节流 @router.get("/frame_image/") -def get_frame_image(frame_id, db: Session = Depends(get_db)): +async def get_frame_image(frame_id, db: AsyncSession = Depends(get_db)): try: service = DeviceFrameService(db) - frame = service.get_frame_annotator(frame_id) + frame = await service.get_frame_annotator(frame_id) if frame is None: return standard_error_response(message="Frame does not exist") @@ -63,14 +63,3 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"Error reading image: {str(e)}") - -@router.get("/frame_test") -def get_frame_test(request: Request): - file_path = "test.jpg" - return FileResponse(file_path) - # def iterfile(): - # with open(file_path, 'rb') as file: - # while chunk := file.read(1024): - # yield chunk - # - # return StreamingResponse(iterfile(), media_type="image/jpeg") diff --git a/apis/model.py b/apis/model.py index a1161eb..fa2b191 100644 --- a/apis/model.py +++ b/apis/model.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query, UploadFile, File, Form -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from app_instance import get_app @@ -18,27 +18,27 @@ @router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) -def get_model_list( +async def get_model_list( name: Optional[str] = None, remark: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = ModelService(db) - models = list(service.get_model_list(name, remark)) - return standard_response(data=models) + models = await service.get_model_list(name, remark) + return standard_response(data=list(models)) @router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModelInfo]]) -def get_model_page( +async def get_model_page( name: Optional[str] = None, remark: Optional[str] = None, offset: int = Query(0, ge=0), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = ModelService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - models, total = service.get_model_page(name, remark, offset, limit) + models, total = await service.get_model_page(name, remark, offset, limit) return standard_response( data=PageResponse(total=total, items=models) @@ -46,9 +46,9 @@ @router.post("/add", response_model=StandardResponse[AlgoModelInfo]) -def create_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelCreate结构"), +async def create_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelCreate结构"), file: UploadFile = File(..., description="模型文件"), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_db), ): # 检查文件类型 if not file.filename.endswith(".zip"): @@ -56,25 +56,25 @@ model_data = AlgoModelCreate.parse_raw(json_data) service = ModelService(db) - model = service.create_model(model_data, file) + model = await service.create_model(model_data, file) return standard_response(data=model) @router.post("/update", response_model=StandardResponse[AlgoModelInfo]) -def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), +async def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), file: UploadFile = File(None, description="模型文件"), service: ModelService = Depends(get_service)): model_data = AlgoModelUpdate.parse_raw(json_data) - model = service.update_model(model_data, file) + model = await service.update_model(model_data, file) if not model: return standard_error_response(data=model_data, message="Model not found") return standard_response(data=model) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_model(model_id: int, db: Session = Depends(get_db)): +async def delete_model(model_id: int, db: AsyncSession = Depends(get_db)): service = ModelService(db) - model = service.delete_model(model_id) + model = await service.delete_model(model_id) if not model: return standard_error_response(data=model_id, message="Model not found") return standard_response(data=model_id) diff --git a/apis/push_config.py b/apis/push_config.py index 5e25579..e234afb 100644 --- a/apis/push_config.py +++ b/apis/push_config.py @@ -1,7 +1,7 @@ from typing import List from fastapi import APIRouter, Depends -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse from db.database import get_db @@ -12,21 +12,21 @@ @router.get("/list", response_model=StandardResponse[List[PushConfig]]) -def get_push_config_list(db: Session = Depends(get_db)): +async def get_push_config_list(db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_configs = service.get_push_config_list() + push_configs = await service.get_push_config_list() return standard_response(data=push_configs) @router.get("/get_by_type", response_model=StandardResponse[PushConfig]) -def get_by_type(push_type: int, db: Session = Depends(get_db)): +async def get_by_type(push_type: int, db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_config = service.get_push_config(push_type) + push_config = await service.get_push_config(push_type) return standard_response(data=push_config) @router.post("/set_push_config", response_model=StandardResponse[PushConfig]) -def set_push_config(push_config: PushConfigCreate, db: Session = Depends(get_db)): +async def set_push_config(push_config: PushConfigCreate, db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_config = service.set_push_config(push_config) + push_config = await service.set_push_config(push_config) return standard_response(data=push_config) diff --git a/apis/scene.py b/apis/scene.py index f84b54b..1b7f706 100644 --- a/apis/scene.py +++ b/apis/scene.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query, UploadFile, File, Form -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db @@ -14,27 +14,27 @@ @router.get("/list", response_model=StandardResponse[List[Scene]]) -def get_scene_list( +async def get_scene_list( name: Optional[str] = None, remark: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) - scenes = list(service.get_scene_list(name, remark)) - return standard_response(data=scenes) + scenes = await service.get_scene_list(name, remark) + return standard_response(data=list(scenes)) @router.get("/page/", response_model=StandardResponse[PageResponse[SceneInfo]]) -def get_scene_page( +async def get_scene_page( name: Optional[str] = None, remark: Optional[str] = None, offset: int = Query(0, ge=0), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - scenes, total = service.get_scene_page(name, remark, offset, limit) + scenes, total = await service.get_scene_page(name, remark, offset, limit) return standard_response( data=PageResponse(total=total, items=scenes) @@ -42,9 +42,9 @@ @router.post("/add", response_model=StandardResponse[SceneInfo]) -def create_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneCreate结构"), +async def create_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneCreate结构"), file: UploadFile = File(..., description="模型文件"), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_db), ): # 检查文件类型 if not file.filename.endswith(".zip"): @@ -52,26 +52,26 @@ scene_data = SceneCreate.parse_raw(json_data) service = SceneService(db) - scene = service.create_scene(scene_data, file) + scene = await service.create_scene(scene_data, file) return standard_response(data=scene) @router.post("/update", response_model=StandardResponse[SceneInfo]) -def update_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneUpdate结构"), +async def update_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneUpdate结构"), file: UploadFile = File(None, description="模型文件"), - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) scene_data = SceneUpdate.parse_raw(json_data) - scene = service.update_scene(scene_data, file) + scene = await service.update_scene(scene_data, file) if not scene: return standard_error_response(data=scene_data, message="Scene not found") return standard_response(data=scene) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_scene(scene_id: int, db: Session = Depends(get_db)): +async def delete_scene(scene_id: int, db: AsyncSession = Depends(get_db)): service = SceneService(db) - scene = service.delete_scene(scene_id) + scene = await service.delete_scene(scene_id) if not scene: return standard_error_response(data=scene_id, message="Scene not found") return standard_response(data=scene_id) diff --git a/app_instance.py b/app_instance.py index a6e0992..f9d1ad9 100644 --- a/app_instance.py +++ b/app_instance.py @@ -12,6 +12,7 @@ from services.device_model_relation_service import DeviceModelRelationService from services.device_scene_relation_service import DeviceSceneRelationService from services.device_service import DeviceService +from services.global_config import GlobalConfig from services.model_service import ModelService from services.scene_service import SceneService from tcp.tcp_manager import TcpManager @@ -28,7 +29,11 @@ async def lifespan(app: FastAPI): main_loop = asyncio.get_running_loop() - with next(get_db()) as db: + # async with get_db() as db: + async for db in get_db(): + global_config = GlobalConfig() + await global_config.init_config() + device_service = DeviceService(db) model_service = ModelService(db) model_relation_service = DeviceModelRelationService(db) @@ -51,7 +56,7 @@ relation_service=model_relation_service, ) app.state.algo_runner = algo_runner - await algo_runner.start() + # await algo_runner.start() scene_runner = SceneRunner( device_service=device_service, @@ -61,7 +66,9 @@ main_loop=main_loop ) app.state.scene_runner = scene_runner - await scene_runner.start() + # await scene_runner.start() + + yield # 允许请求处理 diff --git a/common/global_logger.py b/common/global_logger.py index 3e3a6a9..3e672ca 100644 --- a/common/global_logger.py +++ b/common/global_logger.py @@ -12,6 +12,8 @@ logger = logging.getLogger("casic_safe_logger") logger.setLevel(logging.DEBUG) # 设置日志级别 +logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING) + # 创建一个TimedRotatingFileHandler handler = logging.handlers.TimedRotatingFileHandler( os.path.join(log_dir, 'app.log'), # 日志文件名 diff --git a/db/database.py b/db/database.py index 4f44d2d..01ac013 100644 --- a/db/database.py +++ b/db/database.py @@ -1,22 +1,25 @@ +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession +from sqlalchemy.orm import sessionmaker from sqlmodel import SQLModel, create_engine, Session -from contextlib import contextmanager +from contextlib import contextmanager, asynccontextmanager sqlite_file_name = "./db/safe-algo-pro.db" -sqlite_url = f"sqlite:///{sqlite_file_name}" +sqlite_url = f"sqlite+aiosqlite:///{sqlite_file_name}" # 使用异步SQLite驱动 -connect_args = {"check_same_thread": False} -engine = create_engine(sqlite_url, connect_args=connect_args) +engine = create_async_engine(sqlite_url, echo=False, future=True) # 初始化数据库表 -def init_db(): - SQLModel.metadata.create_all(engine) +async def init_db(): + async with engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.create_all) -# 数据库会话管理 -def get_db(): - session = Session(engine) - try: +# 异步数据库会话管理 +# @asynccontextmanager +async def get_db() -> AsyncSession: + async_session = sessionmaker( + engine, class_=AsyncSession, expire_on_commit=False + ) + async with async_session() as session: yield session - finally: - session.close() diff --git a/main.py b/main.py index e35e458..e830e86 100644 --- a/main.py +++ b/main.py @@ -3,7 +3,9 @@ import logging from fastapi.openapi.docs import get_swagger_ui_html -from starlette.staticfiles import StaticFiles +from fastapi.middleware.cors import CORSMiddleware +from fastapi.staticfiles import StaticFiles + from app_instance import get_app from common.global_logger import logger @@ -28,6 +30,14 @@ from apis.router import router app.include_router(router, prefix="/api") +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + if __name__ == "__main__": # 重定向 uvicorn 的日志 uvicorn_logger = logging.getLogger("uvicorn") diff --git a/requirements.txt b/requirements.txt index 161192a..aa9e1e6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,6 @@ starlette uvicorn sqlalchemy -aiohttp \ No newline at end of file +aiohttp +aiosqlite +aiofiles \ No newline at end of file diff --git a/services/data_gas_service.py b/services/data_gas_service.py index fe6ce53..6ac5665 100644 --- a/services/data_gas_service.py +++ b/services/data_gas_service.py @@ -2,23 +2,24 @@ from typing import Optional, Tuple, Sequence from sqlalchemy import func -from sqlmodel import Session, select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlmodel import select from entity.data_gas import DataGas, DataGasInfo from entity.device import Device class DataGasService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db - def add_data_gas(self, data_gas: DataGas): + async def add_data_gas(self, data_gas: DataGas): self.db.add(data_gas) - self.db.commit() - self.db.refresh(data_gas) + await self.db.commit() + await self.db.refresh(data_gas) return data_gas - def get_data_gas_page(self, + async def get_data_gas_page(self, device_code: Optional[str] = None, start_time: Optional[datetime] = None, end_time: Optional[datetime] = None, @@ -29,13 +30,17 @@ # 查询总记录数 total_statement = select(func.count()).select_from(statement.subquery()) - total = self.db.exec(total_statement).one() + total_result = await self.db.execute(total_statement) + total = total_result.scalar_one() # 添加分页限制 statement = statement.offset(offset).limit(limit) # 执行查询并返回结果 - results = self.db.exec(statement).all() + results = await self.db.execute(statement) + rows = results.all() + + data_gas_info_list = [ DataGasInfo( id=data_gas.id, @@ -44,18 +49,19 @@ ts=data_gas.ts, device_name=device_name ) - for data_gas, device_name in results + for data_gas, device_name in rows ] return data_gas_info_list, total # 返回分页数据和总数 - def get_data_gas_list(self, + async def get_data_gas_list(self, device_code: Optional[str] = None, start_time: Optional[datetime] = None, end_time: Optional[datetime] = None, ) -> Sequence[DataGasInfo]: statement = self.gas_query(device_code, end_time, start_time) - results = self.db.exec(statement).all() + results = await self.db.execute(statement) + rows = results.all() data_gas_info_list = [ DataGasInfo( id=data_gas.id, @@ -64,7 +70,7 @@ ts=data_gas.ts, device_name=device_name ) - for data_gas, device_name in results + for data_gas, device_name in rows ] return data_gas_info_list diff --git a/services/device_frame_service.py b/services/device_frame_service.py index 5dead7d..1a468ae 100644 --- a/services/device_frame_service.py +++ b/services/device_frame_service.py @@ -3,7 +3,10 @@ from copy import deepcopy from datetime import datetime from typing import Sequence, Optional, Tuple -from sqlmodel import Session + +import aiofiles +import numpy as np +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import func @@ -13,46 +16,56 @@ import cv2 -from sqlmodel import Session, select, delete +from sqlmodel import select, delete from services.frame_analysis_result_service import FrameAnalysisResultService class DeviceFrameService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db - def add_frame(self, device_id, frame_data) -> DeviceFrame: - def save_frame_file(): - # 生成当前年月日作为目录路径 - current_date = datetime.now().strftime('%Y-%m-%d') - # 生成随机 UUID 作为文件名 - file_name = str(uuid.uuid4()) + ".jpeg" - # 创建保存图片的完整路径 - save_path = os.path.join('./storage/frames', current_date, file_name) - # 创建目录(如果不存在) - os.makedirs(os.path.dirname(save_path), exist_ok=True) - # 保存图片 - cv2.imwrite(save_path, frame_data) - return save_path + async def add_frame(self, device_id, frame_data) -> DeviceFrame: + async def add_frame(self, device_id, frame_data) -> 'DeviceFrame': + async def save_frame_file(): + # 生成当前年月日作为目录路径 + current_date = datetime.now().strftime('%Y-%m-%d') + # 生成随机 UUID 作为文件名 + file_name = f"{uuid.uuid4()}.jpeg" + # 创建保存图片的完整路径 + save_path = os.path.join('./storage/frames', current_date, file_name) + # 创建目录(如果不存在) + os.makedirs(os.path.dirname(save_path), exist_ok=True) - # 保存图片文件 - file_path = save_frame_file() - device_frame = DeviceFrame(device_id=device_id, frame_path=file_path) - self.db.add(device_frame) - self.db.commit() - self.db.refresh(device_frame) - return device_frame + # 将 frame_data 转换为二进制数据(假设 frame_data 是一个 numpy 数组) + _, encoded_image = cv2.imencode('.jpeg', frame_data) + image_data = encoded_image.tobytes() - def get_frame_page(self, - device_name: Optional[str] = None, - device_code: Optional[str] = None, - frame_start_time: Optional[datetime] = None, - frame_end_time: Optional[datetime] = None, - offset: int = 0, - limit: int = 10 - ) -> Tuple[Sequence[DeviceFrame], int]: + # 使用 aiofiles 进行异步写入 + async with aiofiles.open(save_path, 'wb') as f: + await f.write(image_data) + + return save_path + + # 异步保存图片文件 + file_path = await save_frame_file() + + # 创建并保存到数据库中 + device_frame = DeviceFrame(device_id=device_id, frame_path=file_path) + self.db.add(device_frame) + await self.db.commit() + await self.db.refresh(device_frame) + return device_frame + + async def get_frame_page(self, + device_name: Optional[str] = None, + device_code: Optional[str] = None, + frame_start_time: Optional[datetime] = None, + frame_end_time: Optional[datetime] = None, + offset: int = 0, + limit: int = 10 + ) -> Tuple[Sequence[DeviceFrame], int]: statement = ( select(DeviceFrame, Device) .join(Device, DeviceFrame.device_id == Device.id) @@ -69,26 +82,37 @@ # 查询总记录数 total_statement = select(func.count()).select_from(statement.subquery()) - total = self.db.exec(total_statement).one() + total_result = await self.db.execute(total_statement) + total = total_result.scalar_one() # 添加分页限制 statement = statement.offset(offset).limit(limit) # 执行查询并返回结果 - results = self.db.exec(statement) - results = results.all() - frames = [frame for frame, device in results] + results = await self.db.execute(statement) + rows = results.all() + + frames = [frame for frame, device in rows] return frames, total # 返回分页数据和总数 - def get_frame(self, frame_id: int): - return self.db.get(DeviceFrame, frame_id) + async def get_frame(self, frame_id: int): + result = await self.db.execute(select(DeviceFrame).where(DeviceFrame.id == frame_id)) + frame = result.scalar_one_or_none() + return frame - def get_frame_annotator(self, frame_id: int): - device_frame = self.get_frame(frame_id) + async def get_frame_annotator(self, frame_id: int): + device_frame = await self.get_frame(frame_id) if device_frame: - frame_image = cv2.imread(device_frame.frame_path) + # 异步读取图像文件 + async with aiofiles.open(device_frame.frame_path, mode='rb') as f: + file_content = await f.read() + + # 将读取的字节内容转换为 OpenCV 图像 + np_array = np.frombuffer(file_content, dtype=np.uint8) + frame_image = cv2.imdecode(np_array, cv2.IMREAD_COLOR) + frame_analysis_result_service = FrameAnalysisResultService(self.db) - results = frame_analysis_result_service.get_results_by_frame(device_frame.id) + results = await frame_analysis_result_service.get_results_by_frame(device_frame.id) if results: annotator = Annotator(deepcopy(frame_image)) height, width = frame_image.shape[:2] diff --git a/services/device_model_relation_service.py b/services/device_model_relation_service.py index c32b5a5..5f1aa9d 100644 --- a/services/device_model_relation_service.py +++ b/services/device_model_relation_service.py @@ -1,8 +1,8 @@ from datetime import datetime from typing import List -from sqlmodel import Session, select, delete - +from sqlmodel import select, delete +from sqlalchemy.ext.asyncio import AsyncSession from common.consts import NotifyChangeType from common.global_thread_pool import GlobalThreadPool from entity.device_model_relation import DeviceModelRelation, DeviceModelRelationInfo, DeviceModelRelationCreate @@ -10,7 +10,7 @@ class DeviceModelRelationService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db self.__relation_change_callbacks = [] # 用于存储回调函数 self.thread_pool = GlobalThreadPool() @@ -24,7 +24,7 @@ for callback in self.__relation_change_callbacks: self.thread_pool.executor.submit(callback, device_id, change_type) - def get_device_models(self, device_id: int) -> List[DeviceModelRelationInfo]: + async def get_device_models(self, device_id: int) -> List[DeviceModelRelationInfo]: statement = ( select(DeviceModelRelation, AlgoModel) .join(AlgoModel, DeviceModelRelation.algo_model_id == AlgoModel.id) @@ -32,7 +32,8 @@ ) # 执行联表查询 - result = self.db.exec(statement).all() + results = await self.db.execute(statement) + rows = results.all() models_info = [ DeviceModelRelationInfo( @@ -46,13 +47,11 @@ algo_model_path=model.path, algo_model_remark=model.remark, ) - for relation, model in result + for relation, model in rows ] return models_info - - - def add_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]): + async def add_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]): new_relations = [ DeviceModelRelation( algo_model_id=relation.algo_model_id, @@ -64,20 +63,21 @@ ) for relation in relations ] - self.db.add_all(new_relations) - self.db.commit() for relation in new_relations: - self.db.refresh(relation) + self.db.add(relation) + await self.db.commit() + for relation in new_relations: + await self.db.refresh(relation) return new_relations - def delete_relations_by_device(self, device_id: int): + async def delete_relations_by_device(self, device_id: int): statement = delete(DeviceModelRelation).where(DeviceModelRelation.device_id == device_id) - count = self.db.exec(statement) - self.db.commit() - return count.rowcount + result = await self.db.execute(statement) + await self.db.commit() + return result.rowcount - def update_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]): - self.delete_relations_by_device(device_id) - new_relations = self.add_relations_by_device(device_id, relations) + async def update_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]): + await self.delete_relations_by_device(device_id) + new_relations = await self.add_relations_by_device(device_id, relations) self.notify_change(device_id, NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE) return new_relations diff --git a/services/device_scene_relation_service.py b/services/device_scene_relation_service.py index 592ee98..3416831 100644 --- a/services/device_scene_relation_service.py +++ b/services/device_scene_relation_service.py @@ -1,8 +1,8 @@ from datetime import datetime from typing import List, Optional -from sqlmodel import Session, select, delete - +from sqlmodel import select, delete +from sqlalchemy.ext.asyncio import AsyncSession from common.consts import NotifyChangeType from common.global_thread_pool import GlobalThreadPool from entity.device_scene_relation import DeviceSceneRelationInfo, DeviceSceneRelation @@ -10,7 +10,7 @@ class DeviceSceneRelationService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db self.__relation_change_callbacks = [] # 用于存储回调函数 self.thread_pool = GlobalThreadPool() @@ -24,7 +24,7 @@ for callback in self.__relation_change_callbacks: self.thread_pool.executor.submit(callback, device_id, change_type) - def get_device_scene(self, device_id: int) -> Optional[DeviceSceneRelationInfo]: + async def get_device_scene(self, device_id: int) -> Optional[DeviceSceneRelationInfo]: statement = ( select(DeviceSceneRelation, Scene) .join(Scene, DeviceSceneRelation.scene_id == Scene.id) @@ -32,11 +32,12 @@ ) # 执行联表查询 - result = self.db.exec(statement).first() + result = await self.db.execute(statement) + result_row = result.first() scene_info = None - if result: - relation, scene = result[0], result[1] + if result_row: + relation, scene = result_row scene_info = DeviceSceneRelationInfo( id=relation.id, device_id=relation.device_id, @@ -48,23 +49,23 @@ ) return scene_info - def add_relation_by_device(self, device_id: int, scene_id: int): + async def add_relation_by_device(self, device_id: int, scene_id: int): new_relation = DeviceSceneRelation(device_id=device_id, scene_id=scene_id) new_relation.create_time = datetime.now() new_relation.update_time = datetime.now() self.db.add(new_relation) - self.db.commit() - self.db.refresh(new_relation) + await self.db.commit() + await self.db.refresh(new_relation) return new_relation - def delete_relation_by_device(self, device_id: int): + async def delete_relation_by_device(self, device_id: int): statement = delete(DeviceSceneRelation).where(DeviceSceneRelation.device_id == device_id) - count = self.db.exec(statement) - self.db.commit() - return count.rowcount + result = await self.db.execute(statement) + await self.db.commit() + return result.rowcount - def update_relation_by_device(self, device_id: int, scene_id: int): - self.delete_relation_by_device(device_id) - new_relation = self.add_relation_by_device(device_id, scene_id) + async def update_relation_by_device(self, device_id: int, scene_id: int): + await self.delete_relation_by_device(device_id) + new_relation = await self.add_relation_by_device(device_id, scene_id) self.notify_change(device_id, NotifyChangeType.DEVICE_SCENE_RELATION_UPDATE) return new_relation diff --git a/services/device_service.py b/services/device_service.py index cad46d3..2488fd5 100644 --- a/services/device_service.py +++ b/services/device_service.py @@ -4,8 +4,8 @@ from typing import Sequence, Optional, Tuple from sqlalchemy import func -from sqlmodel import Session, select - +from sqlmodel import select, delete +from sqlalchemy.ext.asyncio import AsyncSession from common.device_status_manager import DeviceStatusManager from common.global_thread_pool import GlobalThreadPool from common.consts import NotifyChangeType, DEVICE_MODE @@ -15,25 +15,25 @@ class DeviceService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db self.__device_change_callbacks = [] # 用于存储回调函数 self.thread_pool = GlobalThreadPool() # 创建一个独立的事件循环并启动线程 - self.loop = asyncio.new_event_loop() - self.loop_thread = threading.Thread(target=self._start_loop, daemon=True) - self.loop_thread.start() + # self.loop = asyncio.new_event_loop() + # self.loop_thread = threading.Thread(target=self._start_loop, daemon=True) + # self.loop_thread.start() - def _start_loop(self): - """后台线程运行事件循环""" - asyncio.set_event_loop(self.loop) - self.loop.run_forever() - - def shutdown(self): - """清理事件循环和线程""" - self.loop.call_soon_threadsafe(self.loop.stop) - self.loop_thread.join() + # def _start_loop(self): + # """后台线程运行事件循环""" + # asyncio.set_event_loop(self.loop) + # self.loop.run_forever() + # + # def shutdown(self): + # """清理事件循环和线程""" + # self.loop.call_soon_threadsafe(self.loop.stop) + # self.loop_thread.join() def register_change_callback(self, callback): """注册设备变化回调函数""" @@ -42,52 +42,53 @@ def notify_change(self, device_id, change_type): """当设备发生变化时,调用回调通知变化""" # loop = asyncio.get_event_loop() # 获取当前的事件循环 - for callback in self.__device_change_callbacks: - if asyncio.iscoroutinefunction(callback): - # 如果是协程函数,使用事件循环运行它 - # loop = asyncio.new_event_loop() - # asyncio.set_event_loop(loop) - asyncio.run_coroutine_threadsafe(callback(device_id, change_type), self.loop) - else: - # 如果是普通函数,直接提交到线程池 - self.thread_pool.executor.submit(callback, device_id, change_type) + # for callback in self.__device_change_callbacks: + # if asyncio.iscoroutinefunction(callback): + # # 如果是协程函数,使用事件循环运行它 + # # loop = asyncio.new_event_loop() + # # asyncio.set_event_loop(loop) + # asyncio.run_coroutine_threadsafe(callback(device_id, change_type), self.loop) + # else: + # # 如果是普通函数,直接提交到线程池 + # self.thread_pool.executor.submit(callback, device_id, change_type) - def get_device_list(self, - name: Optional[str] = None, - code: Optional[str] = None, - device_type: Optional[str] = None, - ) -> Sequence[Device]: + async def get_device_list(self, + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + ) -> Sequence[Device]: statement = self.device_query(code, device_type, name) - results = self.db.exec(statement) - return results.all() + results = await self.db.execute(statement) + return results.scalars().all() - def get_device_page(self, - name: Optional[str] = None, - code: Optional[str] = None, - device_type: Optional[str] = None, - offset: int = 0, - limit: int = 10 - ) -> Tuple[Sequence[DeviceInfo], int]: + async def get_device_page(self, + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + offset: int = 0, + limit: int = 10 + ) -> Tuple[Sequence[DeviceInfo], int]: statement = self.device_query(code, device_type, name) # 查询总记录数 total_statement = select(func.count()).select_from(statement.subquery()) - total = self.db.exec(total_statement).one() + total_result = await self.db.execute(total_statement) + total = total_result.scalar_one() # 添加分页限制 statement = statement.offset(offset).limit(limit) # 执行查询并返回结果 - results = self.db.exec(statement) - device_list = results.all() + results = await self.db.execute(statement) + device_list = results.scalars().all() device_info_list = [] if device_list: device_model_relation_service = DeviceModelRelationService(self.db) device_scene_relation_service = DeviceSceneRelationService(self.db) device_status_manager = DeviceStatusManager() for device in device_list: - model_relations = device_model_relation_service.get_device_models(device.id) - scene_relation = device_scene_relation_service.get_device_scene(device.id) + model_relations = await device_model_relation_service.get_device_models(device.id) + scene_relation = await device_scene_relation_service.get_device_scene(device.id) device_info_list.append(DeviceInfo( id=device.id, @@ -121,20 +122,20 @@ statement = statement.where(Device.type == device_type) return statement - def create_device(self, device_data: DeviceCreate): + async def create_device(self, device_data: DeviceCreate): device = Device.model_validate(device_data) device.create_time = datetime.now() device.update_time = datetime.now() - self.handle_device_mode(device) + await self.handle_device_mode(device) self.db.add(device) - self.db.commit() - self.db.refresh(device) + await self.db.commit() + await self.db.refresh(device) self.notify_change(device.id, NotifyChangeType.DEVICE_CREATE) return device - def update_device(self, device_data: DeviceUpdate): - device_old = self.db.get(Device, device_data.id) - device = self.db.get(Device, device_data.id) + async def update_device(self, device_data: DeviceUpdate): + device_old = await self.get_device(device_data.id) + device = await self.get_device(device_data.id) if not device: return None @@ -142,41 +143,45 @@ for key, value in update_data.items(): setattr(device, key, value) - self.handle_device_mode(device) + await self.handle_device_mode(device) device.update_time = datetime.now() self.db.add(device) - self.db.commit() - self.db.refresh(device) + await self.db.commit() + await self.db.refresh(device) self.notify_change(device.id, NotifyChangeType.DEVICE_UPDATE) return device - def delete_device(self, device_id: int): - device = self.db.get(Device, device_id) + async def delete_device(self, device_id: int): + device = await self.get_device(device_id) if not device: return None - self.db.delete(device) - self.db.commit() + statement = delete(Device).where(Device.id == device_id) + await self.db.execute(statement) + await self.db.commit() + self.notify_change(device.id, NotifyChangeType.DEVICE_DELETE) model_relation_service = DeviceModelRelationService(self.db) - model_relation_service.delete_relations_by_device(device_id) + await model_relation_service.delete_relations_by_device(device_id) scene_relation_service = DeviceSceneRelationService(self.db) - scene_relation_service.delete_relation_by_device(device.id) + await scene_relation_service.delete_relation_by_device(device.id) return device - def handle_device_mode(self, device): + async def handle_device_mode(self, device): if device.mode == DEVICE_MODE.ALGO: scene_relation_service = DeviceSceneRelationService(self.db) - scene_relation_service.delete_relation_by_device(device.id) + await scene_relation_service.delete_relation_by_device(device.id) elif device.mode == DEVICE_MODE.SCENE: model_relation_service = DeviceModelRelationService(self.db) - model_relation_service.delete_relations_by_device(device.id) + await model_relation_service.delete_relations_by_device(device.id) else: scene_relation_service = DeviceSceneRelationService(self.db) - scene_relation_service.delete_relation_by_device(device.id) + await scene_relation_service.delete_relation_by_device(device.id) model_relation_service = DeviceModelRelationService(self.db) - model_relation_service.delete_relations_by_device(device.id) + await model_relation_service.delete_relations_by_device(device.id) - def get_device(self, device_id: int): - return self.db.get(Device, device_id) + async def get_device(self, device_id: int): + result = await self.db.execute(select(Device).where(Device.id == device_id)) + frame = result.scalar_one_or_none() + return frame diff --git a/services/frame_analysis_result_service.py b/services/frame_analysis_result_service.py index dbd272d..4e6c145 100644 --- a/services/frame_analysis_result_service.py +++ b/services/frame_analysis_result_service.py @@ -1,23 +1,25 @@ from typing import List -from sqlmodel import Session, select +from sqlmodel import select +from sqlalchemy.ext.asyncio import AsyncSession from entity.frame_analysis_result import FrameAnalysisResultCreate, FrameAnalysisResult class FrameAnalysisResultService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db - def add_frame_analysis_results(self, results: List[FrameAnalysisResultCreate]): + async def add_frame_analysis_results(self, results: List[FrameAnalysisResultCreate]): new_results = [FrameAnalysisResult.model_validate(result) for result in results] - self.db.add_all(new_results) - self.db.commit() for result in new_results: - self.db.refresh(result) + self.db.add(result) + await self.db.commit() + for result in new_results: + await self.db.refresh(result) return new_results - def get_results_by_frame(self, frame_id): + async def get_results_by_frame(self, frame_id): statement = select(FrameAnalysisResult).where(FrameAnalysisResult.frame_id == frame_id) - results = self.db.exec(statement) - return results.all() + results = await self.db.execute(statement) + return results.scalars().all() diff --git a/services/global_config.py b/services/global_config.py index 7dae73b..fa53b2b 100644 --- a/services/global_config.py +++ b/services/global_config.py @@ -1,3 +1,5 @@ +import asyncio + from common.consts import PUSH_TYPE from db.database import get_db from entity.push_config import PushConfig @@ -6,10 +8,12 @@ class GlobalConfig: _instance = None + _lock = asyncio.Lock() def __new__(cls): if cls._instance is None: cls._instance = super().__new__(cls) + return cls._instance def __init__(self): @@ -19,49 +23,58 @@ self.gas_push_config = None self.algo_result_push_config = None self.alarm_push_config = None - self.init_config() # 进行初始化 + self._init_done = False - def init_config(self): - # 初始化配置逻辑 - with next(get_db()) as db: - self.config_service = PushConfigService(db) - self.set_gas_push_config(self.config_service.get_push_config(PUSH_TYPE.GAS)) - self.set_algo_result_push_config(self.config_service.get_push_config(PUSH_TYPE.ALGO_RESULT)) - self.set_alarm_push_config(self.config_service.get_push_config(PUSH_TYPE.ALARM)) + async def _initialize(self): + await self.init_config() # 调用异步初始化 - self.config_service.register_change_callback(self.on_config_change) + async def init_config(self): + # 确保只初始化一次 + if not self._init_done: + async with self._lock: + if not self._init_done: # 双重检查锁,避免多次初始化 + async for db in get_db(): + self.config_service = PushConfigService(db) + self.config_service.register_change_callback(self.on_config_change) + self.gas_push_config = await self.config_service.get_push_config(PUSH_TYPE.GAS) + self.algo_result_push_config = await self.config_service.get_push_config(PUSH_TYPE.ALGO_RESULT) + self.alarm_push_config = await self.config_service.get_push_config(PUSH_TYPE.ALARM) + self._init_done = True - def on_config_change(self, config: PushConfig): + async def on_config_change(self, config: PushConfig): if config.push_type == PUSH_TYPE.GAS: - self.set_gas_push_config(config) + await self.set_gas_push_config(config) elif config.push_type == PUSH_TYPE.ALGO_RESULT: - self.set_algo_result_push_config(config) + await self.set_algo_result_push_config(config) elif config.push_type == PUSH_TYPE.ALARM: - self.set_alarm_push_config(config) + await self.set_alarm_push_config(config) def get_gas_push_config(self): """获取 gas_push_config 配置""" return self.gas_push_config - def set_gas_push_config(self, config): + async def set_gas_push_config(self, config): """设置 gas_push_config 配置""" if config: - self.gas_push_config = config + async with self._lock: + self.gas_push_config = config def get_algo_result_push_config(self): """获取 algo_result_push_config 配置""" return self.algo_result_push_config - def set_algo_result_push_config(self, config): + async def set_algo_result_push_config(self, config): """设置 algo_result_push_config 配置""" if config: - self.algo_result_push_config = config + async with self._lock: + self.algo_result_push_config = config def get_alarm_push_config(self): """获取 algo_result_push_config 配置""" return self.alarm_push_config - def set_alarm_push_config(self, config): + async def set_alarm_push_config(self, config): """设置 algo_result_push_config 配置""" if config: - self.alarm_push_config = config + async with self._lock: + self.alarm_push_config = config diff --git a/services/model_service.py b/services/model_service.py index f9531a5..afc61d9 100644 --- a/services/model_service.py +++ b/services/model_service.py @@ -1,3 +1,4 @@ +import asyncio import os import uuid import zipfile @@ -5,10 +6,11 @@ from pathlib import Path from typing import List, Sequence, Optional, Tuple, Type +import aiofiles from fastapi import UploadFile from sqlalchemy import func -from sqlmodel import Session, select - +from sqlmodel import select, delete +from sqlalchemy.ext.asyncio import AsyncSession from common.biz_exception import BizException from common.string_utils import snake_to_camel from entity.device_model_relation import DeviceModelRelation @@ -18,7 +20,7 @@ class ModelService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db self.__model_change_callbacks = [] # 用于存储回调函数 self.thread_pool = GlobalThreadPool() @@ -32,15 +34,15 @@ for callback in self.__model_change_callbacks: self.thread_pool.executor.submit(callback, algo_model_id, change_type) - def get_model_list(self, - name: Optional[str] = None, - remark: Optional[str] = None, - ) -> Sequence[AlgoModel]: + async def get_model_list(self, + name: Optional[str] = None, + remark: Optional[str] = None, + ) -> Sequence[AlgoModel]: statement = self.model_query(name, remark) - results = self.db.exec(statement) - return results.all() + results = await self.db.execute(statement) + return results.scalars().all() - def get_model_page(self, + async def get_model_page(self, name: Optional[str] = None, remark: Optional[str] = None, offset: int = 0, @@ -50,19 +52,21 @@ # 查询总记录数 total_statement = select(func.count()).select_from(statement.subquery()) - total = self.db.exec(total_statement).one() + total_result = await self.db.execute(total_statement) + total = total_result.scalar_one() # 添加分页限制 statement = statement.offset(offset).limit(limit) # 执行查询并返回结果 - model_list = self.db.exec(statement) + model_list = await self.db.execute(statement) + rows = model_list.scalars().all() model_info_list: List[AlgoModelInfo] = [] - if model_list: - for model in model_list: + if rows: + for model in rows: model_info_list.append(AlgoModelInfo( **model.dict(), - usage_status="使用中" if self.get_model_usage(model.id) else "未使用" + usage_status="使用中" if await self.get_model_usage(model.id) else "未使用" )) return model_info_list, total # 返回分页数据和总数 @@ -76,7 +80,7 @@ statement = statement.where(AlgoModel.remark.like(f"%{remark}%")) return statement - def process_zip(self, file: UploadFile): + async def process_zip(self, file: UploadFile): model_dir = Path('weights/') model_handle_dir = Path('model_handler/') model_dir.mkdir(parents=True, exist_ok=True) @@ -86,13 +90,16 @@ # 临时保存上传文件 temp_path = Path(f"temp_upload_{uuid.uuid4()}.zip") - with open(temp_path, "wb") as temp_file: - temp_file.write(file.file.read()) + async with aiofiles.open(temp_path, "wb") as temp_file: + content = await file.read() # 异步读取上传文件内容 + await temp_file.write(content) model_file_path = None handle_file_path = None try: + # 使用异步方法读取压缩文件的内容 + loop = asyncio.get_event_loop() with zipfile.ZipFile(temp_path, 'r') as zip_ref: # 获取压缩包文件列表 file_list = zip_ref.namelist() @@ -105,47 +112,47 @@ message=f"Model weight file ({', '.join(SUPPORTED_MODEL_EXTENSIONS)}) is required in the zip." ) - # 解压模型文件到模型目录 - zip_ref.extract(model_file, model_dir) + # 异步解压模型文件到模型目录 model_file_path = model_dir / model_file + await loop.run_in_executor(None, zip_ref.extract, model_file, model_dir) # 检查是否有可选的 Python 脚本 handle_file = next((f for f in file_list if f.endswith(".py")), None) if handle_file: - zip_ref.extract(handle_file, model_handle_dir) handle_file_path = model_handle_dir / handle_file - + await loop.run_in_executor(None, zip_ref.extract, handle_file, model_handle_dir) except zipfile.BadZipFile: raise BizException(status_code=400, message="Invalid zip file.") finally: # 删除临时文件 - temp_path.unlink() + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, temp_path.unlink) return str(model_file_path), str(handle_file_path) if handle_file_path else None - def create_model(self, model_data: AlgoModelCreate, file: UploadFile): - self.process_model_file(file, model_data) + async def create_model(self, model_data: AlgoModelCreate, file: UploadFile): + await self.process_model_file(file, model_data) model = AlgoModel.model_validate(model_data) model.create_time = datetime.now() model.update_time = datetime.now() self.db.add(model) - self.db.commit() - self.db.refresh(model) + await self.db.commit() + await self.db.refresh(model) return model - def process_model_file(self, file, model): - model_file_path, handle_file_path = self.process_zip(file) + async def process_model_file(self, file, model): + model_file_path, handle_file_path = await self.process_zip(file) model.path = model_file_path if handle_file_path: model.handle_task = snake_to_camel(os.path.splitext(os.path.basename(handle_file_path))[0]) else: model.handle_task = 'BaseModelHandler' - def update_model(self, model_data: AlgoModelUpdate, file: UploadFile): - model = self.db.get(AlgoModel, model_data.id) + async def update_model(self, model_data: AlgoModelUpdate, file: UploadFile): + model = await self.get_model_by_id(model_data.id) if not model: return None @@ -155,16 +162,16 @@ model.update_time = datetime.now() if file: - self.process_model_file(file, model) + await self.process_model_file(file, model) self.db.add(model) - self.db.commit() - self.db.refresh(model) + await self.db.commit() + await self.db.refresh(model) self.notify_change(model.id, NotifyChangeType.MODEL_UPDATE) return model - def delete_model(self, model_id: int): - model = self.db.get(AlgoModel, model_id) + async def delete_model(self, model_id: int): + model = await self.get_model_by_id(model_id) if not model: return None # 查询 device_model_relation 中是否存在启用的绑定关系 @@ -173,17 +180,20 @@ .where(DeviceModelRelation.algo_model_id == model_id) .where(DeviceModelRelation.is_use == 1) ) - relation_in_use = self.db.exec(statement).first() + relation_in_use_exec = await self.db.execute(statement) + relation_in_use = relation_in_use_exec.scalars().first() # 如果存在启用的绑定关系,提示无法删除 if relation_in_use: raise BizException(message=f"模型 {model.name} 正在被设备使用,无法删除") - self.db.delete(model) - self.db.commit() + statement = delete(AlgoModel).where(AlgoModel.id == model_id) + await self.db.execute(statement) + await self.db.commit() + return model - def get_models_in_use(self) -> Sequence[AlgoModel]: + async def get_models_in_use(self) -> Sequence[AlgoModel]: """获取所有在 device_model_relation 表里有启用绑定关系的模型信息""" statement = ( select(AlgoModel) @@ -191,10 +201,10 @@ .where(DeviceModelRelation.is_use == 1) .group_by(AlgoModel.id) ) - results = self.db.exec(statement).all() - return results + results = await self.db.execute(statement) + return results.scalars().all() - def get_model_usage(self, algo_model_id) -> bool: + async def get_model_usage(self, algo_model_id) -> bool: statement = ( select(DeviceModelRelation) .where( @@ -202,8 +212,11 @@ DeviceModelRelation.algo_model_id == algo_model_id, ) ) - result = self.db.exec(statement).all() - return len(result) > 0 + result = await self.db.execute(statement) + rows = result.all() + return len(rows) > 0 - def get_model_by_id(self, model_id): - return self.db.get(AlgoModel, model_id) + async def get_model_by_id(self, model_id): + result = await self.db.execute(select(AlgoModel).where(AlgoModel.id == model_id)) + model = result.scalar_one_or_none() + return model diff --git a/services/push_config_service.py b/services/push_config_service.py index 8a2c298..14862df 100644 --- a/services/push_config_service.py +++ b/services/push_config_service.py @@ -1,6 +1,8 @@ +import asyncio from datetime import datetime -from sqlmodel import Session, select +from sqlmodel import select +from sqlalchemy.ext.asyncio import AsyncSession from entity.push_config import PushConfigCreate, PushConfig @@ -13,7 +15,7 @@ cls._instance = super().__new__(cls) return cls._instance - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): if not hasattr(self, 'initialized'): self.db = db self.__push_change_callbacks = [] # 用于存储回调函数 @@ -23,12 +25,22 @@ """注册设备变化回调函数""" self.__push_change_callbacks.append(callback) - def notify_change(self, push_config): - for callback in self.__push_change_callbacks: - callback(push_config) + # def notify_change(self, push_config): + # for callback in self.__push_change_callbacks: + # callback(push_config) - def set_push_config(self, push_config_create: PushConfigCreate): - push_config = self.get_push_config(push_config_create.push_type) + def notify_change(self, config): + """通知所有回调函数""" + for callback in self.__push_change_callbacks: + if asyncio.iscoroutinefunction(callback): + # 如果是异步函数,使用 asyncio.create_task() 调度 + asyncio.create_task(callback(config)) + else: + # 如果是同步函数,直接调用 + callback(config) + + async def set_push_config(self, push_config_create: PushConfigCreate): + push_config = await self.get_push_config(push_config_create.push_type) if push_config: update_data = push_config_create.dict(exclude_unset=True) for key, value in update_data.items(): @@ -40,18 +52,18 @@ push_config.update_time = datetime.now() self.db.add(push_config) - self.db.commit() - self.db.refresh(push_config) + await self.db.commit() + await self.db.refresh(push_config) self.notify_change(push_config) return push_config - def get_push_config(self, push_type): + async def get_push_config(self, push_type): statement = select(PushConfig).where(PushConfig.push_type == push_type) - results = self.db.exec(statement) - return results.first() + results = await self.db.execute(statement) + return results.scalars().first() - def get_push_config_list(self): + async def get_push_config_list(self): statement = select(PushConfig) - results = self.db.exec(statement) - return results.all() \ No newline at end of file + results = await self.db.execute(statement) + return results.scalars().all() diff --git a/services/scene_service.py b/services/scene_service.py index 17c8f52..dc4c67b 100644 --- a/services/scene_service.py +++ b/services/scene_service.py @@ -1,3 +1,4 @@ +import asyncio import os import uuid import zipfile @@ -5,10 +6,11 @@ from pathlib import Path from typing import Optional, Sequence, Tuple, List +import aiofiles from fastapi import UploadFile from sqlalchemy import func -from sqlmodel import Session, select - +from sqlmodel import select, delete +from sqlalchemy.ext.asyncio import AsyncSession from common.biz_exception import BizException from common.consts import NotifyChangeType from common.global_thread_pool import GlobalThreadPool @@ -18,7 +20,7 @@ class SceneService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db self.__scene_change_callbacks = [] # 用于存储回调函数 self.thread_pool = GlobalThreadPool() @@ -32,37 +34,39 @@ for callback in self.__scene_change_callbacks: self.thread_pool.executor.submit(callback, scene_id, change_type) - def get_scene_list(self, - name: Optional[str] = None, - remark: Optional[str] = None, - ) -> Sequence[Scene]: + async def get_scene_list(self, + name: Optional[str] = None, + remark: Optional[str] = None, + ) -> Sequence[Scene]: statement = self.scene_query(name, remark) - results = self.db.exec(statement) - return results.all() + results = await self.db.execute(statement) + return results.scalars().all() - def get_scene_page(self, - name: Optional[str] = None, - remark: Optional[str] = None, - offset: int = 0, - limit: int = 10 - ) -> Tuple[Sequence[SceneInfo], int]: + async def get_scene_page(self, + name: Optional[str] = None, + remark: Optional[str] = None, + offset: int = 0, + limit: int = 10 + ) -> Tuple[Sequence[SceneInfo], int]: statement = self.scene_query(name, remark) # 查询总记录数 total_statement = select(func.count()).select_from(statement.subquery()) - total = self.db.exec(total_statement).one() + total_result = await self.db.execute(total_statement) + total = total_result.scalar_one() # 添加分页限制 statement = statement.offset(offset).limit(limit) # 执行查询并返回结果 - scene_list = self.db.exec(statement) + scene_list = await self.db.execute(statement) + rows = scene_list.scalars().all() scene_info_list: List[SceneInfo] = [] if scene_list: - for scene in scene_list: + for scene in rows: scene_info_list.append(SceneInfo( **scene.dict(), - usage_status="使用中" if self.get_scene_usage(scene.id) else "未使用" + usage_status="使用中" if await self.get_scene_usage(scene.id) else "未使用" )) return scene_info_list, total # 返回分页数据和总数 @@ -76,7 +80,7 @@ statement = statement.where(Scene.remark.like(f"%{remark}%")) return statement - def process_zip(self, file: UploadFile): + async def process_zip(self, file: UploadFile): model_dir = Path('weights/') scene_handle_dir = Path('scene_handler/') model_dir.mkdir(parents=True, exist_ok=True) @@ -86,61 +90,65 @@ # 临时保存上传文件 temp_path = Path(f"temp_upload_{uuid.uuid4()}.zip") - with open(temp_path, "wb") as temp_file: - temp_file.write(file.file.read()) + async with aiofiles.open(temp_path, "wb") as temp_file: + content = await file.read() # 异步读取上传文件内容 + await temp_file.write(content) model_file_paths = [] handle_file_path = None try: + # 使用异步方法读取压缩文件的内容 + loop = asyncio.get_event_loop() with zipfile.ZipFile(temp_path, 'r') as zip_ref: # 获取压缩包文件列表 file_list = zip_ref.namelist() model_files = [f for f in file_list if Path(f).suffix in SUPPORTED_MODEL_EXTENSIONS] - # 解压所有模型文件到模型目录 + # 异步解压所有模型文件到模型目录 for model_file in model_files: - zip_ref.extract(model_file, model_dir) + await loop.run_in_executor(None, zip_ref.extract, model_file, model_dir) model_file_paths.append(model_dir / model_file) # 检查是否有可选的 Python 脚本 handle_file = next((f for f in file_list if f.endswith(".py")), None) if handle_file: - zip_ref.extract(handle_file, scene_handle_dir) + await loop.run_in_executor(None, zip_ref.extract, handle_file, scene_handle_dir) handle_file_path = scene_handle_dir / handle_file else: raise BizException( status_code=400, - message=f"handle file (.py) is required in the zip." + message=f"Handle file (.py) is required in the zip." ) except zipfile.BadZipFile: raise BizException(status_code=400, message="Invalid zip file.") finally: # 删除临时文件 - temp_path.unlink() + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, temp_path.unlink) return [str(path) for path in model_file_paths], str(handle_file_path) - def process_scene_file(self, file, scene): - model_file_paths, handle_file_path = self.process_zip(file) + async def process_scene_file(self, file, scene): + model_file_paths, handle_file_path = await self.process_zip(file) scene.handle_task = snake_to_camel(os.path.splitext(os.path.basename(handle_file_path))[0]) - def create_scene(self, scene_data: SceneCreate, file: UploadFile): - self.process_scene_file(file, scene_data) + async def create_scene(self, scene_data: SceneCreate, file: UploadFile): + await self.process_scene_file(file, scene_data) scene = Scene.model_validate(scene_data) scene.create_time = datetime.now() scene.update_time = datetime.now() self.db.add(scene) - self.db.commit() - self.db.refresh(scene) + await self.db.commit() + await self.db.refresh(scene) return scene - def update_scene(self, scene_data: SceneUpdate, file: UploadFile): - scene = self.db.get(Scene, scene_data.id) + async def update_scene(self, scene_data: SceneUpdate, file: UploadFile): + scene = await self.get_scene_by_id(scene_data.id) if not scene: return None @@ -150,16 +158,16 @@ scene.update_time = datetime.now() if file: - self.process_scene_file(file, scene) + await self.process_scene_file(file, scene) self.db.add(scene) - self.db.commit() - self.db.refresh(scene) + await self.db.commit() + await self.db.refresh(scene) self.notify_change(scene.id, NotifyChangeType.SCENE_UPDATE) return scene - def delete_scene(self, scene_id: int): - scene = self.db.get(Scene, scene_id) + async def delete_scene(self, scene_id: int): + scene = await self.get_scene_by_id(scene_id) if not scene: return None # 查询 device_scene_relation 中是否存在启用的绑定关系 @@ -167,35 +175,40 @@ select(DeviceSceneRelation) .where(DeviceSceneRelation.scene_id == scene_id) ) - relation_in_use = self.db.exec(statement).first() + relation_in_use_exec = self.db.execute(statement) + relation_in_use = relation_in_use_exec.scalars().first() # 如果存在启用的绑定关系,提示无法删除 if relation_in_use: raise BizException(message=f"场景 {scene.name} 正在被设备使用,无法删除") - self.db.delete(scene) - self.db.commit() + statement = delete(Scene).where(Scene.id == scene_id) + await self.db.execute(statement) + await self.db.commit() return scene - def get_scenes_in_use(self) -> Sequence[Scene]: + async def get_scenes_in_use(self) -> Sequence[Scene]: """获取所有在 device_scene_relation 表里有绑定关系的模型信息""" statement = ( select(Scene) .join(DeviceSceneRelation, DeviceSceneRelation.scene_id == Scene.id) .group_by(Scene.id) ) - results = self.db.exec(statement).all() - return results + results = await self.db.execute(statement) + return results.scalars().all() - def get_scene_usage(self, scene_id) -> bool: + async def get_scene_usage(self, scene_id) -> bool: statement = ( select(DeviceSceneRelation) .where( DeviceSceneRelation.scene_id == scene_id, ) ) - result = self.db.exec(statement).all() - return len(result) > 0 + result = await self.db.exec(statement) + rows = result.all() + return len(rows) > 0 - def get_scene_by_id(self, scene_id): - return self.db.get(Scene, scene_id) \ No newline at end of file + async def get_scene_by_id(self, scene_id): + result = await self.db.execute(select(Scene).where(Scene.id == scene_id)) + scene = result.scalar_one_or_none() + return scene diff --git a/tcp/tcp_client_connector.py b/tcp/tcp_client_connector.py index 3d6d076..d645340 100644 --- a/tcp/tcp_client_connector.py +++ b/tcp/tcp_client_connector.py @@ -113,29 +113,29 @@ logger.error(f"Error during query for {self.ip}:{self.port}: {e}") await self.reconnect() - def parse_response(self, data): + async def parse_response(self, data): """解析设备返回的数据""" logger.info(f"Received data from {self.ip}:{self.port}: {format_bytes(data)}") try: res = parse_gas_data(data) logger.info(res) - with next(get_db()) as db: + async for db in get_db(): data_gas_service = DataGasService(db) data_gas = DataGas( device_code=res['device_code'], gas_value=res['gas_value'] ) - data_gas_service.add_data_gas(data_gas) - self.gas_push(data_gas) + await data_gas_service.add_data_gas(data_gas) + await self.gas_push(data_gas) except Exception as e: logger.error(f"Parse and save gas data failed: {e}") - def gas_push(self, data_gas): + async def gas_push(self, data_gas): global_config = GlobalConfig() gas_push_config = global_config.get_gas_push_config() - if gas_push_config: + if gas_push_config and gas_push_config.push_url: last_ts = self.push_ts_dict.get(data_gas.device_code) current_time = datetime.now() @@ -161,7 +161,7 @@ data = await asyncio.wait_for(self.reader.read(1024), timeout=self.timeout) # if not data: # raise ConnectionResetError("Connection lost or no data received") - self.parse_response(data) + await self.parse_response(data) return data # 返回响应数据 else: return None diff --git a/algo/device_detection_task.py b/algo/device_detection_task.py index dbc3fee..cdb97bf 100644 --- a/algo/device_detection_task.py +++ b/algo/device_detection_task.py @@ -98,7 +98,7 @@ def push_frame_results(self, frame_results): global_config = GlobalConfig() push_config = global_config.get_algo_result_push_config() - if push_config: + if push_config and push_config.push_url: last_ts = self.push_ts current_time = datetime.now() diff --git a/apis/control.py b/apis/control.py index 8517389..4c108c5 100644 --- a/apis/control.py +++ b/apis/control.py @@ -48,7 +48,7 @@ @router.get("/restart") -def restart(): +async def restart(): try: # 立即返回响应的函数 def restart_container_async(): diff --git a/apis/data_gas.py b/apis/data_gas.py index fc8ec54..bcb7cb6 100644 --- a/apis/data_gas.py +++ b/apis/data_gas.py @@ -3,8 +3,8 @@ import pandas as pd from fastapi import APIRouter, Query, Depends -from openpyxl.utils import get_column_letter -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession + from starlette.responses import StreamingResponse from apis.base import StandardResponse, PageResponse, convert_page_param, standard_response @@ -17,19 +17,20 @@ @router.get("/page", response_model=StandardResponse[PageResponse[DataGasInfo]]) -def get_gas_page( +async def get_gas_page( device_code: Optional[str] = None, start_time: Optional[str] = None, end_time: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): + service = DataGasService(db) offset, limit = convert_page_param(offset, limit) - data, total = service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), - offset, limit) + data, total = await service.get_data_gas_page(device_code, parse_datetime(start_time), parse_datetime(end_time), + offset, limit) return standard_response( data=PageResponse(total=total, items=data) @@ -37,38 +38,39 @@ @router.get("/export") -def export_data_gas(device_code: Optional[str] = None, - start_time: Optional[str] = None, - end_time: Optional[str] = None, - db: Session = Depends(get_db)): - service = DataGasService(db) - data = service.get_data_gas_list(device_code, start_time, end_time) +async def export_data_gas(device_code: Optional[str] = None, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + db: AsyncSession = Depends(get_db)): + async with db: + service = DataGasService(db) + data = await service.get_data_gas_list(device_code, start_time, end_time) - # 将查询结果转换为 DataFrame - data = [ - {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} - for item in data] - df = pd.DataFrame(data) + # 将查询结果转换为 DataFrame + data = [ + {"设备名称": item.device_name, "设备编号": item.device_code, "燃气浓度(ppm.m)": item.gas_value, "时间": item.ts} + for item in data] + df = pd.DataFrame(data) - # 使用 BytesIO 生成内存中的 Excel 文件 - output = BytesIO() - with pd.ExcelWriter(output, engine="openpyxl") as writer: - df.to_excel(writer, index=False) + # 使用 BytesIO 生成内存中的 Excel 文件 + output = BytesIO() + with pd.ExcelWriter(output, engine="openpyxl") as writer: + df.to_excel(writer, index=False) - # 获取工作表 - worksheet = writer.sheets["Sheet1"] + # 获取工作表 + worksheet = writer.sheets["Sheet1"] - # 设置固定的列宽,例如宽度为 20 - fixed_width = 30 - for col in worksheet.columns: - col_letter = col[0].column_letter # 获取列字母 - worksheet.column_dimensions[col_letter].width = fixed_width + # 设置固定的列宽,例如宽度为 20 + fixed_width = 30 + for col in worksheet.columns: + col_letter = col[0].column_letter # 获取列字母 + worksheet.column_dimensions[col_letter].width = fixed_width - output.seek(0) + output.seek(0) - # 返回内存中的 Excel 文件作为响应 - return StreamingResponse( - output, - media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", - headers={"Content-Disposition": "attachment; filename=gas.xlsx"} - ) + # 返回内存中的 Excel 文件作为响应 + return StreamingResponse( + output, + media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + headers={"Content-Disposition": "attachment; filename=gas.xlsx"} + ) diff --git a/apis/device.py b/apis/device.py index 85fb10c..f170ae3 100644 --- a/apis/device.py +++ b/apis/device.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db @@ -18,29 +18,29 @@ @router.get("/list", response_model=StandardResponse[List[DeviceInfo]]) -def get_device_list( +async def get_device_list( name: Optional[str] = None, code: Optional[str] = None, device_type: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceService(db) - devices = list(service.get_device_list(name, code, device_type)) - return standard_response(data=devices) + devices = await service.get_device_list(name, code, device_type) + return standard_response(data=list(devices)) @router.get("/page/", response_model=StandardResponse[PageResponse[DeviceInfo]]) -def get_device_page( +async def get_device_page( name: Optional[str] = None, code: Optional[str] = None, device_type: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - devices, total = service.get_device_page(name, code, device_type, offset, limit) + devices, total = await service.get_device_page(name, code, device_type, offset, limit) return standard_response( data=PageResponse(total=total, items=devices) @@ -48,22 +48,22 @@ @router.post("/add", response_model=StandardResponse[DeviceInfo]) -def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): - device = service.create_device(device_data) +async def create_device(device_data: DeviceCreate, service: DeviceService = Depends(get_service)): + device = await service.create_device(device_data) return standard_response(data=device) @router.post("/update", response_model=StandardResponse[DeviceInfo]) -def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): - device = service.update_device(device_data) +async def update_device(device_data: DeviceUpdate, service: DeviceService = Depends(get_service)): + device = await service.update_device(device_data) if not device: return standard_error_response(data=device_data, message="Device not found") return standard_response(data=device) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_device(device_id: int, service: DeviceService = Depends(get_service)): - device = service.delete_device(device_id) +async def delete_device(device_id: int, service: DeviceService = Depends(get_service)): + device = await service.delete_device(device_id) if not device: return standard_error_response(data=device_id, message="Device not found") return standard_response(data=device_id) diff --git a/apis/device_model_realtion.py b/apis/device_model_realtion.py index e61e1d6..ddb5eb6 100644 --- a/apis/device_model_realtion.py +++ b/apis/device_model_realtion.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import StandardResponse, standard_response from db.database import get_db @@ -13,37 +13,23 @@ router = APIRouter() app = get_app() + def get_service(): return app.state.model_relation_service @router.get("/list_by_device", response_model=StandardResponse[List[DeviceModelRelationInfo]]) -def list_by_device( +async def list_by_device( device_id: int, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceModelRelationService(db) - models = list(service.get_device_models(device_id)) - return standard_response(data=models) - - -# @router.post("/add_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) -# def add_by_device(relation_data: List[DeviceModelRelationCreate], -# device_id: int = Query(...), -# db: Session = Depends(get_db)): -# service = DeviceModelRelationService(db) -# relations = service.add_relations_by_device(device_id, relation_data) -# return standard_response(data=relations) + models = await service.get_device_models(device_id) + return standard_response(data=list(models)) @router.post("/update_by_device", response_model=StandardResponse[List[DeviceModelRelation]]) -def update_by_device(relation_data: List[DeviceModelRelationCreate], - device_id: int = Query(...), - service: DeviceModelRelationService = Depends(get_service)): - relations = service.update_relations_by_device(device_id, relation_data) +async def update_by_device(relation_data: List[DeviceModelRelationCreate], + device_id: int = Query(...), + service: DeviceModelRelationService = Depends(get_service)): + relations = await service.update_relations_by_device(device_id, relation_data) return standard_response(data=relations) - -# @router.delete("/delete_by_device", response_model=StandardResponse[int]) -# def delete_device(device_id: int, db: Session = Depends(get_db)): -# service = DeviceModelRelationService(db) -# count = service.delete_relations_by_device(device_id) -# return standard_response(data=count) diff --git a/apis/device_scene_realtion.py b/apis/device_scene_realtion.py index 9a30a44..d9aa3a2 100644 --- a/apis/device_scene_realtion.py +++ b/apis/device_scene_realtion.py @@ -1,8 +1,8 @@ -from typing import List, Optional + from fastapi import APIRouter, Depends, Query -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import StandardResponse, standard_response from db.database import get_db from entity.device_scene_relation import DeviceSceneRelationInfo, DeviceSceneRelation @@ -13,17 +13,17 @@ @router.get("/get_by_device", response_model=StandardResponse[DeviceSceneRelationInfo]) -def list_by_device( +async def list_by_device( device_id: int, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - scene = service.get_device_scene(device_id) + scene = await service.get_device_scene(device_id) return standard_response(data=scene) @router.post("/update_by_device", response_model=StandardResponse[DeviceSceneRelation]) -def update_by_device(device_id: int, scene_id: int, - db: Session = Depends(get_db)): +async def update_by_device(device_id: int, scene_id: int, + db: AsyncSession = Depends(get_db)): service = DeviceSceneRelationService(db) - relation = service.update_relation_by_device(device_id, scene_id) + relation = await service.update_relation_by_device(device_id, scene_id) return standard_response(data=relation) diff --git a/apis/frame.py b/apis/frame.py index 1cb9294..5146ede 100644 --- a/apis/frame.py +++ b/apis/frame.py @@ -5,7 +5,7 @@ import cv2 from fastapi import APIRouter, Depends, Query, HTTPException, Request -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from fastapi.responses import StreamingResponse, FileResponse from apis.base import StandardResponse, PageResponse, standard_response, standard_error_response, convert_page_param @@ -19,20 +19,20 @@ @router.get("/page/", response_model=StandardResponse[PageResponse[DeviceFrame]]) -def get_frame_page( +async def get_frame_page( device_name: Optional[str] = None, device_code: Optional[str] = None, frame_start_time: Optional[str] = None, frame_end_time: Optional[str] = None, offset: int = Query(0, ge=1), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = DeviceFrameService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - frames, total = service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), - parse_datetime(frame_end_time), offset, limit) + frames, total = await service.get_frame_page(device_name, device_code, parse_datetime(frame_start_time), + parse_datetime(frame_end_time), offset, limit) return standard_response( data=PageResponse(total=total, items=frames) @@ -41,10 +41,10 @@ # 路由:使用 OpenCV 生成内存图像并返回字节流 @router.get("/frame_image/") -def get_frame_image(frame_id, db: Session = Depends(get_db)): +async def get_frame_image(frame_id, db: AsyncSession = Depends(get_db)): try: service = DeviceFrameService(db) - frame = service.get_frame_annotator(frame_id) + frame = await service.get_frame_annotator(frame_id) if frame is None: return standard_error_response(message="Frame does not exist") @@ -63,14 +63,3 @@ except Exception as e: raise HTTPException(status_code=500, detail=f"Error reading image: {str(e)}") - -@router.get("/frame_test") -def get_frame_test(request: Request): - file_path = "test.jpg" - return FileResponse(file_path) - # def iterfile(): - # with open(file_path, 'rb') as file: - # while chunk := file.read(1024): - # yield chunk - # - # return StreamingResponse(iterfile(), media_type="image/jpeg") diff --git a/apis/model.py b/apis/model.py index a1161eb..fa2b191 100644 --- a/apis/model.py +++ b/apis/model.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query, UploadFile, File, Form -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from app_instance import get_app @@ -18,27 +18,27 @@ @router.get("/list", response_model=StandardResponse[List[AlgoModelInfo]]) -def get_model_list( +async def get_model_list( name: Optional[str] = None, remark: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = ModelService(db) - models = list(service.get_model_list(name, remark)) - return standard_response(data=models) + models = await service.get_model_list(name, remark) + return standard_response(data=list(models)) @router.get("/page/", response_model=StandardResponse[PageResponse[AlgoModelInfo]]) -def get_model_page( +async def get_model_page( name: Optional[str] = None, remark: Optional[str] = None, offset: int = Query(0, ge=0), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = ModelService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - models, total = service.get_model_page(name, remark, offset, limit) + models, total = await service.get_model_page(name, remark, offset, limit) return standard_response( data=PageResponse(total=total, items=models) @@ -46,9 +46,9 @@ @router.post("/add", response_model=StandardResponse[AlgoModelInfo]) -def create_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelCreate结构"), +async def create_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelCreate结构"), file: UploadFile = File(..., description="模型文件"), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_db), ): # 检查文件类型 if not file.filename.endswith(".zip"): @@ -56,25 +56,25 @@ model_data = AlgoModelCreate.parse_raw(json_data) service = ModelService(db) - model = service.create_model(model_data, file) + model = await service.create_model(model_data, file) return standard_response(data=model) @router.post("/update", response_model=StandardResponse[AlgoModelInfo]) -def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), +async def update_model(json_data: str = Form(..., description="JSON数据字段,内容为AlgoModelUpdate结构"), file: UploadFile = File(None, description="模型文件"), service: ModelService = Depends(get_service)): model_data = AlgoModelUpdate.parse_raw(json_data) - model = service.update_model(model_data, file) + model = await service.update_model(model_data, file) if not model: return standard_error_response(data=model_data, message="Model not found") return standard_response(data=model) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_model(model_id: int, db: Session = Depends(get_db)): +async def delete_model(model_id: int, db: AsyncSession = Depends(get_db)): service = ModelService(db) - model = service.delete_model(model_id) + model = await service.delete_model(model_id) if not model: return standard_error_response(data=model_id, message="Model not found") return standard_response(data=model_id) diff --git a/apis/push_config.py b/apis/push_config.py index 5e25579..e234afb 100644 --- a/apis/push_config.py +++ b/apis/push_config.py @@ -1,7 +1,7 @@ from typing import List from fastapi import APIRouter, Depends -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse from db.database import get_db @@ -12,21 +12,21 @@ @router.get("/list", response_model=StandardResponse[List[PushConfig]]) -def get_push_config_list(db: Session = Depends(get_db)): +async def get_push_config_list(db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_configs = service.get_push_config_list() + push_configs = await service.get_push_config_list() return standard_response(data=push_configs) @router.get("/get_by_type", response_model=StandardResponse[PushConfig]) -def get_by_type(push_type: int, db: Session = Depends(get_db)): +async def get_by_type(push_type: int, db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_config = service.get_push_config(push_type) + push_config = await service.get_push_config(push_type) return standard_response(data=push_config) @router.post("/set_push_config", response_model=StandardResponse[PushConfig]) -def set_push_config(push_config: PushConfigCreate, db: Session = Depends(get_db)): +async def set_push_config(push_config: PushConfigCreate, db: AsyncSession = Depends(get_db)): service = PushConfigService(db) - push_config = service.set_push_config(push_config) + push_config = await service.set_push_config(push_config) return standard_response(data=push_config) diff --git a/apis/scene.py b/apis/scene.py index f84b54b..1b7f706 100644 --- a/apis/scene.py +++ b/apis/scene.py @@ -1,7 +1,7 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query, UploadFile, File, Form -from sqlmodel import Session +from sqlalchemy.ext.asyncio import AsyncSession from apis.base import standard_response, StandardResponse, PageResponse, standard_error_response, convert_page_param from db.database import get_db @@ -14,27 +14,27 @@ @router.get("/list", response_model=StandardResponse[List[Scene]]) -def get_scene_list( +async def get_scene_list( name: Optional[str] = None, remark: Optional[str] = None, - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) - scenes = list(service.get_scene_list(name, remark)) - return standard_response(data=scenes) + scenes = await service.get_scene_list(name, remark) + return standard_response(data=list(scenes)) @router.get("/page/", response_model=StandardResponse[PageResponse[SceneInfo]]) -def get_scene_page( +async def get_scene_page( name: Optional[str] = None, remark: Optional[str] = None, offset: int = Query(0, ge=0), # 从第几页开始 limit: int = Query(10, ge=1), # 每页显示多少条记录 - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) # 获取分页后的设备列表和总数 offset, limit = convert_page_param(offset, limit) - scenes, total = service.get_scene_page(name, remark, offset, limit) + scenes, total = await service.get_scene_page(name, remark, offset, limit) return standard_response( data=PageResponse(total=total, items=scenes) @@ -42,9 +42,9 @@ @router.post("/add", response_model=StandardResponse[SceneInfo]) -def create_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneCreate结构"), +async def create_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneCreate结构"), file: UploadFile = File(..., description="模型文件"), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_db), ): # 检查文件类型 if not file.filename.endswith(".zip"): @@ -52,26 +52,26 @@ scene_data = SceneCreate.parse_raw(json_data) service = SceneService(db) - scene = service.create_scene(scene_data, file) + scene = await service.create_scene(scene_data, file) return standard_response(data=scene) @router.post("/update", response_model=StandardResponse[SceneInfo]) -def update_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneUpdate结构"), +async def update_scene(json_data: str = Form(..., description="JSON数据字段,内容为SceneUpdate结构"), file: UploadFile = File(None, description="模型文件"), - db: Session = Depends(get_db)): + db: AsyncSession = Depends(get_db)): service = SceneService(db) scene_data = SceneUpdate.parse_raw(json_data) - scene = service.update_scene(scene_data, file) + scene = await service.update_scene(scene_data, file) if not scene: return standard_error_response(data=scene_data, message="Scene not found") return standard_response(data=scene) @router.delete("/delete", response_model=StandardResponse[int]) -def delete_scene(scene_id: int, db: Session = Depends(get_db)): +async def delete_scene(scene_id: int, db: AsyncSession = Depends(get_db)): service = SceneService(db) - scene = service.delete_scene(scene_id) + scene = await service.delete_scene(scene_id) if not scene: return standard_error_response(data=scene_id, message="Scene not found") return standard_response(data=scene_id) diff --git a/app_instance.py b/app_instance.py index a6e0992..f9d1ad9 100644 --- a/app_instance.py +++ b/app_instance.py @@ -12,6 +12,7 @@ from services.device_model_relation_service import DeviceModelRelationService from services.device_scene_relation_service import DeviceSceneRelationService from services.device_service import DeviceService +from services.global_config import GlobalConfig from services.model_service import ModelService from services.scene_service import SceneService from tcp.tcp_manager import TcpManager @@ -28,7 +29,11 @@ async def lifespan(app: FastAPI): main_loop = asyncio.get_running_loop() - with next(get_db()) as db: + # async with get_db() as db: + async for db in get_db(): + global_config = GlobalConfig() + await global_config.init_config() + device_service = DeviceService(db) model_service = ModelService(db) model_relation_service = DeviceModelRelationService(db) @@ -51,7 +56,7 @@ relation_service=model_relation_service, ) app.state.algo_runner = algo_runner - await algo_runner.start() + # await algo_runner.start() scene_runner = SceneRunner( device_service=device_service, @@ -61,7 +66,9 @@ main_loop=main_loop ) app.state.scene_runner = scene_runner - await scene_runner.start() + # await scene_runner.start() + + yield # 允许请求处理 diff --git a/common/global_logger.py b/common/global_logger.py index 3e3a6a9..3e672ca 100644 --- a/common/global_logger.py +++ b/common/global_logger.py @@ -12,6 +12,8 @@ logger = logging.getLogger("casic_safe_logger") logger.setLevel(logging.DEBUG) # 设置日志级别 +logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING) + # 创建一个TimedRotatingFileHandler handler = logging.handlers.TimedRotatingFileHandler( os.path.join(log_dir, 'app.log'), # 日志文件名 diff --git a/db/database.py b/db/database.py index 4f44d2d..01ac013 100644 --- a/db/database.py +++ b/db/database.py @@ -1,22 +1,25 @@ +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession +from sqlalchemy.orm import sessionmaker from sqlmodel import SQLModel, create_engine, Session -from contextlib import contextmanager +from contextlib import contextmanager, asynccontextmanager sqlite_file_name = "./db/safe-algo-pro.db" -sqlite_url = f"sqlite:///{sqlite_file_name}" +sqlite_url = f"sqlite+aiosqlite:///{sqlite_file_name}" # 使用异步SQLite驱动 -connect_args = {"check_same_thread": False} -engine = create_engine(sqlite_url, connect_args=connect_args) +engine = create_async_engine(sqlite_url, echo=False, future=True) # 初始化数据库表 -def init_db(): - SQLModel.metadata.create_all(engine) +async def init_db(): + async with engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.create_all) -# 数据库会话管理 -def get_db(): - session = Session(engine) - try: +# 异步数据库会话管理 +# @asynccontextmanager +async def get_db() -> AsyncSession: + async_session = sessionmaker( + engine, class_=AsyncSession, expire_on_commit=False + ) + async with async_session() as session: yield session - finally: - session.close() diff --git a/main.py b/main.py index e35e458..e830e86 100644 --- a/main.py +++ b/main.py @@ -3,7 +3,9 @@ import logging from fastapi.openapi.docs import get_swagger_ui_html -from starlette.staticfiles import StaticFiles +from fastapi.middleware.cors import CORSMiddleware +from fastapi.staticfiles import StaticFiles + from app_instance import get_app from common.global_logger import logger @@ -28,6 +30,14 @@ from apis.router import router app.include_router(router, prefix="/api") +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + if __name__ == "__main__": # 重定向 uvicorn 的日志 uvicorn_logger = logging.getLogger("uvicorn") diff --git a/requirements.txt b/requirements.txt index 161192a..aa9e1e6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,6 @@ starlette uvicorn sqlalchemy -aiohttp \ No newline at end of file +aiohttp +aiosqlite +aiofiles \ No newline at end of file diff --git a/services/data_gas_service.py b/services/data_gas_service.py index fe6ce53..6ac5665 100644 --- a/services/data_gas_service.py +++ b/services/data_gas_service.py @@ -2,23 +2,24 @@ from typing import Optional, Tuple, Sequence from sqlalchemy import func -from sqlmodel import Session, select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlmodel import select from entity.data_gas import DataGas, DataGasInfo from entity.device import Device class DataGasService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db - def add_data_gas(self, data_gas: DataGas): + async def add_data_gas(self, data_gas: DataGas): self.db.add(data_gas) - self.db.commit() - self.db.refresh(data_gas) + await self.db.commit() + await self.db.refresh(data_gas) return data_gas - def get_data_gas_page(self, + async def get_data_gas_page(self, device_code: Optional[str] = None, start_time: Optional[datetime] = None, end_time: Optional[datetime] = None, @@ -29,13 +30,17 @@ # 查询总记录数 total_statement = select(func.count()).select_from(statement.subquery()) - total = self.db.exec(total_statement).one() + total_result = await self.db.execute(total_statement) + total = total_result.scalar_one() # 添加分页限制 statement = statement.offset(offset).limit(limit) # 执行查询并返回结果 - results = self.db.exec(statement).all() + results = await self.db.execute(statement) + rows = results.all() + + data_gas_info_list = [ DataGasInfo( id=data_gas.id, @@ -44,18 +49,19 @@ ts=data_gas.ts, device_name=device_name ) - for data_gas, device_name in results + for data_gas, device_name in rows ] return data_gas_info_list, total # 返回分页数据和总数 - def get_data_gas_list(self, + async def get_data_gas_list(self, device_code: Optional[str] = None, start_time: Optional[datetime] = None, end_time: Optional[datetime] = None, ) -> Sequence[DataGasInfo]: statement = self.gas_query(device_code, end_time, start_time) - results = self.db.exec(statement).all() + results = await self.db.execute(statement) + rows = results.all() data_gas_info_list = [ DataGasInfo( id=data_gas.id, @@ -64,7 +70,7 @@ ts=data_gas.ts, device_name=device_name ) - for data_gas, device_name in results + for data_gas, device_name in rows ] return data_gas_info_list diff --git a/services/device_frame_service.py b/services/device_frame_service.py index 5dead7d..1a468ae 100644 --- a/services/device_frame_service.py +++ b/services/device_frame_service.py @@ -3,7 +3,10 @@ from copy import deepcopy from datetime import datetime from typing import Sequence, Optional, Tuple -from sqlmodel import Session + +import aiofiles +import numpy as np +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import func @@ -13,46 +16,56 @@ import cv2 -from sqlmodel import Session, select, delete +from sqlmodel import select, delete from services.frame_analysis_result_service import FrameAnalysisResultService class DeviceFrameService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db - def add_frame(self, device_id, frame_data) -> DeviceFrame: - def save_frame_file(): - # 生成当前年月日作为目录路径 - current_date = datetime.now().strftime('%Y-%m-%d') - # 生成随机 UUID 作为文件名 - file_name = str(uuid.uuid4()) + ".jpeg" - # 创建保存图片的完整路径 - save_path = os.path.join('./storage/frames', current_date, file_name) - # 创建目录(如果不存在) - os.makedirs(os.path.dirname(save_path), exist_ok=True) - # 保存图片 - cv2.imwrite(save_path, frame_data) - return save_path + async def add_frame(self, device_id, frame_data) -> DeviceFrame: + async def add_frame(self, device_id, frame_data) -> 'DeviceFrame': + async def save_frame_file(): + # 生成当前年月日作为目录路径 + current_date = datetime.now().strftime('%Y-%m-%d') + # 生成随机 UUID 作为文件名 + file_name = f"{uuid.uuid4()}.jpeg" + # 创建保存图片的完整路径 + save_path = os.path.join('./storage/frames', current_date, file_name) + # 创建目录(如果不存在) + os.makedirs(os.path.dirname(save_path), exist_ok=True) - # 保存图片文件 - file_path = save_frame_file() - device_frame = DeviceFrame(device_id=device_id, frame_path=file_path) - self.db.add(device_frame) - self.db.commit() - self.db.refresh(device_frame) - return device_frame + # 将 frame_data 转换为二进制数据(假设 frame_data 是一个 numpy 数组) + _, encoded_image = cv2.imencode('.jpeg', frame_data) + image_data = encoded_image.tobytes() - def get_frame_page(self, - device_name: Optional[str] = None, - device_code: Optional[str] = None, - frame_start_time: Optional[datetime] = None, - frame_end_time: Optional[datetime] = None, - offset: int = 0, - limit: int = 10 - ) -> Tuple[Sequence[DeviceFrame], int]: + # 使用 aiofiles 进行异步写入 + async with aiofiles.open(save_path, 'wb') as f: + await f.write(image_data) + + return save_path + + # 异步保存图片文件 + file_path = await save_frame_file() + + # 创建并保存到数据库中 + device_frame = DeviceFrame(device_id=device_id, frame_path=file_path) + self.db.add(device_frame) + await self.db.commit() + await self.db.refresh(device_frame) + return device_frame + + async def get_frame_page(self, + device_name: Optional[str] = None, + device_code: Optional[str] = None, + frame_start_time: Optional[datetime] = None, + frame_end_time: Optional[datetime] = None, + offset: int = 0, + limit: int = 10 + ) -> Tuple[Sequence[DeviceFrame], int]: statement = ( select(DeviceFrame, Device) .join(Device, DeviceFrame.device_id == Device.id) @@ -69,26 +82,37 @@ # 查询总记录数 total_statement = select(func.count()).select_from(statement.subquery()) - total = self.db.exec(total_statement).one() + total_result = await self.db.execute(total_statement) + total = total_result.scalar_one() # 添加分页限制 statement = statement.offset(offset).limit(limit) # 执行查询并返回结果 - results = self.db.exec(statement) - results = results.all() - frames = [frame for frame, device in results] + results = await self.db.execute(statement) + rows = results.all() + + frames = [frame for frame, device in rows] return frames, total # 返回分页数据和总数 - def get_frame(self, frame_id: int): - return self.db.get(DeviceFrame, frame_id) + async def get_frame(self, frame_id: int): + result = await self.db.execute(select(DeviceFrame).where(DeviceFrame.id == frame_id)) + frame = result.scalar_one_or_none() + return frame - def get_frame_annotator(self, frame_id: int): - device_frame = self.get_frame(frame_id) + async def get_frame_annotator(self, frame_id: int): + device_frame = await self.get_frame(frame_id) if device_frame: - frame_image = cv2.imread(device_frame.frame_path) + # 异步读取图像文件 + async with aiofiles.open(device_frame.frame_path, mode='rb') as f: + file_content = await f.read() + + # 将读取的字节内容转换为 OpenCV 图像 + np_array = np.frombuffer(file_content, dtype=np.uint8) + frame_image = cv2.imdecode(np_array, cv2.IMREAD_COLOR) + frame_analysis_result_service = FrameAnalysisResultService(self.db) - results = frame_analysis_result_service.get_results_by_frame(device_frame.id) + results = await frame_analysis_result_service.get_results_by_frame(device_frame.id) if results: annotator = Annotator(deepcopy(frame_image)) height, width = frame_image.shape[:2] diff --git a/services/device_model_relation_service.py b/services/device_model_relation_service.py index c32b5a5..5f1aa9d 100644 --- a/services/device_model_relation_service.py +++ b/services/device_model_relation_service.py @@ -1,8 +1,8 @@ from datetime import datetime from typing import List -from sqlmodel import Session, select, delete - +from sqlmodel import select, delete +from sqlalchemy.ext.asyncio import AsyncSession from common.consts import NotifyChangeType from common.global_thread_pool import GlobalThreadPool from entity.device_model_relation import DeviceModelRelation, DeviceModelRelationInfo, DeviceModelRelationCreate @@ -10,7 +10,7 @@ class DeviceModelRelationService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db self.__relation_change_callbacks = [] # 用于存储回调函数 self.thread_pool = GlobalThreadPool() @@ -24,7 +24,7 @@ for callback in self.__relation_change_callbacks: self.thread_pool.executor.submit(callback, device_id, change_type) - def get_device_models(self, device_id: int) -> List[DeviceModelRelationInfo]: + async def get_device_models(self, device_id: int) -> List[DeviceModelRelationInfo]: statement = ( select(DeviceModelRelation, AlgoModel) .join(AlgoModel, DeviceModelRelation.algo_model_id == AlgoModel.id) @@ -32,7 +32,8 @@ ) # 执行联表查询 - result = self.db.exec(statement).all() + results = await self.db.execute(statement) + rows = results.all() models_info = [ DeviceModelRelationInfo( @@ -46,13 +47,11 @@ algo_model_path=model.path, algo_model_remark=model.remark, ) - for relation, model in result + for relation, model in rows ] return models_info - - - def add_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]): + async def add_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]): new_relations = [ DeviceModelRelation( algo_model_id=relation.algo_model_id, @@ -64,20 +63,21 @@ ) for relation in relations ] - self.db.add_all(new_relations) - self.db.commit() for relation in new_relations: - self.db.refresh(relation) + self.db.add(relation) + await self.db.commit() + for relation in new_relations: + await self.db.refresh(relation) return new_relations - def delete_relations_by_device(self, device_id: int): + async def delete_relations_by_device(self, device_id: int): statement = delete(DeviceModelRelation).where(DeviceModelRelation.device_id == device_id) - count = self.db.exec(statement) - self.db.commit() - return count.rowcount + result = await self.db.execute(statement) + await self.db.commit() + return result.rowcount - def update_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]): - self.delete_relations_by_device(device_id) - new_relations = self.add_relations_by_device(device_id, relations) + async def update_relations_by_device(self, device_id: int, relations: List[DeviceModelRelationCreate]): + await self.delete_relations_by_device(device_id) + new_relations = await self.add_relations_by_device(device_id, relations) self.notify_change(device_id, NotifyChangeType.DEVICE_MODEL_RELATION_UPDATE) return new_relations diff --git a/services/device_scene_relation_service.py b/services/device_scene_relation_service.py index 592ee98..3416831 100644 --- a/services/device_scene_relation_service.py +++ b/services/device_scene_relation_service.py @@ -1,8 +1,8 @@ from datetime import datetime from typing import List, Optional -from sqlmodel import Session, select, delete - +from sqlmodel import select, delete +from sqlalchemy.ext.asyncio import AsyncSession from common.consts import NotifyChangeType from common.global_thread_pool import GlobalThreadPool from entity.device_scene_relation import DeviceSceneRelationInfo, DeviceSceneRelation @@ -10,7 +10,7 @@ class DeviceSceneRelationService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db self.__relation_change_callbacks = [] # 用于存储回调函数 self.thread_pool = GlobalThreadPool() @@ -24,7 +24,7 @@ for callback in self.__relation_change_callbacks: self.thread_pool.executor.submit(callback, device_id, change_type) - def get_device_scene(self, device_id: int) -> Optional[DeviceSceneRelationInfo]: + async def get_device_scene(self, device_id: int) -> Optional[DeviceSceneRelationInfo]: statement = ( select(DeviceSceneRelation, Scene) .join(Scene, DeviceSceneRelation.scene_id == Scene.id) @@ -32,11 +32,12 @@ ) # 执行联表查询 - result = self.db.exec(statement).first() + result = await self.db.execute(statement) + result_row = result.first() scene_info = None - if result: - relation, scene = result[0], result[1] + if result_row: + relation, scene = result_row scene_info = DeviceSceneRelationInfo( id=relation.id, device_id=relation.device_id, @@ -48,23 +49,23 @@ ) return scene_info - def add_relation_by_device(self, device_id: int, scene_id: int): + async def add_relation_by_device(self, device_id: int, scene_id: int): new_relation = DeviceSceneRelation(device_id=device_id, scene_id=scene_id) new_relation.create_time = datetime.now() new_relation.update_time = datetime.now() self.db.add(new_relation) - self.db.commit() - self.db.refresh(new_relation) + await self.db.commit() + await self.db.refresh(new_relation) return new_relation - def delete_relation_by_device(self, device_id: int): + async def delete_relation_by_device(self, device_id: int): statement = delete(DeviceSceneRelation).where(DeviceSceneRelation.device_id == device_id) - count = self.db.exec(statement) - self.db.commit() - return count.rowcount + result = await self.db.execute(statement) + await self.db.commit() + return result.rowcount - def update_relation_by_device(self, device_id: int, scene_id: int): - self.delete_relation_by_device(device_id) - new_relation = self.add_relation_by_device(device_id, scene_id) + async def update_relation_by_device(self, device_id: int, scene_id: int): + await self.delete_relation_by_device(device_id) + new_relation = await self.add_relation_by_device(device_id, scene_id) self.notify_change(device_id, NotifyChangeType.DEVICE_SCENE_RELATION_UPDATE) return new_relation diff --git a/services/device_service.py b/services/device_service.py index cad46d3..2488fd5 100644 --- a/services/device_service.py +++ b/services/device_service.py @@ -4,8 +4,8 @@ from typing import Sequence, Optional, Tuple from sqlalchemy import func -from sqlmodel import Session, select - +from sqlmodel import select, delete +from sqlalchemy.ext.asyncio import AsyncSession from common.device_status_manager import DeviceStatusManager from common.global_thread_pool import GlobalThreadPool from common.consts import NotifyChangeType, DEVICE_MODE @@ -15,25 +15,25 @@ class DeviceService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db self.__device_change_callbacks = [] # 用于存储回调函数 self.thread_pool = GlobalThreadPool() # 创建一个独立的事件循环并启动线程 - self.loop = asyncio.new_event_loop() - self.loop_thread = threading.Thread(target=self._start_loop, daemon=True) - self.loop_thread.start() + # self.loop = asyncio.new_event_loop() + # self.loop_thread = threading.Thread(target=self._start_loop, daemon=True) + # self.loop_thread.start() - def _start_loop(self): - """后台线程运行事件循环""" - asyncio.set_event_loop(self.loop) - self.loop.run_forever() - - def shutdown(self): - """清理事件循环和线程""" - self.loop.call_soon_threadsafe(self.loop.stop) - self.loop_thread.join() + # def _start_loop(self): + # """后台线程运行事件循环""" + # asyncio.set_event_loop(self.loop) + # self.loop.run_forever() + # + # def shutdown(self): + # """清理事件循环和线程""" + # self.loop.call_soon_threadsafe(self.loop.stop) + # self.loop_thread.join() def register_change_callback(self, callback): """注册设备变化回调函数""" @@ -42,52 +42,53 @@ def notify_change(self, device_id, change_type): """当设备发生变化时,调用回调通知变化""" # loop = asyncio.get_event_loop() # 获取当前的事件循环 - for callback in self.__device_change_callbacks: - if asyncio.iscoroutinefunction(callback): - # 如果是协程函数,使用事件循环运行它 - # loop = asyncio.new_event_loop() - # asyncio.set_event_loop(loop) - asyncio.run_coroutine_threadsafe(callback(device_id, change_type), self.loop) - else: - # 如果是普通函数,直接提交到线程池 - self.thread_pool.executor.submit(callback, device_id, change_type) + # for callback in self.__device_change_callbacks: + # if asyncio.iscoroutinefunction(callback): + # # 如果是协程函数,使用事件循环运行它 + # # loop = asyncio.new_event_loop() + # # asyncio.set_event_loop(loop) + # asyncio.run_coroutine_threadsafe(callback(device_id, change_type), self.loop) + # else: + # # 如果是普通函数,直接提交到线程池 + # self.thread_pool.executor.submit(callback, device_id, change_type) - def get_device_list(self, - name: Optional[str] = None, - code: Optional[str] = None, - device_type: Optional[str] = None, - ) -> Sequence[Device]: + async def get_device_list(self, + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + ) -> Sequence[Device]: statement = self.device_query(code, device_type, name) - results = self.db.exec(statement) - return results.all() + results = await self.db.execute(statement) + return results.scalars().all() - def get_device_page(self, - name: Optional[str] = None, - code: Optional[str] = None, - device_type: Optional[str] = None, - offset: int = 0, - limit: int = 10 - ) -> Tuple[Sequence[DeviceInfo], int]: + async def get_device_page(self, + name: Optional[str] = None, + code: Optional[str] = None, + device_type: Optional[str] = None, + offset: int = 0, + limit: int = 10 + ) -> Tuple[Sequence[DeviceInfo], int]: statement = self.device_query(code, device_type, name) # 查询总记录数 total_statement = select(func.count()).select_from(statement.subquery()) - total = self.db.exec(total_statement).one() + total_result = await self.db.execute(total_statement) + total = total_result.scalar_one() # 添加分页限制 statement = statement.offset(offset).limit(limit) # 执行查询并返回结果 - results = self.db.exec(statement) - device_list = results.all() + results = await self.db.execute(statement) + device_list = results.scalars().all() device_info_list = [] if device_list: device_model_relation_service = DeviceModelRelationService(self.db) device_scene_relation_service = DeviceSceneRelationService(self.db) device_status_manager = DeviceStatusManager() for device in device_list: - model_relations = device_model_relation_service.get_device_models(device.id) - scene_relation = device_scene_relation_service.get_device_scene(device.id) + model_relations = await device_model_relation_service.get_device_models(device.id) + scene_relation = await device_scene_relation_service.get_device_scene(device.id) device_info_list.append(DeviceInfo( id=device.id, @@ -121,20 +122,20 @@ statement = statement.where(Device.type == device_type) return statement - def create_device(self, device_data: DeviceCreate): + async def create_device(self, device_data: DeviceCreate): device = Device.model_validate(device_data) device.create_time = datetime.now() device.update_time = datetime.now() - self.handle_device_mode(device) + await self.handle_device_mode(device) self.db.add(device) - self.db.commit() - self.db.refresh(device) + await self.db.commit() + await self.db.refresh(device) self.notify_change(device.id, NotifyChangeType.DEVICE_CREATE) return device - def update_device(self, device_data: DeviceUpdate): - device_old = self.db.get(Device, device_data.id) - device = self.db.get(Device, device_data.id) + async def update_device(self, device_data: DeviceUpdate): + device_old = await self.get_device(device_data.id) + device = await self.get_device(device_data.id) if not device: return None @@ -142,41 +143,45 @@ for key, value in update_data.items(): setattr(device, key, value) - self.handle_device_mode(device) + await self.handle_device_mode(device) device.update_time = datetime.now() self.db.add(device) - self.db.commit() - self.db.refresh(device) + await self.db.commit() + await self.db.refresh(device) self.notify_change(device.id, NotifyChangeType.DEVICE_UPDATE) return device - def delete_device(self, device_id: int): - device = self.db.get(Device, device_id) + async def delete_device(self, device_id: int): + device = await self.get_device(device_id) if not device: return None - self.db.delete(device) - self.db.commit() + statement = delete(Device).where(Device.id == device_id) + await self.db.execute(statement) + await self.db.commit() + self.notify_change(device.id, NotifyChangeType.DEVICE_DELETE) model_relation_service = DeviceModelRelationService(self.db) - model_relation_service.delete_relations_by_device(device_id) + await model_relation_service.delete_relations_by_device(device_id) scene_relation_service = DeviceSceneRelationService(self.db) - scene_relation_service.delete_relation_by_device(device.id) + await scene_relation_service.delete_relation_by_device(device.id) return device - def handle_device_mode(self, device): + async def handle_device_mode(self, device): if device.mode == DEVICE_MODE.ALGO: scene_relation_service = DeviceSceneRelationService(self.db) - scene_relation_service.delete_relation_by_device(device.id) + await scene_relation_service.delete_relation_by_device(device.id) elif device.mode == DEVICE_MODE.SCENE: model_relation_service = DeviceModelRelationService(self.db) - model_relation_service.delete_relations_by_device(device.id) + await model_relation_service.delete_relations_by_device(device.id) else: scene_relation_service = DeviceSceneRelationService(self.db) - scene_relation_service.delete_relation_by_device(device.id) + await scene_relation_service.delete_relation_by_device(device.id) model_relation_service = DeviceModelRelationService(self.db) - model_relation_service.delete_relations_by_device(device.id) + await model_relation_service.delete_relations_by_device(device.id) - def get_device(self, device_id: int): - return self.db.get(Device, device_id) + async def get_device(self, device_id: int): + result = await self.db.execute(select(Device).where(Device.id == device_id)) + frame = result.scalar_one_or_none() + return frame diff --git a/services/frame_analysis_result_service.py b/services/frame_analysis_result_service.py index dbd272d..4e6c145 100644 --- a/services/frame_analysis_result_service.py +++ b/services/frame_analysis_result_service.py @@ -1,23 +1,25 @@ from typing import List -from sqlmodel import Session, select +from sqlmodel import select +from sqlalchemy.ext.asyncio import AsyncSession from entity.frame_analysis_result import FrameAnalysisResultCreate, FrameAnalysisResult class FrameAnalysisResultService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db - def add_frame_analysis_results(self, results: List[FrameAnalysisResultCreate]): + async def add_frame_analysis_results(self, results: List[FrameAnalysisResultCreate]): new_results = [FrameAnalysisResult.model_validate(result) for result in results] - self.db.add_all(new_results) - self.db.commit() for result in new_results: - self.db.refresh(result) + self.db.add(result) + await self.db.commit() + for result in new_results: + await self.db.refresh(result) return new_results - def get_results_by_frame(self, frame_id): + async def get_results_by_frame(self, frame_id): statement = select(FrameAnalysisResult).where(FrameAnalysisResult.frame_id == frame_id) - results = self.db.exec(statement) - return results.all() + results = await self.db.execute(statement) + return results.scalars().all() diff --git a/services/global_config.py b/services/global_config.py index 7dae73b..fa53b2b 100644 --- a/services/global_config.py +++ b/services/global_config.py @@ -1,3 +1,5 @@ +import asyncio + from common.consts import PUSH_TYPE from db.database import get_db from entity.push_config import PushConfig @@ -6,10 +8,12 @@ class GlobalConfig: _instance = None + _lock = asyncio.Lock() def __new__(cls): if cls._instance is None: cls._instance = super().__new__(cls) + return cls._instance def __init__(self): @@ -19,49 +23,58 @@ self.gas_push_config = None self.algo_result_push_config = None self.alarm_push_config = None - self.init_config() # 进行初始化 + self._init_done = False - def init_config(self): - # 初始化配置逻辑 - with next(get_db()) as db: - self.config_service = PushConfigService(db) - self.set_gas_push_config(self.config_service.get_push_config(PUSH_TYPE.GAS)) - self.set_algo_result_push_config(self.config_service.get_push_config(PUSH_TYPE.ALGO_RESULT)) - self.set_alarm_push_config(self.config_service.get_push_config(PUSH_TYPE.ALARM)) + async def _initialize(self): + await self.init_config() # 调用异步初始化 - self.config_service.register_change_callback(self.on_config_change) + async def init_config(self): + # 确保只初始化一次 + if not self._init_done: + async with self._lock: + if not self._init_done: # 双重检查锁,避免多次初始化 + async for db in get_db(): + self.config_service = PushConfigService(db) + self.config_service.register_change_callback(self.on_config_change) + self.gas_push_config = await self.config_service.get_push_config(PUSH_TYPE.GAS) + self.algo_result_push_config = await self.config_service.get_push_config(PUSH_TYPE.ALGO_RESULT) + self.alarm_push_config = await self.config_service.get_push_config(PUSH_TYPE.ALARM) + self._init_done = True - def on_config_change(self, config: PushConfig): + async def on_config_change(self, config: PushConfig): if config.push_type == PUSH_TYPE.GAS: - self.set_gas_push_config(config) + await self.set_gas_push_config(config) elif config.push_type == PUSH_TYPE.ALGO_RESULT: - self.set_algo_result_push_config(config) + await self.set_algo_result_push_config(config) elif config.push_type == PUSH_TYPE.ALARM: - self.set_alarm_push_config(config) + await self.set_alarm_push_config(config) def get_gas_push_config(self): """获取 gas_push_config 配置""" return self.gas_push_config - def set_gas_push_config(self, config): + async def set_gas_push_config(self, config): """设置 gas_push_config 配置""" if config: - self.gas_push_config = config + async with self._lock: + self.gas_push_config = config def get_algo_result_push_config(self): """获取 algo_result_push_config 配置""" return self.algo_result_push_config - def set_algo_result_push_config(self, config): + async def set_algo_result_push_config(self, config): """设置 algo_result_push_config 配置""" if config: - self.algo_result_push_config = config + async with self._lock: + self.algo_result_push_config = config def get_alarm_push_config(self): """获取 algo_result_push_config 配置""" return self.alarm_push_config - def set_alarm_push_config(self, config): + async def set_alarm_push_config(self, config): """设置 algo_result_push_config 配置""" if config: - self.alarm_push_config = config + async with self._lock: + self.alarm_push_config = config diff --git a/services/model_service.py b/services/model_service.py index f9531a5..afc61d9 100644 --- a/services/model_service.py +++ b/services/model_service.py @@ -1,3 +1,4 @@ +import asyncio import os import uuid import zipfile @@ -5,10 +6,11 @@ from pathlib import Path from typing import List, Sequence, Optional, Tuple, Type +import aiofiles from fastapi import UploadFile from sqlalchemy import func -from sqlmodel import Session, select - +from sqlmodel import select, delete +from sqlalchemy.ext.asyncio import AsyncSession from common.biz_exception import BizException from common.string_utils import snake_to_camel from entity.device_model_relation import DeviceModelRelation @@ -18,7 +20,7 @@ class ModelService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db self.__model_change_callbacks = [] # 用于存储回调函数 self.thread_pool = GlobalThreadPool() @@ -32,15 +34,15 @@ for callback in self.__model_change_callbacks: self.thread_pool.executor.submit(callback, algo_model_id, change_type) - def get_model_list(self, - name: Optional[str] = None, - remark: Optional[str] = None, - ) -> Sequence[AlgoModel]: + async def get_model_list(self, + name: Optional[str] = None, + remark: Optional[str] = None, + ) -> Sequence[AlgoModel]: statement = self.model_query(name, remark) - results = self.db.exec(statement) - return results.all() + results = await self.db.execute(statement) + return results.scalars().all() - def get_model_page(self, + async def get_model_page(self, name: Optional[str] = None, remark: Optional[str] = None, offset: int = 0, @@ -50,19 +52,21 @@ # 查询总记录数 total_statement = select(func.count()).select_from(statement.subquery()) - total = self.db.exec(total_statement).one() + total_result = await self.db.execute(total_statement) + total = total_result.scalar_one() # 添加分页限制 statement = statement.offset(offset).limit(limit) # 执行查询并返回结果 - model_list = self.db.exec(statement) + model_list = await self.db.execute(statement) + rows = model_list.scalars().all() model_info_list: List[AlgoModelInfo] = [] - if model_list: - for model in model_list: + if rows: + for model in rows: model_info_list.append(AlgoModelInfo( **model.dict(), - usage_status="使用中" if self.get_model_usage(model.id) else "未使用" + usage_status="使用中" if await self.get_model_usage(model.id) else "未使用" )) return model_info_list, total # 返回分页数据和总数 @@ -76,7 +80,7 @@ statement = statement.where(AlgoModel.remark.like(f"%{remark}%")) return statement - def process_zip(self, file: UploadFile): + async def process_zip(self, file: UploadFile): model_dir = Path('weights/') model_handle_dir = Path('model_handler/') model_dir.mkdir(parents=True, exist_ok=True) @@ -86,13 +90,16 @@ # 临时保存上传文件 temp_path = Path(f"temp_upload_{uuid.uuid4()}.zip") - with open(temp_path, "wb") as temp_file: - temp_file.write(file.file.read()) + async with aiofiles.open(temp_path, "wb") as temp_file: + content = await file.read() # 异步读取上传文件内容 + await temp_file.write(content) model_file_path = None handle_file_path = None try: + # 使用异步方法读取压缩文件的内容 + loop = asyncio.get_event_loop() with zipfile.ZipFile(temp_path, 'r') as zip_ref: # 获取压缩包文件列表 file_list = zip_ref.namelist() @@ -105,47 +112,47 @@ message=f"Model weight file ({', '.join(SUPPORTED_MODEL_EXTENSIONS)}) is required in the zip." ) - # 解压模型文件到模型目录 - zip_ref.extract(model_file, model_dir) + # 异步解压模型文件到模型目录 model_file_path = model_dir / model_file + await loop.run_in_executor(None, zip_ref.extract, model_file, model_dir) # 检查是否有可选的 Python 脚本 handle_file = next((f for f in file_list if f.endswith(".py")), None) if handle_file: - zip_ref.extract(handle_file, model_handle_dir) handle_file_path = model_handle_dir / handle_file - + await loop.run_in_executor(None, zip_ref.extract, handle_file, model_handle_dir) except zipfile.BadZipFile: raise BizException(status_code=400, message="Invalid zip file.") finally: # 删除临时文件 - temp_path.unlink() + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, temp_path.unlink) return str(model_file_path), str(handle_file_path) if handle_file_path else None - def create_model(self, model_data: AlgoModelCreate, file: UploadFile): - self.process_model_file(file, model_data) + async def create_model(self, model_data: AlgoModelCreate, file: UploadFile): + await self.process_model_file(file, model_data) model = AlgoModel.model_validate(model_data) model.create_time = datetime.now() model.update_time = datetime.now() self.db.add(model) - self.db.commit() - self.db.refresh(model) + await self.db.commit() + await self.db.refresh(model) return model - def process_model_file(self, file, model): - model_file_path, handle_file_path = self.process_zip(file) + async def process_model_file(self, file, model): + model_file_path, handle_file_path = await self.process_zip(file) model.path = model_file_path if handle_file_path: model.handle_task = snake_to_camel(os.path.splitext(os.path.basename(handle_file_path))[0]) else: model.handle_task = 'BaseModelHandler' - def update_model(self, model_data: AlgoModelUpdate, file: UploadFile): - model = self.db.get(AlgoModel, model_data.id) + async def update_model(self, model_data: AlgoModelUpdate, file: UploadFile): + model = await self.get_model_by_id(model_data.id) if not model: return None @@ -155,16 +162,16 @@ model.update_time = datetime.now() if file: - self.process_model_file(file, model) + await self.process_model_file(file, model) self.db.add(model) - self.db.commit() - self.db.refresh(model) + await self.db.commit() + await self.db.refresh(model) self.notify_change(model.id, NotifyChangeType.MODEL_UPDATE) return model - def delete_model(self, model_id: int): - model = self.db.get(AlgoModel, model_id) + async def delete_model(self, model_id: int): + model = await self.get_model_by_id(model_id) if not model: return None # 查询 device_model_relation 中是否存在启用的绑定关系 @@ -173,17 +180,20 @@ .where(DeviceModelRelation.algo_model_id == model_id) .where(DeviceModelRelation.is_use == 1) ) - relation_in_use = self.db.exec(statement).first() + relation_in_use_exec = await self.db.execute(statement) + relation_in_use = relation_in_use_exec.scalars().first() # 如果存在启用的绑定关系,提示无法删除 if relation_in_use: raise BizException(message=f"模型 {model.name} 正在被设备使用,无法删除") - self.db.delete(model) - self.db.commit() + statement = delete(AlgoModel).where(AlgoModel.id == model_id) + await self.db.execute(statement) + await self.db.commit() + return model - def get_models_in_use(self) -> Sequence[AlgoModel]: + async def get_models_in_use(self) -> Sequence[AlgoModel]: """获取所有在 device_model_relation 表里有启用绑定关系的模型信息""" statement = ( select(AlgoModel) @@ -191,10 +201,10 @@ .where(DeviceModelRelation.is_use == 1) .group_by(AlgoModel.id) ) - results = self.db.exec(statement).all() - return results + results = await self.db.execute(statement) + return results.scalars().all() - def get_model_usage(self, algo_model_id) -> bool: + async def get_model_usage(self, algo_model_id) -> bool: statement = ( select(DeviceModelRelation) .where( @@ -202,8 +212,11 @@ DeviceModelRelation.algo_model_id == algo_model_id, ) ) - result = self.db.exec(statement).all() - return len(result) > 0 + result = await self.db.execute(statement) + rows = result.all() + return len(rows) > 0 - def get_model_by_id(self, model_id): - return self.db.get(AlgoModel, model_id) + async def get_model_by_id(self, model_id): + result = await self.db.execute(select(AlgoModel).where(AlgoModel.id == model_id)) + model = result.scalar_one_or_none() + return model diff --git a/services/push_config_service.py b/services/push_config_service.py index 8a2c298..14862df 100644 --- a/services/push_config_service.py +++ b/services/push_config_service.py @@ -1,6 +1,8 @@ +import asyncio from datetime import datetime -from sqlmodel import Session, select +from sqlmodel import select +from sqlalchemy.ext.asyncio import AsyncSession from entity.push_config import PushConfigCreate, PushConfig @@ -13,7 +15,7 @@ cls._instance = super().__new__(cls) return cls._instance - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): if not hasattr(self, 'initialized'): self.db = db self.__push_change_callbacks = [] # 用于存储回调函数 @@ -23,12 +25,22 @@ """注册设备变化回调函数""" self.__push_change_callbacks.append(callback) - def notify_change(self, push_config): - for callback in self.__push_change_callbacks: - callback(push_config) + # def notify_change(self, push_config): + # for callback in self.__push_change_callbacks: + # callback(push_config) - def set_push_config(self, push_config_create: PushConfigCreate): - push_config = self.get_push_config(push_config_create.push_type) + def notify_change(self, config): + """通知所有回调函数""" + for callback in self.__push_change_callbacks: + if asyncio.iscoroutinefunction(callback): + # 如果是异步函数,使用 asyncio.create_task() 调度 + asyncio.create_task(callback(config)) + else: + # 如果是同步函数,直接调用 + callback(config) + + async def set_push_config(self, push_config_create: PushConfigCreate): + push_config = await self.get_push_config(push_config_create.push_type) if push_config: update_data = push_config_create.dict(exclude_unset=True) for key, value in update_data.items(): @@ -40,18 +52,18 @@ push_config.update_time = datetime.now() self.db.add(push_config) - self.db.commit() - self.db.refresh(push_config) + await self.db.commit() + await self.db.refresh(push_config) self.notify_change(push_config) return push_config - def get_push_config(self, push_type): + async def get_push_config(self, push_type): statement = select(PushConfig).where(PushConfig.push_type == push_type) - results = self.db.exec(statement) - return results.first() + results = await self.db.execute(statement) + return results.scalars().first() - def get_push_config_list(self): + async def get_push_config_list(self): statement = select(PushConfig) - results = self.db.exec(statement) - return results.all() \ No newline at end of file + results = await self.db.execute(statement) + return results.scalars().all() diff --git a/services/scene_service.py b/services/scene_service.py index 17c8f52..dc4c67b 100644 --- a/services/scene_service.py +++ b/services/scene_service.py @@ -1,3 +1,4 @@ +import asyncio import os import uuid import zipfile @@ -5,10 +6,11 @@ from pathlib import Path from typing import Optional, Sequence, Tuple, List +import aiofiles from fastapi import UploadFile from sqlalchemy import func -from sqlmodel import Session, select - +from sqlmodel import select, delete +from sqlalchemy.ext.asyncio import AsyncSession from common.biz_exception import BizException from common.consts import NotifyChangeType from common.global_thread_pool import GlobalThreadPool @@ -18,7 +20,7 @@ class SceneService: - def __init__(self, db: Session): + def __init__(self, db: AsyncSession): self.db = db self.__scene_change_callbacks = [] # 用于存储回调函数 self.thread_pool = GlobalThreadPool() @@ -32,37 +34,39 @@ for callback in self.__scene_change_callbacks: self.thread_pool.executor.submit(callback, scene_id, change_type) - def get_scene_list(self, - name: Optional[str] = None, - remark: Optional[str] = None, - ) -> Sequence[Scene]: + async def get_scene_list(self, + name: Optional[str] = None, + remark: Optional[str] = None, + ) -> Sequence[Scene]: statement = self.scene_query(name, remark) - results = self.db.exec(statement) - return results.all() + results = await self.db.execute(statement) + return results.scalars().all() - def get_scene_page(self, - name: Optional[str] = None, - remark: Optional[str] = None, - offset: int = 0, - limit: int = 10 - ) -> Tuple[Sequence[SceneInfo], int]: + async def get_scene_page(self, + name: Optional[str] = None, + remark: Optional[str] = None, + offset: int = 0, + limit: int = 10 + ) -> Tuple[Sequence[SceneInfo], int]: statement = self.scene_query(name, remark) # 查询总记录数 total_statement = select(func.count()).select_from(statement.subquery()) - total = self.db.exec(total_statement).one() + total_result = await self.db.execute(total_statement) + total = total_result.scalar_one() # 添加分页限制 statement = statement.offset(offset).limit(limit) # 执行查询并返回结果 - scene_list = self.db.exec(statement) + scene_list = await self.db.execute(statement) + rows = scene_list.scalars().all() scene_info_list: List[SceneInfo] = [] if scene_list: - for scene in scene_list: + for scene in rows: scene_info_list.append(SceneInfo( **scene.dict(), - usage_status="使用中" if self.get_scene_usage(scene.id) else "未使用" + usage_status="使用中" if await self.get_scene_usage(scene.id) else "未使用" )) return scene_info_list, total # 返回分页数据和总数 @@ -76,7 +80,7 @@ statement = statement.where(Scene.remark.like(f"%{remark}%")) return statement - def process_zip(self, file: UploadFile): + async def process_zip(self, file: UploadFile): model_dir = Path('weights/') scene_handle_dir = Path('scene_handler/') model_dir.mkdir(parents=True, exist_ok=True) @@ -86,61 +90,65 @@ # 临时保存上传文件 temp_path = Path(f"temp_upload_{uuid.uuid4()}.zip") - with open(temp_path, "wb") as temp_file: - temp_file.write(file.file.read()) + async with aiofiles.open(temp_path, "wb") as temp_file: + content = await file.read() # 异步读取上传文件内容 + await temp_file.write(content) model_file_paths = [] handle_file_path = None try: + # 使用异步方法读取压缩文件的内容 + loop = asyncio.get_event_loop() with zipfile.ZipFile(temp_path, 'r') as zip_ref: # 获取压缩包文件列表 file_list = zip_ref.namelist() model_files = [f for f in file_list if Path(f).suffix in SUPPORTED_MODEL_EXTENSIONS] - # 解压所有模型文件到模型目录 + # 异步解压所有模型文件到模型目录 for model_file in model_files: - zip_ref.extract(model_file, model_dir) + await loop.run_in_executor(None, zip_ref.extract, model_file, model_dir) model_file_paths.append(model_dir / model_file) # 检查是否有可选的 Python 脚本 handle_file = next((f for f in file_list if f.endswith(".py")), None) if handle_file: - zip_ref.extract(handle_file, scene_handle_dir) + await loop.run_in_executor(None, zip_ref.extract, handle_file, scene_handle_dir) handle_file_path = scene_handle_dir / handle_file else: raise BizException( status_code=400, - message=f"handle file (.py) is required in the zip." + message=f"Handle file (.py) is required in the zip." ) except zipfile.BadZipFile: raise BizException(status_code=400, message="Invalid zip file.") finally: # 删除临时文件 - temp_path.unlink() + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, temp_path.unlink) return [str(path) for path in model_file_paths], str(handle_file_path) - def process_scene_file(self, file, scene): - model_file_paths, handle_file_path = self.process_zip(file) + async def process_scene_file(self, file, scene): + model_file_paths, handle_file_path = await self.process_zip(file) scene.handle_task = snake_to_camel(os.path.splitext(os.path.basename(handle_file_path))[0]) - def create_scene(self, scene_data: SceneCreate, file: UploadFile): - self.process_scene_file(file, scene_data) + async def create_scene(self, scene_data: SceneCreate, file: UploadFile): + await self.process_scene_file(file, scene_data) scene = Scene.model_validate(scene_data) scene.create_time = datetime.now() scene.update_time = datetime.now() self.db.add(scene) - self.db.commit() - self.db.refresh(scene) + await self.db.commit() + await self.db.refresh(scene) return scene - def update_scene(self, scene_data: SceneUpdate, file: UploadFile): - scene = self.db.get(Scene, scene_data.id) + async def update_scene(self, scene_data: SceneUpdate, file: UploadFile): + scene = await self.get_scene_by_id(scene_data.id) if not scene: return None @@ -150,16 +158,16 @@ scene.update_time = datetime.now() if file: - self.process_scene_file(file, scene) + await self.process_scene_file(file, scene) self.db.add(scene) - self.db.commit() - self.db.refresh(scene) + await self.db.commit() + await self.db.refresh(scene) self.notify_change(scene.id, NotifyChangeType.SCENE_UPDATE) return scene - def delete_scene(self, scene_id: int): - scene = self.db.get(Scene, scene_id) + async def delete_scene(self, scene_id: int): + scene = await self.get_scene_by_id(scene_id) if not scene: return None # 查询 device_scene_relation 中是否存在启用的绑定关系 @@ -167,35 +175,40 @@ select(DeviceSceneRelation) .where(DeviceSceneRelation.scene_id == scene_id) ) - relation_in_use = self.db.exec(statement).first() + relation_in_use_exec = self.db.execute(statement) + relation_in_use = relation_in_use_exec.scalars().first() # 如果存在启用的绑定关系,提示无法删除 if relation_in_use: raise BizException(message=f"场景 {scene.name} 正在被设备使用,无法删除") - self.db.delete(scene) - self.db.commit() + statement = delete(Scene).where(Scene.id == scene_id) + await self.db.execute(statement) + await self.db.commit() return scene - def get_scenes_in_use(self) -> Sequence[Scene]: + async def get_scenes_in_use(self) -> Sequence[Scene]: """获取所有在 device_scene_relation 表里有绑定关系的模型信息""" statement = ( select(Scene) .join(DeviceSceneRelation, DeviceSceneRelation.scene_id == Scene.id) .group_by(Scene.id) ) - results = self.db.exec(statement).all() - return results + results = await self.db.execute(statement) + return results.scalars().all() - def get_scene_usage(self, scene_id) -> bool: + async def get_scene_usage(self, scene_id) -> bool: statement = ( select(DeviceSceneRelation) .where( DeviceSceneRelation.scene_id == scene_id, ) ) - result = self.db.exec(statement).all() - return len(result) > 0 + result = await self.db.exec(statement) + rows = result.all() + return len(rows) > 0 - def get_scene_by_id(self, scene_id): - return self.db.get(Scene, scene_id) \ No newline at end of file + async def get_scene_by_id(self, scene_id): + result = await self.db.execute(select(Scene).where(Scene.id == scene_id)) + scene = result.scalar_one_or_none() + return scene diff --git a/tcp/tcp_client_connector.py b/tcp/tcp_client_connector.py index 3d6d076..d645340 100644 --- a/tcp/tcp_client_connector.py +++ b/tcp/tcp_client_connector.py @@ -113,29 +113,29 @@ logger.error(f"Error during query for {self.ip}:{self.port}: {e}") await self.reconnect() - def parse_response(self, data): + async def parse_response(self, data): """解析设备返回的数据""" logger.info(f"Received data from {self.ip}:{self.port}: {format_bytes(data)}") try: res = parse_gas_data(data) logger.info(res) - with next(get_db()) as db: + async for db in get_db(): data_gas_service = DataGasService(db) data_gas = DataGas( device_code=res['device_code'], gas_value=res['gas_value'] ) - data_gas_service.add_data_gas(data_gas) - self.gas_push(data_gas) + await data_gas_service.add_data_gas(data_gas) + await self.gas_push(data_gas) except Exception as e: logger.error(f"Parse and save gas data failed: {e}") - def gas_push(self, data_gas): + async def gas_push(self, data_gas): global_config = GlobalConfig() gas_push_config = global_config.get_gas_push_config() - if gas_push_config: + if gas_push_config and gas_push_config.push_url: last_ts = self.push_ts_dict.get(data_gas.device_code) current_time = datetime.now() @@ -161,7 +161,7 @@ data = await asyncio.wait_for(self.reader.read(1024), timeout=self.timeout) # if not data: # raise ConnectionResetError("Connection lost or no data received") - self.parse_response(data) + await self.parse_response(data) return data # 返回响应数据 else: return None diff --git a/tcp/tcp_manager.py b/tcp/tcp_manager.py index f28446d..094aa6a 100644 --- a/tcp/tcp_manager.py +++ b/tcp/tcp_manager.py @@ -3,10 +3,12 @@ from common.consts import DEVICE_TYPE, NotifyChangeType +from db.database import get_db from entity.device import Device from services.device_service import DeviceService from common.global_logger import logger +from services.global_config import GlobalConfig from tcp.tcp_client_connector import TcpClientConnector @@ -22,7 +24,7 @@ async def load_and_connect_devices(self): """从数据库加载设备并连接所有设备""" - devices = self.device_service.get_device_list(device_type=DEVICE_TYPE.TREE) # 使用局部变量 + devices = await self.device_service.get_device_list(device_type=DEVICE_TYPE.TREE) # 使用局部变量 logger.info(f"get {len(devices)} tree devices") for device in devices: await self.start_device_connect(device) @@ -47,14 +49,14 @@ async def restart_device_thread(self, device_id): await self.stop_device_connect(device_id) - device = self.device_service.get_device(device_id) + device = await self.device_service.get_device(device_id) await self.start_device_connect(device) async def on_device_change(self, device_id, change_type): """设备变化时的回调处理""" if change_type == NotifyChangeType.DEVICE_CREATE: # 新增设备,加载新设备并连接 - new_device = self.device_service.get_device(device_id) + new_device = await self.device_service.get_device(device_id) await self.start_device_connect(new_device) elif change_type == NotifyChangeType.DEVICE_DELETE: @@ -70,7 +72,7 @@ async def send_message_to_device(self, device_id, message: bytes, have_response): if device_id not in self.connector_map: - device = self.device_service.get_device(device_id) + device = await self.device_service.get_device(device_id) await self.start_device_connect(device) connector = self.connector_map[device_id] if connector: @@ -79,5 +81,9 @@ if __name__ == '__main__': - tcp_manager = TcpManager() - asyncio.run(tcp_manager.start()) + async for db in get_db(): + global_config = GlobalConfig() + await global_config.init_config() + device_service = DeviceService(db) + tcp_manager = TcpManager(device_service) + asyncio.run(tcp_manager.start())