Spaces:
Build error
Build error
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import torch # Import torch for device management | |
import os # For file operations | |
# --- Configuration and Model Loading --- | |
# You can choose a different model here if you have access to more powerful ones. | |
# For larger models, ensure you have sufficient VRAM (GPU memory). | |
# For CPU, smaller models might be necessary or use quantization. | |
MODEL_NAME = "google/flan-t5-large" # Changed to 'large' for slightly better performance than 'base' and still manageable. | |
# If you have a powerful GPU, consider "google/flan-t5-xl" or even "google/flan-t5-xxl" | |
# For even larger models, consider using model.to(torch.bfloat16) or bitsandbytes for 4-bit loading if available. | |
try: | |
# Determine the device to use (GPU if available, else CPU) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Loading model on device: {device}") | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
# Load model with half-precision (float16) to save VRAM if on GPU | |
# Or load in 8-bit/4-bit if using libraries like bitsandbytes (requires installation) | |
if device == "cuda": | |
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16).to(device) | |
else: | |
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(device) | |
model.eval() # Set model to evaluation mode | |
print(f"Model '{MODEL_NAME}' loaded successfully.") | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
print("Please check your internet connection, model name, and available resources (RAM/VRAM).") | |
# Exit or handle gracefully if model loading fails | |
tokenizer, model = None, None | |
# --- Prompt Engineering Functions (more structured) --- | |
def create_arabic_prompt(topic, style): | |
if style == "Blog Post (Descriptive)": | |
return f"اكتب مقالاً احترافياً بأسلوب شخصي عن: {topic}. ركز على التفاصيل، الوصف الجذاب، قدم نصائح عملية. اجعل النص منسقاً بفقرات وعناوين فرعية." | |
elif style == "Social Media Post (Short & Catchy)": | |
return f"اكتب منشوراً قصيراً وجذاباً ومثيراً للتفاعل عن: {topic}. أضف 2-3 إيموجي مناسبة واقترح 4 هاشتاغات شائعة. ابدأ بسؤال أو جملة جذابة." | |
else: # Video Script (Storytelling) | |
return f"اكتب سيناريو فيديو احترافي ومقنع عن: {topic}. اجعل الأسلوب قصصي وسردي، مقسماً إلى مشاهد رئيسية، مع اقتراح لقطات بصرية (B-roll) وأصوات (SFX) لكل مشهد. ركز على إثارة المشاعر." | |
def create_english_prompt(topic, style): | |
if style == "Blog Post (Descriptive)": | |
return f"Write a detailed and professional blog post about: {topic}. Focus on personal insights, vivid descriptions, and practical advice. Structure it with clear paragraphs and subheadings." | |
elif style == "Social Media Post (Short & Catchy)": | |
return f"Write a short, catchy, and engaging social media post about: {topic}. Include 2-3 relevant emojis and suggest 4 trending hashtags. Start with a hook question or statement." | |
else: # Video Script (Storytelling) | |
return f"Write a professional, compelling video script about: {topic}. Make it emotionally engaging and story-driven, divided into key scenes, with suggested visual shots (B-roll) and sound effects (SFX) for each scene." | |
# --- Content Generation Function --- | |
# Disable gradient calculations for inference to save memory | |
def generate_content(topic, style_choice, lang_choice, length_choice, creativity, detail_level, diversity_penalty): | |
if tokenizer is None or model is None: | |
return "⚠️ Error: Model not loaded. Please check the console for details." | |
if not topic: | |
return "⚠️ Please enter a topic to generate content." | |
# Max length based on desired length and model's context window | |
# Flan-T5 has a context window of 512, so max_length should be within this. | |
if length_choice == "Short": | |
max_new_tokens = 150 | |
min_new_tokens = 50 | |
elif length_choice == "Medium": | |
max_new_tokens = 300 | |
min_new_tokens = 100 | |
else: # Long | |
max_new_tokens = 450 # Max for Flan-T5 effectively | |
min_new_tokens = 150 | |
# Adjust generation parameters based on user input | |
temperature = creativity # Direct mapping | |
top_p = detail_level # Direct mapping, higher means more detail/diversity | |
no_repeat_ngram_size = diversity_penalty # Higher means less repetition | |
# Build the prompt | |
if lang_choice == "Arabic": | |
prompt = create_arabic_prompt(topic, style_choice) | |
else: # English | |
prompt = create_english_prompt(topic, style_choice) | |
# Add detail level instruction to prompt if high | |
if detail_level > 0.7: # Only if user explicitly wants high detail | |
prompt += " Ensure comprehensive coverage and rich descriptions." | |
if creativity > 0.8: | |
prompt += " Be highly creative and imaginative in your writing." | |
try: | |
inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True).to(device) | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=max_new_tokens, | |
min_new_tokens=min_new_tokens, | |
num_beams=5, # Beam search for better quality | |
do_sample=True, # Enable sampling for creativity | |
temperature=temperature, | |
top_p=top_p, | |
top_k=50, # Consider top 50 words | |
no_repeat_ngram_size=no_repeat_ngram_size, | |
length_penalty=1.0, # Adjust to control output length | |
early_stopping=True | |
) | |
content = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return content | |
except RuntimeError as e: | |
if "out of memory" in str(e): | |
return "⚠️ Generation failed: Out of memory. Try a shorter length, a less complex model, or restart the application if on GPU." | |
return f"⚠️ Generation failed due as runtime error: {str(e)}" | |
except Exception as e: | |
return f"⚠️ An unexpected error occurred during generation: {str(e)}" | |
# --- Gradio Interface --- | |
# Custom CSS for a more polished look | |
custom_css = """ | |
h1, h2, h3 { color: #4B0082; } /* Dark Purple */ | |
.gradio-container { | |
background-color: #F8F0FF; /* Light Lavender */ | |
font-family: 'Segoe UI', sans-serif; | |
} | |
.gr-button { | |
background-color: #8A2BE2; /* Blue Violet */ | |
color: white; | |
border-radius: 10px; | |
padding: 10px 20px; | |
font-size: 1.1em; | |
} | |
.gr-button:hover { | |
background-color: #9370DB; /* Medium Purple */ | |
} | |
.gr-text-input, .gr-textarea { | |
border: 1px solid #DDA0DD; /* Plum */ | |
border-radius: 8px; | |
padding: 10px; | |
} | |
.gradio-radio input:checked + label { | |
background-color: #DA70D6 !important; /* Orchid */ | |
color: white !important; | |
} | |
.gradio-radio label { | |
border: 1px solid #DDA0DD; | |
border-radius: 8px; | |
padding: 8px 15px; | |
} | |
""" | |
with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as iface: | |
gr.Markdown("# ✨ AI Content Creation Studio") | |
gr.Markdown("## Generate professional blogs, social media posts, or video scripts in seconds!") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
topic = gr.Textbox( | |
label="Topic / الموضوع", | |
placeholder="e.g., The Future of AI in Healthcare / مثال: مستقبل الذكاء الاصطناعي في الرعاية الصحية", | |
lines=2 | |
) | |
with gr.Accordion("Advanced Settings", open=False): | |
with gr.Row(): | |
creativity = gr.Slider( | |
minimum=0.1, maximum=1.0, value=0.7, step=0.1, | |
label="Creativity (Temperature)", | |
info="Higher values lead to more creative, less predictable text. Lower values are more focused." | |
) | |
detail_level = gr.Slider( | |
minimum=0.1, maximum=1.0, value=0.9, step=0.1, | |
label="Detail Level (Top-p Sampling)", | |
info="Higher values allow for more diverse and detailed vocabulary. Lower values prune less likely words." | |
) | |
with gr.Row(): | |
diversity_penalty = gr.Slider( | |
minimum=1, maximum=5, value=2, step=1, | |
label="Repetition Penalty (N-gram)", | |
info="Higher values reduce the chance of repeating the same phrases or words. Set to 1 for no penalty." | |
) | |
with gr.Column(scale=1): | |
with gr.Group(): | |
style_choice = gr.Radio( | |
["Blog Post (Descriptive)", "Social Media Post (Short & Catchy)", "Video Script (Storytelling)"], | |
label="Content Style / نوع المحتوى", | |
value="Blog Post (Descriptive)", | |
interactive=True | |
) | |
with gr.Group(): | |
lang_choice = gr.Radio( | |
["English", "Arabic"], | |
label="Language / اللغة", | |
value="English", | |
interactive=True | |
) | |
with gr.Group(): | |
length_choice = gr.Radio( | |
["Short", "Medium", "Long"], | |
label="Content Length / طول النص", | |
value="Medium", | |
interactive=True | |
) | |
gr.Markdown("*(Note: 'Long' is relative to model capabilities, max ~450 words)*") | |
btn = gr.Button("🚀 Generate Content", variant="primary") | |
output = gr.Textbox(label="Generated Content", lines=20, interactive=True) | |
# Download button logic | |
def download_file(content): | |
if content and not content.startswith("⚠️"): # Only provide file if content is valid | |
file_path = "generated_content.txt" | |
with open(file_path, "w", encoding="utf-8") as f: | |
f.write(content) | |
return file_path | |
return None # Return None if no valid content to download | |
download_button = gr.DownloadButton("⬇️ Download Content", file_path=None, interactive=False) | |
# Event handlers | |
btn.click( | |
fn=generate_content, | |
inputs=[topic, style_choice, lang_choice, length_choice, creativity, detail_level, diversity_penalty], | |
outputs=output | |
) | |
# Enable download button only when there's valid content | |
output.change(fn=download_file, inputs=[output], outputs=[download_button]) | |
if __name__ == "__main__": | |
iface.launch() |