lfoppiano commited on
Commit
b7b1a78
·
1 Parent(s): 53ab843

update libraries and code to support Modal as OpenAI-based server

Browse files
document_qa/document_qa_engine.py CHANGED
@@ -5,7 +5,8 @@ from typing import Union, Any, List
5
 
6
  import tiktoken
7
  from langchain.chains import create_extraction_chain
8
- from langchain.chains.question_answering import load_qa_chain, stuff_prompt, refine_prompts, map_reduce_prompt, \
 
9
  map_rerank_prompt
10
  from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
11
  from langchain.retrievers import MultiQueryRetriever
@@ -14,7 +15,6 @@ from langchain_community.vectorstores.chroma import Chroma
14
  from langchain_core.vectorstores import VectorStore
15
  from tqdm import tqdm
16
 
17
- # from document_qa.embedding_visualiser import QueryVisualiser
18
  from document_qa.grobid_processors import GrobidProcessor
19
  from document_qa.langchain import ChromaAdvancedRetrieval
20
 
@@ -177,17 +177,19 @@ class DataStorage:
177
 
178
  def embed_document(self, doc_id, texts, metadatas):
179
  if doc_id not in self.embeddings_dict.keys():
180
- self.embeddings_dict[doc_id] = self.engine.from_texts(texts,
181
- embedding=self.embedding_function,
182
- metadatas=metadatas,
183
- collection_name=doc_id)
 
184
  else:
185
  # Workaround Chroma (?) breaking change
186
  self.embeddings_dict[doc_id].delete_collection()
187
- self.embeddings_dict[doc_id] = self.engine.from_texts(texts,
188
- embedding=self.embedding_function,
189
- metadatas=metadatas,
190
- collection_name=doc_id)
 
191
 
192
  self.embeddings_root_path = None
193
 
@@ -206,14 +208,13 @@ class DocumentQAEngine:
206
  def __init__(self,
207
  llm,
208
  data_storage: DataStorage,
209
- qa_chain_type="stuff",
210
  grobid_url=None,
211
  memory=None
212
  ):
213
 
214
  self.llm = llm
215
  self.memory = memory
216
- self.chain = load_qa_chain(llm, chain_type=qa_chain_type)
217
  self.text_merger = TextMerger()
218
  self.data_storage = data_storage
219
 
@@ -271,7 +272,10 @@ class DocumentQAEngine:
271
  Returns both the context and the embedding information from a given query
