zamal commited on
Commit
82895ea
·
verified ·
1 Parent(s): cd8c42c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -48
app.py CHANGED
@@ -51,11 +51,19 @@ vision_model = LlavaNextForConditionalGeneration.from_pretrained(
51
  ).to("cuda")
52
 
53
 
 
 
 
 
54
  @spaces.GPU()
55
  def get_image_description(image: Image.Image) -> str:
 
 
 
 
56
  global processor, vision_model
57
 
58
- # on first call, load & move to cuda
59
  if processor is None or vision_model is None:
60
  processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
61
  vision_model = LlavaNextForConditionalGeneration.from_pretrained(
@@ -64,9 +72,9 @@ def get_image_description(image: Image.Image) -> str:
64
  low_cpu_mem_usage=True
65
  ).to("cuda")
66
 
 
67
  torch.cuda.empty_cache()
68
  gc.collect()
69
-
70
  prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]"
71
  inputs = processor(prompt, image, return_tensors="pt").to("cuda")
72
  output = vision_model.generate(**inputs, max_new_tokens=100)
@@ -166,23 +174,22 @@ def extract_data_from_pdfs(
166
  progress=gr.Progress()
167
  ):
168
  """
169
- 1) Dynamically instantiate the chosen OCR pipeline (if any)
170
- 2) Dynamically instantiate the chosen vision‐language model
171
- 3) Monkey‐patch get_image_description to use that VL model
172
- 4) Extract text & images, index into ChromaDB
173
  """
174
  if not docs:
175
  raise gr.Error("No documents to process")
176
 
177
- # ——— 1) OCR setup (if requested) —————————————————————
178
  if do_ocr == "Get Text With OCR":
179
  db_m, crnn_m = OCR_CHOICES[ocr_choice]
180
  local_ocr = ocr_predictor(db_m, crnn_m, pretrained=True, assume_straight_pages=True)
181
  else:
182
  local_ocr = None
183
 
184
- # ——— 2) Visionlanguage model setup ——————————————————
185
- # Load processor + model *inside* the GPU worker
186
  proc = LlavaNextProcessor.from_pretrained(vlm_choice)
187
  vis = (
188
  LlavaNextForConditionalGeneration
@@ -190,25 +197,24 @@ def extract_data_from_pdfs(
190
  .to("cuda")
191
  )
192
 
193
- # ——— 3) Monkeypatch get_image_description —————————————————
194
  def describe(img: Image.Image) -> str:
195
- torch.cuda.empty_cache()
196
- gc.collect()
197
  prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]"
198
  inputs = proc(prompt, img, return_tensors="pt").to("cuda")
199
  output = vis.generate(**inputs, max_new_tokens=100)
200
  return proc.decode(output[0], skip_special_tokens=True)
201
 
202
- global get_image_description
203
  get_image_description = describe
204
 
205
- # ——— 4) Extract text & images —————————————————————
206
  progress(0.2, "Extracting text and images…")
207
  all_text = ""
208
  images, names = [], []
209
 
210
  for path in docs:
211
- # text extraction
212
  if local_ocr:
213
  pdf = DocumentFile.from_pdf(path)
214
  res = local_ocr(pdf)
