Image_Creation / app.py
Engr-Usman-Ali's picture
Update app.py
8eeadf6 verified
import streamlit as st
import torch
from diffusers import StableDiffusionPipeline, StableDiffusionUpscalePipeline
from PIL import Image
import io
st.set_page_config(page_title="AI Image Studio", layout="wide")
# ----------------- Model Loader -----------------
@st.cache_resource
def load_models():
models = {}
device = "cuda" if torch.cuda.is_available() else "cpu"
# Generation model
try:
models["generation"] = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16 if device == "cuda" else torch.float32
).to(device)
except Exception as e:
st.error(f"Error loading generation model: {e}")
models["generation"] = None
# Upscale model
try:
models["upscale"] = StableDiffusionUpscalePipeline.from_pretrained(
"stabilityai/stable-diffusion-x4-upscaler",
torch_dtype=torch.float16 if device == "cuda" else torch.float32
).to(device)
except Exception as e:
st.error(f"Error loading upscale model: {e}")
models["upscale"] = None
return models
models = load_models()
# ----------------- UI -----------------
st.title("🎨 AI Image Studio")
option = st.sidebar.radio("Choose Action", ["Generate Image", "Upscale Image", "History Gallery"])
# ----------------- Generate Image -----------------
if option == "Generate Image":
st.subheader("Generate a New Image")
prompt = st.text_area("Write your prompt:", placeholder="e.g., A futuristic city at sunset, cinematic lighting")
if st.button("Generate"):
if not prompt.strip():
st.warning("Please write a prompt before generating.")
elif models["generation"] is None:
st.error("Image generation model not available.")
else:
try:
# Truncate long prompts
if len(prompt) > 500:
prompt = prompt[:500]
st.warning("Prompt too long. Truncated to 500 characters.")
with st.spinner("Creating image..."):
image = models["generation"](prompt=prompt, num_inference_steps=20).images[0]
# Show image
st.image(image, caption="Generated Image", use_container_width=True)
# Save in session
st.session_state["last_image"] = image
st.session_state["last_prompt"] = prompt
if "gallery" not in st.session_state:
st.session_state["gallery"] = []
st.session_state["gallery"].append(("Generated", image))
# Download button
buf = io.BytesIO()
image.save(buf, format="PNG")
st.download_button("Download Image", buf.getvalue(), "generated.png", "image/png")
# Extra actions
col1, col2 = st.columns(2)
with col1:
if st.button("Regenerate"):
st.session_state["regenerate"] = True
with col2:
if st.button("Modify Prompt"):
st.session_state["modify"] = True
except Exception as e:
st.error(f"Error generating image: {e}")
# Regenerate
if "regenerate" in st.session_state and st.session_state.get("last_prompt"):
try:
with st.spinner("Regenerating..."):
image = models["generation"](prompt=st.session_state["last_prompt"], num_inference_steps=20).images[0]
st.image(image, caption="Regenerated Image", use_container_width=True)
st.session_state["last_image"] = image
st.session_state["gallery"].append(("Regenerated", image))
st.session_state.pop("regenerate")
except Exception as e:
st.error(f"Error regenerating: {e}")
# Modify prompt
if "modify" in st.session_state and st.session_state.get("last_prompt"):
new_prompt = st.text_area("Modify your previous prompt:", st.session_state["last_prompt"])
if st.button("Generate from Modified Prompt"):
try:
if len(new_prompt) > 500:
new_prompt = new_prompt[:500]
st.warning("Prompt too long. Truncated to 500 characters.")
with st.spinner("Generating from modified prompt..."):
image = models["generation"](prompt=new_prompt, num_inference_steps=20).images[0]
st.image(image, caption="Modified Image", use_container_width=True)
st.session_state["last_image"] = image
st.session_state["last_prompt"] = new_prompt
st.session_state["gallery"].append(("Modified", image))
st.session_state.pop("modify")
except Exception as e:
st.error(f"Error generating modified image: {e}")
# ----------------- Upscale Image -----------------
elif option == "Upscale Image":
st.subheader("Upscale Your Image")
uploaded_file = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
if st.button("Upscale"):
if uploaded_file is None and "last_image" not in st.session_state:
st.warning("Please upload an image or generate one first!")
elif models["upscale"] is None:
st.error("Upscaler model not loaded.")
else:
try:
with st.spinner("Upscaling..."):
img = Image.open(uploaded_file) if uploaded_file else st.session_state["last_image"]
# Always provide dummy prompt
upscaled_image = models["upscale"](prompt="upscale", image=img).images[0]
st.image(upscaled_image, caption="Upscaled Image", use_container_width=True)
buf = io.BytesIO()
upscaled_image.save(buf, format="PNG")
st.download_button("Download Upscaled Image", buf.getvalue(), "upscaled.png", "image/png")
if "gallery" not in st.session_state:
st.session_state["gallery"] = []
st.session_state["gallery"].append(("Upscaled", upscaled_image))
except Exception as e:
st.error(f"Error upscaling image: {e}")
# ----------------- History Gallery -----------------
elif option == "History Gallery":
st.subheader("📸 Your Image Gallery")
if "gallery" in st.session_state and st.session_state["gallery"]:
for idx, (label, img) in enumerate(st.session_state["gallery"]):
st.image(img, caption=f"{label} #{idx+1}", use_container_width=True)
else:
st.info("No images in gallery yet. Generate or upscale to see history.")