File size: 8,274 Bytes
63c0a0b
90fddeb
4e00df7
 
63c0a0b
90fddeb
63c0a0b
d4f6a15
 
dfd217b
d4f6a15
 
63c0a0b
d4f6a15
dfd217b
d4f6a15
 
 
 
63c0a0b
4e00df7
dfd217b
63c0a0b
a91d644
90fddeb
d4f6a15
 
 
2200d67
d4f6a15
 
4e00df7
dfd217b
63c0a0b
4e00df7
d4f6a15
 
78308ba
90fddeb
9c6f575
4e00df7
d4f6a15
 
90fddeb
 
 
 
 
 
4e00df7
90fddeb
d4f6a15
 
 
63c0a0b
bd57608
 
 
 
 
 
 
 
 
 
 
 
d4f6a15
 
 
 
bd57608
 
 
 
 
 
 
 
 
d4f6a15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd57608
d4f6a15
 
 
 
 
 
 
 
 
90fddeb
 
 
 
4e00df7
d4f6a15
 
 
 
 
 
 
 
 
 
90fddeb
 
d4f6a15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90fddeb
d4f6a15
 
 
 
 
90fddeb
 
1ca7761
 
 
 
 
 
90fddeb
 
 
 
1ca7761
cae23e1
1ca7761
90fddeb
 
 
 
 
 
 
4e00df7
d4f6a15
90fddeb
 
 
d4f6a15
90fddeb
 
 
 
d4f6a15
90fddeb
d4f6a15
dfd217b
d4f6a15
 
dfd217b
 
90fddeb
d4f6a15
90fddeb
d4f6a15
90fddeb
dfd217b
90fddeb
d4f6a15
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
# main.py

import os
import streamlit as st
import anthropic
from requests import JSONDecodeError

# Updated imports for latest LangChain
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import SupabaseVectorStore
from langchain_openai import ChatOpenAI
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint

# Updated memory and chain imports
from langchain.memory import ConversationBufferMemory
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import HumanMessage, AIMessage

from supabase import Client, create_client
from streamlit.logger import get_logger
from stats import get_usage, add_usage

# ─────── supabase + secrets ────────────────────────────────────────────────────
supabase_url = st.secrets.SUPABASE_URL
supabase_key = st.secrets.SUPABASE_KEY
openai_api_key = st.secrets.openai_api_key
anthropic_api_key = st.secrets.anthropic_api_key
hf_api_key = st.secrets.hf_api_key
username = st.secrets.username

supabase: Client = create_client(supabase_url, supabase_key)
logger = get_logger(__name__)

# ─────── embeddings (Updated to use langchain-huggingface) ─────────────────────
embeddings = HuggingFaceEmbeddings(
    model_name="BAAI/bge-large-en-v1.5",
    model_kwargs={"device": "cpu"},
    encode_kwargs={"normalize_embeddings": True}
)

# ─────── vector store ──────────────────────────────────────────────────────────
vector_store = SupabaseVectorStore(
    client=supabase,
    embedding=embeddings,
    query_name="match_documents",
    table_name="documents",
)

# ─────── LLM setup ──────────────────────────────────────────────────────────────
model = "HuggingFaceTB/SmolLM3-3B"
temperature = 0.1
max_tokens = 500

import re

def clean_response(answer: str) -> str:
    """Clean up AI response by removing unwanted artifacts and formatting."""
    if not answer:
        return answer
    
    # Remove thinking tags and content
    answer = re.sub(r'<think>.*?</think>', '', answer, flags=re.DOTALL)
    answer = re.sub(r'<thinking>.*?</thinking>', '', answer, flags=re.DOTALL)
    
    # Remove other common AI response artifacts
    answer = re.sub(r'\[.*?\]', '', answer, flags=re.DOTALL)
    answer = re.sub(r'\{.*?\}', '', answer, flags=re.DOTALL)
    answer = re.sub(r'```.*?```', '', answer, flags=re.DOTALL)
    answer = re.sub(r'---.*?---', '', answer, flags=re.DOTALL)
    
    # Remove excessive whitespace and newlines
    answer = re.sub(r'\s+', ' ', answer).strip()
    
    # Remove common AI-generated prefixes/suffixes
    answer = re.sub(r'^(Assistant:|AI:|Grok:)\s*', '', answer, flags=re.IGNORECASE)
    answer = re.sub(r'\s*(Sincerely,.*|Best regards,.*|Regards,.*)$', '', answer, flags=re.IGNORECASE)
    
    return answer

