Integrate with Sentence Transformers v5.4
Hello!
Pull Request overview
- Integrate
zerank-2with Sentence Transformers (v5.4+) so the model loads via the standardCrossEncoderAPI with no extra kwargs.
Details
This PR supersedes #2, which requires trust_remote_code. Instead, this integration uses the stock CrossEncoder + causal-LM pipeline Transformer(text-generation) -> LogitScore with new (ST v5.4+) built-in classes only, so no trust_remote_code=True is required and no custom modeling code is added on top of what already ships in the repo.
LogitScore returns the raw "Yes" logit at the last position. Rankings can be used directly. To recover the 0-1 score range that the original predict() produces, callers can apply the temperature-scaled sigmoid sigmoid(score / 5) themselves; the README shows the one-liner. Keeping this transformation client-side avoids needing a custom score head and keeps the integration purely config-driven.
chat_template.jinja gets a small new branch at the top: when the input messages carry query / document roles (the convention Sentence Transformers passes to a CrossEncoder), it renders them as <|im_start|>system\n{query}<|im_end|>\n<|im_start|>user\n{document}<|im_end|>\n<|im_start|>assistant\n, which is byte-for-byte identical to the chat-templated string the original format_pointwise_datapoints produces. Any other role configuration falls through to the original Qwen3 logic untouched, so direct tokenizer.apply_chat_template(...) usage is unaffected.
CrossEncoder("zeroentropy/zerank-2").predict(...) loads in bf16 automatically (picked up from config.json) and matches the reference baseline numbers for the README example: [5.40625, -4.5] raw, [0.7461, 0.2891] after (scores / 5).sigmoid().
Added files:
modules.json: theTransformer -> LogitScorepipeline using ST's built-in classes (sentence_transformers.base.modules.transformer.Transformerandsentence_transformers.cross_encoder.modules.logit_score.LogitScore).sentence_bert_config.json: declarestransformer_task: text-generation, thetext/messagemodality config (withformat: flat),module_output_name: causal_logits, andprocessing_kwargs.chat_template.add_generation_prompt: true.config_sentence_transformers.json:model_type: CrossEncoderwithactivation_fn: torch.nn.modules.linear.Identityandprompts: {}.1_LogitScore/config.json: storestrue_token_id: 9454(the "Yes" token), withfalse_token_idleft null so the score is the raw "Yes" logit rather than a "Yes"-vs-"No" log-odds.
Modified files:
chat_template.jinja: prepended aquery/documentbranch as described above; the original Qwen3 ChatML logic is preserved unchanged below it.README.md: Expanded the "How to Use" section with apip install, expectedpredictoutput, thesigmoid(score / 5)post-processing one-liner, and a shortmodel.rankexample.
The main changes
from sentence_transformers import CrossEncoder
model = CrossEncoder("zeroentropy/zerank-2", revision="refs/pr/8")
query_documents = [
("What is 2+2?", "4"),
("What is 2+2?", "The answer is definitely 1 million"),
]
scores = model.predict(query_documents, convert_to_tensor=True)
print(scores)
# tensor([ 5.4062, -4.5000], device='cuda:0', dtype=torch.bfloat16)
probabilities = (scores / 5).sigmoid()
print(probabilities)
# tensor([0.7461, 0.2891], device='cuda:0', dtype=torch.bfloat16)
Note that the revision="refs/pr/8" means that you can already test the PR branch without having to clone or anything.
A few small deltas vs. the original predict() to be aware of (rankings are unaffected in all of them):
- Padding side: Sentence Transformers forces
padding_side="left"for causal LMsm, while the originalmodeling_zeranker.pyusespadding_side="right". This slightly shifts RoPE position IDs on padded items in a batch, drifting per-pair scores by up to ~0.01 on the shortest item in a mixed-length batch. The longest item in any batch is bit-identical. A customTransformersubclass that flips back to right-padding can match exactly, at the cost of needingtrust_remote_code=True. - Sigmoid lives client-side: the original
predict()returnssigmoid(yes_logit / 5)directly; here the caller applies(scores / 5).sigmoid(). The scaled sigmoid isn't possible in non-trust_remote_codeSentence Transformers, but can also be fixed to output 0...1 if we're going thetrust_remote_code=Trueroute.
I kept the existing modeling_zeranker.py, although it's not being used right now. Feel free to remove it also.
Please let me know if you have any questions or feedback!
- Tom Aarsen