gpaasch commited on
Commit
16559c8
·
1 Parent(s): 1f22857

added ability to develop on my local rtx2060

Browse files
Files changed (2) hide show
  1. requirements.txt +2 -1
  2. utils/llama_index_utils.py +33 -6
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
- gradio==3.53.1
2
  llama-index==0.6.9
3
  openai==0.27.0
 
 
1
+ gradio[full]
2
  llama-index==0.6.9
3
  openai==0.27.0
4
+ transformers
utils/llama_index_utils.py CHANGED
@@ -1,16 +1,43 @@
1
- # llama_index_utils.py
2
- from llama_index import SimpleDirectoryReader, GPTVectorStoreIndex
 
 
3
 
4
  _index = None
5
 
6
- def build_index(data_path="data/icd10cm_tabular_2025"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  global _index
8
  if _index is None:
 
9
  docs = SimpleDirectoryReader(data_path).load_data()
10
- _index = GPTVectorStoreIndex.from_documents(docs)
 
 
11
  return _index
12
 
 
13
  def query_symptoms(prompt: str, top_k: int = 5):
 
 
 
14
  idx = build_index()
15
- qe = idx.as_query_engine(similarity_top_k=top_k)
16
- return qe.query(prompt)
 
 
 
1
+ import os
2
+
3
+ from transformers import pipeline
4
+ from llama_index import SimpleDirectoryReader, GPTVectorStoreIndex, LLMPredictor, OpenAI
5
 
6
  _index = None
7
 
8
+ def get_llm_predictor():
9
+ """
10
+ Return an LLMPredictor configured for local GPU (transformers) if USE_LOCAL_GPU=1,
11
+ otherwise uses OpenAI.
12
+ """
13
+ if os.getenv("USE_LOCAL_GPU") == "1":
14
+ # Local GPU inference using GPT-2 as an example
15
+ local_pipe = pipeline("text-generation", model="gpt2", device=0)
16
+ return LLMPredictor(llm=local_pipe)
17
+ # Default to OpenAI provider
18
+ return LLMPredictor(llm=OpenAI(temperature=0))
19
+
20
+
21
+ def build_index(data_path="data/icd10cm_tabular_2025"): # noqa: C901
22
+ """
23
+ Build (or retrieve cached) GPTVectorStoreIndex from ICD documents.
24
+ """
25
  global _index
26
  if _index is None:
27
+ # Load documents from the ICD data directory
28
  docs = SimpleDirectoryReader(data_path).load_data()
29
+ # Initialize the index with chosen LLM predictor
30
+ predictor = get_llm_predictor()
31
+ _index = GPTVectorStoreIndex.from_documents(docs, llm_predictor=predictor)
32
  return _index
33
 
34
+
35
  def query_symptoms(prompt: str, top_k: int = 5):
36
+ """
37
+ Query the index for the given symptom prompt and return the result.
38
+ """
39
  idx = build_index()
40
+ # Create a query engine with the same predictor
41
+ predictor = get_llm_predictor()
42
+ query_engine = idx.as_query_engine(similarity_top_k=top_k, llm_predictor=predictor)
43
+ return query_engine.query(prompt)