272
  """
273
  db = self.data_storage.embeddings_dict[doc_id]
274
- retriever = db.as_retriever(search_kwargs={"k": context_size}, search_type="similarity_with_embeddings")
 
 
 
275
  relevant_documents = retriever.invoke(query)
276
 
277
  return relevant_documents
@@ -327,20 +331,18 @@ class DocumentQAEngine:
327
 
328
  def _run_query(self, doc_id, query, context_size=4) -> (List[Document], list):
329
  relevant_documents, relevant_document_coordinates = self._get_context(doc_id, query, context_size)
330
- response = self.chain.run(input_documents=relevant_documents,
331
- question=query)
332
-
333
- if self.memory:
334
- self.memory.save_context({"input": query}, {"output": response})
335
  return response, relevant_document_coordinates
336
 
337
  def _get_context(self, doc_id, query, context_size=4) -> (List[Document], list):
338
  db = self.data_storage.embeddings_dict[doc_id]
339
  retriever = db.as_retriever(search_kwargs={"k": context_size})
340
  relevant_documents = retriever.invoke(query)
341
- relevant_document_coordinates = [doc.metadata['coordinates'].split(";") if 'coordinates' in doc.metadata else []
342
- for doc in
343
- relevant_documents]
 
 
344
  if self.memory and len(self.memory.buffer_as_messages) > 0:
345
  relevant_documents.append(
346
  Document(
 
5
 
6
  import tiktoken
7
  from langchain.chains import create_extraction_chain
8
+ from langchain.chains.combine_documents import create_stuff_documents_chain
9
+ from langchain.chains.question_answering import stuff_prompt, refine_prompts, map_reduce_prompt, \
10
  map_rerank_prompt
11
  from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
12
  from langchain.retrievers import MultiQueryRetriever
 
15
  from langchain_core.vectorstores import VectorStore
16
  from tqdm import tqdm
17
 
 
18
  from document_qa.grobid_processors import GrobidProcessor
19
  from document_qa.langchain import ChromaAdvancedRetrieval
20
 
 
177
 
178
  def embed_document(self, doc_id, texts, metadatas):
179
  if doc_id not in self.embeddings_dict.keys():
180
+ self.embeddings_dict[doc_id] = self.engine.from_texts(
181
+ texts,
182
+ embedding=self.embedding_function,
183
+ metadatas=metadatas,
184
+ collection_name=doc_id)
185
  else:
186
  # Workaround Chroma (?) breaking change
187
  self.embeddings_dict[doc_id].delete_collection()
188
+ self.embeddings_dict[doc_id] = self.engine.from_texts(
189
+ texts,
190
+ embedding=self.embedding_function,
191
+ metadatas=metadatas,
192
+ collection_name=doc_id)
193
 
194
  self.embeddings_root_path = None
195
 
 
208
  def __init__(self,
209
  llm,
210
  data_storage: DataStorage,
 
211
  grobid_url=None,
212
  memory=None
213
  ):
214
 
215
  self.llm = llm
216
  self.memory = memory
217
+ self.chain = create_stuff_documents_chain(llm, self.default_prompts['stuff'].PROMPT)
218
  self.text_merger = TextMerger()
219
  self.data_storage = data_storage
220
 
 
272
  Returns both the context and the embedding information from a given query
273
  """
274
  db = self.data_storage.embeddings_dict[doc_id]
275
+ retriever = db.as_retriever(
276
+ search_kwargs={"k": context_size},
277
+ search_type="similarity_with_embeddings"
278
+ )
279
  relevant_documents = retriever.invoke(query)
280
 
281
  return relevant_documents
 
331
 
332
  def _run_query(self, doc_id, query, context_size=4) -> (List[Document], list):
333
  relevant_documents, relevant_document_coordinates = self._get_context(doc_id, query, context_size)
334
+ response = self.chain.invoke({"context": relevant_documents, "question": query})
 
 
 
 
335
  return response, relevant_document_coordinates
336
 
337
  def _get_context(self, doc_id, query, context_size=4) -> (List[Document], list):
338
  db = self.data_storage.embeddings_dict[doc_id]
339
  retriever = db.as_retriever(search_kwargs={"k": context_size})
340
  relevant_documents = retriever.invoke(query)
341
+ relevant_document_coordinates = [
342
+ doc.metadata['coordinates'].split(";") if 'coordinates' in doc.metadata else []
343
+ for doc in
344
+ relevant_documents
345
+ ]
346
  if self.memory and len(self.memory.buffer_as_messages) > 0:
