Skip to content

Support modular Sentence Transformers cross-encoder rerankers (e.g. ettin-reranker)#867

Open
hotchpotch wants to merge 18 commits into
huggingface:mainfrom
hotchpotch:ettin-st-reranker-support
Open

Support modular Sentence Transformers cross-encoder rerankers (e.g. ettin-reranker)#867
hotchpotch wants to merge 18 commits into
huggingface:mainfrom
hotchpotch:ettin-st-reranker-support

Conversation

@hotchpotch

Copy link
Copy Markdown

What does this PR do?

This PR adds support for serving modular Sentence Transformers cross-encoder rerankers such as the
ettin-reranker family in TEI.

Unlike the existing *ForSequenceClassification rerankers, these score with a post-pooling head:

Transformer backbone -> Pooling -> Dense -> LayerNorm -> Dense(scores)

This format is already part of Sentence Transformers, so more rerankers are likely to ship in this
shape. On the same GPU, TEI serves them about 1.5x faster than the Sentence Transformers CrossEncoder
with numerically equivalent scores (benchmark in a comment below).

Changes

  • Add BackendOutput::{Predict, Embed} (split from ModelType) so an embedding backbone can be routed
    to predict: a reranker is ModelType::Embedding(pool) + BackendOutput::Predict. Existing embedding
    backbones load unchanged, and Backend::new / CandleBackend::new keep their signatures.
  • Router: detect a modular reranker from modules.json + the final Dense config (out_features == 1,
    output "scores"); default single-label map when id2label/label2id are missing; reject
    unsupported post-pooling modules.
  • Candle: load the post-pooling head (Dense/LayerNorm), run it on the pooled embeddings in one
    batch, validate the [batch, 1] score shape. Post-pooling prediction is Candle-only (ORT/Python skip
    with a clear log).
  • ModernBert: support rope_parameters and nullable eos/bos_token_id, extending the pattern from
    Set pad_token_id as nullable & add support for rope_parameters #832 (which covered GTE/Qwen/Gemma/...) to ModernBert, which ettin requires.

Testing

  • Unit tests for reranker detection and modules.json parsing (legacy + current module strings), and a
    Candle integration test with an insta snapshot (cross-encoder/ettin-reranker-17m-v1).
  • Scores match the reference Sentence Transformers CrossEncoder (details in a comment below).

Happy to adjust the design/naming (e.g. the BackendOutput name) based on feedback.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline?
  • Was this discussed/approved via a GitHub issue or the forum? Please add a link to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests? If applicable, did you include or update the insta snapshots?

Who can review?

cc @alvarobartt @Narsil — happy to take feedback

hotchpotch and others added 18 commits May 25, 2026 10:47
A transient download failure of the pooling/final-dense config used to
warn and continue, letting `detect_modular_reranker` fall back to the
embedding path and silently disable `/rerank`. Require those configs for
transformer -> pooling -> ... -> dense pipelines, and document why the
local-read fallbacks differ before vs. after the reranker is confirmed.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Wrapping the ORT and Python startup blocks in `else { ... }` re-indented
every original line, swamping the diff. Use `'label: { ... break }` early
guards so the unchanged backend bodies keep their original indentation and
the diff only shows the post-pooling-prediction skip.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The post-pooling prediction loader was nested inside the existing Dense
loop's `if`, pushing the original Dense-loading branch into an `else` and
re-indenting every line. Split it into `if use_post_pooling_prediction {
.. } else if let Some(dense_paths) ..` so the embedding Dense loop is
byte-identical to main, and fold the "requires at least one module" guard
into the new branch.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
- Drop the unused `module_input_name`/`module_output_name` fields from the
  Candle `DenseConfig` (they were `#[allow(dead_code)]`); the detection-time
  copy lives in the router's `DenseDetectionConfig`.
- Document why `PredictionHeadModule` is kept separate from `DenseLayer`
  despite the identical signature.
- Reword the detection-config download failure so it is not reranker-specific:
  the same `transformer -> pooling -> ... -> dense` shape also covers embedding
  models ending in a Dense projection.
- Apply rustfmt to the prediction-head loader and detection guard so the
  pre-commit `fmt` hook passes.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Add the post-pooling-head reranker family (e.g. `cross-encoder/ettin-reranker-*`)
to the supported re-rankers list, and note that it is scored by a post-pooling
Dense head rather than a `*ForSequenceClassification` head.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Cover the post-pooling prediction-head path end to end with
`cross-encoder/ettin-reranker-17m-v1`, mirroring the existing
`gte-reranker-modernbert` classification test:

