File size: 10,922 Bytes
d4d58c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch # Import torch for device management
import os # For file operations

# --- Configuration and Model Loading ---
# You can choose a different model here if you have access to more powerful ones.
# For larger models, ensure you have sufficient VRAM (GPU memory).
# For CPU, smaller models might be necessary or use quantization.
MODEL_NAME = "google/flan-t5-large" # Changed to 'large' for slightly better performance than 'base' and still manageable.
# If you have a powerful GPU, consider "google/flan-t5-xl" or even "google/flan-t5-xxl"
# For even larger models, consider using model.to(torch.bfloat16) or bitsandbytes for 4-bit loading if available.

try:
    # Determine the device to use (GPU if available, else CPU)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Loading model on device: {device}")

    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    # Load model with half-precision (float16) to save VRAM if on GPU
    # Or load in 8-bit/4-bit if using libraries like bitsandbytes (requires installation)
    if device == "cuda":
        model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16).to(device)
    else:
        model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(device)
    
    model.eval() # Set model to evaluation mode
    print(f"Model '{MODEL_NAME}' loaded successfully.")

except Exception as e:
    print(f"Error loading model: {e}")
    print("Please check your internet connection, model name, and available resources (RAM/VRAM).")
    # Exit or handle gracefully if model loading fails
    tokenizer, model = None, None

# --- Prompt Engineering Functions (more structured) ---

def create_arabic_prompt(topic, style):
    if style == "Blog Post (Descriptive)":
        return f"اكتب مقالاً احترافياً بأسلوب شخصي عن: {topic}. ركز على التفاصيل، الوصف الجذاب، قدم نصائح عملية. اجعل النص منسقاً بفقرات وعناوين فرعية."
    elif style == "Social Media Post (Short & Catchy)":
        return f"اكتب منشوراً قصيراً وجذاباً ومثيراً للتفاعل عن: {topic}. أضف 2-3 إيموجي مناسبة واقترح 4 هاشتاغات شائعة. ابدأ بسؤال أو جملة جذابة."
    else:  # Video Script (Storytelling)
        return f"اكتب سيناريو فيديو احترافي ومقنع عن: {topic}. اجعل الأسلوب قصصي وسردي، مقسماً إلى مشاهد رئيسية، مع اقتراح لقطات بصرية (B-roll) وأصوات (SFX) لكل مشهد. ركز على إثارة المشاعر."

def create_english_prompt(topic, style):
    if style == "Blog Post (Descriptive)":
        return f"Write a detailed and professional blog post about: {topic}. Focus on personal insights, vivid descriptions, and practical advice. Structure it with clear paragraphs and subheadings."
    elif style == "Social Media Post (Short & Catchy)":
        return f"Write a short, catchy, and engaging social media post about: {topic}. Include 2-3 relevant emojis and suggest 4 trending hashtags. Start with a hook question or statement."
    else:  # Video Script (Storytelling)
        return f"Write a professional, compelling video script about: {topic}. Make it emotionally engaging and story-driven, divided into key scenes, with suggested visual shots (B-roll) and sound effects (SFX) for each scene."

# --- Content Generation Function ---

@torch.no_grad() # Disable gradient calculations for inference to save memory
def generate_content(topic, style_choice, lang_choice, length_choice, creativity, detail_level, diversity_penalty):
    if tokenizer is None or model is None:
        return "⚠️ Error: Model not loaded. Please check the console for details."

    if not topic:
        return "⚠️ Please enter a topic to generate content."

    # Max length based on desired length and model's context window
    # Flan-T5 has a context window of 512, so max_length should be within this.
    if length_choice == "Short":
        max_new_tokens = 150
        min_new_tokens = 50
    elif length_choice == "Medium":
        max_new_tokens = 300
        min_new_tokens = 100
    else:  # Long
        max_new_tokens = 450 # Max for Flan-T5 effectively
        min_new_tokens = 150

    # Adjust generation parameters based on user input
    temperature = creativity # Direct mapping
    top_p = detail_level # Direct mapping, higher means more detail/diversity
    no_repeat_ngram_size = diversity_penalty # Higher means less repetition

    # Build the prompt
    if lang_choice == "Arabic":
        prompt = create_arabic_prompt(topic, style_choice)
    else:  # English
        prompt = create_english_prompt(topic, style_choice)

    # Add detail level instruction to prompt if high
    if detail_level > 0.7: # Only if user explicitly wants high detail
         prompt += " Ensure comprehensive coverage and rich descriptions."
    if creativity > 0.8:
        prompt += " Be highly creative and imaginative in your writing."

    try:
        inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True).to(device)
        
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            min_new_tokens=min_new_tokens,
            num_beams=5, # Beam search for better quality
            do_sample=True, # Enable sampling for creativity
            temperature=temperature,
            top_p=top_p,
            top_k=50, # Consider top 50 words
            no_repeat_ngram_size=no_repeat_ngram_size,
            length_penalty=1.0, # Adjust to control output length
            early_stopping=True
        )
        content = tokenizer.decode(outputs[0], skip_special_tokens=True)

        return content
    except RuntimeError as e:
        if "out of memory" in str(e):
            return "⚠️ Generation failed: Out of memory. Try a shorter length, a less complex model, or restart the application if on GPU."
        return f"⚠️ Generation failed due as runtime error: {str(e)}"
    except Exception as e:
        return f"⚠️ An unexpected error occurred during generation: {str(e)}"

