import gradio as gr # from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler from langchain.chains import RetrievalQA from langchain.embeddings import LlamaCppEmbeddings from langchain.llms import GPT4All, LlamaCpp from langchain.vectorstores import Chroma from dotenv import load_dotenv import os from langchain.embeddings import HuggingFaceEmbeddings load_dotenv() from constants import CHROMA_SETTINGS import openai #from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings from gptcall import generate #from yy_main import return_qa # Set your OpenAI API key api_key = os.environ.get('OPEN_AI_KEY') # Replace with your actual API key openai.api_key = api_key ''' def ask_gpt3(question): response = openai.Completion.create( engine="gpt-3.5-turbo", prompt=question, max_tokens=50 ) return response.choices[0].text.strip() def generate(prompt): try: response = openai.ChatCompletion.create( model="gpt-3.5-turbo", messages=[ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt} ], max_tokens=1000, temperature=0.9 ) return response['choices'][0]['message']['content'] except Exception as e: return str(e) ''' hf = os.environ.get("HF_TOKEN") embeddings_model_name = os.environ.get("EMBEDDINGS_MODEL_NAME") persist_directory = os.environ.get('PERSIST_DIRECTORY') model_type = os.environ.get('MODEL_TYPE') model_path = os.environ.get('MODEL_PATH') model_n_ctx = os.environ.get('MODEL_N_CTX') target_source_chunks = int(os.environ.get('TARGET_SOURCE_CHUNKS',4)) server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" def clear_history(request: gr.Request): state = None return ([], state, "") def post_process_code(code): sep = "\n```" if sep in code: blocks = code.split(sep) if len(blocks) % 2 == 1: for i in range(1, len(blocks), 2): blocks[i] = blocks[i].replace("\\_", "_") code = sep.join(blocks) return code def post_process_answer(answer): answer += f"

" answer = answer.replace("\n", "
") return answer def predict( question: str, system_content: str, use_api: bool, chatbot: list = [], history: list = [], ): try: if use_api: # Check if API call is requested history.append(question) answer = generate(question) history.append(answer) else: model_n_ctx = 2048 print(" print state in order", system_content, persist_directory, model_type, model_path, model_n_ctx, chatbot, history) print("going inside embedding dunction",embeddings_model_name) embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name) #embeddings = HuggingFaceInferenceAPIEmbeddings(api_key=hf, model_name="sentence-transformers/all-MiniLM-l6-v2") db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS) retriever = db.as_retriever(search_kwargs={"k": target_source_chunks}) # Prepare the LLM # callbacks = [StreamingStdOutCallbackHandler()] if model_type == "LlamaCpp": llm = LlamaCpp(model_path=model_path, n_ctx=model_n_ctx, verbose=False) elif model_type == "GPT4All": llm = GPT4All(model=model_path, n_ctx=2048, backend='gptj', verbose=False) else: print(f"Model {model_type} not supported!") exit() qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=False) # Get the answer from the chain prompt = system_content + f"\n Question: {question}" res = qa(prompt) print(res) answer = res['result'] answer = post_process_answer(answer) history.append(question) history.append(answer) # Ensure history has an even number of elements if len(history) % 2 != 0: history.append("") chatbot = [(history[i], history[i + 1]) for i in range(0, len(history), 2)] return chatbot, history except Exception as e: history.append("") answer = server_error_msg + f" (error_code: 503)" history.append(answer) # Ensure history has an even number of elements if len(history) % 2 != 0: history.append("") chatbot = [(history[i], history[i + 1]) for i in range(0, len(history), 2)] return chatbot, history def reset_textbox(): return gr.update(value="") llama_embeddings_model = "models/ggml-model-q4_0.bin" def main(): title = """

Chat with TxGpt 🤖

