|
import gradio as gr |
|
import chromadb |
|
from sentence_transformers import SentenceTransformer |
|
from transformers import pipeline |
|
import re |
|
|
|
|
|
print("Loading sentence-transformer model for retrieval...") |
|
retriever_model = SentenceTransformer('all-MiniLM-L6-v2') |
|
print("Retriever model loaded.") |
|
|
|
|
|
print("Loading generative model for answering (google/flan-t5-base)...") |
|
|
|
generator_pipe = pipeline("text2text-generation", model="google/flan-t5-base", device=-1) |
|
print("Generative model loaded.") |
|
|
|
|
|
|
|
|
|
client = chromadb.Client() |
|
|
|
try: |
|
collection = client.create_collection("whatsapp_chat_v2") |
|
print("ChromaDB collection created.") |
|
|
|
|
|
try: |
|
print("Loading data from my_data.txt...") |
|
with open('my_data.txt', 'r', encoding='utf-8') as f: |
|
lines = [line.strip() for line in f if line.strip()] |
|
|
|
message_pattern = re.compile(r'^\[?.*?\]?\s*.*?:\s*(.*)') |
|
|
|
cleaned_documents = [] |
|
for line in lines: |
|
match = message_pattern.match(line) |
|
if match and match.group(1): |
|
cleaned_documents.append(match.group(1).strip()) |
|
|
|
if not cleaned_documents: |
|
print("ERROR: Could not extract any valid messages from my_data.txt.") |
|
cleaned_documents = ["Error: The data file 'my_data.txt' could not be processed."] |
|
else: |
|
print(f"Successfully loaded and cleaned {len(cleaned_documents)} messages.") |
|
|
|
documents = cleaned_documents |
|
|
|
except FileNotFoundError: |
|
print("Error: my_data.txt not found.") |
|
documents = ["Error: my_data.txt not found. Please make sure the file is uploaded."] |
|
|
|
|
|
batch_size = 5000 |
|
print("Starting to process and add documents in batches...") |
|
for i in range(0, len(documents), batch_size): |
|
end_i = min(i + batch_size, len(documents)) |
|
batch_docs = documents[i:end_i] |
|
print(f"Processing batch of {len(batch_docs)} documents...") |
|
batch_embeddings = retriever_model.encode(batch_docs) |
|
batch_ids = [f"id_{j}" for j in range(i, end_i)] |
|
collection.add( |
|
embeddings=batch_embeddings.tolist(), |
|
documents=batch_docs, |
|
ids=batch_ids |
|
) |
|
print("All documents have been successfully added to ChromaDB.") |
|
|
|
except ValueError: |
|
collection = client.get_collection("whatsapp_chat_v2") |
|
print("ChromaDB collection loaded.") |
|
|
|
|
|
|
|
def chatbot_response(message, history): |
|
query_embedding = retriever_model.encode([message]).tolist() |
|
results = collection.query( |
|
query_embeddings=query_embedding, |
|
n_results=5 |
|
) |
|
retrieved_documents = results['documents'][0] |
|
|
|
if not retrieved_documents or "Error:" in retrieved_documents[0]: |
|
return "I'm sorry, I couldn't find any relevant information in the chat history. 🤔" |
|
|
|
context = "\n- ".join(retrieved_documents) |
|
prompt = f""" |
|
Based on the following excerpts from a WhatsApp chat, provide a helpful and accurate answer to the user's question. |
|
|
|
Chat Context: |
|
- {context} |
|
|
|
Question: |
|
{message} |
|
|
|
Answer: |
|
""" |
|
|
|
generated_text = generator_pipe(prompt, max_length=150, num_beams=5, early_stopping=True) |
|
response = generated_text[0]['generated_text'] |
|
|
|
return response |
|
|
|
|
|
iface = gr.ChatInterface( |
|
fn=chatbot_response, |
|
title="WhatsApp Chat Bot ⚡️", |
|
description="Ask me anything about this WhatsApp chat history. (Powered by flan-t5-base)", |
|
theme="soft", |
|
examples=["What was the final decision on the project deadline?", "Summarize the conversation about the event."], |
|
cache_examples=False |
|
) |
|
|
|
|
|
iface.launch() |