|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import torch |
|
|
|
|
|
model_name = "distilgpt2" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForCausalLM.from_pretrained(model_name) |
|
|
|
|
|
def chat_function(user_input, history): |
|
if history is None: |
|
history = [] |
|
|
|
|
|
prompt = "\n".join([f"User: {h[0]}\nAI: {h[1]}" for h in history] + [f"User: {user_input}"]) |
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt", padding=True) |
|
|
|
|
|
outputs = model.generate( |
|
inputs["input_ids"], |
|
max_length=100, |
|
num_return_sequences=1, |
|
temperature=0.7, |
|
do_sample=True, |
|
pad_token_id=tokenizer.eos_token_id |
|
) |
|
|
|
|
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
response = response[len(prompt):].strip() or "Hmm, I'm not sure what to say!" |
|
|
|
|
|
history.append((user_input, response)) |
|
return history, history |
|
|
|
|
|
with gr.Blocks(title="Simple Chat App") as demo: |
|
gr.Markdown("# Simple AI Chat App") |
|
gr.Markdown("Chat with an AI powered by DistilGPT-2!") |
|
|
|
|
|
chatbot = gr.Chatbot(label="Conversation") |
|
|
|
|
|
user_input = gr.Textbox(label="Your message", placeholder="Type here...") |
|
|
|
|
|
history = gr.State(value=[]) |
|
|
|
|
|
submit_btn = gr.Button("Send") |
|
|
|
|
|
clear_btn = gr.Button("Clear Chat") |
|
|
|
|
|
submit_btn.click( |
|
fn=chat_function, |
|
inputs=[user_input, history], |
|
outputs=[chatbot, history] |
|
) |
|
|
|
clear_btn.click( |
|
fn=lambda: ([], []), |
|
inputs=None, |
|
outputs=[chatbot, history] |
|
) |
|
|
|
|
|
demo.launch() |