import gradio as gr import re from sentence_transformers import SentenceTransformer, util from transformers import T5ForConditionalGeneration, T5Tokenizer import torch # --- Model Loading --- # Load the sentence transformer model for creating embeddings embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') # Load the T5 model and tokenizer for question answering qa_model_name = 'google/flan-t5-base' qa_tokenizer = T5Tokenizer.from_pretrained(qa_model_name) qa_model = T5ForConditionalGeneration.from_pretrained(qa_model_name) # --- Global Variables --- chat_history_embeddings = None chat_lines = [] # --- Helper Functions --- def process_chat_file(file): """ Reads and parses the uploaded WhatsApp chat file. """ global chat_history_embeddings, chat_lines if file is None: return "Please upload a file first.", [] try: # Read the file content with open(file.name, 'r', encoding='utf-8') as f: content = f.read() # Simple line-based parsing (can be improved with regex for more complex formats) # This regex is a basic attempt and might need to be adjusted for different WhatsApp export formats. # It tries to capture lines that start with a date and time. lines = re.split(r'\n(?=\[\d{1,2}/\d{1,2}/\d{2,4}, \d{1,2}:\d{1,2}:\d{1,2}\])', content) # Filter out empty lines and system messages chat_lines = [line.strip() for line in lines if line.strip() and ":" in line] if not chat_lines: return "Could not find any chat messages in the file. Please check the file format.", [] # Create embeddings for the chat history chat_history_embeddings = embedding_model.encode(chat_lines, convert_to_tensor=True) return "File processed successfully! You can now ask questions.", [] except Exception as e: return f"An error occurred: {e}", [] def get_bot_response(user_message, history, temperature): """ Generates a response from the chatbot. """ global chat_history_embeddings, chat_lines if chat_history_embeddings is None: return "Please upload and process a chat file first." # 1. Find relevant context from the chat history question_embedding = embedding_model.encode(user_message, convert_to_tensor=True) cos_scores = util.pytorch_cos_sim(question_embedding, chat_history_embeddings)[0] # Get the top 5 most similar chat lines top_k = min(5, len(chat_lines)) top_results = torch.topk(cos_scores, k=top_k) context = "" for score, idx in zip(top_results[0], top_results[1]): context += chat_lines[idx] + "\n" # 2. Generate an answer using the T5 model prompt = f""" Answer the following question based on the provided chat history. If the answer is not in the context, say "I couldn't find an answer to that in the chat history." Chat History: {context} Question: {user_message} Answer: """ input_ids = qa_tokenizer.encode(prompt, return_tensors='pt') # Generate the output output_ids = qa_model.generate( input_ids, max_length=150, num_beams=4, temperature=temperature, early_stopping=True ) answer = qa_tokenizer.decode(output_ids[0], skip_special_tokens=True) return answer # --- Gradio UI --- with gr.Blocks(theme=gr.themes.Soft(primary_hue="teal", secondary_hue="orange")) as demo: gr.Markdown("# 💬 Chat with your WhatsApp History") gr.Markdown("Upload your WhatsApp chat `.txt` file and ask questions about it!") # Fun GIF gr.HTML("""
Chatbot GIF
""") with gr.Row(): with gr.Column(scale=1): file_upload = gr.File(label="Upload WhatsApp Chat (.txt)") process_button = gr.Button("Process File") upload_status = gr.Textbox(label="Status", interactive=False) temperature_slider = gr.Slider( minimum=0.1, maximum=1.0, value=0.1, step=0.1, label="Temperature", info="Lower values are more accurate, higher values are more creative." ) with gr.Column(scale=2): chatbot = gr.Chatbot(label="Chat") msg = gr.Textbox(label="Your Question") clear = gr.ClearButton([msg, chatbot]) # --- Event Handlers --- file_upload.upload(process_chat_file, inputs=[file_upload], outputs=[upload_status, chatbot]) process_button.click(process_chat_file, inputs=[file_upload], outputs=[upload_status, chatbot]) msg.submit(get_bot_response, [msg, chatbot, temperature_slider], [msg, chatbot]) if __name__ == "__main__": demo.launch(debug=True)