Spaces:
Running
on
Zero
Running
on
Zero
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain_core.prompts import ChatPromptTemplate | |
from langgraph.checkpoint.memory import MemorySaver | |
from langchain_core.messages import SystemMessage | |
from langchain_core.messages import ToolMessage | |
from dotenv import load_dotenv | |
from datetime import datetime | |
import logging | |
import torch | |
import glob | |
import ast | |
import os | |
# Imports for local and remote chat models | |
from langchain_huggingface import ChatHuggingFace, HuggingFacePipeline | |
from langchain_openai import ChatOpenAI | |
# Local modules | |
from pipeline import MyTextGenerationPipeline | |
from retriever import BuildRetriever, db_dir | |
from prompts import answer_prompt | |
from index import ProcessFile | |
from graph import BuildGraph | |
# ----------- | |
# R-help-chat | |
# ----------- | |
# First version by Jeffrey Dick on 2025-06-29 | |
# Setup environment variables | |
load_dotenv(dotenv_path=".env", override=True) | |
# Define the remote (OpenAI) model | |
openai_model = "gpt-4o-mini" | |
# Get the local model ID | |
model_id = os.getenv("MODEL_ID") | |
if model_id is None: | |
# model_id = "HuggingFaceTB/SmolLM3-3B" | |
model_id = "google/gemma-3-12b-it" | |
# model_id = "Qwen/Qwen3-14B" | |
# Suppress these messages: | |
# INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK" | |
# INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK" | |
# https://community.openai.com/t/suppress-http-request-post-message/583334/8 | |
httpx_logger = logging.getLogger("httpx") | |
httpx_logger.setLevel(logging.WARNING) | |
def ProcessDirectory(path, compute_mode): | |
""" | |
Update vector store and sparse index for files in a directory, only adding new or updated files | |
Args: | |
path: Directory to process | |
compute_mode: Compute mode for embeddings (remote or local) | |
Usage example: | |
ProcessDirectory("R-help", "remote") | |
""" | |
# TODO: use UUID to process only changed documents | |
# https://stackoverflow.com/questions/76265631/chromadb-add-single-document-only-if-it-doesnt-exist | |
# Get a dense retriever instance | |
retriever = BuildRetriever(compute_mode, "dense") | |
# List all text files in target directory | |
file_paths = glob.glob(f"{path}/*.txt") | |
for file_path in file_paths: | |
# Process file for sparse search (BM25S) | |
ProcessFile(file_path, "sparse", compute_mode) | |
# Logic for dense search: skip file if already indexed | |
# Look for existing embeddings for this file | |
results = retriever.vectorstore.get( | |
# Metadata key-value pair | |
where={"source": file_path} | |
) | |
# Flag to add or update file | |
add_file = False | |
update_file = False | |
# If file path doesn't exist in vector store, then add it | |
if len(results["ids"]) == 0: | |
add_file = True | |
else: | |
# Check file timestamp to decide whether to update embeddings | |
mod_time = os.path.getmtime(file_path) | |
timestamp = datetime.fromtimestamp(mod_time).isoformat() | |
# Loop over metadata and compare to actual file timestamp | |
for metadata in results["metadatas"]: | |
# Process file if any of embeddings has a different timestamp | |
if not metadata["timestamp"] == timestamp: | |
add_file = True | |
break | |
# Delete the old embeddings | |
if add_file: | |
retriever.vectorstore.delete(results["ids"]) | |
update_file = True | |
if add_file: | |
ProcessFile(file_path, "dense", compute_mode) | |
if update_file: | |
print(f"Chroma: updated embeddings for {file_path}") | |
# Clear out the unused parent files | |
# The used doc_ids are the files to keep | |
used_doc_ids = [ | |
d["doc_id"] for d in retriever.vectorstore.get()["metadatas"] | |
] | |
files_to_keep = list(set(used_doc_ids)) | |
# Get all files in the file store | |
file_store = f"{db_dir}/file_store_{compute_mode}" | |
all_files = os.listdir(file_store) | |
# Iterate through the files and delete those not in the list | |
for file in all_files: | |
if file not in files_to_keep: | |
file_path = os.path.join(file_store, file) | |
os.remove(file_path) | |
elif add_file: | |
print(f"Chroma: added embeddings for {file_path}") | |
else: | |
print(f"Chroma: no change for {file_path}") | |
def GetChatModel(compute_mode, ckpt_dir=None): | |
""" | |
Get a chat model. | |
Args: | |
compute_mode: Compute mode for chat model (remote or local) | |
ckpt_dir: Checkpoint directory for model weights (optional) | |
""" | |
if compute_mode == "remote": | |
chat_model = ChatOpenAI(model=openai_model, temperature=0) | |
if compute_mode == "local": | |
# Don't try to use local models without a GPU | |
if compute_mode == "local" and not torch.cuda.is_available(): | |
raise Exception("Local chat model selected without GPU") | |
# Define the pipeline to pass to the HuggingFacePipeline class | |
# https://huggingface.co/blog/langchain | |
id_or_dir = ckpt_dir if ckpt_dir else model_id | |
tokenizer = AutoTokenizer.from_pretrained(id_or_dir) | |
model = AutoModelForCausalLM.from_pretrained( | |
id_or_dir, | |
# We need this to load the model in BF16 instead of fp32 (torch.float) | |
torch_dtype=torch.bfloat16, | |
# Enable FlashAttention (requires pip install flash-attn) | |
# https://huggingface.co/docs/transformers/en/attention_interface | |
# https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention | |
attn_implementation="flash_attention_2", | |
) | |
# For Flash Attention version of Qwen3 | |
tokenizer.padding_side = "left" | |
# Use MyTextGenerationPipeline with custom preprocess() method | |
pipe = MyTextGenerationPipeline( | |
model=model, | |
tokenizer=tokenizer, | |
# ToolCallingLLM needs return_full_text=False in order to parse just the assistant response | |
return_full_text=False, | |
# It seems that max_new_tokens has to be specified here, not in .invoke() | |
max_new_tokens=2000, | |
# Use padding for proper alignment for FlashAttention | |
# Part of fix for: "RuntimeError: p.attn_bias_ptr is not correctly aligned" | |
# https://github.com/google-deepmind/gemma/issues/169 | |
padding="longest", | |
) | |
# We need the task so HuggingFacePipeline can deal with our class | |
pipe.task = "text-generation" | |
llm = HuggingFacePipeline(pipeline=pipe) | |
chat_model = ChatHuggingFace(llm=llm) | |
return chat_model | |
def RunChain( | |
query, | |
compute_mode: str = "remote", | |
search_type: str = "hybrid", | |
think: bool = False, | |
): | |
""" | |
Run chain to retrieve documents and send to chat | |
Args: | |
query: User's query | |
compute_mode: Compute mode for embedding and chat models (remote or local) | |
search_type: Type of search to use. Options: "dense", "sparse", or "hybrid" | |
think: Control thinking mode for SmolLM3 | |
Example: | |
RunChain("What R functions are discussed?") | |
""" | |
# Get retriever instance | |
retriever = BuildRetriever(compute_mode, search_type) | |
if retriever is None: | |
return "No retriever available. Please process some documents first." | |
# Get chat model (LLM) | |
chat_model = GetChatModel(compute_mode) | |
# Get prompt with /no_think for SmolLM3/Qwen | |
system_prompt = answer_prompt(chat_model) | |
# Create a prompt template | |
system_template = ChatPromptTemplate.from_messages([SystemMessage(system_prompt)]) | |
# NOTE: Each new email starts with \n\n\nFrom, so we don't need newlines after Retrieved Emails: | |
human_template = ChatPromptTemplate.from_template( | |
"""" | |
### Question: | |
{question} | |
### Retrieved Emails:{context} | |
""" | |
) | |
prompt_template = system_template + human_template | |
# Build an LCEL retrieval chain | |
chain = ( | |
{"context": retriever, "question": RunnablePassthrough()} | |
| prompt_template | |
| chat_model | |
| StrOutputParser() | |
) | |
# Invoke the retrieval chain | |
result = chain.invoke(query) | |
return result | |
def RunGraph( | |
query: str, | |
compute_mode: str = "remote", | |
search_type: str = "hybrid", | |
top_k: int = 6, | |
think_query=False, | |
think_answer=False, | |
thread_id=None, | |
): | |
"""Run graph for conversational RAG app | |
Args: | |
query: User query to start the chat | |
compute_mode: Compute mode for embedding and chat models (remote or local) | |
search_type: Type of search to use. Options: "dense", "sparse", or "hybrid" | |
top_k: Number of documents to retrieve | |
think_query: Whether to use thinking mode for the query | |
think_answer: Whether to use thinking mode for the answer | |
thread_id: Thread ID for memory (optional) | |
Example: | |
RunGraph("Help with parsing REST API response.") | |
""" | |
# Get chat model used in both query and generate steps | |
chat_model = GetChatModel(compute_mode) | |
# Build the graph | |
graph_builder = BuildGraph( | |
chat_model, | |
compute_mode, | |
search_type, | |
top_k, | |
think_query, | |
think_answer, | |
) | |
# Compile the graph with an in-memory checkpointer | |
memory = MemorySaver() | |
graph = graph_builder.compile(checkpointer=memory) | |
# Specify an ID for the thread | |
config = {"configurable": {"thread_id": thread_id}} | |
# Stream the steps to observe the query generation, retrieval, and answer generation: | |
# - User input as a HumanMessage | |
# - Vector store query as an AIMessage with tool calls | |
# - Retrieved documents as a ToolMessage. | |
# - Final response as a AIMessage | |
for state in graph.stream( | |
{"messages": [{"role": "user", "content": query}]}, | |
stream_mode="values", | |
config=config, | |
): | |
if not state["messages"][-1].type == "tool": | |
state["messages"][-1].pretty_print() | |
# Parse the messages for the answer and citations | |
try: | |
answer, citations = ast.literal_eval(state["messages"][-1].content) | |
except: | |
# In case we got an answer without citations | |
answer = state["messages"][-1].content | |
citations = None | |
result = {"answer": answer} | |
if citations: | |
result["citations"] = citations | |
# Parse tool messages to get retrieved emails | |
tool_messages = [msg for msg in state["messages"] if type(msg) == ToolMessage] | |
# Get content from the most recent retrieve_emails response | |
content = None | |
for msg in tool_messages: | |
if msg.name == "retrieve_emails": | |
content = msg.content | |
# Parse it into a list of emails | |
if content: | |
retrieved_emails = content.replace("### Retrieved Emails:\n\n\n\n", "").split( | |
"--- --- --- --- Next Email --- --- --- ---\n\n" | |
) | |
result["retrieved_emails"] = retrieved_emails | |
return result | |