Spaces:
Sleeping
Sleeping
import os | |
from langgraph.graph import StateGraph, END | |
from langgraph.checkpoint.memory import MemorySaver | |
# from .prompt import SYSTEM_PROMPT | |
import asyncio | |
from .agents.supervisor import SupervisorAgent | |
from .agents.verification import VerificationAgent, process_tool, verification_route | |
from .agents.medical import MedicalQuestionAgent, medical_route | |
from .agents.rag import RAGTool | |
from .agents.state.state import GraphState | |
from data.preprocessing.vectorstore.get import retriever | |
from langchain_openai import ChatOpenAI | |
from .upload_pdf.ingest_documents import PDFProcessor | |
pdf_processor = PDFProcessor(file_path=os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'data', 'combined_forms', 'temp', 'ACTC-Patient-Packet.pdf'))) | |
questions = pdf_processor.extract_questions() | |
questions = [q.content for q in questions] | |
print('QUESTIONS**********************', questions) | |
memory = MemorySaver() | |
graph = StateGraph(GraphState) | |
supervisor = SupervisorAgent() | |
graph.add_node("supervisor_agent", supervisor) | |
graph.add_node("verification_agent", VerificationAgent()) | |
graph.add_node("verification_tool_node", process_tool) | |
graph.add_node("medical_agent", MedicalQuestionAgent(questions=questions)) | |
graph.add_node("rag_tool_node", RAGTool(retriever=retriever, | |
llm=ChatOpenAI(model=os.environ["MODEL"]))) | |
graph.set_entry_point("supervisor_agent") | |
graph.add_edge("verification_tool_node", "verification_agent") | |
graph.add_edge("rag_tool_node", "medical_agent") | |
graph.add_conditional_edges( | |
'supervisor_agent', | |
supervisor.route | |
) | |
graph.add_conditional_edges( | |
"verification_agent", | |
verification_route, | |
{"__end__": END, "verification_tool_node": "verification_tool_node"} | |
) | |
graph.add_conditional_edges( | |
"medical_agent", | |
medical_route, | |
{"__end__": END, "rag_tool_node": "rag_tool_node"} | |
) | |
async def run_verfication(app, fields="", values=""): | |
config = {"configurable": {"thread_id": 1}} | |
_input = input('User: ') | |
while _input != 'quit': | |
async for event in app.astream_events({"messages": [('user', _input)], "fields": "full name, birthdate", "values": "John Doe, 1990-01-01"}, config=config, version="v2"): | |
if event['event'] == "on_chat_model_stream": | |
data = event["data"] | |
if data["chunk"].content: | |
print(data["chunk"].content.replace( | |
"\n", ""), end="", flush=True) | |
_input = input('\nUser: ') | |
async def run(app): | |
from langchain_core.messages import AIMessageChunk, HumanMessage | |
config = {"configurable": {"thread_id": 1}} | |
_user_input = input("User: ") | |
while _user_input != "quit": | |
out="" | |
astream = app.astream({"messages": [HumanMessage(content=_user_input)], "fields":"full name, birthdate", "values":"John Doe, 1990-01-01"}, config=config, stream_mode="messages") | |
async for msg, metadata in astream: | |
if isinstance(msg, AIMessageChunk): | |
out+=msg.content | |
print('Assistant: ', out) | |
_user_input = input("User: ") | |
if __name__ == "__main__": | |
app = graph.compile(checkpointer=memory) | |
asyncio.run(run(app)) | |