Matryoshka Embedding Model (Merged) for SEA-LION 8B
This repository contains the full, standalone fine-tuned model weights for a Matryoshka-style text embedding model based on aisingapore/Llama-SEA-LION-v3.5-8B-R
. The LoRA adapters have been merged into the base model for simpler deployment.
Note: This is the full model, resulting in a large repository size (16GB+). For a much more lightweight version (~50MB) that uses LoRA adapters, please see the adapter-only repository here.
Model Features
- Base Model:
aisingapore/Llama-SEA-LION-v3.5-8B-R
- Latent Attention Pooling: A sophisticated pooling mechanism that uses cross-attention to summarize token sequences into a single vector.
- Matryoshka Representation Learning (MRL): Trained to produce nested embeddings. You can use the full 4096-dimension embedding for maximum performance, or slice it to a smaller dimension (e.g., 1024, 512, 128) for a trade-off in speed and storage.
Intended Use
This model is ideal for generating fixed-size embeddings for tasks like:
- Semantic Search & Information Retrieval
- Retrieval-Augmented Generation (RAG)
- Clustering and Text Similarity
How to Use
Loading this model is simpler as the LoRA adapters are already merged. You still need the custom code from modeling.py
and the weights for the pooling/projection heads.
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import hf_hub_download
import importlib.util
# --- 1. Setup and Load Components ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
repo_id = "evoreign/sea-lion-8b-mrl-embedding-merged"
# --- 2. Dynamically Load Custom Classes ---
print("Downloading custom modeling code...")
modeling_path = hf_hub_download(repo_id=repo_id, filename="modeling.py")
spec = importlib.util.spec_from_file_location("modeling", modeling_path)
modeling = importlib.util.module_from_spec(spec)
spec.loader.exec_module(modeling)
LatentAttentionPooling = modeling.LatentAttentionPooling
MatryoshkaProjection = modeling.MatryoshkaProjection
print("Custom classes loaded successfully.")
# --- 3. Load Merged Model ---
print("Loading the full merged model (this may take time and memory)...")
# No PeftModel needed, we load directly from the repo ID.
model = AutoModelForCausalLM.from_pretrained(
repo_id,
torch_dtype=torch.float16, # Use float16 for memory efficiency
device_map="auto",
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(repo_id)
# --- 4. Load Custom Pooling and Projection Heads ---
HIDDEN_SIZE = model.config.hidden_size
MAX_DIM = 4096
print("Loading custom pooling and projection heads...")
pooler = LatentAttentionPooling(hidden_size=HIDDEN_SIZE).to(device).to(torch.float16)
projection = MatryoshkaProjection(hidden_size=HIDDEN_SIZE, max_embed_dim=MAX_DIM).to(device).to(torch.float16)
pooler_path = hf_hub_download(repo_id=repo_id, filename="pooler.pt")
projection_path = hf_hub_download(repo_id=repo_id, filename="projection.pt")
pooler.load_state_dict(torch.load(pooler_path, map_location=device))
projection.load_state_dict(torch.load(projection_path, map_location=device))
model.eval()
pooler.eval()
projection.eval()
# --- 5. Create the Inference Function ---
def embed_texts_mrl(texts, out_dim=None):
with torch.no_grad():
inputs = tokenizer(
texts, return_tensors="pt", padding=True, truncation=True, max_length=4096
).to(device)
# Use model() directly as it's not a PeftModel
out = model(**inputs, output_hidden_states=True)
hidden = out.hidden_states[-1]
mask = inputs.attention_mask
pooled = pooler(hidden, attention_mask=mask)
z_max = projection(pooled)
z = z_max[:, :out_dim] if out_dim else z_max
return F.normalize(z, p=2, dim=1)
# --- 6. Example Usage ---
my_texts = ["Contoh kalimat untuk di-embed.", "Another sentence to embed."]
emb_256 = embed_texts_mrl(my_texts, out_dim=256)
print("Sliced embedding shape:", emb_256.shape)
# Expected output: torch.Size([2, 256])
Training Details
- Loss Function: In-batch contrastive loss with hard negatives.
- MRL Objective: Loss was averaged across dimensions [128, 256, 512, 1024, 2048, 4096].
- Dataset: Fine-tuned on a private triplet dataset (
query
,positive
,hard_negative
).
Author: [Edbert Khovey]
- Downloads last month
- 3
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
馃檵
Ask for provider support
Model tree for evoreign/sea-lion-8b-mrl-embedding-merged
Base model
meta-llama/Llama-3.1-8B
Finetuned
meta-llama/Llama-3.1-8B-Instruct
Finetuned
aisingapore/Llama-SEA-LION-v3-8B
Finetuned
aisingapore/Llama-SEA-LION-v3-8B-IT
Finetuned
aisingapore/Llama-SEA-LION-v3.5-8B-R