File size: 5,956 Bytes
506ecd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5fdd10a
506ecd3
 
 
 
 
 
 
 
5fdd10a
 
 
 
506ecd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5fdd10a
 
 
 
ff769a6
 
a6fb107
 
 
 
839c1dd
506ecd3
839c1dd
506ecd3
839c1dd
506ecd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import requests
import zipfile
import os
import argparse

def download_file_from_google_drive(file_id, destination):
    """
    通过文件ID下载Google Drive共享文件

    Args:
        file_id (str): Google Drive文件的ID
        destination (str): 本地保存路径
    """
    # 基本的下载URL
    URL = "https://docs.google.com/uc?export=download"

    session = requests.Session()

    # 发起初始GET请求
    response = session.get(URL, params={'id': file_id}, stream=True)
    token = get_confirm_token(response)  # 从响应中获取确认令牌(如果需要)

    if token: # 如果需要确认(大文件)
        params = {'id': file_id, 'confirm': token}
        response = session.get(URL, params=params, stream=True)

    # 将响应内容保存到文件
    save_response_content(response, destination)

def get_confirm_token(response):
    """
    从响应中检查是否存在下载确认令牌(cookie)

    Args:
        response (requests.Response): 响应对象

    Returns:
        str: 确认令牌的值(如果存在),否则为None
    """
    for key, value in response.cookies.items():
        if key.startswith('download_warning'): # 确认令牌的cookie通常以这个开头
            return value
    return None

def save_response_content(response, destination, chunk_size=32768):
    """
    以流式方式将响应内容写入文件,支持大文件下载。

    Args:
        response (requests.Response): 流式响应对象
        destination (str): 本地保存路径
        chunk_size (int, optional): 每次迭代写入的块大小. Defaults to 32768.
    """
    with open(destination, "wb") as f:
        for chunk in response.iter_content(chunk_size):
            if chunk: # 过滤掉保持连接的空白块
                f.write(chunk)

def download_model_from_modelscope(destination,hf_cache_dir):
    """
    从ModelScope下载模型(伪代码,需根据实际API实现)
    Args:
        model_id (str): ModelScope模型ID
        destination (str): 本地保存路径
    """
    print(f"[ModelScope] Downloading models to {destination},model cache dir={hf_cache_dir}")
    from modelscope import snapshot_download
    os.makedirs(os.path.join(hf_cache_dir, "models--amphion--MaskGCT"), exist_ok=True)
    os.makedirs(os.path.join(hf_cache_dir, "models--facebook--w2v-bert-2.0"), exist_ok=True)
    os.makedirs(os.path.join(hf_cache_dir, "models--nvidia--bigvgan_v2_22khz_80band_256x"), exist_ok=True)
    os.makedirs(os.path.join(hf_cache_dir, "models--funasr--campplus"), exist_ok=True)
    snapshot_download("IndexTeam/IndexTTS-2", local_dir="checkpoints")
    snapshot_download("amphion/MaskGCT", local_dir="checkpoints/hf_cache/models--amphion--MaskGCT")
    snapshot_download("facebook/w2v-bert-2.0",local_dir="checkpoints/hf_cache/models--facebook--w2v-bert-2.0")
    snapshot_download("nv-community/bigvgan_v2_22khz_80band_256x",local_dir="checkpoints/hf_cache/models--nvidia--bigvgan_v2_22khz_80band_256x")
    # models--funasr--campplus
    snapshot_download("nv-community/bigvgan_v2_22khz_80band_256x",local_dir="checkpoints/hf_cache/models--nvidia--bigvgan_v2_22khz_80band_256x")

def download_model_from_huggingface(destination,hf_cache_dir):
    """
    从HuggingFace下载模型(伪代码,需根据实际API实现)
    Args:
        model_id (str): HuggingFace模型ID
        destination (str): 本地保存路径
    """
    print(f"[HuggingFace] Downloading models to {destination},model cache dir={hf_cache_dir}")
    from huggingface_hub import snapshot_download
    os.makedirs(os.path.join(hf_cache_dir,"models--amphion--MaskGCT"), exist_ok=True)
    os.makedirs(os.path.join(hf_cache_dir,"models--facebook--w2v-bert-2.0"), exist_ok=True)
    os.makedirs(os.path.join(hf_cache_dir, "models--nvidia--bigvgan_v2_22khz_80band_256x"), exist_ok=True)
    os.makedirs(os.path.join(hf_cache_dir,"models--funasr--campplus"), exist_ok=True)
    snapshot_download("IndexTeam/IndexTTS-2", local_dir=destination)
    print("[HuggingFace] IndexTTS-2 Download finished")
    # snapshot_download("amphion/MaskGCT", local_dir=os.path.join(hf_cache_dir,"models--amphion--MaskGCT"))
    # print("[HuggingFace] MaskGCT Download finished")
    # snapshot_download("facebook/w2v-bert-2.0",local_dir=os.path.join(hf_cache_dir,"models--facebook--w2v-bert-2.0"))
    snapshot_download("facebook/w2v-bert-2.0")
    print("[HuggingFace] w2v-bert-2.0 Download finished")
    snapshot_download("nvidia/bigvgan_v2_22khz_80band_256x",local_dir=os.path.join(hf_cache_dir, "models--nvidia--bigvgan_v2_22khz_80band_256x"))
    print("[HuggingFace] bigvgan_v2_22khz_80band_256x Download finished")
    snapshot_download("funasr/campplus",local_dir=os.path.join(hf_cache_dir,"models--funasr--campplus"))
    print("[HuggingFace] campplus Download finished")

# 使用示例
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="下载文件和模型工具")
    parser.add_argument('--model_source', choices=['modelscope', 'huggingface'], default=None, help='模型下载来源')
    args = parser.parse_args()

    if args.model_source:
        if args.model_source == 'modelscope':
            download_model_from_modelscope("checkpoints",os.path.join("checkpoints","hf_cache"))
        elif args.model_source == 'huggingface':
            download_model_from_huggingface("checkpoints",os.path.join("checkpoints","hf_cache"))

    print("Downloading example files from Google Drive...")
    file_id = "1o_dCMzwjaA2azbGOxAE7-4E7NbJkgdgO"
    destination = "example_wavs.zip" # 替换为你希望的本地路径
    download_file_from_google_drive(file_id, destination)
    print(f"File downloaded to: {destination}")
    # 解压下载的zip文件到examples目录
    examples_dir = "examples"
    with zipfile.ZipFile(destination, 'r') as zip_ref:
        zip_ref.extractall(examples_dir)
    print(f"File extracted to: {examples_dir}")