import gradio as gr from sentence_transformers import SentenceTransformer from rank_bm25 import BM25Okapi from transformers import AutoTokenizer, AutoModel import torch # 1. Dense embedding model (HF bi-encoder) # dense_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') # dense_model = SentenceTransformer('distiluse-base-multilingual-cased-v2') dense_model = SentenceTransformer('multi-qa-mpnet-base-cos-v1') def embed_dense(text: str): if not text.strip(): return {"error": "Input text is empty."} emb = dense_model.encode([text])[0] return {"dense_embedding": emb.tolist()} # 2. Sparse embedding model (BM25) # Uses rank_bm25 to compute term weights def embed_sparse(text: str): if not text.strip(): return {"error": "Input text is empty."} tokens = text.split() bm25 = BM25Okapi([tokens]) unique_terms = sorted(set(tokens)) scores = bm25.get_scores(unique_terms) # Assign scores for all unique terms term_weights = {term: float(score) for term, score in zip(unique_terms, scores)} indices = list(range(len(unique_terms))) values = [term_weights.get(term, 0.0) for term in unique_terms] return {"indices": indices, "values": values, "terms": unique_terms} # 3. Late-interaction embedding model (ColBERT) colbert_tokenizer = AutoTokenizer.from_pretrained('colbert-ir/colbertv2.0', use_fast=True) colbert_model = AutoModel.from_pretrained('colbert-ir/colbertv2.0') # Freeze model parameters for inference speed for param in colbert_model.parameters(): param.requires_grad = False def embed_colbert(text: str): if not text.strip(): return {"error": "Input text is empty."} inputs = colbert_tokenizer(text, return_tensors='pt', truncation=True, max_length=64) with torch.no_grad(): outputs = colbert_model(**inputs) # last_hidden_state: (1, seq_len, hidden_size) embeddings = outputs.last_hidden_state.squeeze(0).tolist() return {"colbert_embeddings": embeddings} # Build Gradio interface with tabs for each model with gr.Blocks(title="Text Embedding Playground") as demo: gr.Markdown("# Text Embedding Playground\nChoose a model and input text to get embeddings.") with gr.Tab("Dense (MiniLM-L6-v2)"): txt1 = gr.Textbox(lines=3, label="Input Text") out1 = gr.JSON(label="Embedding") txt1.submit(embed_dense, txt1, out1) gr.Button("Embed").click(embed_dense, txt1, out1) with gr.Tab("Sparse (BM25)"): txt2 = gr.Textbox(lines=3, label="Input Text") out2 = gr.JSON(label="Term Weights") txt2.submit(embed_sparse, txt2, out2) gr.Button("Embed").click(embed_sparse, txt2, out2) with gr.Tab("Late-Interaction (ColBERT)"): txt3 = gr.Textbox(lines=3, label="Input Text") out3 = gr.JSON(label="Embeddings per Token") txt3.submit(embed_colbert, txt3, out3) gr.Button("Embed").click(embed_colbert, txt3, out3) if __name__ == "__main__": demo.launch(mcp_server=True)