safety-copilot / main.py
codelion's picture
Update main.py
d4f6a15 verified
# main.py
import os
import streamlit as st
import anthropic
from requests import JSONDecodeError
# Updated imports for latest LangChain
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import SupabaseVectorStore
from langchain_openai import ChatOpenAI
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
# Updated memory and chain imports
from langchain.memory import ConversationBufferMemory
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import HumanMessage, AIMessage
from supabase import Client, create_client
from streamlit.logger import get_logger
from stats import get_usage, add_usage
# ─────── supabase + secrets ────────────────────────────────────────────────────
supabase_url = st.secrets.SUPABASE_URL
supabase_key = st.secrets.SUPABASE_KEY
openai_api_key = st.secrets.openai_api_key
anthropic_api_key = st.secrets.anthropic_api_key
hf_api_key = st.secrets.hf_api_key
username = st.secrets.username
supabase: Client = create_client(supabase_url, supabase_key)
logger = get_logger(__name__)
# ─────── embeddings (Updated to use langchain-huggingface) ─────────────────────
embeddings = HuggingFaceEmbeddings(
model_name="BAAI/bge-large-en-v1.5",
model_kwargs={"device": "cpu"},
encode_kwargs={"normalize_embeddings": True}
)
# ─────── vector store ──────────────────────────────────────────────────────────
vector_store = SupabaseVectorStore(
client=supabase,
embedding=embeddings,
query_name="match_documents",
table_name="documents",
)
# ─────── LLM setup ──────────────────────────────────────────────────────────────
model = "HuggingFaceTB/SmolLM3-3B"
temperature = 0.1
max_tokens = 500
import re
def clean_response(answer: str) -> str:
"""Clean up AI response by removing unwanted artifacts and formatting."""
if not answer:
return answer
# Remove thinking tags and content
answer = re.sub(r'<think>.*?</think>', '', answer, flags=re.DOTALL)
answer = re.sub(r'<thinking>.*?</thinking>', '', answer, flags=re.DOTALL)
# Remove other common AI response artifacts
answer = re.sub(r'\[.*?\]', '', answer, flags=re.DOTALL)
answer = re.sub(r'\{.*?\}', '', answer, flags=re.DOTALL)
answer = re.sub(r'```.*?```', '', answer, flags=re.DOTALL)
answer = re.sub(r'---.*?---', '', answer, flags=re.DOTALL)
# Remove excessive whitespace and newlines
answer = re.sub(r'\s+', ' ', answer).strip()
# Remove common AI-generated prefixes/suffixes
answer = re.sub(r'^(Assistant:|AI:|Grok:)\s*', '', answer, flags=re.IGNORECASE)
answer = re.sub(r'\s*(Sincerely,.*|Best regards,.*|Regards,.*)$', '', answer, flags=re.IGNORECASE)
return answer
def create_conversational_rag_chain():
"""Create a modern conversational RAG chain using LCEL."""
# Create the HuggingFace LLM
llm = ChatOpenAI(
base_url=f"https://router.huggingface.co/hf-inference/models/{model}/v1",
api_key=hf_api_key,
model=model,
temperature=temperature,
max_tokens=max_tokens,
timeout=30,
max_retries=3,
)
# Create retriever
retriever = vector_store.as_retriever(
search_kwargs={"score_threshold": 0.6, "k": 4, "filter": {"user": username}}
)
# Create system prompt for RAG
system_prompt = """You are a helpful safety assistant. Use the following pieces of retrieved context to answer the question.
If you don't know the answer based on the context, just say that you don't have enough information to answer that question.
Context: {context}
Chat History: {chat_history}
Question: {input}
Answer:"""
prompt = ChatPromptTemplate.from_messages([
("system", system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
])
# Create document processing chain
question_answer_chain = create_stuff_documents_chain(llm, prompt)
# Create retrieval chain
rag_chain = create_retrieval_chain(retriever, question_answer_chain)
return rag_chain
def response_generator(query: str, chat_history: list) -> str:
"""Ask the RAG chain to answer `query`, with JSON‑error fallback."""
# log usage
add_usage(supabase, "chat", "prompt:" + query, {"model": model, "temperature": temperature})
logger.info("Using HF model %s", model)
# Create the RAG chain
rag_chain = create_conversational_rag_chain()
# Format chat history for the chain
formatted_history = []
for msg in chat_history:
if msg["role"] == "user":
formatted_history.append(HumanMessage(content=msg["content"]))
elif msg["role"] == "assistant":
formatted_history.append(AIMessage(content=msg["content"]))
try:
result = rag_chain.invoke({
"input": query,
"chat_history": formatted_history
})
answer = result.get("answer", "")
context = result.get("context", [])
if not context:
return (
"I'm sorry, I don't have enough information to answer that. "
"If you have a public data source to add, please email copilot@securade.ai."
)
answer = clean_response(answer)
return answer
except JSONDecodeError as e:
logger.error("JSONDecodeError: %s", e)
return "Sorry, I had trouble processing your request. Please try again."
except Exception as e:
logger.error("Unexpected error: %s", e)
return "Sorry, I encountered an error while processing your request. Please try again."
# ─────── Streamlit UI ──────────────────────────────────────────────────────────
st.set_page_config(
page_title="Securade.ai - Safety Copilot",
page_icon="https://securade.ai/favicon.ico",
layout="centered",
initial_sidebar_state="collapsed",
menu_items={
"About": "# Securade.ai Safety Copilot v0.1\n[https://securade.ai](https://securade.ai)",
"Get Help": "https://securade.ai",
"Report a Bug": "mailto:hello@securade.ai",
},
)
st.title("πŸ‘·β€β™‚οΈ Safety Copilot 🦺")
stats = get_usage(supabase)
st.markdown(f"_{stats} queries answered!_")
st.markdown(
"Chat with your personal safety assistant about any health & safety related queries. "
"[[blog](https://securade.ai/blog/how-securade-ai-safety-copilot-transforms-worker-safety.html)"
"|[paper](https://securade.ai/assets/pdfs/Securade.ai-Safety-Copilot-Whitepaper.pdf)]"
)
# Initialize chat history
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
# Display chat history
for msg in st.session_state.chat_history:
with st.chat_message(msg["role"]):
st.markdown(msg["content"])
# Handle new user input
if prompt := st.chat_input("Ask a question"):
# Add user message to history
st.session_state.chat_history.append({"role": "user", "content": prompt})
# Display user message
with st.chat_message("user"):
st.markdown(prompt)
# Generate and display response
with st.spinner("Safety briefing in progress..."):
answer = response_generator(prompt, st.session_state.chat_history[:-1]) # Exclude current message
with st.chat_message("assistant"):
st.markdown(answer)
# Add assistant response to history
st.session_state.chat_history.append({"role": "assistant", "content": answer})