First_agent_template / Gradio_UI.py
Tbaberca's picture
Update Gradio_UI.py
7f095bd verified
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import mimetypes
import os
import re
from typing import Optional
import gradio as gr
from smolagents.agent_types import AgentAudio, AgentImage, AgentText, handle_agent_output_types
from smolagents.agents import ActionStep
from smolagents.memory import MemoryStep
from smolagents.utils import _is_package_available
class GradioUI:
def __init__(self, agent, file_upload_folder: str = "./uploads"):
self.agent = agent
self.file_upload_folder = file_upload_folder
if not os.path.exists(self.file_upload_folder):
os.makedirs(self.file_upload_folder)
def interact_with_agent(self, prompt, messages):
messages.append(gr.ChatMessage(role="user", content=prompt))
yield messages
for msg in stream_to_gradio(self.agent, task=prompt, reset_agent_memory=False):
messages.append(msg)
yield messages
yield messages
def upload_file(
self,
file,
file_uploads_log,
allowed_file_types=[
"application/pdf",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"text/plain",
],
):
if file is None:
return gr.Textbox("No file uploaded", visible=True), file_uploads_log
try:
mime_type, _ = mimetypes.guess_type(file.name)
except Exception as e:
return gr.Textbox(f"Error: {e}", visible=True), file_uploads_log
if mime_type not in allowed_file_types:
return gr.Textbox("File type disallowed", visible=True), file_uploads_log
original_name = os.path.basename(file.name)
sanitized_name = re.sub(r"[^\w\-.]", "_", original_name)
type_to_ext = {}
for ext, t in mimetypes.types_map.items():
if t not in type_to_ext:
type_to_ext[t] = ext
name_without_ext = ".".join(sanitized_name.split(".")[:-1])
ext = type_to_ext.get(mime_type, "")
if not ext.startswith("."):
ext = "." + ext if ext else ""
sanitized_name = f"{name_without_ext}{ext}"
file_path = os.path.join(self.file_upload_folder, sanitized_name)
with open(file_path, "wb") as f:
f.write(file.read())
file_uploads_log.append(file_path)
return gr.Textbox(f"File uploaded: {sanitized_name}", visible=True), file_uploads_log
def build_ui(self):
chatbot = gr.Chatbot()
msg = gr.Textbox(placeholder="Ask something...", label="Your message")
clear = gr.Button("Clear")
file_uploads_log = []
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
chatbot.render()
msg.render()
clear.render()
msg.submit(self.interact_with_agent, [msg, chatbot], chatbot)
clear.click(lambda: None, None, chatbot, queue=False)
return demo
def pull_messages_from_step(step_log: MemoryStep):
if isinstance(step_log, ActionStep):
step_number = f"Step {step_log.step_number}" if step_log.step_number is not None else ""
yield gr.ChatMessage(role="assistant", content=f"**{step_number}**")
if hasattr(step_log, "model_output") and step_log.model_output is not None:
model_output = step_log.model_output.strip()
model_output = re.sub(r"```\s*<end_code>", "```", model_output)
model_output = re.sub(r"<end_code>\s*```", "```", model_output)
model_output = re.sub(r"```\s*\n\s*<end_code>", "```", model_output)
model_output = model_output.strip()
yield gr.ChatMessage(role="assistant", content=model_output)
if hasattr(step_log, "tool_calls") and step_log.tool_calls is not None:
first_tool_call = step_log.tool_calls[0]
used_code = first_tool_call.name == "python_interpreter"
parent_id = f"call_{len(step_log.tool_calls)}"
args = first_tool_call.arguments
content = str(args.get("answer", str(args))) if isinstance(args, dict) else str(args).strip()
if used_code:
content = re.sub(r"```.*?\n", "", content)
content = re.sub(r"\s*<end_code>\s*", "", content).strip()
if not content.startswith("```python"):
content = f"```python\n{content}\n```"