Spaces:
Build error
Build error
File size: 10,922 Bytes
d4d58c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 |
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch # Import torch for device management
import os # For file operations
# --- Configuration and Model Loading ---
# You can choose a different model here if you have access to more powerful ones.
# For larger models, ensure you have sufficient VRAM (GPU memory).
# For CPU, smaller models might be necessary or use quantization.
MODEL_NAME = "google/flan-t5-large" # Changed to 'large' for slightly better performance than 'base' and still manageable.
# If you have a powerful GPU, consider "google/flan-t5-xl" or even "google/flan-t5-xxl"
# For even larger models, consider using model.to(torch.bfloat16) or bitsandbytes for 4-bit loading if available.
try:
# Determine the device to use (GPU if available, else CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading model on device: {device}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Load model with half-precision (float16) to save VRAM if on GPU
# Or load in 8-bit/4-bit if using libraries like bitsandbytes (requires installation)
if device == "cuda":
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16).to(device)
else:
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(device)
model.eval() # Set model to evaluation mode
print(f"Model '{MODEL_NAME}' loaded successfully.")
except Exception as e:
print(f"Error loading model: {e}")
print("Please check your internet connection, model name, and available resources (RAM/VRAM).")
# Exit or handle gracefully if model loading fails
tokenizer, model = None, None
# --- Prompt Engineering Functions (more structured) ---
def create_arabic_prompt(topic, style):
if style == "Blog Post (Descriptive)":
return f"اكتب مقالاً احترافياً بأسلوب شخصي عن: {topic}. ركز على التفاصيل، الوصف الجذاب، قدم نصائح عملية. اجعل النص منسقاً بفقرات وعناوين فرعية."
elif style == "Social Media Post (Short & Catchy)":
return f"اكتب منشوراً قصيراً وجذاباً ومثيراً للتفاعل عن: {topic}. أضف 2-3 إيموجي مناسبة واقترح 4 هاشتاغات شائعة. ابدأ بسؤال أو جملة جذابة."
else: # Video Script (Storytelling)
return f"اكتب سيناريو فيديو احترافي ومقنع عن: {topic}. اجعل الأسلوب قصصي وسردي، مقسماً إلى مشاهد رئيسية، مع اقتراح لقطات بصرية (B-roll) وأصوات (SFX) لكل مشهد. ركز على إثارة المشاعر."
def create_english_prompt(topic, style):
if style == "Blog Post (Descriptive)":
return f"Write a detailed and professional blog post about: {topic}. Focus on personal insights, vivid descriptions, and practical advice. Structure it with clear paragraphs and subheadings."
elif style == "Social Media Post (Short & Catchy)":
return f"Write a short, catchy, and engaging social media post about: {topic}. Include 2-3 relevant emojis and suggest 4 trending hashtags. Start with a hook question or statement."
else: # Video Script (Storytelling)
return f"Write a professional, compelling video script about: {topic}. Make it emotionally engaging and story-driven, divided into key scenes, with suggested visual shots (B-roll) and sound effects (SFX) for each scene."
# --- Content Generation Function ---
@torch.no_grad() # Disable gradient calculations for inference to save memory
def generate_content(topic, style_choice, lang_choice, length_choice, creativity, detail_level, diversity_penalty):
if tokenizer is None or model is None:
return "⚠️ Error: Model not loaded. Please check the console for details."
if not topic:
return "⚠️ Please enter a topic to generate content."
# Max length based on desired length and model's context window
# Flan-T5 has a context window of 512, so max_length should be within this.
if length_choice == "Short":
max_new_tokens = 150
min_new_tokens = 50
elif length_choice == "Medium":
max_new_tokens = 300
min_new_tokens = 100
else: # Long
max_new_tokens = 450 # Max for Flan-T5 effectively
min_new_tokens = 150
# Adjust generation parameters based on user input
temperature = creativity # Direct mapping
top_p = detail_level # Direct mapping, higher means more detail/diversity
no_repeat_ngram_size = diversity_penalty # Higher means less repetition
# Build the prompt
if lang_choice == "Arabic":
prompt = create_arabic_prompt(topic, style_choice)
else: # English
prompt = create_english_prompt(topic, style_choice)
# Add detail level instruction to prompt if high
if detail_level > 0.7: # Only if user explicitly wants high detail
prompt += " Ensure comprehensive coverage and rich descriptions."
if creativity > 0.8:
prompt += " Be highly creative and imaginative in your writing."
try:
inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True).to(device)
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
num_beams=5, # Beam search for better quality
do_sample=True, # Enable sampling for creativity
temperature=temperature,
top_p=top_p,
top_k=50, # Consider top 50 words
no_repeat_ngram_size=no_repeat_ngram_size,
length_penalty=1.0, # Adjust to control output length
early_stopping=True
)
content = tokenizer.decode(outputs[0], skip_special_tokens=True)
return content
except RuntimeError as e:
if "out of memory" in str(e):
return "⚠️ Generation failed: Out of memory. Try a shorter length, a less complex model, or restart the application if on GPU."
return f"⚠️ Generation failed due as runtime error: {str(e)}"
except Exception as e:
return f"⚠️ An unexpected error occurred during generation: {str(e)}"
# --- Gradio Interface ---
# Custom CSS for a more polished look
custom_css = """
h1, h2, h3 { color: #4B0082; } /* Dark Purple */
.gradio-container {
background-color: #F8F0FF; /* Light Lavender */
font-family: 'Segoe UI', sans-serif;
}
.gr-button {
background-color: #8A2BE2; /* Blue Violet */
color: white;
border-radius: 10px;
padding: 10px 20px;
font-size: 1.1em;
}
.gr-button:hover {
background-color: #9370DB; /* Medium Purple */
}
.gr-text-input, .gr-textarea {
border: 1px solid #DDA0DD; /* Plum */
border-radius: 8px;
padding: 10px;
}
.gradio-radio input:checked + label {
background-color: #DA70D6 !important; /* Orchid */
color: white !important;
}
.gradio-radio label {
border: 1px solid #DDA0DD;
border-radius: 8px;
padding: 8px 15px;
}
"""
with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as iface:
gr.Markdown("# ✨ AI Content Creation Studio")
gr.Markdown("## Generate professional blogs, social media posts, or video scripts in seconds!")
with gr.Row():
with gr.Column(scale=2):
topic = gr.Textbox(
label="Topic / الموضوع",
placeholder="e.g., The Future of AI in Healthcare / مثال: مستقبل الذكاء الاصطناعي في الرعاية الصحية",
lines=2
)
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
creativity = gr.Slider(
minimum=0.1, maximum=1.0, value=0.7, step=0.1,
label="Creativity (Temperature)",
info="Higher values lead to more creative, less predictable text. Lower values are more focused."
)
detail_level = gr.Slider(
minimum=0.1, maximum=1.0, value=0.9, step=0.1,
label="Detail Level (Top-p Sampling)",
info="Higher values allow for more diverse and detailed vocabulary. Lower values prune less likely words."
)
with gr.Row():
diversity_penalty = gr.Slider(
minimum=1, maximum=5, value=2, step=1,
label="Repetition Penalty (N-gram)",
info="Higher values reduce the chance of repeating the same phrases or words. Set to 1 for no penalty."
)
with gr.Column(scale=1):
with gr.Group():
style_choice = gr.Radio(
["Blog Post (Descriptive)", "Social Media Post (Short & Catchy)", "Video Script (Storytelling)"],
label="Content Style / نوع المحتوى",
value="Blog Post (Descriptive)",
interactive=True
)
with gr.Group():
lang_choice = gr.Radio(
["English", "Arabic"],
label="Language / اللغة",
value="English",
interactive=True
)
with gr.Group():
length_choice = gr.Radio(
["Short", "Medium", "Long"],
label="Content Length / طول النص",
value="Medium",
interactive=True
)
gr.Markdown("*(Note: 'Long' is relative to model capabilities, max ~450 words)*")
btn = gr.Button("🚀 Generate Content", variant="primary")
output = gr.Textbox(label="Generated Content", lines=20, interactive=True)
# Download button logic
def download_file(content):
if content and not content.startswith("⚠️"): # Only provide file if content is valid
file_path = "generated_content.txt"
with open(file_path, "w", encoding="utf-8") as f:
f.write(content)
return file_path
return None # Return None if no valid content to download
download_button = gr.DownloadButton("⬇️ Download Content", file_path=None, interactive=False)
# Event handlers
btn.click(
fn=generate_content,
inputs=[topic, style_choice, lang_choice, length_choice, creativity, detail_level, diversity_penalty],
outputs=output
)
# Enable download button only when there's valid content
output.change(fn=download_file, inputs=[output], outputs=[download_button])
if __name__ == "__main__":
iface.launch() |