philtoms's picture
Upload app.py
380225c verified
import gradio as gr
import time
import os
import json
import torch
from transformers import AutoTokenizer, AutoModel
from sentence_transformers import SentenceTransformer, util
# --- Path Configuration ---
# Get the absolute path of the directory containing this script
script_dir = os.path.dirname(os.path.abspath(__file__))
# Check if running in a Hugging Face Space
is_hf_space = "SPACE_ID" in os.environ
if is_hf_space:
# In a Space, load model from the Hub and data from the repo root
model_path = os.environ.get("MODEL_REPO_ID", "philtoms/minilm-alice-base-rsft-v2")
data_path = "training_triplets.jsonl"
print(f"Running on HF Spaces. Using model from Hub: {model_path}")
else:
# Locally, construct absolute paths based on the script's location
model_path = os.path.join(script_dir, "..", "models", "minilm-alice-base-rsft-v2", "final")
data_path = os.path.join(script_dir, "..", "data", "training_triplets.jsonl")
print(f"Running locally. Using local model at: {model_path}")
# --- Model and Tokenizer Loading ---
try:
# model_path = "sentence-transformers/all-MiniLM-L6-v2"
model_path = "sentence-transformers/multi-qa-mpnet-base-cos-v1"
# model_path = "Qwen/Qwen3-Embedding-0.6B"
# tokenizer = AutoTokenizer.from_pretrained(model_path)
# model = AutoModel.from_pretrained(model_path)
model = SentenceTransformer(model_path)
except Exception as e:
raise gr.Error(f"Failed to load model from '{model_path}'. Error: {e}")
# --- Dataset Loading ---
if not os.path.exists(data_path):
raise gr.Error(f"Data file not found at '{data_path}'. Please ensure the file exists.")
dataset = []
with open(data_path, "r") as f:
for line in f:
dataset.append(json.loads(line))
# Pre-compute corpus embeddings
import re
# def split_into_sentences(text):
# """Splits a paragraph into sentences based on capitalization and punctuation."""
# # This regex looks for a capital letter, followed by anything that's not a period,
# # exclamation mark, or question mark, and then ends with one of those punctuation marks.
# sentences = re.findall(r'([A-Z][^.!?]*[.!?])', text)
# return sentences
# def create_overlapped_chunks(corpus_documents, chunk_size=2, overlap=1):
# chunked_corpus = []
# for doc_idx, doc_text in enumerate(corpus_documents):
# sentences = split_into_sentences(doc_text)
# if not sentences:
# continue
# # If there are fewer sentences than chunk_size, just use the whole document as one chunk
# if len(sentences) < chunk_size:
# chunked_corpus.append({
# "text": doc_text,
# "original_doc_idx": doc_idx,
# "start_sentence_idx": 0,
# "end_sentence_idx": len(sentences) - 1
# })
# continue
# for i in range(0, len(sentences) - chunk_size + 1, chunk_size - overlap):
# chunk_sentences = sentences[i : i + chunk_size]
# chunk_text = " ".join(chunk_sentences)
# chunked_corpus.append({
# "text": chunk_text,
# "original_doc_idx": doc_idx,
# "start_sentence_idx": i,
# "end_sentence_idx": i + chunk_size - 1
# })
# return chunked_corpus
# def process_documents_for_chunking(documents):
# chunked_corpus_data = create_overlapped_chunks(documents)
# flat_corpus_chunks = [item["text"] for item in chunked_corpus_data]
# return chunked_corpus_data, flat_corpus_chunks
# Pre-compute corpus embeddings
original_corpus = [item["positive"] for item in dataset]
# chunked_corpus_data, flat_corpus_chunks = process_documents_for_chunking(original_corpus)
# corpus_embeddings = model.encode(flat_corpus_chunks)
corpus_embeddings = model.encode(original_corpus)
# def find_similar(prompt, top_k):
# start_time = time.time()
# prompt_embedding = model.encode(prompt)
# scores = util.dot_score(prompt_embedding, corpus_embeddings)[0].cpu().tolist()
# # Pair scores with the chunked corpus data
# scored_chunks = []
# for i, score in enumerate(scores):
# scored_chunks.append({
# "score": score,
# "text": chunked_corpus_data[i]["text"],
# "original_doc_idx": chunked_corpus_data[i]["original_doc_idx"]
# })
# # Sort by decreasing score
# scored_chunks = sorted(scored_chunks, key=lambda x: x["score"], reverse=True)
# results = []
# for item in scored_chunks[:top_k]:
# # Return the original document text, not just the chunk
# original_doc_text = original_corpus[item["original_doc_idx"]]
# results.append((item["score"], original_doc_text))
# end_time = time.time()
# return results, f"{(end_time - start_time) * 1000:.2f} ms"
# with torch.no_grad():
# encoded_corpus = tokenizer(corpus, padding=True, truncation=True, return_tensors='pt')
# corpus_embeddings = model(**encoded_corpus).last_hidden_state.mean(dim=1)
def find_similar(prompt, top_k):
start_time = time.time()
prompt_embedding = model.encode(prompt)
scores = util.dot_score(prompt_embedding, corpus_embeddings)[0].cpu().tolist()
doc_score_pairs = list(zip(original_corpus, scores))
#Sort by decreasing score
doc_score_pairs = sorted(doc_score_pairs, key=lambda x: x[1], reverse=True)
# with torch.no_grad():
# encoded_prompt = tokenizer(prompt, padding=True, truncation=True, return_tensors='pt')
# prompt_embedding = model(**encoded_prompt).last_hidden_state.mean(dim=1)
# cos_scores = torch.nn.functional.cosine_similarity(prompt_embedding, corpus_embeddings, dim=1)
# top_results = torch.topk(cos_scores, k=int(top_k))
end_time = time.time()
results = []
for doc, score in doc_score_pairs[:top_k]:
# for doc, score in doc_score_pairs:
results.append((score, doc))
return results, f"{(end_time - start_time) * 1000:.2f} ms"
iface = gr.Interface(
fn=find_similar,
inputs=[
gr.Dropdown(
["Alice sees White rabbit for the first time", "Alice meets caterpillar", "sad turtle story"],
label="Select a prompt or type your own",
allow_custom_value=True
),
gr.Slider(1, 20, value=5, step=1, label="Top K")
],
outputs=[
gr.Dataframe(headers=[ "Score", "Response"]),
gr.Textbox(label="Time Taken")
],
title="RSFT Alice Embeddings (Transformers)",
description=f"Enter a prompt to find similar sentences from the corpus."
)
if __name__ == "__main__":
iface.launch()