Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import spaces | |
from transformers import pipeline, AutoTokenizer | |
import torch | |
from typing import List, Dict, Optional | |
# Global variable to store pipelines | |
model_cache = {} | |
tokenizer_cache = {} | |
# Available models | |
AVAILABLE_MODELS = { | |
"Daedalus-1-2B": "NoemaResearch/Daedalus-1-2B", | |
"Daedalus-1-8B": "NoemaResearch/Daedalus-1-8B", | |
} | |
# Models that need special token handling for repetition issues | |
MODELS_NEEDING_SPECIAL_HANDLING = {"Daedalus-1-8B"} | |
def initialize_model(model_name): | |
global model_cache, tokenizer_cache | |
if model_name not in AVAILABLE_MODELS: | |
raise ValueError(f"Model {model_name} not found in available models") | |
model_id = AVAILABLE_MODELS[model_name] | |
# Check if model is already cached | |
if model_id not in model_cache: | |
try: | |
# Load tokenizer separately to handle chat template properly | |
tokenizer_cache[model_id] = AutoTokenizer.from_pretrained( | |
model_id, | |
trust_remote_code=True | |
) | |
model_cache[model_id] = pipeline( | |
"text-generation", | |
model=model_id, | |
tokenizer=tokenizer_cache[model_id], | |
torch_dtype=torch.float16, | |
device_map="auto", | |
trust_remote_code=True | |
) | |
except Exception: | |
# Fallback to CPU if GPU fails | |
tokenizer_cache[model_id] = AutoTokenizer.from_pretrained( | |
model_id, | |
trust_remote_code=True | |
) | |
model_cache[model_id] = pipeline( | |
"text-generation", | |
model=model_id, | |
tokenizer=tokenizer_cache[model_id], | |
torch_dtype=torch.float32, | |
device_map="cpu", | |
trust_remote_code=True | |
) | |
return model_cache[model_id], tokenizer_cache[model_id] | |
def format_conversation_with_template(messages: List[Dict], tokenizer) -> str: | |
"""Manually apply the chat template to ensure proper formatting""" | |
# Get the chat template | |
if hasattr(tokenizer, 'chat_template') and tokenizer.chat_template: | |
try: | |
# Use the tokenizer's apply_chat_template method | |
formatted = tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
return formatted | |
except Exception as e: | |
print(f"Chat template application failed: {e}") | |
# Fall back to manual formatting | |
pass | |
# Manual fallback formatting using actual special tokens | |
bos_token = "<[begin▁of▁sentence]>" | |
eos_token = "<[end▁of▁sentence]>" | |
# Start with system message | |
formatted = f"{bos_token}system\nYou are an AI Coding model called Daedalus, developed by Noema Research{eos_token}" | |
# Add each message | |
for msg in messages: | |
role = msg.get('role', 'user') | |
content = msg.get('content', '').strip() | |
formatted += f"{bos_token}{role}\n{content}{eos_token}" | |
# Add generation prompt | |
formatted += f"{bos_token}assistant\n" | |
return formatted | |
def generate_response(message, history, model_name, max_length=512, temperature=0.7, top_p=0.9): | |
"""Generate response using the selected model""" | |
try: | |
model_pipe, tokenizer = initialize_model(model_name) | |
except Exception as e: | |
return f"Error loading model {model_name}: {str(e)}" | |
# Format the conversation history | |
messages = [] | |
for user_msg, assistant_msg in history: | |
messages.append({"role": "user", "content": user_msg}) | |
if assistant_msg: | |
messages.append({"role": "assistant", "content": assistant_msg}) | |
messages.append({"role": "user", "content": message}) | |
try: | |
# Format the conversation using the chat template | |
formatted_prompt = format_conversation_with_template(messages, tokenizer) | |
# Different generation parameters based on model | |
if model_name in MODELS_NEEDING_SPECIAL_HANDLING: | |
# 8B model needs special token handling to prevent repetition | |
stop_tokens = [ | |
"<[end▁of▁sentence]>", # EOS token | |
"<[begin▁of▁sentence]>", # BOS token (shouldn't appear mid-generation) | |
"user\n", # Stop if model tries to continue conversation | |
"system\n", # Stop if model tries to add system messages | |
"\nuser", # Alternative format | |
"\nsystem" # Alternative format | |
] | |
response = model_pipe( | |
formatted_prompt, | |
max_new_tokens=max_length, | |
temperature=temperature, | |
top_p=top_p, | |
do_sample=True, | |
pad_token_id=1, # PAD token ID from config | |
eos_token_id=2, # EOS token ID from config | |
bos_token_id=0, # BOS token ID from config | |
return_full_text=False, | |
repetition_penalty=1.1, # Reduce loops | |
stop_sequence=stop_tokens[0] # Primary stop token | |
) | |
else: | |
# 2B model - standard generation without special handling | |
response = model_pipe( | |
formatted_prompt, | |
max_new_tokens=max_length, | |
temperature=temperature, | |
top_p=top_p, | |
do_sample=True, | |
pad_token_id=tokenizer.pad_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
return_full_text=False, | |
repetition_penalty=1.05 # Light repetition penalty | |
) | |
if isinstance(response, list) and len(response) > 0: | |
generated_text = response[0]['generated_text'] | |
else: | |
generated_text = str(response) | |
# Clean up the response | |
assistant_response = str(generated_text).strip() | |
# Apply different cleanup based on model | |
if model_name in MODELS_NEEDING_SPECIAL_HANDLING: | |
# More aggressive cleanup for 8B model | |
stop_tokens = [ | |
"<[end▁of▁sentence]>", "<[begin▁of▁sentence]>", | |
"user\n", "system\n", "\nuser", "\nsystem" | |
] | |
for stop_token in stop_tokens: | |
if stop_token in assistant_response: | |
assistant_response = assistant_response.split(stop_token)[0].strip() | |
# Additional cleanup for common repetition patterns | |
lines = assistant_response.split('\n') | |
cleaned_lines = [] | |
for line in lines: | |
if line.strip() and not line.strip().startswith(('user', 'assistant', 'system')): | |
cleaned_lines.append(line) | |
assistant_response = '\n'.join(cleaned_lines).strip() | |
else: | |
# Standard cleanup for 2B model | |
if assistant_response.startswith("assistant\n"): | |
assistant_response = assistant_response[10:].strip() | |
return assistant_response if assistant_response else "I apologize, but I couldn't generate a proper response. Please try again." | |
except Exception as e: | |
return f"Error generating response: {str(e)}" | |
def create_interface(): | |
with gr.Blocks(title="Daedalus-1-8B Chat", theme=gr.themes.Base(primary_hue="green")) as demo: | |
gr.Markdown(""" | |
# 🟢 Daedalus Chat Interface | |
Chat with **Daedalus models** by Noema Research. | |
""") | |
# Model selection dropdown | |
model_dropdown = gr.Dropdown( | |
choices=list(AVAILABLE_MODELS.keys()), | |
value="Daedalus-1-2B", # Default to 2B model | |
label="Select Model", | |
info="Choose between Daedalus-1-2B (faster) or Daedalus-1-8B (more capable)" | |
) | |
chatbot = gr.Chatbot( | |
height=400, | |
placeholder="Start chatting with Daedalus-1-8B...", | |
label="Chat" | |
) | |
msg = gr.Textbox( | |
placeholder="Type your message here...", | |
label="Message", | |
lines=2 | |
) | |
with gr.Row(): | |
submit_btn = gr.Button("Send", variant="primary") | |
clear_btn = gr.Button("Clear Chat", variant="secondary") | |
with gr.Accordion("Advanced Settings", open=False): | |
max_length = gr.Slider( | |
minimum=200, | |
maximum=4096, # Reduced from 8192 to prevent memory issues | |
value=1024, # Reduced default from 2048 | |
step=50, | |
label="Max New Tokens", | |
info="Maximum number of new tokens to generate" | |
) | |
temperature = gr.Slider( | |
minimum=0.1, | |
maximum=2.0, | |
value=0.7, | |
step=0.1, | |
label="Temperature", | |
info="Controls randomness in generation" | |
) | |
top_p = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.9, | |
step=0.1, | |
label="Top P", | |
info="Controls diversity via nucleus sampling" | |
) | |
def user_message(message, history): | |
return "", history + [[message, None]] | |
def bot_response(history, selected_model, max_len, temp, top_p): | |
if history: | |
user_message = history[-1][0] | |
bot_message = generate_response( | |
user_message, | |
history[:-1], | |
selected_model, # Use selected model | |
max_len, | |
temp, | |
top_p | |
) | |
history[-1][1] = bot_message | |
return history | |
msg.submit(user_message, [msg, chatbot], [msg, chatbot]).then( | |
bot_response, [chatbot, model_dropdown, max_length, temperature, top_p], chatbot | |
) | |
submit_btn.click(user_message, [msg, chatbot], [msg, chatbot]).then( | |
bot_response, [chatbot, model_dropdown, max_length, temperature, top_p], chatbot | |
) | |
clear_btn.click(lambda: None, None, chatbot, queue=False) | |
gr.Markdown(""" | |
--- | |
### About Daedalus Models | |
**Daedalus-1-2B:** Faster, lightweight model for quick responses and basic coding tasks. | |
**Daedalus-1-8B:** More capable model with advanced reasoning, fine-tuned for structured outputs, | |
debugging, and long-context reasoning (up to ~64K tokens). | |
Both models are optimized for: | |
- Conversational AI | |
- Code generation & debugging | |
- Structured JSON/function outputs | |
- Multi-step reasoning | |
""") | |
return demo | |
# Launch the app | |
if __name__ == "__main__": | |
demo = create_interface() | |
demo.launch(share=True) |