Spaces:
Sleeping
Sleeping
import gradio as gr | |
import logging | |
import os | |
import numpy as np | |
from sentence_transformers import SentenceTransformer | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
import faiss | |
from simple_salesforce import Salesforce | |
from dotenv import load_dotenv | |
import zipfile | |
from pathlib import Path | |
# Setup logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Load environment variables from .env file | |
load_dotenv() # Load the .env file | |
# Get the Salesforce credentials from environment variables | |
sf_username = os.getenv("SF_USERNAME") | |
sf_password = os.getenv("SF_PASSWORD") | |
sf_security_token = os.getenv("SF_SECURITY_TOKEN") | |
sf_instance_url = os.getenv("SF_INSTANCE_URL") | |
# Check if the environment variables are correctly set | |
if not sf_username or not sf_password or not sf_security_token or not sf_instance_url: | |
logger.error("β Salesforce credentials are missing from environment variables!") | |
raise ValueError("Salesforce credentials are not properly set.") | |
# Salesforce connection | |
try: | |
sf = Salesforce( | |
username=sf_username, | |
password=sf_password, | |
security_token=sf_security_token, | |
instance_url=sf_instance_url | |
) | |
logger.info("β Connected to Salesforce") | |
except Exception as e: | |
logger.error(f"β Salesforce connection failed: {str(e)}") | |
raise | |
# --- Extract zip files and read documents --- | |
def extract_zip(zip_path, extract_to): | |
try: | |
with zipfile.ZipFile(zip_path, 'r') as zip_ref: | |
zip_ref.extractall(extract_to) | |
logger.info(f"Extracted {zip_path} to {extract_to}") | |
except Exception as e: | |
logger.error(f"Failed to extract {zip_path}: {str(e)}") | |
raise | |
def load_documents(folder_path): | |
documents = [] | |
sources = [] | |
for file in Path(folder_path).rglob("*.txt"): | |
text = file.read_text(encoding="utf-8", errors="ignore") | |
documents.append(text) | |
sources.append(file.name) | |
return documents, sources | |
# --- Chunking --- | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=50) | |
# --- Load model --- | |
model = SentenceTransformer("all-MiniLM-L6-v2") | |
# --- Preprocessing --- | |
data_dir = Path("./data") | |
data_dir.mkdir(exist_ok=True) | |
doc_folders = [ | |
("Company_Policies.zip", "Company_Policies"), | |
("HR_Policies.zip", "Hr_Policies"), | |
("Contract_Clauses.zip", "Contract_Clauses") | |
] | |
all_chunks = [] | |
metadata = [] | |
for zip_name, folder in doc_folders: | |
zip_path = Path(zip_name) | |
if not zip_path.exists(): | |
logger.error(f"Zip file {zip_name} not found") | |
raise FileNotFoundError(f"Zip file {zip_name} not found") | |
extract_path = data_dir / folder | |
extract_path.mkdir(exist_ok=True) | |
extract_zip(zip_path, extract_path) | |
docs, sources = load_documents(extract_path) | |
if not docs: | |
logger.error(f"No documents found in {extract_path}") | |
raise ValueError(f"No documents found in {extract_path}") | |
for doc, src in zip(docs, sources): | |
chunks = text_splitter.split_text(doc) | |
all_chunks.extend(chunks) | |
src_url = f"https://company.com/{folder}/{src}" | |
metadata.extend([src_url] * len(chunks)) | |
# --- Embeddings + FAISS index --- | |
embeddings = model.encode(all_chunks) | |
index = faiss.IndexFlatL2(embeddings.shape[1]) | |
index.add(np.array(embeddings)) | |
logger.info("FAISS index built successfully") | |
# --- Create Record in Salesforce --- | |
def create_salesforce_record(query, answer, confidence_percentage, source_link): | |
try: | |
# Convert the confidence_percentage to Python float (to avoid numpy float32) | |
confidence_percentage = float(confidence_percentage) | |
# Data with correctly mapped field names | |
data = { | |
"Query__c": query, # Field for User Query | |
"Answer__c": answer, # Field for Answer | |
"Confidence_Percentage__c": confidence_percentage, # Field for Confidence Score | |
"Document_link__c": source_link, # Field for Document Link | |
} | |
# Creating the record in Salesforce | |
response = sf.chat_query_log__c.create(data) | |
# Check if record was created successfully | |
if 'id' in response: # If the response contains an 'id', the record is created successfully | |
record_id = response['id'] | |
logger.info(f"β Record created successfully in Salesforce with ID: {record_id}") | |
return record_id # Return the Salesforce record ID | |
else: | |
# Log the failure response | |
logger.error(f"β Failed to create Salesforce record. Response: {response}") | |
return None | |
except Exception as e: | |
# Log any error during record creation | |
logger.error(f"Error creating Salesforce record: {str(e)}") | |
return None | |
# --- Search & Answer --- | |
def answer_query(query): | |
try: | |
logger.info(f"Processing query: {query}") | |
query_embedding = model.encode([query]) | |
D, I = index.search(np.array(query_embedding), k=3) | |
top_chunks = [all_chunks[i] for i in I[0]] | |
top_sources = [metadata[i] for i in I[0]] | |
distances = D[0] | |
relevant_chunks = [ | |
chunk for chunk, dist in zip(top_chunks, distances) if dist < 0.8 | |
] | |
relevant_sources = [ | |
src for src, dist in zip(top_sources, distances) if dist < 0.8 | |
] | |
if not relevant_chunks: | |
return "No relevant information found.", "Confidence: 0%", "Source Link: None" | |
answer = relevant_chunks[0].strip() | |
min_distance = min(distances) | |
confidence_percentage = max(0, 100 - (min_distance * 100)) | |
source_link = relevant_sources[0] if relevant_sources else "None" | |
# Create Salesforce record for the query response | |
record_id = create_salesforce_record(query, answer, confidence_percentage, source_link) | |
if record_id: | |
return ( | |
answer, | |
f"Confidence: {confidence_percentage:.2f}%", | |
f"Source Link: {source_link}", | |
f"Salesforce Record ID: {record_id}" # Display the Salesforce record ID | |
) | |
else: | |
return ( | |
answer, | |
f"Confidence: {confidence_percentage:.2f}%", | |
f"Source Link: {source_link}", | |
"Failed to create record in Salesforce" | |
) | |
except Exception as e: | |
logger.error(f"Error in answer_query: {str(e)}") | |
return f"Error: {str(e)}", "", "", "" | |
# --- Gradio Chatbot UI Design --- | |
def process_question(q, chat_history): | |
if not q.strip(): | |
return chat_history + [("User", "Please enter a question.")], "", "" | |
answer, confidence, source, record_id = answer_query(q) | |
chat_history.append(("User", q)) | |
chat_history.append(("Bot", answer)) | |
return chat_history, confidence, source, record_id | |
# --- Chatbot UI with dynamic styling using elem_id --- | |
with gr.Blocks(title="Company Documents Q&A Chatbot", theme=gr.themes.Soft()) as demo: | |
gr.Markdown("## π **Company Policies Q&A Chatbot**") | |
with gr.Row(): | |
with gr.Column(scale=3): | |
question = gr.Textbox( | |
label="Ask a Question", | |
placeholder="What are the conditions for permanent employment status?", | |
lines=1, | |
interactive=True, | |
elem_id="user-question", | |
visible=True | |
) | |
with gr.Column(scale=1): | |
submit_btn = gr.Button("Submit", variant="primary", elem_id="submit-btn") | |
with gr.Row(): | |
with gr.Column(): | |
chat_history = gr.Chatbot( | |
label="Chat History", | |
elem_id="chatbox", | |
height=400, # Set a fixed height | |
show_label=False # Hide the label to make the chat more clean | |
) | |
conf_out = gr.Markdown(label="Confidence", elem_id="confidence") | |
source_out = gr.Markdown(label="Source Link", elem_id="source-link") | |
record_out = gr.Markdown(label="Salesforce Record ID", elem_id="salesforce-id") | |
submit_btn.click(fn=process_question, inputs=[question, chat_history], outputs=[chat_history, conf_out, source_out, record_out]) | |
demo.launch(server_name="0.0.0.0", server_port=7860, share=True) | |