Matryoshka Representation Learning
Paper • 2205.13147 • Published • 26
This is a sentence-transformers model trained. It maps sentences & paragraphs to a 512-dimensional dense vector space and can be used for retrieval.
SentenceTransformer(
(0): StaticEmbedding({})
)
First install the Sentence Transformers library:
pip install -U sentence-transformers
Then you can load this model and run inference.
from sentence_transformers import SentenceTransformer
# Download from the 🤗 Hub
model = SentenceTransformer("oneryalcin/static-embedding-chess")
# Run inference
queries = [
'[UNK] deflection discoveredAttack [UNK] queensideAttack short Philidor Defense [UNK] Defense Other variations',
]
documents = [
'themes crushing deflection discoveredAttack middlegame queensideAttack short opening Philidor Defense Philidor Defense Other variations moves d3c3 d4b3 c1b1 d7d1 d3c3+d4b3 d4b3+c1b1 c1b1+d7d1',
'themes advantage discoveredAttack middlegame short opening Philidor Defense Philidor Defense Other variations moves e4d4 d3f5 c8b8 d1d4 e4d4+d3f5 d3f5+c8b8 c8b8+d1d4',
'themes crushing middlegame pin queensideAttack short opening Sicilian Defense Sicilian Defense Najdorf Variation moves c3d5 c5b3 c1b1 b3d2 c3d5+c5b3 c5b3+c1b1 c1b1+b3d2',
]
query_embeddings = model.encode_query(queries)
document_embeddings = model.encode_document(documents)
print(query_embeddings.shape, document_embeddings.shape)
# [1, 512] [3, 512]
# Get the similarity scores for the embeddings
similarities = model.similarity(query_embeddings, document_embeddings)
print(similarities)
# tensor([[0.8405, 0.5061, 0.2136]])
chess-ir and chess-ir-tokensInformationRetrievalEvaluator| Metric | chess-ir | chess-ir-tokens |
|---|---|---|
| cosine_accuracy@1 | 0.005 | 0.0794 |
| cosine_accuracy@10 | 0.07 | 0.2593 |
| cosine_precision@1 | 0.005 | 0.0794 |
| cosine_precision@10 | 0.008 | 0.0603 |
| cosine_recall@1 | 0.0017 | 0.0022 |
| cosine_recall@10 | 0.0267 | 0.024 |
| cosine_ndcg@10 | 0.0168 | 0.0672 |
| cosine_mrr@10 | 0.0207 | 0.1233 |
| cosine_map@100 | 0.0141 | 0.0332 |
anchor and positive| anchor | positive | |
|---|---|---|
| type | string | string |
| modality | text | text |
| details |
|
|
| anchor | positive |
|---|---|
kingsideAttack mate mateIn1 middlegame oneMove Horwitz Defense Horwitz Defense [UNK] variations |
themes kingsideAttack mate mateIn1 middlegame oneMove opening Horwitz Defense Horwitz Defense Other variations moves f7h8 g6g2 f7h8+g6g2 |
backRankMate endgame mate mateIn2 short Kings Knight Opening Kings Knight Opening [UNK] [UNK] |
themes backRankMate endgame mate mateIn2 short opening Kings Knight Opening Kings Knight Opening Other variations moves c5d4 c3c8 g5d8 c8d8 c5d4+c3c8 c3c8+g5d8 g5d8+c8d8 |
kingsideAttack mate mateIn1 middlegame oneMove Sicilian Defense Sicilian Defense Paulsen-Basman Defense |
themes kingsideAttack mate mateIn1 middlegame oneMove opening Sicilian Defense Sicilian Defense Paulsen-Basman Defense moves g3f3 c7h2 g3f3+c7h2 |
MatryoshkaLoss with these parameters:{
"loss": "MultipleNegativesRankingLoss",
"matryoshka_dims": [
512,
256,
128,
64,
32
],
"matryoshka_weights": [
1,
1,
1,
1,
1
],
"n_dims_per_step": -1
}
per_device_train_batch_size: 4096num_train_epochs: 20learning_rate: 0.01warmup_steps: 0.1weight_decay: 0.01per_device_eval_batch_size: 4096push_to_hub: Truehub_model_id: oneryalcin/static-embedding-chessload_best_model_at_end: Trueseed: 12per_device_train_batch_size: 4096num_train_epochs: 20max_steps: -1learning_rate: 0.01lr_scheduler_type: linearlr_scheduler_kwargs: Nonewarmup_steps: 0.1optim: adamw_torch_fusedoptim_args: Noneweight_decay: 0.01adam_beta1: 0.9adam_beta2: 0.999adam_epsilon: 1e-08optim_target_modules: Nonegradient_accumulation_steps: 1average_tokens_across_devices: Truemax_grad_norm: 1.0label_smoothing_factor: 0.0bf16: Falsefp16: Falsebf16_full_eval: Falsefp16_full_eval: Falsetf32: Nonegradient_checkpointing: Falsegradient_checkpointing_kwargs: Nonetorch_compile: Falsetorch_compile_backend: Nonetorch_compile_mode: Noneuse_liger_kernel: Falseliger_kernel_config: Noneuse_cache: Falseneftune_noise_alpha: Nonetorch_empty_cache_steps: Noneauto_find_batch_size: Falselog_on_each_node: Truelogging_nan_inf_filter: Trueinclude_num_input_tokens_seen: nolog_level: passivelog_level_replica: warningdisable_tqdm: Falseproject: huggingfacetrackio_space_id: Nonetrackio_bucket_id: Nonetrackio_static_space_id: Noneper_device_eval_batch_size: 4096prediction_loss_only: Trueeval_on_start: Falseeval_do_concat_batches: Trueeval_use_gather_object: Falseeval_accumulation_steps: Noneinclude_for_metrics: []batch_eval_metrics: Falsesave_only_model: Falsesave_on_each_node: Falseenable_jit_checkpoint: Falsepush_to_hub: Truehub_private_repo: Nonehub_model_id: oneryalcin/static-embedding-chesshub_strategy: every_savehub_always_push: Falsehub_revision: Noneload_best_model_at_end: Trueignore_data_skip: Falserestore_callback_states_from_checkpoint: Falsefull_determinism: Falseseed: 12data_seed: Noneuse_cpu: Falseaccelerator_config: {'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None}parallelism_config: Nonedataloader_drop_last: Falsedataloader_num_workers: 0dataloader_pin_memory: Truedataloader_persistent_workers: Falsedataloader_prefetch_factor: Noneremove_unused_columns: Truelabel_names: Nonetrain_sampling_strategy: randomlength_column_name: lengthddp_find_unused_parameters: Noneddp_bucket_cap_mb: Noneddp_broadcast_buffers: Falseddp_static_graph: Noneddp_backend: Noneddp_timeout: 1800fsdp: []fsdp_config: {'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False}deepspeed: Nonedebug: []skip_memory_metrics: Truedo_predict: Falseresume_from_checkpoint: Nonewarmup_ratio: Nonelocal_rank: -1prompts: Nonebatch_sampler: batch_samplermulti_dataset_batch_sampler: proportionalrouter_mapping: {}learning_rate_mapping: {}| Epoch | Step | Training Loss | chess-ir_cosine_ndcg@10 | chess-ir-tokens_cosine_ndcg@10 |
|---|---|---|---|---|
| -1 | -1 | - | 0.0123 | 0.0561 |
| 0.0025 | 1 | 27.3123 | - | - |
| 0.2020 | 80 | 26.3304 | - | - |
| 0.4040 | 160 | 22.2114 | - | - |
| 0.6061 | 240 | 17.4522 | - | - |
| 0.8081 | 320 | 12.8864 | - | - |
| 1.0 | 396 | - | 0.0800 | 0.1181 |
| 1.0101 | 400 | 9.1439 | - | - |
| 1.2121 | 480 | 6.5434 | - | - |
| 1.4141 | 560 | 4.9138 | - | - |
| 1.6162 | 640 | 3.9819 | - | - |
| 1.8182 | 720 | 3.4584 | - | - |
| 2.0 | 792 | - | 0.0505 | 0.0938 |
| 2.0202 | 800 | 3.1303 | - | - |
| 2.2222 | 880 | 2.9652 | - | - |
| 2.4242 | 960 | 2.8584 | - | - |
| 2.6263 | 1040 | 2.7907 | - | - |
| 2.8283 | 1120 | 2.7475 | - | - |
| 3.0 | 1188 | - | 0.0251 | 0.0830 |
| 3.0303 | 1200 | 2.7031 | - | - |
| 3.2323 | 1280 | 2.6927 | - | - |
| 3.4343 | 1360 | 2.6516 | - | - |
| 3.6364 | 1440 | 2.6441 | - | - |
| 3.8384 | 1520 | 2.6202 | - | - |
| 4.0 | 1584 | - | 0.0168 | 0.0672 |
@inproceedings{reimers-2019-sentence-bert,
title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks",
author = "Reimers, Nils and Gurevych, Iryna",
booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing",
month = "11",
year = "2019",
publisher = "Association for Computational Linguistics",
url = "https://arxiv.org/abs/1908.10084",
}
@misc{kusupati2024matryoshka,
title={Matryoshka Representation Learning},
author={Aditya Kusupati and Gantavya Bhatt and Aniket Rege and Matthew Wallingford and Aditya Sinha and Vivek Ramanujan and William Howard-Snyder and Kaifeng Chen and Sham Kakade and Prateek Jain and Ali Farhadi},
year={2024},
eprint={2205.13147},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
@misc{oord2019representationlearningcontrastivepredictive,
title={Representation Learning with Contrastive Predictive Coding},
author={Aaron van den Oord and Yazhe Li and Oriol Vinyals},
year={2019},
eprint={1807.03748},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/1807.03748},
}