agent-gaia / agent.py
dleandro's picture
Tools, Retriever, Systemp prompt and Agent creation
3927a42
import os
from dotenv import load_dotenv
from typing import TypedDict, List, Dict, Any, Optional
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
from langchain_huggingface.chat_models import ChatHuggingFace
from langchain_groq.chat_models import ChatGroq
from langgraph.graph.message import add_messages
from langgraph.graph import StateGraph, START, END, MessagesState
from langgraph.prebuilt import ToolNode, tools_condition
from tools import (
add,
subtract, multiply, div, modulus, power,
wikipedia_search, search_web, arxiv_search,
save_and_read_file, download_file_from_url, extract_text_from_image,
pdf_loader
)
from retriever import get_retriever_tool
load_dotenv(dotenv_path = ".env")
# Configurations
SYSTEM_PROMPT_PATH = "system_prompt.txt"
DEFAULT_PROVIDER = "groq"
MODEL_NAME = "llama3-70b-8192"
def load_system_prompt(path: str = SYSTEM_PROMPT_PATH) -> str:
if not os.path.exists(path):
raise ValueError(f"System prompt file not foud at: {path}")
with open(path, "r", encoding = "utf-8") as f:
return f.read()
system_prompt = load_system_prompt()
sys_msg = SystemMessage(content = system_prompt)
# Load tools
vector_store, vector_retriever, retriever_tool = get_retriever_tool()
TOOLS = [
# Math
add, subtract, multiply, div, modulus, power,
# Documents Search
wikipedia_search, search_web, arxiv_search,
# Process Files
save_and_read_file, download_file_from_url, extract_text_from_image,
pdf_loader,
# Retriever
retriever_tool
]
def get_llm(provider: str = DEFAULT_PROVIDER):
if provider == "groq":
return ChatGroq(model = MODEL_NAME, temperature = 0)
elif provider == "huggingface":
raise NotImplementedError("HuggingFace support not yet implemented.")
else:
raise ValueError("Invalid LLM provider. Choose 'groq' or 'huggingface'")
def build_graph(provider: str = DEFAULT_PROVIDER):
"""
Builds LangGraph graph
"""
llm = get_llm(provider)
# Add tools to the LLM
llm_with_tools = llm.bind_tools(TOOLS)
def assistant(state: MessagesState):
return {"messages": llm_with_tools.invoke(state["messages"])}
def retriever(state: MessagesState):
query = state["messages"][0].content
similar_qas = vector_store.similarity_search(query)
if similar_qas:
reference = similar_qas[0].page_content
example_qa = HumanMessage(
content = f"I provide a similar question and answer for reference:\n\n{reference}"
)
return {"messages": [sys_msg] + state["messages"] + [example_qa]}
else:
return {"messages": [sys_msg] + state["messages"]}
# Graph
builder = StateGraph(MessagesState)
# Nodes
builder.add_node("retriever", retriever)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(TOOLS))
# Edges
builder.add_edge(START, "retriever")
builder.add_edge("retriever", "assistant")
builder.add_conditional_edges(
"assistant",
tools_condition
)
builder.add_edge("tools", "assistant")
return builder.compile()
if __name__ == "__main__":
import random
import json
with open("metadata.jsonl") as dataset_file:
json_list = list(dataset_file)
QAs = [json.loads(qa) for qa in json_list]
question = random.choice(QAs)["Question"]
graph = build_graph()
messages = [HumanMessage(content = question)]
messages = graph.invoke({"messages": messages})
for m in messages["messages"]:
m.pretty_print()