SmolAgentsv2 / run.py
CultriX's picture
Update run.py
c62316d verified
raw
history blame
11.7 kB
import argparse
import os
import logging
from dotenv import load_dotenv
from huggingface_hub import login
from scripts.text_inspector_tool import TextInspectorTool
from scripts.text_web_browser import (
ArchiveSearchTool,
FinderTool,
FindNextTool,
PageDownTool,
PageUpTool,
SimpleTextBrowser,
VisitTool,
)
from scripts.visual_qa import visualizer
from smolagents import (
CodeAgent,
DuckDuckGoSearchTool,
LiteLLMModel,
ToolCallingAgent,
)
# Initialize logging
logger = logging.getLogger("smolagents")
logger.setLevel(logging.INFO)
log_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
# Load environment variables
load_dotenv(override=True)
hf_token = os.getenv("HF_TOKEN")
if hf_token:
login(hf_token)
logger.info("Logged into Hugging Face Hub.")
else:
logger.warning("HF_TOKEN not found. Proceeding without authentication.")
AUTHORIZED_IMPORTS = [
"requests", "zipfile", "os", "pandas", "numpy", "sympy", "json", "bs4",
"pubchempy", "xml", "yahoo_finance", "Bio", "sklearn", "scipy", "pydub",
"io", "PIL", "chess", "PyPDF2", "pptx", "torch", "datetime", "fractions", "csv", "string", "secrets",
]
USER_AGENT = (
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
"AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0"
)
BROWSER_CONFIG = {
"viewport_size": 5120,
"downloads_folder": "downloads_folder",
"request_kwargs": {
"headers": {"User-Agent": USER_AGENT},
"timeout": 150,
"max_retries": 2,
},
"serpapi_key": os.getenv("SERPAPI_API_KEY"),
}
os.makedirs(f"./{BROWSER_CONFIG['downloads_folder']}", exist_ok=True)
custom_role_conversions = {"tool-call": "assistant", "tool-response": "user"}
# Define the model configurations (custom models intact)
MODEL_CONFIGS = {
# OPENAI MODELS
"gpt-3.5-turbo": {
"litellm_params": {
"model_id": "openai/gpt-3.5-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
"custom_role_conversions": custom_role_conversions,
"max_completion_tokens": 8192,
},
},
"gpt-3.5-turbo-16k": {
"litellm_params": {
"model_id": "openai/gpt-3.5-turbo-16k",
"api_key": os.getenv("OPENAI_API_KEY"),
"custom_role_conversions": custom_role_conversions,
"max_completion_tokens": 16384,
},
},
"gpt-4o-mini": {
"litellm_params": {
"model_id": "openai/gpt-4o-mini",
"api_key": os.getenv("OPENAI_API_KEY"),
"custom_role_conversions": custom_role_conversions,
"max_completion_tokens": 8192,
},
},
"chatgpt-4o-latest": {
"litellm_params": {
"model_id": "openai/chatgpt-4o-latest",
"api_key": os.getenv("OPENAI_API_KEY"),
"custom_role_conversions": custom_role_conversions,
"max_completion_tokens": 8192,
},
},
"gpt-4-turbo": {
"litellm_params": {
"model_id": "openai/gpt-4-turbo",
"api_key": os.getenv("OPENAI_API_KEY"),
"custom_role_conversions": custom_role_conversions,
"max_completion_tokens": 8192,
},
},
"gpt-4o": {
"litellm_params": {
"model_id": "openai/gpt-4o",
"api_key": os.getenv("OPENAI_API_KEY"),
"custom_role_conversions": custom_role_conversions,
"max_completion_tokens": 8192,
},
},
"o1-mini": {
"litellm_params": {
"model_id": "openai/o1-mini",
"api_key": os.getenv("OPENAI_API_KEY"),
"custom_role_conversions": custom_role_conversions,
"max_completion_tokens": 8192,
"reasoning_effort": "high",
},
},
"o1-preview": {
"litellm_params": {
"model_id": "openai/o1-preview",
"api_key": os.getenv("OPENAI_API_KEY"),
"custom_role_conversions": custom_role_conversions,
"max_completion_tokens": 8192,
"reasoning_effort": "high",
},
},
# HUGGINGFACE MODELS
"hf-llama-3.1-8B-instruct": {
"litellm_params": {
"model_id": "huggingface/meta-llama/Meta-Llama-3.1-8B-Instruct",
"api_key": os.getenv("HF_TOKEN"),
"custom_role_conversions": custom_role_conversions,
"max_completion_tokens": 8192,
},
},
"hf-DeepSeek-R1-Distill-Qwen-32B": {
"litellm_params": {
"model_id": "huggingface/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
"api_key": os.getenv("HF_TOKEN"),
"custom_role_conversions": custom_role_conversions,
"max_completion_tokens": 8192,
},
},
"hf-Qwen2.5-Coder-32B-Instruct": {
"litellm_params": {
"model_id": "huggingface/Qwen/Qwen2.5-Coder-32B-Instruct",
"api_key": os.getenv("HF_TOKEN"),
"custom_role_conversions": custom_role_conversions,
"max_completion_tokens": 8192,
},
},
"hf-QwQ-32B-Preview": {
"litellm_params": {
"model_id": "huggingface/Qwen/QwQ-32B-Preview",
"api_key": os.getenv("HF_TOKEN"),
"custom_role_conversions": custom_role_conversions,
"max_completion_tokens": 8192,
},
},
"hf-Llama-3.1-70B-Instruct": {
"litellm_params": {
"model_id": "huggingface/meta-llama/Llama-3.1-70B-Instruct",
"api_key": os.getenv("HF_TOKEN"),
"custom_role_conversions": custom_role_conversions,
"max_completion_tokens": 8192,
},
},
# GROQ MODELS
"groq-llama3-8b-8192": {
"litellm_params": {
"model_id": "groq/llama3-8b-8192",
"api_key": os.getenv("GROQ_API_KEY"),
"custom_role_conversions": custom_role_conversions,
"max_completion_tokens": 8192,
},
},
"groq-llama3-70b-8192": {
"litellm_params": {
"model_id": "groq/llama3-70b-8192",
"api_key": os.getenv("GROQ_API_KEY"),
"custom_role_conversions": custom_role_conversions,
"max_completion_tokens": 8192,
},
},
"groq-mixtral-8x7b-32768": {
"litellm_params": {
"model_id": "groq/mixtral-8x7b-32768",
"api_key": os.getenv("GROQ_API_KEY"),
"custom_role_conversions": custom_role_conversions,
"max_completion_tokens": 32768,
},
},
# GEMINI MODELS
"gemini-pro": {
"litellm_params": {
"model_id": "gemini/gemini-pro",
"api_key": os.getenv("GEMINI_API_KEY"),
"custom_role_conversions": custom_role_conversions,
"max_completion_tokens": 8192,
},
},
"gemini-1.5-pro": {
"litellm_params": {
"model_id": "gemini/gemini-1.5-pro",
"api_key": os.getenv("GEMINI_API_KEY"),
"custom_role_conversions": custom_role_conversions,
"max_completion_tokens": 8192,
},
},
"gemini-1.5-flash": {
"litellm_params": {
"model_id": "gemini/gemini-1.5-flash",
"api_key": os.getenv("GEMINI_API_KEY"),
"custom_role_conversions": custom_role_conversions,
"max_completion_tokens": 8192,
},
},
"gemini-pro-vision": {
"litellm_params": {
"model_id": "gemini/gemini-pro-vision",
"api_key": os.getenv("GEMINI_API_KEY"),
"custom_role_conversions": custom_role_conversions,
"max_completion_tokens": 8192,
},
},
"gemini-2.0-flash": {
"litellm_params": {
"model_id": "gemini/gemini-2.0-flash",
"api_key": os.getenv("GEMINI_API_KEY"),
"custom_role_conversions": custom_role_conversions,
"max_completion_tokens": 8192,
},
},
"gemini-2.0-flash-thinking-exp-01-21": {
"litellm_params": {
"model_id": "gemini/gemini-2.0-flash-thinking-exp-01-21",
"api_key": os.getenv("GEMINI_API_KEY"),
"custom_role_conversions": custom_role_conversions,
"max_completion_tokens": 8192,
},
},
# Default o1 model
"o1": {
"litellm_params": {
"model_id": "o1",
"custom_role_conversions": custom_role_conversions,
"max_completion_tokens": 8192,
"reasoning_effort": "high",
},
},
}
def parse_args():
parser = argparse.ArgumentParser(
description="Run the search agent to answer questions using web browsing tools."
)
parser.add_argument(
"question",
type=str,
help="Example: 'How many studio albums did Mercedes Sosa release before 2007?'"
)
parser.add_argument(
"--model-id",
type=str,
default="o1",
help="Model identifier (default: o1)"
)
return parser.parse_args()
def create_agent(model_name="o1"):
if model_name not in MODEL_CONFIGS:
raise ValueError(f"Model '{model_name}' is not a valid model. Available models are: {list(MODEL_CONFIGS.keys())}")
model_params = MODEL_CONFIGS[model_name]["litellm_params"]
model_params.setdefault("custom_role_conversions", custom_role_conversions)
model_params.setdefault("max_completion_tokens", 8192)
# (Optional: adjust parameters here to lower temperature for more factual answers.)
model = LiteLLMModel(**model_params)
logger.info(f"Initialized LiteLLMModel with model_name={model_name}")
text_limit = 100000
browser = SimpleTextBrowser(**BROWSER_CONFIG)
logger.info("Initialized SimpleTextBrowser with custom configuration.")
WEB_TOOLS = [
DuckDuckGoSearchTool(),
VisitTool(browser),
PageUpTool(browser),
PageDownTool(browser),
FinderTool(browser),
FindNextTool(browser),
ArchiveSearchTool(browser),
TextInspectorTool(model, text_limit),
]
logger.info("Initialized web tools for ToolCallingAgent.")
text_webbrowser_agent = ToolCallingAgent(
model=model,
tools=WEB_TOOLS,
max_steps=10,
verbosity_level=2,
planning_interval=4,
name="search_agent",
description=(
"A team member that will search the internet to answer your question. "
"Ask all questions that require browsing the web using complete sentences. "
"Provide as much context as possible, especially if searching within a specific timeframe."
),
provide_run_summary=True,
)
logger.info("Initialized ToolCallingAgent.")
manager_agent = CodeAgent(
model=model,
tools=[visualizer, TextInspectorTool(model, text_limit)],
max_steps=12,
verbosity_level=2,
additional_authorized_imports=AUTHORIZED_IMPORTS,
planning_interval=4,
managed_agents=[text_webbrowser_agent],
)
logger.info("Initialized Manager CodeAgent.")
return manager_agent
def main():
args = parse_args()
logger.info(f"Received question: {args.question} with model_id={args.model_id}")
agent = create_agent(model_name=args.model_id)
answer = agent.run(args.question)
if isinstance(answer, str):
print(f"Got this answer: {answer}")
else:
result = ""
for chunk in answer:
result += chunk
print(f"Got this answer: {result}")
logger.info("Agent has completed processing the question.")
if __name__ == "__main__":
main()