import gradio as gr
# from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.chains import RetrievalQA
from langchain.embeddings import LlamaCppEmbeddings
from langchain.llms import GPT4All, LlamaCpp
from langchain.vectorstores import Chroma
from dotenv import load_dotenv
import os
from langchain.embeddings import HuggingFaceEmbeddings
load_dotenv()
from constants import CHROMA_SETTINGS
import openai
#from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
from gptcall import generate
#from yy_main import return_qa
# Set your OpenAI API key
api_key = os.environ.get('OPEN_AI_KEY') # Replace with your actual API key
openai.api_key = api_key
'''
def ask_gpt3(question):
response = openai.Completion.create(
engine="gpt-3.5-turbo",
prompt=question,
max_tokens=50
)
return response.choices[0].text.strip()
def generate(prompt):
try:
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}
],
max_tokens=1000,
temperature=0.9
)
return response['choices'][0]['message']['content']
except Exception as e:
return str(e)
'''
hf = os.environ.get("HF_TOKEN")
embeddings_model_name = os.environ.get("EMBEDDINGS_MODEL_NAME")
persist_directory = os.environ.get('PERSIST_DIRECTORY')
model_type = os.environ.get('MODEL_TYPE')
model_path = os.environ.get('MODEL_PATH')
model_n_ctx = os.environ.get('MODEL_N_CTX')
target_source_chunks = int(os.environ.get('TARGET_SOURCE_CHUNKS',4))
server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
def clear_history(request: gr.Request):
state = None
return ([], state, "")
def post_process_code(code):
sep = "\n```"
if sep in code:
blocks = code.split(sep)
if len(blocks) % 2 == 1:
for i in range(1, len(blocks), 2):
blocks[i] = blocks[i].replace("\\_", "_")
code = sep.join(blocks)
return code
def post_process_answer(answer):
answer += f"
"
answer = answer.replace("\n", "
")
return answer
def predict(
question: str,
system_content: str,
use_api: bool,
chatbot: list = [],
history: list = [],
):
try:
if use_api: # Check if API call is requested
history.append(question)
answer = generate(question)
history.append(answer)
else:
model_n_ctx = 2048
print(" print state in order", system_content, persist_directory, model_type, model_path, model_n_ctx, chatbot, history)
print("going inside embedding dunction",embeddings_model_name)
embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)
#embeddings = HuggingFaceInferenceAPIEmbeddings(api_key=hf, model_name="sentence-transformers/all-MiniLM-l6-v2")
db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS)
retriever = db.as_retriever(search_kwargs={"k": target_source_chunks})
# Prepare the LLM
# callbacks = [StreamingStdOutCallbackHandler()]
if model_type == "LlamaCpp":
llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, verbose=False)
elif model_type == "GPT4All":
llm = GPT4All(model=model_path, n_ctx=2048, backend='gptj', verbose=False)
else:
print(f"Model {model_type} not supported!")
exit()
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=False)
# Get the answer from the chain
prompt = system_content + f"\n Question: {question}"
res = qa(prompt)
print(res)
answer = res['result']
answer = post_process_answer(answer)
history.append(question)
history.append(answer)
# Ensure history has an even number of elements
if len(history) % 2 != 0:
history.append("")
chatbot = [(history[i], history[i + 1]) for i in range(0, len(history), 2)]
return chatbot, history
except Exception as e:
history.append("")
answer = server_error_msg + f" (error_code: 503)"
history.append(answer)
# Ensure history has an even number of elements
if len(history) % 2 != 0:
history.append("")
chatbot = [(history[i], history[i + 1]) for i in range(0, len(history), 2)]
return chatbot, history
def reset_textbox(): return gr.update(value="")
llama_embeddings_model = "models/ggml-model-q4_0.bin"
def main():
title = """