- `download_modular_reranker_artifacts` fetches the embedding backbone plus
  the post-pooling scoring head modules (`2_Dense`, `3_LayerNorm`, `4_Dense`),
  parsing `modules.json` leniently so both legacy and current module type
  strings work.
- `test_modernbert_modular_reranker` loads it via
  `new_with_post_pooling_prediction` (CLS pooling) and snapshots the score.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Reword the `PredictionHeadModule` doc comment and the post-pooling loader
comment to state what the code is (the reranker scoring head, distinct from
the embedding `DenseLayer` projection) rather than narrating which paths were
kept separate or left untouched.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Drop the doc comments added on private/internal helpers (the repo does not
`///` private fns or test helpers) and shorten the inline rationale comments,
keeping only the non-obvious "why" notes in the terse style used elsewhere.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Link the ettin reranker family collection instead of a single checkpoint in
the supported re-rankers table and intro sentence.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@hotchpotch

hotchpotch commented May 25, 2026

Copy link
Copy Markdown
Author

Some throughput numbers vs the reference Sentence Transformers CrossEncoder.

Benchmark target:

Both runners used the same GPU (NVIDIA RTX PRO 6000 Blackwell), float16, max_length=512, batch size 512,
after a 2,048-pair warmup; model/dataset load time is excluded from the timed region. Reference stack:
sentence-transformers==5.5.0, transformers==5.9.0, flash-attn==2.8.3, torch==2.9.0+cu128.

Speed

Runner Attention / API Elapsed Pairs/s
SentenceTransformers CrossEncoder SDPA 16.670 s 6,012
SentenceTransformers CrossEncoder FlashAttention 2 10.008 s 10,014
TEI /predict - 6.709 s 14,940

TEI is about 1.5x faster than CrossEncoder + FlashAttention 2 and about 2.5x faster than
CrossEncoder + SDPA on the same GPU and dtype. SentenceTransformers throughput is essentially the same
in float16 and bfloat16, so these float16 figures are representative regardless of dtype. TEI per-request
latency at batch 512 / concurrency 2: mean 67.9 ms, p99 82.8 ms.

Methodology note: the benchmark drives TEI's /predict with 512 pairs per request (not /rerank).
Natural Questions is a set of independent query-answer pairs, so batched /predict is the right
endpoint; /rerank is a single-query / many-texts API and 1-pair requests would be dominated by HTTP
overhead. Concurrency is fixed at 2 (at batch 512, concurrency 4 exhausts TEI's request permits and
returns 429).

Score agreement (full dataset)

All 100,231 pairs, raw scores:

Comparison Pairs Mean abs diff p99 abs diff Max abs diff Pearson
SentenceTransformers fp16 vs TEI fp16 100,231 0.00080 0.00781 0.01953 0.999998
SentenceTransformers bf16 vs TEI fp16 100,231 0.01676 0.04688 0.53125 0.999827

At matched float16 the mean absolute score difference over the full dataset is about 8e-4 with Pearson
0.999998. The bf16 row reflects the dtype difference, not a TEI defect, and still correlates at 0.9998.
(The SentenceTransformers fp16 row uses SDPA; SDPA vs FlashAttention 2 changes its scores by under 1e-4,
so it does not affect this comparison.)

@hotchpotch

Copy link
Copy Markdown
Author

Single-pair parity and the effect of FlashAttention

cross-encoder/ettin-reranker-17m-v1, input ("What is Deep Learning?", "Deep Learning is not..."),
raw score:

Configuration Score Δ vs reference
SentenceTransformers CrossEncoder (fp32, CPU) 5.5547919 -
TEI, GPU, FlashAttention OFF (fp16) 5.5546875 0.0001
TEI, CPU (fp32) 5.553596 0.0012
TEI, GPU, FlashAttention ON (fp16) 5.5507812 0.0040
  • With FlashAttention disabled, TEI matches the reference to within 1e-4.
  • The largest single source of divergence on GPU is FlashAttention (about 4e-3). FlashAttention is
    mathematically equivalent to standard attention, but its online softmax recomputes the softmax in a
    different order; combined with fp16 rounding this produces the gap. This is the generic
    FlashAttention/fp16 trade-off, not something specific to the post-pooling reranker head.
  • In all configurations the difference is at most 0.07% of the score and never changes the ranking.
    Use USE_FLASH_ATTENTION=false (or CPU fp32) when exact parity with Sentence Transformers is required.

Verification (selected)

  • Focused tests: cargo test -p text-embeddings-router --lib / cargo test -p text-embeddings-backend --features candle --lib
  • Feature builds: candle,http / http,ort --no-default-features / http,python --no-default-features
  • Runtime: ettin-reranker served natively and via the CUDA Docker image, /rerank and /predict checked.
  • Reference: CrossEncoder("cross-encoder/ettin-reranker-17m-v1").predict([pair], activation_fn=nn.Identity()).

