ask-candid / ask_candid /tools /question_reformulation.py
brainsqueeze's picture
Smarter document context retrieval
f86d7f2 verified
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.language_models.llms import LLM
from ask_candid.agents.schema import AgentState
def reformulate_question_using_history(
state: AgentState,
llm: LLM,
focus_on_recommendations: bool = False
) -> AgentState:
"""Transform the query to produce a better query with details from previous messages and emphasize aspects important
for recommendations if needed.
Parameters
----------
state : AgentState
The current state
llm : LLM
focus_on_recommendations : bool, optional
Flag to determine if the reformulation should emphasize recommendation-relevant aspects such as geographies,
cause areas, etc., by default False
Returns
-------
AgentState
The updated state
"""
print("---REFORMULATE THE USER INPUT---")
messages = state["messages"]
question = messages[-1].content
if len(messages[:-1]) > 1: # need to skip the system message
if focus_on_recommendations:
prompt_text = """Given a chat history and the latest user input which might reference context in the chat
history, especially geographic locations, cause areas and/or population groups, formulate a standalone input
which can be understood without the chat history.
Chat history: ```{chat_history}```
User input: ```{question}```
Reformulate the question without adding implications or assumptions about the user's needs or intentions.
Focus solely on clarifying any contextual details present in the original input."""
else:
prompt_text = """Given a chat history and the latest user input which might reference context in the chat
history, formulate a standalone input which can be understood without the chat history. Include hints as to
what the user is getting at given the context in the chat history.
Chat history: ```{chat_history}```
User input: ```{question}```
Do NOT answer the question, just reformulate it if needed and otherwise return it as is.
"""
contextualize_q_prompt = ChatPromptTemplate([
("system", prompt_text),
("human", question),
])
rag_chain = contextualize_q_prompt | llm | StrOutputParser()
# new_question = rag_chain.invoke({"chat_history": messages, "question": question})
new_question = rag_chain.invoke({
"chat_history": '\n'.join(f"{m.type.upper()}: {m.content}" for m in messages[1:]),
"question": question
})
print(f"user asked: '{question}', agent reformulated the question basing on the chat history: {new_question}")
return {"messages": [new_question], "user_input" : question}
return {"messages": [question], "user_input" : question}