SmolAgents_1 / agent.py
DDSS's picture
Update agent.py
620166e verified
import os
from langchain_groq import ChatGroq
from langchain.prompts import PromptTemplate
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.document_loaders import WikipediaLoader
from langchain_core.messages import HumanMessage
from langchain.tools import tool
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import Runnable
from dotenv import load_dotenv
# Load environment variables from .env
load_dotenv()
# Initialize LLM
def initialize_llm():
"""Initializes the ChatGroq LLM."""
llm = ChatGroq(
temperature=0,
model_name="qwen-qwq-32b",
groq_api_key=os.getenv("GROQ_API_KEY")
)
return llm
# Initialize Tavily Search Tool
def initialize_search_tool():
"""Initializes the TavilySearchResults tool."""
return TavilySearchResults()
# Weather tool
def get_weather(location: str, search_tool: TavilySearchResults = None) -> str:
"""
Fetches the current weather information for a given location using Tavily search.
Args:
location (str): The name of the location to search for.
search_tool (TavilySearchResults, optional): Defaults to None.
Returns:
str: The weather information for the specified location.
"""
if search_tool is None:
search_tool = initialize_search_tool()
query = f"current weather in {location}"
return search_tool.run(query)
# Recommendation chain
def initialize_recommendation_chain(llm: ChatGroq) -> Runnable:
"""
Initializes the recommendation chain.
Args:
llm(ChatGroq):The LLM to use
Returns:
Runnable: A runnable sequence to generate recommendations.
"""
recommendation_prompt = ChatPromptTemplate.from_template("""
You are a helpful assistant that gives weather-based advice.
Given the current weather condition: "{weather_condition}", provide:
1. Clothing or activity recommendations suited for this weather.
2. At least one health tip to stay safe or comfortable in this condition.
Be concise and clear.
""")
return recommendation_prompt | llm
def get_recommendation(weather_condition: str, recommendation_chain: Runnable = None) -> str:
"""
Gives activity/clothing recommendations and health tips based on the weather condition.
Args:
weather_condition (str): The current weather condition.
recommendation_chain (Runnable, optional): The recommendation chain to use. Defaults to None.
Returns:
str: Recommendations and health tips for the given weather condition.
"""
if recommendation_chain is None:
llm = initialize_llm()
recommendation_chain = initialize_recommendation_chain(llm)
return recommendation_chain.invoke({"weather_condition": weather_condition})
# Math tools
@tool
def add(x: int, y: int) -> int:
"""
Adds two integers.
Args:
x (int): The first integer.
y (int): The second integer.
Returns:
int: The sum of x and y.
"""
return x + y
@tool
def subtract(x: int, y: int) -> int:
"""
Subtracts two integers.
Args:
x (int): The first integer.
y (int): The second integer.
Returns:
int: The difference between x and y.
"""
return x - y
@tool
def multiply(x: int, y: int) -> int:
"""
Multiplies two integers.
Args:
x (int): The first integer.
y (int): The second integer.
Returns:
int: The product of x and y.
"""
return x * y
@tool
def divide(x: int, y: int) -> float:
"""
Divides two numbers.
Args:
x (int): The numerator.
y (int): The denominator.
Returns:
float: The result of the division.
Raises:
ValueError: If y is zero.
"""
if y == 0:
raise ValueError("Cannot divide by zero.")
return x / y
@tool
def square(x: int) -> int:
"""
Calculates the square of a number.
Args:
x (int): The number to square.
Returns:
int: The square of x.
"""
return x * x
@tool
def cube(x: int) -> int:
"""
Calculates the cube of a number.
Args:
x (int): The number to cube.
Returns:
int: The cube of x.
"""
return x * x * x
@tool
def power(x: int, y: int) -> int:
"""
Raises a number to the power of another number.
Args:
x (int): The base number.
y (int): The exponent.
Returns:
int: x raised to the power of y.
"""
return x ** y
@tool
def factorial(n: int) -> int:
"""
Calculates the factorial of a non-negative integer.
Args:
n (int): The non-negative integer.
Returns:
int: The factorial of n.
Raises:
ValueError: If n is negative.
"""
if n < 0:
raise ValueError("Factorial is not defined for negative numbers.")
if n == 0 or n == 1:
return 1
result = 1
for i in range(2, n + 1):
result *= i
return result
@tool
def mean(numbers: list) -> float:
"""
Calculates the mean of a list of numbers.
Args:
numbers (list): A list of numbers.
Returns:
float: The mean of the numbers.
Raises:
ValueError: If the list is empty.
"""
if not numbers:
raise ValueError("The list is empty.")
return sum(numbers) / len(numbers)
@tool
def standard_deviation(numbers: list) -> float:
"""
Calculates the standard deviation of a list of numbers.
Args:
numbers (list): A list of numbers.
Returns:
float: The standard deviation of the numbers.
Raises:
ValueError: If the list is empty.
"""
if not numbers:
raise ValueError("The list is empty.")
mean_value = mean(numbers)
variance = sum((x - mean_value) ** 2 for x in numbers) / len(numbers)
return variance ** 0.5
# Build the LangGraph
def build_graph():
"""
Builds the LangGraph with the defined tools and assistant node.
Returns:
RunnableGraph: The compiled LangGraph.
"""
llm = initialize_llm()
search_tool = initialize_search_tool()
recommendation_chain = initialize_recommendation_chain(llm)
@tool
def weather_tool(location: str) -> str:
"""
Fetches the weather for a location.
Args:
location (str): The location to fetch weather for.
Returns:
str: The weather information.
"""
return get_weather(location, search_tool)
@tool
def web_search(query: str) -> str:
"""Search the web for a given query and return the summary.
Args:
query (str): The search query.
"""
search_tool = TavilySearchResults()
result = search_tool.run(query)
return result[0]['content']
@tool
def wiki_search(query : str) -> str:
"""Search Wikipedia for a given query and return the summary.
Args:
query (str): The search query.
"""
search_docs = WikipediaLoader(query=query, load_max_docs=1).load()
formatted_search_docs = "\n\n----\n\n".join(
[
f'<Document Source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}">\n{doc.page_content}\n</Document>'
for doc in search_docs
]
)
return formatted_search_docs
@tool
def recommendation_tool(weather_condition: str) -> str:
"""
Provides recommendations based on weather conditions.
Args:
weather_condition (str): The weather condition.
Returns:
str: The recommendations.
"""
return get_recommendation(weather_condition, recommendation_chain)
tools = [weather_tool, recommendation_tool, wiki_search, web_search,
add, subtract, multiply, divide, square, cube, power, factorial, mean, standard_deviation]
llm_with_tools = llm.bind_tools(tools)
def assistant(state: MessagesState):
"""
Assistant node in the LangGraph.
Args:
state (MessagesState): The current state of the conversation.
Returns:
dict: The next state of the conversation.
"""
print("Entering assistant node...")
response = llm_with_tools.invoke(state["messages"])
print(f"Assistant says: {response.content}")
return {"messages": [response]}
builder = StateGraph(MessagesState)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))
builder.set_entry_point("assistant")
builder.add_conditional_edges("assistant", tools_condition)
builder.add_edge("tools", "assistant")
return builder.compile()
if __name__ == "__main__":
graph = build_graph()
question = "How many albums were pulished by Mercedes Sosa?"
messages = [HumanMessage(content=question)]
result = graph.invoke({"messages": messages})
for msg in result["messages"]:
msg.pretty_print()