File size: 6,642 Bytes
ff052dc
d54a2c6
69a8203
65426a8
886d416
4046baa
65426a8
17d4813
886d416
5a0901c
 
 
 
02f7f0d
 
 
 
 
 
 
5a0901c
886d416
a9a00b3
de49789
 
 
a9a00b3
 
 
de49789
 
 
681163e
de49789
 
 
 
 
886d416
681163e
d54a2c6
bf9f272
 
 
 
d54a2c6
bf9f272
 
 
 
985cc0d
bf9f272
d54a2c6
bf9f272
 
 
 
 
 
 
 
 
 
d54a2c6
3098f22
65426a8
bf9f272
3098f22
d54a2c6
65426a8
 
bf9f272
886d416
de49789
 
886d416
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d54a2c6
b9a5c9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
886d416
d54a2c6
1f9ee55
4046baa
886d416
b9a5c9f
 
 
 
 
1f9ee55
b9a5c9f
 
 
 
1f9ee55
 
b9a5c9f
 
 
 
 
 
 
 
 
 
 
 
 
de49789
681163e
 
4fc218f
 
 
 
 
de49789
b9a5c9f
4046baa
d54a2c6
afa05d9
 
b9a5c9f
4046baa
d54a2c6
 
65426a8
d54a2c6
65426a8
d54a2c6
 
17d4813
d54a2c6
65426a8
d54a2c6
17d4813
d54a2c6
 
 
 
 
 
 
 
 
 
17d4813
d54a2c6
17d4813
1f9ee55
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import gradio as gr
import os
import torch
import tempfile
import sys
from huggingface_hub import snapshot_download
import spaces

# Setup paths
PUSA_PATH = os.path.abspath("./PusaV1")
if PUSA_PATH not in sys.path:
    sys.path.insert(0, PUSA_PATH)

# Validate diffsynth presence
DIFFSYNTH_PATH = os.path.join(PUSA_PATH, "diffsynth")
if not os.path.exists(DIFFSYNTH_PATH):
    raise RuntimeError(
        f"'diffsynth' package not found in {PUSA_PATH}. "
        f"Ensure PusaV1 is correctly cloned and folder structure is intact."
    )

# Import core modules from PusaV1
from PusaV1.diffsynth import ModelManager, WanVideoPusaPipeline, save_video





class PatchedModelManager(ModelManager):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # Patch architecture dict here
        custom_architecture_dict = {
            "WanModel": ("diffsynth.models.wan_model", "WanModelPusa", None),
        }
        self.architecture_dict.update(custom_architecture_dict)



# Constants
import os
from huggingface_hub import snapshot_download

# Constants
MODEL_ZOO_DIR = "./model_zoo"
PUSA_DIR = os.path.join(MODEL_ZOO_DIR, "PusaV1")
WAN_SUBFOLDER = "Wan2.1-T2V-14B"
WAN_MODEL_PATH = os.path.join(PUSA_DIR, WAN_SUBFOLDER)
LORA_PATH = os.path.join(PUSA_DIR, "pusa_v1.pt")

# Ensure model and weights are downloaded
def ensure_model_downloaded():
    if not os.path.exists(PUSA_DIR):
        print("Downloading RaphaelLiu/PusaV1 to ./model_zoo/PusaV1 ...")
        snapshot_download(
            repo_id="RaphaelLiu/PusaV1",
            local_dir=PUSA_DIR,
            repo_type="model",
            local_dir_use_symlinks=False,
        )
        print("βœ… PusaV1 downloaded.")

    if not os.path.exists(WAN_MODEL_PATH):
        print("Downloading Wan-AI/Wan2.1-T2V-14B to ./model_zoo/PusaV1/Wan2.1-T2V-14B ...")
        snapshot_download(
            repo_id="Wan-AI/Wan2.1-T2V-14B",
            local_dir=WAN_MODEL_PATH,  # Changed from WAN_DIR to WAN_MODEL_PATH
            repo_type="model",
            local_dir_use_symlinks=False,
        )
        print("βœ… Wan2.1-T2V-14B downloaded.")
# Subclass ModelManager to force WanModelPusa


