Spaces:
Sleeping
Sleeping
import gradio as gr | |
import time | |
import os | |
import json | |
import torch | |
from transformers import AutoTokenizer, AutoModel | |
from sentence_transformers import SentenceTransformer, util | |
# --- Path Configuration --- | |
# Get the absolute path of the directory containing this script | |
script_dir = os.path.dirname(os.path.abspath(__file__)) | |
# Check if running in a Hugging Face Space | |
is_hf_space = "SPACE_ID" in os.environ | |
if is_hf_space: | |
# In a Space, load model from the Hub and data from the repo root | |
model_path = os.environ.get("MODEL_REPO_ID", "philtoms/minilm-alice-base-rsft-v2") | |
data_path = "training_triplets.jsonl" | |
print(f"Running on HF Spaces. Using model from Hub: {model_path}") | |
else: | |
# Locally, construct absolute paths based on the script's location | |
model_path = os.path.join(script_dir, "..", "models", "minilm-alice-base-rsft-v2", "final") | |
data_path = os.path.join(script_dir, "..", "data", "training_triplets.jsonl") | |
print(f"Running locally. Using local model at: {model_path}") | |
# --- Model and Tokenizer Loading --- | |
try: | |
# model_path = "sentence-transformers/all-MiniLM-L6-v2" | |
model_path = "sentence-transformers/multi-qa-mpnet-base-cos-v1" | |
# model_path = "Qwen/Qwen3-Embedding-0.6B" | |
# tokenizer = AutoTokenizer.from_pretrained(model_path) | |
# model = AutoModel.from_pretrained(model_path) | |
model = SentenceTransformer(model_path) | |
except Exception as e: | |
raise gr.Error(f"Failed to load model from '{model_path}'. Error: {e}") | |
# --- Dataset Loading --- | |
if not os.path.exists(data_path): | |
raise gr.Error(f"Data file not found at '{data_path}'. Please ensure the file exists.") | |
dataset = [] | |
with open(data_path, "r") as f: | |
for line in f: | |
dataset.append(json.loads(line)) | |
# Pre-compute corpus embeddings | |
import re | |
# def split_into_sentences(text): | |
# """Splits a paragraph into sentences based on capitalization and punctuation.""" | |
# # This regex looks for a capital letter, followed by anything that's not a period, | |
# # exclamation mark, or question mark, and then ends with one of those punctuation marks. | |
# sentences = re.findall(r'([A-Z][^.!?]*[.!?])', text) | |
# return sentences | |
# def create_overlapped_chunks(corpus_documents, chunk_size=2, overlap=1): | |
# chunked_corpus = [] | |
# for doc_idx, doc_text in enumerate(corpus_documents): | |
# sentences = split_into_sentences(doc_text) | |
# if not sentences: | |
# continue | |
# # If there are fewer sentences than chunk_size, just use the whole document as one chunk | |
# if len(sentences) < chunk_size: | |
# chunked_corpus.append({ | |
# "text": doc_text, | |
# "original_doc_idx": doc_idx, | |
# "start_sentence_idx": 0, | |
# "end_sentence_idx": len(sentences) - 1 | |
# }) | |
# continue | |
# for i in range(0, len(sentences) - chunk_size + 1, chunk_size - overlap): | |
# chunk_sentences = sentences[i : i + chunk_size] | |
# chunk_text = " ".join(chunk_sentences) | |
# chunked_corpus.append({ | |
# "text": chunk_text, | |
# "original_doc_idx": doc_idx, | |
# "start_sentence_idx": i, | |
# "end_sentence_idx": i + chunk_size - 1 | |
# }) | |
# return chunked_corpus | |
# def process_documents_for_chunking(documents): | |
# chunked_corpus_data = create_overlapped_chunks(documents) | |
# flat_corpus_chunks = [item["text"] for item in chunked_corpus_data] | |
# return chunked_corpus_data, flat_corpus_chunks | |
# Pre-compute corpus embeddings | |
original_corpus = [item["positive"] for item in dataset] | |
# chunked_corpus_data, flat_corpus_chunks = process_documents_for_chunking(original_corpus) | |
# corpus_embeddings = model.encode(flat_corpus_chunks) | |
corpus_embeddings = model.encode(original_corpus) | |
# def find_similar(prompt, top_k): | |
# start_time = time.time() | |
# prompt_embedding = model.encode(prompt) | |
# scores = util.dot_score(prompt_embedding, corpus_embeddings)[0].cpu().tolist() | |
# # Pair scores with the chunked corpus data | |
# scored_chunks = [] | |
# for i, score in enumerate(scores): | |
# scored_chunks.append({ | |
# "score": score, | |
# "text": chunked_corpus_data[i]["text"], | |
# "original_doc_idx": chunked_corpus_data[i]["original_doc_idx"] | |
# }) | |
# # Sort by decreasing score | |
# scored_chunks = sorted(scored_chunks, key=lambda x: x["score"], reverse=True) | |
# results = [] | |
# for item in scored_chunks[:top_k]: | |
# # Return the original document text, not just the chunk | |
# original_doc_text = original_corpus[item["original_doc_idx"]] | |
# results.append((item["score"], original_doc_text)) | |
# end_time = time.time() | |
# return results, f"{(end_time - start_time) * 1000:.2f} ms" | |
# with torch.no_grad(): | |
# encoded_corpus = tokenizer(corpus, padding=True, truncation=True, return_tensors='pt') | |
# corpus_embeddings = model(**encoded_corpus).last_hidden_state.mean(dim=1) | |
def find_similar(prompt, top_k): | |
start_time = time.time() | |
prompt_embedding = model.encode(prompt) | |
scores = util.dot_score(prompt_embedding, corpus_embeddings)[0].cpu().tolist() | |
doc_score_pairs = list(zip(original_corpus, scores)) | |
#Sort by decreasing score | |
doc_score_pairs = sorted(doc_score_pairs, key=lambda x: x[1], reverse=True) | |
# with torch.no_grad(): | |
# encoded_prompt = tokenizer(prompt, padding=True, truncation=True, return_tensors='pt') | |
# prompt_embedding = model(**encoded_prompt).last_hidden_state.mean(dim=1) | |
# cos_scores = torch.nn.functional.cosine_similarity(prompt_embedding, corpus_embeddings, dim=1) | |
# top_results = torch.topk(cos_scores, k=int(top_k)) | |
end_time = time.time() | |
results = [] | |
for doc, score in doc_score_pairs[:top_k]: | |
# for doc, score in doc_score_pairs: | |
results.append((score, doc)) | |
return results, f"{(end_time - start_time) * 1000:.2f} ms" | |
iface = gr.Interface( | |
fn=find_similar, | |
inputs=[ | |
gr.Dropdown( | |
["Alice sees White rabbit for the first time", "Alice meets caterpillar", "sad turtle story"], | |
label="Select a prompt or type your own", | |
allow_custom_value=True | |
), | |
gr.Slider(1, 20, value=5, step=1, label="Top K") | |
], | |
outputs=[ | |
gr.Dataframe(headers=[ "Score", "Response"]), | |
gr.Textbox(label="Time Taken") | |
], | |
title="RSFT Alice Embeddings (Transformers)", | |
description=f"Enter a prompt to find similar sentences from the corpus." | |
) | |
if __name__ == "__main__": | |
iface.launch() |