File size: 4,161 Bytes
182160d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import tools_condition, ToolNode
from langchain_openai import ChatOpenAI
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.tools import tool

#tools
@tool
def multiply(a: int, b: int) -> int:
    return a * b

@tool
def add(a: int, b: int) -> int:
    return a + b

@tool
def subtract(a: int, b: int) -> int:
    return a - b

@tool
def divide(a: int, b: int) -> float:
    if b == 0:
        raise ValueError("Cannot divide by zero.")
    return a / b

@tool
def modulus(a: int, b: int) -> int:
    return a % b

@tool
def wiki_search(query: str) -> dict:
    docs = WikipediaLoader(query=query, load_max_docs=2).load()
    formatted = "\n\n---\n\n".join(
        f'<Document source="{d.metadata["source"]}"/>\n{d.page_content}'
        for d in docs
    )
    return {"wiki_results": formatted}

@tool
def web_search(query: str) -> dict:
    docs = TavilySearchResults(max_results=3).invoke(query=query)
    formatted = "\n\n---\n\n".join(
        f'<Document source="{d.metadata["source"]}"/>\n{d.page_content}'
        for d in docs
    )
    return {"web_results": formatted}

@tool
def arvix_search(query: str) -> dict:
    docs = ArxivLoader(query=query, load_max_docs=3).load()
    formatted = "\n\n---\n\n".join(
        f'<Document source="{d.metadata["source"]}"/>\n{d.page_content[:1000]}'
        for d in docs
    )
    return {"arvix_results": formatted}


OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
HF_SPACE_TOKEN = os.getenv("HF_SPACE_TOKEN")

# ───────────────────────────────────────────────────────────────────────────────
# 4) Assemble tool list
tools = [
    multiply, add, subtract, divide, modulus,
    wiki_search, web_search, arvix_search,
]

# ───────────────────────────────────────────────────────────────────────────────
# 5) Load your system prompt
with open("system_prompt.txt", "r", encoding="utf-8") as f:
    system_prompt = f.read()
sys_msg = SystemMessage(content=system_prompt)

# ───────────────────────────────────────────────────────────────────────────────
def build_graph(provider: str = "openai"):
    """Build the LangGraph agent with chosen LLM (default: OpenAI)."""
    if provider == "openai":
        llm = ChatOpenAI(
            model_name="o4-mini-2025-04-16",
            openai_api_key=OPENAI_API_KEY,
            # no temperature override here
        )
    elif provider == "huggingface":
        llm = ChatHuggingFace(
            llm=HuggingFaceEndpoint(
                url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
            ),
            temperature=0,
        )
    else:
        raise ValueError("Invalid provider. Choose 'openai' or 'huggingface'.")

    llm_with_tools = llm.bind_tools(tools)

    def assistant(state: MessagesState):
        return {"messages": [llm_with_tools.invoke(state["messages"])]}

    builder = StateGraph(MessagesState)
    builder.add_node("assistant", assistant)
    builder.add_node("tools", ToolNode(tools))
    builder.add_edge(START, "assistant")
    builder.add_conditional_edges("assistant", tools_condition)
    builder.add_edge("tools", "assistant")

    return builder.compile()

if __name__ == "__main__":
    graph = build_graph()
    msgs = graph.invoke({"messages":[ HumanMessage(content="What’s the capital of France?") ]})
    for m in msgs["messages"]:
        m.pretty_print()