Spaces:
Sleeping
Sleeping
import os | |
from dotenv import load_dotenv | |
from typing import TypedDict, List, Dict, Any, Optional | |
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage | |
from langchain_huggingface.chat_models import ChatHuggingFace | |
from langchain_groq.chat_models import ChatGroq | |
from langgraph.graph.message import add_messages | |
from langgraph.graph import StateGraph, START, END, MessagesState | |
from langgraph.prebuilt import ToolNode, tools_condition | |
from tools import ( | |
add, | |
subtract, multiply, div, modulus, power, | |
wikipedia_search, search_web, arxiv_search, | |
save_and_read_file, download_file_from_url, extract_text_from_image, | |
pdf_loader | |
) | |
from retriever import get_retriever_tool | |
load_dotenv(dotenv_path = ".env") | |
# Configurations | |
SYSTEM_PROMPT_PATH = "system_prompt.txt" | |
DEFAULT_PROVIDER = "groq" | |
MODEL_NAME = "llama3-70b-8192" | |
def load_system_prompt(path: str = SYSTEM_PROMPT_PATH) -> str: | |
if not os.path.exists(path): | |
raise ValueError(f"System prompt file not foud at: {path}") | |
with open(path, "r", encoding = "utf-8") as f: | |
return f.read() | |
system_prompt = load_system_prompt() | |
sys_msg = SystemMessage(content = system_prompt) | |
# Load tools | |
vector_store, vector_retriever, retriever_tool = get_retriever_tool() | |
TOOLS = [ | |
# Math | |
add, subtract, multiply, div, modulus, power, | |
# Documents Search | |
wikipedia_search, search_web, arxiv_search, | |
# Process Files | |
save_and_read_file, download_file_from_url, extract_text_from_image, | |
pdf_loader, | |
# Retriever | |
retriever_tool | |
] | |
def get_llm(provider: str = DEFAULT_PROVIDER): | |
if provider == "groq": | |
return ChatGroq(model = MODEL_NAME, temperature = 0) | |
elif provider == "huggingface": | |
raise NotImplementedError("HuggingFace support not yet implemented.") | |
else: | |
raise ValueError("Invalid LLM provider. Choose 'groq' or 'huggingface'") | |
def build_graph(provider: str = DEFAULT_PROVIDER): | |
""" | |
Builds LangGraph graph | |
""" | |
llm = get_llm(provider) | |
# Add tools to the LLM | |
llm_with_tools = llm.bind_tools(TOOLS) | |
def assistant(state: MessagesState): | |
return {"messages": llm_with_tools.invoke(state["messages"])} | |
def retriever(state: MessagesState): | |
query = state["messages"][0].content | |
similar_qas = vector_store.similarity_search(query) | |
if similar_qas: | |
reference = similar_qas[0].page_content | |
example_qa = HumanMessage( | |
content = f"I provide a similar question and answer for reference:\n\n{reference}" | |
) | |
return {"messages": [sys_msg] + state["messages"] + [example_qa]} | |
else: | |
return {"messages": [sys_msg] + state["messages"]} | |
# Graph | |
builder = StateGraph(MessagesState) | |
# Nodes | |
builder.add_node("retriever", retriever) | |
builder.add_node("assistant", assistant) | |
builder.add_node("tools", ToolNode(TOOLS)) | |
# Edges | |
builder.add_edge(START, "retriever") | |
builder.add_edge("retriever", "assistant") | |
builder.add_conditional_edges( | |
"assistant", | |
tools_condition | |
) | |
builder.add_edge("tools", "assistant") | |
return builder.compile() | |
if __name__ == "__main__": | |
import random | |
import json | |
with open("metadata.jsonl") as dataset_file: | |
json_list = list(dataset_file) | |
QAs = [json.loads(qa) for qa in json_list] | |
question = random.choice(QAs)["Question"] | |
graph = build_graph() | |
messages = [HumanMessage(content = question)] | |
messages = graph.invoke({"messages": messages}) | |
for m in messages["messages"]: | |
m.pretty_print() |