TxGpt / app.py
TestingXperts's picture
Update app.py
5867dc8 verified
import gradio as gr
# from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.chains import RetrievalQA
from langchain.embeddings import LlamaCppEmbeddings
from langchain.llms import GPT4All, LlamaCpp
from langchain.vectorstores import Chroma
from dotenv import load_dotenv
import os
from langchain.embeddings import HuggingFaceEmbeddings
load_dotenv()
from constants import CHROMA_SETTINGS
import openai
#from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
from gptcall import generate
#from yy_main import return_qa
# Set your OpenAI API key
api_key = os.environ.get('OPEN_AI_KEY') # Replace with your actual API key
openai.api_key = api_key
'''
def ask_gpt3(question):
response = openai.Completion.create(
engine="gpt-3.5-turbo",
prompt=question,
max_tokens=50
)
return response.choices[0].text.strip()
def generate(prompt):
try:
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}
],
max_tokens=1000,
temperature=0.9
)
return response['choices'][0]['message']['content']
except Exception as e:
return str(e)
'''
hf = os.environ.get("HF_TOKEN")
embeddings_model_name = os.environ.get("EMBEDDINGS_MODEL_NAME")
persist_directory = os.environ.get('PERSIST_DIRECTORY')
model_type = os.environ.get('MODEL_TYPE')
model_path = os.environ.get('MODEL_PATH')
model_n_ctx = os.environ.get('MODEL_N_CTX')
target_source_chunks = int(os.environ.get('TARGET_SOURCE_CHUNKS',4))
server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
def clear_history(request: gr.Request):
state = None
return ([], state, "")
def post_process_code(code):
sep = "\n```"
if sep in code:
blocks = code.split(sep)
if len(blocks) % 2 == 1:
for i in range(1, len(blocks), 2):
blocks[i] = blocks[i].replace("\\_", "_")
code = sep.join(blocks)
return code
def post_process_answer(answer):
answer += f"<br><br>"
answer = answer.replace("\n", "<br>")
return answer
def predict(
question: str,
system_content: str,
use_api: bool,
chatbot: list = [],
history: list = [],
):
try:
if use_api: # Check if API call is requested
history.append(question)
answer = generate(question)
history.append(answer)
else:
model_n_ctx = 2048
print(" print state in order", system_content, persist_directory, model_type, model_path, model_n_ctx, chatbot, history)
print("going inside embedding dunction",embeddings_model_name)
embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)
#embeddings = HuggingFaceInferenceAPIEmbeddings(api_key=hf, model_name="sentence-transformers/all-MiniLM-l6-v2")
db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS)
retriever = db.as_retriever(search_kwargs={"k": target_source_chunks})
# Prepare the LLM
# callbacks = [StreamingStdOutCallbackHandler()]
if model_type == "LlamaCpp":
llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, verbose=False)
elif model_type == "GPT4All":
llm = GPT4All(model=model_path, n_ctx=2048, backend='gptj', verbose=False)
else:
print(f"Model {model_type} not supported!")
exit()
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=False)
# Get the answer from the chain
prompt = system_content + f"\n Question: {question}"
res = qa(prompt)
print(res)
answer = res['result']
answer = post_process_answer(answer)
history.append(question)
history.append(answer)
# Ensure history has an even number of elements
if len(history) % 2 != 0:
history.append("")
chatbot = [(history[i], history[i + 1]) for i in range(0, len(history), 2)]
return chatbot, history
except Exception as e:
history.append("")
answer = server_error_msg + f" (error_code: 503)"
history.append(answer)
# Ensure history has an even number of elements
if len(history) % 2 != 0:
history.append("")
chatbot = [(history[i], history[i + 1]) for i in range(0, len(history), 2)]
return chatbot, history
def reset_textbox(): return gr.update(value="")
llama_embeddings_model = "models/ggml-model-q4_0.bin"
def main():
title = """
<h1 align="center">Chat with TxGpt 🤖</h1>"""
css = """
@import url('https://fonts.googleapis.com/css2?family=Poppins:wght@400;700&display=swap');
/* Hide the footer */
footer .svelte-1lyswbr {
display: none !important;
}
/* Center the column container */
#prompt_container {
margin-left: auto;
margin-right: auto;
background: linear-gradient(to right, #48c6ef, #6f86d6); /* Gradient background */
padding: 20px; /* Decreased padding */
border-radius: 10px;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
color: black;
font-family: 'Poppins', sans-serif; /* Poppins font */
font-weight: 600; /* Bold font */
resize: none;
font-size: 18px;
}
/* Chatbot container styling */
#chatbot_container {
margin: 0 auto; /* Remove left and right margins */
max-width: 80%; /* Adjust the maximum width as needed */
background: linear-gradient(to right, #ff7e5f, #feb47b); /* Gradient background */
padding: 20px;
border-radius: 10px;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
}
/* Chatbot message area styling */
#chatbot .wrap.svelte-13f7djk {
height: 60vh; /* Adjusted height */
max-height: 60vh; /* Adjusted height */
border: 2px solid #007bff;
border-radius: 10px;
overflow-y: auto;
padding: 20px;
background-color: #e9f5ff;
}
/* User message styling */
#chatbot .message.user.svelte-13f7djk.svelte-13f7djk {
width: fit-content;
background: #007bff;
color: white;
border-bottom-right-radius: 0;
border-top-left-radius: 10px;
border-top-right-radius: 10px;
border-bottom-left-radius: 10px;
margin-bottom: 10px;
padding: 10px 15px;
font-size: 14px;
font-family: 'Poppins', sans-serif; /* Poppins font */
font-weight: 700; /* Bold font */
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2);
}
/* Bot message styling */
#chatbot .message.bot.svelte-13f7djk.svelte-13f7djk {
width: fit-content;
background: #e1e1e1;
color: black;
border-bottom-left-radius: 0;
border-top-right-radius: 10px;
border-top-left-radius: 10px;
border-bottom-right-radius: 10px;
margin-bottom: 10px;
padding: 10px 15px;
font-size: 14px;
font-family: 'Poppins', sans-serif; /* Poppins font */
font-weight: 700; /* Bold font */
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2);
}
/* Preformatted text styling */
#chatbot .pre {
border: 2px solid #f1f1f1;
padding: 10px;
border-radius: 5px;
background-color: #ffffff;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.05);
font-family: 'Poppins', sans-serif; /* Poppins font */
font-size: 14px;
font-weight: 400; /* Regular font */
}
/* General preformatted text styling */
pre {
white-space: pre-wrap; /* Since CSS 2.1 */
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
white-space: -pre-wrap; /* Opera 4-6 */
white-space: -o-pre-wrap; /* Opera 7 */
word-wrap: break-word; /* Internet Explorer 5.5+ */
font-family: 'Poppins', sans-serif; /* Poppins font */
font-size: 14px;
font-weight: 400; /* Regular font */
line-height: 1.5;
color: #333;
background-color: #f8f9fa;
padding: 10px;
border-radius: 5px;
}
/* Styling for accordion sections */
.accordion.svelte-1lyswbr {
background-color: #e9f5ff; /* Light blue background for accordions */
border: 1px solid #007bff;
border-radius: 10px;
padding: 10px;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
resize: both;
}
/* Prompt styling */
#prompt_title {
font-size: 24px;
margin-bottom: 10px;
resize= none;
}
/* Styling for Copy button */
.copy_button {
display: inline-block;
padding: 5px 10px;
margin: 5px 0;
font-size: 14px;
cursor: pointer;
color: #007bff;
border: 1px solid #007bff;
border-radius: 5px;
background-color: #ffffff;
transition: background-color 0.3s;
}
.copy_button:hover {
background-color: #007bff;
color: #ffffff;
}
"""
with gr.Blocks(css=css) as demo:
gr.HTML(title)
with gr.Row():
with gr.Column(elem_id="prompt_container", scale=0.3): # Separate column for prompt
with gr.Accordion("Description", open=True):
system_content = gr.Textbox(value="TxGpt talk to your local documents without internet. If you need information on public data, please enable the ChatGpt checkbox and start querying!",show_label=False,lines=5)
with gr.Column(elem_id="chatbot_container", scale=0.7): # Right column for chatbot interface
chatbot = gr.Chatbot(elem_id="chatbot", label="TxGpt")
question = gr.Textbox(placeholder="Ask something", show_label=False, value="")
state = gr.State([])
use_api_toggle = gr.Checkbox(label="Enable ChatGpt", default=False, key="use_api")
with gr.Row():
with gr.Column():
submit_btn = gr.Button(value="🚀 Send")
with gr.Column():
clear_btn = gr.Button(value="🗑️ Clear history")
question.submit(
predict,
[question, system_content, use_api_toggle, chatbot, state],
[chatbot, state],
)
submit_btn.click(
predict,
[question, system_content, chatbot, state],
[chatbot, state],
)
submit_btn.click(reset_textbox, [], [question])
clear_btn.click(clear_history, None, [chatbot, state, question])
question.submit(reset_textbox, [], [question])
demo.queue(concurrency_count=10, status_update_rate="auto")
#demo.launch(server_name=args.server_name, server_port=args.server_port, share=args.share, debug=args.debug)
demo.launch(share=True, server_name='192.168.6.78')
if __name__ == '__main__':
""" import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--server-name", default="0.0.0.0")
parser.add_argument("--server-port", default=8071)
parser.add_argument("--share", action="store_true")
parser.add_argument("--debug", action="store_true")
parser.add_argument("--verbose", action="store_true")
args = parser.parse_args() """
main()