| import gradio as gr |
| from transformers import BartTokenizer, BartForConditionalGeneration, AutoModelForSeq2SeqLM, AutoTokenizer |
| import torch |
| from langchain.memory import ConversationBufferMemory |
|
|
| |
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
|
| |
| tokenizer = BartTokenizer.from_pretrained("facebook/bart-base") |
| model = BartForConditionalGeneration.from_pretrained("facebook/bart-base") |
| model.to(device) |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| memory = ConversationBufferMemory() |
|
|
| |
| def chat_with_bart(input_text): |
| |
| conversation_history = memory.load_memory_variables({})['history'] |
| |
| |
| |
| |
| |
| |
| full_input = f"{conversation_history}\nUser: {input_text}\nAssistant:" |
| |
| |
| inputs = tokenizer(full_input, return_tensors="pt", max_length=1024, truncation=True).to(device) |
| |
| |
| outputs = model.generate( |
| inputs["input_ids"], |
| max_length=1024, |
| num_beams=4, |
| early_stopping=True, |
| no_repeat_ngram_size=3, |
| repetition_penalty=1.2 |
| |
| |
| |
| |
| ) |
| |
| |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| |
| |
| memory.save_context({"input": input_text}, {"output": response}) |
| |
| return response |
|
|
| |
| interface = gr.Interface( |
| fn=chat_with_bart, |
| inputs=gr.Textbox(label="Chat with BART Base"), |
| outputs=gr.Textbox(label="BART Base's Response"), |
| title="BART Base Chatbot with Memory", |
| description="This is a simple chatbot powered by the BART Base model with conversational memory, using LangChain.", |
| ) |
|
|
| |
| interface.launch() |
|
|
|
|
|
|
|
|