farouk1's picture
Create app.py
d4d58c1 verified
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch # Import torch for device management
import os # For file operations
# --- Configuration and Model Loading ---
# You can choose a different model here if you have access to more powerful ones.
# For larger models, ensure you have sufficient VRAM (GPU memory).
# For CPU, smaller models might be necessary or use quantization.
MODEL_NAME = "google/flan-t5-large" # Changed to 'large' for slightly better performance than 'base' and still manageable.
# If you have a powerful GPU, consider "google/flan-t5-xl" or even "google/flan-t5-xxl"
# For even larger models, consider using model.to(torch.bfloat16) or bitsandbytes for 4-bit loading if available.
try:
# Determine the device to use (GPU if available, else CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading model on device: {device}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Load model with half-precision (float16) to save VRAM if on GPU
# Or load in 8-bit/4-bit if using libraries like bitsandbytes (requires installation)
if device == "cuda":
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16).to(device)
else:
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(device)
model.eval() # Set model to evaluation mode
print(f"Model '{MODEL_NAME}' loaded successfully.")
except Exception as e:
print(f"Error loading model: {e}")
print("Please check your internet connection, model name, and available resources (RAM/VRAM).")
# Exit or handle gracefully if model loading fails
tokenizer, model = None, None
# --- Prompt Engineering Functions (more structured) ---
def create_arabic_prompt(topic, style):
if style == "Blog Post (Descriptive)":
return f"اكتب مقالاً احترافياً بأسلوب شخصي عن: {topic}. ركز على التفاصيل، الوصف الجذاب، قدم نصائح عملية. اجعل النص منسقاً بفقرات وعناوين فرعية."
elif style == "Social Media Post (Short & Catchy)":
return f"اكتب منشوراً قصيراً وجذاباً ومثيراً للتفاعل عن: {topic}. أضف 2-3 إيموجي مناسبة واقترح 4 هاشتاغات شائعة. ابدأ بسؤال أو جملة جذابة."
else: # Video Script (Storytelling)
return f"اكتب سيناريو فيديو احترافي ومقنع عن: {topic}. اجعل الأسلوب قصصي وسردي، مقسماً إلى مشاهد رئيسية، مع اقتراح لقطات بصرية (B-roll) وأصوات (SFX) لكل مشهد. ركز على إثارة المشاعر."
def create_english_prompt(topic, style):
if style == "Blog Post (Descriptive)":
return f"Write a detailed and professional blog post about: {topic}. Focus on personal insights, vivid descriptions, and practical advice. Structure it with clear paragraphs and subheadings."
elif style == "Social Media Post (Short & Catchy)":
return f"Write a short, catchy, and engaging social media post about: {topic}. Include 2-3 relevant emojis and suggest 4 trending hashtags. Start with a hook question or statement."
else: # Video Script (Storytelling)
return f"Write a professional, compelling video script about: {topic}. Make it emotionally engaging and story-driven, divided into key scenes, with suggested visual shots (B-roll) and sound effects (SFX) for each scene."
# --- Content Generation Function ---
@torch.no_grad() # Disable gradient calculations for inference to save memory
def generate_content(topic, style_choice, lang_choice, length_choice, creativity, detail_level, diversity_penalty):
if tokenizer is None or model is None:
return "⚠️ Error: Model not loaded. Please check the console for details."
if not topic:
return "⚠️ Please enter a topic to generate content."
# Max length based on desired length and model's context window
# Flan-T5 has a context window of 512, so max_length should be within this.
if length_choice == "Short":
max_new_tokens = 150
min_new_tokens = 50
elif length_choice == "Medium":
max_new_tokens = 300
min_new_tokens = 100
else: # Long
max_new_tokens = 450 # Max for Flan-T5 effectively
min_new_tokens = 150
# Adjust generation parameters based on user input
temperature = creativity # Direct mapping
top_p = detail_level # Direct mapping, higher means more detail/diversity
no_repeat_ngram_size = diversity_penalty # Higher means less repetition
# Build the prompt
if lang_choice == "Arabic":
prompt = create_arabic_prompt(topic, style_choice)
else: # English
prompt = create_english_prompt(topic, style_choice)
# Add detail level instruction to prompt if high
if detail_level > 0.7: # Only if user explicitly wants high detail
prompt += " Ensure comprehensive coverage and rich descriptions."
if creativity > 0.8:
prompt += " Be highly creative and imaginative in your writing."
try:
inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True).to(device)
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
num_beams=5, # Beam search for better quality
do_sample=True, # Enable sampling for creativity
temperature=temperature,
top_p=top_p,
top_k=50, # Consider top 50 words
no_repeat_ngram_size=no_repeat_ngram_size,
length_penalty=1.0, # Adjust to control output length
early_stopping=True
)
content = tokenizer.decode(outputs[0], skip_special_tokens=True)
return content
except RuntimeError as e:
if "out of memory" in str(e):
return "⚠️ Generation failed: Out of memory. Try a shorter length, a less complex model, or restart the application if on GPU."
return f"⚠️ Generation failed due as runtime error: {str(e)}"
except Exception as e:
return f"⚠️ An unexpected error occurred during generation: {str(e)}"
# --- Gradio Interface ---
# Custom CSS for a more polished look
custom_css = """
h1, h2, h3 { color: #4B0082; } /* Dark Purple */
.gradio-container {
background-color: #F8F0FF; /* Light Lavender */
font-family: 'Segoe UI', sans-serif;
}
.gr-button {
background-color: #8A2BE2; /* Blue Violet */
color: white;
border-radius: 10px;
padding: 10px 20px;
font-size: 1.1em;
}
.gr-button:hover {
background-color: #9370DB; /* Medium Purple */
}
.gr-text-input, .gr-textarea {
border: 1px solid #DDA0DD; /* Plum */
border-radius: 8px;
padding: 10px;
}
.gradio-radio input:checked + label {
background-color: #DA70D6 !important; /* Orchid */
color: white !important;
}
.gradio-radio label {
border: 1px solid #DDA0DD;
border-radius: 8px;
padding: 8px 15px;
}
"""
with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as iface:
gr.Markdown("# ✨ AI Content Creation Studio")
gr.Markdown("## Generate professional blogs, social media posts, or video scripts in seconds!")
with gr.Row():
with gr.Column(scale=2):
topic = gr.Textbox(
label="Topic / الموضوع",
placeholder="e.g., The Future of AI in Healthcare / مثال: مستقبل الذكاء الاصطناعي في الرعاية الصحية",
lines=2
)
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
creativity = gr.Slider(
minimum=0.1, maximum=1.0, value=0.7, step=0.1,
label="Creativity (Temperature)",
info="Higher values lead to more creative, less predictable text. Lower values are more focused."
)
detail_level = gr.Slider(
minimum=0.1, maximum=1.0, value=0.9, step=0.1,
label="Detail Level (Top-p Sampling)",
info="Higher values allow for more diverse and detailed vocabulary. Lower values prune less likely words."
)
with gr.Row():
diversity_penalty = gr.Slider(
minimum=1, maximum=5, value=2, step=1,
label="Repetition Penalty (N-gram)",
info="Higher values reduce the chance of repeating the same phrases or words. Set to 1 for no penalty."
)
with gr.Column(scale=1):
with gr.Group():
style_choice = gr.Radio(
["Blog Post (Descriptive)", "Social Media Post (Short & Catchy)", "Video Script (Storytelling)"],
label="Content Style / نوع المحتوى",
value="Blog Post (Descriptive)",
interactive=True
)
with gr.Group():
lang_choice = gr.Radio(
["English", "Arabic"],
label="Language / اللغة",
value="English",
interactive=True
)
with gr.Group():
length_choice = gr.Radio(
["Short", "Medium", "Long"],
label="Content Length / طول النص",
value="Medium",
interactive=True
)
gr.Markdown("*(Note: 'Long' is relative to model capabilities, max ~450 words)*")
btn = gr.Button("🚀 Generate Content", variant="primary")
output = gr.Textbox(label="Generated Content", lines=20, interactive=True)
# Download button logic
def download_file(content):
if content and not content.startswith("⚠️"): # Only provide file if content is valid
file_path = "generated_content.txt"
with open(file_path, "w", encoding="utf-8") as f:
f.write(content)
return file_path
return None # Return None if no valid content to download
download_button = gr.DownloadButton("⬇️ Download Content", file_path=None, interactive=False)
# Event handlers
btn.click(
fn=generate_content,
inputs=[topic, style_choice, lang_choice, length_choice, creativity, detail_level, diversity_penalty],
outputs=output
)
# Enable download button only when there's valid content
output.change(fn=download_file, inputs=[output], outputs=[download_button])
if __name__ == "__main__":
iface.launch()