import gradio as gr import os import torch import tempfile import sys from huggingface_hub import snapshot_download import spaces import os import sys from huggingface_hub import snapshot_download # === Setup Paths === import os import sys from huggingface_hub import snapshot_download # === Robust Base Path === # Ensures compatibility inside Hugging Face Spaces (or any container) BASE_DIR = os.path.dirname(os.path.abspath(__file__)) PUSA_ROOT = os.path.join(BASE_DIR, "PusaV1") MODEL_ZOO_DIR = os.path.join(PUSA_ROOT, "model_zoo") MODEL_ZOO_SUB_DIR = os.path.join(MODEL_ZOO_DIR , "PusaV1") WAN_SUBFOLDER = "Wan2.1-T2V-14B" WAN_MODEL_PATH = os.path.join(MODEL_ZOO_SUB_DIR, WAN_SUBFOLDER) LORA_PATH = os.path.join(MODEL_ZOO_SUB_DIR, "pusa_v1.pt") # Add PUSA_ROOT to sys.path so Python can import diffsynth if PUSA_ROOT not in sys.path: sys.path.insert(0, PUSA_ROOT) # === Validate diffsynth === DIFFSYNTH_PATH = os.path.join(PUSA_ROOT, "diffsynth") if not os.path.exists(DIFFSYNTH_PATH): raise RuntimeError( f"'diffsynth' package not found in {PUSA_ROOT}. " f"Ensure PusaV1 is correctly cloned and folder structure is intact." ) # === Ensure models exist, skip download if already present === def ensure_model_downloaded(): print("š Checking model presence...\n") # === List contents of model_zoo for verification print(f"\nš Verifying files under: {MODEL_ZOO_SUB_DIR}\n") for root, dirs, files in os.walk(MODEL_ZOO_SUB_DIR): for file in files: full_path = os.path.relpath(os.path.join(root, file), start=MODEL_ZOO_SUB_DIR) print(" -", full_path) if not os.path.exists(MODEL_ZOO_DIR): print("Downloading RaphaelLiu/PusaV1 to ./PusaV1/model_zoo ...") snapshot_download( repo_id="RaphaelLiu/PusaV1", local_dir=MODEL_ZOO_SUB_DIR, repo_type="model", local_dir_use_symlinks=False, ) print("ā PusaV1 base model downloaded.") else: print("ā PusaV1 base folder already exists.") if not os.path.exists(WAN_MODEL_PATH): print("Downloading Wan-AI/Wan2.1-T2V-14B to ./PusaV1/model_zoo/Wan2.1-T2V-14B ...") snapshot_download( repo_id="Wan-AI/Wan2.1-T2V-14B", local_dir=WAN_MODEL_PATH, repo_type="model", local_dir_use_symlinks=False, ) print("ā Wan2.1-T2V-14B model downloaded.") else: print("ā Wan2.1-T2V-14B folder already exists.") # if not os.path.exists(LORA_PATH): # raise FileNotFoundError( # f"ā Expected LoRA weights 'pusa_v1.pt' not found at {LORA_PATH}. " # f"Please make sure it exists in your repo." # ) # else: # print("ā LoRA weights (pusa_v1.pt) found.") # === List contents of model_zoo for verification print(f"\nš Verifying files under: {MODEL_ZOO_SUB_DIR}\n") for root, dirs, files in os.walk(MODEL_ZOO_SUB_DIR): for file in files: full_path = os.path.relpath(os.path.join(root, file), start=MODEL_ZOO_SUB_DIR) print(" -", full_path) import gradio as gr import torch import os import sys import datetime import shutil from PIL import Image import cv2 import numpy as np from diffsynth import ModelManager, PusaMultiFramesPipeline, PusaV2VPipeline, WanVideoPusaPipeline, save_video import tempfile class PusaVideoDemo: def __init__(self): print("load class demo=======") print(WAN_MODEL_PATH) print("š§ Initializing DemoLoader...") # Check WAN model path if not os.path.exists(WAN_MODEL_PATH): raise FileNotFoundError(f"ā WAN_MODEL_PATH not found: {WAN_MODEL_PATH}") print(f"ā WAN_MODEL_PATH resolved: {WAN_MODEL_PATH}") self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model_manager = None self.multi_frames_pipe = None self.v2v_pipe = None self.t2v_pipe = None self.base_dir = WAN_MODEL_PATH self.output_dir = "outputs" os.makedirs(self.output_dir, exist_ok=True) def load_models(self): """Load all models once for efficiency""" if self.model_manager is None: print("Loading models...") self.model_manager = ModelManager(device="cpu") model_files = sorted([os.path.join(self.base_dir, f) for f in os.listdir(self.base_dir) if f.endswith('.safetensors')]) self.model_manager.load_models( [ model_files, os.path.join(self.base_dir, "models_t5_umt5-xxl-enc-bf16.pth"), os.path.join(self.base_dir, "Wan2.1_VAE.pth"), ], torch_dtype=torch.bfloat16, ) print("Models loaded successfully!") def load_lora_and_get_pipe(self, pipe_type, lora_path, lora_alpha): """Load LoRA and return appropriate pipeline""" self.load_models() # Load LoRA self.model_manager.load_lora(lora_path, lora_alpha=lora_alpha) if pipe_type == "multi_frames": pipe = PusaMultiFramesPipeline.from_model_manager(self.model_manager, torch_dtype=torch.bfloat16, device=self.device) pipe.enable_vram_management(num_persistent_param_in_dit=6*10**9) elif pipe_type == "v2v": pipe = PusaV2VPipeline.from_model_manager(self.model_manager, torch_dtype=torch.bfloat16, device=self.device) pipe.enable_vram_management(num_persistent_param_in_dit=6*10**9) elif pipe_type == "t2v": pipe = WanVideoPusaPipeline.from_model_manager(self.model_manager, torch_dtype=torch.bfloat16, device=self.device) pipe.enable_vram_management(num_persistent_param_in_dit=None) return pipe def process_video_frames(self, video_path): """Process video frames for V2V pipeline""" if not os.path.isfile(video_path): raise FileNotFoundError(f"Video file not found: {video_path}") cap = cv2.VideoCapture(video_path) frames = [] # Get original video dimensions width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # Calculate scaling and cropping parameters target_width = 1280 target_height = 720 target_ratio = target_width / target_height original_ratio = width / height while True: ret, frame = cap.read() if not ret: break # Resize maintaining aspect ratio if original_ratio > target_ratio: # Video is wider than target new_width = int(height * target_ratio) # Crop width from center start_x = (width - new_width) // 2 frame = frame[:, start_x:start_x + new_width] else: # Video is taller than target new_height = int(width / target_ratio) # Crop height from center start_y = (height - new_height) // 2 frame = frame[start_y:start_y + new_height] # Resize to target dimensions frame = cv2.resize(frame, (target_width, target_height), interpolation=cv2.INTER_LANCZOS4) # Convert to RGB frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames.append(Image.fromarray(frame)) cap.release() return frames def generate_i2v_video(self, image_path, prompt, noise_multiplier, lora_alpha, num_inference_steps, negative_prompt, progress=gr.Progress()): """Generate video from single image (I2V)""" try: progress(0.1, desc="Loading models...") lora_path = "./model_zoo/PusaV1/pusa_v1.pt" pipe = self.load_lora_and_get_pipe("multi_frames", lora_path, lora_alpha) progress(0.2, desc="Processing input image...") # Process single image for I2V if image_path is None: raise ValueError("No image provided") # Handle image path - Gradio with type="filepath" returns the path directly img = Image.open(image_path) processed_image = img.convert("RGB").resize((1280, 720), Image.LANCZOS) # I2V always uses position 0 (first frame) multi_frame_images = {0: (processed_image, float(noise_multiplier))} progress(0.4, desc="Generating video...") video = pipe( prompt=prompt, negative_prompt=negative_prompt, multi_frame_images=multi_frame_images, num_inference_steps=num_inference_steps, height=720, width=1280, num_frames=81, seed=0, tiled=True ) progress(0.9, desc="Saving video...") timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") video_filename = os.path.join(self.output_dir, f"i2v_output_{timestamp}_noise_{noise_multiplier}_alpha_{lora_alpha}.mp4") save_video(video, video_filename, fps=25, quality=5) progress(1.0, desc="Complete!") return video_filename, f"Video generated successfully! Saved to {video_filename}" except Exception as e: return None, f"Error: {str(e)}" def generate_multi_frames_video(self, image1, image2, image3, num_imgs, prompt, cond_position, noise_multipliers, lora_alpha, num_inference_steps, negative_prompt, progress=gr.Progress()): """Generate video from multiple frames (Start-End, Multi-frame)""" try: progress(0.1, desc="Loading models...") lora_path = "./model_zoo/PusaV1/pusa_v1.pt" pipe = self.load_lora_and_get_pipe("multi_frames", lora_path, lora_alpha) progress(0.2, desc="Processing input images...") # Parse conditioning positions and noise multipliers cond_pos_list = [int(x.strip()) for x in cond_position.split(',')] noise_mult_list = [float(x.strip()) for x in noise_multipliers.split(',')] # Collect images based on num_imgs image_paths = [image1, image2] if num_imgs == "3" and image3 is not None: image_paths.append(image3) # Filter out None values image_paths = [path for path in image_paths if path is not None] if len(image_paths) != len(cond_pos_list) or len(image_paths) != len(noise_mult_list): raise ValueError("The number of images, conditioning positions, and noise multipliers must be the same.") # Process images processed_images = [] for img_path in image_paths: img = Image.open(img_path) processed_images.append(img.convert("RGB").resize((1280, 720), Image.LANCZOS)) multi_frame_images = { cond_pos: (img, noise_mult) for cond_pos, img, noise_mult in zip(cond_pos_list, processed_images, noise_mult_list) } progress(0.4, desc="Generating video...") video = pipe( prompt=prompt, negative_prompt=negative_prompt, multi_frame_images=multi_frame_images, num_inference_steps=num_inference_steps, height=720, width=1280, num_frames=81, seed=0, tiled=True ) progress(0.9, desc="Saving video...") timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") video_filename = os.path.join(self.output_dir, f"multi_frame_output_{timestamp}.mp4") save_video(video, video_filename, fps=25, quality=5) progress(1.0, desc="Complete!") return video_filename, f"Video generated successfully! Saved to {video_filename}" except Exception as e: return None, f"Error: {str(e)}" def generate_v2v_video(self, video_path, prompt, cond_position, noise_multipliers, lora_alpha, num_inference_steps, negative_prompt, progress=gr.Progress()): """Generate video from video (V2V completion, extension)""" try: progress(0.1, desc="Loading models...") lora_path = "./model_zoo/PusaV1/pusa_v1.pt" pipe = self.load_lora_and_get_pipe("v2v", lora_path, lora_alpha) progress(0.2, desc="Processing input video...") # Parse conditioning positions and noise multipliers cond_pos_list = [int(x.strip()) for x in cond_position.split(',')] noise_mult_list = [float(x.strip()) for x in noise_multipliers.split(',')] # Process video conditioning_video = self.process_video_frames(video_path) progress(0.4, desc="Generating video...") video = pipe( prompt=prompt, negative_prompt=negative_prompt, conditioning_video=conditioning_video, conditioning_indices=cond_pos_list, conditioning_noise_multipliers=noise_mult_list, num_inference_steps=num_inference_steps, height=720, width=1280, num_frames=81, seed=0, tiled=True ) progress(0.9, desc="Saving video...") timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") output_filename = os.path.basename(video_path).split('.')[0] video_filename = os.path.join(self.output_dir, f"v2v_{output_filename}_{timestamp}.mp4") save_video(video, video_filename, fps=25, quality=5) progress(1.0, desc="Complete!") return video_filename, f"Video generated successfully! Saved to {video_filename}" except Exception as e: return None, f"Error: {str(e)}" @spaces.GPU(duration=200) def generate_t2v_video(self, prompt, lora_alpha, num_inference_steps, negative_prompt, progress=gr.Progress()): """Generate video from text prompt""" try: progress(0.1, desc="Loading models...") lora_path = "./model_zoo/PusaV1/pusa_v1.pt" pipe = self.load_lora_and_get_pipe("t2v", lora_path, lora_alpha) progress(0.3, desc="Generating video...") video = pipe( prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, height=720, width=1280, num_frames=81, seed=0, tiled=True ) progress(0.9, desc="Saving video...") timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") video_filename = os.path.join(self.output_dir, f"t2v_output_{timestamp}.mp4") save_video(video, video_filename, fps=25, quality=5) progress(1.0, desc="Complete!") return video_filename, f"Video generated successfully! Saved to {video_filename}" except Exception as e: return None, f"Error: {str(e)}" def create_demo(): demo_instance = PusaVideoDemo() # Set custom cache directory to avoid permission issues import tempfile import os try: # Try to use a custom cache directory in the current workspace cache_dir = os.path.join(os.getcwd(), "gradio_cache") os.makedirs(cache_dir, exist_ok=True) os.environ["GRADIO_TEMP_DIR"] = cache_dir except: pass # Fall back to default if this fails # Helper function to safely load demo files def safe_file_path(file_path): """Return file path if it exists, None otherwise""" try: if os.path.exists(file_path): return file_path except: pass return None # Custom CSS for fancy black design css = """ /* === Main Theme: "Cosmic Flow" === */ :root { --color-primary: #22d3ee; /* Cosmic Cyan */ --color-secondary: #ec4899; /* Galactic Pink */ --color-accent: #a78bfa; /* Astral Violet */ --color-background-dark: #0f172a; /* Midnight Slate */ --color-background-light: #1e293b; /* Twilight Slate */ --color-surface: rgba(30, 41, 59, 0.6); /* Glassy Slate */ --color-surface-hover: rgba(30, 41, 59, 0.9); --color-text-light: #f1f5f9; /* Starlight White */ --color-text-medium: #94a3b8; /* Nebula Gray */ --color-text-dark: #64748b; /* Meteor Gray */ --font-main: 'Inter', 'SF Pro Display', -apple-system, BlinkMacSystemFont, sans-serif; --radius-lg: 20px; --radius-md: 12px; --radius-sm: 8px; } /* === Global Styles === */ .gradio-container { font-family: var(--font-main) !important; background: linear-gradient(135deg, var(--color-background-dark) 0%, var(--color-background-light) 100%) !important; color: var(--color-text-light) !important; } * { color: var(--color-text-light); border-color: rgba(148, 163, 184, 0.1); /* slate-400/10% */ } /* === Glassmorphism Containers === */ .gr-panel, .gr-box, .gr-group, .gr-column, .gr-tabitem, .gr-accordion { background: var(--color-surface) !important; backdrop-filter: blur(12px) !important; -webkit-backdrop-filter: blur(12px) !important; border: 1px solid rgba(148, 163, 184, 0.1) !important; border-radius: var(--radius-lg) !important; box-shadow: 0 8px 32px rgba(0, 0, 0, 0.2) !important; transition: all 0.3s ease !important; } .gr-panel:hover, .gr-box:hover, .gr-group:hover, .gr-column:hover { background: var(--color-surface-hover) !important; border-color: rgba(148, 163, 184, 0.2) !important; transform: translateY(-2px) scale(1.01); box-shadow: 0 12px 40px rgba(0, 0, 0, 0.3) !important; } /* === Header (Static Nebula) === */ .fancy-header { text-align: center !important; background-color: var(--color-background-dark) !important; padding: 40px !important; border-radius: var(--radius-lg) !important; margin-bottom: 40px !important; border: 1px solid rgba(148, 163, 184, 0.2) !important; position: relative !important; overflow: hidden !important; box-shadow: 0 20px 60px rgba(15, 23, 42, 0.5) !important; } .fancy-header::before { content: '' !important; position: absolute !important; top: -150px; left: -150px; right: -150px; bottom: -150px; background: radial-gradient(ellipse at 20% 25%, var(--color-primary), transparent 40%), radial-gradient(ellipse at 80% 30%, var(--color-accent), transparent 40%), radial-gradient(ellipse at 50% 90%, var(--color-secondary), transparent 45%) !important; opacity: 0.2 !important; filter: blur(80px) !important; transform: scale(1.2) !important; z-index: 0 !important; } .fancy-header > * { position: relative !important; /* Ensures content is on top of the nebula effect */ z-index: 1 !important; } /* === Tabs === */ .gr-tabs { background: transparent !important; } .gr-tab-nav { background: rgba(30, 41, 59, 0.8) !important; border-radius: var(--radius-lg) !important; padding: 6px !important; border: none !important; } .gr-tab-nav button { background: transparent !important; color: var(--color-text-medium) !important; border-radius: var(--radius-md) !important; font-weight: 600 !important; transition: all 0.3s ease !important; padding: 12px 20px !important; border: none !important; } .gr-tab-nav button:hover { background: rgba(167, 139, 250, 0.2) !important; color: var(--color-text-light) !important; } .gr-tab-nav button.selected { background: linear-gradient(135deg, var(--color-primary) 0%, var(--color-accent) 100%) !important; color: white !important; box-shadow: 0 8px 25px rgba(34, 211, 238, 0.3) !important; } /* === Primary Generate Button === */ .generate-btn, .primary-btn, button.primary, .gr-button-primary { background: linear-gradient(135deg, var(--color-primary) 0%, var(--color-secondary) 100%) !important; background-size: 250% 250% !important; border: 2px solid transparent !important; border-radius: var(--radius-lg) !important; color: white !important; font-weight: 700 !important; padding: 18px 36px !important; text-transform: uppercase !important; letter-spacing: 1.5px !important; transition: all 0.4s ease !important; box-shadow: 0 10px 30px rgba(34, 211, 238, 0.2), 0 10px 30px rgba(236, 72, 153, 0.2) !important; position: relative; overflow: hidden; z-index: 1; } .generate-btn::before, .primary-btn::before { content: '' !important; position: absolute !important; top: 0; left: -100%; width: 100%; height: 100%; background: linear-gradient(120deg, transparent, rgba(255,255,255,0.4), transparent); transition: left 0.6s ease; z-index: -1; } .generate-btn:hover::before, .primary-btn:hover::before { left: 100%; } .generate-btn:hover, .primary-btn:hover { transform: translateY(-5px) scale(1.03) !important; box-shadow: 0 15px 40px rgba(34, 211, 238, 0.4), 0 15px 40px rgba(236, 72, 153, 0.4) !important; background-position: 100% 50% !important; } /* === Secondary & Tertiary Buttons (e.g., "Load Example") === */ button:not(.primary):not(.selected) { background: rgba(148, 163, 184, 0.1) !important; border: 1px solid rgba(148, 163, 184, 0.2) !important; color: var(--color-text-medium) !important; border-radius: var(--radius-md) !important; padding: 10px 20px !important; font-weight: 500 !important; transition: all 0.3s ease !important; } button:not(.primary):not(.selected):hover { background: var(--color-accent) !important; border-color: var(--color-accent) !important; color: white !important; transform: translateY(-2px); box-shadow: 0 6px 20px rgba(167, 139, 250, 0.3) !important; } /* === Input Fields & Textareas === */ input, textarea, .gr-textbox, .gr-number { background: rgba(15, 23, 42, 0.8) !important; /* Midnight Slate dark */ border: 1px solid rgba(148, 163, 184, 0.2) !important; border-radius: var(--radius-md) !important; color: var(--color-text-light) !important; padding: 12px !important; transition: all 0.3s ease !important; } input:focus, textarea:focus, .gr-textbox:focus-within, .gr-number:focus-within { border-color: var(--color-primary) !important; box-shadow: 0 0 15px rgba(34, 211, 238, 0.2) !important; outline: none !important; } input::placeholder, textarea::placeholder { color: var(--color-text-dark) !important; } /* === Sliders === */ .gr-slider { --slider-track-color: rgba(15, 23, 42, 0.9); --slider-range-color: linear-gradient(90deg, var(--color-primary) 0%, var(--color-accent) 100%); --slider-handle-color: white; --slider-handle-shadow: 0 4px 15px rgba(34, 211, 238, 0.4); } .gradio-container .gr-slider .gr-slider-track { background: var(--slider-track-color) !important; } .gradio-container .gr-slider .gr-slider-range { background: var(--slider-range-color) !important; } .gradio-container .gr-slider .gr-slider-handle { background: var(--slider-handle-color) !important; border: 2px solid var(--color-primary) !important; box-shadow: var(--slider-handle-shadow) !important; } /* === File Upload === */ .gr-file, .gr-upload { background: rgba(15, 23, 42, 0.7) !important; border: 2px dashed var(--color-text-dark) !important; border-radius: var(--radius-lg) !important; transition: all 0.3s ease !important; } .gr-file:hover, .gr-upload:hover { border-color: var(--color-primary) !important; background: rgba(34, 211, 238, 0.1) !important; } .gr-file *, .gr-upload * { color: var(--color-text-medium) !important; background: transparent !important; } /* === Markdown & Text === */ .gr-markdown { color: var(--color-text-light) !important; } .gr-markdown h1, .gr-markdown h2, .gr-markdown h3 { background: linear-gradient(90deg, var(--color-primary) 0%, var(--color-secondary) 100%); -webkit-background-clip: text; -moz-background-clip: text; background-clip: text; -webkit-text-fill-color: transparent; margin-bottom: 1rem; } .gr-markdown a { color: var(--color-primary) !important; text-decoration: none !important; transition: all 0.2s ease; } .gr-markdown a:hover { color: var(--color-secondary) !important; text-decoration: underline !important; } label { color: var(--color-text-medium) !important; font-weight: 600 !important; margin-bottom: 8px !important; text-transform: uppercase; font-size: 0.8rem; letter-spacing: 0.5px; } .gr-info { color: var(--color-text-dark) !important; font-style: italic; } /* === Progress Bar === */ .gr-progress { background: rgba(15, 23, 42, 0.8) !important; border-radius: var(--radius-sm) !important; } .gr-progress-bar { background: linear-gradient(90deg, var(--color-primary) 0%, var(--color-accent) 100%) !important; border-radius: var(--radius-sm) !important; } /* === Scrollbar === */ ::-webkit-scrollbar { width: 10px; } ::-webkit-scrollbar-track { background: var(--color-background-light); } ::-webkit-scrollbar-thumb { background: linear-gradient(var(--color-accent), var(--color-primary)); border-radius: 5px; } ::-webkit-scrollbar-thumb:hover { background: linear-gradient(var(--color-primary), var(--color-secondary)); } /* === Final cleanup & overrides === */ .gradio-container .prose { color: var(--color-text-light) !important; } .gradio-container .gr-button * { color: inherit !important; } """ with gr.Blocks(css=css, title="⨠Pusa V1.0 - Revolutionary AI Video Generation āØ", theme=gr.themes.Default(primary_hue="purple", neutral_hue="gray").set( body_background_fill="linear-gradient(135deg, #0f172a 0%, #1e293b 100%)", background_fill_primary="#1e293b", background_fill_secondary="#0f172a", border_color_primary="rgba(148, 163, 184, 0.1)" )) as demo: # Header gr.HTML("""
š„ BREAKTHROUGH PERFORMANCE: Surpassing Wan-I2V on Vbench-I2V with only $500 training cost! š„
š 4 Powerful Modes: I2V ⢠Multi-Frame ⢠V2V ⢠T2V š
Explore real examples showcasing the power and versatility of Pusa V1.0 across different generation modes.
š Note: Demo files should be placed in ./demos/ and ./assets/ directories to display properly.
Pusa V1.0 demonstrates that high-quality video generation doesn't require massive computational resources. Our vectorized timestep adaptation approach opens new possibilities for democratizing video AI research and applications.
Pusa V1.0 leverages vectorized timestep adaptation (VTA) for fine-grained temporal control within a unified video diffusion framework. The model achieves unprecedented efficiency, surpassing Wan-I2V on Vbench-I2V with only $500 training cost and 4k data.
⨠Made with ā¤ļø for the AI Community āØ
Experience the future of video generation with Pusa V1.0 š