Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,956 Bytes
506ecd3 5fdd10a 506ecd3 5fdd10a 506ecd3 5fdd10a ff769a6 a6fb107 839c1dd 506ecd3 839c1dd 506ecd3 839c1dd 506ecd3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import requests
import zipfile
import os
import argparse
def download_file_from_google_drive(file_id, destination):
"""
通过文件ID下载Google Drive共享文件
Args:
file_id (str): Google Drive文件的ID
destination (str): 本地保存路径
"""
# 基本的下载URL
URL = "https://docs.google.com/uc?export=download"
session = requests.Session()
# 发起初始GET请求
response = session.get(URL, params={'id': file_id}, stream=True)
token = get_confirm_token(response) # 从响应中获取确认令牌(如果需要)
if token: # 如果需要确认(大文件)
params = {'id': file_id, 'confirm': token}
response = session.get(URL, params=params, stream=True)
# 将响应内容保存到文件
save_response_content(response, destination)
def get_confirm_token(response):
"""
从响应中检查是否存在下载确认令牌(cookie)
Args:
response (requests.Response): 响应对象
Returns:
str: 确认令牌的值(如果存在),否则为None
"""
for key, value in response.cookies.items():
if key.startswith('download_warning'): # 确认令牌的cookie通常以这个开头
return value
return None
def save_response_content(response, destination, chunk_size=32768):
"""
以流式方式将响应内容写入文件,支持大文件下载。
Args:
response (requests.Response): 流式响应对象
destination (str): 本地保存路径
chunk_size (int, optional): 每次迭代写入的块大小. Defaults to 32768.
"""
with open(destination, "wb") as f:
for chunk in response.iter_content(chunk_size):
if chunk: # 过滤掉保持连接的空白块
f.write(chunk)
def download_model_from_modelscope(destination,hf_cache_dir):
"""
从ModelScope下载模型(伪代码,需根据实际API实现)
Args:
model_id (str): ModelScope模型ID
destination (str): 本地保存路径
"""
print(f"[ModelScope] Downloading models to {destination},model cache dir={hf_cache_dir}")
from modelscope import snapshot_download
os.makedirs(os.path.join(hf_cache_dir, "models--amphion--MaskGCT"), exist_ok=True)
os.makedirs(os.path.join(hf_cache_dir, "models--facebook--w2v-bert-2.0"), exist_ok=True)
os.makedirs(os.path.join(hf_cache_dir, "models--nvidia--bigvgan_v2_22khz_80band_256x"), exist_ok=True)
os.makedirs(os.path.join(hf_cache_dir, "models--funasr--campplus"), exist_ok=True)
snapshot_download("IndexTeam/IndexTTS-2", local_dir="checkpoints")
snapshot_download("amphion/MaskGCT", local_dir="checkpoints/hf_cache/models--amphion--MaskGCT")
snapshot_download("facebook/w2v-bert-2.0",local_dir="checkpoints/hf_cache/models--facebook--w2v-bert-2.0")
snapshot_download("nv-community/bigvgan_v2_22khz_80band_256x",local_dir="checkpoints/hf_cache/models--nvidia--bigvgan_v2_22khz_80band_256x")
# models--funasr--campplus
snapshot_download("nv-community/bigvgan_v2_22khz_80band_256x",local_dir="checkpoints/hf_cache/models--nvidia--bigvgan_v2_22khz_80band_256x")
def download_model_from_huggingface(destination,hf_cache_dir):
"""
从HuggingFace下载模型(伪代码,需根据实际API实现)
Args:
model_id (str): HuggingFace模型ID
destination (str): 本地保存路径
"""
print(f"[HuggingFace] Downloading models to {destination},model cache dir={hf_cache_dir}")
from huggingface_hub import snapshot_download
os.makedirs(os.path.join(hf_cache_dir,"models--amphion--MaskGCT"), exist_ok=True)
os.makedirs(os.path.join(hf_cache_dir,"models--facebook--w2v-bert-2.0"), exist_ok=True)
os.makedirs(os.path.join(hf_cache_dir, "models--nvidia--bigvgan_v2_22khz_80band_256x"), exist_ok=True)
os.makedirs(os.path.join(hf_cache_dir,"models--funasr--campplus"), exist_ok=True)
snapshot_download("IndexTeam/IndexTTS-2", local_dir=destination)
print("[HuggingFace] IndexTTS-2 Download finished")
# snapshot_download("amphion/MaskGCT", local_dir=os.path.join(hf_cache_dir,"models--amphion--MaskGCT"))
# print("[HuggingFace] MaskGCT Download finished")
# snapshot_download("facebook/w2v-bert-2.0",local_dir=os.path.join(hf_cache_dir,"models--facebook--w2v-bert-2.0"))
snapshot_download("facebook/w2v-bert-2.0")
print("[HuggingFace] w2v-bert-2.0 Download finished")
snapshot_download("nvidia/bigvgan_v2_22khz_80band_256x",local_dir=os.path.join(hf_cache_dir, "models--nvidia--bigvgan_v2_22khz_80band_256x"))
print("[HuggingFace] bigvgan_v2_22khz_80band_256x Download finished")
snapshot_download("funasr/campplus",local_dir=os.path.join(hf_cache_dir,"models--funasr--campplus"))
print("[HuggingFace] campplus Download finished")
# 使用示例
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="下载文件和模型工具")
parser.add_argument('--model_source', choices=['modelscope', 'huggingface'], default=None, help='模型下载来源')
args = parser.parse_args()
if args.model_source:
if args.model_source == 'modelscope':
download_model_from_modelscope("checkpoints",os.path.join("checkpoints","hf_cache"))
elif args.model_source == 'huggingface':
download_model_from_huggingface("checkpoints",os.path.join("checkpoints","hf_cache"))
print("Downloading example files from Google Drive...")
file_id = "1o_dCMzwjaA2azbGOxAE7-4E7NbJkgdgO"
destination = "example_wavs.zip" # 替换为你希望的本地路径
download_file_from_google_drive(file_id, destination)
print(f"File downloaded to: {destination}")
# 解压下载的zip文件到examples目录
examples_dir = "examples"
with zipfile.ZipFile(destination, 'r') as zip_ref:
zip_ref.extractall(examples_dir)
print(f"File extracted to: {examples_dir}")
|