boryasbora commited on
Commit
ec75084
·
verified ·
1 Parent(s): 8b01f29

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -17
app.py CHANGED
@@ -14,9 +14,10 @@ from langchain_core.output_parsers import StrOutputParser
14
  from langchain_core.runnables import RunnableLambda
15
  from datetime import date
16
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
17
-
18
  import threading
19
  import time
 
 
20
  # Environment variables
21
  os.environ['LANGCHAIN_TRACING_V2'] = 'true'
22
  os.environ['LANGCHAIN_ENDPOINT'] = 'https://api.smith.langchain.com'
@@ -65,7 +66,7 @@ def retrieve_normal_context(retriever, question):
65
  # Your OLMOLLM class implementation here (adapted for the Hugging Face model)
66
 
67
  @st.cache_resource
68
- def get_chain(temperature):
69
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L12-v2")
70
 
71
  docstore_path = 'ohw_proj_chorma_db.pcl'
@@ -77,20 +78,26 @@ def get_chain(temperature):
77
  child_splitter = RecursiveCharacterTextSplitter(chunk_size=300,
78
  chunk_overlap=50)
79
  retriever = load_retriever(docstore_path,chroma_path,embeddings,child_splitter,parent_splitter)
80
- model, tokenizer = load_model()
81
-
82
- pipe = pipeline(
83
- "text-generation",
84
- model=model,
85
- tokenizer=tokenizer,
86
- max_length=1800,
87
- max_new_tokens = 200,
88
- temperature=temperature,
89
- top_p=0.95,
90
- repetition_penalty=1.15
91
- )
92
-
93
- llm = HuggingFacePipeline(pipeline=pipe)
 
 
 
 
 
 
94
 
95
 
96
  today = date.today()
@@ -147,8 +154,20 @@ def generate_response(chain, query, context):
147
  # Sidebar
148
  with st.sidebar:
149
  st.title("OHW Assistant")
 
 
 
 
 
150
  temperature = st.slider("Temperature: ", 0.0, 1.0, 0.5, 0.1)
151
- chain = get_chain(temperature)
 
 
 
 
 
 
 
152
  st.button('Clear Chat History', on_click=clear_chat_history)
153
 
154
  # Main app
 
14
  from langchain_core.runnables import RunnableLambda
15
  from datetime import date
16
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
 
17
  import threading
18
  import time
19
+ llm_list = ['Mistral-7B-Instruct-v0.2','Mixtral-8x7B-Instruct-v0.1','LLAMA3']
20
+ blablador_base = "https://helmholtz-blablador.fz-juelich.de:8000/v1"
21
  # Environment variables
22
  os.environ['LANGCHAIN_TRACING_V2'] = 'true'
23
  os.environ['LANGCHAIN_ENDPOINT'] = 'https://api.smith.langchain.com'
 
66
  # Your OLMOLLM class implementation here (adapted for the Hugging Face model)
67
 
68
  @st.cache_resource
69
+ def get_chain(temperature,selected_model):
70
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L12-v2")
71
 
72
  docstore_path = 'ohw_proj_chorma_db.pcl'
 
78
  child_splitter = RecursiveCharacterTextSplitter(chunk_size=300,
79
  chunk_overlap=50)
80
  retriever = load_retriever(docstore_path,chroma_path,embeddings,child_splitter,parent_splitter)
81
+ llm_api = 'glpat-AMzMevbqaVjp4HbLcVum'
82
+ llm = ChatOpenAI(model_name=selected_model,
83
+ temperature=temperature,
84
+ openai_api_key=llm_api,
85
+ openai_api_base=blablador_base,
86
+ streaming=True)
87
+ # model, tokenizer = load_model()
88
+
89
+ # pipe = pipeline(
90
+ # "text-generation",
91
+ # model=model,
92
+ # tokenizer=tokenizer,
93
+ # max_length=1800,
94
+ # max_new_tokens = 200,
95
+ # temperature=temperature,
96
+ # top_p=0.95,
97
+ # repetition_penalty=1.15
98
+ # )
99
+
100
+ # llm = HuggingFacePipeline(pipeline=pipe)
101
 
102
 
103
  today = date.today()
 
154
  # Sidebar
155
  with st.sidebar:
156
  st.title("OHW Assistant")
157
+ selected_model = st.sidebar.selectbox('Choose a LLM model',
158
+ llm_list,
159
+ key='selected_model',
160
+ index = None)
161
+
162
  temperature = st.slider("Temperature: ", 0.0, 1.0, 0.5, 0.1)
163
+ if selected_model in ['Mistral-7B-Instruct-v0.2', 'Mixtral-8x7B-Instruct-v0.1','LLAMA3']:
164
+ if selected_model == 'Mistral-7B-Instruct-v0.2':
165
+ selected_model = 'alias-fast'
166
+ elif selected_model == 'Mixtral-8x7B-Instruct-v0.1':
167
+ selected_model = 'alias-large'
168
+ elif selected_model == 'LLAMA3':
169
+ selected_model = 'alias-experimental'
170
+ chain = get_chain(temperature,selected_model)
171
  st.button('Clear Chat History', on_click=clear_chat_history)
172
 
173
  # Main app