# --- START OF FILE media.py (FINAL WITH LIVE PROGRESS & FIXES) --- # --- LIBRARIES --- import torch import gradio as gr import random import time from diffusers import AutoPipelineForText2Image, TextToVideoSDPipeline, EulerAncestralDiscreteScheduler import gc import os import imageio import numpy as np import threading from queue import Queue, Empty as QueueEmpty from PIL import Image import os from huggingface_hub import login # --- DYNAMIC HARDWARE DETECTION --- if torch.cuda.is_available(): device = "cuda" torch_dtype = torch.float16 print("✅ GPU detected. Using CUDA.") else: device = "cpu" torch_dtype = torch.float32 print("⚠️ No GPU detected. Using CPU.") HF_TOKEN = os.environ.get('HF_TOKEN') if HF_TOKEN: print("✅ Found HF_TOKEN secret. Logging in...") try: login(token=HF_TOKEN) print("✅ Hugging Face Authentication successful.") except Exception as e: print(f"❌ Hugging Face login failed: {e}") else: # This message will show when you run the app locally, which is fine. print("⚠️ No HF_TOKEN secret found. This is normal for local testing.") print(" The deployed app will use the secret you set on Hugging Face.") # --- CONFIGURATION & STATE --- available_models = { "Fast Image (SDXL Turbo)": "stabilityai/sdxl-turbo", "Quality Image (SDXL)": "stabilityai/stable-diffusion-xl-base-1.0", "Photorealism (Juggernaut)": "RunDiffusion/Juggernaut-XL-v9", "Video (Damo-Vilab)": "damo-vilab/text-to-video-ms-1.7b" } model_state = { "current_pipe": None, "loaded_model_name": None } # --- THE FINAL GENERATION FUNCTION WITH LIVE PROGRESS & FIXES --- def generate_media_live_progress(model_key, prompt, negative_prompt, steps, cfg_scale, width, height, seed, num_frames): global model_state # --- Model Loading & Cleanup --- if model_state.get("loaded_model_name") != model_key: yield {output_image: None, output_video: None, status_textbox: f"Loading {model_key}..."} # --- More Aggressive & Explicit Cleanup --- pipe_to_delete = model_state.pop("current_pipe", None) if pipe_to_delete: # FIX: Explicitly move the model to CPU before deleting to free VRAM. print("Offloading previous model to CPU...") pipe_to_delete.to("cpu") del pipe_to_delete print("Previous model deleted.") # Explicitly run garbage collection and empty CUDA cache. gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() # Load the new pipeline model_id = available_models[model_key] if "Video" in model_key: pipe = TextToVideoSDPipeline.from_pretrained(model_id, torch_dtype=torch_dtype, variant="fp16") else: pipe = AutoPipelineForText2Image.from_pretrained(model_id, torch_dtype=torch_dtype, variant="fp16") pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) pipe.to(device) if device == "cuda": if "Video" not in model_key: pipe.enable_model_cpu_offload() pipe.enable_vae_slicing() model_state["current_pipe"] = pipe model_state["loaded_model_name"] = model_key print(f"✅ Model '{model_key}' loaded on {device.upper()}.") pipe = model_state["current_pipe"] generator = torch.Generator(device).manual_seed(seed) # --- Generation Logic --- if "Video" in model_key: yield {output_image: None, output_video: None, status_textbox: "Generating video..."} try: video_frames = pipe(prompt=prompt, num_inference_steps=int(steps), height=320, width=576, num_frames=int(num_frames), generator=generator).frames # FIX: More memory-efficient video saving video_path = f"video_{seed}.mp4" with imageio.get_writer(video_path, fps=12) as writer: for frame in video_frames: writer.append_data((frame * 255).astype(np.uint8)) yield {output_image: None, output_video: video_path, status_textbox: f"Video saved! Seed: {seed}"} except Exception as e: print(f"An error occurred during video generation: {e}") yield {status_textbox: f"Error during video generation: {e}"} else: # Image Generation with Live Progress progress_queue = Queue() def run_pipe(): start_time = time.time() def progress_callback(pipe, step, timestep, callback_kwargs): elapsed_time = time.time() - start_time if elapsed_time > 0: its_per_sec = (step + 1) / elapsed_time progress_queue.put(("progress", (step + 1, its_per_sec))) return callback_kwargs try: final_image = pipe( prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=int(steps), guidance_scale=float(cfg_scale), width=int(width), height=int(height), generator=generator, callback_on_step_end=progress_callback ).images[0] progress_queue.put(("result", final_image)) except Exception as e: print(f"An error occurred in the generation thread: {e}") progress_queue.put(("error", str(e))) thread = threading.Thread(target=run_pipe) thread.start() total_steps = int(steps) yield {status_textbox: "Generating..."} while True: try: update_type, payload = progress_queue.get(timeout=1.0) if update_type == "result": yield {output_image: payload, status_textbox: f"Generation complete! Seed: {seed}"} break elif update_type == "progress": current_step, its_per_sec = payload progress_percent = (current_step / total_steps) * 100 steps_remaining = total_steps - current_step eta_seconds = steps_remaining / its_per_sec if its_per_sec > 0 else 0 eta_minutes, eta_seconds_rem = divmod(int(eta_seconds), 60) status_text = ( f"Generating... {progress_percent:.0f}% ({current_step}/{total_steps}) | " f"{its_per_sec:.2f}it/s | " f"ETA: {eta_minutes:02d}:{eta_seconds_rem:02d}" ) yield {status_textbox: status_text} elif update_type == "error": yield {status_textbox: f"Error: {payload}. Check console."} break except QueueEmpty: if not thread.is_alive(): print("⚠️ Generation thread finished unexpectedly.") yield {status_textbox: "Generation failed. Check console for details."} break thread.join() print("Generation thread joined.") # --- GRADIO UI --- with gr.Blocks(theme='gradio/soft') as demo: gr.Markdown("# The Generative Media Suite") gr.Markdown("Create fast images, high-quality images, or short videos. Created by cheeseman182. (note: the speed on the status bar is wrong)") seed_state = gr.State(-1) with gr.Row(): with gr.Column(scale=2): model_selector = gr.Radio(label="Select Model", choices=list(available_models.keys()), value=list(available_models.keys())[0]) prompt_input = gr.Textbox(label="Prompt", lines=4, placeholder="An astronaut riding a horse on Mars, cinematic...") negative_prompt_input = gr.Textbox(label="Negative Prompt", lines=2, value="ugly, blurry, deformed, watermark, text, overblown, high contrast, not photorealistic") with gr.Accordion("Settings", open=True): steps_slider = gr.Slider(1, 100, 30, step=1, label="Inference Steps") cfg_slider = gr.Slider(0.0, 15.0, 7.5, step=0.5, label="Guidance Scale (CFG)") with gr.Row(): width_slider = gr.Slider(256, 1024, 768, step=64, label="Width") height_slider = gr.Slider(256, 1024, 768, step=64, label="Height") num_frames_slider = gr.Slider(12, 48, 24, step=4, label="Video Frames", visible=False) seed_input = gr.Number(-1, label="Seed (-1 for random)") generate_button = gr.Button("Generate", variant="primary") with gr.Column(scale=3): output_image = gr.Image(label="Image Result", interactive=False, height="60vh", visible=True) output_video = gr.Video(label="Video Result", interactive=False, height="60vh", visible=False) status_textbox = gr.Textbox(label="Status", interactive=False) def update_ui_on_model_change(model_key): is_video = "Video" in model_key is_turbo = "Turbo" in model_key return { steps_slider: gr.update(interactive=not is_turbo, value=1 if is_turbo else 30), cfg_slider: gr.update(interactive=not is_turbo, value=0.0 if is_turbo else 7.5), width_slider: gr.update(visible=not is_video), height_slider: gr.update(visible=not is_video), num_frames_slider: gr.update(visible=is_video), output_image: gr.update(visible=not is_video), output_video: gr.update(visible=is_video) } model_selector.change(update_ui_on_model_change, model_selector, [steps_slider, cfg_slider, width_slider, height_slider, num_frames_slider, output_image, output_video]) generate_button.click( fn=lambda s: s if s != -1 else random.randint(0, 2**32 - 1), inputs=seed_input, outputs=seed_state, queue=False ).then( fn=generate_media_live_progress, inputs=[model_selector, prompt_input, negative_prompt_input, steps_slider, cfg_slider, width_slider, height_slider, seed_state, num_frames_slider], outputs=[output_image, output_video, status_textbox] ) if __name__ == "__main__": demo.launch(share=True, debug=True)