import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer import os # --- Konfiguration --- MODEL_ID = "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B" HF_TOKEN = os.getenv("HF_TOKEN") # Optional: Für private Modelle oder Zugriffsbeschränkungen # --- Lade Modell und Tokenizer (explizit auf CPU) --- print(f"Lade Tokenizer: {MODEL_ID}") # Stelle sicher, dass trust_remote_code=True gesetzt ist, da Qwen3 dies oft benötigt tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, token=HF_TOKEN) if tokenizer.pad_token is None: print("pad_token nicht gesetzt, verwende eos_token als pad_token.") tokenizer.pad_token = tokenizer.eos_token print(f"Lade Modell: {MODEL_ID} auf CPU. Dies kann einige Zeit dauern...") try: model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16, device_map="cpu", trust_remote_code=True, token=HF_TOKEN ) except Exception as e: print(f"Fehler beim Laden mit bfloat16 ({e}), versuche float32...") model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True, token=HF_TOKEN ) model.eval() print("Modell und Tokenizer erfolgreich geladen.") # --- Vorhersagefunktion für das ChatInterface --- def predict(message, history): messages_for_template = [] for user_msg, ai_msg in history: # history ist jetzt eine Liste von Listen/Tupeln messages_for_template.append({"role": "user", "content": user_msg}) messages_for_template.append({"role": "assistant", "content": ai_msg}) messages_for_template.append({"role": "user", "content": message}) try: prompt = tokenizer.apply_chat_template( messages_for_template, tokenize=False, add_generation_prompt=True ) except Exception as e: print(f"Fehler beim Anwenden des Chat-Templates: {e}") prompt_parts = [] for turn in messages_for_template: prompt_parts.append(f"<|im_start|>{turn['role']}\n{turn['content']}<|im_end|>") prompt = "\n".join(prompt_parts) + "\n<|im_start|>assistant\n" inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to("cpu") generation_kwargs = { "max_new_tokens": 512, "temperature": 0.7, "top_p": 0.9, "top_k": 50, "do_sample": True, "pad_token_id": tokenizer.eos_token_id, } print("Generiere Antwort...") with torch.no_grad(): outputs = model.generate(**inputs, **generation_kwargs) response_ids = outputs[0][inputs.input_ids.shape[-1]:] response = tokenizer.decode(response_ids, skip_special_tokens=True) print(f"Antwort: {response}") return response # --- Gradio UI --- with gr.Blocks(theme=gr.themes.Soft(), title="DeepSeek Qwen3 8B (CPU)") as demo: gr.Markdown( """ # DeepSeek Qwen3 8B Chat (CPU) Dies ist eine Demo des `deepseek-ai/DeepSeek-R1-0528-Qwen3-8B` Modells, das auf einer CPU läuft. **Achtung:** Antworten können aufgrund der CPU-Inferenz **sehr langsam** sein (mehrere Minuten pro Antwort sind möglich). Bitte habe Geduld. """ ) chatbot_interface = gr.ChatInterface( fn=predict, chatbot=gr.Chatbot( height=600, label="Chat", show_label=False, # bubble_full_width=False, # Entfernt, da veraltet # type="messages" # Wichtig, um die Warnung zu beheben, aber history-Format in predict() muss passen # Da predict bereits die history als [[user, ai], [user, ai]] erwartet (Standard für ChatInterface), # lassen wir type hier weg, damit es mit dem Format von predict harmoniert. # Wenn predict `history` als [{"role": "user", ...}, {"role": "assistant", ...}] erwarten würde, # dann wäre `type="messages"` hier richtig. # Da die Warnung sich auf die Standardeinstellung bezieht, die bald "messages" sein wird, # und unsere predict-Funktion bereits das "tuples"-Format verarbeitet, ist das OK für jetzt. # Man könnte predict anpassen, um das "messages" Format direkt zu verarbeiten, wenn man type="messages" setzt. ), textbox=gr.Textbox( placeholder="Stelle mir eine Frage...", container=False, scale=7 ), examples=[ ["Hallo, wer bist du?"], ["Was ist die Hauptstadt von Frankreich?"], ["Schreibe ein kurzes Gedicht über KI."] ], # Entferne die nicht unterstützten Button-Argumente: # retry_btn="Wiederholen", # undo_btn="Letzte entfernen", # clear_btn="Chat löschen", ) gr.Markdown("Modell von [deepseek-ai](https://huggingface.co/deepseek-ai) auf Hugging Face.") if __name__ == "__main__": demo.launch()