Spaces:
Running
Running
import streamlit as st | |
import torch | |
from diffusers import StableDiffusionPipeline, StableDiffusionUpscalePipeline | |
from PIL import Image | |
import io | |
st.set_page_config(page_title="AI Image Studio", layout="wide") | |
# ----------------- Model Loader ----------------- | |
def load_models(): | |
models = {} | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Generation model | |
try: | |
models["generation"] = StableDiffusionPipeline.from_pretrained( | |
"runwayml/stable-diffusion-v1-5", | |
torch_dtype=torch.float16 if device == "cuda" else torch.float32 | |
).to(device) | |
except Exception as e: | |
st.error(f"Error loading generation model: {e}") | |
models["generation"] = None | |
# Upscale model | |
try: | |
models["upscale"] = StableDiffusionUpscalePipeline.from_pretrained( | |
"stabilityai/stable-diffusion-x4-upscaler", | |
torch_dtype=torch.float16 if device == "cuda" else torch.float32 | |
).to(device) | |
except Exception as e: | |
st.error(f"Error loading upscale model: {e}") | |
models["upscale"] = None | |
return models | |
models = load_models() | |
# ----------------- UI ----------------- | |
st.title("🎨 AI Image Studio") | |
option = st.sidebar.radio("Choose Action", ["Generate Image", "Upscale Image", "History Gallery"]) | |
# ----------------- Generate Image ----------------- | |
if option == "Generate Image": | |
st.subheader("Generate a New Image") | |
prompt = st.text_area("Write your prompt:", placeholder="e.g., A futuristic city at sunset, cinematic lighting") | |
if st.button("Generate"): | |
if not prompt.strip(): | |
st.warning("Please write a prompt before generating.") | |
elif models["generation"] is None: | |
st.error("Image generation model not available.") | |
else: | |
try: | |
# Truncate long prompts | |
if len(prompt) > 500: | |
prompt = prompt[:500] | |
st.warning("Prompt too long. Truncated to 500 characters.") | |
with st.spinner("Creating image..."): | |
image = models["generation"](prompt=prompt, num_inference_steps=20).images[0] | |
# Show image | |
st.image(image, caption="Generated Image", use_container_width=True) | |
# Save in session | |
st.session_state["last_image"] = image | |
st.session_state["last_prompt"] = prompt | |
if "gallery" not in st.session_state: | |
st.session_state["gallery"] = [] | |
st.session_state["gallery"].append(("Generated", image)) | |
# Download button | |
buf = io.BytesIO() | |
image.save(buf, format="PNG") | |
st.download_button("Download Image", buf.getvalue(), "generated.png", "image/png") | |
# Extra actions | |
col1, col2 = st.columns(2) | |
with col1: | |
if st.button("Regenerate"): | |
st.session_state["regenerate"] = True | |
with col2: | |
if st.button("Modify Prompt"): | |
st.session_state["modify"] = True | |
except Exception as e: | |
st.error(f"Error generating image: {e}") | |
# Regenerate | |
if "regenerate" in st.session_state and st.session_state.get("last_prompt"): | |
try: | |
with st.spinner("Regenerating..."): | |
image = models["generation"](prompt=st.session_state["last_prompt"], num_inference_steps=20).images[0] | |
st.image(image, caption="Regenerated Image", use_container_width=True) | |
st.session_state["last_image"] = image | |
st.session_state["gallery"].append(("Regenerated", image)) | |
st.session_state.pop("regenerate") | |
except Exception as e: | |
st.error(f"Error regenerating: {e}") | |
# Modify prompt | |
if "modify" in st.session_state and st.session_state.get("last_prompt"): | |
new_prompt = st.text_area("Modify your previous prompt:", st.session_state["last_prompt"]) | |
if st.button("Generate from Modified Prompt"): | |
try: | |
if len(new_prompt) > 500: | |
new_prompt = new_prompt[:500] | |
st.warning("Prompt too long. Truncated to 500 characters.") | |
with st.spinner("Generating from modified prompt..."): | |
image = models["generation"](prompt=new_prompt, num_inference_steps=20).images[0] | |
st.image(image, caption="Modified Image", use_container_width=True) | |
st.session_state["last_image"] = image | |
st.session_state["last_prompt"] = new_prompt | |
st.session_state["gallery"].append(("Modified", image)) | |
st.session_state.pop("modify") | |
except Exception as e: | |
st.error(f"Error generating modified image: {e}") | |
# ----------------- Upscale Image ----------------- | |
elif option == "Upscale Image": | |
st.subheader("Upscale Your Image") | |
uploaded_file = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"]) | |
if st.button("Upscale"): | |
if uploaded_file is None and "last_image" not in st.session_state: | |
st.warning("Please upload an image or generate one first!") | |
elif models["upscale"] is None: | |
st.error("Upscaler model not loaded.") | |
else: | |
try: | |
with st.spinner("Upscaling..."): | |
img = Image.open(uploaded_file) if uploaded_file else st.session_state["last_image"] | |
# Always provide dummy prompt | |
upscaled_image = models["upscale"](prompt="upscale", image=img).images[0] | |
st.image(upscaled_image, caption="Upscaled Image", use_container_width=True) | |
buf = io.BytesIO() | |
upscaled_image.save(buf, format="PNG") | |
st.download_button("Download Upscaled Image", buf.getvalue(), "upscaled.png", "image/png") | |
if "gallery" not in st.session_state: | |
st.session_state["gallery"] = [] | |
st.session_state["gallery"].append(("Upscaled", upscaled_image)) | |
except Exception as e: | |
st.error(f"Error upscaling image: {e}") | |
# ----------------- History Gallery ----------------- | |
elif option == "History Gallery": | |
st.subheader("📸 Your Image Gallery") | |
if "gallery" in st.session_state and st.session_state["gallery"]: | |
for idx, (label, img) in enumerate(st.session_state["gallery"]): | |
st.image(img, caption=f"{label} #{idx+1}", use_container_width=True) | |
else: | |
st.info("No images in gallery yet. Generate or upscale to see history.") | |