import gradio as gr import os import torch import tempfile import sys from huggingface_hub import snapshot_download import spaces # Setup paths PUSA_PATH = os.path.abspath("./PusaV1") if PUSA_PATH not in sys.path: sys.path.insert(0, PUSA_PATH) # Validate diffsynth presence DIFFSYNTH_PATH = os.path.join(PUSA_PATH, "diffsynth") if not os.path.exists(DIFFSYNTH_PATH): raise RuntimeError( f"'diffsynth' package not found in {PUSA_PATH}. " f"Ensure PusaV1 is correctly cloned and folder structure is intact." ) # Import core modules from PusaV1 from PusaV1.diffsynth import ModelManager, WanVideoPusaPipeline, save_video class PatchedModelManager(ModelManager): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # Patch architecture dict here custom_architecture_dict = { "WanModel": ("diffsynth.models.wan_model", "WanModelPusa", None), } self.architecture_dict.update(custom_architecture_dict) # Constants import os from huggingface_hub import snapshot_download # Constants MODEL_ZOO_DIR = "./model_zoo" PUSA_DIR = os.path.join(MODEL_ZOO_DIR, "PusaV1") WAN_SUBFOLDER = "Wan2.1-T2V-14B" WAN_MODEL_PATH = os.path.join(PUSA_DIR, WAN_SUBFOLDER) LORA_PATH = os.path.join(PUSA_DIR, "pusa_v1.pt") # Ensure model and weights are downloaded def ensure_model_downloaded(): if not os.path.exists(PUSA_DIR): print("Downloading RaphaelLiu/PusaV1 to ./model_zoo/PusaV1 ...") snapshot_download( repo_id="RaphaelLiu/PusaV1", local_dir=PUSA_DIR, repo_type="model", local_dir_use_symlinks=False, ) print("✅ PusaV1 downloaded.") if not os.path.exists(WAN_MODEL_PATH): print("Downloading Wan-AI/Wan2.1-T2V-14B to ./model_zoo/PusaV1/Wan2.1-T2V-14B ...") snapshot_download( repo_id="Wan-AI/Wan2.1-T2V-14B", local_dir=WAN_MODEL_PATH, # Changed from WAN_DIR to WAN_MODEL_PATH repo_type="model", local_dir_use_symlinks=False, ) print("✅ Wan2.1-T2V-14B downloaded.") # Subclass ModelManager to force WanModelPusa class PatchedModelManager(ModelManager): def load_model(self, file_path=None, model_names=None, device=None, torch_dtype=None): if file_path is None: file_path = self.file_path_list[0] print(f"[app.py] Forcing architecture: WanModelPusa for {file_path}") for detector in self.model_detector: if detector.match(file_path, {}): model_names, models = detector.load( file_path, state_dict={}, device=device or self.device, torch_dtype=torch_dtype or self.torch_dtype, allowed_model_names=model_names, model_manager=self, forced_architecture="WanModelPusa" ) for name, model in zip(model_names, models): self.model.append(model) self.model_path.append(file_path) self.model_name.append(name) return models[0] if models else None print("No suitable model detector matched.") return None # Video generation logic def generate_t2v_video(self, prompt, lora_alpha, num_inference_steps, negative_prompt, progress=gr.Progress()): """Generate video from text prompt""" try: progress(0.1, desc="Loading models...") lora_path = "./model_zoo/PusaV1/pusa_v1.pt" pipe = self.load_lora_and_get_pipe("t2v", lora_path, lora_alpha) progress(0.3, desc="Generating video...") video = pipe( prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, height=720, width=1280, num_frames=81, seed=0, tiled=True ) progress(0.9, desc="Saving video...") timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") video_filename = os.path.join(self.output_dir, f"t2v_output_{timestamp}.mp4") save_video(video, video_filename, fps=25, quality=5) progress(1.0, desc="Complete!") return video_filename, f"Video generated successfully! Saved to {video_filename}" except Exception as e: return None, f"Error: {str(e)}" @spaces.GPU(duration=200) def generate_video(prompt: str): # Load model using patched manager model_manager = ModelManager(device="cuda") base_dir = "model_zoo/PusaV1/Wan2.1-T2V-14B" model_files = sorted([os.path.join(base_dir, f) for f in os.listdir(base_dir) if f.endswith('.safetensors')]) model_manager.load_models( [ model_files, os.path.join(base_dir, "models_t5_umt5-xxl-enc-bf16.pth"), os.path.join(base_dir, "Wan2.1_VAE.pth"), ], torch_dtype=torch.bfloat16, ) # manager = ModelManager( # file_path_list=[WAN_MODEL_PATH], # torch_dtype=torch.float16, # device="cuda" # ) # manager = PatchedModelManager( # file_path_list=[WAN_MODEL_PATH], # torch_dtype=torch.float16, # device="cuda" # ) #model = manager.load_model(WAN_MODEL_PATH) # Set up pipeline #pipeline = WanVideoPusaPipeline(model=model_manager) pipeline = WanVideoPusaPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") #pipeline.set_lora_adapters(LORA_PATH) # Generate video result = pipeline(prompt) # Save video tmp_dir = tempfile.mkdtemp() output_path = os.path.join(tmp_dir, "video.mp4") save_video(result.frames, output_path, fps=8) return output_path # Gradio UI with gr.Blocks() as demo: gr.Markdown("## 🎥 Wan2.1-T2V-14B with Pusa LoRA | Text-to-Video Generator") prompt_input = gr.Textbox( lines=4, label="Prompt", placeholder="Describe your video (e.g. A coral reef full of colorful fish...)" ) generate_btn = gr.Button("Generate Video") video_output = gr.Video(label="Output") generate_btn.click(fn=generate_video, inputs=prompt_input, outputs=video_output) if __name__ == "__main__": ensure_model_downloaded() demo.launch(share=True, show_error=True)