Sbnos commited on
Commit
d938a07
·
verified ·
1 Parent(s): 64aa216

modernising the application

Browse files

updated.
also adding streaming

Files changed (1) hide show
  1. app.py +106 -137
app.py CHANGED
@@ -1,145 +1,114 @@
1
- import streamlit as st
2
  import os
3
- import asyncio
4
- from langchain.chains import create_history_aware_retriever, create_retrieval_chain
5
- from langchain.chains.combine_documents import create_stuff_documents_chain
6
- from langchain_community.vectorstores import Chroma
7
- from langchain_together import Together
8
- from langchain_community.chat_message_histories import StreamlitChatMessageHistory
9
- from langchain_community.document_loaders import WebBaseLoader
10
- from langchain_core.chat_history import BaseChatMessageHistory
11
- from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
12
- from langchain_core.runnables.history import RunnableWithMessageHistory
13
  from langchain.embeddings import HuggingFaceBgeEmbeddings
14
- from langchain_text_splitters import RecursiveCharacterTextSplitter
15
-
16
- # Initialize the LLMs
17
- llm = Together(
18
- model="mistralai/Mixtral-8x22B-Instruct-v0.1",
19
- temperature=0.2,
20
- top_k=12,
21
- max_tokens=22048,
22
- together_api_key=os.environ['pilotikval']
 
 
 
 
 
 
 
23
  )
24
 
25
- # Function to store chat history
26
- store = {}
27
-
28
- model_name = "BAAI/bge-base-en"
29
- encode_kwargs = {'normalize_embeddings': True} # set True to compute cosine similarity
30
-
31
- embedding_function = HuggingFaceBgeEmbeddings(
32
- model_name=model_name,
33
- encode_kwargs=encode_kwargs
34
  )
35
 
36
- def get_session_history(session_id: str) -> BaseChatMessageHistory:
37
- if session_id not in store:
38
- store[session_id] = StreamlitChatMessageHistory(key=session_id)
39
- return store[session_id]
40
-
41
- # Define the Streamlit app
42
- def app():
43
- with st.sidebar:
44
- st.title("dochatter")
45
- option = st.selectbox(
46
- 'Which retriever would you like to use?',
47
- ('General Medicine', 'RespiratoryFishman', 'RespiratoryMurray', 'MedMRCP2', 'OldMedicine')
48
- )
49
-
50
- # Define retrievers based on option
51
- persist_directory = {
52
- 'General Medicine': "./oxfordmedbookdir/",
53
- 'Respiratory1': "./respfishmandbcud/",
54
- 'Respiratory2': "./respmurray/",
55
- 'Med2.2': "./medmrcp2store/",
56
- 'Med2.1': "./mrcpchromadb/"
57
- }.get(option, "./mrcpchromadb/")
58
-
59
- collection_name = {
60
- 'General Medicine': "oxfordmed",
61
- 'Respiratory1': "fishmannotescud",
62
- 'Respiratory2': "respmurraynotes",
63
- 'Med2.2': "medmrcp2notes",
64
- 'Med2.1': "mrcppassmednotes"
65
- }.get(option, "mrcppassmednotes")
66
-
67
- vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name=collection_name)
68
- retriever = vectordb.as_retriever(search_kwargs={"k": 5})
69
-
70
- # Define the prompt templates
71
- contextualize_q_system_prompt = (
72
- "Given a chat history and the latest user question "
73
- "which might reference context in the chat history, "
74
- "formulate a standalone question which can be understood "
75
- "without the chat history. Do NOT answer the question, "
76
- "just reformulate it if needed and otherwise return it as is."
77
- )
78
- contextualize_q_prompt = ChatPromptTemplate.from_messages(
79
- [
80
- ("system", contextualize_q_system_prompt),
81
- MessagesPlaceholder("chat_history"),
82
- ("human", "{input}"),
83
- ]
84
- )
85
- history_aware_retriever = create_history_aware_retriever(
86
- llm, retriever, contextualize_q_prompt
87
- )
88
-
89
- system_prompt = (
90
- "You are helping a doctor. Be as detailed and thorough as possible "
91
- "Use the following pieces of retrieved context to answer "
92
- "the question. If you don't know the answer, say that you "
93
- "don't know."
94
- "\n\n"
95
- "{context}"
96
- )
97
- qa_prompt = ChatPromptTemplate.from_messages(
98
- [
99
- ("system", system_prompt),
100
- MessagesPlaceholder("chat_history"),
101
- ("human", "{input}"),
102
- ]
103
- )
104
- question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
105
- rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
106
-
107
- # Statefully manage chat history
108
- conversational_rag_chain = RunnableWithMessageHistory(
109
- rag_chain,
110
- get_session_history,
111
- input_messages_key="input",
112
- history_messages_key="chat_history",
113
- output_messages_key="answer",
114
- )
115
-
116
- # Session State
117
- if "messages" not in st.session_state.keys():
118
- st.session_state.messages = [{"role": "assistant", "content": "How may I help you?"}]
119
-
120
- st.header("Hello Doc!")
121
- for message in st.session_state.messages:
122
- with st.chat_message(message["role"]):
123
- st.write(message["content"])
124
-
125
- prompts2 = st.chat_input("Say something")
126
-
127
- if prompts2:
128
- st.session_state.messages.append({"role": "user", "content": prompts2})
129
- with st.chat_message("user"):
130
- st.write(prompts2)
131
 
