mrl / src /streamlit_app.py
dejanseo's picture
Update src/streamlit_app.py
8e633b2 verified
import streamlit as st
from sentence_transformers import SentenceTransformer
from sentence_transformers.quantization import quantize_embeddings
import numpy as np
st.set_page_config(page_title="MXBAI Embed Demo", layout="centered")
st.title("MRL + Binary Embedding Demo by DEJAN")
st.markdown(
"""
Enter any sentence (or multiple sentences separated by newline).
Use the checkboxes on the left sidebar to toggle:
- **MRL** (truncate to 256 dimensions)
- **Binary** (apply binary quantization)
"""
)
# Sidebar controls
st.sidebar.header("Options")
use_mrl = st.sidebar.checkbox("Enable MRL", value=False)
use_binary = st.sidebar.checkbox("Enable Binary", value=False)
# Cache the model loading so Streamlit doesn't reload on every interaction unnecessarily.
@st.cache_resource(show_spinner=False)
def load_model(mrl: bool) -> SentenceTransformer:
"""
Load the sentence-transformers model.
If mrl=True, load with truncate_dim=256. Otherwise, load default (512 dims).
"""
if mrl:
# Matryoshka Representation: cut down to 256 dims
return SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1", truncate_dim=256)
else:
# Full 512-dimension model
return SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
model = load_model(use_mrl)
# Text input
input_text = st.text_area(
"Enter text here:",
placeholder="Type one or more sentences..."
)
if not input_text.strip():
st.info("⏳ Please enter some text above to see its embedding.")
st.stop()
if st.button("Compute Embedding"):
# Split input into lines (one sentence per line if multiple)
sentences = [line.strip() for line in input_text.split("\n") if line.strip()]
if len(sentences) == 0:
st.error("❗ Please enter at least one non-empty sentence.")
st.stop()
embeddings = model.encode(sentences) # shape: (num_sentences, dim)
if use_binary:
# quantize_embeddings expects a 2D numpy array
embeddings = quantize_embeddings(np.vstack(embeddings), precision="ubinary")
st.subheader("Binary-Quantized Embeddings (dtype=bool)")
st.write(embeddings)
else:
st.subheader(f"Floating‐Point Embeddings (shape = {embeddings.shape})")
for idx, emb in enumerate(embeddings):
st.markdown(f"**Sentence #{idx+1}:**")
st.write(emb.tolist())
st.success("✅ Done!")