File size: 3,945 Bytes
182160d 8a158c4 182160d 3d2ef38 182160d 3d2ef38 182160d 3d2ef38 182160d 3d2ef38 182160d 3d2ef38 182160d 3d2ef38 182160d 3d2ef38 182160d 3d2ef38 182160d 3d2ef38 182160d 2345f27 182160d 28d4a17 182160d 3d2ef38 182160d 3d2ef38 7f84964 182160d 3d2ef38 3ea980f 2345f27 3ea980f 5cce4fd 2345f27 247a0af 2345f27 182160d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import os
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import tools_condition, ToolNode
from langchain_openai import ChatOpenAI
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.tools import tool
@tool
def multiply(a: int, b: int) -> int:
"""Multiply two integers."""
return a * b
@tool
def add(a: int, b: int) -> int:
"""Add two integers."""
return a + b
@tool
def subtract(a: int, b: int) -> int:
"""Subtract the second integer from the first."""
return a - b
@tool
def divide(a: int, b: int) -> float:
"""Divide first integer by second; error if divisor is zero."""
if b == 0:
raise ValueError("Cannot divide by zero.")
return a / b
@tool
def modulus(a: int, b: int) -> int:
"""Return the remainder of dividing first integer by second."""
return a % b
@tool
def wiki_search(query: str) -> dict:
"""Search Wikipedia for a query and return up to 2 documents."""
docs = WikipediaLoader(query=query, load_max_docs=2).load()
formatted = "\n\n---\n\n".join(
f'<Document source="{d.metadata["source"]}"/>\n{d.page_content}'
for d in docs
)
return {"wiki_results": formatted}
@tool
def web_search(query: str) -> dict:
"""Perform a web search (via Tavily) and return up to 3 results."""
docs = TavilySearchResults(max_results=3).invoke(query=query)
formatted = "\n\n---\n\n".join(
f'<Document source="{d.metadata["source"]}"/>\n{d.page_content}'
for d in docs
)
return {"web_results": formatted}
@tool
def arvix_search(query: str) -> dict:
"""Search arXiv for a query and return up to 3 paper excerpts."""
docs = ArxivLoader(query=query, load_max_docs=3).load()
formatted = "\n\n---\n\n".join(
f'<Document source="{d.metadata["source"]}"/>\n{d.page_content[:1000]}'
for d in docs
)
return {"arvix_results": formatted}
API_KEY = os.getenv("GEMINI_API_KEY")
HF_SPACE_TOKEN = os.getenv("HF_SPACE_TOKEN")
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
tools = [
multiply, add, subtract, divide, modulus,
wiki_search, web_search, arvix_search,
]
with open("prompt.txt", "r", encoding="utf-8") as f:
system_prompt = f.read()
sys_msg = SystemMessage(content=system_prompt)
def build_graph(provider: str = "gemini"):
"""Build the LangGraph agent with chosen LLM (default: Gemini)."""
if provider == "gemini":
llm = ChatGoogleGenerativeAI(
model= "gemini-2.5-pro-preview-05-06",
temperature=1.0,
max_retries=2,
api_key=GEMINI_API_KEY,
)
elif provider == "huggingface":
llm = ChatHuggingFace(
llm=HuggingFaceEndpoint(
url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
),
temperature=0,
)
else:
raise ValueError("Invalid provider. Choose 'openai' or 'huggingface'.")
llm_with_tools = llm.bind_tools(tools)
def assistant(state: MessagesState):
return {"messages": [llm_with_tools.invoke(state["messages"])]}
builder = StateGraph(MessagesState)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))
builder.add_edge(START, "assistant")
builder.add_conditional_edges("assistant", tools_condition)
builder.add_edge("tools", "assistant")
return builder.compile()
if __name__ == "__main__":
graph = build_graph()
msgs = graph.invoke({"messages":[ HumanMessage(content="What’s the capital of France?") ]})
for m in msgs["messages"]:
m.pretty_print() |