347
  relevant_documents.append(
348
  Document(
document_qa/langchain.py CHANGED
@@ -1,4 +1,3 @@
1
- from pathlib import Path
2
  from typing import Any, Optional, List, Dict, Tuple, ClassVar, Collection
3
 
4
  from langchain.schema import Document
 
 
1
  from typing import Any, Optional, List, Dict, Tuple, ClassVar, Collection
2
 
3
  from langchain.schema import Document
requirements.txt CHANGED
@@ -1,32 +1,33 @@
1
  # Grobid
2
  grobid-quantities-client==0.4.0
3
  grobid-client-python==0.0.9
4
- grobid_tei_xml==0.1.3
5
 
6
  # Utils
7
- tqdm==4.66.2
8
  pyyaml==6.0.1
9
  pytest==8.1.1
10
- streamlit==1.37.1
11
- lxml
12
- Beautifulsoup4
13
- python-dotenv
14
- watchdog
15
- dateparser
 
16
 
17
  # LLM
18
  chromadb==0.4.24
19
- tiktoken==0.7.0
20
- openai==1.42.0
21
- langchain==0.2.14
22
- langchain-core==0.2.34
23
- langchain-openai==0.1.22
24
- langchain-huggingface==0.0.3
25
- langchain-community==0.2.12
26
  typing-inspect==0.9.0
27
- typing_extensions==4.11.0
28
- pydantic==2.6.4
29
- sentence_transformers==2.6.1
30
- streamlit-pdf-viewer==0.0.22
31
- umap-learn
32
- plotly
 
1
  # Grobid
2
  grobid-quantities-client==0.4.0
3
  grobid-client-python==0.0.9
4
+ grobid-tei-xml==0.1.3
5
 
6
  # Utils
7
+ tqdm==4.66.3
8
  pyyaml==6.0.1
9
  pytest==8.1.1
10
+ streamlit==1.45.1
11
+ lxml==5.2.1
12
+ beautifulsoup4==4.12.3
13
+ python-dotenv==1.0.1
14
+ watchdog==4.0.0
15
+ dateparser==1.2.0
16
+ requests>=2.31.0
17
 
18
  # LLM
19
  chromadb==0.4.24
20
+ tiktoken==0.9.0
21
+ openai==1.82.0
22
+ langchain==0.3.25
23
+ langchain-core==0.3.61
24
+ langchain-openai==0.3.18
25
+ langchain-huggingface==0.2.0
26
+ langchain-community==0.3.21
27
  typing-inspect==0.9.0
28
+ typing_extensions==4.12.2
29
+ pydantic==2.10.6
30
+ sentence-transformers==2.6.1
31
+ streamlit-pdf-viewer==0.0.22rc0
32
+ umap-learn==0.5.6
33
+ plotly==5.20.0
streamlit_app.py CHANGED
@@ -5,11 +5,9 @@ from tempfile import NamedTemporaryFile
5
 
6
  import dotenv
7
  from grobid_quantities.quantities import QuantitiesAPI
8
- from langchain.memory import ConversationBufferWindowMemory
9
- from langchain_community.chat_models import ChatOpenAI
10
- from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
11
  from langchain_huggingface import HuggingFaceEmbeddings
12
- from langchain_openai import OpenAIEmbeddings
13
  from streamlit_pdf_viewer import pdf_viewer
14
 
15
  from document_qa.ner_client_generic import NERClientGeneric
@@ -20,30 +18,14 @@ import streamlit as st
20
  from document_qa.document_qa_engine import DocumentQAEngine, DataStorage
21
  from document_qa.grobid_processors import GrobidAggregationProcessor, decorate_text_with_annotations
22
 
23
- OPENAI_MODELS = ['gpt-3.5-turbo',
24
- "gpt-4",
25
- "gpt-4-1106-preview"]
26
-
27
- OPENAI_EMBEDDINGS = [
28
- 'text-embedding-ada-002',
29
- 'text-embedding-3-large',
30
- 'openai-text-embedding-3-small'
31
- ]
32
-
33
- OPEN_MODELS = {
34
- 'Mistral-Nemo-Instruct-2407': 'mistralai/Mistral-Nemo-Instruct-2407',
35
- 'mistral-7b-instruct-v0.3': 'mistralai/Mistral-7B-Instruct-v0.3',
36
- 'Phi-3-mini-4k-instruct': "microsoft/Phi-3-mini-4k-instruct"
37
  }
38
 
39
- DEFAULT_OPEN_EMBEDDING_NAME = 'Default (all-MiniLM-L6-v2)'
40
- OPEN_EMBEDDINGS = {
41
- DEFAULT_OPEN_EMBEDDING_NAME: 'all-MiniLM-L6-v2',
42
- 'SFR-Embedding-Mistral': 'Salesforce/SFR-Embedding-Mistral',
43
- 'SFR-Embedding-2_R': 'Salesforce/SFR-Embedding-2_R',
44
- 'NV-Embed': 'nvidia/NV-Embed-v1',
45
- 'e5-mistral-7b-instruct': 'intfloat/e5-mistral-7b-instruct',
46
- 'gte-large-en-v1.5': 'Alibaba-NLP/gte-large-en-v1.5'
47
  }
48
 
49
  if 'rqa' not in st.session_state:
@@ -141,48 +123,20 @@ def clear_memory():
141
 
142
 
143
  # @st.cache_resource
144
- def init_qa(model, embeddings_name=None, api_key=None):
145
- ## For debug add: callbacks=[PromptLayerCallbackHandler(pl_tags=["langchain", "chatgpt", "document-qa"])])
146
- if model in OPENAI_MODELS:
147
- if embeddings_name is None:
148
- embeddings_name = 'text-embedding-ada-002'
149
-
150
- st.session_state['memory'] = ConversationBufferWindowMemory(k=4)
151
- if api_key:
152
- chat = ChatOpenAI(model_name=model,
153
- temperature=0,
154
- openai_api_key=api_key,
155
- frequency_penalty=0.1)
156
- if embeddings_name not in OPENAI_EMBEDDINGS:
157
- st.error(f"The embeddings provided {embeddings_name} are not supported by this model {model}.")
158
- st.stop()
159
- return
160
- embeddings = OpenAIEmbeddings(model=embeddings_name, openai_api_key=api_key)
161
 
162
- else:
163
- chat = ChatOpenAI(model_name=model,
164
- temperature=0,
165
- frequency_penalty=0.1)
166
- embeddings = OpenAIEmbeddings(model=embeddings_name)
167
-
168
- elif model in OPEN_MODELS:
169
- if embeddings_name is None:
170
- embeddings_name = DEFAULT_OPEN_EMBEDDING_NAME
171
-
172
- chat = HuggingFaceEndpoint(
173
- repo_id=OPEN_MODELS[model],
174
- temperature=0.01,
175
- max_new_tokens=4092,
176
- model_kwargs={"max_length": 8192},
177
- # callbacks=[PromptLayerCallbackHandler(pl_tags=[model, "document-qa"])]
178
- )
179
- embeddings = HuggingFaceEmbeddings(
180
- model_name=OPEN_EMBEDDINGS[embeddings_name])
181
- # st.session_state['memory'] = ConversationBufferWindowMemory(k=4) if model not in DISABLE_MEMORY else None
182
- else:
183
- st.error("The model was not loaded properly. Try reloading. ")
184
- st.stop()
185
- return
186
 
187
  storage = DataStorage(embeddings)
188
  return DocumentQAEngine(chat, storage, grobid_url=os.environ['GROBID_URL'], memory=st.session_state['memory'])
@@ -246,65 +200,31 @@ with st.sidebar:
246
  st.divider()
247
  st.session_state['model'] = model = st.selectbox(
248
  "Model:",
249
- options=OPENAI_MODELS + list(OPEN_MODELS.keys()),
250
- index=(OPENAI_MODELS + list(OPEN_MODELS.keys())).index(
251
  os.environ["DEFAULT_MODEL"]) if "DEFAULT_MODEL" in os.environ and os.environ["DEFAULT_MODEL"] else 0,
252
  placeholder="Select model",
253
  help="Select the LLM model:",
254
  disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded']
255
  )
256
- embedding_choices = OPENAI_EMBEDDINGS if model in OPENAI_MODELS else OPEN_EMBEDDINGS
257
 
258
  st.session_state['embeddings'] = embedding_name = st.selectbox(
259
  "Embeddings:",
260
- options=embedding_choices,
261
- index=0,
 
 
262
  placeholder="Select embedding",
263
  help="Select the Embedding function:",
264
  disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded']
265
  )
266
 
267
- if (model in OPEN_MODELS) and model not in st.session_state['api_keys']:
268
- if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ:
269
- api_key = st.text_input('Huggingface API Key', type="password")
270
 
271
- st.markdown("Get it [here](https://huggingface.co/docs/hub/security-tokens)")
272
- else:
273
- api_key = os.environ['HUGGINGFACEHUB_API_TOKEN']
274
-
275
- if api_key:
276
- # st.session_state['api_key'] = is_api_key_provided = True
277
- if model not in st.session_state['rqa'] or model not in st.session_state['api_keys']:
278
- with st.spinner("Preparing environment"):
279
- st.session_state['api_keys'][model] = api_key
280
- # if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ:
281
- # os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_key
282
- st.session_state['rqa'][model] = init_qa(model, embedding_name)
283
-
284
- elif model in OPENAI_MODELS and model not in st.session_state['api_keys']:
285
- if 'OPENAI_API_KEY' not in os.environ:
286
- api_key = st.text_input('OpenAI API Key', type="password")
287
- st.markdown("Get it [here](https://platform.openai.com/account/api-keys)")
288
- else:
289
- api_key = os.environ['OPENAI_API_KEY']
290
-
291
- if api_key:
292
- if model not in st.session_state['rqa'] or model not in st.session_state['api_keys']:
293
- with st.spinner("Preparing environment"):
294
- st.session_state['api_keys'][model] = api_key
295
- if 'OPENAI_API_KEY' not in os.environ:
296
- st.session_state['rqa'][model] = init_qa(model, st.session_state['embeddings'], api_key)
297
- else:
298
- st.session_state['rqa'][model] = init_qa(model, st.session_state['embeddings'])
299
- # else:
300
- # is_api_key_provided = st.session_state['api_key']
301
-
302
- # st.button(
303
- # 'Reset chat memory.',
304
- # key="reset-memory-button",
305
- # on_click=clear_memory,
306
- # help="Clear the conversational memory. Currently implemented to retrain the 4 most recent messages.",
307
- # disabled=model in st.session_state['rqa'] and st.session_state['rqa'][model].memory is None)
308
 
309
  left_column, right_column = st.columns([5, 4])
310
  right_column = right_column.container(border=True)
@@ -390,15 +310,16 @@ if uploaded_file and not st.session_state.loaded_embeddings:
390
  st.stop()
391
 
392
  with left_column:
393
- with st.spinner('Reading file, calling Grobid, and creating memory embeddings...'):
394
  binary = uploaded_file.getvalue()
395
  tmp_file = NamedTemporaryFile()
396
  tmp_file.write(bytearray(binary))
397
  st.session_state['binary'] = binary
398
 
399
- st.session_state['doc_id'] = hash = st.session_state['rqa'][model].create_memory_embeddings(tmp_file.name,
400
- chunk_size=chunk_size,
401
- perc_overlap=0.1)
 
402
  st.session_state['loaded_embeddings'] = True
403
  st.session_state.messages = []
404
 
@@ -477,7 +398,7 @@ with right_column:
477
  annotation_doc]
478
 
479
  if not text_response:
480
- st.error("Something went wrong. Contact Luca Foppiano (Foppiano.Luca@nims.co.jp) to report the issue.")
481
 
482
  if mode == "llm":
483
  if st.session_state['ner_processing']:
@@ -503,5 +424,6 @@ with left_column:
503
  annotation_outline_size=2,
504
  annotations=st.session_state['annotations'] if st.session_state['annotations'] else [],
505
  render_text=True,
506
- scroll_to_annotation=1 if (st.session_state['annotations'] and st.session_state['scroll_to_first_annotation']) else None
 
507
  )
 
