R-help-chat / main.py
jedick
Use @spaces.GPU(duration=100)
6f5111d
from transformers import AutoModelForCausalLM, AutoTokenizer
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_core.prompts import ChatPromptTemplate
from langgraph.checkpoint.memory import MemorySaver
from langchain_core.messages import SystemMessage
from langchain_core.messages import ToolMessage
from dotenv import load_dotenv
from datetime import datetime
import logging
import torch
import glob
import ast
import os
# Imports for local and remote chat models
from langchain_huggingface import ChatHuggingFace, HuggingFacePipeline
from langchain_openai import ChatOpenAI
# Local modules
from pipeline import MyTextGenerationPipeline
from retriever import BuildRetriever, db_dir
from prompts import answer_prompt
from index import ProcessFile
from graph import BuildGraph
# -----------
# R-help-chat
# -----------
# First version by Jeffrey Dick on 2025-06-29
# Setup environment variables
load_dotenv(dotenv_path=".env", override=True)
# Define the remote (OpenAI) model
openai_model = "gpt-4o-mini"
# Get the local model ID
model_id = os.getenv("MODEL_ID")
if model_id is None:
# model_id = "HuggingFaceTB/SmolLM3-3B"
model_id = "google/gemma-3-12b-it"
# model_id = "Qwen/Qwen3-14B"
# Suppress these messages:
# INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
# INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
# https://community.openai.com/t/suppress-http-request-post-message/583334/8
httpx_logger = logging.getLogger("httpx")
httpx_logger.setLevel(logging.WARNING)
def ProcessDirectory(path, compute_mode):
"""
Update vector store and sparse index for files in a directory, only adding new or updated files
Args:
path: Directory to process
compute_mode: Compute mode for embeddings (remote or local)
Usage example:
ProcessDirectory("R-help", "remote")
"""
# TODO: use UUID to process only changed documents
# https://stackoverflow.com/questions/76265631/chromadb-add-single-document-only-if-it-doesnt-exist
# Get a dense retriever instance
retriever = BuildRetriever(compute_mode, "dense")
# List all text files in target directory
file_paths = glob.glob(f"{path}/*.txt")
for file_path in file_paths:
# Process file for sparse search (BM25S)
ProcessFile(file_path, "sparse", compute_mode)
# Logic for dense search: skip file if already indexed
# Look for existing embeddings for this file
results = retriever.vectorstore.get(
# Metadata key-value pair
where={"source": file_path}
)
# Flag to add or update file
add_file = False
update_file = False
# If file path doesn't exist in vector store, then add it
if len(results["ids"]) == 0:
add_file = True
else:
# Check file timestamp to decide whether to update embeddings
mod_time = os.path.getmtime(file_path)
timestamp = datetime.fromtimestamp(mod_time).isoformat()
# Loop over metadata and compare to actual file timestamp
for metadata in results["metadatas"]:
# Process file if any of embeddings has a different timestamp
if not metadata["timestamp"] == timestamp:
add_file = True
break
# Delete the old embeddings
if add_file:
retriever.vectorstore.delete(results["ids"])
update_file = True
if add_file:
ProcessFile(file_path, "dense", compute_mode)
if update_file:
print(f"Chroma: updated embeddings for {file_path}")
# Clear out the unused parent files
# The used doc_ids are the files to keep
used_doc_ids = [
d["doc_id"] for d in retriever.vectorstore.get()["metadatas"]
]
files_to_keep = list(set(used_doc_ids))
# Get all files in the file store
file_store = f"{db_dir}/file_store_{compute_mode}"
all_files = os.listdir(file_store)
# Iterate through the files and delete those not in the list
for file in all_files:
if file not in files_to_keep:
file_path = os.path.join(file_store, file)
os.remove(file_path)
elif add_file:
print(f"Chroma: added embeddings for {file_path}")
else:
print(f"Chroma: no change for {file_path}")
def GetChatModel(compute_mode, ckpt_dir=None):
"""
Get a chat model.
Args:
compute_mode: Compute mode for chat model (remote or local)
ckpt_dir: Checkpoint directory for model weights (optional)
"""
if compute_mode == "remote":
chat_model = ChatOpenAI(model=openai_model, temperature=0)
if compute_mode == "local":
# Don't try to use local models without a GPU
if compute_mode == "local" and not torch.cuda.is_available():
raise Exception("Local chat model selected without GPU")
# Define the pipeline to pass to the HuggingFacePipeline class
# https://huggingface.co/blog/langchain
id_or_dir = ckpt_dir if ckpt_dir else model_id
tokenizer = AutoTokenizer.from_pretrained(id_or_dir)
model = AutoModelForCausalLM.from_pretrained(
id_or_dir,
# We need this to load the model in BF16 instead of fp32 (torch.float)
torch_dtype=torch.bfloat16,
# Enable FlashAttention (requires pip install flash-attn)
# https://huggingface.co/docs/transformers/en/attention_interface
# https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention
attn_implementation="flash_attention_2",
)
# For Flash Attention version of Qwen3
tokenizer.padding_side = "left"
# Use MyTextGenerationPipeline with custom preprocess() method
pipe = MyTextGenerationPipeline(
model=model,
tokenizer=tokenizer,
# ToolCallingLLM needs return_full_text=False in order to parse just the assistant response
return_full_text=False,
# It seems that max_new_tokens has to be specified here, not in .invoke()
max_new_tokens=2000,
# Use padding for proper alignment for FlashAttention
# Part of fix for: "RuntimeError: p.attn_bias_ptr is not correctly aligned"
# https://github.com/google-deepmind/gemma/issues/169
padding="longest",
)
# We need the task so HuggingFacePipeline can deal with our class
pipe.task = "text-generation"
llm = HuggingFacePipeline(pipeline=pipe)
chat_model = ChatHuggingFace(llm=llm)
return chat_model
def RunChain(
query,
compute_mode: str = "remote",
search_type: str = "hybrid",
think: bool = False,
):
"""
Run chain to retrieve documents and send to chat
Args:
query: User's query
compute_mode: Compute mode for embedding and chat models (remote or local)
search_type: Type of search to use. Options: "dense", "sparse", or "hybrid"
think: Control thinking mode for SmolLM3
Example:
RunChain("What R functions are discussed?")
"""
# Get retriever instance
retriever = BuildRetriever(compute_mode, search_type)
if retriever is None:
return "No retriever available. Please process some documents first."
# Get chat model (LLM)
chat_model = GetChatModel(compute_mode)
# Get prompt with /no_think for SmolLM3/Qwen
system_prompt = answer_prompt(chat_model)
# Create a prompt template
system_template = ChatPromptTemplate.from_messages([SystemMessage(system_prompt)])
# NOTE: Each new email starts with \n\n\nFrom, so we don't need newlines after Retrieved Emails:
human_template = ChatPromptTemplate.from_template(
""""
### Question:
{question}
### Retrieved Emails:{context}
"""
)
prompt_template = system_template + human_template
# Build an LCEL retrieval chain
chain = (
{"context": retriever, "question": RunnablePassthrough()}
| prompt_template
| chat_model
| StrOutputParser()
)
# Invoke the retrieval chain
result = chain.invoke(query)
return result
def RunGraph(
query: str,
compute_mode: str = "remote",
search_type: str = "hybrid",
top_k: int = 6,
think_query=False,
think_answer=False,
thread_id=None,
):
"""Run graph for conversational RAG app
Args:
query: User query to start the chat
compute_mode: Compute mode for embedding and chat models (remote or local)
search_type: Type of search to use. Options: "dense", "sparse", or "hybrid"
top_k: Number of documents to retrieve
think_query: Whether to use thinking mode for the query
think_answer: Whether to use thinking mode for the answer
thread_id: Thread ID for memory (optional)
Example:
RunGraph("Help with parsing REST API response.")
"""
# Get chat model used in both query and generate steps
chat_model = GetChatModel(compute_mode)
# Build the graph
graph_builder = BuildGraph(
chat_model,
compute_mode,
search_type,
top_k,
think_query,
think_answer,
)
# Compile the graph with an in-memory checkpointer
memory = MemorySaver()
graph = graph_builder.compile(checkpointer=memory)
# Specify an ID for the thread
config = {"configurable": {"thread_id": thread_id}}
# Stream the steps to observe the query generation, retrieval, and answer generation:
# - User input as a HumanMessage
# - Vector store query as an AIMessage with tool calls
# - Retrieved documents as a ToolMessage.
# - Final response as a AIMessage
for state in graph.stream(
{"messages": [{"role": "user", "content": query}]},
stream_mode="values",
config=config,
):
if not state["messages"][-1].type == "tool":
state["messages"][-1].pretty_print()
# Parse the messages for the answer and citations
try:
answer, citations = ast.literal_eval(state["messages"][-1].content)
except:
# In case we got an answer without citations
answer = state["messages"][-1].content
citations = None
result = {"answer": answer}
if citations:
result["citations"] = citations
# Parse tool messages to get retrieved emails
tool_messages = [msg for msg in state["messages"] if type(msg) == ToolMessage]
# Get content from the most recent retrieve_emails response
content = None
for msg in tool_messages:
if msg.name == "retrieve_emails":
content = msg.content
# Parse it into a list of emails
if content:
retrieved_emails = content.replace("### Retrieved Emails:\n\n\n\n", "").split(
"--- --- --- --- Next Email --- --- --- ---\n\n"
)
result["retrieved_emails"] = retrieved_emails
return result