Spaces:
Sleeping
Sleeping
import gradio as gr | |
from diffusers import StableDiffusionPipeline | |
import torch | |
import re | |
def split_into_sentences(text): | |
""" | |
Splits the input text into individual sentences. | |
This helps in identifying key scenes for image generation. | |
""" | |
# Simple sentence splitter based on punctuation | |
sentences = re.split(r'(?<=[.!?]) +', text) | |
return sentences | |
def generate_comic_strip(story): | |
""" | |
Generates a comic strip from the input story. | |
Parameters: | |
- story (str): The user's story prompt. | |
Returns: | |
- comic_strip (list): A list of generated images representing each scene. | |
""" | |
if pipe is None: | |
return ["https://via.placeholder.com/512x512.png?text=Model+Not+Loaded"] | |
# Split the story into sentences to identify key scenes | |
scenes = split_into_sentences(story) | |
# Limit the number of scenes to prevent excessive image generation | |
max_scenes = 3 | |
scenes = scenes[:max_scenes] | |
comic_strip = [] | |
for idx, scene in enumerate(scenes): | |
try: | |
# Generate image for each scene with optimizations | |
image = pipe( | |
scene, | |
num_inference_steps=20, # Reduced steps for faster generation | |
height=256, # Reduced resolution | |
width=256, # Reduced resolution | |
guidance_scale=7.5, # Default guidance scale | |
).images[0] | |
comic_strip.append(image) | |
except Exception as e: | |
# In case of any error during image generation, append a placeholder image | |
print(f"Error generating image for scene {idx+1}: {e}") | |
comic_strip.append("https://via.placeholder.com/512x512.png?text=Image+Unavailable") | |
return comic_strip | |
def main(): | |
""" | |
Sets up the Gradio interface for the GenArt Narrative application. | |
""" | |
# Define the input component: A textbox for the user to input their story | |
input_text = gr.Textbox( | |
lines=5, | |
placeholder="Enter your short story here...", | |
label="Story Prompt" | |
) | |
# Define the output component: A gallery to display the generated comic strip | |
output_gallery = gr.Gallery( | |
label="Generated Comic Strip", | |
columns=3, | |
object_fit="contain", | |
height="auto" | |
) | |
# Create the Gradio interface | |
iface = gr.Interface( | |
fn=generate_comic_strip, # Function to process input and generate output | |
inputs=input_text, # Input component | |
outputs=output_gallery, # Output component | |
title="GenArt Narrative", # Title of the app | |
description="Transform your short stories into engaging comic strips using AI-powered image generation.", # Description | |
examples=[ # Example inputs for demonstration | |
["A young wizard discovers a hidden magical forest and befriends a talking owl."], | |
["An astronaut lands on a distant planet and encounters alien life forms."] | |
], | |
allow_flagging="never", # Disable flagging of outputs | |
theme="default", # You can choose other themes like "huggingface" | |
) | |
# Launch the Gradio app | |
iface.launch() | |
# Load the model globally to avoid reloading it for each request | |
pipe = None | |
try: | |
print("Loading Stable Diffusion model...") | |
# Initialize the Stable Diffusion pipeline with optimizations | |
pipe = StableDiffusionPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-2-1-base", # Model name | |
torch_dtype=torch.float32, # Use float32 for CPU | |
low_cpu_mem_usage=True, # Optimize for CPU usage | |
safety_checker=None, # Disable safety checker to speed up loading | |
force_download=True # Force download to avoid resume_download warning | |
) | |
pipe = pipe.to("cpu") # Move the model to CPU | |
pipe.enable_attention_slicing() # Reduce memory usage | |
print("Model loaded successfully.") | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
pipe = None | |
if __name__ == "__main__": | |
if pipe is not None: | |
main() | |
else: | |
print("Failed to load the model. Please check the error messages above.") | |