Generative_Suite / media.py
cheeseman182's picture
Update media.py
de3c817 verified
# --- START OF FILE media.py (FINAL WITH LIVE PROGRESS) ---
# --- LIBRARIES ---
import torch
import gradio as gr
import random
import time
from diffusers import AutoPipelineForText2Image, TextToVideoSDPipeline, EulerAncestralDiscreteScheduler
import gc
import os
import imageio
import numpy as np
import threading
from queue import Queue, Empty as QueueEmpty
from PIL import Image
# --- SECURE AUTHENTICATION FOR HUGGING FACE SPACES ---
import os
from huggingface_hub import login
# This code will attempt to read the HF_TOKEN from the Space's secrets.
# On your local machine, this will do nothing unless you set it up, which isn't necessary.
# On the Hugging Face server, it will find the secret you just saved.
HF_TOKEN = os.environ.get('HF_TOKEN')
if HF_TOKEN:
print("✅ Found HF_TOKEN secret. Logging in...")
try:
login(token=HF_TOKEN)
print("✅ Hugging Face Authentication successful.")
except Exception as e:
print(f"❌ Hugging Face login failed: {e}")
else:
print("⚠️ No HF_TOKEN secret found. Gated models may not be available on the deployed app.")
# --- CONFIGURATION & STATE ---
available_models = {
"Fast Image (SDXL Turbo)": "stabilityai/sdxl-turbo",
"Quality Image (SDXL)": "stabilityai/stable-diffusion-xl-base-1.0",
"Photorealism (Juggernaut)": "RunDiffusion/Juggernaut-XL-v9",
"Video (Damo-Vilab)": "damo-vilab/text-to-video-ms-1.7b"
}
model_state = { "current_pipe": None, "loaded_model_name": None }
# --- THE FINAL GENERATION FUNCTION WITH LIVE PROGRESS ---
def generate_media_live_progress(model_key, prompt, negative_prompt, steps, cfg_scale, width, height, seed, num_frames):
# --- Model Loading (Unchanged) ---
if model_state.get("loaded_model_name") != model_key:
yield {output_image: None, output_video: None, status_textbox: f"Loading {model_key}..."}
if model_state.get("current_pipe"):
del model_state["current_pipe"]; gc.collect(); torch.cuda.empty_cache()
model_id = available_models[model_key]
if "Video" in model_key:
pipe = TextToVideoSDPipeline.from_pretrained(model_id, torch_dtype=torch_dtype)
else:
pipe = AutoPipelineForText2Image.from_pretrained(model_id, torch_dtype=torch_dtype, variant="fp16")
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
pipe.to(device)
if device == "cuda":
if "Video" not in model_key: pipe.enable_model_cpu_offload()
pipe.enable_vae_slicing()
model_state["current_pipe"] = pipe
model_state["loaded_model_name"] = model_key
print(f"✅ Model loaded on {device.upper()}.")
pipe = model_state["current_pipe"]
generator = torch.Generator(device).manual_seed(seed)
# --- Generation Logic ---
if "Video" in model_key:
# For video, we'll keep the simple status updates for now
yield {output_image: None, output_video: None, status_textbox: "Generating video..."}
video_frames = pipe(prompt=prompt, num_inference_steps=int(steps), height=320, width=576, num_frames=int(num_frames), generator=generator).frames
video_frames_5d = np.array(video_frames)
video_frames_4d = np.squeeze(video_frames_5d)
video_uint8 = (video_frames_4d * 255).astype(np.uint8)
list_of_frames = [frame for frame in video_uint8]
video_path = f"video_{seed}.mp4"
imageio.mimsave(video_path, list_of_frames, fps=12)
yield {output_image: None, output_video: video_path, status_textbox: f"Video saved! Seed: {seed}"}
else: # Image Generation with Live Progress
progress_queue = Queue()
def run_pipe():
# This function runs in a separate thread
start_time = time.time()
def progress_callback(pipe, step, timestep, callback_kwargs):
# This is called by the pipeline at each step
elapsed_time = time.time() - start_time
# Avoid division by zero on the first step
if elapsed_time > 0:
its_per_sec = (step + 1) / elapsed_time
progress_queue.put((step + 1, its_per_sec))
return callback_kwargs
try:
# The final image is still generated using the pipeline's high-quality VAE
final_image = pipe(
prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=int(steps),
guidance_scale=float(cfg_scale), width=int(width), height=int(height),
generator=generator,
callback_on_step_end=progress_callback
).images[0]
progress_queue.put(final_image) # Put the final result on the queue
except Exception as e:
print(f"An error occurred in the generation thread: {e}")
progress_queue.put(None) # Signal an error
# Start the generation in the background
thread = threading.Thread(target=run_pipe)
thread.start()
# In the main thread, listen for updates from the queue and yield to Gradio
total_steps = int(steps)
yield {status_textbox: "Generating..."} # Initial status
while True:
try:
update = progress_queue.get(timeout=1.0) # Wait for an update
if isinstance(update, Image.Image): # It's the final image
yield {output_image: update, status_textbox: f"Generation complete! Seed: {seed}"}
break
elif isinstance(update, tuple): # It's a progress update (step, speed)
current_step, its_per_sec = update
progress_percent = (current_step / total_steps) * 100
steps_remaining = total_steps - current_step
eta_seconds = steps_remaining / its_per_sec if its_per_sec > 0 else 0
eta_minutes, eta_seconds_rem = divmod(int(eta_seconds), 60)
status_text = (
f"Generating... {progress_percent:.0f}% ({current_step}/{total_steps}) | "
f"{its_per_sec:.2f}it/s | "
f"ETA: {eta_minutes:02d}:{eta_seconds_rem:02d}"
)
yield {status_textbox: status_text}
elif update is None: # An error occurred
yield {status_textbox: "Error during generation. Check console."}
break
except QueueEmpty:
if not thread.is_alive():
print("⚠️ Generation thread finished unexpectedly.")
yield {status_textbox: "Generation failed. Check console for details."}
break
thread.join()
# --- GRADIO UI ---
with gr.Blocks(theme='gradio/soft') as demo:
# (UI layout is the same, just point to the new function)
gr.Markdown("# The Generative Media Suite")
gr.Markdown("Create fast images, high-quality images, or short videos. Created by cheeseman182. (note: the speed on the status bar is wrong)")
seed_state = gr.State(-1)
with gr.Row():
with gr.Column(scale=2):
model_selector = gr.Radio(label="Select Model", choices=list(available_models.keys()), value=list(available_models.keys())[0])
prompt_input = gr.Textbox(label="Prompt", lines=4, placeholder="An astronaut riding a horse on Mars, cinematic...")
negative_prompt_input = gr.Textbox(label="Negative Prompt", lines=2, value="ugly, blurry, deformed, watermark, text")
with gr.Accordion("Settings", open=True):
steps_slider = gr.Slider(1, 100, 30, step=1, label="Inference Steps")
cfg_slider = gr.Slider(0.0, 15.0, 7.5, step=0.5, label="Guidance Scale (CFG)")
with gr.Row():
width_slider = gr.Slider(256, 1024, 768, step=64, label="Width")
height_slider = gr.Slider(256, 1024, 768, step=64, label="Height")
num_frames_slider = gr.Slider(12, 48, 24, step=4, label="Video Frames", visible=False)
seed_input = gr.Number(-1, label="Seed (-1 for random)")
generate_button = gr.Button("Generate", variant="primary")
with gr.Column(scale=3):
output_image = gr.Image(label="Image Result", interactive=False, height="60vh", visible=True)
output_video = gr.Video(label="Video Result", interactive=False, height="60vh", visible=False)
status_textbox = gr.Textbox(label="Status", interactive=False)
def update_ui_on_model_change(model_key):
is_video = "Video" in model_key
is_turbo = "Turbo" in model_key
return {
steps_slider: gr.update(interactive=not is_turbo, value=1 if is_turbo else 30),
cfg_slider: gr.update(interactive=not is_turbo, value=0.0 if is_turbo else 7.5),
width_slider: gr.update(visible=not is_video),
height_slider: gr.update(visible=not is_video),
num_frames_slider: gr.update(visible=is_video),
output_image: gr.update(visible=not is_video),
output_video: gr.update(visible=is_video)
}
model_selector.change(update_ui_on_model_change, model_selector, [steps_slider, cfg_slider, width_slider, height_slider, num_frames_slider, output_image, output_video])
click_event = generate_button.click(
fn=lambda s: (s if s != -1 else random.randint(0, 2**32 - 1)),
inputs=seed_input,
outputs=seed_state,
queue=False
).then(
fn=generate_media_live_progress, # Use the new function with progress
inputs=[model_selector, prompt_input, negative_prompt_input, steps_slider, cfg_slider, width_slider, height_slider, seed_state, num_frames_slider],
outputs=[output_image, output_video, status_textbox]
)
demo.launch()