import gradio as gr import torch from transformers import pipeline import os # --- App Configuration --- TITLE = "✍️ AI Story Outliner" DESCRIPTION = """ Enter a prompt and get 10 unique story outlines from a CPU-friendly AI model. The app uses **DistilGPT-2**, a reliable and lightweight model, to generate creative outlines. **How it works:** 1. Enter your story idea. 2. The AI will generate 10 different story outlines. 3. Each outline has a dramatic beginning and is concise, like a song. """ # --- Example Prompts for Storytelling --- examples = [ ["The old lighthouse keeper stared into the storm. He'd seen many tempests, but this one was different. This one had eyes..."], ["In a city powered by dreams, a young inventor creates a machine that can record them. His first recording reveals a nightmare that doesn't belong to him."], ["The knight adjusted his helmet, the dragon's roar echoing in the valley. He was ready for the fight, but for what the dragon said when it finally spoke."], ["She found the old leather-bound journal in her grandfather's attic. The first entry read: 'To relieve stress, I walk in the woods. But today, the woods walked with me.'"], ["The meditation app promised to help her 'delete unhelpful thoughts.' She tapped the button, and to her horror, the memory of her own name began to fade..."] ] # --- Model Initialization --- # This section loads a smaller, stable, and CPU-friendly model that requires no authentication. generator = None model_error = None try: print("Initializing model... This may take a moment.") # Using 'distilgpt2', a stable and widely supported model that does not require a token. # This is much more suitable for a standard CPU environment. generator = pipeline( "text-generation", model="distilgpt2", torch_dtype=torch.float32, # Use float32 for wider CPU compatibility device_map="auto" # Will use GPU if available, otherwise CPU ) print("✅ distilgpt2 model loaded successfully!") except Exception as e: model_error = e print(f"--- 🚨 Error loading model ---") print(f"Error: {model_error}") # --- App Logic --- def generate_stories(prompt: str) -> list[str]: """ Generates 10 story outlines from the loaded model based on the user's prompt. """ print("--- Button clicked. Attempting to generate stories... ---") # If the model failed to load during startup, display that error. if model_error: error_message = f"**Model failed to load during startup.**\n\nPlease check the console logs for details.\n\n**Error:**\n`{str(model_error)}`" print(f"Returning startup error: {error_message}") return [error_message] * 10 if not prompt: # Return a list of 10 empty strings to clear the outputs return [""] * 10 try: # A generic story prompt that works well with models like GPT-2. story_prompt = f""" Story Idea: "{prompt}" Create a short story outline based on this idea. ### 🎬 The Hook A dramatic opening. ### 🎼 The Ballad The main story, told concisely. ### 🔚 The Finale A clear and satisfying ending. --- """ # Parameters for the pipeline to generate 10 diverse results. params = { "max_new_tokens": 200, "num_return_sequences": 10, "do_sample": True, "temperature": 0.9, "top_k": 50, "pad_token_id": generator.tokenizer.eos_token_id } print("Generating text with the model...") # Generate 10 different story variations outputs = generator(story_prompt, **params) print("✅ Text generation complete.") # Extract the generated text. stories = [] for out in outputs: full_text = out['generated_text'] stories.append(full_text) # Ensure we return exactly 10 stories, padding if necessary. while len(stories) < 10: stories.append("Failed to generate a story for this slot.") return stories except Exception as e: # Catch any errors that happen DURING generation and display them in the UI. print(f"--- 🚨 Error during story generation ---") print(f"Error: {e}") runtime_error_message = f"**An error occurred during story generation.**\n\nPlease check the console logs for details.\n\n**Error:**\n`{str(e)}`" return [runtime_error_message] * 10 # --- Gradio Interface --- with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 95% !important;}") as demo: gr.Markdown(f"