""" css = """ @import url('https://fonts.googleapis.com/css2?family=Poppins:wght@400;700&display=swap'); /* Hide the footer */ footer .svelte-1lyswbr { display: none !important; } /* Center the column container */ #prompt_container { margin-left: auto; margin-right: auto; background: linear-gradient(to right, #48c6ef, #6f86d6); /* Gradient background */ padding: 20px; /* Decreased padding */ border-radius: 10px; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); color: black; font-family: 'Poppins', sans-serif; /* Poppins font */ font-weight: 600; /* Bold font */ resize: none; font-size: 18px; } /* Chatbot container styling */ #chatbot_container { margin: 0 auto; /* Remove left and right margins */ max-width: 80%; /* Adjust the maximum width as needed */ background: linear-gradient(to right, #ff7e5f, #feb47b); /* Gradient background */ padding: 20px; border-radius: 10px; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); } /* Chatbot message area styling */ #chatbot .wrap.svelte-13f7djk { height: 60vh; /* Adjusted height */ max-height: 60vh; /* Adjusted height */ border: 2px solid #007bff; border-radius: 10px; overflow-y: auto; padding: 20px; background-color: #e9f5ff; } /* User message styling */ #chatbot .message.user.svelte-13f7djk.svelte-13f7djk { width: fit-content; background: #007bff; color: white; border-bottom-right-radius: 0; border-top-left-radius: 10px; border-top-right-radius: 10px; border-bottom-left-radius: 10px; margin-bottom: 10px; padding: 10px 15px; font-size: 14px; font-family: 'Poppins', sans-serif; /* Poppins font */ font-weight: 700; /* Bold font */ box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2); } /* Bot message styling */ #chatbot .message.bot.svelte-13f7djk.svelte-13f7djk { width: fit-content; background: #e1e1e1; color: black; border-bottom-left-radius: 0; border-top-right-radius: 10px; border-top-left-radius: 10px; border-bottom-right-radius: 10px; margin-bottom: 10px; padding: 10px 15px; font-size: 14px; font-family: 'Poppins', sans-serif; /* Poppins font */ font-weight: 700; /* Bold font */ box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2); } /* Preformatted text styling */ #chatbot .pre { border: 2px solid #f1f1f1; padding: 10px; border-radius: 5px; background-color: #ffffff; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.05); font-family: 'Poppins', sans-serif; /* Poppins font */ font-size: 14px; font-weight: 400; /* Regular font */ } /* General preformatted text styling */ pre { white-space: pre-wrap; /* Since CSS 2.1 */ white-space: -moz-pre-wrap; /* Mozilla, since 1999 */ white-space: -pre-wrap; /* Opera 4-6 */ white-space: -o-pre-wrap; /* Opera 7 */ word-wrap: break-word; /* Internet Explorer 5.5+ */ font-family: 'Poppins', sans-serif; /* Poppins font */ font-size: 14px; font-weight: 400; /* Regular font */ line-height: 1.5; color: #333; background-color: #f8f9fa; padding: 10px; border-radius: 5px; } /* Styling for accordion sections */ .accordion.svelte-1lyswbr { background-color: #e9f5ff; /* Light blue background for accordions */ border: 1px solid #007bff; border-radius: 10px; padding: 10px; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); resize: both; } /* Prompt styling */ #prompt_title { font-size: 24px; margin-bottom: 10px; resize= none; } /* Styling for Copy button */ .copy_button { display: inline-block; padding: 5px 10px; margin: 5px 0; font-size: 14px; cursor: pointer; color: #007bff; border: 1px solid #007bff; border-radius: 5px; background-color: #ffffff; transition: background-color 0.3s; } .copy_button:hover { background-color: #007bff; color: #ffffff; } """ with gr.Blocks(css=css) as demo: gr.HTML(title) with gr.Row(): with gr.Column(elem_id="prompt_container", scale=0.3): # Separate column for prompt with gr.Accordion("Description", open=True): system_content = gr.Textbox(value="TxGpt talk to your local documents without internet. If you need information on public data, please enable the ChatGpt checkbox and start querying!",show_label=False,lines=5) with gr.Column(elem_id="chatbot_container", scale=0.7): # Right column for chatbot interface chatbot = gr.Chatbot(elem_id="chatbot", label="TxGpt") question = gr.Textbox(placeholder="Ask something", show_label=False, value="") state = gr.State([]) use_api_toggle = gr.Checkbox(label="Enable ChatGpt", default=False, key="use_api") with gr.Row(): with gr.Column(): submit_btn = gr.Button(value="🚀 Send") with gr.Column(): clear_btn = gr.Button(value="🗑️ Clear history") question.submit( predict, [question, system_content, use_api_toggle, chatbot, state], [chatbot, state], ) submit_btn.click( predict, [question, system_content, chatbot, state], [chatbot, state], ) submit_btn.click(reset_textbox, [], [question]) clear_btn.click(clear_history, None, [chatbot, state, question]) question.submit(reset_textbox, [], [question]) demo.queue(concurrency_count=10, status_update_rate="auto") #demo.launch(server_name=args.server_name, server_port=args.server_port, share=args.share, debug=args.debug) demo.launch(share=True, server_name='192.168.6.78') if __name__ == '__main__': """ import argparse parser = argparse.ArgumentParser() parser.add_argument("--server-name", default="0.0.0.0") parser.add_argument("--server-port", default=8071) parser.add_argument("--share", action="store_true") parser.add_argument("--debug", action="store_true") parser.add_argument("--verbose", action="store_true") args = parser.parse_args() """ main()