# --- Gradio Interface ---

# Custom CSS for a more polished look
custom_css = """
h1, h2, h3 { color: #4B0082; } /* Dark Purple */
.gradio-container {
    background-color: #F8F0FF; /* Light Lavender */
    font-family: 'Segoe UI', sans-serif;
}
.gr-button {
    background-color: #8A2BE2; /* Blue Violet */
    color: white;
    border-radius: 10px;
    padding: 10px 20px;
    font-size: 1.1em;
}
.gr-button:hover {
    background-color: #9370DB; /* Medium Purple */
}
.gr-text-input, .gr-textarea {
    border: 1px solid #DDA0DD; /* Plum */
    border-radius: 8px;
    padding: 10px;
}
.gradio-radio input:checked + label {
    background-color: #DA70D6 !important; /* Orchid */
    color: white !important;
}
.gradio-radio label {
    border: 1px solid #DDA0DD;
    border-radius: 8px;
    padding: 8px 15px;
}
"""

with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as iface:
    gr.Markdown("# ✨ AI Content Creation Studio")
    gr.Markdown("## Generate professional blogs, social media posts, or video scripts in seconds!")
    
    with gr.Row():
        with gr.Column(scale=2):
            topic = gr.Textbox(
                label="Topic / الموضوع", 
                placeholder="e.g., The Future of AI in Healthcare / مثال: مستقبل الذكاء الاصطناعي في الرعاية الصحية",
                lines=2
            )
            
            with gr.Accordion("Advanced Settings", open=False):
                with gr.Row():
                    creativity = gr.Slider(
                        minimum=0.1, maximum=1.0, value=0.7, step=0.1, 
                        label="Creativity (Temperature)", 
                        info="Higher values lead to more creative, less predictable text. Lower values are more focused."
                    )
                    detail_level = gr.Slider(
                        minimum=0.1, maximum=1.0, value=0.9, step=0.1, 
                        label="Detail Level (Top-p Sampling)", 
                        info="Higher values allow for more diverse and detailed vocabulary. Lower values prune less likely words."
                    )
                with gr.Row():
                    diversity_penalty = gr.Slider(
                        minimum=1, maximum=5, value=2, step=1, 
                        label="Repetition Penalty (N-gram)", 
                        info="Higher values reduce the chance of repeating the same phrases or words. Set to 1 for no penalty."
                    )

        with gr.Column(scale=1):
            with gr.Group():
                style_choice = gr.Radio(
                    ["Blog Post (Descriptive)", "Social Media Post (Short & Catchy)", "Video Script (Storytelling)"],
                    label="Content Style / نوع المحتوى",
                    value="Blog Post (Descriptive)",
                    interactive=True
                )
            with gr.Group():
                lang_choice = gr.Radio(
                    ["English", "Arabic"], 
                    label="Language / اللغة",
                    value="English",
                    interactive=True
                )
            with gr.Group():
                length_choice = gr.Radio(
                    ["Short", "Medium", "Long"], 
                    label="Content Length / طول النص", 
                    value="Medium",
                    interactive=True
                )
                gr.Markdown("*(Note: 'Long' is relative to model capabilities, max ~450 words)*")

    btn = gr.Button("🚀 Generate Content", variant="primary")
    
    output = gr.Textbox(label="Generated Content", lines=20, interactive=True)
    
    # Download button logic
    def download_file(content):
        if content and not content.startswith("⚠️"): # Only provide file if content is valid
            file_path = "generated_content.txt"
            with open(file_path, "w", encoding="utf-8") as f:
                f.write(content)
            return file_path
        return None # Return None if no valid content to download

    download_button = gr.DownloadButton("⬇️ Download Content", file_path=None, interactive=False)

    # Event handlers
    btn.click(
        fn=generate_content, 
        inputs=[topic, style_choice, lang_choice, length_choice, creativity, detail_level, diversity_penalty], 
        outputs=output
    )
    
    # Enable download button only when there's valid content
    output.change(fn=download_file, inputs=[output], outputs=[download_button])

if __name__ == "__main__":
    iface.launch()