Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -14,9 +14,10 @@ from langchain_core.output_parsers import StrOutputParser
|
|
14 |
from langchain_core.runnables import RunnableLambda
|
15 |
from datetime import date
|
16 |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
17 |
-
|
18 |
import threading
|
19 |
import time
|
|
|
|
|
20 |
# Environment variables
|
21 |
os.environ['LANGCHAIN_TRACING_V2'] = 'true'
|
22 |
os.environ['LANGCHAIN_ENDPOINT'] = 'https://api.smith.langchain.com'
|
@@ -65,7 +66,7 @@ def retrieve_normal_context(retriever, question):
|
|
65 |
# Your OLMOLLM class implementation here (adapted for the Hugging Face model)
|
66 |
|
67 |
@st.cache_resource
|
68 |
-
def get_chain(temperature):
|
69 |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L12-v2")
|
70 |
|
71 |
docstore_path = 'ohw_proj_chorma_db.pcl'
|
@@ -77,20 +78,26 @@ def get_chain(temperature):
|
|
77 |
child_splitter = RecursiveCharacterTextSplitter(chunk_size=300,
|
78 |
chunk_overlap=50)
|
79 |
retriever = load_retriever(docstore_path,chroma_path,embeddings,child_splitter,parent_splitter)
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
|
96 |
today = date.today()
|
@@ -147,8 +154,20 @@ def generate_response(chain, query, context):
|
|
147 |
# Sidebar
|
148 |
with st.sidebar:
|
149 |
st.title("OHW Assistant")
|
|
|
|
|
|
|
|
|
|
|
150 |
temperature = st.slider("Temperature: ", 0.0, 1.0, 0.5, 0.1)
|
151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
st.button('Clear Chat History', on_click=clear_chat_history)
|
153 |
|
154 |
# Main app
|
|
|
14 |
from langchain_core.runnables import RunnableLambda
|
15 |
from datetime import date
|
16 |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
|
|
17 |
import threading
|
18 |
import time
|
19 |
+
llm_list = ['Mistral-7B-Instruct-v0.2','Mixtral-8x7B-Instruct-v0.1','LLAMA3']
|
20 |
+
blablador_base = "https://helmholtz-blablador.fz-juelich.de:8000/v1"
|
21 |
# Environment variables
|
22 |
os.environ['LANGCHAIN_TRACING_V2'] = 'true'
|
23 |
os.environ['LANGCHAIN_ENDPOINT'] = 'https://api.smith.langchain.com'
|
|
|
66 |
# Your OLMOLLM class implementation here (adapted for the Hugging Face model)
|
67 |
|
68 |
@st.cache_resource
|
69 |
+
def get_chain(temperature,selected_model):
|
70 |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L12-v2")
|
71 |
|
72 |
docstore_path = 'ohw_proj_chorma_db.pcl'
|
|
|
78 |
child_splitter = RecursiveCharacterTextSplitter(chunk_size=300,
|
79 |
chunk_overlap=50)
|
80 |
retriever = load_retriever(docstore_path,chroma_path,embeddings,child_splitter,parent_splitter)
|
81 |
+
llm_api = 'glpat-AMzMevbqaVjp4HbLcVum'
|
82 |
+
llm = ChatOpenAI(model_name=selected_model,
|
83 |
+
temperature=temperature,
|
84 |
+
openai_api_key=llm_api,
|
85 |
+
openai_api_base=blablador_base,
|
86 |
+
streaming=True)
|
87 |
+
# model, tokenizer = load_model()
|
88 |
+
|
89 |
+
# pipe = pipeline(
|
90 |
+
# "text-generation",
|
91 |
+
# model=model,
|
92 |
+
# tokenizer=tokenizer,
|
93 |
+
# max_length=1800,
|
94 |
+
# max_new_tokens = 200,
|
95 |
+
# temperature=temperature,
|
96 |
+
# top_p=0.95,
|
97 |
+
# repetition_penalty=1.15
|
98 |
+
# )
|
99 |
+
|
100 |
+
# llm = HuggingFacePipeline(pipeline=pipe)
|
101 |
|
102 |
|
103 |
today = date.today()
|
|
|
154 |
# Sidebar
|
155 |
with st.sidebar:
|
156 |
st.title("OHW Assistant")
|
157 |
+
selected_model = st.sidebar.selectbox('Choose a LLM model',
|
158 |
+
llm_list,
|
159 |
+
key='selected_model',
|
160 |
+
index = None)
|
161 |
+
|
162 |
temperature = st.slider("Temperature: ", 0.0, 1.0, 0.5, 0.1)
|
163 |
+
if selected_model in ['Mistral-7B-Instruct-v0.2', 'Mixtral-8x7B-Instruct-v0.1','LLAMA3']:
|
164 |
+
if selected_model == 'Mistral-7B-Instruct-v0.2':
|
165 |
+
selected_model = 'alias-fast'
|
166 |
+
elif selected_model == 'Mixtral-8x7B-Instruct-v0.1':
|
167 |
+
selected_model = 'alias-large'
|
168 |
+
elif selected_model == 'LLAMA3':
|
169 |
+
selected_model = 'alias-experimental'
|
170 |
+
chain = get_chain(temperature,selected_model)
|
171 |
st.button('Clear Chat History', on_click=clear_chat_history)
|
172 |
|
173 |
# Main app
|