Spaces:
Running
Running
import logging | |
import mimetypes | |
import os | |
import re | |
import shutil | |
import gradio as gr | |
from gradio_pdf import PDF | |
from huggingface_hub import login | |
from smolagents.gradio_ui import _process_action_step, _process_final_answer_step | |
from smolagents.memory import ActionStep, FinalAnswerStep, MemoryStep, PlanningStep | |
from smolagents.models import ChatMessageStreamDelta | |
# from smolagents import CodeAgent, InferenceClientModel | |
from src.insurance_assistants.agents import manager_agent | |
from src.insurance_assistants.consts import ( | |
PRIMARY_HEADING, | |
PROJECT_ROOT_DIR, | |
PROMPT_PREFIX, | |
) | |
# load_dotenv(override=True) | |
# Setup logger | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class UI: | |
"""A one-line interface to launch your agent in Gradio""" | |
def __init__(self, file_upload_folder: str | None = None): | |
self.file_upload_folder = file_upload_folder | |
if self.file_upload_folder is not None: | |
if not os.path.exists(file_upload_folder): | |
os.mkdir(file_upload_folder) | |
def pull_messages_from_step( | |
self, step_log: MemoryStep, skip_model_outputs: bool = False | |
): | |
"""Extract ChatMessage objects from agent steps with proper nesting. | |
Args: | |
step_log: The step log to display as gr.ChatMessage objects. | |
skip_model_outputs: If True, skip the model outputs when creating the gr.ChatMessage objects: | |
This is used for instance when streaming model outputs have already been displayed. | |
""" | |
if isinstance(step_log, ActionStep): | |
yield from _process_action_step(step_log, skip_model_outputs) | |
elif isinstance(step_log, PlanningStep): | |
pass | |
# yield from _process_planning_step(step_log, skip_model_outputs) | |
elif isinstance(step_log, FinalAnswerStep): | |
yield from _process_final_answer_step(step_log) | |
else: | |
raise ValueError(f"Unsupported step type: {type(step_log)}") | |
def stream_to_gradio( | |
self, | |
agent, | |
task: str, | |
task_images: list | None = None, | |
reset_agent_memory: bool = False, | |
additional_args: dict | None = None, | |
): | |
"""Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages.""" | |
intermediate_text = "" | |
for step_log in agent.run( | |
task, | |
images=task_images, | |
stream=True, | |
reset=reset_agent_memory, | |
additional_args=additional_args, | |
): | |
# Track tokens if model provides them | |
if getattr(agent.model, "last_input_token_count", None) is not None: | |
if isinstance(step_log, (ActionStep, PlanningStep)): | |
step_log.input_token_count = agent.model.last_input_token_count | |
step_log.output_token_count = agent.model.last_output_token_count | |
if isinstance(step_log, MemoryStep): | |
intermediate_text = "" | |
for message in self.pull_messages_from_step( | |
step_log, | |
# If we're streaming model outputs, no need to display them twice | |
skip_model_outputs=getattr(agent, "stream_outputs", False), | |
): | |
yield message | |
elif isinstance(step_log, ChatMessageStreamDelta): | |
intermediate_text += step_log.content or "" | |
yield intermediate_text | |
def interact_with_agent(self, prompt, messages, session_state, api_key): | |
# Get or create session-specific agent | |
if not api_key or not api_key.startswith("hf"): | |
raise ValueError("Incorrect HuggingFace Inference API Key") | |
# Login to Hugging Face with the provided API key | |
login(token=api_key) | |
if "agent" not in session_state: | |
# session_state["agent"] = CodeAgent(tools=[], model=InfenceClientModel()) | |
session_state["agent"] = manager_agent | |
session_state["agent"].system_prompt = ( | |
session_state["agent"].system_prompt + PROMPT_PREFIX | |
) | |
# Adding monitoring | |
try: | |
# log the existence of agent memory | |
has_memory = hasattr(session_state["agent"], "memory") | |
logger.info(f"Agent has memory: {has_memory}") | |
if has_memory: | |
logger.info(f"Memory type: {type(session_state['agent'].memory)}") | |
messages.append(gr.ChatMessage(role="user", content=prompt)) | |
yield messages | |
for msg in self.stream_to_gradio( | |
agent=session_state["agent"], | |
task=prompt, | |
reset_agent_memory=False, | |
): | |
messages.append(msg) | |
yield messages | |
yield messages | |
except Exception as e: | |
logger.info(f"Error in interaction: {str(e)}") | |
raise | |
def upload_file( | |
self, | |
file, | |
file_uploads_log, | |
allowed_file_types=[ | |
"application/pdf", | |
"application/vnd.openxmlformats-officedocument.wordprocessingml.document", | |
"text/plain", | |
], | |
): | |
""" | |
Handle file uploads, default allowed types are .pdf, .docx, and .txt | |
""" | |
if file is None: | |
return gr.Textbox("No file uploaded", visible=True), file_uploads_log | |
try: | |
mime_type, _ = mimetypes.guess_type(file.name) | |
except Exception as e: | |
return gr.Textbox(f"Error: {e}", visible=True), file_uploads_log | |
if mime_type not in allowed_file_types: | |
return gr.Textbox("File type disallowed", visible=True), file_uploads_log | |
# Sanitize file name | |
original_name = os.path.basename(file.name) | |
sanitized_name = re.sub( | |
r"[^\w\-.]", "_", original_name | |
) # Replace any non-alphanumeric, non-dash, or non-dot characters with underscores | |
type_to_ext = {} | |
for ext, t in mimetypes.types_map.items(): | |
if t not in type_to_ext: | |
type_to_ext[t] = ext | |
# Ensure the extension correlates to the mime type | |
sanitized_name = sanitized_name.split(".")[:-1] | |
sanitized_name.append("" + type_to_ext[mime_type]) | |
sanitized_name = "".join(sanitized_name) | |
# Save the uploaded file to the specified folder | |
file_path = os.path.join( | |
self.file_upload_folder, os.path.basename(sanitized_name) | |
) | |
shutil.copy(file.name, file_path) | |
return gr.Textbox( | |
f"File uploaded: {file_path}", visible=True | |
), file_uploads_log + [file_path] | |
def log_user_message(self, text_input, file_uploads_log): | |
return ( | |
text_input | |
+ ( | |
f"\nYou have been provided with these files, which might be helpful or not: {file_uploads_log}" | |
if len(file_uploads_log) > 0 | |
else "" | |
), | |
gr.Textbox( | |
value="", | |
interactive=False, | |
placeholder="Please wait while the agent answers your question", | |
), | |
gr.Button(interactive=False), | |
) | |
def list_pdfs(self, dir=PROJECT_ROOT_DIR / "data/policy_wordings"): | |
file_names = [f.name for f in dir.iterdir()] | |
return file_names | |
def interrupt_agent(self, session_state): | |
if "agent" not in session_state: | |
session_state["agent"] = manager_agent | |
agent = session_state["agent"] | |
agent.interrupt() | |
return | |
def display_pdf(self, pdf_selector): | |
return PDF( | |
value=(f"{PROJECT_ROOT_DIR}/data/policy_wordings/{pdf_selector}"), | |
label="PDF Viewer", | |
show_label=True, | |
) | |
def launch(self, **kwargs): | |
with gr.Blocks(fill_height=True) as demo: | |
gr.Markdown(value=PRIMARY_HEADING) | |
def layout(request: gr.Request): | |
# Render layout with sidebar | |
with gr.Blocks( | |
fill_height=True, | |
): | |
file_uploads_log = gr.State([]) | |
with gr.Sidebar(): | |
gr.Markdown( | |
value="""#### <span style="color:red"> The `interrupt` button doesn't stop the process instantaneously.</span> | |
<span style="color:green">You can continue to use the application upon pressing the interrupt button.</span> | |
<span style="color:violet">PRECISE PROMPT = ACCURATE RESULTS.</span> | |
""" | |
) | |
with gr.Group(): | |
api_key = gr.Textbox( | |
placeholder="Enter your HuggingFace Inference API KEY HERE", | |
label="π€ Inference API Key", | |
show_label=True, | |
type="password", | |
) | |
gr.Markdown( | |
value="**Your question, please...**", container=True | |
) | |
text_input = gr.Textbox( | |
lines=3, | |
label="Your question, please...", | |
container=False, | |
placeholder="Enter your prompt here and press Shift+Enter or press `Run`", | |
) | |
run_btn = gr.Button(value="Run", variant="primary") | |
agent_interrup_btn = gr.Button( | |
value="Interrupt", variant="stop" | |
) | |
# If an upload folder is provided, enable the upload feature | |
if self.file_upload_folder is not None: | |
upload_file = gr.File(label="Upload a file") | |
upload_status = gr.Textbox( | |
label="Upload Status", | |
interactive=False, | |
visible=False, | |
) | |
upload_file.change( | |
fn=self.upload_file, | |
inputs=[upload_file, file_uploads_log], | |
outputs=[upload_status, file_uploads_log], | |
) | |
gr.HTML("<br><br><h4><center>Powered by:</center></h4>") | |
with gr.Row(): | |
gr.HTML("""<div style="display: flex; align-items: center; gap: 8px; font-family: system-ui, -apple-system, sans-serif;"> | |
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/smolagents/mascot_smol.png" style="width: 32px; height: 32px; object-fit: contain;" alt="logo"> | |
<a target="_blank" href="https://github.com/huggingface/smolagents"><b>huggingface/smolagents</b></a> | |
</div>""") | |
# Add session state to store session-specific data | |
session_state = gr.State({}) | |
# Initialize empty state for each session | |
stored_messages = gr.State([]) | |
chatbot = gr.Chatbot( | |
label="Health Insurance Agent", | |
type="messages", | |
avatar_images=( | |
None, | |
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/smolagents/mascot_smol.png", | |
), | |
resizeable=False, | |
scale=1, | |
elem_id="Insurance-Agent", | |
) | |
with gr.Group(): | |
gr.Markdown("### π PDF Viewer") | |
pdf_choices = self.list_pdfs() | |
pdf_selector = gr.Dropdown( | |
choices=pdf_choices, | |
label="Select a PDF", | |
info="Choose one", | |
show_label=True, | |
interactive=True, | |
) | |
pdf_viewer = PDF( | |
label="PDF Viewer", | |
show_label=True, | |
) | |
pdf_selector.change( | |
fn=self.display_pdf, inputs=pdf_selector, outputs=pdf_viewer | |
) | |
text_input.submit( | |
fn=self.log_user_message, | |
inputs=[text_input, file_uploads_log], | |
outputs=[stored_messages, text_input, run_btn], | |
).then( | |
fn=self.interact_with_agent, | |
# Include session_state in function calls | |
inputs=[stored_messages, chatbot, session_state, api_key], | |
outputs=[chatbot], | |
).then( | |
fn=lambda: ( | |
gr.Textbox( | |
interactive=True, | |
placeholder="Enter your prompt here or press `Run`", | |
), | |
gr.Button(interactive=True), | |
), | |
inputs=None, | |
outputs=[text_input, run_btn], | |
) | |
run_btn.click( | |
fn=self.log_user_message, | |
inputs=[text_input, file_uploads_log], | |
outputs=[stored_messages, text_input, run_btn], | |
).then( | |
fn=self.interact_with_agent, | |
# Include session_state in function calls | |
inputs=[stored_messages, chatbot, session_state, api_key], | |
outputs=[chatbot], | |
).then( | |
fn=lambda: ( | |
gr.Textbox( | |
interactive=True, | |
placeholder="Enter your prompt here or press `Run`", | |
), | |
gr.Button(interactive=True), | |
), | |
inputs=None, | |
outputs=[text_input, run_btn], | |
) | |
agent_interrup_btn.click( | |
fn=self.interrupt_agent, | |
inputs=[session_state], | |
) | |
demo.queue(max_size=4).launch(debug=False, **kwargs) | |
# if __name__ == "__main__": | |
# UI().launch( | |
# share=True, | |
# allowed_paths=[(PROJECT_ROOT_DIR / "data/policy_wordings").as_posix()], | |
# ) | |