Newer
Older
lynxi-casic-demo / download_tool.py
zhangyingjie on 24 Jan 6 KB 增加后台接口调用
from datetime import datetime
import hashlib
import os
import requests
import shutil
import zipfile

from constants import FILE_MD5_URI, SERVER_BASE_URL
from global_logger import logger

download_path = './downloads/'
weight_path = './models'
model_task_path = './model_handler'
scene_task_path = './scene_handler'

def _file_url(file_path):
    return f'{SERVER_BASE_URL}/static/{file_path}'

def _file_md5_url(file_path):
    return f'{SERVER_BASE_URL}/{FILE_MD5_URI}'

def _calculate_md5(file_path, chunk_size=8192):
    """高效计算文件的 MD5 值,使用分块读取"""
    md5_hash = hashlib.md5()
    with open(file_path, 'rb') as f:
        for chunk in iter(lambda: f.read(chunk_size), b""):
            md5_hash.update(chunk)
    return md5_hash.hexdigest()

def _download_file(file_path):
    response = requests.get(_file_url(file_path), stream=True)
    path = os.path.join(download_path,file_path)
    os.makedirs(os.path.dirname(path), exist_ok=True)

    with open(path, 'wb') as f:
        for chunk in response.iter_content(chunk_size=8192):
            f.write(chunk)
    return response

def _get_server_md5(file_path):
    params = {'path': file_path}
    logger.debug(f'GET: {_file_md5_url(file_path)}, {params}')
    response = requests.get(url=_file_md5_url(file_path),params=params)
    if response.status_code == 200:
        data = response.json()
        if data['code'] == 200:
            return data['data']
        else:
            logger.error(f'get server md5 return error: {data}')
            return None
    else:
        logger.error(f"get server md5 error: status_code: {response.status_code}, response: {response.text}")
        return None


def _write_meta_file(meta_file, filename, version, size, md5):
    with open(meta_file, 'w') as f:
        f.write(f"filename: {filename}\n")
        f.write(f"version: {version}\n")
        f.write(f"size: {size}\n")
        f.write(f"md5: {md5}\n")
        f.write(f"last_modified: {datetime.now()}\n")

# 读取 .meta 文件
def _read_meta_file(meta_file):
    if not os.path.exists(meta_file):
        return None
    meta_data = {}
    with open(meta_file, 'r') as f:
        for line in f:
            # 处理空行或无效格式的行
            if not line.strip():
                continue

            # 确保行中至少包含冒号(防止解包错误)
            parts = line.strip().split(": ", 1)
            if len(parts) == 2:
                key, value = parts
            elif len(parts) == 1:
                key, value = parts[0].replace(":","").strip(), ""  # 处理 value 为空的情况
            else:
                continue  # 忽略错误格式的行

            meta_data[key.strip()] = value.strip()
    return meta_data

def _is_file_up_to_date(local_path, meta_path, server_version, server_size, server_md5):
    if not os.path.exists(local_path) or not os.path.exists(meta_path):
        return False

    meta_data = _read_meta_file(meta_path)
    logger.debug(f'meta_data={meta_data}')
    if not meta_data:
        return False

    # 检查文件版本、大小、MD5
    local_md5 = _calculate_md5(local_path)

    logger.debug(f'server: {server_version} {server_size} {server_md5}')
    logger.debug(f'meta: {meta_data.get("version")} {meta_data.get("size")} {meta_data.get("md5")}')
    return (
        meta_data.get("version") == server_version and
        meta_data.get("size") == str(server_size) and
        meta_data.get("md5") == server_md5 and
        local_md5 == server_md5 #防止文件被篡改或意外更改
    )



def _extract_and_organize(zip_path, type):
    folder_target = weight_path
    py_target = model_task_path if type =='model' else scene_task_path
    zip_name = os.path.splitext(os.path.basename(zip_path))[0]
    logger.debug(f'zip_name={zip_name}, zip_path={zip_path}')
    # 创建目标路径
    os.makedirs(folder_target, exist_ok=True)
    os.makedirs(py_target, exist_ok=True)

    # 解压缩文件
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        extract_dir = os.path.join(os.path.dirname(zip_path), "temp_extract")
        zip_ref.extractall(extract_dir)

    # 遍历解压后的内容
    for root, dirs, files in os.walk(extract_dir):
        for directory in dirs:
            source_folder = os.path.join(root, directory)
            target_folder = os.path.join(folder_target, zip_name)
            logger.info(f"移动文件夹: {source_folder} 到 {target_folder}")
            shutil.move(source_folder, target_folder)

        for file in files:
            if file.endswith(".py"):
                source_file = os.path.join(root, file)
                target_file = os.path.join(py_target, file)
                logger.info(f"移动 Python 文件: {source_file} 到 {target_file}")
                shutil.move(source_file, target_file)

    # 清理临时解压目录
    shutil.rmtree(extract_dir)
    logger.info("文件整理完成!")

# 文件更新逻辑
def check_file(file_path, file_version, type):
    logger.info(f'checking file file_path={file_path} file_version={file_version}')
    # filename = os.path.basename(file_path)
    filename = file_path
    local_path = os.path.join(download_path, filename)
    meta_path = local_path + ".meta"
    logger.debug(f'filename={filename}')
    logger.debug(f'local_path={local_path}')
    logger.debug(f'meta_path={meta_path}')

    # 获取服务器文件信息
    response = requests.head(_file_url(file_path))
    server_size = int(response.headers.get('Content-Length', 0))

    # 如果本地文件已存在且是最新的,则跳过下载
    if os.path.exists(local_path):
        logger.info(f'{local_path}已存在')
        server_md5 = _get_server_md5(file_path)
        if _is_file_up_to_date(local_path, meta_path, file_version, server_size, server_md5):
            logger.info(f"{filename} 已经是最新版本,无需下载")
            return

    # 下载新文件
    logger.info(f"正在下载 {filename}...")
    _download_file(file_path)

    # 更新 .meta 文件
    local_md5 = _calculate_md5(local_path)
    _write_meta_file(meta_path, filename, file_version, server_size, local_md5)
    logger.info(f"{filename} 下载完成并更新元数据")

    # 解压文件,放到对应的位置
    _extract_and_organize(local_path, type)