import os import time import gc import threading from datetime import datetime import gradio as gr import torch from transformers import pipeline, TextIteratorStreamer import spaces # Import spaces early to enable ZeroGPU support # ------------------------------ # Global Cancellation Event # ------------------------------ cancel_event = threading.Event() # ------------------------------ # Qwen3 Model Definitions # ------------------------------ MODELS = { "Qwen3-8B": {"repo_id": "Qwen/Qwen3-8B", "description": "Qwen3-8B - Largest model with highest capabilities"}, "Qwen3-4B": {"repo_id": "Qwen/Qwen3-4B", "description": "Qwen3-4B - Good balance of performance and efficiency"}, "Qwen3-1.7B": {"repo_id": "Qwen/Qwen3-1.7B", "description": "Qwen3-1.7B - Smaller model for faster responses"}, "Qwen3-0.6B": {"repo_id": "Qwen/Qwen3-0.6B", "description": "Qwen3-0.6B - Ultra-lightweight model"} } # Global cache for pipelines to avoid re-loading. PIPELINES = {} def load_pipeline(model_name): """ Load and cache a transformers pipeline for text generation. Tries bfloat16, falls back to float16 or float32 if unsupported. """ global PIPELINES if model_name in PIPELINES: return PIPELINES[model_name] repo = MODELS[model_name]["repo_id"] for dtype in (torch.bfloat16, torch.float16, torch.float32): try: pipe = pipeline( task="text-generation", model=repo, tokenizer=repo, trust_remote_code=True, torch_dtype=dtype, device_map="auto" ) PIPELINES[model_name] = pipe return pipe except Exception: continue # Final fallback pipe = pipeline( task="text-generation", model=repo, tokenizer=repo, trust_remote_code=True, device_map="auto" ) PIPELINES[model_name] = pipe return pipe def format_conversation(history, system_prompt): """ Flatten chat history and system prompt into a single string. """ prompt = system_prompt.strip() + "\n" for user_msg, assistant_msg in history: prompt += "User: " + user_msg.strip() + "\n" if assistant_msg: # might be None or empty prompt += "Assistant: " + assistant_msg.strip() + "\n" prompt += "Assistant: " return prompt def generate_response(user_input, history, system_prompt, model_name, max_tokens, temperature, top_k, top_p, repeat_penalty): """ Generate a complete response (non-streaming). """ cancel_event.clear() full_history = history.copy() # Format conversation for the model conversation = format_conversation(full_history, system_prompt) try: pipe = load_pipeline(model_name) output = pipe( conversation, max_new_tokens=max_tokens, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repeat_penalty, return_full_text=False )[0]["generated_text"] # Return the updated history history.append((user_input, output)) return history except Exception as e: history.append((user_input, f"Error: {e}")) return history finally: gc.collect() def cancel_generation(): cancel_event.set() return 'Generation cancelled.' def get_default_system_prompt(): today = datetime.now().strftime('%Y-%m-%d') return f"""You are Qwen3, a helpful and friendly AI assistant created by Alibaba Cloud. Today is {today}. Be concise, accurate, and helpful in your responses.""" # CSS for improved visual style css = """ .gradio-container { background-color: #f5f7fb !important; } .qwen-header { background: linear-gradient(90deg, #0099FF, #0066CC); padding: 20px; border-radius: 10px; margin-bottom: 20px; text-align: center; color: white; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); } .qwen-container { border-radius: 10px; box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05); background: white; padding: 20px; margin-bottom: 20px; } .controls-container { background: #f0f4fa; border-radius: 10px; padding: 15px; margin-bottom: 15px; } .model-select { border: 2px solid #0099FF !important; border-radius: 8px !important; } .button-primary { background-color: #0099FF !important; color: white !important; } .button-secondary { background-color: #6c757d !important; color: white !important; } .footer { text-align: center; margin-top: 20px; font-size: 0.8em; color: #666; } """ # Function to get just the model name from the dropdown selection def get_model_name(full_selection): return full_selection.split(" - ")[0] # ------------------------------ # Gradio UI # ------------------------------ with gr.Blocks(title="Qwen3 Chat", css=css) as demo: gr.HTML("""

🤖 Qwen3 Chat

Interact with Alibaba Cloud's Qwen3 language models

""") with gr.Row(): with gr.Column(scale=3): with gr.Group(elem_classes="qwen-container"): model_dd = gr.Dropdown( label="Select Qwen3 Model", choices=[f"{k} - {v['description']}" for k, v in MODELS.items()], value=f"{list(MODELS.keys())[0]} - {MODELS[list(MODELS.keys())[0]]['description']}", elem_classes="model-select" ) with gr.Group(elem_classes="controls-container"): gr.Markdown("### ⚙️ Generation Parameters") sys_prompt = gr.Textbox(label="System Prompt", lines=5, value=get_default_system_prompt()) with gr.Row(): max_tok = gr.Slider(64, 1024, value=512, step=32, label="Max Tokens") with gr.Row(): temp = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature") p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P") with gr.Row(): k = gr.Slider(1, 100, value=40, step=1, label="Top-K") rp = gr.Slider(1.0, 2.0, value=1.1, step=0.1, label="Repetition Penalty") clear_btn = gr.Button("Clear Chat", elem_classes="button-secondary") with gr.Column(scale=7): chatbot = gr.Chatbot() with gr.Row(): txt = gr.Textbox( show_label=False, placeholder="Type your message here...", lines=2 ) submit_btn = gr.Button("Send", variant="primary", elem_classes="button-primary") gr.HTML(""" """) # Define event handlers def user_input(user_message, history): return "", history + [(user_message, None)] def bot_response(history, sys_prompt, model, max_tok, temp, k, p, rp): user_message = history[-1][0] bot_message = generate_response( user_message, history[:-1], sys_prompt, get_model_name(model), max_tok, temp, k, p, rp )[-1][1] history[-1] = (user_message, bot_message) return history # Connect everything submit_btn.click( user_input, [txt, chatbot], [txt, chatbot], queue=False ).then( bot_response, [chatbot, sys_prompt, model_dd, max_tok, temp, k, p, rp], [chatbot] ) txt.submit( user_input, [txt, chatbot], [txt, chatbot], queue=False ).then( bot_response, [chatbot, sys_prompt, model_dd, max_tok, temp, k, p, rp], [chatbot] ) clear_btn.click(lambda: None, None, chatbot, queue=False) if __name__ == "__main__": demo.launch()