Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import os | |
import torch | |
import tempfile | |
import sys | |
from huggingface_hub import snapshot_download | |
import spaces | |
# Setup paths | |
PUSA_PATH = os.path.abspath("./PusaV1") | |
if PUSA_PATH not in sys.path: | |
sys.path.insert(0, PUSA_PATH) | |
# Validate diffsynth presence | |
DIFFSYNTH_PATH = os.path.join(PUSA_PATH, "diffsynth") | |
if not os.path.exists(DIFFSYNTH_PATH): | |
raise RuntimeError( | |
f"'diffsynth' package not found in {PUSA_PATH}. " | |
f"Ensure PusaV1 is correctly cloned and folder structure is intact." | |
) | |
# Import core modules from PusaV1 | |
from PusaV1.diffsynth import ModelManager, WanVideoPusaPipeline, save_video | |
class PatchedModelManager(ModelManager): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
# Patch architecture dict here | |
custom_architecture_dict = { | |
"WanModel": ("diffsynth.models.wan_model", "WanModelPusa", None), | |
} | |
self.architecture_dict.update(custom_architecture_dict) | |
# Constants | |
import os | |
from huggingface_hub import snapshot_download | |
# Constants | |
MODEL_ZOO_DIR = "./model_zoo" | |
PUSA_DIR = os.path.join(MODEL_ZOO_DIR, "PusaV1") | |
WAN_SUBFOLDER = "Wan2.1-T2V-14B" | |
WAN_MODEL_PATH = os.path.join(PUSA_DIR, WAN_SUBFOLDER) | |
LORA_PATH = os.path.join(PUSA_DIR, "pusa_v1.pt") | |
# Ensure model and weights are downloaded | |
def ensure_model_downloaded(): | |
if not os.path.exists(PUSA_DIR): | |
print("Downloading RaphaelLiu/PusaV1 to ./model_zoo/PusaV1 ...") | |
snapshot_download( | |
repo_id="RaphaelLiu/PusaV1", | |
local_dir=PUSA_DIR, | |
repo_type="model", | |
local_dir_use_symlinks=False, | |
) | |
print("✅ PusaV1 downloaded.") | |
if not os.path.exists(WAN_MODEL_PATH): | |
print("Downloading Wan-AI/Wan2.1-T2V-14B to ./model_zoo/PusaV1/Wan2.1-T2V-14B ...") | |
snapshot_download( | |
repo_id="Wan-AI/Wan2.1-T2V-14B", | |
local_dir=WAN_MODEL_PATH, # Changed from WAN_DIR to WAN_MODEL_PATH | |
repo_type="model", | |
local_dir_use_symlinks=False, | |
) | |
print("✅ Wan2.1-T2V-14B downloaded.") | |
# Subclass ModelManager to force WanModelPusa | |
class PatchedModelManager(ModelManager): | |
def load_model(self, file_path=None, model_names=None, device=None, torch_dtype=None): | |
if file_path is None: | |
file_path = self.file_path_list[0] | |
print(f"[app.py] Forcing architecture: WanModelPusa for {file_path}") | |
for detector in self.model_detector: | |
if detector.match(file_path, {}): | |
model_names, models = detector.load( | |
file_path, | |
state_dict={}, | |
device=device or self.device, | |
torch_dtype=torch_dtype or self.torch_dtype, | |
allowed_model_names=model_names, | |
model_manager=self, | |
forced_architecture="WanModelPusa" | |
) | |
for name, model in zip(model_names, models): | |
self.model.append(model) | |
self.model_path.append(file_path) | |
self.model_name.append(name) | |
return models[0] if models else None | |
print("No suitable model detector matched.") | |
return None | |
# Video generation logic | |
def generate_t2v_video(self, prompt, lora_alpha, num_inference_steps, | |
negative_prompt, progress=gr.Progress()): | |
"""Generate video from text prompt""" | |
try: | |
progress(0.1, desc="Loading models...") | |
lora_path = "./model_zoo/PusaV1/pusa_v1.pt" | |
pipe = self.load_lora_and_get_pipe("t2v", lora_path, lora_alpha) | |
progress(0.3, desc="Generating video...") | |
video = pipe( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
num_inference_steps=num_inference_steps, | |
height=720, width=1280, num_frames=81, | |
seed=0, tiled=True | |
) | |
progress(0.9, desc="Saving video...") | |
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
video_filename = os.path.join(self.output_dir, f"t2v_output_{timestamp}.mp4") | |
save_video(video, video_filename, fps=25, quality=5) | |
progress(1.0, desc="Complete!") | |
return video_filename, f"Video generated successfully! Saved to {video_filename}" | |
except Exception as e: | |
return None, f"Error: {str(e)}" | |
def generate_video(prompt: str): | |
# Load model using patched manager | |
model_manager = ModelManager(device="cuda") | |
base_dir = "model_zoo/PusaV1/Wan2.1-T2V-14B" | |
model_files = sorted([os.path.join(base_dir, f) for f in os.listdir(base_dir) if f.endswith('.safetensors')]) | |
model_manager.load_models( | |
[ | |
model_files, | |
os.path.join(base_dir, "models_t5_umt5-xxl-enc-bf16.pth"), | |
os.path.join(base_dir, "Wan2.1_VAE.pth"), | |
], | |
torch_dtype=torch.bfloat16, | |
) | |
# manager = ModelManager( | |
# file_path_list=[WAN_MODEL_PATH], | |
# torch_dtype=torch.float16, | |
# device="cuda" | |
# ) | |
# manager = PatchedModelManager( | |
# file_path_list=[WAN_MODEL_PATH], | |
# torch_dtype=torch.float16, | |
# device="cuda" | |
# ) | |
#model = manager.load_model(WAN_MODEL_PATH) | |
# Set up pipeline | |
#pipeline = WanVideoPusaPipeline(model=model_manager) | |
pipeline = WanVideoPusaPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") | |
#pipeline.set_lora_adapters(LORA_PATH) | |
# Generate video | |
result = pipeline(prompt) | |
# Save video | |
tmp_dir = tempfile.mkdtemp() | |
output_path = os.path.join(tmp_dir, "video.mp4") | |
save_video(result.frames, output_path, fps=8) | |
return output_path | |
# Gradio UI | |
with gr.Blocks() as demo: | |
gr.Markdown("## 🎥 Wan2.1-T2V-14B with Pusa LoRA | Text-to-Video Generator") | |
prompt_input = gr.Textbox( | |
lines=4, | |
label="Prompt", | |
placeholder="Describe your video (e.g. A coral reef full of colorful fish...)" | |
) | |
generate_btn = gr.Button("Generate Video") | |
video_output = gr.Video(label="Output") | |
generate_btn.click(fn=generate_video, inputs=prompt_input, outputs=video_output) | |
if __name__ == "__main__": | |
ensure_model_downloaded() | |
demo.launch(share=True, show_error=True) | |