5
 
6
  import dotenv
7
  from grobid_quantities.quantities import QuantitiesAPI
8
+ from langchain.memory import ConversationBufferMemory
 
 
9
  from langchain_huggingface import HuggingFaceEmbeddings
10
+ from langchain_openai import ChatOpenAI
11
  from streamlit_pdf_viewer import pdf_viewer
12
 
13
  from document_qa.ner_client_generic import NERClientGeneric
 
18
  from document_qa.document_qa_engine import DocumentQAEngine, DataStorage
19
  from document_qa.grobid_processors import GrobidAggregationProcessor, decorate_text_with_annotations
20
 
21
+ API_MODELS = {
22
+ "microsoft/Phi-4-mini-instruct": os.environ["MODAL_1_URL"]
 
 
 
 
 
 
 
 
 
 
 
 
23
  }
24
 
25
+ API_EMBEDDINGS = {
26
+ 'intfloat/e5-large-v2': 'intfloat/e5-large-v2',
27
+ 'intfloat/multilingual-e5-large-instruct': 'intfloat/multilingual-e5-large-instruct:',
28
+ 'Salesforce/SFR-Embedding-2_R': 'Salesforce/SFR-Embedding-2_R'
 
 
 
 
29
  }
30
 
31
  if 'rqa' not in st.session_state:
 
