StableAvatar / app.py
YinmingHuang's picture
change duration
4599acf
import torch
import psutil
import argparse
import gradio as gr
import os
from diffusers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import load_image
from transformers import AutoTokenizer, Wav2Vec2Model, Wav2Vec2Processor
from omegaconf import OmegaConf
from wan.models.cache_utils import get_teacache_coefficients
from wan.models.wan_fantasy_transformer3d_1B import WanTransformer3DFantasyModel
from wan.models.wan_text_encoder import WanT5EncoderModel
from wan.models.wan_vae import AutoencoderKLWan
from wan.models.wan_image_encoder import CLIPModel
from wan.pipeline.wan_inference_long_pipeline import WanI2VTalkingInferenceLongPipeline
from wan.utils.fp8_optimization import replace_parameters_by_name, convert_weight_dtype_wrapper, convert_model_weight_to_float8
from wan.utils.utils import get_image_to_video_latent, save_videos_grid
import numpy as np
import librosa
import datetime
import random
import math
import subprocess
from moviepy.editor import VideoFileClip
from huggingface_hub import snapshot_download
import shutil
import spaces
try:
from audio_separator.separator import Separator
except:
print("Unable to use vocal separation feature. Please install audio-separator[gpu].")
if torch.cuda.is_available():
device = "cuda"
if torch.cuda.get_device_capability()[0] >= 8:
dtype = torch.bfloat16
else:
dtype = torch.float16
else:
device = "cpu"
dtype = torch.float32
def filter_kwargs(cls, kwargs):
import inspect
sig = inspect.signature(cls.__init__)
valid_params = set(sig.parameters.keys()) - {'self', 'cls'}
filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
return filtered_kwargs
def load_transformer_model(model_version):
"""
根据选择的模型版本加载对应的transformer模型
Args:
model_version (str): 模型版本,"square" 或 "rec_vec"
Returns:
WanTransformer3DFantasyModel: 加载的transformer模型
"""
global transformer3d
if model_version == "square":
transformer_path = os.path.join(repo_root, "StableAvatar-1.3B", "transformer3d-square.pt")
elif model_version == "rec_vec":
transformer_path = os.path.join(repo_root, "StableAvatar-1.3B", "transformer3d-rec-vec.pt")
else:
# 默认使用square版本
transformer_path = os.path.join(repo_root, "StableAvatar-1.3B", "transformer3d-square.pt")
print(f"正在加载模型: {transformer_path}")
if os.path.exists(transformer_path):
state_dict = torch.load(transformer_path, map_location="cpu")
state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
m, u = transformer3d.load_state_dict(state_dict, strict=False)
print(f"模型加载成功: {transformer_path}")
print(f"Missing keys: {len(m)}; Unexpected keys: {len(u)}")
return transformer3d
else:
print(f"错误:模型文件不存在: {transformer_path}")
return None
REPO_ID = "FrancisRing/StableAvatar"
repo_root = snapshot_download(
repo_id=REPO_ID,
allow_patterns=[
"StableAvatar-1.3B/*",
"Wan2.1-Fun-V1.1-1.3B-InP/*",
"wav2vec2-base-960h/*",
"assets/**",
"Kim_Vocal_2.onnx",
],
)
pretrained_model_name_or_path = os.path.join(repo_root, "Wan2.1-Fun-V1.1-1.3B-InP")
pretrained_wav2vec_path = os.path.join(repo_root, "wav2vec2-base-960h")
# 人声分离 onnx
audio_separator_model_file = os.path.join(repo_root, "Kim_Vocal_2.onnx")
# model_path = "/datadrive/stableavatar/checkpoints"
# pretrained_model_name_or_path = f"{model_path}/Wan2.1-Fun-V1.1-1.3B-InP"
# pretrained_wav2vec_path = f"{model_path}/wav2vec2-base-960h"
# transformer_path = f"{model_path}/StableAvatar-1.3B/transformer3d-square.pt"
config = OmegaConf.load("deepspeed_config/wan2.1/wan_civitai.yaml")
sampler_name = "Flow"
clip_sample_n_frames = 81
tokenizer = AutoTokenizer.from_pretrained(os.path.join(pretrained_model_name_or_path, config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')), )
text_encoder = WanT5EncoderModel.from_pretrained(
os.path.join(pretrained_model_name_or_path, config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')),
additional_kwargs=OmegaConf.to_container(config['text_encoder_kwargs']),
low_cpu_mem_usage=True,
torch_dtype=dtype,
)
text_encoder = text_encoder.eval()
vae = AutoencoderKLWan.from_pretrained(
os.path.join(pretrained_model_name_or_path, config['vae_kwargs'].get('vae_subpath', 'vae')),
additional_kwargs=OmegaConf.to_container(config['vae_kwargs']),
)
wav2vec_processor = Wav2Vec2Processor.from_pretrained(pretrained_wav2vec_path)
wav2vec = Wav2Vec2Model.from_pretrained(pretrained_wav2vec_path).to("cpu")
clip_image_encoder = CLIPModel.from_pretrained(os.path.join(pretrained_model_name_or_path, config['image_encoder_kwargs'].get('image_encoder_subpath', 'image_encoder')), )
clip_image_encoder = clip_image_encoder.eval()
transformer3d = WanTransformer3DFantasyModel.from_pretrained(
os.path.join(pretrained_model_name_or_path, config['transformer_additional_kwargs'].get('transformer_subpath', 'transformer')),
transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']),
low_cpu_mem_usage=False,
torch_dtype=dtype,
)
# 默认加载square版本模型
load_transformer_model("square")
Choosen_Scheduler = scheduler_dict = {
"Flow": FlowMatchEulerDiscreteScheduler,
}[sampler_name]
scheduler = Choosen_Scheduler(
**filter_kwargs(Choosen_Scheduler, OmegaConf.to_container(config['scheduler_kwargs']))
)
pipeline = WanI2VTalkingInferenceLongPipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
vae=vae,
transformer=transformer3d,
clip_image_encoder=clip_image_encoder,
scheduler=scheduler,
wav2vec_processor=wav2vec_processor,
wav2vec=wav2vec,
)
@spaces.GPU(duration=120)
def generate(
GPU_memory_mode,
teacache_threshold,
num_skip_start_steps,
image_path,
audio_path,
prompt,
negative_prompt,
width,
height,
guidance_scale,
num_inference_steps,
text_guide_scale,
audio_guide_scale,
motion_frame,
fps,
overlap_window_length,
seed_param,
overlapping_weight_scheme,
progress=gr.Progress(track_tqdm=True),
):
global pipeline, transformer3d
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
if seed_param<0:
seed = random.randint(0, np.iinfo(np.int32).max)
else:
seed = seed_param
if GPU_memory_mode == "sequential_cpu_offload":
replace_parameters_by_name(transformer3d, ["modulation", ], device=device)
transformer3d.freqs = transformer3d.freqs.to(device=device)
pipeline.enable_sequential_cpu_offload(device=device)
elif GPU_memory_mode == "model_cpu_offload_and_qfloat8":
convert_model_weight_to_float8(transformer3d, exclude_module_name=["modulation", ])
convert_weight_dtype_wrapper(transformer3d, dtype)
pipeline.enable_model_cpu_offload(device=device)
elif GPU_memory_mode == "model_cpu_offload":
pipeline.enable_model_cpu_offload(device=device)
else:
pipeline.to(device=device)
if teacache_threshold > 0:
coefficients = get_teacache_coefficients(pretrained_model_name_or_path)
pipeline.transformer.enable_teacache(
coefficients,
num_inference_steps,
teacache_threshold,
num_skip_start_steps=num_skip_start_steps,
)
with torch.no_grad():
video_length = int((clip_sample_n_frames - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if clip_sample_n_frames != 1 else 1
input_video, input_video_mask, clip_image = get_image_to_video_latent(image_path, None, video_length=video_length, sample_size=[height, width])
sr = 16000
vocal_input, sample_rate = librosa.load(audio_path, sr=sr)
sample = pipeline(
prompt,
num_frames=video_length,
negative_prompt=negative_prompt,
width=width,
height=height,
guidance_scale=guidance_scale,
generator=torch.Generator().manual_seed(seed),
num_inference_steps=num_inference_steps,
video=input_video,
mask_video=input_video_mask,
clip_image=clip_image,
text_guide_scale=text_guide_scale,
audio_guide_scale=audio_guide_scale,
vocal_input_values=vocal_input,
motion_frame=motion_frame,
fps=fps,
sr=sr,
cond_file_path=image_path,
overlap_window_length=overlap_window_length,
seed=seed,
overlapping_weight_scheme=overlapping_weight_scheme,
).videos
os.makedirs("outputs", exist_ok=True)
video_path = os.path.join("outputs", f"{timestamp}.mp4")
save_videos_grid(sample, video_path, fps=fps)
output_video_with_audio = os.path.join("outputs", f"{timestamp}_audio.mp4")
subprocess.run([
"ffmpeg", "-y", "-loglevel", "quiet", "-i", video_path, "-i", audio_path,
"-c:v", "copy", "-c:a", "aac", "-strict", "experimental",
output_video_with_audio
], check=True)
return output_video_with_audio, seed, f"Generated outputs/{timestamp}.mp4 / 已生成outputs/{timestamp}.mp4"
def exchange_width_height(width, height):
return height, width, "✅ Width and Height Swapped / 宽高交换完毕"
def adjust_width_height(image):
image = load_image(image)
width, height = image.size
original_area = width * height
default_area = 512*512
ratio = math.sqrt(original_area / default_area)
width = width / ratio // 16 * 16
height = height / ratio // 16 * 16
return int(width), int(height), "✅ Adjusted Size Based on Image / 根据图片调整宽高"
def audio_extractor(video_path):
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
os.makedirs("outputs", exist_ok=True) # 确保目录存在
out_wav = os.path.abspath(os.path.join("outputs", f"{timestamp}.wav"))
video = VideoFileClip(video_path)
audio = video.audio
audio.write_audiofile(out_wav, codec="pcm_s16le")
return out_wav, f"Generated {out_wav} / 已生成 {out_wav}", out_wav # ← 第3个返回给 gr.File
def vocal_separation(audio_path):
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
os.makedirs("outputs", exist_ok=True)
# audio_separator_model_file = "checkpoints/Kim_Vocal_2.onnx"
audio_separator = Separator(
output_dir=os.path.abspath(os.path.join("outputs", timestamp)),
output_single_stem="vocals",
model_file_dir=os.path.dirname(audio_separator_model_file),
)
audio_separator.load_model(os.path.basename(audio_separator_model_file))
assert audio_separator.model_instance is not None, "Fail to load audio separate model."
outputs = audio_separator.separate(audio_path)
vocal_audio_file = os.path.join(audio_separator.output_dir, outputs[0])
destination_file = os.path.abspath(os.path.join("outputs", f"{timestamp}.wav"))
shutil.copy(vocal_audio_file, destination_file)
os.remove(vocal_audio_file)
return destination_file, f"Generated {destination_file} / 已生成 {destination_file}", destination_file
def update_language(language):
if language == "English":
return {
GPU_memory_mode: gr.Dropdown(label="GPU Memory Mode", info="Normal uses 25G VRAM, model_cpu_offload uses 13G VRAM"),
teacache_threshold: gr.Slider(label="TeaCache Threshold", info="Recommended 0.1, 0 disables TeaCache acceleration"),
num_skip_start_steps: gr.Slider(label="Skip Start Steps", info="Recommended 5"),
model_version: gr.Dropdown(label="Model Version", choices=["square", "rec_vec"], value="square"),
image_path: gr.Image(label="Upload Image"),
audio_path: gr.Audio(label="Upload Audio"),
prompt: gr.Textbox(label="Prompt"),
negative_prompt: gr.Textbox(label="Negative Prompt"),
generate_button: gr.Button("🎬 Start Generation"),
width: gr.Slider(label="Width"),
height: gr.Slider(label="Height"),
exchange_button: gr.Button("🔄 Swap Width/Height"),
adjust_button: gr.Button("Adjust Size Based on Image"),
guidance_scale: gr.Slider(label="Guidance Scale"),
num_inference_steps: gr.Slider(label="Sampling Steps (Recommended 50)"),
text_guide_scale: gr.Slider(label="Text Guidance Scale"),
audio_guide_scale: gr.Slider(label="Audio Guidance Scale"),
motion_frame: gr.Slider(label="Motion Frame"),
fps: gr.Slider(label="FPS"),
overlap_window_length: gr.Slider(label="Overlap Window Length"),
seed_param: gr.Number(label="Seed (positive integer, -1 for random)"),
overlapping_weight_scheme: gr.Dropdown(label="Overlapping Weight Scheme", choices=["uniform", "log"], value="uniform"),
info: gr.Textbox(label="Status"),
video_output: gr.Video(label="Generated Result"),
seed_output: gr.Textbox(label="Seed"),
video_path: gr.Video(label="Upload Video"),
extractor_button: gr.Button("🎬 Start Extraction"),
info2: gr.Textbox(label="Status"),
audio_output: gr.Audio(label="Generated Result"),
audio_path3: gr.Audio(label="Upload Audio"),
separation_button: gr.Button("🎬 Start Separation"),
info3: gr.Textbox(label="Status"),
audio_output3: gr.Audio(label="Generated Result"),
example_title: gr.Markdown(value="### Select the following example cases for testing:"),
example1_label: gr.Markdown(value="**Example 1**"),
example2_label: gr.Markdown(value="**Example 2**"),
example3_label: gr.Markdown(value="**Example 3**"),
example4_label: gr.Markdown(value="**Example 4**"),
example5_label: gr.Markdown(value="**Example 5**"),
example1_btn: gr.Button("🚀 Use Example 1", variant="secondary"),
example2_btn: gr.Button("🚀 Use Example 2", variant="secondary"),
example3_btn: gr.Button("🚀 Use Example 3", variant="secondary"),
example4_btn: gr.Button("🚀 Use Example 4", variant="secondary"),
example5_btn: gr.Button("🚀 Use Example 5", variant="secondary"),
parameter_settings_title: gr.Accordion(label="Parameter Settings", open=True),
example_cases_title: gr.Accordion(label="Example Cases", open=True),
stableavatar_title: gr.TabItem(label="StableAvatar"),
audio_extraction_title: gr.TabItem(label="Audio Extraction"),
vocal_separation_title: gr.TabItem(label="Vocal Separation")
}
else:
return {
GPU_memory_mode: gr.Dropdown(label="显存模式", info="Normal占用25G显存,model_cpu_offload占用13G显存"),
teacache_threshold: gr.Slider(label="teacache threshold", info="推荐参数0.1,0为禁用teacache加速"),
num_skip_start_steps: gr.Slider(label="跳过开始步数", info="推荐参数5"),
model_version: gr.Dropdown(label="模型版本", choices=["square", "rec_vec"], value="square"),
image_path: gr.Image(label="上传图片"),
audio_path: gr.Audio(label="上传音频"),
prompt: gr.Textbox(label="提示词"),
negative_prompt: gr.Textbox(label="负面提示词"),
generate_button: gr.Button("🎬 开始生成"),
width: gr.Slider(label="宽度"),
height: gr.Slider(label="高度"),
exchange_button: gr.Button("🔄 交换宽高"),
adjust_button: gr.Button("根据图片调整宽高"),
guidance_scale: gr.Slider(label="guidance scale"),
num_inference_steps: gr.Slider(label="采样步数(推荐50步)", minimum=1, maximum=100, step=1, value=50),
text_guide_scale: gr.Slider(label="text guidance scale"),
audio_guide_scale: gr.Slider(label="audio guidance scale"),
motion_frame: gr.Slider(label="motion frame"),
fps: gr.Slider(label="帧率"),
overlap_window_length: gr.Slider(label="overlap window length"),
seed_param: gr.Number(label="种子,请输入正整数,-1为随机"),
overlapping_weight_scheme: gr.Dropdown(label="Overlapping Weight Scheme", choices=["uniform", "log"], value="uniform"),
info: gr.Textbox(label="提示信息"),
video_output: gr.Video(label="生成结果"),
seed_output: gr.Textbox(label="种子"),
video_path: gr.Video(label="上传视频"),
extractor_button: gr.Button("🎬 开始提取"),
info2: gr.Textbox(label="提示信息"),
audio_output: gr.Audio(label="生成结果"),
audio_path3: gr.Audio(label="上传音频"),
separation_button: gr.Button("🎬 开始分离"),
info3: gr.Textbox(label="提示信息"),
audio_output3: gr.Audio(label="生成结果"),
example_title: gr.Markdown(value="### 选择以下示例案例进行测试:"),
example1_label: gr.Markdown(value="**示例 1**"),
example2_label: gr.Markdown(value="**示例 2**"),
example3_label: gr.Markdown(value="**示例 3**"),
example4_label: gr.Markdown(value="**示例 4**"),
example5_label: gr.Markdown(value="**示例 5**"),
example1_btn: gr.Button("🚀 使用示例 1", variant="secondary"),
example2_btn: gr.Button("🚀 使用示例 2", variant="secondary"),
example3_btn: gr.Button("🚀 使用示例 3", variant="secondary"),
example4_btn: gr.Button("🚀 使用示例 4", variant="secondary"),
example5_btn: gr.Button("🚀 使用示例 5", variant="secondary"),
parameter_settings_title: gr.Accordion(label="参数设置", open=True),
example_cases_title: gr.Accordion(label="示例案例", open=True),
stableavatar_title: gr.TabItem(label="StableAvatar"),
audio_extraction_title: gr.TabItem(label="音频提取"),
vocal_separation_title: gr.TabItem(label="人声分离")
}
BANNER_HTML = """
<div class="hero">
<div class="brand">
<!-- 如有项目 logo,可放到仓库并换成你的地址;没有就删这一行 -->
<!-- <img src="https://raw.githubusercontent.com/Francis-Rings/StableAvatar/main/assets/logo.png" alt="StableAvatar Logo"> -->
<span class="brand-text">STABLEAVATAR</span>
</div>
<div class="titles">
<h1>StableAvatar</h1>
<div class="badges">
<a class="badge" href="https://arxiv.org/abs/2508.08248" target="_blank" rel="noopener">
<img src="https://img.shields.io/badge/arXiv-2508.08248-b31b1b">
</a>
<a class="badge" href="https://francis-rings.github.io/StableAvatar/" target="_blank" rel="noopener">
<img src="https://img.shields.io/badge/Webpage-Visit-2266ee">
</a>
<a class="badge" href="https://github.com/Francis-Rings/StableAvatar" target="_blank" rel="noopener">
<img src="https://img.shields.io/badge/GitHub-Repo-181717?logo=github&logoColor=white">
</a>
<a class="badge" href="https://www.youtube.com/watch?v=6lhvmbzvv3Y" target="_blank" rel="noopener">
<img src="https://img.shields.io/badge/YouTube-Demo-ff0000?logo=youtube&logoColor=white">
</a>
</div>
</div>
</div>
<hr class="divider">
"""
BANNER_CSS = """
.hero{display:flex;align-items:center;gap:18px;padding:18px;border-radius:14px;
color:inherit;margin-bottom:12px}
.brand-text{font-weight:800;letter-spacing:2px}
.brand img{height:46px}
.titles h1{font-size:28px;margin:0 0 6px 0}
.badges{display:flex;gap:10px;flex-wrap:wrap}
.badge img{height:22px}
.divider{border:0;border-top:1px solid rgba(0,0,0,0.12);margin:6px 0 18px}
"""
# with gr.Blocks(theme=gr.themes.Base()) as demo:
# gr.Markdown("""
# <div>
# <h2 style="font-size: 30px;text-align: center;">StableAvatar</h2>
# </div>
# """)
with gr.Blocks(theme=gr.themes.Base(), css=BANNER_CSS) as demo:
gr.HTML(BANNER_HTML)
language_radio = gr.Radio(
choices=["English", "中文"],
value="English",
label="Language / 语言"
)
with gr.Accordion("Model Settings / 模型设置", open=False):
with gr.Row():
GPU_memory_mode = gr.Dropdown(
label = "显存模式",
info = "Normal占用25G显存,model_cpu_offload占用13G显存",
choices = ["Normal", "model_cpu_offload", "model_cpu_offloadand_qfloat8", "sequential_cpu_offload"],
value = "model_cpu_offload"
)
teacache_threshold = gr.Slider(label="teacache threshold", info = "推荐参数0.1,0为禁用teacache加速", minimum=0, maximum=1, step=0.01, value=0)
num_skip_start_steps = gr.Slider(label="跳过开始步数", info = "推荐参数5", minimum=0, maximum=100, step=1, value=5)
with gr.Row():
model_version = gr.Dropdown(
label = "模型版本",
choices = ["square","rec_vec"],
value = "square"
)
stableavatar_title = gr.TabItem(label="StableAvatar")
with stableavatar_title:
with gr.Row():
with gr.Column():
with gr.Row():
image_path = gr.Image(label="上传图片", type="filepath", height=280)
audio_path = gr.Audio(label="上传音频", type="filepath")
prompt = gr.Textbox(label="提示词", value="")
negative_prompt = gr.Textbox(label="负面提示词", value="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走")
generate_button = gr.Button("🎬 开始生成", variant='primary')
parameter_settings_title = gr.Accordion(label="参数设置", open=True)
with parameter_settings_title:
with gr.Row():
width = gr.Slider(label="宽度", minimum=256, maximum=2048, step=16, value=512)
height = gr.Slider(label="高度", minimum=256, maximum=2048, step=16, value=512)
with gr.Row():
exchange_button = gr.Button("🔄 交换宽高")
adjust_button = gr.Button("根据图片调整宽高")
with gr.Row():
guidance_scale = gr.Slider(label="guidance scale", minimum=1.0, maximum=10.0, step=0.1, value=6.0)
num_inference_steps = gr.Slider(label="采样步数(推荐50步)", minimum=1, maximum=100, step=1, value=50)
with gr.Row():
text_guide_scale = gr.Slider(label="text guidance scale", minimum=1.0, maximum=10.0, step=0.1, value=3.0)
audio_guide_scale = gr.Slider(label="audio guidance scale", minimum=1.0, maximum=10.0, step=0.1, value=5.0)
with gr.Row():
motion_frame = gr.Slider(label="motion frame", minimum=1, maximum=50, step=1, value=25)
fps = gr.Slider(label="帧率", minimum=1, maximum=60, step=1, value=25)
with gr.Row():
overlap_window_length = gr.Slider(label="overlap window length", minimum=1, maximum=20, step=1, value=10)
seed_param = gr.Number(label="种子,请输入正整数,-1为随机", value=42)
with gr.Row():
overlapping_weight_scheme = gr.Dropdown(label="Overlapping Weight Scheme", choices=["uniform", "log"], value="uniform")
with gr.Column():
info = gr.Textbox(label="提示信息", interactive=False)
video_output = gr.Video(label="生成结果", interactive=False)
seed_output = gr.Textbox(label="种子")
# 示例案例部分移到StableAvatar标签页内部
example_cases_title = gr.Accordion(label="示例案例", open=True)
with example_cases_title:
example_title = gr.Markdown(value="### 选择以下示例案例进行测试:")
with gr.Row():
with gr.Column():
example1_label = gr.Markdown(value="**示例 1**")
example1_image = gr.Image(value="example_case/case-1/reference.png", label="", interactive=False, height=120, show_label=False)
example1_audio = gr.Audio(value="example_case/case-1/audio.wav", label="", interactive=False, show_label=False)
example1_btn = gr.Button("🚀 使用示例 1", variant="secondary", size="sm")
with gr.Column():
example2_label = gr.Markdown(value="**示例 2**")
example2_image = gr.Image(value="example_case/case-2/reference.png", label="", interactive=False, height=120, show_label=False)
example2_audio = gr.Audio(value="example_case/case-2/audio.wav", label="", interactive=False, show_label=False)
example2_btn = gr.Button("🚀 使用示例 2", variant="secondary", size="sm")
with gr.Column():
example3_label = gr.Markdown(value="**示例 3**")
example3_image = gr.Image(value="example_case/case-6/reference.png", label="", interactive=False, height=120, show_label=False)
example3_audio = gr.Audio(value="example_case/case-6/audio.wav", label="", interactive=False, show_label=False)
example3_btn = gr.Button("🚀 使用示例 3", variant="secondary", size="sm")
with gr.Column():
example4_label = gr.Markdown(value="**示例 4**")
example4_image = gr.Image(value="example_case/case-45/reference.png", label="", interactive=False, height=120, show_label=False)
example4_audio = gr.Audio(value="example_case/case-45/audio.wav", label="", interactive=False, show_label=False)
example4_btn = gr.Button("🚀 使用示例 4", variant="secondary", size="sm")
with gr.Column():
example5_label = gr.Markdown(value="**示例 5**")
example5_image = gr.Image(value="example_case/case-3/reference.jpg", label="", interactive=False, height=120, show_label=False)
example5_audio = gr.Audio(value="example_case/case-3/audio.wav", label="", interactive=False, show_label=False)
example5_btn = gr.Button("🚀 使用示例 5", variant="secondary", size="sm")
audio_extraction_title = gr.TabItem(label="音频提取")
with audio_extraction_title:
with gr.Row():
with gr.Column():
video_path = gr.Video(label="上传视频", height=500)
extractor_button = gr.Button("🎬 开始提取", variant='primary')
with gr.Column():
info2 = gr.Textbox(label="提示信息", interactive=False)
audio_output = gr.Audio(label="生成结果", interactive=False)
audio_file = gr.File(label="download audio file")
vocal_separation_title = gr.TabItem(label="人声分离")
with vocal_separation_title:
with gr.Row():
with gr.Column():
audio_path3 = gr.Audio(label="上传音频", type="filepath")
separation_button = gr.Button("🎬 开始分离", variant='primary')
with gr.Column():
info3 = gr.Textbox(label="提示信息", interactive=False)
audio_output3 = gr.Audio(label="生成结果", interactive=False)
audio_file3 = gr.File(label="download audio file")
# 示例案例部分移到末尾
# example_cases_title = gr.Accordion(label="示例案例", open=True)
# with example_cases_title:
# example_title = gr.Markdown(value="### 选择以下示例案例进行测试:")
# with gr.Row():
# with gr.Column():
# example1_label = gr.Markdown(value="**示例 1**")
# example1_image = gr.Image(value="example_case/case-1/reference.png", label="", interactive=False, height=120, show_label=False)
# example1_audio = gr.Audio(value="example_case/case-1/audio.wav", label="", interactive=False, show_label=False)
# example1_btn = gr.Button("🚀 使用示例 1", variant="secondary", size="sm")
# with gr.Column():
# example2_label = gr.Markdown(value="**示例 2**")
# example2_image = gr.Image(value="example_case/case-2/reference.png", label="", interactive=False, height=120, show_label=False)
# example2_audio = gr.Audio(value="example_case/case-2/audio.wav", label="", interactive=False, show_label=False)
# example2_btn = gr.Button("🚀 使用示例 2", variant="secondary", size="sm")
# with gr.Column():
# example3_label = gr.Markdown(value="**示例 3**")
# example3_image = gr.Image(value="example_case/case-6/reference.png", label="", interactive=False, height=120, show_label=False)
# example3_audio = gr.Audio(value="example_case/case-6/audio.wav", label="", interactive=False, show_label=False)
# example3_btn = gr.Button("🚀 使用示例 3", variant="secondary", size="sm")
# with gr.Column():
# example4_label = gr.Markdown(value="**示例 4**")
# example4_image = gr.Image(value="example_case/case-45/reference.png", label="", interactive=False, height=120, show_label=False)
# example4_audio = gr.Audio(value="example_case/case-45/audio.wav", label="", interactive=False, show_label=False)
# example4_btn = gr.Button("🚀 使用示例 4", variant="secondary", size="sm")
# with gr.Column():
# example5_label = gr.Markdown(value="**示例 5**")
# example5_image = gr.Image(value="example_case/case-3/reference.jpg", label="", interactive=False, height=120, show_label=False)
# example5_audio = gr.Audio(value="example_case/case-3/audio.wav", label="", interactive=False, show_label=False)
# example5_btn = gr.Button("🚀 使用示例 5", variant="secondary", size="sm")
all_components = [GPU_memory_mode, teacache_threshold, num_skip_start_steps, model_version, image_path, audio_path, prompt, negative_prompt, generate_button, width, height, exchange_button, adjust_button, guidance_scale, num_inference_steps, text_guide_scale, audio_guide_scale, motion_frame, fps, overlap_window_length, seed_param, overlapping_weight_scheme, info, video_output, seed_output, video_path, extractor_button, info2, audio_output, audio_path3, separation_button, info3, audio_output3, example_title, example1_label, example2_label, example3_label, example4_label, example1_btn, example2_btn, example3_btn, example4_btn, example5_label, example5_btn, parameter_settings_title, example_cases_title, stableavatar_title, audio_extraction_title, vocal_separation_title]
language_radio.change(
fn=update_language,
inputs=[language_radio],
outputs=all_components
)
# 添加模型版本选择的事件处理
def on_model_version_change(model_version):
"""当模型版本改变时,重新加载对应的模型"""
result = load_transformer_model(model_version)
if result is not None:
return f"✅ 模型已切换到 {model_version} 版本"
else:
return f"❌ 模型切换失败,请检查文件是否存在"
model_version.change(
fn=on_model_version_change,
inputs=[model_version],
outputs=[info]
)
demo.load(fn=update_language, inputs=[language_radio], outputs=all_components)
# 添加示例案例按钮的事件处理
def load_example1():
try:
with open("example_case/case-1/prompt.txt", "r", encoding="utf-8") as f:
prompt_text = f.read().strip()
except:
prompt_text = ""
return "example_case/case-1/reference.png", "example_case/case-1/audio.wav", prompt_text
def load_example2():
try:
with open("example_case/case-2/prompt.txt", "r", encoding="utf-8") as f:
prompt_text = f.read().strip()
except:
prompt_text = ""
return "example_case/case-2/reference.png", "example_case/case-2/audio.wav", prompt_text
def load_example3():
try:
with open("example_case/case-6/prompt.txt", "r", encoding="utf-8") as f:
prompt_text = f.read().strip()
except:
prompt_text = ""
return "example_case/case-6/reference.png", "example_case/case-6/audio.wav", prompt_text
def load_example4():
try:
with open("example_case/case-45/prompt.txt", "r", encoding="utf-8") as f:
prompt_text = f.read().strip()
except:
prompt_text = ""
return "example_case/case-45/reference.png", "example_case/case-45/audio.wav", prompt_text
def load_example5():
try:
with open("example_case/case-3/prompt.txt", "r", encoding="utf-8") as f:
prompt_text = f.read().strip()
except:
prompt_text = ""
return "example_case/case-3/reference.jpg", "example_case/case-3/audio.wav", prompt_text
example1_btn.click(fn=load_example1, outputs=[image_path, audio_path, prompt])
example2_btn.click(fn=load_example2, outputs=[image_path, audio_path, prompt])
example3_btn.click(fn=load_example3, outputs=[image_path, audio_path, prompt])
example4_btn.click(fn=load_example4, outputs=[image_path, audio_path, prompt])
example5_btn.click(fn=load_example5, outputs=[image_path, audio_path, prompt])
gr.on(
triggers=[generate_button.click, prompt.submit, negative_prompt.submit],
fn = generate,
inputs = [
GPU_memory_mode,
teacache_threshold,
num_skip_start_steps,
image_path,
audio_path,
prompt,
negative_prompt,
width,
height,
guidance_scale,
num_inference_steps,
text_guide_scale,
audio_guide_scale,
motion_frame,
fps,
overlap_window_length,
seed_param,
overlapping_weight_scheme,
],
outputs = [video_output, seed_output, info]
)
exchange_button.click(
fn=exchange_width_height,
inputs=[width, height],
outputs=[width, height, info]
)
adjust_button.click(
fn=adjust_width_height,
inputs=[image_path],
outputs=[width, height, info]
)
extractor_button.click(
fn=audio_extractor,
inputs=[video_path],
outputs=[audio_output, info2, audio_file]
)
separation_button.click(
fn=vocal_separation,
inputs=[audio_path3],
outputs=[audio_output3, info3, audio_file3]
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=int(os.getenv("PORT", 7860)),
share=False,
inbrowser=False,
)