Spaces:
Running
Running
import argparse | |
import os | |
import logging | |
from dotenv import load_dotenv | |
from huggingface_hub import login | |
from scripts.text_inspector_tool import TextInspectorTool | |
from scripts.text_web_browser import ( | |
ArchiveSearchTool, | |
FinderTool, | |
FindNextTool, | |
PageDownTool, | |
PageUpTool, | |
SimpleTextBrowser, | |
VisitTool, | |
) | |
from scripts.visual_qa import visualizer | |
from smolagents import ( | |
CodeAgent, | |
DuckDuckGoSearchTool, | |
LiteLLMModel, | |
ToolCallingAgent, | |
) | |
# Initialize logging | |
logger = logging.getLogger("smolagents") | |
logger.setLevel(logging.INFO) | |
log_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') | |
# Load environment variables | |
load_dotenv(override=True) | |
hf_token = os.getenv("HF_TOKEN") | |
if hf_token: | |
login(hf_token) | |
logger.info("Logged into Hugging Face Hub.") | |
else: | |
logger.warning("HF_TOKEN not found. Proceeding without authentication.") | |
AUTHORIZED_IMPORTS = [ | |
"requests", "zipfile", "os", "pandas", "numpy", "sympy", "json", "bs4", | |
"pubchempy", "xml", "yahoo_finance", "Bio", "sklearn", "scipy", "pydub", | |
"io", "PIL", "chess", "PyPDF2", "pptx", "torch", "datetime", "fractions", "csv", "string", "secrets", | |
] | |
USER_AGENT = ( | |
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) " | |
"AppleWebKit/537.36 (KHTML, like Gecko) " | |
"Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0" | |
) | |
BROWSER_CONFIG = { | |
"viewport_size": 5120, | |
"downloads_folder": "downloads_folder", | |
"request_kwargs": { | |
"headers": {"User-Agent": USER_AGENT}, | |
"timeout": 150, | |
"max_retries": 2, | |
}, | |
"serpapi_key": os.getenv("SERPAPI_API_KEY"), | |
} | |
os.makedirs(f"./{BROWSER_CONFIG['downloads_folder']}", exist_ok=True) | |
custom_role_conversions = {"tool-call": "assistant", "tool-response": "user"} | |
# Define the model configurations (custom models intact) | |
MODEL_CONFIGS = { | |
# OPENAI MODELS | |
"gpt-3.5-turbo": { | |
"litellm_params": { | |
"model_id": "openai/gpt-3.5-turbo", | |
"api_key": os.getenv("OPENAI_API_KEY"), | |
"custom_role_conversions": custom_role_conversions, | |
"max_completion_tokens": 8192, | |
}, | |
}, | |
"gpt-3.5-turbo-16k": { | |
"litellm_params": { | |
"model_id": "openai/gpt-3.5-turbo-16k", | |
"api_key": os.getenv("OPENAI_API_KEY"), | |
"custom_role_conversions": custom_role_conversions, | |
"max_completion_tokens": 16384, | |
}, | |
}, | |
"gpt-4o-mini": { | |
"litellm_params": { | |
"model_id": "openai/gpt-4o-mini", | |
"api_key": os.getenv("OPENAI_API_KEY"), | |
"custom_role_conversions": custom_role_conversions, | |
"max_completion_tokens": 8192, | |
}, | |
}, | |
"chatgpt-4o-latest": { | |
"litellm_params": { | |
"model_id": "openai/chatgpt-4o-latest", | |
"api_key": os.getenv("OPENAI_API_KEY"), | |
"custom_role_conversions": custom_role_conversions, | |
"max_completion_tokens": 8192, | |
}, | |
}, | |
"gpt-4-turbo": { | |
"litellm_params": { | |
"model_id": "openai/gpt-4-turbo", | |
"api_key": os.getenv("OPENAI_API_KEY"), | |
"custom_role_conversions": custom_role_conversions, | |
"max_completion_tokens": 8192, | |
}, | |
}, | |
"gpt-4o": { | |
"litellm_params": { | |
"model_id": "openai/gpt-4o", | |
"api_key": os.getenv("OPENAI_API_KEY"), | |
"custom_role_conversions": custom_role_conversions, | |
"max_completion_tokens": 8192, | |
}, | |
}, | |
"o1-mini": { | |
"litellm_params": { | |
"model_id": "openai/o1-mini", | |
"api_key": os.getenv("OPENAI_API_KEY"), | |
"custom_role_conversions": custom_role_conversions, | |
"max_completion_tokens": 8192, | |
"reasoning_effort": "high", | |
}, | |
}, | |
"o1-preview": { | |
"litellm_params": { | |
"model_id": "openai/o1-preview", | |
"api_key": os.getenv("OPENAI_API_KEY"), | |
"custom_role_conversions": custom_role_conversions, | |
"max_completion_tokens": 8192, | |
"reasoning_effort": "high", | |
}, | |
}, | |
# HUGGINGFACE MODELS | |
"hf-llama-3.1-8B-instruct": { | |
"litellm_params": { | |
"model_id": "huggingface/meta-llama/Meta-Llama-3.1-8B-Instruct", | |
"api_key": os.getenv("HF_TOKEN"), | |
"custom_role_conversions": custom_role_conversions, | |
"max_completion_tokens": 8192, | |
}, | |
}, | |
"hf-DeepSeek-R1-Distill-Qwen-32B": { | |
"litellm_params": { | |
"model_id": "huggingface/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", | |
"api_key": os.getenv("HF_TOKEN"), | |
"custom_role_conversions": custom_role_conversions, | |
"max_completion_tokens": 8192, | |
}, | |
}, | |
"hf-Qwen2.5-Coder-32B-Instruct": { | |
"litellm_params": { | |
"model_id": "huggingface/Qwen/Qwen2.5-Coder-32B-Instruct", | |
"api_key": os.getenv("HF_TOKEN"), | |
"custom_role_conversions": custom_role_conversions, | |
"max_completion_tokens": 8192, | |
}, | |
}, | |
"hf-QwQ-32B-Preview": { | |
"litellm_params": { | |
"model_id": "huggingface/Qwen/QwQ-32B-Preview", | |
"api_key": os.getenv("HF_TOKEN"), | |
"custom_role_conversions": custom_role_conversions, | |
"max_completion_tokens": 8192, | |
}, | |
}, | |
"hf-Llama-3.1-70B-Instruct": { | |
"litellm_params": { | |
"model_id": "huggingface/meta-llama/Llama-3.1-70B-Instruct", | |
"api_key": os.getenv("HF_TOKEN"), | |
"custom_role_conversions": custom_role_conversions, | |
"max_completion_tokens": 8192, | |
}, | |
}, | |
# GROQ MODELS | |
"groq-llama3-8b-8192": { | |
"litellm_params": { | |
"model_id": "groq/llama3-8b-8192", | |
"api_key": os.getenv("GROQ_API_KEY"), | |
"custom_role_conversions": custom_role_conversions, | |
"max_completion_tokens": 8192, | |
}, | |
}, | |
"groq-llama3-70b-8192": { | |
"litellm_params": { | |
"model_id": "groq/llama3-70b-8192", | |
"api_key": os.getenv("GROQ_API_KEY"), | |
"custom_role_conversions": custom_role_conversions, | |
"max_completion_tokens": 8192, | |
}, | |
}, | |
"groq-mixtral-8x7b-32768": { | |
"litellm_params": { | |
"model_id": "groq/mixtral-8x7b-32768", | |
"api_key": os.getenv("GROQ_API_KEY"), | |
"custom_role_conversions": custom_role_conversions, | |
"max_completion_tokens": 32768, | |
}, | |
}, | |
# GEMINI MODELS | |
"gemini-pro": { | |
"litellm_params": { | |
"model_id": "gemini/gemini-pro", | |
"api_key": os.getenv("GEMINI_API_KEY"), | |
"custom_role_conversions": custom_role_conversions, | |
"max_completion_tokens": 8192, | |
}, | |
}, | |
"gemini-1.5-pro": { | |
"litellm_params": { | |
"model_id": "gemini/gemini-1.5-pro", | |
"api_key": os.getenv("GEMINI_API_KEY"), | |
"custom_role_conversions": custom_role_conversions, | |
"max_completion_tokens": 8192, | |
}, | |
}, | |
"gemini-1.5-flash": { | |
"litellm_params": { | |
"model_id": "gemini/gemini-1.5-flash", | |
"api_key": os.getenv("GEMINI_API_KEY"), | |
"custom_role_conversions": custom_role_conversions, | |
"max_completion_tokens": 8192, | |
}, | |
}, | |
"gemini-pro-vision": { | |
"litellm_params": { | |
"model_id": "gemini/gemini-pro-vision", | |
"api_key": os.getenv("GEMINI_API_KEY"), | |
"custom_role_conversions": custom_role_conversions, | |
"max_completion_tokens": 8192, | |
}, | |
}, | |
"gemini-2.0-flash": { | |
"litellm_params": { | |
"model_id": "gemini/gemini-2.0-flash", | |
"api_key": os.getenv("GEMINI_API_KEY"), | |
"custom_role_conversions": custom_role_conversions, | |
"max_completion_tokens": 8192, | |
}, | |
}, | |
"gemini-2.0-flash-thinking-exp-01-21": { | |
"litellm_params": { | |
"model_id": "gemini/gemini-2.0-flash-thinking-exp-01-21", | |
"api_key": os.getenv("GEMINI_API_KEY"), | |
"custom_role_conversions": custom_role_conversions, | |
"max_completion_tokens": 8192, | |
}, | |
}, | |
# Default o1 model | |
"o1": { | |
"litellm_params": { | |
"model_id": "o1", | |
"custom_role_conversions": custom_role_conversions, | |
"max_completion_tokens": 8192, | |
"reasoning_effort": "high", | |
}, | |
}, | |
} | |
def parse_args(): | |
parser = argparse.ArgumentParser( | |
description="Run the search agent to answer questions using web browsing tools." | |
) | |
parser.add_argument( | |
"question", | |
type=str, | |
help="Example: 'How many studio albums did Mercedes Sosa release before 2007?'" | |
) | |
parser.add_argument( | |
"--model-id", | |
type=str, | |
default="o1", | |
help="Model identifier (default: o1)" | |
) | |
return parser.parse_args() | |
def create_agent(model_name="o1"): | |
if model_name not in MODEL_CONFIGS: | |
raise ValueError(f"Model '{model_name}' is not a valid model. Available models are: {list(MODEL_CONFIGS.keys())}") | |
model_params = MODEL_CONFIGS[model_name]["litellm_params"] | |
model_params.setdefault("custom_role_conversions", custom_role_conversions) | |
model_params.setdefault("max_completion_tokens", 8192) | |
# (Optional: adjust parameters here to lower temperature for more factual answers.) | |
model = LiteLLMModel(**model_params) | |
logger.info(f"Initialized LiteLLMModel with model_name={model_name}") | |
text_limit = 100000 | |
browser = SimpleTextBrowser(**BROWSER_CONFIG) | |
logger.info("Initialized SimpleTextBrowser with custom configuration.") | |
WEB_TOOLS = [ | |
DuckDuckGoSearchTool(), | |
VisitTool(browser), | |
PageUpTool(browser), | |
PageDownTool(browser), | |
FinderTool(browser), | |
FindNextTool(browser), | |
ArchiveSearchTool(browser), | |
TextInspectorTool(model, text_limit), | |
] | |
logger.info("Initialized web tools for ToolCallingAgent.") | |
text_webbrowser_agent = ToolCallingAgent( | |
model=model, | |
tools=WEB_TOOLS, | |
max_steps=10, | |
verbosity_level=2, | |
planning_interval=4, | |
name="search_agent", | |
description=( | |
"A team member that will search the internet to answer your question. " | |
"Ask all questions that require browsing the web using complete sentences. " | |
"Provide as much context as possible, especially if searching within a specific timeframe." | |
), | |
provide_run_summary=True, | |
) | |
logger.info("Initialized ToolCallingAgent.") | |
manager_agent = CodeAgent( | |
model=model, | |
tools=[visualizer, TextInspectorTool(model, text_limit)], | |
max_steps=12, | |
verbosity_level=2, | |
additional_authorized_imports=AUTHORIZED_IMPORTS, | |
planning_interval=4, | |
managed_agents=[text_webbrowser_agent], | |
) | |
logger.info("Initialized Manager CodeAgent.") | |
return manager_agent | |
def main(): | |
args = parse_args() | |
logger.info(f"Received question: {args.question} with model_id={args.model_id}") | |
agent = create_agent(model_name=args.model_id) | |
answer = agent.run(args.question) | |
if isinstance(answer, str): | |
print(f"Got this answer: {answer}") | |
else: | |
result = "" | |
for chunk in answer: | |
result += chunk | |
print(f"Got this answer: {result}") | |
logger.info("Agent has completed processing the question.") | |
if __name__ == "__main__": | |
main() |