Spaces:
Sleeping
Sleeping
File size: 1,739 Bytes
8b1e853 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import os
from dotenv import load_dotenv
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_openai import ChatOpenAI
from langchain_core.tools import tool
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
from langgraph.graph import END
from .prompt import SYSTEM_PROMPT, CONTEXT_PROMPT, QUESTION_PROMPT
load_dotenv()
MODEL = os.getenv("MODEL")
@tool
def user_query(query:str):
"""
Call this tool to retrieve the context of the conversation for the user's query which is an unambiguous and concise query with enough context from the message history.
"""
return query
@tool
def completed(**kwargs):
"""
Call this tool when allmedical questions have been completed.
"""
return True
tools_by_name = {
"user_query": user_query,
"completed": completed
}
def medical_route(state):
if not state["messages"]:
return END
last_message = state["messages"][-1]
if last_message.tool_calls:
return "rag_tool_node"
else:
return END
class MedicalQuestionAgent:
def __init__(self, questions=[]):
self.prompt = ChatPromptTemplate.from_messages([
("system", SYSTEM_PROMPT),
("system", QUESTION_PROMPT),
('system', CONTEXT_PROMPT),
MessagesPlaceholder(variable_name="messages")
])
self.llm = ChatOpenAI(model=MODEL, temperature=0, streaming=True)
self.chain = self.prompt | self.llm.bind_tools([user_query, completed])
self.questions = questions
def __call__(self, state):
results = self.chain.invoke({**state, "questions": self.questions})
return {**state, "messages":[results] }
|