PusaV1 / app.py
rahul7star's picture
Update app.py
afa05d9 verified
import gradio as gr
import os
import torch
import tempfile
import sys
from huggingface_hub import snapshot_download
import spaces
# Setup paths
PUSA_PATH = os.path.abspath("./PusaV1")
if PUSA_PATH not in sys.path:
sys.path.insert(0, PUSA_PATH)
# Validate diffsynth presence
DIFFSYNTH_PATH = os.path.join(PUSA_PATH, "diffsynth")
if not os.path.exists(DIFFSYNTH_PATH):
raise RuntimeError(
f"'diffsynth' package not found in {PUSA_PATH}. "
f"Ensure PusaV1 is correctly cloned and folder structure is intact."
)
# Import core modules from PusaV1
from PusaV1.diffsynth import ModelManager, WanVideoPusaPipeline, save_video
class PatchedModelManager(ModelManager):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Patch architecture dict here
custom_architecture_dict = {
"WanModel": ("diffsynth.models.wan_model", "WanModelPusa", None),
}
self.architecture_dict.update(custom_architecture_dict)
# Constants
import os
from huggingface_hub import snapshot_download
# Constants
MODEL_ZOO_DIR = "./model_zoo"
PUSA_DIR = os.path.join(MODEL_ZOO_DIR, "PusaV1")
WAN_SUBFOLDER = "Wan2.1-T2V-14B"
WAN_MODEL_PATH = os.path.join(PUSA_DIR, WAN_SUBFOLDER)
LORA_PATH = os.path.join(PUSA_DIR, "pusa_v1.pt")
# Ensure model and weights are downloaded
def ensure_model_downloaded():
if not os.path.exists(PUSA_DIR):
print("Downloading RaphaelLiu/PusaV1 to ./model_zoo/PusaV1 ...")
snapshot_download(
repo_id="RaphaelLiu/PusaV1",
local_dir=PUSA_DIR,
repo_type="model",
local_dir_use_symlinks=False,
)
print("✅ PusaV1 downloaded.")
if not os.path.exists(WAN_MODEL_PATH):
print("Downloading Wan-AI/Wan2.1-T2V-14B to ./model_zoo/PusaV1/Wan2.1-T2V-14B ...")
snapshot_download(
repo_id="Wan-AI/Wan2.1-T2V-14B",
local_dir=WAN_MODEL_PATH, # Changed from WAN_DIR to WAN_MODEL_PATH
repo_type="model",
local_dir_use_symlinks=False,
)
print("✅ Wan2.1-T2V-14B downloaded.")
# Subclass ModelManager to force WanModelPusa
class PatchedModelManager(ModelManager):
def load_model(self, file_path=None, model_names=None, device=None, torch_dtype=None):
if file_path is None:
file_path = self.file_path_list[0]
print(f"[app.py] Forcing architecture: WanModelPusa for {file_path}")
for detector in self.model_detector:
if detector.match(file_path, {}):
model_names, models = detector.load(
file_path,
state_dict={},
device=device or self.device,
torch_dtype=torch_dtype or self.torch_dtype,
allowed_model_names=model_names,
model_manager=self,
forced_architecture="WanModelPusa"
)
for name, model in zip(model_names, models):
self.model.append(model)
self.model_path.append(file_path)
self.model_name.append(name)
return models[0] if models else None
print("No suitable model detector matched.")
return None
# Video generation logic
def generate_t2v_video(self, prompt, lora_alpha, num_inference_steps,
negative_prompt, progress=gr.Progress()):
"""Generate video from text prompt"""
try:
progress(0.1, desc="Loading models...")
lora_path = "./model_zoo/PusaV1/pusa_v1.pt"
pipe = self.load_lora_and_get_pipe("t2v", lora_path, lora_alpha)
progress(0.3, desc="Generating video...")
video = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
height=720, width=1280, num_frames=81,
seed=0, tiled=True
)
progress(0.9, desc="Saving video...")
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
video_filename = os.path.join(self.output_dir, f"t2v_output_{timestamp}.mp4")
save_video(video, video_filename, fps=25, quality=5)
progress(1.0, desc="Complete!")
return video_filename, f"Video generated successfully! Saved to {video_filename}"
except Exception as e:
return None, f"Error: {str(e)}"
@spaces.GPU(duration=200)
def generate_video(prompt: str):
# Load model using patched manager
model_manager = ModelManager(device="cuda")
base_dir = "model_zoo/PusaV1/Wan2.1-T2V-14B"
model_files = sorted([os.path.join(base_dir, f) for f in os.listdir(base_dir) if f.endswith('.safetensors')])
model_manager.load_models(
[
model_files,
os.path.join(base_dir, "models_t5_umt5-xxl-enc-bf16.pth"),
os.path.join(base_dir, "Wan2.1_VAE.pth"),
],
torch_dtype=torch.bfloat16,
)
# manager = ModelManager(
# file_path_list=[WAN_MODEL_PATH],
# torch_dtype=torch.float16,
# device="cuda"
# )
# manager = PatchedModelManager(
# file_path_list=[WAN_MODEL_PATH],
# torch_dtype=torch.float16,
# device="cuda"
# )
#model = manager.load_model(WAN_MODEL_PATH)
# Set up pipeline
#pipeline = WanVideoPusaPipeline(model=model_manager)
pipeline = WanVideoPusaPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
#pipeline.set_lora_adapters(LORA_PATH)
# Generate video
result = pipeline(prompt)
# Save video
tmp_dir = tempfile.mkdtemp()
output_path = os.path.join(tmp_dir, "video.mp4")
save_video(result.frames, output_path, fps=8)
return output_path
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## 🎥 Wan2.1-T2V-14B with Pusa LoRA | Text-to-Video Generator")
prompt_input = gr.Textbox(
lines=4,
label="Prompt",
placeholder="Describe your video (e.g. A coral reef full of colorful fish...)"
)
generate_btn = gr.Button("Generate Video")
video_output = gr.Video(label="Output")
generate_btn.click(fn=generate_video, inputs=prompt_input, outputs=video_output)
if __name__ == "__main__":
ensure_model_downloaded()
demo.launch(share=True, show_error=True)