@nikmall

nikmall commented Jun 9, 2026

Copy link
Copy Markdown

Deployed the fork using GPU: NVIDIA A10G (23 GB VRAM), CUDA Driver 13.0
Calling the predict I for batches larger than 1 I get the following Error:

ERROR rerank:predict{truncate=true truncation_direction=Right raw_scores=false}: text_embeddings_core::infer: core/src/infer.rs:450: MatMulUnexpectedStriding { lhs_l: Layout { shape: [2, 1024], stride: [31744, 1], start_offset: 0 }, rhs_l: Layout { shape: [1024, 1024], stride: [1, 1024], start_offset: 0 }, bmnk: (1, 2, 1024, 1024), msg: "non-contiguous lhs" }

@hotchpotch

@hotchpotch

hotchpotch commented Jun 11, 2026

Copy link
Copy Markdown
Author

Hello! @nikmall

Thanks for the report. I could not test on an A10G directly, but I verified the same Docker CUDA/FA2 build path on my local GPU.

My local environment:

GPU: NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition
Compute capability: 12.0 / sm_120
Driver: 590.48.01
Host-reported CUDA: 13.1
Docker CUDA base: nvidia/cuda:12.9.1-*-ubuntu24.04
Container CUDA packages:
  cuda-cudart-12-9 12.9.79-1
  libcublas-12-9 12.9.1.4-1
  cuda-compat-12-9 575.57.08-0ubuntu1

I built the runtime image for my local GPU with FA2 enabled:

docker build -f Dockerfile-cuda \
  --build-arg CUDA_COMPUTE_CAP=120 \
  -t tei-cuda120-fa2-runtime .

Then I started TEI on GPU 0:

docker run -d --name tei-cuda120-fa2-repro \
  --gpus '"device=0"' \
  -p 18083:80 \
  -v ~/.cache/huggingface:/data \
  tei-cuda120-fa2-runtime \
  --model-id cross-encoder/ettin-reranker-17m-v1 \
  --dtype float16 \
  --max-batch-tokens 32768 \
  --max-batch-requests 64

The logs show that the CUDA FlashModernBert backend was used:

Starting FlashModernBert model on Cuda(...)
Loading post-pooling prediction module/s from path/s: ["2_Dense", "3_LayerNorm", "4_Dense"]
Ready

I tested /predict with batch size > 1:

curl -sS -w '\nHTTP %{http_code}\n' 127.0.0.1:18083/predict \
  -X POST -H 'Content-Type: application/json' \
  -d '{"inputs":[["What is Deep Learning?","Deep Learning is not..."],["What is Machine Learning?","Machine learning is a field of AI."]],"truncate":true,"truncation_direction":"Right","raw_scores":false}'

This returned HTTP 200. I also tested a batch of 16 and 32 concurrent /predict requests; all returned HTTP 200. I did not see MatMulUnexpectedStriding or non-contiguous lhs in the logs.

I also confirmed that the A10G-targeted Docker build completes successfully:

docker build -f Dockerfile-cuda \
  --build-arg CUDA_COMPUTE_CAP=86 \
  -t tei-cuda86-fa2-runtime .

I cannot run this sm_86 image locally because my available GPU is sm_120, and TEI correctly rejects the runtime/compile-time mismatch:

Runtime compute cap 120 is not compatible with compile time compute cap 86

Could you try building and running the A10G image with CUDA_COMPUTE_CAP=86 on the A10G machine?

docker build -f Dockerfile-cuda \
  --build-arg CUDA_COMPUTE_CAP=86 \
  -t tei-cuda86-fa2-runtime .
docker run --rm --name tei-a10g-repro \
  --gpus '"device=0"' \
  -p 18083:80 \
  -v ~/.cache/huggingface:/data \
  tei-cuda86-fa2-runtime \
  --model-id cross-encoder/ettin-reranker-17m-v1 \
  --dtype float16 \
  --max-batch-tokens 32768 \
  --max-batch-requests 64
curl -sS -w '\nHTTP %{http_code}\n' 127.0.0.1:18083/predict \
  -X POST -H 'Content-Type: application/json' \
  -d '{"inputs":[["What is Deep Learning?","Deep Learning is not..."],["What is Machine Learning?","Machine learning is a field of AI."]],"truncate":true,"truncation_direction":"Right","raw_scores":false}'

If this still fails on A10G with MatMulUnexpectedStriding, I think the fix should be generic rather than A10G-specific, probably around making the prediction-head input contiguous before passing it to the CUDA Linear/cuBLASLt path.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants