|
import gradio as gr |
|
|
|
from langchain.chains import RetrievalQA |
|
from langchain.embeddings import LlamaCppEmbeddings |
|
from langchain.llms import GPT4All, LlamaCpp |
|
from langchain.vectorstores import Chroma |
|
from dotenv import load_dotenv |
|
import os |
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
load_dotenv() |
|
from constants import CHROMA_SETTINGS |
|
import openai |
|
|
|
|
|
from gptcall import generate |
|
|
|
|
|
api_key = os.environ.get('OPEN_AI_KEY') |
|
openai.api_key = api_key |
|
''' |
|
def ask_gpt3(question): |
|
response = openai.Completion.create( |
|
engine="gpt-3.5-turbo", |
|
prompt=question, |
|
max_tokens=50 |
|
) |
|
return response.choices[0].text.strip() |
|
|
|
def generate(prompt): |
|
try: |
|
response = openai.ChatCompletion.create( |
|
model="gpt-3.5-turbo", |
|
messages=[ |
|
{"role": "system", "content": "You are a helpful assistant."}, |
|
{"role": "user", "content": prompt} |
|
], |
|
max_tokens=1000, |
|
temperature=0.9 |
|
) |
|
return response['choices'][0]['message']['content'] |
|
except Exception as e: |
|
return str(e) |
|
''' |
|
hf = os.environ.get("HF_TOKEN") |
|
embeddings_model_name = os.environ.get("EMBEDDINGS_MODEL_NAME") |
|
persist_directory = os.environ.get('PERSIST_DIRECTORY') |
|
|
|
model_type = os.environ.get('MODEL_TYPE') |
|
model_path = os.environ.get('MODEL_PATH') |
|
model_n_ctx = os.environ.get('MODEL_N_CTX') |
|
target_source_chunks = int(os.environ.get('TARGET_SOURCE_CHUNKS',4)) |
|
server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" |
|
|
|
def clear_history(request: gr.Request): |
|
state = None |
|
return ([], state, "") |
|
|
|
def post_process_code(code): |
|
sep = "\n```" |
|
if sep in code: |
|
blocks = code.split(sep) |
|
if len(blocks) % 2 == 1: |
|
for i in range(1, len(blocks), 2): |
|
blocks[i] = blocks[i].replace("\\_", "_") |
|
code = sep.join(blocks) |
|
return code |
|
|
|
def post_process_answer(answer): |
|
answer += f"<br><br>" |
|
answer = answer.replace("\n", "<br>") |
|
return answer |
|
|
|
def predict( |
|
question: str, |
|
system_content: str, |
|
use_api: bool, |
|
chatbot: list = [], |
|
history: list = [], |
|
): |
|
try: |
|
if use_api: |
|
history.append(question) |
|
answer = generate(question) |
|
history.append(answer) |
|
else: |
|
model_n_ctx = 2048 |
|
print(" print state in order", system_content, persist_directory, model_type, model_path, model_n_ctx, chatbot, history) |
|
print("going inside embedding dunction",embeddings_model_name) |
|
embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name) |
|
|
|
|
|
db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS) |
|
retriever = db.as_retriever(search_kwargs={"k": target_source_chunks}) |
|
|
|
|
|
|
|
if model_type == "LlamaCpp": |
|
llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, verbose=False) |
|
elif model_type == "GPT4All": |
|
llm = GPT4All(model=model_path, n_ctx=2048, backend='gptj', verbose=False) |
|
else: |
|
print(f"Model {model_type} not supported!") |
|
exit() |
|
|
|
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=False) |
|
|
|
|
|
prompt = system_content + f"\n Question: {question}" |
|
res = qa(prompt) |
|
print(res) |
|
answer = res['result'] |
|
answer = post_process_answer(answer) |
|
history.append(question) |
|
history.append(answer) |
|
|
|
|
|
if len(history) % 2 != 0: |
|
history.append("") |
|
|
|
chatbot = [(history[i], history[i + 1]) for i in range(0, len(history), 2)] |
|
return chatbot, history |
|
|
|
except Exception as e: |
|
history.append("") |
|
answer = server_error_msg + f" (error_code: 503)" |
|
history.append(answer) |
|
|
|
|
|
if len(history) % 2 != 0: |
|
history.append("") |
|
|
|
chatbot = [(history[i], history[i + 1]) for i in range(0, len(history), 2)] |
|
return chatbot, history |
|
|
|
|
|
|
|
def reset_textbox(): return gr.update(value="") |
|
|
|
llama_embeddings_model = "models/ggml-model-q4_0.bin" |
|
|
|
def main(): |
|
title = """ |
|
<h1 align="center">Chat with TxGpt 🤖</h1>""" |
|
|
|
css = """ |
|
@import url('https://fonts.googleapis.com/css2?family=Poppins:wght@400;700&display=swap'); |
|
|
|
/* Hide the footer */ |
|
footer .svelte-1lyswbr { |
|
display: none !important; |
|
} |
|
|
|
/* Center the column container */ |
|
#prompt_container { |
|
margin-left: auto; |
|
margin-right: auto; |
|
background: linear-gradient(to right, #48c6ef, #6f86d6); /* Gradient background */ |
|
padding: 20px; /* Decreased padding */ |
|
border-radius: 10px; |
|
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); |
|
color: black; |
|
font-family: 'Poppins', sans-serif; /* Poppins font */ |
|
font-weight: 600; /* Bold font */ |
|
resize: none; |
|
font-size: 18px; |
|
} |
|
|
|
/* Chatbot container styling */ |
|
#chatbot_container { |
|
margin: 0 auto; /* Remove left and right margins */ |
|
max-width: 80%; /* Adjust the maximum width as needed */ |
|
background: linear-gradient(to right, #ff7e5f, #feb47b); /* Gradient background */ |
|
padding: 20px; |
|
border-radius: 10px; |
|
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); |
|
} |
|
|
|
/* Chatbot message area styling */ |
|
#chatbot .wrap.svelte-13f7djk { |
|
height: 60vh; /* Adjusted height */ |
|
max-height: 60vh; /* Adjusted height */ |
|
border: 2px solid #007bff; |
|
border-radius: 10px; |
|
overflow-y: auto; |
|
padding: 20px; |
|
background-color: #e9f5ff; |
|
} |
|
|
|
/* User message styling */ |
|
#chatbot .message.user.svelte-13f7djk.svelte-13f7djk { |
|
width: fit-content; |
|
background: #007bff; |
|
color: white; |
|
border-bottom-right-radius: 0; |
|
border-top-left-radius: 10px; |
|
border-top-right-radius: 10px; |
|
border-bottom-left-radius: 10px; |
|
margin-bottom: 10px; |
|
padding: 10px 15px; |
|
font-size: 14px; |
|
font-family: 'Poppins', sans-serif; /* Poppins font */ |
|
font-weight: 700; /* Bold font */ |
|
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2); |
|
} |
|
|
|
/* Bot message styling */ |
|
#chatbot .message.bot.svelte-13f7djk.svelte-13f7djk { |
|
width: fit-content; |
|
background: #e1e1e1; |
|
color: black; |
|
border-bottom-left-radius: 0; |
|
border-top-right-radius: 10px; |
|
border-top-left-radius: 10px; |
|
border-bottom-right-radius: 10px; |
|
margin-bottom: 10px; |
|
padding: 10px 15px; |
|
font-size: 14px; |
|
font-family: 'Poppins', sans-serif; /* Poppins font */ |
|
font-weight: 700; /* Bold font */ |
|
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2); |
|
} |
|
|
|
/* Preformatted text styling */ |
|
#chatbot .pre { |
|
border: 2px solid #f1f1f1; |
|
padding: 10px; |
|
border-radius: 5px; |
|
background-color: #ffffff; |
|
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.05); |
|
font-family: 'Poppins', sans-serif; /* Poppins font */ |
|
font-size: 14px; |
|
font-weight: 400; /* Regular font */ |
|
} |
|
|
|
/* General preformatted text styling */ |
|
pre { |
|
white-space: pre-wrap; /* Since CSS 2.1 */ |
|
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */ |
|
white-space: -pre-wrap; /* Opera 4-6 */ |
|
white-space: -o-pre-wrap; /* Opera 7 */ |
|
word-wrap: break-word; /* Internet Explorer 5.5+ */ |
|
font-family: 'Poppins', sans-serif; /* Poppins font */ |
|
font-size: 14px; |
|
font-weight: 400; /* Regular font */ |
|
line-height: 1.5; |
|
color: #333; |
|
background-color: #f8f9fa; |
|
padding: 10px; |
|
border-radius: 5px; |
|
} |
|
|
|
/* Styling for accordion sections */ |
|
.accordion.svelte-1lyswbr { |
|
background-color: #e9f5ff; /* Light blue background for accordions */ |
|
border: 1px solid #007bff; |
|
border-radius: 10px; |
|
padding: 10px; |
|
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); |
|
resize: both; |
|
} |
|
|
|
/* Prompt styling */ |
|
#prompt_title { |
|
font-size: 24px; |
|
margin-bottom: 10px; |
|
resize= none; |
|
} |
|
|
|
/* Styling for Copy button */ |
|
.copy_button { |
|
display: inline-block; |
|
padding: 5px 10px; |
|
margin: 5px 0; |
|
font-size: 14px; |
|
cursor: pointer; |
|
color: #007bff; |
|
border: 1px solid #007bff; |
|
border-radius: 5px; |
|
background-color: #ffffff; |
|
transition: background-color 0.3s; |
|
} |
|
|
|
.copy_button:hover { |
|
background-color: #007bff; |
|
color: #ffffff; |
|
} |
|
""" |
|
with gr.Blocks(css=css) as demo: |
|
gr.HTML(title) |
|
with gr.Row(): |
|
with gr.Column(elem_id="prompt_container", scale=0.3): |
|
with gr.Accordion("Description", open=True): |
|
system_content = gr.Textbox(value="TxGpt talk to your local documents without internet. If you need information on public data, please enable the ChatGpt checkbox and start querying!",show_label=False,lines=5) |
|
|
|
with gr.Column(elem_id="chatbot_container", scale=0.7): |
|
chatbot = gr.Chatbot(elem_id="chatbot", label="TxGpt") |
|
question = gr.Textbox(placeholder="Ask something", show_label=False, value="") |
|
state = gr.State([]) |
|
use_api_toggle = gr.Checkbox(label="Enable ChatGpt", default=False, key="use_api") |
|
with gr.Row(): |
|
with gr.Column(): |
|
submit_btn = gr.Button(value="🚀 Send") |
|
with gr.Column(): |
|
clear_btn = gr.Button(value="🗑️ Clear history") |
|
|
|
question.submit( |
|
predict, |
|
[question, system_content, use_api_toggle, chatbot, state], |
|
[chatbot, state], |
|
) |
|
submit_btn.click( |
|
predict, |
|
[question, system_content, chatbot, state], |
|
[chatbot, state], |
|
) |
|
submit_btn.click(reset_textbox, [], [question]) |
|
clear_btn.click(clear_history, None, [chatbot, state, question]) |
|
question.submit(reset_textbox, [], [question]) |
|
demo.queue(concurrency_count=10, status_update_rate="auto") |
|
|
|
demo.launch(share=True, server_name='192.168.6.78') |
|
|
|
if __name__ == '__main__': |
|
""" import argparse |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--server-name", default="0.0.0.0") |
|
parser.add_argument("--server-port", default=8071) |
|
parser.add_argument("--share", action="store_true") |
|
parser.add_argument("--debug", action="store_true") |
|
parser.add_argument("--verbose", action="store_true") |
|
args = parser.parse_args() """ |
|
|
|
main() |
|
|