Spaces:
Sleeping
Sleeping
import gradio as gr | |
from openai import OpenAI | |
import os | |
import base64 | |
import time | |
import copy | |
import re | |
from dotenv import load_dotenv | |
# Load environment variables from .env file | |
load_dotenv() | |
from agents import rag_decision | |
from agents import get_top_k | |
from agents import get_prescription_text | |
from prompts import bot_welcome_message, openai_opening_system_message | |
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
with gr.Blocks() as demo: | |
################################ AGENTS ################################### | |
# Agent1 - RAG Decision Agent (whether RAG is needed for the user's query) | |
def agent1_rag_decision(query): | |
decision = rag_decision(query) | |
return decision | |
# Agent2 - RAG Retrieval Agent (retrieve top k relevant documents) | |
def agent2_use_rag(query, k=3): | |
results = get_top_k(query, k=k) | |
return results | |
# Agent3 - LLM Agent (get query response from LLM) | |
def agent3_llm_agent(messages): | |
response = client.chat.completions.create( | |
model="gpt-4o-mini", | |
messages=messages, | |
temperature=0.7 | |
) | |
return response.choices[0].message.content.strip() | |
def agent4_get_prescription_text(messages): | |
""" | |
Openai agent to get prescription text. | |
""" | |
prescription_text = get_prescription_text(messages) | |
return prescription_text | |
########################################################################### | |
def encode_image(image_path): | |
with open(image_path, "rb") as f: | |
return base64.b64encode(f.read()).decode("utf-8") | |
def load_welcome(): | |
history = [] | |
history.append({"role": "system", "content": openai_opening_system_message}) | |
history.append({"role": "assistant", "content": bot_welcome_message}) | |
return history | |
def clear_and_load(): | |
# Return the welcome message | |
history = [] | |
history.append({"role": "system", "content": openai_opening_system_message}) | |
history.append({"role": "assistant", "content": bot_welcome_message}) | |
return history, None | |
def add_message(history, message): | |
# Send the image to the agent4_get_prescription_text | |
messages = [] | |
if message["text"] is not None: | |
messages.append({ | |
"role": "user", | |
"content":[{"type": "text", "text": message["text"]}] | |
}) | |
for x in message["files"]: | |
encoded_content = encode_image(x) | |
messages[0]["content"].append({ | |
"type": "image_url", | |
"image_url": {"url": f"data:image/jpeg;base64,{encoded_content}"} | |
}) | |
history.append({"role": "user", "content": {"path": x}}) | |
# call agent4_get_prescription_text if there is an image_url in the message | |
has_image_url = any("image_url" in item for item in messages[0]["content"]) | |
if has_image_url: | |
prescription_text = agent4_get_prescription_text(messages) | |
history.append({"role": "system", "content": prescription_text}) | |
if message["text"] is not None: | |
history.append({"role": "user", "content": message["text"]}) | |
return history, gr.MultimodalTextbox(value=None, interactive=False, file_count="multiple", placeholder="Enter message or upload file...") | |
def respond(history): | |
if len(history) == 2: | |
history.insert(0,{"role": "system", "content": openai_opening_system_message}) | |
messages = copy.deepcopy(history) | |
for i, msg in enumerate(messages): | |
if isinstance(msg["content"], str): | |
# If the content is a string, encode it | |
messages[i]["content"] = [{ | |
"type": "text", | |
"text": msg["content"] | |
}] | |
if isinstance(msg["content"],tuple): | |
# If the content is a file path, encode it | |
# file_path = msg["content"][0] | |
# encoded_content = encode_image(file_path) | |
messages[i]["content"] = [{ | |
"type": "text", | |
"text": "User Image"}] | |
clean_messages = [] # OpenAI doesnot accept metadata or options in messages | |
for msg in messages: | |
clean_msg = { | |
"role": msg["role"], | |
"content": msg["content"] | |
} | |
clean_messages.append(clean_msg) | |
########################### AGENTIC WORKFLOW ########################## | |
# Call Agent1- the RAG Decision Agent | |
rag_query = "" | |
if clean_messages[-1]["role"] == "system" and "No prescription found" in clean_messages[-1]["content"]: | |
# If the last message is a system message with "No prescription found", skip RAG decision | |
rag_decision = False | |
elif clean_messages[-2]["role"] == "system" and "No prescription found" in clean_messages[-2]["content"]: | |
rag_decision = False | |
else: | |
# Get the last 10 messages in the format "role: <message>" | |
last_10 = clean_messages[-10:] if len(clean_messages) > 10 else clean_messages | |
rag_query = "\n".join( | |
f"{msg['role']}: {msg['content'][0]['text'] if isinstance(msg['content'], list) and msg['content'] and 'text' in msg['content'][0] else ''}" | |
for msg in last_10 | |
) | |
rag_decision = agent1_rag_decision(rag_query) | |
if rag_decision == True: | |
#Call Agent2 - the RAG Retrieval Agent | |
top_k_results = agent2_use_rag(clean_messages[-1]["content"][0]["text"], k=5) | |
# Append the top k results to the messages | |
for i, result in enumerate(top_k_results): | |
clean_messages.append({ | |
"role": "system", | |
"content": f"RAG Retrieved Result-{i+1}: " + result["content"] | |
}) | |
# Call Agent3 - the LLM Agent to get query response | |
response = agent3_llm_agent(clean_messages) | |
else: | |
# Call Agent3 - the LLM Agent to get query response | |
response = agent3_llm_agent(clean_messages) | |
####################################################################### | |
# history.append({"role": "assistant", "content": response}) | |
# return history | |
history.append({"role": "assistant", "content": ""}) | |
# Split by sentence boundaries (naive but works for most cases) | |
chunks = re.split(r'(?<=[.!?]) +', response) | |
for chunk in chunks: | |
history[-1]["content"] += chunk + " " | |
time.sleep(0.3) | |
yield history | |
########################################################################## | |
gr.Markdown( | |
""" | |
<h1 style='text-align: center; font-size: 1.5em; color: #2c3e50; margin-bottom: 0.2em;'> | |
MedScan Diagnostic Services Chatbot (Agentic AI framework powered by OpenAI) | |
</h1> | |
""" | |
) | |
chatbot = gr.Chatbot(type="messages", | |
render_markdown=True, | |
height=380) | |
chat_input = gr.MultimodalTextbox( | |
interactive=True, | |
file_count="multiple", | |
placeholder="Enter message or upload file...", | |
show_label=False | |
) | |
clear = gr.Button("New Chat") | |
clear.click( | |
clear_and_load, | |
inputs=None, | |
outputs=[chatbot, chat_input] | |
) | |
demo.load(load_welcome, None, chatbot, api_name="load_welcome") | |
chat_msg = chat_input.submit( | |
add_message, [chatbot, chat_input], [chatbot, chat_input] | |
) | |
bot_msg = chat_msg.then(respond, chatbot, chatbot, api_name="bot_response") | |
bot_msg.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input]) | |
if __name__ == "__main__": | |
demo.launch() |