import torch import psutil import argparse import gradio as gr import os from diffusers import FlowMatchEulerDiscreteScheduler from diffusers.utils import load_image from transformers import AutoTokenizer, Wav2Vec2Model, Wav2Vec2Processor from omegaconf import OmegaConf from wan.models.cache_utils import get_teacache_coefficients from wan.models.wan_fantasy_transformer3d_1B import WanTransformer3DFantasyModel from wan.models.wan_text_encoder import WanT5EncoderModel from wan.models.wan_vae import AutoencoderKLWan from wan.models.wan_image_encoder import CLIPModel from wan.pipeline.wan_inference_long_pipeline import WanI2VTalkingInferenceLongPipeline from wan.utils.fp8_optimization import replace_parameters_by_name, convert_weight_dtype_wrapper, convert_model_weight_to_float8 from wan.utils.utils import get_image_to_video_latent, save_videos_grid import numpy as np import librosa import datetime import random import math import subprocess from moviepy.editor import VideoFileClip from huggingface_hub import snapshot_download import shutil import spaces try: from audio_separator.separator import Separator except: print("Unable to use vocal separation feature. Please install audio-separator[gpu].") if torch.cuda.is_available(): device = "cuda" if torch.cuda.get_device_capability()[0] >= 8: dtype = torch.bfloat16 else: dtype = torch.float16 else: device = "cpu" dtype = torch.float32 def filter_kwargs(cls, kwargs): import inspect sig = inspect.signature(cls.__init__) valid_params = set(sig.parameters.keys()) - {'self', 'cls'} filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params} return filtered_kwargs def load_transformer_model(model_version): """ 根据选择的模型版本加载对应的transformer模型 Args: model_version (str): 模型版本,"square" 或 "rec_vec" Returns: WanTransformer3DFantasyModel: 加载的transformer模型 """ global transformer3d if model_version == "square": transformer_path = os.path.join(repo_root, "StableAvatar-1.3B", "transformer3d-square.pt") elif model_version == "rec_vec": transformer_path = os.path.join(repo_root, "StableAvatar-1.3B", "transformer3d-rec-vec.pt") else: # 默认使用square版本 transformer_path = os.path.join(repo_root, "StableAvatar-1.3B", "transformer3d-square.pt") print(f"正在加载模型: {transformer_path}") if os.path.exists(transformer_path): state_dict = torch.load(transformer_path, map_location="cpu") state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict m, u = transformer3d.load_state_dict(state_dict, strict=False) print(f"模型加载成功: {transformer_path}") print(f"Missing keys: {len(m)}; Unexpected keys: {len(u)}") return transformer3d else: print(f"错误:模型文件不存在: {transformer_path}") return None REPO_ID = "FrancisRing/StableAvatar" repo_root = snapshot_download( repo_id=REPO_ID, allow_patterns=[ "StableAvatar-1.3B/*", "Wan2.1-Fun-V1.1-1.3B-InP/*", "wav2vec2-base-960h/*", "assets/**", "Kim_Vocal_2.onnx", ], ) pretrained_model_name_or_path = os.path.join(repo_root, "Wan2.1-Fun-V1.1-1.3B-InP") pretrained_wav2vec_path = os.path.join(repo_root, "wav2vec2-base-960h") # 人声分离 onnx audio_separator_model_file = os.path.join(repo_root, "Kim_Vocal_2.onnx") # model_path = "/datadrive/stableavatar/checkpoints" # pretrained_model_name_or_path = f"{model_path}/Wan2.1-Fun-V1.1-1.3B-InP" # pretrained_wav2vec_path = f"{model_path}/wav2vec2-base-960h" # transformer_path = f"{model_path}/StableAvatar-1.3B/transformer3d-square.pt" config = OmegaConf.load("deepspeed_config/wan2.1/wan_civitai.yaml") sampler_name = "Flow" clip_sample_n_frames = 81 tokenizer = AutoTokenizer.from_pretrained(os.path.join(pretrained_model_name_or_path, config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')), ) text_encoder = WanT5EncoderModel.from_pretrained( os.path.join(pretrained_model_name_or_path, config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')), additional_kwargs=OmegaConf.to_container(config['text_encoder_kwargs']), low_cpu_mem_usage=True, torch_dtype=dtype, ) text_encoder = text_encoder.eval() vae = AutoencoderKLWan.from_pretrained( os.path.join(pretrained_model_name_or_path, config['vae_kwargs'].get('vae_subpath', 'vae')), additional_kwargs=OmegaConf.to_container(config['vae_kwargs']), ) wav2vec_processor = Wav2Vec2Processor.from_pretrained(pretrained_wav2vec_path) wav2vec = Wav2Vec2Model.from_pretrained(pretrained_wav2vec_path).to("cpu") clip_image_encoder = CLIPModel.from_pretrained(os.path.join(pretrained_model_name_or_path, config['image_encoder_kwargs'].get('image_encoder_subpath', 'image_encoder')), ) clip_image_encoder = clip_image_encoder.eval() transformer3d = WanTransformer3DFantasyModel.from_pretrained( os.path.join(pretrained_model_name_or_path, config['transformer_additional_kwargs'].get('transformer_subpath', 'transformer')), transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']), low_cpu_mem_usage=False, torch_dtype=dtype, ) # 默认加载square版本模型 load_transformer_model("square") Choosen_Scheduler = scheduler_dict = { "Flow": FlowMatchEulerDiscreteScheduler, }[sampler_name] scheduler = Choosen_Scheduler( **filter_kwargs(Choosen_Scheduler, OmegaConf.to_container(config['scheduler_kwargs'])) ) pipeline = WanI2VTalkingInferenceLongPipeline( tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer3d, clip_image_encoder=clip_image_encoder, scheduler=scheduler, wav2vec_processor=wav2vec_processor, wav2vec=wav2vec, ) @spaces.GPU(duration=120) def generate( GPU_memory_mode, teacache_threshold, num_skip_start_steps, image_path, audio_path, prompt, negative_prompt, width, height, guidance_scale, num_inference_steps, text_guide_scale, audio_guide_scale, motion_frame, fps, overlap_window_length, seed_param, overlapping_weight_scheme, progress=gr.Progress(track_tqdm=True), ): global pipeline, transformer3d timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") if seed_param<0: seed = random.randint(0, np.iinfo(np.int32).max) else: seed = seed_param if GPU_memory_mode == "sequential_cpu_offload": replace_parameters_by_name(transformer3d, ["modulation", ], device=device) transformer3d.freqs = transformer3d.freqs.to(device=device) pipeline.enable_sequential_cpu_offload(device=device) elif GPU_memory_mode == "model_cpu_offload_and_qfloat8": convert_model_weight_to_float8(transformer3d, exclude_module_name=["modulation", ]) convert_weight_dtype_wrapper(transformer3d, dtype) pipeline.enable_model_cpu_offload(device=device) elif GPU_memory_mode == "model_cpu_offload": pipeline.enable_model_cpu_offload(device=device) else: pipeline.to(device=device) if teacache_threshold > 0: coefficients = get_teacache_coefficients(pretrained_model_name_or_path) pipeline.transformer.enable_teacache( coefficients, num_inference_steps, teacache_threshold, num_skip_start_steps=num_skip_start_steps, ) with torch.no_grad(): video_length = int((clip_sample_n_frames - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if clip_sample_n_frames != 1 else 1 input_video, input_video_mask, clip_image = get_image_to_video_latent(image_path, None, video_length=video_length, sample_size=[height, width]) sr = 16000 vocal_input, sample_rate = librosa.load(audio_path, sr=sr) sample = pipeline( prompt, num_frames=video_length, negative_prompt=negative_prompt, width=width, height=height, guidance_scale=guidance_scale, generator=torch.Generator().manual_seed(seed), num_inference_steps=num_inference_steps, video=input_video, mask_video=input_video_mask, clip_image=clip_image, text_guide_scale=text_guide_scale, audio_guide_scale=audio_guide_scale, vocal_input_values=vocal_input, motion_frame=motion_frame, fps=fps, sr=sr, cond_file_path=image_path, overlap_window_length=overlap_window_length, seed=seed, overlapping_weight_scheme=overlapping_weight_scheme, ).videos os.makedirs("outputs", exist_ok=True) video_path = os.path.join("outputs", f"{timestamp}.mp4") save_videos_grid(sample, video_path, fps=fps) output_video_with_audio = os.path.join("outputs", f"{timestamp}_audio.mp4") subprocess.run([ "ffmpeg", "-y", "-loglevel", "quiet", "-i", video_path, "-i", audio_path, "-c:v", "copy", "-c:a", "aac", "-strict", "experimental", output_video_with_audio ], check=True) return output_video_with_audio, seed, f"Generated outputs/{timestamp}.mp4 / 已生成outputs/{timestamp}.mp4" def exchange_width_height(width, height): return height, width, "✅ Width and Height Swapped / 宽高交换完毕" def adjust_width_height(image): image = load_image(image) width, height = image.size original_area = width * height default_area = 512*512 ratio = math.sqrt(original_area / default_area) width = width / ratio // 16 * 16 height = height / ratio // 16 * 16 return int(width), int(height), "✅ Adjusted Size Based on Image / 根据图片调整宽高" def audio_extractor(video_path): timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") os.makedirs("outputs", exist_ok=True) # 确保目录存在 out_wav = os.path.abspath(os.path.join("outputs", f"{timestamp}.wav")) video = VideoFileClip(video_path) audio = video.audio audio.write_audiofile(out_wav, codec="pcm_s16le") return out_wav, f"Generated {out_wav} / 已生成 {out_wav}", out_wav # ← 第3个返回给 gr.File def vocal_separation(audio_path): timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") os.makedirs("outputs", exist_ok=True) # audio_separator_model_file = "checkpoints/Kim_Vocal_2.onnx" audio_separator = Separator( output_dir=os.path.abspath(os.path.join("outputs", timestamp)), output_single_stem="vocals", model_file_dir=os.path.dirname(audio_separator_model_file), ) audio_separator.load_model(os.path.basename(audio_separator_model_file)) assert audio_separator.model_instance is not None, "Fail to load audio separate model." outputs = audio_separator.separate(audio_path) vocal_audio_file = os.path.join(audio_separator.output_dir, outputs[0]) destination_file = os.path.abspath(os.path.join("outputs", f"{timestamp}.wav")) shutil.copy(vocal_audio_file, destination_file) os.remove(vocal_audio_file) return destination_file, f"Generated {destination_file} / 已生成 {destination_file}", destination_file def update_language(language): if language == "English": return { GPU_memory_mode: gr.Dropdown(label="GPU Memory Mode", info="Normal uses 25G VRAM, model_cpu_offload uses 13G VRAM"), teacache_threshold: gr.Slider(label="TeaCache Threshold", info="Recommended 0.1, 0 disables TeaCache acceleration"), num_skip_start_steps: gr.Slider(label="Skip Start Steps", info="Recommended 5"), model_version: gr.Dropdown(label="Model Version", choices=["square", "rec_vec"], value="square"), image_path: gr.Image(label="Upload Image"), audio_path: gr.Audio(label="Upload Audio"), prompt: gr.Textbox(label="Prompt"), negative_prompt: gr.Textbox(label="Negative Prompt"), generate_button: gr.Button("🎬 Start Generation"), width: gr.Slider(label="Width"), height: gr.Slider(label="Height"), exchange_button: gr.Button("🔄 Swap Width/Height"), adjust_button: gr.Button("Adjust Size Based on Image"), guidance_scale: gr.Slider(label="Guidance Scale"), num_inference_steps: gr.Slider(label="Sampling Steps (Recommended 50)"), text_guide_scale: gr.Slider(label="Text Guidance Scale"), audio_guide_scale: gr.Slider(label="Audio Guidance Scale"), motion_frame: gr.Slider(label="Motion Frame"), fps: gr.Slider(label="FPS"), overlap_window_length: gr.Slider(label="Overlap Window Length"), seed_param: gr.Number(label="Seed (positive integer, -1 for random)"), overlapping_weight_scheme: gr.Dropdown(label="Overlapping Weight Scheme", choices=["uniform", "log"], value="uniform"), info: gr.Textbox(label="Status"), video_output: gr.Video(label="Generated Result"), seed_output: gr.Textbox(label="Seed"), video_path: gr.Video(label="Upload Video"), extractor_button: gr.Button("🎬 Start Extraction"), info2: gr.Textbox(label="Status"), audio_output: gr.Audio(label="Generated Result"), audio_path3: gr.Audio(label="Upload Audio"), separation_button: gr.Button("🎬 Start Separation"), info3: gr.Textbox(label="Status"), audio_output3: gr.Audio(label="Generated Result"), example_title: gr.Markdown(value="### Select the following example cases for testing:"), example1_label: gr.Markdown(value="**Example 1**"), example2_label: gr.Markdown(value="**Example 2**"), example3_label: gr.Markdown(value="**Example 3**"), example4_label: gr.Markdown(value="**Example 4**"), example5_label: gr.Markdown(value="**Example 5**"), example1_btn: gr.Button("🚀 Use Example 1", variant="secondary"), example2_btn: gr.Button("🚀 Use Example 2", variant="secondary"), example3_btn: gr.Button("🚀 Use Example 3", variant="secondary"), example4_btn: gr.Button("🚀 Use Example 4", variant="secondary"), example5_btn: gr.Button("🚀 Use Example 5", variant="secondary"), parameter_settings_title: gr.Accordion(label="Parameter Settings", open=True), example_cases_title: gr.Accordion(label="Example Cases", open=True), stableavatar_title: gr.TabItem(label="StableAvatar"), audio_extraction_title: gr.TabItem(label="Audio Extraction"), vocal_separation_title: gr.TabItem(label="Vocal Separation") } else: return { GPU_memory_mode: gr.Dropdown(label="显存模式", info="Normal占用25G显存,model_cpu_offload占用13G显存"), teacache_threshold: gr.Slider(label="teacache threshold", info="推荐参数0.1,0为禁用teacache加速"), num_skip_start_steps: gr.Slider(label="跳过开始步数", info="推荐参数5"), model_version: gr.Dropdown(label="模型版本", choices=["square", "rec_vec"], value="square"), image_path: gr.Image(label="上传图片"), audio_path: gr.Audio(label="上传音频"), prompt: gr.Textbox(label="提示词"), negative_prompt: gr.Textbox(label="负面提示词"), generate_button: gr.Button("🎬 开始生成"), width: gr.Slider(label="宽度"), height: gr.Slider(label="高度"), exchange_button: gr.Button("🔄 交换宽高"), adjust_button: gr.Button("根据图片调整宽高"), guidance_scale: gr.Slider(label="guidance scale"), num_inference_steps: gr.Slider(label="采样步数(推荐50步)", minimum=1, maximum=100, step=1, value=50), text_guide_scale: gr.Slider(label="text guidance scale"), audio_guide_scale: gr.Slider(label="audio guidance scale"), motion_frame: gr.Slider(label="motion frame"), fps: gr.Slider(label="帧率"), overlap_window_length: gr.Slider(label="overlap window length"), seed_param: gr.Number(label="种子,请输入正整数,-1为随机"), overlapping_weight_scheme: gr.Dropdown(label="Overlapping Weight Scheme", choices=["uniform", "log"], value="uniform"), info: gr.Textbox(label="提示信息"), video_output: gr.Video(label="生成结果"), seed_output: gr.Textbox(label="种子"), video_path: gr.Video(label="上传视频"), extractor_button: gr.Button("🎬 开始提取"), info2: gr.Textbox(label="提示信息"), audio_output: gr.Audio(label="生成结果"), audio_path3: gr.Audio(label="上传音频"), separation_button: gr.Button("🎬 开始分离"), info3: gr.Textbox(label="提示信息"), audio_output3: gr.Audio(label="生成结果"), example_title: gr.Markdown(value="### 选择以下示例案例进行测试:"), example1_label: gr.Markdown(value="**示例 1**"), example2_label: gr.Markdown(value="**示例 2**"), example3_label: gr.Markdown(value="**示例 3**"), example4_label: gr.Markdown(value="**示例 4**"), example5_label: gr.Markdown(value="**示例 5**"), example1_btn: gr.Button("🚀 使用示例 1", variant="secondary"), example2_btn: gr.Button("🚀 使用示例 2", variant="secondary"), example3_btn: gr.Button("🚀 使用示例 3", variant="secondary"), example4_btn: gr.Button("🚀 使用示例 4", variant="secondary"), example5_btn: gr.Button("🚀 使用示例 5", variant="secondary"), parameter_settings_title: gr.Accordion(label="参数设置", open=True), example_cases_title: gr.Accordion(label="示例案例", open=True), stableavatar_title: gr.TabItem(label="StableAvatar"), audio_extraction_title: gr.TabItem(label="音频提取"), vocal_separation_title: gr.TabItem(label="人声分离") } BANNER_HTML = """