story / app.py
geethareddy's picture
Update app.py
52844a5 verified
import gradio as gr
from diffusers import StableDiffusionPipeline
import torch
import re
def split_into_sentences(text):
"""
Splits the input text into individual sentences.
This helps in identifying key scenes for image generation.
"""
# Simple sentence splitter based on punctuation
sentences = re.split(r'(?<=[.!?]) +', text)
return sentences
def generate_comic_strip(story):
"""
Generates a comic strip from the input story.
Parameters:
- story (str): The user's story prompt.
Returns:
- comic_strip (list): A list of generated images representing each scene.
"""
if pipe is None:
return ["https://via.placeholder.com/512x512.png?text=Model+Not+Loaded"]
# Split the story into sentences to identify key scenes
scenes = split_into_sentences(story)
# Limit the number of scenes to prevent excessive image generation
max_scenes = 3
scenes = scenes[:max_scenes]
comic_strip = []
for idx, scene in enumerate(scenes):
try:
# Generate image for each scene with optimizations
image = pipe(
scene,
num_inference_steps=20, # Reduced steps for faster generation
height=256, # Reduced resolution
width=256, # Reduced resolution
guidance_scale=7.5, # Default guidance scale
).images[0]
comic_strip.append(image)
except Exception as e:
# In case of any error during image generation, append a placeholder image
print(f"Error generating image for scene {idx+1}: {e}")
comic_strip.append("https://via.placeholder.com/512x512.png?text=Image+Unavailable")
return comic_strip
def main():
"""
Sets up the Gradio interface for the GenArt Narrative application.
"""
# Define the input component: A textbox for the user to input their story
input_text = gr.Textbox(
lines=5,
placeholder="Enter your short story here...",
label="Story Prompt"
)
# Define the output component: A gallery to display the generated comic strip
output_gallery = gr.Gallery(
label="Generated Comic Strip",
columns=3,
object_fit="contain",
height="auto"
)
# Create the Gradio interface
iface = gr.Interface(
fn=generate_comic_strip, # Function to process input and generate output
inputs=input_text, # Input component
outputs=output_gallery, # Output component
title="GenArt Narrative", # Title of the app
description="Transform your short stories into engaging comic strips using AI-powered image generation.", # Description
examples=[ # Example inputs for demonstration
["A young wizard discovers a hidden magical forest and befriends a talking owl."],
["An astronaut lands on a distant planet and encounters alien life forms."]
],
allow_flagging="never", # Disable flagging of outputs
theme="default", # You can choose other themes like "huggingface"
)
# Launch the Gradio app
iface.launch()
# Load the model globally to avoid reloading it for each request
pipe = None
try:
print("Loading Stable Diffusion model...")
# Initialize the Stable Diffusion pipeline with optimizations
pipe = StableDiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-1-base", # Model name
torch_dtype=torch.float32, # Use float32 for CPU
low_cpu_mem_usage=True, # Optimize for CPU usage
safety_checker=None, # Disable safety checker to speed up loading
force_download=True # Force download to avoid resume_download warning
)
pipe = pipe.to("cpu") # Move the model to CPU
pipe.enable_attention_slicing() # Reduce memory usage
print("Model loaded successfully.")
except Exception as e:
print(f"Error loading model: {e}")
pipe = None
if __name__ == "__main__":
if pipe is not None:
main()
else:
print("Failed to load the model. Please check the error messages above.")