Al-Atlas-LLM / app.py
BounharAbdelaziz's picture
added disclamer about chat
37d4ad6 verified
raw
history blame
13.8 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
import spaces
import torch
from datasets import load_dataset
from huggingface_hub import CommitScheduler
from pathlib import Path
import uuid
import json
import time
from datetime import datetime
import logging
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("app.log"),
logging.StreamHandler()
]
)
logger = logging.getLogger("darija-llm")
device = "cuda:0" if torch.cuda.is_available() else "cpu"
logger.info(f'Using device: {device}')
# token
token = os.environ['TOKEN']
# Load the pretrained model and tokenizer
MODEL_NAME = "atlasia/Al-Atlas-0.5B" # "atlasia/Al-Atlas-LLM-mid-training" # "BounharAbdelaziz/Al-Atlas-LLM-0.5B" #"atlasia/Al-Atlas-LLM"
logger.info(f"Loading model: {MODEL_NAME}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME,token=token) # , token=token
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME,token=token).to(device)
logger.info("Model loaded successfully")
# Fix tokenizer padding
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token # Set pad token
logger.info("Set pad_token to eos_token")
# Predefined examples
examples = [
["الذكاء الاصطناعي هو فرع من علوم الكمبيوتر اللي كيركز"
, 256, 0.7, 0.9, 100, 4, 1.5],
["المستقبل ديال الذكاء الصناعي فالمغرب"
, 256, 0.7, 0.9, 100, 4, 1.5],
[" المطبخ المغربي"
, 256, 0.7, 0.9, 100, 4, 1.5],
["الماكلة المغربية كتعتبر من أحسن الماكلات فالعالم"
, 256, 0.7, 0.9, 100, 4, 1.5],
]
# Define the file where to save the data
submit_file = Path("user_submit/") / f"data_{uuid.uuid4()}.json"
feedback_file = submit_file
# Create directory if it doesn't exist
submit_file.parent.mkdir(exist_ok=True, parents=True)
logger.info(f"Created feedback file: {feedback_file}")
scheduler = CommitScheduler(
repo_id="atlasia/atlaset_inference_ds",
repo_type="dataset",
folder_path=submit_file.parent,
path_in_repo="data",
every=5,
token=token
)
logger.info(f"Initialized CommitScheduler for repo: atlasia/atlaset_inference_ds")
# Track usage statistics
usage_stats = {
"total_generations": 0,
"total_tokens_generated": 0,
"start_time": time.time()
}
@spaces.GPU
def generate_text(prompt, max_length=256, temperature=0.7, top_p=0.9, top_k=150, num_beams=8, repetition_penalty=1.5, progress=gr.Progress()):
if not prompt.strip():
logger.warning("Empty prompt submitted")
return "", "الرجاء إدخال نص للتوليد (Please enter text to generate)"
logger.info(f"Generating text for prompt: '{prompt[:50]}...' (length: {len(prompt)})")
logger.info(f"Parameters: max_length={max_length}, temp={temperature}, top_p={top_p}, top_k={top_k}, beams={num_beams}, rep_penalty={repetition_penalty}")
start_time = time.time()
# Start progress
progress(0, desc="تجهيز النموذج (Preparing model)")
# Tokenize input
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
progress(0.1, desc="تحليل النص (Tokenizing)")
# Generate text with optimized parameters for speed
progress(0.2, desc="توليد النص (Generating text)")
output = model.generate(
**inputs,
max_length=max_length,
temperature=temperature,
top_p=top_p,
do_sample=True,
repetition_penalty=repetition_penalty,
num_beams=1 if num_beams > 4 else num_beams, # Reduce beam search or use greedy decoding
top_k=top_k,
early_stopping=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
use_cache=True, # Ensure cache is used
)
# Decode output
progress(0.9, desc="معالجة النتائج (Processing results)")
result = tokenizer.decode(output[0], skip_special_tokens=True)
# Update stats
generation_time = time.time() - start_time
token_count = len(output[0])
with scheduler.lock:
usage_stats["total_generations"] += 1
usage_stats["total_tokens_generated"] += token_count
logger.info(f"Generated {token_count} tokens in {generation_time:.2f}s")
logger.info(f"Result: '{result[:50]}...' (length: {len(result)})")
# Save feedback with additional metadata
save_feedback(
prompt,
result,
{
"max_length": max_length,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"num_beams": num_beams,
"repetition_penalty": repetition_penalty,
"generation_time": generation_time,
"token_count": token_count,
"timestamp": datetime.now().isoformat()
}
)
progress(1.0, desc="اكتمل (Complete)")
return result, f"تم توليد {token_count} رمز في {generation_time:.2f} ثانية (Generated {token_count} tokens in {generation_time:.2f} seconds)"
def save_feedback(input, output, params) -> None:
"""
Append input/outputs and parameters to a JSON Lines file using a thread lock
to avoid concurrent writes from different users.
"""
logger.info(f"Saving feedback to {feedback_file}")
with scheduler.lock:
try:
with feedback_file.open("a") as f:
f.write(json.dumps({
"input": input,
"output": output,
"params": params
}))
f.write("\n")
logger.info("Feedback saved successfully")
except Exception as e:
logger.error(f"Error saving feedback: {str(e)}")
def get_stats():
"""Return current usage statistics"""
with scheduler.lock:
uptime = time.time() - usage_stats["start_time"]
hours = uptime / 3600
stats = {
"Total generations": usage_stats["total_generations"],
"Total tokens generated": usage_stats["total_tokens_generated"],
"Uptime": f"{int(hours)}h {int((hours % 1) * 60)}m",
"Generations per hour": f"{usage_stats['total_generations'] / hours:.1f}" if hours > 0 else "N/A",
"Last updated": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
}
logger.info(f"Stats requested: {stats}")
return stats
def reset_params():
"""Reset parameters to default values"""
logger.info("Parameters reset to defaults")
return 128, 0.7, 0.9, 50, 1, 1.2 # Updated defaults for faster generation
def thumbs_up_callback(input_text, output_text):
"""Record positive feedback"""
logger.info("Received positive feedback")
feedback_path = Path("user_submit") / "positive_feedback.jsonl"
feedback_path.parent.mkdir(exist_ok=True, parents=True)
with scheduler.lock:
try:
with feedback_path.open("a") as f:
feedback_data = {
"input": input_text,
"output": output_text,
"rating": "positive",
"timestamp": datetime.now().isoformat()
}
f.write(json.dumps(feedback_data))
f.write("\n")
logger.info(f"Positive feedback saved to {feedback_path}")
except Exception as e:
logger.error(f"Error saving positive feedback: {str(e)}")
return "شكرا على التقييم الإيجابي!"
def thumbs_down_callback(input_text, output_text, feedback=""):
"""Record negative feedback"""
logger.info(f"Received negative feedback: '{feedback}'")
feedback_path = Path("user_submit") / "negative_feedback.jsonl"
feedback_path.parent.mkdir(exist_ok=True, parents=True)
with scheduler.lock:
try:
with feedback_path.open("a") as f:
feedback_data = {
"input": input_text,
"output": output_text,
"rating": "negative",
"feedback": feedback,
"timestamp": datetime.now().isoformat()
}
f.write(json.dumps(feedback_data))
f.write("\n")
logger.info(f"Negative feedback saved to {feedback_path}")
except Exception as e:
logger.error(f"Error saving negative feedback: {str(e)}")
return "شكرا على ملاحظاتك!"
if __name__ == "__main__":
logger.info("Starting Moroccan Darija LLM application")
# Create the Gradio interface
with gr.Blocks(css="""
footer {visibility: hidden}
.center-text {text-align: center; margin: 0 auto; max-width: 900px;}
.header-text {font-size: 2.5rem; font-weight: bold; margin-bottom: 0.5rem;}
.subheader-text {font-size: 1.2rem; margin-bottom: 2rem;}
.flag-emoji {font-size: 3rem;}
""") as app:
with gr.Row(elem_classes=["center-text"]):
gr.Markdown("""
# 🇲🇦🇲🇦🇲🇦
# Al-Atlas-0.5B-base
This is a pretrained model to do text generation in a continuation of text fashion. Do not expect it to behave as a Chat (Instruct) model. The latter is coming soon!
""")
with gr.Row():
with gr.Column(scale=6):
prompt_input = gr.Textbox(
label="Prompt ",
placeholder="اكتب هنا...",
lines=4, rtl=True
)
with gr.Row():
submit_btn = gr.Button("Generate", variant="primary")
clear_btn = gr.Button("Clear")
reset_btn = gr.Button("Reset Parameters")
with gr.Accordion("Generation Parameters", open=False):
with gr.Row():
with gr.Column():
max_length = gr.Slider(8, 4096, value=128, label="Max Length") # Reduced default
temperature = gr.Slider(0.0, 2, value=0.7, label="Temperature")
top_p = gr.Slider(0.0, 1.0, value=0.9, label="Top-p")
with gr.Column():
top_k = gr.Slider(1, 10000, value=50, label="Top-k") # Reduced default
num_beams = gr.Slider(1, 20, value=1, label="Number of Beams") # Reduced default
repetition_penalty = gr.Slider(0.0, 100.0, value=1.2, label="Repetition Penalty") # Reduced default
with gr.Column(scale=6):
output_text = gr.Textbox(label="Generated Text", lines=10, rtl=True)
generation_info = gr.Markdown("")
with gr.Row():
thumbs_up = gr.Button("👍 ناضي")
thumbs_down = gr.Button("👎 عيان")
with gr.Accordion("Feedback", open=False, visible=False) as feedback_accordion:
feedback_text = gr.Textbox(label="Why didn't you like the output?", lines=2, rtl=True)
submit_feedback = gr.Button("Submit Feedback")
feedback_result = gr.Markdown("")
with gr.Accordion("Usage Statistics", open=False):
stats_md = gr.JSON(get_stats, every=10)
refresh_stats = gr.Button("Refresh")
# Examples section with caching
gr.Examples(
examples=examples,
inputs=[prompt_input, max_length, temperature, top_p, top_k, num_beams, repetition_penalty],
outputs=[output_text, generation_info],
fn=generate_text,
cache_examples=True
)
# Button actions
submit_btn.click(
generate_text,
inputs=[prompt_input, max_length, temperature, top_p, top_k, num_beams, repetition_penalty],
outputs=[output_text, generation_info]
)
clear_btn.click(
lambda: ("", ""),
inputs=None,
outputs=[prompt_input, output_text]
)
reset_btn.click(
reset_params,
inputs=None,
outputs=[max_length, temperature, top_p, top_k, num_beams, repetition_penalty]
)
# Feedback system
thumbs_up.click(
thumbs_up_callback,
inputs=[prompt_input, output_text],
outputs=[feedback_result]
)
thumbs_down.click(
thumbs_down_callback,
inputs=[prompt_input, output_text],
outputs=[feedback_result]
)
submit_feedback.click(
thumbs_down_callback,
inputs=[prompt_input, output_text, feedback_text],
outputs=[feedback_result]
)
# Stats refresh
refresh_stats.click(
get_stats,
inputs=None,
outputs=[stats_md]
)
# Keyboard shortcuts
prompt_input.submit(
generate_text,
inputs=[prompt_input, max_length, temperature, top_p, top_k, num_beams, repetition_penalty],
outputs=[output_text, generation_info]
)
logger.info("Launching Gradio interface")
app.launch()
logger.info("Gradio interface closed")