class PatchedModelManager(ModelManager):
    def load_model(self, file_path=None, model_names=None, device=None, torch_dtype=None):
        if file_path is None:
            file_path = self.file_path_list[0]
        print(f"[app.py] Forcing architecture: WanModelPusa for {file_path}")
        for detector in self.model_detector:
            if detector.match(file_path, {}):
                model_names, models = detector.load(
                    file_path,
                    state_dict={},
                    device=device or self.device,
                    torch_dtype=torch_dtype or self.torch_dtype,
                    allowed_model_names=model_names,
                    model_manager=self,
                    forced_architecture="WanModelPusa"
                )
                for name, model in zip(model_names, models):
                    self.model.append(model)
                    self.model_path.append(file_path)
                    self.model_name.append(name)
                return models[0] if models else None
        print("No suitable model detector matched.")
        return None

# Video generation logic



def generate_t2v_video(self, prompt, lora_alpha, num_inference_steps, 
                          negative_prompt, progress=gr.Progress()):
        """Generate video from text prompt"""
        try:
            progress(0.1, desc="Loading models...")
            lora_path = "./model_zoo/PusaV1/pusa_v1.pt"
            pipe = self.load_lora_and_get_pipe("t2v", lora_path, lora_alpha)
            
            progress(0.3, desc="Generating video...")
            video = pipe(
                prompt=prompt,
                negative_prompt=negative_prompt,
                num_inference_steps=num_inference_steps,
                height=720, width=1280, num_frames=81,
                seed=0, tiled=True
            )
            
            progress(0.9, desc="Saving video...")
            timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
            video_filename = os.path.join(self.output_dir, f"t2v_output_{timestamp}.mp4")
            save_video(video, video_filename, fps=25, quality=5)
            
            progress(1.0, desc="Complete!")
            return video_filename, f"Video generated successfully! Saved to {video_filename}"
            
        except Exception as e:
            return None, f"Error: {str(e)}"



            
@spaces.GPU(duration=200)
def generate_video(prompt: str):
   

    # Load model using patched manager


    model_manager = ModelManager(device="cuda")
    base_dir = "model_zoo/PusaV1/Wan2.1-T2V-14B"
            
    model_files = sorted([os.path.join(base_dir, f) for f in os.listdir(base_dir) if f.endswith('.safetensors')])
            
    model_manager.load_models(
                [
                    model_files,
                    os.path.join(base_dir, "models_t5_umt5-xxl-enc-bf16.pth"),
                    os.path.join(base_dir, "Wan2.1_VAE.pth"),
                ],
                torch_dtype=torch.bfloat16,
            )
    
    
    
    
    
    # manager = ModelManager(
    #     file_path_list=[WAN_MODEL_PATH],
    #     torch_dtype=torch.float16,
    #     device="cuda"
    # )

    

    # manager = PatchedModelManager(
    #    file_path_list=[WAN_MODEL_PATH],
    #    torch_dtype=torch.float16,
    #    device="cuda" 
    #  )
   
    #model =  manager.load_model(WAN_MODEL_PATH)

    # Set up pipeline
    #pipeline = WanVideoPusaPipeline(model=model_manager)
    pipeline = WanVideoPusaPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
    #pipeline.set_lora_adapters(LORA_PATH)

    # Generate video
    result = pipeline(prompt)

    # Save video
    tmp_dir = tempfile.mkdtemp()
    output_path = os.path.join(tmp_dir, "video.mp4")
    save_video(result.frames, output_path, fps=8)

    return output_path

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("## πŸŽ₯ Wan2.1-T2V-14B with Pusa LoRA | Text-to-Video Generator")

    prompt_input = gr.Textbox(
        lines=4,
        label="Prompt",
        placeholder="Describe your video (e.g. A coral reef full of colorful fish...)"
    )

    generate_btn = gr.Button("Generate Video")
    video_output = gr.Video(label="Output")

    generate_btn.click(fn=generate_video, inputs=prompt_input, outputs=video_output)



if __name__ == "__main__":
    ensure_model_downloaded()
    demo.launch(share=True, show_error=True)