File size: 2,110 Bytes
7f42fd2
 
 
 
 
 
a6e0c21
 
7f42fd2
 
a6e0c21
 
7f42fd2
 
 
 
 
 
 
 
 
 
a6e0c21
 
7f42fd2
 
 
a6e0c21
 
7f42fd2
 
 
 
a6e0c21
7f42fd2
 
 
 
 
 
 
 
 
 
a6e0c21
 
 
7f42fd2
 
 
 
a6e0c21
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import argparse
from ui.components import create_main_demo_ui
from pipeline_ace_step import ACEStepPipeline
from data_sampler import DataSampler
import os

# 获取当前脚本的绝对路径,用于构建默认的存储路径
APP_ROOT = os.path.dirname(os.path.abspath(__file__))

parser = argparse.ArgumentParser()
# 将 checkpoint_path 的默认值改为应用程序根目录下的 'checkpoints' 文件夹
parser.add_argument("--checkpoint_path", type=str, default=os.path.join(APP_ROOT, "checkpoints"))
parser.add_argument("--server_name", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=7860)
parser.add_argument("--device_id", type=int, default=0)
parser.add_argument("--share", action='store_true', default=False)
parser.add_argument("--bf16", action='store_true', default=True)
parser.add_argument("--torch_compile", type=bool, default=False)

args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device_id)

# 将 persistent_storage_path 的默认值改为应用程序根目录下的 'persistent_data' 文件夹
persistent_storage_path = os.path.join(APP_ROOT, "persistent_data")


def main(args):
    print(f"Using checkpoint path: {args.checkpoint_path}")
    print(f"Using persistent storage path: {persistent_storage_path}")

    model_demo = ACEStepPipeline(
        checkpoint_dir=args.checkpoint_path,
        dtype="bfloat16" if args.bf16 else "float32",
        persistent_storage_path=persistent_storage_path, # 传递修改后的路径
        torch_compile=args.torch_compile
    )
    data_sampler = DataSampler()

    demo = create_main_demo_ui(
        text2music_process_func=model_demo.__call__,
        sample_data_func=data_sampler.sample,
        load_data_func=data_sampler.load_json,
    )
    demo.queue(default_concurrency_limit=8).launch(
        server_name=args.server_name, # 添加这一行以使用命令行参数
        server_port=args.port,       # 添加这一行以使用命令行参数
        share=args.share             # 添加这一行以使用命令行参数
    )


if __name__ == "__main__":
    main(args)