Spaces:
Running
Running
# --- START OF FILE media.py (FINAL WITH LIVE PROGRESS) --- | |
# --- LIBRARIES --- | |
import torch | |
import gradio as gr | |
import random | |
import time | |
from diffusers import AutoPipelineForText2Image, TextToVideoSDPipeline, EulerAncestralDiscreteScheduler | |
import gc | |
import os | |
import imageio | |
import numpy as np | |
import threading | |
from queue import Queue, Empty as QueueEmpty | |
from PIL import Image | |
# --- SECURE AUTHENTICATION FOR HUGGING FACE SPACES --- | |
import os | |
from huggingface_hub import login | |
# This code will attempt to read the HF_TOKEN from the Space's secrets. | |
# On your local machine, this will do nothing unless you set it up, which isn't necessary. | |
# On the Hugging Face server, it will find the secret you just saved. | |
HF_TOKEN = os.environ.get('HF_TOKEN') | |
if HF_TOKEN: | |
print("✅ Found HF_TOKEN secret. Logging in...") | |
try: | |
login(token=HF_TOKEN) | |
print("✅ Hugging Face Authentication successful.") | |
except Exception as e: | |
print(f"❌ Hugging Face login failed: {e}") | |
else: | |
print("⚠️ No HF_TOKEN secret found. Gated models may not be available on the deployed app.") | |
# --- CONFIGURATION & STATE --- | |
available_models = { | |
"Fast Image (SDXL Turbo)": "stabilityai/sdxl-turbo", | |
"Quality Image (SDXL)": "stabilityai/stable-diffusion-xl-base-1.0", | |
"Photorealism (Juggernaut)": "RunDiffusion/Juggernaut-XL-v9", | |
"Video (Damo-Vilab)": "damo-vilab/text-to-video-ms-1.7b" | |
} | |
model_state = { "current_pipe": None, "loaded_model_name": None } | |
# --- THE FINAL GENERATION FUNCTION WITH LIVE PROGRESS --- | |
def generate_media_live_progress(model_key, prompt, negative_prompt, steps, cfg_scale, width, height, seed, num_frames): | |
# --- Model Loading (Unchanged) --- | |
if model_state.get("loaded_model_name") != model_key: | |
yield {output_image: None, output_video: None, status_textbox: f"Loading {model_key}..."} | |
if model_state.get("current_pipe"): | |
del model_state["current_pipe"]; gc.collect(); torch.cuda.empty_cache() | |
model_id = available_models[model_key] | |
if "Video" in model_key: | |
pipe = TextToVideoSDPipeline.from_pretrained(model_id, torch_dtype=torch_dtype) | |
else: | |
pipe = AutoPipelineForText2Image.from_pretrained(model_id, torch_dtype=torch_dtype, variant="fp16") | |
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) | |
pipe.to(device) | |
if device == "cuda": | |
if "Video" not in model_key: pipe.enable_model_cpu_offload() | |
pipe.enable_vae_slicing() | |
model_state["current_pipe"] = pipe | |
model_state["loaded_model_name"] = model_key | |
print(f"✅ Model loaded on {device.upper()}.") | |
pipe = model_state["current_pipe"] | |
generator = torch.Generator(device).manual_seed(seed) | |
# --- Generation Logic --- | |
if "Video" in model_key: | |
# For video, we'll keep the simple status updates for now | |
yield {output_image: None, output_video: None, status_textbox: "Generating video..."} | |
video_frames = pipe(prompt=prompt, num_inference_steps=int(steps), height=320, width=576, num_frames=int(num_frames), generator=generator).frames | |
video_frames_5d = np.array(video_frames) | |
video_frames_4d = np.squeeze(video_frames_5d) | |
video_uint8 = (video_frames_4d * 255).astype(np.uint8) | |
list_of_frames = [frame for frame in video_uint8] | |
video_path = f"video_{seed}.mp4" | |
imageio.mimsave(video_path, list_of_frames, fps=12) | |
yield {output_image: None, output_video: video_path, status_textbox: f"Video saved! Seed: {seed}"} | |
else: # Image Generation with Live Progress | |
progress_queue = Queue() | |
def run_pipe(): | |
# This function runs in a separate thread | |
start_time = time.time() | |
def progress_callback(pipe, step, timestep, callback_kwargs): | |
# This is called by the pipeline at each step | |
elapsed_time = time.time() - start_time | |
# Avoid division by zero on the first step | |
if elapsed_time > 0: | |
its_per_sec = (step + 1) / elapsed_time | |
progress_queue.put((step + 1, its_per_sec)) | |
return callback_kwargs | |
try: | |
# The final image is still generated using the pipeline's high-quality VAE | |
final_image = pipe( | |
prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=int(steps), | |
guidance_scale=float(cfg_scale), width=int(width), height=int(height), | |
generator=generator, | |
callback_on_step_end=progress_callback | |
).images[0] | |
progress_queue.put(final_image) # Put the final result on the queue | |
except Exception as e: | |
print(f"An error occurred in the generation thread: {e}") | |
progress_queue.put(None) # Signal an error | |
# Start the generation in the background | |
thread = threading.Thread(target=run_pipe) | |
thread.start() | |
# In the main thread, listen for updates from the queue and yield to Gradio | |
total_steps = int(steps) | |
yield {status_textbox: "Generating..."} # Initial status | |
while True: | |
try: | |
update = progress_queue.get(timeout=1.0) # Wait for an update | |
if isinstance(update, Image.Image): # It's the final image | |
yield {output_image: update, status_textbox: f"Generation complete! Seed: {seed}"} | |
break | |
elif isinstance(update, tuple): # It's a progress update (step, speed) | |
current_step, its_per_sec = update | |
progress_percent = (current_step / total_steps) * 100 | |
steps_remaining = total_steps - current_step | |
eta_seconds = steps_remaining / its_per_sec if its_per_sec > 0 else 0 | |
eta_minutes, eta_seconds_rem = divmod(int(eta_seconds), 60) | |
status_text = ( | |
f"Generating... {progress_percent:.0f}% ({current_step}/{total_steps}) | " | |
f"{its_per_sec:.2f}it/s | " | |
f"ETA: {eta_minutes:02d}:{eta_seconds_rem:02d}" | |
) | |
yield {status_textbox: status_text} | |
elif update is None: # An error occurred | |
yield {status_textbox: "Error during generation. Check console."} | |
break | |
except QueueEmpty: | |
if not thread.is_alive(): | |
print("⚠️ Generation thread finished unexpectedly.") | |
yield {status_textbox: "Generation failed. Check console for details."} | |
break | |
thread.join() | |
# --- GRADIO UI --- | |
with gr.Blocks(theme='gradio/soft') as demo: | |
# (UI layout is the same, just point to the new function) | |
gr.Markdown("# The Generative Media Suite") | |
gr.Markdown("Create fast images, high-quality images, or short videos. Created by cheeseman182. (note: the speed on the status bar is wrong)") | |
seed_state = gr.State(-1) | |
with gr.Row(): | |
with gr.Column(scale=2): | |
model_selector = gr.Radio(label="Select Model", choices=list(available_models.keys()), value=list(available_models.keys())[0]) | |
prompt_input = gr.Textbox(label="Prompt", lines=4, placeholder="An astronaut riding a horse on Mars, cinematic...") | |
negative_prompt_input = gr.Textbox(label="Negative Prompt", lines=2, value="ugly, blurry, deformed, watermark, text") | |
with gr.Accordion("Settings", open=True): | |
steps_slider = gr.Slider(1, 100, 30, step=1, label="Inference Steps") | |
cfg_slider = gr.Slider(0.0, 15.0, 7.5, step=0.5, label="Guidance Scale (CFG)") | |
with gr.Row(): | |
width_slider = gr.Slider(256, 1024, 768, step=64, label="Width") | |
height_slider = gr.Slider(256, 1024, 768, step=64, label="Height") | |
num_frames_slider = gr.Slider(12, 48, 24, step=4, label="Video Frames", visible=False) | |
seed_input = gr.Number(-1, label="Seed (-1 for random)") | |
generate_button = gr.Button("Generate", variant="primary") | |
with gr.Column(scale=3): | |
output_image = gr.Image(label="Image Result", interactive=False, height="60vh", visible=True) | |
output_video = gr.Video(label="Video Result", interactive=False, height="60vh", visible=False) | |
status_textbox = gr.Textbox(label="Status", interactive=False) | |
def update_ui_on_model_change(model_key): | |
is_video = "Video" in model_key | |
is_turbo = "Turbo" in model_key | |
return { | |
steps_slider: gr.update(interactive=not is_turbo, value=1 if is_turbo else 30), | |
cfg_slider: gr.update(interactive=not is_turbo, value=0.0 if is_turbo else 7.5), | |
width_slider: gr.update(visible=not is_video), | |
height_slider: gr.update(visible=not is_video), | |
num_frames_slider: gr.update(visible=is_video), | |
output_image: gr.update(visible=not is_video), | |
output_video: gr.update(visible=is_video) | |
} | |
model_selector.change(update_ui_on_model_change, model_selector, [steps_slider, cfg_slider, width_slider, height_slider, num_frames_slider, output_image, output_video]) | |
click_event = generate_button.click( | |
fn=lambda s: (s if s != -1 else random.randint(0, 2**32 - 1)), | |
inputs=seed_input, | |
outputs=seed_state, | |
queue=False | |
).then( | |
fn=generate_media_live_progress, # Use the new function with progress | |
inputs=[model_selector, prompt_input, negative_prompt_input, steps_slider, cfg_slider, width_slider, height_slider, seed_state, num_frames_slider], | |
outputs=[output_image, output_video, status_textbox] | |
) | |
demo.launch() |