File size: 3,795 Bytes
8b1e853
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
from dotenv import load_dotenv
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_openai import ChatOpenAI
from langchain_core.tools import tool
from langchain_core.messages import HumanMessage, ToolMessage
from langgraph.graph import END
from .prompt import VERIFICATION_PROMPT, SYSTEM_PROMPT
load_dotenv()
MODEL = os.getenv("MODEL")

@tool
def invalid(field:str, value:str, counter=1):
    """
    Call this tool if the user's response is not valid for one of the fields you are verifying.
    """
    if counter >= 2:
        return f"The user's response for {field} with value {value} is not valid. Politely end the conversation and after ask them to call the support number."
    else:
        return f"The user's response for {field} with value {value} is not valid. Indicate to the user that it does not match our records. Please ask the user one more time."

@tool
def completed(**kwargs):
    """
    Call this tool when verification is complete and successful.
    """
    return "The verification is complete. Moving on to medical questions."

tools_by_name = {
    "invalid": invalid,
    "completed": completed
}

def verification_route(state):
    if not state["messages"]:
        return END
    last_message = state["messages"][-1]
    if last_message.tool_calls:
        return "verification_tool_node"
    else:
        return END

class VerificationAgent:
    def __init__(self):
        self.prompt = ChatPromptTemplate.from_messages([
            ("system", SYSTEM_PROMPT),
            ("system", VERIFICATION_PROMPT),
            ("system", "Fields:{fields}"),
            ("system", "Values:{values}"),
            MessagesPlaceholder(variable_name="messages")
        ])
        self.llm = ChatOpenAI(model=MODEL, temperature=0, streaming=True)
        self.chain = self.prompt | self.llm.bind_tools([invalid, completed])
        
    def __call__(self, state):
        result = self.chain.invoke(state)
        if not state.get("counter") or not result.tool_calls:
            state["counter"] = 0
        return {**state, "messages": [result]}
    
def process_tool(state): 
    last_message = state["messages"][-1]
    state["counter"] = state.get("counter")+1
    #print('LAST MESSAGE**********************', last_message)
    messages = []
    for tool_call in last_message.tool_calls:

        if tool_call["name"] == "invalid":
            #print('TOOL CALL**********************', tools_by_name[tool_call["name"]].invoke({**tool_call["args"], "counter": state["counter"]}))
            message = tools_by_name[tool_call["name"]].invoke({**tool_call["args"], "counter": state["counter"]})
            if state["counter"] >= 2:
                state["counter"] = 0
                messages.append(ToolMessage(name=tool_call["name"], tool_call_id=tool_call["id"], content=message))
            else:
                state["counter"] += 1
                messages.append(ToolMessage(name=tool_call["name"], tool_call_id=tool_call["id"],  content=f"The user's response for {tool_call['args']['field']} is not valid. Indicate to the user that it does not match our records. Please ask the user one more time."))
        elif tool_call["name"] == "completed":
            state["next"]+=1
            print("COMPLETED!!!!!", state["next"])
            messages.append(ToolMessage(name=tool_call["name"], tool_call_id=tool_call["id"],  content="Verification complete. Prompt the user that we are moving on to medical questions. Do not end with a farwell. Mention that during the next stage the patient can ask any questions they have."))
        else:
            messages.append(ToolMessage(name=tool_call["name"], tool_call_id=tool_call["id"],  content=""))
    return {**state, "messages": messages}