import gradio as gr from diffusers import StableDiffusionPipeline import torch import re def split_into_sentences(text): """ Splits the input text into individual sentences. This helps in identifying key scenes for image generation. """ # Simple sentence splitter based on punctuation sentences = re.split(r'(?<=[.!?]) +', text) return sentences def generate_comic_strip(story): """ Generates a comic strip from the input story. Parameters: - story (str): The user's story prompt. Returns: - comic_strip (list): A list of generated images representing each scene. """ if pipe is None: return ["https://via.placeholder.com/512x512.png?text=Model+Not+Loaded"] # Split the story into sentences to identify key scenes scenes = split_into_sentences(story) # Limit the number of scenes to prevent excessive image generation max_scenes = 3 scenes = scenes[:max_scenes] comic_strip = [] for idx, scene in enumerate(scenes): try: # Generate image for each scene with optimizations image = pipe( scene, num_inference_steps=20, # Reduced steps for faster generation height=256, # Reduced resolution width=256, # Reduced resolution guidance_scale=7.5, # Default guidance scale ).images[0] comic_strip.append(image) except Exception as e: # In case of any error during image generation, append a placeholder image print(f"Error generating image for scene {idx+1}: {e}") comic_strip.append("https://via.placeholder.com/512x512.png?text=Image+Unavailable") return comic_strip def main(): """ Sets up the Gradio interface for the GenArt Narrative application. """ # Define the input component: A textbox for the user to input their story input_text = gr.Textbox( lines=5, placeholder="Enter your short story here...", label="Story Prompt" ) # Define the output component: A gallery to display the generated comic strip output_gallery = gr.Gallery( label="Generated Comic Strip", columns=3, object_fit="contain", height="auto" ) # Create the Gradio interface iface = gr.Interface( fn=generate_comic_strip, # Function to process input and generate output inputs=input_text, # Input component outputs=output_gallery, # Output component title="GenArt Narrative", # Title of the app description="Transform your short stories into engaging comic strips using AI-powered image generation.", # Description examples=[ # Example inputs for demonstration ["A young wizard discovers a hidden magical forest and befriends a talking owl."], ["An astronaut lands on a distant planet and encounters alien life forms."] ], allow_flagging="never", # Disable flagging of outputs theme="default", # You can choose other themes like "huggingface" ) # Launch the Gradio app iface.launch() # Load the model globally to avoid reloading it for each request pipe = None try: print("Loading Stable Diffusion model...") # Initialize the Stable Diffusion pipeline with optimizations pipe = StableDiffusionPipeline.from_pretrained( "stabilityai/stable-diffusion-2-1-base", # Model name torch_dtype=torch.float32, # Use float32 for CPU low_cpu_mem_usage=True, # Optimize for CPU usage safety_checker=None, # Disable safety checker to speed up loading force_download=True # Force download to avoid resume_download warning ) pipe = pipe.to("cpu") # Move the model to CPU pipe.enable_attention_slicing() # Reduce memory usage print("Model loaded successfully.") except Exception as e: print(f"Error loading model: {e}") pipe = None if __name__ == "__main__": if pipe is not None: main() else: print("Failed to load the model. Please check the error messages above.")