Spaces:
Running
Running
#!/usr/bin/env python | |
# coding=utf-8 | |
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import mimetypes | |
import os | |
import re | |
from typing import Optional | |
import gradio as gr | |
from smolagents.agent_types import AgentAudio, AgentImage, AgentText, handle_agent_output_types | |
from smolagents.agents import ActionStep | |
from smolagents.memory import MemoryStep | |
from smolagents.utils import _is_package_available | |
class GradioUI: | |
def __init__(self, agent, file_upload_folder: str = "./uploads"): | |
self.agent = agent | |
self.file_upload_folder = file_upload_folder | |
if not os.path.exists(self.file_upload_folder): | |
os.makedirs(self.file_upload_folder) | |
def interact_with_agent(self, prompt, messages): | |
messages.append(gr.ChatMessage(role="user", content=prompt)) | |
yield messages | |
for msg in stream_to_gradio(self.agent, task=prompt, reset_agent_memory=False): | |
messages.append(msg) | |
yield messages | |
yield messages | |
def upload_file( | |
self, | |
file, | |
file_uploads_log, | |
allowed_file_types=[ | |
"application/pdf", | |
"application/vnd.openxmlformats-officedocument.wordprocessingml.document", | |
"text/plain", | |
], | |
): | |
if file is None: | |
return gr.Textbox("No file uploaded", visible=True), file_uploads_log | |
try: | |
mime_type, _ = mimetypes.guess_type(file.name) | |
except Exception as e: | |
return gr.Textbox(f"Error: {e}", visible=True), file_uploads_log | |
if mime_type not in allowed_file_types: | |
return gr.Textbox("File type disallowed", visible=True), file_uploads_log | |
original_name = os.path.basename(file.name) | |
sanitized_name = re.sub(r"[^\w\-.]", "_", original_name) | |
type_to_ext = {} | |
for ext, t in mimetypes.types_map.items(): | |
if t not in type_to_ext: | |
type_to_ext[t] = ext | |
name_without_ext = ".".join(sanitized_name.split(".")[:-1]) | |
ext = type_to_ext.get(mime_type, "") | |
if not ext.startswith("."): | |
ext = "." + ext if ext else "" | |
sanitized_name = f"{name_without_ext}{ext}" | |
file_path = os.path.join(self.file_upload_folder, sanitized_name) | |
with open(file_path, "wb") as f: | |
f.write(file.read()) | |
file_uploads_log.append(file_path) | |
return gr.Textbox(f"File uploaded: {sanitized_name}", visible=True), file_uploads_log | |
def build_ui(self): | |
chatbot = gr.Chatbot() | |
msg = gr.Textbox(placeholder="Ask something...", label="Your message") | |
clear = gr.Button("Clear") | |
file_uploads_log = [] | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
chatbot.render() | |
msg.render() | |
clear.render() | |
msg.submit(self.interact_with_agent, [msg, chatbot], chatbot) | |
clear.click(lambda: None, None, chatbot, queue=False) | |
return demo | |
def pull_messages_from_step(step_log: MemoryStep): | |
if isinstance(step_log, ActionStep): | |
step_number = f"Step {step_log.step_number}" if step_log.step_number is not None else "" | |
yield gr.ChatMessage(role="assistant", content=f"**{step_number}**") | |
if hasattr(step_log, "model_output") and step_log.model_output is not None: | |
model_output = step_log.model_output.strip() | |
model_output = re.sub(r"```\s*<end_code>", "```", model_output) | |
model_output = re.sub(r"<end_code>\s*```", "```", model_output) | |
model_output = re.sub(r"```\s*\n\s*<end_code>", "```", model_output) | |
model_output = model_output.strip() | |
yield gr.ChatMessage(role="assistant", content=model_output) | |
if hasattr(step_log, "tool_calls") and step_log.tool_calls is not None: | |
first_tool_call = step_log.tool_calls[0] | |
used_code = first_tool_call.name == "python_interpreter" | |
parent_id = f"call_{len(step_log.tool_calls)}" | |
args = first_tool_call.arguments | |
content = str(args.get("answer", str(args))) if isinstance(args, dict) else str(args).strip() | |
if used_code: | |
content = re.sub(r"```.*?\n", "", content) | |
content = re.sub(r"\s*<end_code>\s*", "", content).strip() | |
if not content.startswith("```python"): | |
content = f"```python\n{content}\n```" | |