import gradio as gr from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import torch # Import torch for device management import os # For file operations # --- Configuration and Model Loading --- # You can choose a different model here if you have access to more powerful ones. # For larger models, ensure you have sufficient VRAM (GPU memory). # For CPU, smaller models might be necessary or use quantization. MODEL_NAME = "google/flan-t5-large" # Changed to 'large' for slightly better performance than 'base' and still manageable. # If you have a powerful GPU, consider "google/flan-t5-xl" or even "google/flan-t5-xxl" # For even larger models, consider using model.to(torch.bfloat16) or bitsandbytes for 4-bit loading if available. try: # Determine the device to use (GPU if available, else CPU) device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Loading model on device: {device}") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) # Load model with half-precision (float16) to save VRAM if on GPU # Or load in 8-bit/4-bit if using libraries like bitsandbytes (requires installation) if device == "cuda": model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16).to(device) else: model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(device) model.eval() # Set model to evaluation mode print(f"Model '{MODEL_NAME}' loaded successfully.") except Exception as e: print(f"Error loading model: {e}") print("Please check your internet connection, model name, and available resources (RAM/VRAM).") # Exit or handle gracefully if model loading fails tokenizer, model = None, None # --- Prompt Engineering Functions (more structured) --- def create_arabic_prompt(topic, style): if style == "Blog Post (Descriptive)": return f"اكتب مقالاً احترافياً بأسلوب شخصي عن: {topic}. ركز على التفاصيل، الوصف الجذاب، قدم نصائح عملية. اجعل النص منسقاً بفقرات وعناوين فرعية." elif style == "Social Media Post (Short & Catchy)": return f"اكتب منشوراً قصيراً وجذاباً ومثيراً للتفاعل عن: {topic}. أضف 2-3 إيموجي مناسبة واقترح 4 هاشتاغات شائعة. ابدأ بسؤال أو جملة جذابة." else: # Video Script (Storytelling) return f"اكتب سيناريو فيديو احترافي ومقنع عن: {topic}. اجعل الأسلوب قصصي وسردي، مقسماً إلى مشاهد رئيسية، مع اقتراح لقطات بصرية (B-roll) وأصوات (SFX) لكل مشهد. ركز على إثارة المشاعر." def create_english_prompt(topic, style): if style == "Blog Post (Descriptive)": return f"Write a detailed and professional blog post about: {topic}. Focus on personal insights, vivid descriptions, and practical advice. Structure it with clear paragraphs and subheadings." elif style == "Social Media Post (Short & Catchy)": return f"Write a short, catchy, and engaging social media post about: {topic}. Include 2-3 relevant emojis and suggest 4 trending hashtags. Start with a hook question or statement." else: # Video Script (Storytelling) return f"Write a professional, compelling video script about: {topic}. Make it emotionally engaging and story-driven, divided into key scenes, with suggested visual shots (B-roll) and sound effects (SFX) for each scene." # --- Content Generation Function --- @torch.no_grad() # Disable gradient calculations for inference to save memory def generate_content(topic, style_choice, lang_choice, length_choice, creativity, detail_level, diversity_penalty): if tokenizer is None or model is None: return "⚠️ Error: Model not loaded. Please check the console for details." if not topic: return "⚠️ Please enter a topic to generate content." # Max length based on desired length and model's context window # Flan-T5 has a context window of 512, so max_length should be within this. if length_choice == "Short": max_new_tokens = 150 min_new_tokens = 50 elif length_choice == "Medium": max_new_tokens = 300 min_new_tokens = 100 else: # Long max_new_tokens = 450 # Max for Flan-T5 effectively min_new_tokens = 150 # Adjust generation parameters based on user input temperature = creativity # Direct mapping top_p = detail_level # Direct mapping, higher means more detail/diversity no_repeat_ngram_size = diversity_penalty # Higher means less repetition # Build the prompt if lang_choice == "Arabic": prompt = create_arabic_prompt(topic, style_choice) else: # English prompt = create_english_prompt(topic, style_choice) # Add detail level instruction to prompt if high if detail_level > 0.7: # Only if user explicitly wants high detail prompt += " Ensure comprehensive coverage and rich descriptions." if creativity > 0.8: prompt += " Be highly creative and imaginative in your writing." try: inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True).to(device) outputs = model.generate( **inputs, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, num_beams=5, # Beam search for better quality do_sample=True, # Enable sampling for creativity temperature=temperature, top_p=top_p, top_k=50, # Consider top 50 words no_repeat_ngram_size=no_repeat_ngram_size, length_penalty=1.0, # Adjust to control output length early_stopping=True ) content = tokenizer.decode(outputs[0], skip_special_tokens=True) return content except RuntimeError as e: if "out of memory" in str(e): return "⚠️ Generation failed: Out of memory. Try a shorter length, a less complex model, or restart the application if on GPU." return f"⚠️ Generation failed due as runtime error: {str(e)}" except Exception as e: return f"⚠️ An unexpected error occurred during generation: {str(e)}" # --- Gradio Interface --- # Custom CSS for a more polished look custom_css = """ h1, h2, h3 { color: #4B0082; } /* Dark Purple */ .gradio-container { background-color: #F8F0FF; /* Light Lavender */ font-family: 'Segoe UI', sans-serif; } .gr-button { background-color: #8A2BE2; /* Blue Violet */ color: white; border-radius: 10px; padding: 10px 20px; font-size: 1.1em; } .gr-button:hover { background-color: #9370DB; /* Medium Purple */ } .gr-text-input, .gr-textarea { border: 1px solid #DDA0DD; /* Plum */ border-radius: 8px; padding: 10px; } .gradio-radio input:checked + label { background-color: #DA70D6 !important; /* Orchid */ color: white !important; } .gradio-radio label { border: 1px solid #DDA0DD; border-radius: 8px; padding: 8px 15px; } """ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as iface: gr.Markdown("# ✨ AI Content Creation Studio") gr.Markdown("## Generate professional blogs, social media posts, or video scripts in seconds!") with gr.Row(): with gr.Column(scale=2): topic = gr.Textbox( label="Topic / الموضوع", placeholder="e.g., The Future of AI in Healthcare / مثال: مستقبل الذكاء الاصطناعي في الرعاية الصحية", lines=2 ) with gr.Accordion("Advanced Settings", open=False): with gr.Row(): creativity = gr.Slider( minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Creativity (Temperature)", info="Higher values lead to more creative, less predictable text. Lower values are more focused." ) detail_level = gr.Slider( minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Detail Level (Top-p Sampling)", info="Higher values allow for more diverse and detailed vocabulary. Lower values prune less likely words." ) with gr.Row(): diversity_penalty = gr.Slider( minimum=1, maximum=5, value=2, step=1, label="Repetition Penalty (N-gram)", info="Higher values reduce the chance of repeating the same phrases or words. Set to 1 for no penalty." ) with gr.Column(scale=1): with gr.Group(): style_choice = gr.Radio( ["Blog Post (Descriptive)", "Social Media Post (Short & Catchy)", "Video Script (Storytelling)"], label="Content Style / نوع المحتوى", value="Blog Post (Descriptive)", interactive=True ) with gr.Group(): lang_choice = gr.Radio( ["English", "Arabic"], label="Language / اللغة", value="English", interactive=True ) with gr.Group(): length_choice = gr.Radio( ["Short", "Medium", "Long"], label="Content Length / طول النص", value="Medium", interactive=True ) gr.Markdown("*(Note: 'Long' is relative to model capabilities, max ~450 words)*") btn = gr.Button("🚀 Generate Content", variant="primary") output = gr.Textbox(label="Generated Content", lines=20, interactive=True) # Download button logic def download_file(content): if content and not content.startswith("⚠️"): # Only provide file if content is valid file_path = "generated_content.txt" with open(file_path, "w", encoding="utf-8") as f: f.write(content) return file_path return None # Return None if no valid content to download download_button = gr.DownloadButton("⬇️ Download Content", file_path=None, interactive=False) # Event handlers btn.click( fn=generate_content, inputs=[topic, style_choice, lang_choice, length_choice, creativity, detail_level, diversity_penalty], outputs=output ) # Enable download button only when there's valid content output.change(fn=download_file, inputs=[output], outputs=[download_button]) if __name__ == "__main__": iface.launch()