Spaces:
Runtime error
Runtime error
import os | |
from langchain_groq import ChatGroq | |
from langchain.prompts import PromptTemplate | |
from langgraph.graph import START, StateGraph, MessagesState | |
from langgraph.prebuilt import ToolNode, tools_condition | |
from langchain_community.tools.tavily_search import TavilySearchResults | |
from langchain_community.document_loaders import WikipediaLoader | |
from langchain_core.messages import HumanMessage | |
from langchain.tools import tool | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.runnables import Runnable | |
from dotenv import load_dotenv | |
# Load environment variables from .env | |
load_dotenv() | |
# Initialize LLM | |
def initialize_llm(): | |
"""Initializes the ChatGroq LLM.""" | |
llm = ChatGroq( | |
temperature=0, | |
model_name="qwen-qwq-32b", | |
groq_api_key=os.getenv("GROQ_API_KEY") | |
) | |
return llm | |
# Initialize Tavily Search Tool | |
def initialize_search_tool(): | |
"""Initializes the TavilySearchResults tool.""" | |
return TavilySearchResults() | |
# Weather tool | |
def get_weather(location: str, search_tool: TavilySearchResults = None) -> str: | |
""" | |
Fetches the current weather information for a given location using Tavily search. | |
Args: | |
location (str): The name of the location to search for. | |
search_tool (TavilySearchResults, optional): Defaults to None. | |
Returns: | |
str: The weather information for the specified location. | |
""" | |
if search_tool is None: | |
search_tool = initialize_search_tool() | |
query = f"current weather in {location}" | |
return search_tool.run(query) | |
# Recommendation chain | |
def initialize_recommendation_chain(llm: ChatGroq) -> Runnable: | |
""" | |
Initializes the recommendation chain. | |
Args: | |
llm(ChatGroq):The LLM to use | |
Returns: | |
Runnable: A runnable sequence to generate recommendations. | |
""" | |
recommendation_prompt = ChatPromptTemplate.from_template(""" | |
You are a helpful assistant that gives weather-based advice. | |
Given the current weather condition: "{weather_condition}", provide: | |
1. Clothing or activity recommendations suited for this weather. | |
2. At least one health tip to stay safe or comfortable in this condition. | |
Be concise and clear. | |
""") | |
return recommendation_prompt | llm | |
def get_recommendation(weather_condition: str, recommendation_chain: Runnable = None) -> str: | |
""" | |
Gives activity/clothing recommendations and health tips based on the weather condition. | |
Args: | |
weather_condition (str): The current weather condition. | |
recommendation_chain (Runnable, optional): The recommendation chain to use. Defaults to None. | |
Returns: | |
str: Recommendations and health tips for the given weather condition. | |
""" | |
if recommendation_chain is None: | |
llm = initialize_llm() | |
recommendation_chain = initialize_recommendation_chain(llm) | |
return recommendation_chain.invoke({"weather_condition": weather_condition}) | |
# Math tools | |
def add(x: int, y: int) -> int: | |
""" | |
Adds two integers. | |
Args: | |
x (int): The first integer. | |
y (int): The second integer. | |
Returns: | |
int: The sum of x and y. | |
""" | |
return x + y | |
def subtract(x: int, y: int) -> int: | |
""" | |
Subtracts two integers. | |
Args: | |
x (int): The first integer. | |
y (int): The second integer. | |
Returns: | |
int: The difference between x and y. | |
""" | |
return x - y | |
def multiply(x: int, y: int) -> int: | |
""" | |
Multiplies two integers. | |
Args: | |
x (int): The first integer. | |
y (int): The second integer. | |
Returns: | |
int: The product of x and y. | |
""" | |
return x * y | |
def divide(x: int, y: int) -> float: | |
""" | |
Divides two numbers. | |
Args: | |
x (int): The numerator. | |
y (int): The denominator. | |
Returns: | |
float: The result of the division. | |
Raises: | |
ValueError: If y is zero. | |
""" | |
if y == 0: | |
raise ValueError("Cannot divide by zero.") | |
return x / y | |
def square(x: int) -> int: | |
""" | |
Calculates the square of a number. | |
Args: | |
x (int): The number to square. | |
Returns: | |
int: The square of x. | |
""" | |
return x * x | |
def cube(x: int) -> int: | |
""" | |
Calculates the cube of a number. | |
Args: | |
x (int): The number to cube. | |
Returns: | |
int: The cube of x. | |
""" | |
return x * x * x | |
def power(x: int, y: int) -> int: | |
""" | |
Raises a number to the power of another number. | |
Args: | |
x (int): The base number. | |
y (int): The exponent. | |
Returns: | |
int: x raised to the power of y. | |
""" | |
return x ** y | |
def factorial(n: int) -> int: | |
""" | |
Calculates the factorial of a non-negative integer. | |
Args: | |
n (int): The non-negative integer. | |
Returns: | |
int: The factorial of n. | |
Raises: | |
ValueError: If n is negative. | |
""" | |
if n < 0: | |
raise ValueError("Factorial is not defined for negative numbers.") | |
if n == 0 or n == 1: | |
return 1 | |
result = 1 | |
for i in range(2, n + 1): | |
result *= i | |
return result | |
def mean(numbers: list) -> float: | |
""" | |
Calculates the mean of a list of numbers. | |
Args: | |
numbers (list): A list of numbers. | |
Returns: | |
float: The mean of the numbers. | |
Raises: | |
ValueError: If the list is empty. | |
""" | |
if not numbers: | |
raise ValueError("The list is empty.") | |
return sum(numbers) / len(numbers) | |
def standard_deviation(numbers: list) -> float: | |
""" | |
Calculates the standard deviation of a list of numbers. | |
Args: | |
numbers (list): A list of numbers. | |
Returns: | |
float: The standard deviation of the numbers. | |
Raises: | |
ValueError: If the list is empty. | |
""" | |
if not numbers: | |
raise ValueError("The list is empty.") | |
mean_value = mean(numbers) | |
variance = sum((x - mean_value) ** 2 for x in numbers) / len(numbers) | |
return variance ** 0.5 | |
# Build the LangGraph | |
def build_graph(): | |
""" | |
Builds the LangGraph with the defined tools and assistant node. | |
Returns: | |
RunnableGraph: The compiled LangGraph. | |
""" | |
llm = initialize_llm() | |
search_tool = initialize_search_tool() | |
recommendation_chain = initialize_recommendation_chain(llm) | |
def weather_tool(location: str) -> str: | |
""" | |
Fetches the weather for a location. | |
Args: | |
location (str): The location to fetch weather for. | |
Returns: | |
str: The weather information. | |
""" | |
return get_weather(location, search_tool) | |
def web_search(query: str) -> str: | |
"""Search the web for a given query and return the summary. | |
Args: | |
query (str): The search query. | |
""" | |
search_tool = TavilySearchResults() | |
result = search_tool.run(query) | |
return result[0]['content'] | |
def wiki_search(query : str) -> str: | |
"""Search Wikipedia for a given query and return the summary. | |
Args: | |
query (str): The search query. | |
""" | |
search_docs = WikipediaLoader(query=query, load_max_docs=1).load() | |
formatted_search_docs = "\n\n----\n\n".join( | |
[ | |
f'<Document Source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}">\n{doc.page_content}\n</Document>' | |
for doc in search_docs | |
] | |
) | |
return formatted_search_docs | |
def recommendation_tool(weather_condition: str) -> str: | |
""" | |
Provides recommendations based on weather conditions. | |
Args: | |
weather_condition (str): The weather condition. | |
Returns: | |
str: The recommendations. | |
""" | |
return get_recommendation(weather_condition, recommendation_chain) | |
tools = [weather_tool, recommendation_tool, wiki_search, web_search, | |
add, subtract, multiply, divide, square, cube, power, factorial, mean, standard_deviation] | |
llm_with_tools = llm.bind_tools(tools) | |
def assistant(state: MessagesState): | |
""" | |
Assistant node in the LangGraph. | |
Args: | |
state (MessagesState): The current state of the conversation. | |
Returns: | |
dict: The next state of the conversation. | |
""" | |
print("Entering assistant node...") | |
response = llm_with_tools.invoke(state["messages"]) | |
print(f"Assistant says: {response.content}") | |
return {"messages": [response]} | |
builder = StateGraph(MessagesState) | |
builder.add_node("assistant", assistant) | |
builder.add_node("tools", ToolNode(tools)) | |
builder.set_entry_point("assistant") | |
builder.add_conditional_edges("assistant", tools_condition) | |
builder.add_edge("tools", "assistant") | |
return builder.compile() | |
if __name__ == "__main__": | |
graph = build_graph() | |
question = "How many albums were pulished by Mercedes Sosa?" | |
messages = [HumanMessage(content=question)] | |
result = graph.invoke({"messages": messages}) | |
for msg in result["messages"]: | |
msg.pretty_print() | |