@@ -217,43 +223,48 @@ def extract_data_from_pdfs(
217
  txt = PdfReader(path).pages[0].extract_text() or ""
218
  all_text += txt + "\n\n"
219
 
220
- # image extraction
221
  if include_images == "Include Images":
222
  imgs = extract_images([path])
223
  images.extend(imgs)
224
  names.extend([os.path.basename(path)] * len(imgs))
225
 
226
- # ——— 5) Index into ChromaDB —————————————————————
227
  progress(0.6, "Indexing in vector DB…")
228
- vdb = get_vectordb(all_text, images, names)
229
 
230
- # mark session done & prepare outputs
231
  session["processed"] = True
232
  sample_imgs = images[:4] if include_images == "Include Images" else []
233
 
234
  return (
235
- vdb,
236
- session,
237
- gr.Row(visible=True),
238
  all_text[:2000] + "...",
239
  sample_imgs,
240
  "<h3>Done!</h3>"
241
  )
242
 
 
243
  # Chat function
244
  def conversation(
245
- vdb, question: str, num_ctx, img_ctx,
246
- history: list, temp: float, max_tok: int, model_id: str
 
 
 
 
 
 
247
  ):
248
- # 0) Cast the context sliders to ints
249
- num_ctx = int(num_ctx)
250
- img_ctx = int(img_ctx)
251
-
252
- # 1) Guard: must have extracted first
253
- if vdb is None:
254
  raise gr.Error("Please extract data first")
255
 
256
- # 2) Instantiate the chosen HF endpoint
257
  llm = HuggingFaceEndpoint(
258
  repo_id=model_id,
259
  temperature=temp,
@@ -261,23 +272,22 @@ def conversation(
261
  huggingfacehub_api_token=HF_TOKEN
262
  )
263
 
264
- # 3) Query text collection
265
- text_col = vdb.get_collection("text_db")
266
  docs = text_col.query(
267
  query_texts=[question],
268
- n_results=num_ctx, # now an int
269
  include=["documents"]
270
  )["documents"][0]
271
 
272
- # 4) Query image collection
273
- img_col = vdb.get_collection("image_db")
274
  img_q = img_col.query(
275
  query_texts=[question],
276
- n_results=img_ctx, # now an int
277
  include=["metadatas", "documents"]
278
  )
279
- # rest unchanged
280
- images, img_descs = [], img_q["documents"][0] or ["No images found"]
281
  for meta in img_q["metadatas"][0]:
282
  b64 = meta.get("image", "")
283
  try:
@@ -286,7 +296,7 @@ def conversation(
286
  pass
287
  img_desc = "\n".join(img_descs)
288
 
289
- # 5) Build prompt
290
  prompt = PromptTemplate(
291
  template="""
292
  Context:
@@ -302,10 +312,12 @@ Answer:
302
  """,
303
  input_variables=["text", "img_desc", "q"],
304
  )
305
- context = "\n\n".join(docs)
306
- user_input = prompt.format(text=context, img_desc=img_desc, q=question)
 
 
 
307
 
308
- # 6) Call the model with error handling
309
  try:
310
  answer = llm.invoke(user_input)
311
  except HfHubHTTPError as e:
@@ -316,13 +328,10 @@ Answer:
316
  except Exception as e:
317
  answer = f"⚠️ Unexpected error: {e}"
318
 
319
- # 7) Append to history
320
  new_history = history + [
321
- {"role":"user", "content": question},
322
- {"role":"assistant","content": answer}
323
  ]
324
-
325
- # 8) Return updated history, docs, images
326
  return new_history, docs, images
327
 
328
 
 
51
  ).to("cuda")
52
 
53
 
54
+ # Add at the top of your module, alongside your other globals
55
+ CURRENT_VDB = None
56
+
57
+
58
  @spaces.GPU()
59
  def get_image_description(image: Image.Image) -> str:
60
+ """
61
+ Lazy-loads the Llava processor + model into the GPU worker,
62
+ runs captioning, and returns a one-sentence description.
63
+ """
64
  global processor, vision_model
65
 
66
+ # First-call: instantiate + move to CUDA
67
  if processor is None or vision_model is None:
68
  processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
69
  vision_model = LlavaNextForConditionalGeneration.from_pretrained(
 
72
  low_cpu_mem_usage=True
73
  ).to("cuda")
74
 
75
+ # clear and run
76
  torch.cuda.empty_cache()
77
  gc.collect()
 
78
  prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]"
79
  inputs = processor(prompt, image, return_tensors="pt").to("cuda")
80
  output = vision_model.generate(**inputs, max_new_tokens=100)
 
174
  progress=gr.Progress()
175
  ):
176
  """
177
+ 1) (Optional) OCR setup
178
+ 2) V+L model setup & monkey-patch get_image_description
179
+ 3) Extract text and images
180
+ 4) Build and store vector DB in global CURRENT_VDB
181
  """
182
  if not docs:
183
  raise gr.Error("No documents to process")
184
 
185
+ # 1) OCR instantiation if requested
186
  if do_ocr == "Get Text With OCR":
187
  db_m, crnn_m = OCR_CHOICES[ocr_choice]
188
  local_ocr = ocr_predictor(db_m, crnn_m, pretrained=True, assume_straight_pages=True)
189
  else:
190
  local_ocr = None
191
 
192
+ # 2) Visionlanguage model instantiation
 
193
  proc = LlavaNextProcessor.from_pretrained(vlm_choice)
