import gradio as gr import numpy as np import torch import os from PIL import Image from typing import Optional, Union from diffusers import StableDiffusionUpscalePipeline PIPE: Optional[StableDiffusionUpscalePipeline] = None def get_pipe() -> StableDiffusionUpscalePipeline: global PIPE if PIPE is not None: return PIPE dtype = torch.float16 if torch.cuda.is_available() else torch.float32 model_id = "stabilityai/stable-diffusion-x4-upscaler" PIPE = StableDiffusionUpscalePipeline.from_pretrained(model_id, torch_dtype=dtype) device = "cuda" if torch.cuda.is_available() else "cpu" PIPE = PIPE.to(device) return PIPE def upscale(image: Union[str, np.ndarray], scale: float = 4.0) -> np.ndarray: """Upscale using Stable Diffusion x4 Upscaler and return an RGB numpy array. The pipeline inherently performs 4x upscaling; if a smaller scale is requested, we will downscale the 4x result to the requested size. """ if image is None: raise gr.Error("No image provided") scale = max(1.0, min(float(scale or 4.0), 4.0)) pipe = get_pipe() # Accept either a filepath (preferred for robustness) or a numpy array if isinstance(image, str): if not os.path.exists(image): raise gr.Error("Uploaded image not found. Please re-upload and try again.") pil = Image.open(image).convert("RGB") else: pil = Image.fromarray(image) # Run the upscaler (4x) # Use neutral prompt and zero guidance for faithful upscaling result = pipe(prompt="", image=pil, num_inference_steps=20, guidance_scale=0.0) out: Image.Image = result.images[0] if scale < 4.0: w, h = pil.size target = (int(round(w * scale)), int(round(h * scale))) out = out.resize(target, Image.LANCZOS) return np.array(out) demo = gr.Interface( fn=upscale, inputs=[ gr.Image(type="filepath", label="image"), gr.Slider(1.0, 4.0, step=0.5, value=4.0, label="scale"), ], outputs=gr.Image(type="numpy", label="output"), title="SD x4 Upscaler API", description=( "Upscale images using Stability AI's Stable Diffusion x4 upscaler. The input 'scale' controls the final size (<4x downscales the 4x result). " "Use the 'View API' button on this Space to see the exact /predict schema." ), allow_flagging="never", ) if __name__ == "__main__": demo.queue().launch()