import streamlit as st import os import requests import base64 import io import time from PIL import Image from pydub import AudioSegment import IPython import soundfile as sf from transformers import load_tool, Agent import torch class ToolLoader: def __init__(self, tool_names): self.tools = [load_tool(tool_name) for tool_name in tool_names] class CustomHfAgent(Agent): def __init__(self, url_endpoint, token, chat_prompt_template=None, run_prompt_template=None, additional_tools=None, input_params=None): super().__init__( chat_prompt_template=chat_prompt_template, run_prompt_template=run_prompt_template, additional_tools=additional_tools, ) self.url_endpoint = url_endpoint self.token = token self.input_params = input_params def generate_one(self, prompt, stop): headers = {"Authorization": self.token} max_new_tokens = self.input_params.get("max_new_tokens", 192) parameters = {"max_new_tokens": max_new_tokens, "return_full_text": False, "stop": stop, "padding": True, "truncation": True} inputs = { "inputs": prompt, "parameters": parameters, } response = requests.post(self.url_endpoint, json=inputs, headers=headers) if response.status_code == 429: print("Getting rate-limited, waiting a tiny bit before trying again.") time.sleep(1) return self._generate_one(prompt) elif response.status_code != 200: raise ValueError(f"Errors {inputs} {response.status_code}: {response.json()}") print(response) result = response.json()[0]["generated_text"] for stop_seq in stop: if result.endswith(stop_seq): return result[: -len(stop_seq)] return result def load_tools(tool_names): return [load_tool(tool_name) for tool_name in tool_names] # Define the tool names to load tool_names = [ "Chris4K/random-character-tool", "Chris4K/text-generation-tool", # Add other tool names as needed ] # Create tool loader instance tool_loader = ToolLoader(tool_names) # Define the callback function to handle the form submission def handle_submission(user_message, selected_tools): agent = CustomHfAgent( url_endpoint="https://api-inference.huggingface.co/models/bigcode/starcoder", token=os.environ['HF_token'], additional_tools=selected_tools, input_params={"max_new_tokens": 192}, ) response = agent.run(user_message) print("Agent Response\n {}".format(response)) return response st.title("Hugging Face Agent and tools") if "messages" not in st.session_state: st.session_state.messages = [] for message in st.session_state.messages: with st.chat_message(message["role"]): st.markdown(message["content"]) tool_checkboxes = [st.checkbox(f"{tool.name} --- {tool.description} ") for tool in tool_loader.tools] with st.chat_message("assistant"): st.markdown("Hello there! How can I assist you today?") if user_message := st.chat_input("Enter message"): st.chat_message("user").markdown(user_message) st.session_state.messages.append({"role": "user", "content": user_message}) selected_tools = [tool_loader.tools[idx] for idx, checkbox in enumerate(tool_checkboxes) if checkbox] response = handle_submission(user_message, selected_tools) with st.chat_message("assistant"): if response is None: st.warning("The agent's response is None. Please try again.") elif "emojified_text" in response: st.markdown(f"Emojified Text: {response['emojified_text']}") elif isinstance(response, Image.Image): st.image(response) elif "audio" in str(response): audio_data = base64.b64decode(response.split(",")[1]) audio = AudioSegment.from_file(io.BytesIO(audio_data)) st.audio(audio) elif isinstance(response, AudioSegment): st.audio(response) elif isinstance(response, str): st.markdown(response) elif isinstance(response, int): st.markdown(response) else: st.warning("Unrecognized response type. Please try again.") st.session_state.messages.append({"role": "assistant", "content": response})