def create_conversational_rag_chain():
    """Create a modern conversational RAG chain using LCEL."""
    
    # Create the HuggingFace LLM
    llm = ChatOpenAI(
        base_url=f"https://router.huggingface.co/hf-inference/models/{model}/v1",
        api_key=hf_api_key,
        model=model,
        temperature=temperature,
        max_tokens=max_tokens,
        timeout=30,
        max_retries=3,
    )
    
    # Create retriever
    retriever = vector_store.as_retriever(
        search_kwargs={"score_threshold": 0.6, "k": 4, "filter": {"user": username}}
    )
    
    # Create system prompt for RAG
    system_prompt = """You are a helpful safety assistant. Use the following pieces of retrieved context to answer the question. 
    If you don't know the answer based on the context, just say that you don't have enough information to answer that question.
    
    Context: {context}
    
    Chat History: {chat_history}
    
    Question: {input}
    
    Answer:"""
    
    prompt = ChatPromptTemplate.from_messages([
        ("system", system_prompt),
        MessagesPlaceholder("chat_history"),
        ("human", "{input}"),
    ])
    
    # Create document processing chain
    question_answer_chain = create_stuff_documents_chain(llm, prompt)
    
    # Create retrieval chain
    rag_chain = create_retrieval_chain(retriever, question_answer_chain)
    
    return rag_chain

def response_generator(query: str, chat_history: list) -> str:
    """Ask the RAG chain to answer `query`, with JSON‑error fallback."""
    # log usage
    add_usage(supabase, "chat", "prompt:" + query, {"model": model, "temperature": temperature})
    logger.info("Using HF model %s", model)

    # Create the RAG chain
    rag_chain = create_conversational_rag_chain()
    
    # Format chat history for the chain
    formatted_history = []
    for msg in chat_history:
        if msg["role"] == "user":
            formatted_history.append(HumanMessage(content=msg["content"]))
        elif msg["role"] == "assistant":
            formatted_history.append(AIMessage(content=msg["content"]))

    try:
        result = rag_chain.invoke({
            "input": query,
            "chat_history": formatted_history
        })
        
        answer = result.get("answer", "")
        context = result.get("context", [])
        
        if not context:
            return (
                "I'm sorry, I don't have enough information to answer that. "
                "If you have a public data source to add, please email copilot@securade.ai."
            )

        answer = clean_response(answer)
        return answer
        
    except JSONDecodeError as e:
        logger.error("JSONDecodeError: %s", e)
        return "Sorry, I had trouble processing your request. Please try again."
    except Exception as e:
        logger.error("Unexpected error: %s", e)
        return "Sorry, I encountered an error while processing your request. Please try again."

# ─────── Streamlit UI ──────────────────────────────────────────────────────────
st.set_page_config(
    page_title="Securade.ai - Safety Copilot",
    page_icon="https://securade.ai/favicon.ico",
    layout="centered",
    initial_sidebar_state="collapsed",
    menu_items={
        "About": "# Securade.ai Safety Copilot v0.1\n[https://securade.ai](https://securade.ai)",
        "Get Help": "https://securade.ai",
        "Report a Bug": "mailto:hello@securade.ai",
    },
)

st.title("πŸ‘·β€β™‚οΈ Safety Copilot 🦺")
stats = get_usage(supabase)
st.markdown(f"_{stats} queries answered!_")
st.markdown(
    "Chat with your personal safety assistant about any health & safety related queries. "
    "[[blog](https://securade.ai/blog/how-securade-ai-safety-copilot-transforms-worker-safety.html)"
    "|[paper](https://securade.ai/assets/pdfs/Securade.ai-Safety-Copilot-Whitepaper.pdf)]"
)

# Initialize chat history
if "chat_history" not in st.session_state:
    st.session_state.chat_history = []

# Display chat history
for msg in st.session_state.chat_history:
    with st.chat_message(msg["role"]):
        st.markdown(msg["content"])

# Handle new user input
if prompt := st.chat_input("Ask a question"):
    # Add user message to history
    st.session_state.chat_history.append({"role": "user", "content": prompt})
    
    # Display user message
    with st.chat_message("user"):
        st.markdown(prompt)

    # Generate and display response
    with st.spinner("Safety briefing in progress..."):
        answer = response_generator(prompt, st.session_state.chat_history[:-1])  # Exclude current message

    with st.chat_message("assistant"):
        st.markdown(answer)
    
    # Add assistant response to history
    st.session_state.chat_history.append({"role": "assistant", "content": answer})