Spaces:
Running
Running
# This block contains the full combined script for testing. | |
# It includes all the code from the previous successful steps. | |
# Combined Imports | |
import spaces | |
import os | |
import gradio as gr | |
from huggingface_hub import InferenceClient | |
import torch | |
import re | |
import warnings | |
import time | |
import json | |
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, BitsAndBytesConfig | |
from sentence_transformers import SentenceTransformer, util, CrossEncoder | |
import gspread | |
# from google.colab import auth | |
from google.auth import default | |
from tqdm import tqdm | |
from duckduckgo_search import DDGS | |
import spacy | |
from datetime import date, timedelta | |
from dateutil.relativedelta import relativedelta # Corrected typo | |
import traceback # Import traceback | |
import base64 # Import base64 | |
def startup(): | |
print("GPU function registered for Hugging Face Spaces startup.") | |
return "Ready" | |
startup() | |
# Suppress warnings | |
warnings.filterwarnings("ignore", category=UserWarning) | |
# Define global variables and load secrets | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
SHEET_ID = "19ipxC2vHYhpXCefpxpIkpeYdI43a1Ku2kYwecgUULIw" | |
GOOGLE_BASE64_CREDENTIALS = os.getenv("GOOGLE_BASE64_CREDENTIALS") | |
# Initialize InferenceClient | |
client = InferenceClient("google/gemma-2-9b-it", token=HF_TOKEN) | |
# Initialize InferenceClient | |
client = InferenceClient("google/gemma-2-9b-it", token=HF_TOKEN) | |
# Load spacy model for sentence splitting | |
nlp = None | |
try: | |
nlp = spacy.load("en_core_web_sm") | |
print("SpaCy model 'en_core_web_sm' loaded.") | |
except OSError: | |
print("SpaCy model 'en_core_web_sm' not found. Downloading...") | |
try: | |
os.system("python -m spacy download en_core_web_sm") | |
nlp = spacy.load("en_core_web_sm") | |
print("SpaCy model 'en_core_web_sm' downloaded and loaded.") | |
except Exception as e: | |
print(f"Failed to download or load SpaCy model: {e}") | |
# Load SentenceTransformer for RAG/business info retrieval | |
embedder = None | |
try: | |
print("Attempting to load Sentence Transformer (sentence-transformers/paraphrase-MiniLM-L6-v2)...") | |
embedder = SentenceTransformer("sentence-transformers/paraphrase-MiniLM-L6-v2") | |
print("Sentence Transformer loaded.") | |
except Exception as e: | |
print(f"Error loading Sentence Transformer: {e}") | |
# Load a Cross-Encoder model for re-ranking retrieved documents | |
reranker = None | |
try: | |
print("Attempting to load Cross-Encoder Reranker (cross-encoder/ms-marco-MiniLM-L6-v2)...") | |
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L6-v2') | |
print("Cross-Encoder Reranker loaded.") | |
except Exception as e: | |
print(f"Error loading Cross-Encoder Reranker: {e}") | |
print("Please ensure the model identifier 'cross-encoder/ms-marco-MiniLM-L6-v2' is correct and accessible on Hugging Face Hub.") | |
print(traceback.format_exc()) | |
reranker = None | |
# This block contains the full combined script for testing. | |
# This block contains the full combined script for testing. | |
# It includes all the code from the previous successful steps. | |
# Google Sheets Authentication | |
gc = None # Global variable for gspread client | |
def authenticate_google_sheets(): | |
"""Authenticates with Google Sheets using base64 encoded credentials.""" | |
global gc | |
print("Authenticating Google Account...") | |
if not GOOGLE_BASE64_CREDENTIALS: | |
print("Error: GOOGLE_BASE64_CREDENTIALS secret not found.") | |
return False | |
try: | |
# Decode the base64 credentials | |
credentials_json = base64.b64decode(GOOGLE_BASE64_CREDENTIALS).decode('utf-8') | |
credentials = json.loads(credentials_json) | |
# Authenticate using service account from dictionary | |
gc = gspread.service_account_from_dict(credentials) | |
print("Google Sheets authentication successful via service account.") | |
return True | |
except Exception as e: | |
print(f"Google Sheets authentication failed: {e}") | |
print("Please ensure your GOOGLE_BASE64_CREDENTIALS secret is correctly set and contains valid service account credentials.") | |
print(traceback.format_exc()) | |
return False | |
# Google Sheets Data Loading and Embedding | |
# business_data = [] # Global variable to store loaded data - This was intended to be global, but needs to be named 'data' to match usage | |
data = [] # Global variable to store loaded data - Renamed to 'data' | |
descriptions_for_embedding = [] | |
embeddings = torch.tensor([]) | |
business_info_available = False # Flag to indicate if business info was loaded successfully | |
def load_business_info(): | |
"""Loads business information from Google Sheet and creates embeddings.""" | |
global data, descriptions_for_embedding, embeddings, business_info_available # Added 'data' to global | |
business_info_available = False # Reset flag | |
if gc is None: | |
print("Skipping Google Sheet loading: Google Sheets client not authenticated.") | |
return | |
if not SHEET_ID: | |
print("Error: SHEET_ID not set.") | |
return | |
try: | |
sheet = gc.open_by_key(SHEET_ID).sheet1 | |
print(f"Successfully opened Google Sheet with ID: {SHEET_ID}") | |
data_records = sheet.get_all_records() | |
if not data_records: | |
print(f"Warning: No data records found in Google Sheet with ID: {SHEET_ID}") | |
data = [] # Use the global 'data' | |
descriptions_for_embedding = [] | |
else: | |
# Filter out rows missing 'Service' or 'Description' | |
filtered_data = [row for row in data_records if row.get('Service') and row.get('Description')] | |
if not filtered_data: | |
print("Warning: Filtered data is empty after checking for 'Service' and 'Description'.") | |
data = [] # Use the global 'data' | |
descriptions_for_embedding = [] | |
else: | |
data = filtered_data # Assign to the global 'data' | |
# Use BOTH Service and Description for embedding | |
descriptions_for_embedding = [f"Service: {row['Service']}. Description: {row['Description']}" for row in data] | |
# Only encode if descriptions_for_embedding are found and embedder is available | |
if descriptions_for_embedding and embedder is not None: | |
print("Encoding descriptions...") | |
try: | |
embeddings = embedder.encode(descriptions_for_embedding, convert_to_tensor=True) | |
print("Encoding complete.") | |
business_info_available = True # Set flag if successful | |
except Exception as e: | |
print(f"Error during description encoding: {e}") | |
embeddings = torch.tensor([]) # Ensure embeddings is an empty tensor on error | |
business_info_available = False # Encoding failed | |
else: | |
print("Skipping encoding descriptions: No descriptions found or embedder not available.") | |
embeddings = torch.tensor([]) # Ensure embeddings is an empty tensor | |
business_info_available = False # Cannot use RAG without descriptions or embedder | |
print(f"Loaded {len(descriptions_for_embedding)} entries from Google Sheet for embedding/RAG.") | |
if not business_info_available: | |
print("Business information retrieval (RAG) is NOT available.") | |
except gspread.exceptions.SpreadsheetNotFound: | |
print(f"Error: Google Sheet with ID '{SHEET_ID}' not found.") | |
print("Please check the SHEET_ID and ensure your authenticated Google Account has access to this sheet.") | |
business_info_available = False # Sheet not found | |
except Exception as e: | |
print(f"An error occurred while accessing the Google Sheet: {e}") | |
print(traceback.format_exc()) | |
business_info_available = False # Other sheet access error | |
# Business Info Retrieval (RAG) | |
def retrieve_business_info(query: str, top_n: int = 3) -> list: | |
""" | |
Retrieves relevant business information from loaded data based on a query. | |
Args: | |
query: The user's query string. | |
top_n: The number of top relevant entries to retrieve. | |
Returns: | |
A list of dictionaries, where each dictionary is a relevant row from the | |
Google Sheet data. Returns an empty list if RAG is not available or | |
no relevant information is found. | |
""" | |
# Access the global 'data' variable | |
global data | |
if not business_info_available or embedder is None or not descriptions_for_embedding or not data: # Added check for data | |
print("Business information retrieval is not available or data is empty.") | |
return [] | |
try: | |
# Compute the embedding for the query | |
query_embedding = embedder.encode(query, convert_to_tensor=True) | |
# Compute cosine similarity between the query embedding and all description embeddings | |
cosine_scores = util.cos_sim(query_embedding, embeddings)[0] | |
# Get the top N indices based on cosine similarity | |
top_results_indices = torch.topk(cosine_scores, k=min(top_n, len(data)))[1].tolist() # Use len(data) | |
# Retrieve the actual data entries corresponding to the top indices | |
top_results = [data[i] for i in top_results_indices] # Use data[i] | |
# Optional: Re-rank the top results using the Cross-Encoder | |
if reranker is not None and top_results: | |
print("Re-ranking top results...") | |
# Create pairs of (query, description) for the Cross-Encoder | |
rerank_pairs = [(query, descriptions_for_embedding[i]) for i in top_results_indices] | |
rerank_scores = reranker.predict(rerank_pairs) | |
# Sort the top results based on the re-ranker scores | |
reranked_indices = sorted(range(len(rerank_scores)), key=lambda i: rerank_scores[i], reverse=True) | |
reranked_results = [top_results[i] for i in reranked_indices] | |
print("Re-ranking complete.") | |
return reranked_results | |
else: | |
return top_results | |
except Exception as e: | |
print(f"Error during business information retrieval: {e}") | |
print(traceback.format_exc()) | |
return [] | |
# Function to perform DuckDuckGo Search and return results with URLs | |
def perform_duckduckgo_search(query: str, max_results: int = 5): | |
""" | |
Performs a search using DuckDuckGo and returns a list of dictionaries. | |
Includes a delay to avoid rate limits. | |
Returns an empty list and prints an error if search fails. | |
""" | |
print(f"Executing Tool: perform_duckduckgo_search with query='{query}')") | |
search_results_list = [] | |
try: | |
# Add a delay before each search | |
time.sleep(1) # Sleep for 1 second | |
with DDGS() as ddgs: | |
if not query or len(query.split()) < 2: | |
print(f"Skipping search for short query: '{query}'") | |
return [] | |
# Use text() method for general text search | |
results_generator = ddgs.text(query, max_results=max_results) | |
results_found = False | |
for r in results_generator: | |
search_results_list.append(r) | |
results_found = True | |
if not results_found and max_results > 0: | |
print(f"DuckDuckGo search for '{query}' returned no results.") | |
except Exception as e: | |
print(f"Error during Duckduckgo search for '{query}': {e}") | |
return [] | |
return search_results_list | |
# Function to perform date calculation if needed | |
def perform_date_calculation(query: str): | |
""" | |
Analyzes query for date calculation requests and performs the calculation. | |
Returns a dict describing the calculation and result, or None. | |
Handles formats like 'X days ago', 'X days from now', 'X weeks ago', 'X weeks from now', 'what is today's date'. | |
Uses dateutil for slightly more flexibility (though core logic remains simple). | |
""" | |
print(f"Executing Tool: perform_date_calculation with query='{query}')") | |
query_lower = query.lower() | |
today = date.today() | |
result_date = None | |
calculation_description = None | |
if re.search(r"\btoday'?s date\b|what is today'?s date\b|what day is it\b", query_lower): | |
result_date = today | |
calculation_description = f"The current date is: {today.strftime('%Y-%m-%d')}" | |
print(f"Identified query for today's date.") | |
return {"query": query, "description": calculation_description, "result": result_date.strftime('%Y-%m-%d'), "success": True} | |
match = re.search(r"(\d+)\s+(day|week|month|year)s?\s+(ago|from now)", query_lower) | |
if match: | |
value = int(match.group(1)) | |
unit = match.group(2) | |
direction = match.group(3) | |
try: | |
if unit == 'day': | |
delta = timedelta(days=value) | |
elif unit == 'week': | |
delta = timedelta(weeks=value) | |
elif unit == 'month': | |
delta = relativedelta(months=value) | |
elif unit == 'year': | |
delta = relativedelta(years=value) | |
else: | |
desc = f"Could not understand the time unit '{unit}' in '{query}'." | |
print(desc) | |
return {"query": query, "description": desc, "result": None, "success": False, "error": desc} | |
if direction == 'ago': | |
result_date = today - delta | |
calculation_description = f"Calculating date {value} {unit}s ago from {today.strftime('%Y-%m-%d')}: {result_date.strftime('%Y-%m-%d')}" | |
elif direction == 'from now': | |
result_date = today + delta | |
calculation_description = f"Calculating date {value} {unit}s from now from {today.strftime('%Y-%m-%d')}: {result_date.strftime('%Y-%m-%d')}" | |
print(f"Performed date calculation: {calculation_description}") | |
return {"query": query, "description": calculation_description, "result": result_date.strftime('%Y-%m-%d'), "success": True} | |
except OverflowError: | |
desc = f"Date calculation overflow for query: {query}" | |
print(f"Date calculation overflow for query: {query}") | |
return {"query": query, "description": desc, "result": None, "success": False, "error": desc} | |
except Exception as e: | |
desc = f"An error occurred during date calculation for query '{query}': {e}" | |
print(desc) | |
return {"query": query, "description": desc, "result": None, "success": False, "error": str(e)} | |
desc = "No specific date calculation pattern recognized." | |
print(f"No specific date calculation pattern found in query: '{query}'") | |
return {"query": query, "description": desc, "result": None, "success": False} | |
# ββββββββββββββββββββββββββ | |
# 2 Chat handler | |
# ββββββββββββββββββββββββββ | |
def respond( | |
message: str, | |
history: list[tuple[str, str]], | |
system_message: str, | |
max_tokens: int, | |
temperature: float, | |
top_p: float, | |
): | |
# Retrieve relevant business information based on the user's message | |
retrieved_info = retrieve_business_info(message) | |
# Build ChatML conversation | |
messages = [{"role": "system", "content": system_message}] | |
# Include retrieved information as context if available | |
if retrieved_info: | |
# Modified context formatting | |
context_message = "Use the following business information to help answer the user's question if relevant:\n" | |
for i, info in enumerate(retrieved_info): | |
# Use a clear delimiter between entries | |
context_message += f"--- Business Info Entry {i+1} ---\n" | |
# Include all key-value pairs from the dictionary | |
for key, value in info.items(): | |
# Ensure values are strings | |
context_message += f"{key}: {str(value)}\n" | |
context_message += "---\n" # Delimiter after each entry | |
# Add the formatted context as a user message right after the initial system message | |
# This format might help the model see it as explicit information provided for the current turn | |
messages.append({"role": "user", "content": context_message}) | |
print("Added retrieved business info to messages in a new format.") # Debug print | |
# Add conversation history | |
for user_msg, bot_msg in history: | |
if user_msg: | |
messages.append({"role": "user", "content": user_msg}) | |
if bot_msg: | |
messages.append({"role": "assistant", "content": bot_msg}) | |
# Add the current user message | |
messages.append({"role": "user", "content": message}) | |
# Stream tokens | |
response = "" | |
try: | |
for chunk in client.chat_completion( | |
messages=messages, | |
max_tokens=max_tokens, | |
stream=True, | |
temperature=temperature, | |
top_p=top_p, | |
): | |
token = chunk.choices[0].delta.content or "" | |
response += token | |
yield response | |
except Exception as e: | |
print(f"Error during chat completion: {e}") | |
print(traceback.format_exc()) | |
yield f"An error occurred: {e}" | |
# ββββββββββββββββββββββββββ | |
# 3 Gradio interface | |
# ββββββββββββββββββββββββββ | |
# The Gradio interface definition remains the same as it correctly | |
# uses the updated respond function. | |
print(f"RAG functionality available: {business_info_available}") | |
demo = gr.ChatInterface( | |
fn=respond, | |
additional_inputs=[ | |
gr.Textbox(value="You are a friendly Chatbot. Use the provided business information to answer questions when relevant.", label="System message"), | |
gr.Slider(1, 2048, value=512, step=1, label="Max new tokens"), | |
gr.Slider(0.1, 4.0, value=0.7, step=0.1, label="Temperature"), | |
gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Topβp (nucleus sampling)"), | |
], | |
title="Gemmaβ2β9BβIT Chat with RAG", | |
description="Chat with Googleβ―Gemmaβ2β9BβIT via Huggingβ―Face Inference API, with business info retrieved from Google Sheets.", | |
) | |
# Enable request queueing (concurrency handled automatically on Gradio β₯β―4) | |
demo.queue() | |
if __name__ == "__main__": | |
# Authenticate and load data before launching the demo | |
if authenticate_google_sheets(): | |
load_business_info() | |
else: | |
print("Google Sheets authentication failed. RAG functionality will not be available.") | |
# The print statement for RAG status is added here, before launching the demo. | |
print(f"RAG functionality available: {business_info_available}") | |
demo.launch() |