132
- if st.session_state.messages[-1]["role"] != "assistant":
133
- with st.chat_message("assistant"):
134
- with st.spinner("Thinking..."):
135
- final_response = conversational_rag_chain.invoke(
136
- {
137
- "input": prompts2,
138
- },
139
- config={"configurable": {"session_id": "current_session"}}
140
- )
141
- st.write(final_response['answer'])
142
- st.session_state.messages.append({"role": "assistant", "content": final_response['answer']})
143
 
144
- if __name__ == '__main__':
145
- app()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import streamlit as st
3
+ from together import Together
4
+ from langchain.vectorstores import Chroma
 
 
 
 
 
 
 
5
  from langchain.embeddings import HuggingFaceBgeEmbeddings
6
+ from langchain.chains import ConversationalRetrievalChain
7
+
8
+ # --- Configuration ---
9
+ TOGETHER_API_KEY = os.environ.get("TOGETHER_API_KEY")
10
+ if not TOGETHER_API_KEY:
11
+ st.error("Missing TOGETHER_API_KEY environment variable.")
12
+ st.stop()
13
+
14
+ # Initialize TogetherAI client
15
+ client = Together(api_key=TOGETHER_API_KEY)
16
+
17
+ # Embeddings setup
18
+ EMBED_MODEL_NAME = "BAAI/bge-base-en"
19
+ embeddings = HuggingFaceBgeEmbeddings(
20
+ model_name=EMBED_MODEL_NAME,
21
+ encode_kwargs={"normalize_embeddings": True},
22
  )
23
 
24
+ # Sidebar: select collection
25
+ st.sidebar.title("DocChatter RAG")
26
+ collection = st.sidebar.selectbox(
27
+ "Choose a document collection:",
28
+ ['General Medicine', 'RespiratoryFishman', 'RespiratoryMurray', 'MedMRCP2', 'OldMedicine']
 
 
 
 
29
  )
30
 
31
+ dirs = {
32
+ 'General Medicine': './oxfordmedbookdir/',
33
+ 'RespiratoryFishman': './respfishmandbcud/',
34
+ 'RespiratoryMurray': './respmurray/',
35
+ 'MedMRCP2': './medmrcp2store/',
36
+ 'OldMedicine': './mrcpchromadb/'
37
+ }
38
+ cols = {
39
+ 'General Medicine': 'oxfordmed',
40
+ 'RespiratoryFishman': 'fishmannotescud',
41
+ 'RespiratoryMurray': 'respmurraynotes',
42
+ 'MedMRCP2': 'medmrcp2notes',
43
+ 'OldMedicine': 'mrcppassmednotes'
44
+ }
45
+
46
+ persist_directory = dirs.get(collection)
47
+ collection_name = cols.get(collection)
48
+
49
+ # Load Chroma vector store
50
+ vectorstore = Chroma(
51
+ collection_name=collection_name,
52
+ persist_directory=persist_directory,
53
+ embedding_function=embeddings
54
+ )
55
+ retriever = vectorstore.as_retriever(search_kwargs={"k":5})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
+ # System prompt template
58
+ SYSTEM_PROMPT = (
59
+ "You are a helpful assistant for medical professionals. "
60
+ "Use the following context from medical documents to answer the question. "
61
+ "If you don't know, say you don't know.\n\nContext:\n{context}\n"
62
+ )
 
 
 
 
 
63
 
64
+ st.title("🩺 DocChatter RAG (Streaming)")
65
+
66
+ # Initialize chat history
67
+ if 'chat_history' not in st.session_state:
68
+ st.session_state.chat_history = [] # list of dicts {role, content}
69
+
70
+ # Tabs
71
+ chat_tab, clear_tab = st.tabs(["Chat", "Clear History"])
72
+ with chat_tab:
73
+ # Display history
74
+ for msg in st.session_state.chat_history:
75
+ st.chat_message(msg['role']).write(msg['content'])
76
+
77
+ # User input
78
+ if prompt := st.chat_input("Ask anything about your docs..."):
79
+ # User message
80
+ st.chat_message("user").write(prompt)
81
+ st.session_state.chat_history.append({"role": "user", "content": prompt})
82
+
83
+ # Retrieve relevant docs
84
+ docs = retriever.get_relevant_documents(prompt)
85
+ context = "\n---\n".join([d.page_content for d in docs])
86
+
87
+ # Build messages for TogetherAI
88
+ system_msg = {"role": "system", "content": SYSTEM_PROMPT.format(context=context)}
89
+ messages = [system_msg]
90
+ # include prior conversation
91
+ for msg in st.session_state.chat_history:
92
+ if msg['role'] in ('user', 'assistant'):
93
+ messages.append(msg)
94
+ # Prepare streaming response
95
+ response_container = st.chat_message("assistant")
96
+ placeholder = response_container.empty()
97
+ answer = ""
98
+ # Stream tokens
99
+ for token in client.chat.completions.create(
100
+ model="meta-llama/Llama-4-Scout-17B-16E-Instruct",
101
+ messages=messages,
102
+ stream=True
103
+ ):
104
+ if hasattr(token, 'choices'):
105
+ delta = token.choices[0].delta.content
106
+ answer += delta
107
+ placeholder.write(answer)
108
+ # Save assistant message
109
+ st.session_state.chat_history.append({"role": "assistant", "content": answer})
110
+
111
+ with clear_tab:
112
+ if st.button("🗑️ Clear chat history"):
113
+ st.session_state.chat_history = []
114
+ st.experimental_rerun()