File size: 3,839 Bytes
182160d
 
 
 
 
 
 
 
 
 
 
 
3d2ef38
182160d
 
 
 
3d2ef38
182160d
 
 
 
3d2ef38
182160d
 
 
 
3d2ef38
182160d
 
 
 
 
 
3d2ef38
182160d
 
3d2ef38
182160d
 
3d2ef38
182160d
 
 
 
 
 
 
 
 
3d2ef38
182160d
 
 
 
 
 
 
 
 
3d2ef38
182160d
 
 
 
 
 
 
 
 
 
3d2ef38
182160d
 
 
 
 
3d2ef38
7f84964
182160d
 
 
3d2ef38
182160d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import tools_condition, ToolNode
from langchain_openai import ChatOpenAI
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.tools import tool

@tool
def multiply(a: int, b: int) -> int:
    """Multiply two integers."""
    return a * b

@tool
def add(a: int, b: int) -> int:
    """Add two integers."""
    return a + b

@tool
def subtract(a: int, b: int) -> int:
    """Subtract the second integer from the first."""
    return a - b

@tool
def divide(a: int, b: int) -> float:
    """Divide first integer by second; error if divisor is zero."""
    if b == 0:
        raise ValueError("Cannot divide by zero.")
    return a / b

@tool
def modulus(a: int, b: int) -> int:
    """Return the remainder of dividing first integer by second."""
    return a % b


@tool
def wiki_search(query: str) -> dict:
    """Search Wikipedia for a query and return up to 2 documents."""
    docs = WikipediaLoader(query=query, load_max_docs=2).load()
    formatted = "\n\n---\n\n".join(
        f'<Document source="{d.metadata["source"]}"/>\n{d.page_content}'
        for d in docs
    )
    return {"wiki_results": formatted}

@tool
def web_search(query: str) -> dict:
    """Perform a web search (via Tavily) and return up to 3 results."""
    docs = TavilySearchResults(max_results=3).invoke(query=query)
    formatted = "\n\n---\n\n".join(
        f'<Document source="{d.metadata["source"]}"/>\n{d.page_content}'
        for d in docs
    )
    return {"web_results": formatted}

@tool
def arvix_search(query: str) -> dict:
    """Search arXiv for a query and return up to 3 paper excerpts."""
    docs = ArxivLoader(query=query, load_max_docs=3).load()
    formatted = "\n\n---\n\n".join(
        f'<Document source="{d.metadata["source"]}"/>\n{d.page_content[:1000]}'
        for d in docs
    )
    return {"arvix_results": formatted}

OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
HF_SPACE_TOKEN = os.getenv("HF_SPACE_TOKEN")


tools = [
    multiply, add, subtract, divide, modulus,
    wiki_search, web_search, arvix_search,
]


with open("prompt.txt", "r", encoding="utf-8") as f:
    system_prompt = f.read()
sys_msg = SystemMessage(content=system_prompt)


def build_graph(provider: str = "openai"):
    """Build the LangGraph agent with chosen LLM (default: OpenAI)."""
    if provider == "openai":
        llm = ChatOpenAI(
            model_name="o4-mini-2025-04-16",
            openai_api_key=OPENAI_API_KEY,
            # no temperature override here
        )
    elif provider == "huggingface":
        llm = ChatHuggingFace(
            llm=HuggingFaceEndpoint(
                url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
            ),
            temperature=0,
        )
    else:
        raise ValueError("Invalid provider. Choose 'openai' or 'huggingface'.")

    llm_with_tools = llm.bind_tools(tools)

    def assistant(state: MessagesState):
        return {"messages": [llm_with_tools.invoke(state["messages"])]}

    builder = StateGraph(MessagesState)
    builder.add_node("assistant", assistant)
    builder.add_node("tools", ToolNode(tools))
    builder.add_edge(START, "assistant")
    builder.add_conditional_edges("assistant", tools_condition)
    builder.add_edge("tools", "assistant")

    return builder.compile()

if __name__ == "__main__":
    graph = build_graph()
    msgs = graph.invoke({"messages":[ HumanMessage(content="What’s the capital of France?") ]})
    for m in msgs["messages"]:
        m.pretty_print()