194
  vis = (
195
  LlavaNextForConditionalGeneration
 
197
  .to("cuda")
198
  )
199
 
200
+ # Monkey-patch global captioning fn
201
  def describe(img: Image.Image) -> str:
202
+ torch.cuda.empty_cache(); gc.collect()
 
203
  prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]"
204
  inputs = proc(prompt, img, return_tensors="pt").to("cuda")
205
  output = vis.generate(**inputs, max_new_tokens=100)
206
  return proc.decode(output[0], skip_special_tokens=True)
207
 
208
+ global get_image_description, CURRENT_VDB
209
  get_image_description = describe
210
 
211
+ # 3) Extract text & images
212
  progress(0.2, "Extracting text and images…")
213
  all_text = ""
214
  images, names = [], []
215
 
216
  for path in docs:
217
+ # text
218
  if local_ocr:
219
  pdf = DocumentFile.from_pdf(path)
220
  res = local_ocr(pdf)
 
223
  txt = PdfReader(path).pages[0].extract_text() or ""
224
  all_text += txt + "\n\n"
225
 
226
+ # images
227
  if include_images == "Include Images":
228
  imgs = extract_images([path])
229
  images.extend(imgs)
230
  names.extend([os.path.basename(path)] * len(imgs))
231
 
232
+ # 4) Build and stash the vector DB
233
  progress(0.6, "Indexing in vector DB…")
234
+ CURRENT_VDB = get_vectordb(all_text, images, names)
235
 
236
+ # mark done & return only picklable outputs
237
  session["processed"] = True
238
  sample_imgs = images[:4] if include_images == "Include Images" else []
239
 
240
  return (
241
+ session, # gr.State for “processed”
242
+ gr.Row(visible=True), # to un‐hide your chat UI
 
243
  all_text[:2000] + "...",
244
  sample_imgs,
245
  "<h3>Done!</h3>"
246
  )
247
 
248
+
249
  # Chat function
250
  def conversation(
251
+ session: dict,
252
+ question: str,
253
+ num_ctx: int,
254
+ img_ctx: int,
255
+ history: list,
256
+ temp: float,
257
+ max_tok: int,
258
+ model_id: str
259
  ):
260
+ """
261
+ Pulls CURRENT_VDB from module global, runs text+image retrieval,
262
+ calls the HF endpoint, and returns updated chat history.
263
+ """
264
+ global CURRENT_VDB
265
+ if not session.get("processed") or CURRENT_VDB is None:
266
  raise gr.Error("Please extract data first")
267
 
 
268
  llm = HuggingFaceEndpoint(
269
  repo_id=model_id,
270
  temperature=temp,
 
272
  huggingfacehub_api_token=HF_TOKEN
273
  )
274
 
275
+ # Retrieve top‐k text & images
276
+ text_col = CURRENT_VDB.get_collection("text_db")
277
  docs = text_col.query(
278
  query_texts=[question],
279
+ n_results=int(num_ctx),
280
  include=["documents"]
281
  )["documents"][0]
282
 
283
+ img_col = CURRENT_VDB.get_collection("image_db")
 
284
  img_q = img_col.query(
285
  query_texts=[question],
286
+ n_results=int(img_ctx),
287
  include=["metadatas", "documents"]
288
  )
289
+ img_descs = img_q["documents"][0] or ["No images found"]
290
+ images = []
291
  for meta in img_q["metadatas"][0]:
292
  b64 = meta.get("image", "")
293
  try:
 
296
  pass
297
  img_desc = "\n".join(img_descs)
298
 
299
+ # Build and call prompt
300
  prompt = PromptTemplate(
301
  template="""
302
  Context:
 
312
  """,
313
  input_variables=["text", "img_desc", "q"],
314
  )
315
+ user_input = prompt.format(
316
+ text="\n\n".join(docs),
317
+ img_desc=img_desc,
318
+ q=question
319
+ )
320
 
 
321
  try:
322
  answer = llm.invoke(user_input)
323
  except HfHubHTTPError as e:
 
328
  except Exception as e:
329
  answer = f"⚠️ Unexpected error: {e}"
330
 
 
331
  new_history = history + [
332
+ {"role": "user", "content": question},
333
+ {"role": "assistant", "content": answer}
334
  ]
 
 
335
  return new_history, docs, images
336
 
337