ALVHB95 commited on
Commit
d3abd97
·
1 Parent(s): 41a18cc
Files changed (1) hide show
  1. app.py +8 -10
app.py CHANGED
@@ -31,7 +31,7 @@ from pydantic import BaseModel
31
  import shutil
32
 
33
  # Cell 1: Image Classification Model
34
- image_pipeline = pipeline(task="image-classification", model="microsoft/resnet-50")
35
 
36
  def predict_image(input_img):
37
  predictions = image_pipeline(input_img)
@@ -68,10 +68,11 @@ vectordb = Chroma.from_documents(
68
  persist_directory=persist_directory
69
  )
70
  # define retriever
71
- retriever = vectordb.as_retriever(search_kwargs={"k": 3},search_type="mmr")
72
  prompt_template = """
73
  Your name is AngryGreta and you are a recycling chatbot with the objective to anwer questions from user in English or Spanish /
74
  Use the following pieces of context to answer the question if the question is related with recycling /
 
75
  Answer in the same language of the question /
76
  Always say "thanks for asking!" at the end of the answer /
77
  If the context is not relevant, please answer the question by using your own knowledge about the topic.
@@ -101,7 +102,7 @@ llm = HuggingFaceHub(
101
  },
102
  )
103
 
104
- memory = ConversationBufferMemory(llm=llm, memory_key="chat_history", input_key='question', output_key='answer')
105
 
106
  qa_chain = ConversationalRetrievalChain.from_llm(
107
  llm = llm,
@@ -110,17 +111,14 @@ qa_chain = ConversationalRetrievalChain.from_llm(
110
  verbose = True,
111
  combine_docs_chain_kwargs={'prompt': qa_prompt},
112
  get_chat_history = lambda h : h,
113
- rephrase_question = False
 
114
  )
115
 
116
- def chat_interface(question):
117
 
118
  result = qa_chain.invoke({"question": question})
119
-
120
- # Extract only the answer from the result
121
- answer = result.get('answer', None)
122
-
123
- return answer
124
 
125
  chatbot_gradio_app = gr.ChatInterface(
126
  fn=chat_interface,
 
31
  import shutil
32
 
33
  # Cell 1: Image Classification Model
34
+ image_pipeline = pipeline(task="image-classification", model="rocioadlc/TrashNet_ResNet152V2")
35
 
36
  def predict_image(input_img):
37
  predictions = image_pipeline(input_img)
 
68
  persist_directory=persist_directory
69
  )
70
  # define retriever
71
+ retriever = vectordb.as_retriever(search_type="mmr")
72
  prompt_template = """
73
  Your name is AngryGreta and you are a recycling chatbot with the objective to anwer questions from user in English or Spanish /
74
  Use the following pieces of context to answer the question if the question is related with recycling /
75
+ No more than two chunks of context /
76
  Answer in the same language of the question /
77
  Always say "thanks for asking!" at the end of the answer /
78
  If the context is not relevant, please answer the question by using your own knowledge about the topic.
 
102
  },
103
  )
104
 
105
+ memory = ConversationBufferMemory(llm=llm, memory_key="chat_history", input_key='question', output_key='answer', return_messages=True)
106
 
107
  qa_chain = ConversationalRetrievalChain.from_llm(
108
  llm = llm,
 
111
  verbose = True,
112
  combine_docs_chain_kwargs={'prompt': qa_prompt},
113
  get_chat_history = lambda h : h,
114
+ rephrase_question = False,
115
+ output_key = 'answer'
116
  )
117
 
118
+ def chat_interface(question,history):
119
 
120
  result = qa_chain.invoke({"question": question})
121
+ return result['answer']
 
 
 
 
122
 
123
  chatbot_gradio_app = gr.ChatInterface(
124
  fn=chat_interface,