import os from dotenv import load_dotenv from typing import TypedDict, List, Dict, Any, Optional from langchain_core.messages import SystemMessage, HumanMessage, AIMessage from langchain_huggingface.chat_models import ChatHuggingFace from langchain_groq.chat_models import ChatGroq from langgraph.graph.message import add_messages from langgraph.graph import StateGraph, START, END, MessagesState from langgraph.prebuilt import ToolNode, tools_condition from tools import ( add, subtract, multiply, div, modulus, power, wikipedia_search, search_web, arxiv_search, save_and_read_file, download_file_from_url, extract_text_from_image, pdf_loader ) from retriever import get_retriever_tool load_dotenv(dotenv_path = ".env") # Configurations SYSTEM_PROMPT_PATH = "system_prompt.txt" DEFAULT_PROVIDER = "groq" MODEL_NAME = "llama3-70b-8192" def load_system_prompt(path: str = SYSTEM_PROMPT_PATH) -> str: if not os.path.exists(path): raise ValueError(f"System prompt file not foud at: {path}") with open(path, "r", encoding = "utf-8") as f: return f.read() system_prompt = load_system_prompt() sys_msg = SystemMessage(content = system_prompt) # Load tools vector_store, vector_retriever, retriever_tool = get_retriever_tool() TOOLS = [ # Math add, subtract, multiply, div, modulus, power, # Documents Search wikipedia_search, search_web, arxiv_search, # Process Files save_and_read_file, download_file_from_url, extract_text_from_image, pdf_loader, # Retriever retriever_tool ] def get_llm(provider: str = DEFAULT_PROVIDER): if provider == "groq": return ChatGroq(model = MODEL_NAME, temperature = 0) elif provider == "huggingface": raise NotImplementedError("HuggingFace support not yet implemented.") else: raise ValueError("Invalid LLM provider. Choose 'groq' or 'huggingface'") def build_graph(provider: str = DEFAULT_PROVIDER): """ Builds LangGraph graph """ llm = get_llm(provider) # Add tools to the LLM llm_with_tools = llm.bind_tools(TOOLS) def assistant(state: MessagesState): return {"messages": llm_with_tools.invoke(state["messages"])} def retriever(state: MessagesState): query = state["messages"][0].content similar_qas = vector_store.similarity_search(query) if similar_qas: reference = similar_qas[0].page_content example_qa = HumanMessage( content = f"I provide a similar question and answer for reference:\n\n{reference}" ) return {"messages": [sys_msg] + state["messages"] + [example_qa]} else: return {"messages": [sys_msg] + state["messages"]} # Graph builder = StateGraph(MessagesState) # Nodes builder.add_node("retriever", retriever) builder.add_node("assistant", assistant) builder.add_node("tools", ToolNode(TOOLS)) # Edges builder.add_edge(START, "retriever") builder.add_edge("retriever", "assistant") builder.add_conditional_edges( "assistant", tools_condition ) builder.add_edge("tools", "assistant") return builder.compile() if __name__ == "__main__": import random import json with open("metadata.jsonl") as dataset_file: json_list = list(dataset_file) QAs = [json.loads(qa) for qa in json_list] question = random.choice(QAs)["Question"] graph = build_graph() messages = [HumanMessage(content = question)] messages = graph.invoke({"messages": messages}) for m in messages["messages"]: m.pretty_print()