import streamlit as st from sentence_transformers import SentenceTransformer from sentence_transformers.quantization import quantize_embeddings import numpy as np st.set_page_config(page_title="MXBAI Embed Demo", layout="centered") st.title("MRL + Binary Embedding Demo by DEJAN") st.markdown( """ Enter any sentence (or multiple sentences separated by newline). Use the checkboxes on the left sidebar to toggle: - **MRL** (truncate to 256 dimensions) - **Binary** (apply binary quantization) """ ) # Sidebar controls st.sidebar.header("Options") use_mrl = st.sidebar.checkbox("Enable MRL", value=False) use_binary = st.sidebar.checkbox("Enable Binary", value=False) # Cache the model loading so Streamlit doesn't reload on every interaction unnecessarily. @st.cache_resource(show_spinner=False) def load_model(mrl: bool) -> SentenceTransformer: """ Load the sentence-transformers model. If mrl=True, load with truncate_dim=256. Otherwise, load default (512 dims). """ if mrl: # Matryoshka Representation: cut down to 256 dims return SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1", truncate_dim=256) else: # Full 512-dimension model return SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1") model = load_model(use_mrl) # Text input input_text = st.text_area( "Enter text here:", placeholder="Type one or more sentences..." ) if not input_text.strip(): st.info("⏳ Please enter some text above to see its embedding.") st.stop() if st.button("Compute Embedding"): # Split input into lines (one sentence per line if multiple) sentences = [line.strip() for line in input_text.split("\n") if line.strip()] if len(sentences) == 0: st.error("❗ Please enter at least one non-empty sentence.") st.stop() embeddings = model.encode(sentences) # shape: (num_sentences, dim) if use_binary: # quantize_embeddings expects a 2D numpy array embeddings = quantize_embeddings(np.vstack(embeddings), precision="ubinary") st.subheader("Binary-Quantized Embeddings (dtype=bool)") st.write(embeddings) else: st.subheader(f"Floating‐Point Embeddings (shape = {embeddings.shape})") for idx, emb in enumerate(embeddings): st.markdown(f"**Sentence #{idx+1}:**") st.write(emb.tolist()) st.success("✅ Done!")