#!/usr/bin/env python # coding=utf-8 # Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import mimetypes import os import re from typing import Optional import gradio as gr from smolagents.agent_types import AgentAudio, AgentImage, AgentText, handle_agent_output_types from smolagents.agents import ActionStep from smolagents.memory import MemoryStep from smolagents.utils import _is_package_available class GradioUI: def __init__(self, agent, file_upload_folder: str = "./uploads"): self.agent = agent self.file_upload_folder = file_upload_folder if not os.path.exists(self.file_upload_folder): os.makedirs(self.file_upload_folder) def interact_with_agent(self, prompt, messages): messages.append(gr.ChatMessage(role="user", content=prompt)) yield messages for msg in stream_to_gradio(self.agent, task=prompt, reset_agent_memory=False): messages.append(msg) yield messages yield messages def upload_file( self, file, file_uploads_log, allowed_file_types=[ "application/pdf", "application/vnd.openxmlformats-officedocument.wordprocessingml.document", "text/plain", ], ): if file is None: return gr.Textbox("No file uploaded", visible=True), file_uploads_log try: mime_type, _ = mimetypes.guess_type(file.name) except Exception as e: return gr.Textbox(f"Error: {e}", visible=True), file_uploads_log if mime_type not in allowed_file_types: return gr.Textbox("File type disallowed", visible=True), file_uploads_log original_name = os.path.basename(file.name) sanitized_name = re.sub(r"[^\w\-.]", "_", original_name) type_to_ext = {} for ext, t in mimetypes.types_map.items(): if t not in type_to_ext: type_to_ext[t] = ext name_without_ext = ".".join(sanitized_name.split(".")[:-1]) ext = type_to_ext.get(mime_type, "") if not ext.startswith("."): ext = "." + ext if ext else "" sanitized_name = f"{name_without_ext}{ext}" file_path = os.path.join(self.file_upload_folder, sanitized_name) with open(file_path, "wb") as f: f.write(file.read()) file_uploads_log.append(file_path) return gr.Textbox(f"File uploaded: {sanitized_name}", visible=True), file_uploads_log def build_ui(self): chatbot = gr.Chatbot() msg = gr.Textbox(placeholder="Ask something...", label="Your message") clear = gr.Button("Clear") file_uploads_log = [] with gr.Blocks() as demo: with gr.Row(): with gr.Column(): chatbot.render() msg.render() clear.render() msg.submit(self.interact_with_agent, [msg, chatbot], chatbot) clear.click(lambda: None, None, chatbot, queue=False) return demo def pull_messages_from_step(step_log: MemoryStep): if isinstance(step_log, ActionStep): step_number = f"Step {step_log.step_number}" if step_log.step_number is not None else "" yield gr.ChatMessage(role="assistant", content=f"**{step_number}**") if hasattr(step_log, "model_output") and step_log.model_output is not None: model_output = step_log.model_output.strip() model_output = re.sub(r"```\s*", "```", model_output) model_output = re.sub(r"\s*```", "```", model_output) model_output = re.sub(r"```\s*\n\s*", "```", model_output) model_output = model_output.strip() yield gr.ChatMessage(role="assistant", content=model_output) if hasattr(step_log, "tool_calls") and step_log.tool_calls is not None: first_tool_call = step_log.tool_calls[0] used_code = first_tool_call.name == "python_interpreter" parent_id = f"call_{len(step_log.tool_calls)}" args = first_tool_call.arguments content = str(args.get("answer", str(args))) if isinstance(args, dict) else str(args).strip() if used_code: content = re.sub(r"```.*?\n", "", content) content = re.sub(r"\s*\s*", "", content).strip() if not content.startswith("```python"): content = f"```python\n{content}\n```"