123
 
124
 
125
  # @st.cache_resource
126
+ def init_qa(model_name, embeddings_name):
127
+ st.session_state['memory'] = ConversationBufferMemory(
128
+ memory_key="chat_history",
129
+ return_messages=True
130
+ )
131
+ chat = ChatOpenAI(
132
+ model=model_name,
133
+ temperature=0.0,
134
+ base_url=API_MODELS[model_name],
135
+ api_key=os.environ.get('API_KEY')
136
+ )
 
 
 
 
 
 
137
 
138
+ embeddings = HuggingFaceEmbeddings(
139
+ model_name=API_EMBEDDINGS[embeddings_name])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
  storage = DataStorage(embeddings)
142
  return DocumentQAEngine(chat, storage, grobid_url=os.environ['GROBID_URL'], memory=st.session_state['memory'])
 
200
  st.divider()
201
  st.session_state['model'] = model = st.selectbox(
202
  "Model:",
203
+ options=API_MODELS.keys(),
204
+ index=(list(API_MODELS.keys())).index(
205
  os.environ["DEFAULT_MODEL"]) if "DEFAULT_MODEL" in os.environ and os.environ["DEFAULT_MODEL"] else 0,
206
  placeholder="Select model",
207
  help="Select the LLM model:",
208
  disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded']
209
  )
 
