|
import gradio as gr |
|
import re |
|
from sentence_transformers import SentenceTransformer, util |
|
from transformers import T5ForConditionalGeneration, T5Tokenizer |
|
import torch |
|
|
|
|
|
|
|
embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') |
|
|
|
|
|
qa_model_name = 'google/flan-t5-base' |
|
qa_tokenizer = T5Tokenizer.from_pretrained(qa_model_name) |
|
qa_model = T5ForConditionalGeneration.from_pretrained(qa_model_name) |
|
|
|
|
|
chat_history_embeddings = None |
|
chat_lines = [] |
|
|
|
|
|
def process_chat_file(file): |
|
""" |
|
Reads and parses the uploaded WhatsApp chat file. |
|
""" |
|
global chat_history_embeddings, chat_lines |
|
if file is None: |
|
return "Please upload a file first.", [] |
|
|
|
try: |
|
|
|
with open(file.name, 'r', encoding='utf-8') as f: |
|
content = f.read() |
|
|
|
|
|
|
|
|
|
lines = re.split(r'\n(?=\[\d{1,2}/\d{1,2}/\d{2,4}, \d{1,2}:\d{1,2}:\d{1,2}\])', content) |
|
|
|
|
|
chat_lines = [line.strip() for line in lines if line.strip() and ":" in line] |
|
|
|
if not chat_lines: |
|
return "Could not find any chat messages in the file. Please check the file format.", [] |
|
|
|
|
|
chat_history_embeddings = embedding_model.encode(chat_lines, convert_to_tensor=True) |
|
|
|
return "File processed successfully! You can now ask questions.", [] |
|
except Exception as e: |
|
return f"An error occurred: {e}", [] |
|
|
|
def get_bot_response(user_message, history, temperature): |
|
""" |
|
Generates a response from the chatbot. |
|
""" |
|
global chat_history_embeddings, chat_lines |
|
|
|
if chat_history_embeddings is None: |
|
return "Please upload and process a chat file first." |
|
|
|
|
|
question_embedding = embedding_model.encode(user_message, convert_to_tensor=True) |
|
cos_scores = util.pytorch_cos_sim(question_embedding, chat_history_embeddings)[0] |
|
|
|
|
|
top_k = min(5, len(chat_lines)) |
|
top_results = torch.topk(cos_scores, k=top_k) |
|
|
|
context = "" |
|
for score, idx in zip(top_results[0], top_results[1]): |
|
context += chat_lines[idx] + "\n" |
|
|
|
|
|
prompt = f""" |
|
Answer the following question based on the provided chat history. |
|
If the answer is not in the context, say "I couldn't find an answer to that in the chat history." |
|
|
|
Chat History: |
|
{context} |
|
|
|
Question: {user_message} |
|
|
|
Answer: |
|
""" |
|
|
|
input_ids = qa_tokenizer.encode(prompt, return_tensors='pt') |
|
|
|
|
|
output_ids = qa_model.generate( |
|
input_ids, |
|
max_length=150, |
|
num_beams=4, |
|
temperature=temperature, |
|
early_stopping=True |
|
) |
|
|
|
answer = qa_tokenizer.decode(output_ids[0], skip_special_tokens=True) |
|
|
|
return answer |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft(primary_hue="teal", secondary_hue="orange")) as demo: |
|
gr.Markdown("# 💬 Chat with your WhatsApp History") |
|
gr.Markdown("Upload your WhatsApp chat `.txt` file and ask questions about it!") |
|
|
|
|
|
gr.HTML(""" |
|
<div style="text-align: center;"> |
|
<img src="https://media.giphy.com/media/v1.Y2lkPTc5MGI3NjExaDB2d2k5eXNoc2FqZzNqZzZqenp2cDIzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZzZ-/media/k-pop/images/bts-oppas-and-hyungs-and-dongsaengs-and-no.gif" alt="Chatbot GIF" style="width:300px; height:auto; border-radius: 15px;"> |
|
</div> |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
file_upload = gr.File(label="Upload WhatsApp Chat (.txt)") |
|
process_button = gr.Button("Process File") |
|
upload_status = gr.Textbox(label="Status", interactive=False) |
|
|
|
temperature_slider = gr.Slider( |
|
minimum=0.1, |
|
maximum=1.0, |
|
value=0.1, |
|
step=0.1, |
|
label="Temperature", |
|
info="Lower values are more accurate, higher values are more creative." |
|
) |
|
|
|
with gr.Column(scale=2): |
|
chatbot = gr.Chatbot(label="Chat") |
|
msg = gr.Textbox(label="Your Question") |
|
clear = gr.ClearButton([msg, chatbot]) |
|
|
|
|
|
file_upload.upload(process_chat_file, inputs=[file_upload], outputs=[upload_status, chatbot]) |
|
process_button.click(process_chat_file, inputs=[file_upload], outputs=[upload_status, chatbot]) |
|
msg.submit(get_bot_response, [msg, chatbot, temperature_slider], [msg, chatbot]) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(debug=True) |
|
|