ALVHB95 commited on
Commit
9aab6a2
·
verified ·
1 Parent(s): 44b0493

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -18
app.py CHANGED
@@ -2,6 +2,40 @@ import gradio as gr
2
  from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  # Cell 1: Image Classification Model
6
  image_pipeline = pipeline(task="image-classification", model="julien-c/hotdog-not-hotdog")
7
 
@@ -17,26 +51,77 @@ image_gradio_app = gr.Interface(
17
  )
18
 
19
  # Cell 2: Chatbot Model
20
- tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
21
- chatbot_model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
22
-
23
- def predict_chatbot(input, history=[]):
24
- new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')
25
- bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
26
- history = chatbot_model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id).tolist()
27
- response = tokenizer.decode(history[0]).split("")
28
-
29
- response_tuples = [(response[i], response[i+1]) for i in range(0, len(response)-1, 2)]
30
- return response_tuples, history
31
-
32
- chatbot_gradio_app = gr.Interface(
33
- fn=predict_chatbot,
34
- inputs=gr.Textbox(show_label=False, placeholder="Enter text and press enter"),
35
- outputs=gr.Textbox(),
36
- live=True,
37
- title="Chatbot",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
 
40
 
41
  # Combine both interfaces into a single app
42
  gr.TabbedInterface(
 
2
  from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
+ import os
6
+ import sys
7
+ sys.path.append('../..')
8
+
9
+ import panel as pn # GUI
10
+ pn.extension()
11
+
12
+ from dotenv import load_dotenv, find_dotenv
13
+ _ = load_dotenv(find_dotenv()) # read local .env file
14
+
15
+ #langchain
16
+ from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter
17
+ from langchain.embeddings import HuggingFaceEmbeddings
18
+ from langchain.prompts import PromptTemplate
19
+ from langchain.chains import RetrievalQA
20
+ from langchain.prompts import ChatPromptTemplate
21
+ from langchain.schema import StrOutputParser
22
+ from langchain.schema.runnable import Runnable
23
+ from langchain.schema.runnable.config import RunnableConfig
24
+ from langchain.chains import (
25
+ LLMChain, ConversationalRetrievalChain)
26
+ from langchain.vectorstores import Chroma
27
+ from langchain.memory import ConversationBufferMemory
28
+ from langchain.chains import LLMChain
29
+ from langchain.prompts.prompt import PromptTemplate
30
+ from langchain.prompts.chat import ChatPromptTemplate, SystemMessagePromptTemplate
31
+ from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
32
+ from langchain.document_loaders import PyPDFDirectoryLoader
33
+
34
+ from langchain_community.llms import HuggingFaceHub
35
+
36
+ from pydantic import BaseModel
37
+ import shutil
38
+
39
  # Cell 1: Image Classification Model
40
  image_pipeline = pipeline(task="image-classification", model="julien-c/hotdog-not-hotdog")
41
 
 
51
  )
52
 
53
  # Cell 2: Chatbot Model
54
+
55
+ loader = PyPDFDirectoryLoader(r"TFM_DataScience_APP\pdfs")
56
+ data=loader.load()
57
+ # split documents
58
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
59
+ docs = text_splitter.split_documents(data)
60
+ # define embedding
61
+ embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-l6-v2')
62
+ # create vector database from data
63
+ persist_directory = 'docs/chroma/'
64
+
65
+ # Remove old database files if any
66
+ shutil.rmtree(persist_directory, ignore_errors=True)
67
+ vectordb = Chroma.from_documents(
68
+ documents=docs,
69
+ embedding=embeddings,
70
+ persist_directory=persist_directory
71
+ )
72
+ # define retriever
73
+ retriever = vectordb.as_retriever(search_type="mmr")
74
+ template = """Your name is AngryGreta and you are a recycling chatbot created to help people. Use the following pieces of context to answer the question at the end. Answer in the same language of the question. Keep the answer as concise as possible. Always say "thanks for asking!" at the end of the answer.
75
+ CONTEXT: {context}
76
+ CHAT HISTORY:
77
+ {chat_history}
78
+ Question: {question}
79
+ Helpful Answer:"""
80
+
81
+ # Create the chat prompt templates
82
+ system_prompt = SystemMessagePromptTemplate.from_template(prompt_template)
83
+ qa_prompt = ChatPromptTemplate(
84
+ messages=[
85
+ system_prompt,
86
+ MessagesPlaceholder(variable_name="chat_history"),
87
+ HumanMessagePromptTemplate.from_template("{question}")
88
+
89
+
90
+ llm = HuggingFaceHub(
91
+ repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
92
+ task="text-generation",
93
+ model_kwargs={
94
+ "max_new_tokens": 512,
95
+ "top_k": 30,
96
+ "temperature": 0.1,
97
+ "repetition_penalty": 1.03,
98
+ },
99
  )
100
+ llm_chain = LLMChain(llm=llm, prompt=qa_prompt)
101
+
102
+ memory = ConversationBufferMemory(llm=llm, memory_key="chat_history", output_key='answer', return_messages=True)
103
+
104
+ qa_chain = ConversationalRetrievalChain.from_llm(
105
+ llm = llm,
106
+ memory = memory,
107
+ retriever = retriever,
108
+ verbose = True,
109
+ combine_docs_chain_kwargs={'prompt': qa_prompt},
110
+ get_chat_history = lambda h : h
111
+ )
112
+
113
+ def qa_response(question, chat_history):
114
+ # Add user's question to chat history
115
+ chat_history.append(("User", question))
116
+
117
+ # Get response from the conversational retrieval chain
118
+ response = qa_chain.run(chat_history)
119
+
120
+ # Extract and return the assistant's answer from the response
121
+ assistant_answer = response.get("answer", "I don't know.")
122
+ return assistant_answer
123
 
124
+ chatbot_gradio_app = gr.ChatInterface(qa_response)
125
 
126
  # Combine both interfaces into a single app
127
  gr.TabbedInterface(