210
 
211
  st.session_state['embeddings'] = embedding_name = st.selectbox(
212
  "Embeddings:",
213
+ options=API_EMBEDDINGS.keys(),
214
+ index=(list(API_EMBEDDINGS.keys())).index(
215
+ os.environ["DEFAULT_EMBEDDING"]) if "DEFAULT_EMBEDDING" in os.environ and os.environ[
216
+ "DEFAULT_EMBEDDING"] else 0,
217
  placeholder="Select embedding",
218
  help="Select the Embedding function:",
219
  disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded']
220
  )
221
 
222
+ api_key = os.environ['API_KEY']
 
 
223
 
224
+ if model not in st.session_state['rqa'] or model not in st.session_state['api_keys']:
225
+ with st.spinner("Preparing environment"):
226
+ st.session_state['rqa'][model] = init_qa(model, st.session_state['embeddings'])
227
+ st.session_state['api_keys'][model] = api_key
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
 
229
  left_column, right_column = st.columns([5, 4])
230
  right_column = right_column.container(border=True)
 
310
  st.stop()
311
 
312
  with left_column:
313
+ with st.spinner('Reading file, calling Grobid, and creating in-memory embeddings...'):
314
  binary = uploaded_file.getvalue()
315
  tmp_file = NamedTemporaryFile()
316
  tmp_file.write(bytearray(binary))
317
  st.session_state['binary'] = binary
318
 
319
+ st.session_state['doc_id'] = hash = st.session_state['rqa'][model].create_memory_embeddings(
320
+ tmp_file.name,
321
+ chunk_size=chunk_size,
322
+ perc_overlap=0.1)
323
  st.session_state['loaded_embeddings'] = True
324
  st.session_state.messages = []
325
 
 
398
  annotation_doc]
399
 
400
  if not text_response:
401
+ st.error("Something went wrong. Contact info AT sciencialab.com to report the issue through GitHub.")
402
 
403
  if mode == "llm":
404
  if st.session_state['ner_processing']:
 
424
  annotation_outline_size=2,
425
  annotations=st.session_state['annotations'] if st.session_state['annotations'] else [],
426
  render_text=True,
427
+ scroll_to_annotation=1 if (st.session_state['annotations'] and st.session_state[
428
+ 'scroll_to_first_annotation']) else None
429
  )