MaxSim
Introduction
A fast, memory-efficient exact MaxSim kernel for ColBERT-style late-interaction retrieval and training. It's a drop-in replacement for PyTorch's MaxSim that's ~10× faster in training and saves 98% of the backward memory. See benchmarks.
The kernel is packaged as a
Hugging Face kernels repo so you can
kernels.get_kernel(...) it in your own code with minimal dependencies.
The kernel computes
score(q, d) = sum_i max_j <q_i, d_j>
over batches of (query, document) pairs without materialising the full
[Lq, Ld] similarity matrix. Most of those values get discarded by max anyway
so the kernel never computes them in forward, and saves only the surviving
argmax positions ([Lq] int32 per pair) for backward. This is where the speed
and memory savings come from.
Install
The kernel is ahead-of-time compiled and distributed as a Hugging Face kernel repo. This means that you can use the kernel without any compile step and just with a minimal dependency. Using the kernel is as easy as:
uv add kernels # or: pip install kernels
from kernels import get_kernel
maxsim = get_kernel("erikkaum/maxsim", version=2, trust_remote_code=True)
Common workflows
The MaxSim kernel has three entrypoints with an inference and training variant of each:
- Contrastive: every query scored against every document. This is the classic ColBERT-style paradigm that's mostly used in training and fine-tuning.
- Padded: each query has a fixed number
Cof candidates, queries and documents are both padded to max-Lq/ max-Ld. This is the most common inference workflow as a second stage exact rerank. - Packed: arbitrary
(query, document)pair grids over ragged queries and documents. The lowest-level surface; the most flexible, but also the most boilerplate to set up. You build the offset arrays and pair-id arrays yourself.
The two most common ways you will use the kernel are contrastive training and padded reranking. The packed API is lower-level and is covered in the reference section below.
Contrastive training
Use this when training or fine-tuning a ColBERT-style model with in-batch
negatives. score_contrastive_train scores every query in the batch against
every document in the batch, producing an [Nq, Nb] score matrix.
These are the abbreviations used in this section:
Nq = number of queries in the batch
Nb = number of documents in the batch
q = query index
d = document index
Lq = query length after padding/truncation
dim = embedding dimension from the encoder
For example, in the common in-batch setup:
scores[q, d] = MaxSim(queries[q], documents[d])
This scores query q against document d. The batch is arranged so that
queries[i] and documents[i] are the positive pair for training example i,
so the diagonal entries are positives and the off-diagonal entries are in-batch
negatives:
d=0 d=1 d=2 d=3
q=0 + - - -
q=1 - + - -
q=2 - - + -
q=3 - - - +
That is why the result can be passed directly to cross-entropy with diagonal targets, like:
targets = torch.arange(Nq, device=scores.device)
loss = torch.nn.functional.cross_entropy(scores, targets)
The contrastive API uses a mixed layout: queries are rectangular, while documents are packed.
Queries are usually short and padded/truncated to one fixed length for the batch. Each filled square is one token embedding:
queries: [Nq, Lq, dim]
t0 t1 t2 t3 ... t31
q0 â– â– â– â– â–
q1 â– â– â– â– â–
q2 â– â– â– â– â–
In contrastive mode, all Lq query positions participate in MaxSim; there is no
separate query_lengths tensor. This matches the usual ColBERT training setup
where queries are encoded to a fixed query length.
Documents are often longer and have more variable lengths. If we padded every document to the longest document in the batch, some of the work would go to padding:
padded documents: [Nb, max_Ld, dim]
t0 t1 t2 ... t79 t80 ... t255
d0 â– â– â– â– â– â– 256 tokens
d1 â– â– â– â– â–¡ â–¡ 80 tokens
d2 â– â– â– ... â–¡ â–¡ â–¡ 64 tokens
â– = real token embedding
â–¡ = padding
Instead, contrastive mode packs only the real document token embeddings into one flat tensor:
documents: [total_d_tokens, dim]
idx 0 1 2 ... 255 | 256 257 ... 335 | 336 337 ... 399
token â– â– â– â– | â– â– â– | â– â– â–
doc doc 0 | doc 1 | doc 2
The companion document_offsets tensor marks the boundaries between documents:
document_offsets = [0, 256, 336, 400]
document 0 = documents[0:256] # 256 tokens
document 1 = documents[256:336] # 80 tokens
document 2 = documents[336:400] # 64 tokens
A typical training step looks like this:
import torch
import torch.nn.functional as F
from kernels import get_kernel
maxsim = get_kernel("erikkaum/maxsim", version=2, trust_remote_code=True)
# Encoder outputs.
queries = encoder_q(query_ids) # [Nq, Lq, dim]
doc_emb = encoder_d(document_ids) # [Nb, Ld, dim] before packing
# Naive/illustrative packing according to real token lengths.
documents = torch.cat(
[doc_emb[i, :doc_lengths[i]] for i in range(Nb)],
dim=0,
) # [total_d_tokens, dim]
document_offsets = torch.zeros(Nb + 1, dtype=torch.int32, device=doc_emb.device)
document_offsets[1:] = doc_lengths.cumsum(0)
scores = maxsim.score_contrastive_train(
queries,
documents,
document_offsets,
) # [Nq, Nb] fp32
# Diagonal positives: query i should match document i.
targets = torch.arange(Nq, device=scores.device)
# MaxSim scores grow with Lq, so training loops often use a temperature/scale.
temperature = 1.0
loss = F.cross_entropy(scores / temperature, targets)
loss.backward()
Typical ColBERT-style values are:
Nq, Nb = 16 to 64
Lq = 32
Ld = 80 to 256 document tokens before packing
dim = encoder output width, often 64 / 96 / 128
On CUDA, the contrastive training kernel currently uses the WMMA path and expects
Lq % 16 == 0anddim % 16 == 0. Typical ColBERT shapes such asLq=32,dim=128satisfy this.
The kernel saves only the winning document-token index per query token for
backward, rather than retaining the full [Nq, Nb, Lq, Ld] similarity tensor.
This is where most of the training-time memory saving comes from.
Padded inference
A typical ColBERT/late-interaction retrieval pipeline usually has two stages.
- A retriever such as PLAID first gives you
Kcandidate documents for each query. - Then rerank those
Kwith exact MaxSim, throughscore_candidates_padded(...). This computes the exact MaxSim for every query-candidate pair and returns the scores as a[B, K]matrix.
The inputs come from your retriever, e.g. PLAID:
queries [B, Lq, dim] fp16/bf16
candidates [B, K, Ld, dim] same dtype as queries
query_lengths [B] int32, real per-query token count (≤ Lq)
doc_lengths [B, K] int32, real per-doc token count (≤ Ld)
Lq and Ld are the padded lengths of the tensors, while query_lengths and
doc_lengths are the actual per-query and per-document token counts. Padding
positions are ignored when computing MaxSim. In practice, Lq is usually the
longest query length in the batch, and Ld is the longest document length in
the batch.
These are the abbreviations used in this section:
B = number of queries in the batch
K = number of candidate documents per query
Lq = padded query length
Ld = padded document length
dim = embedding dimension from your encoder, e.g. 64 / 96 / 128
We feed these into the score_candidates_padded, which returns a [B,K] matrix
with dtype fp32, regardless of the input dtype.
# Exact MaxSim score per (query, candidate)
scores = maxsim.score_candidates_padded(queries, candidates, query_lengths, doc_lengths)
And sort them to get the reranked order
reranked_ids = scores.argsort(dim=1, descending=True)
For B=32, this runs ~1.2–3× faster than PyTorch einsum+max at K=50 and ~2.5–5× at K=100. The padded kernel's advantage grows with the candidate count K. See Benchmarks.
Reference
Inference
# Padded: per-query reranking with a fixed K candidates/documents each. Inputs are
# padded to max-Lq / max-Ld; *_lengths give the real per-row length.
scores = maxsim.score_candidates_padded(
queries, # [B, Lq, dim] fp32 / fp16 / bf16
documents, # [B, K, Ld, dim] same dtype as queries
query_lengths, # [B] int32 — real query length (≤ Lq)
doc_lengths, # [B, K] int32 — real doc length (≤ Ld)
) # -> [B, K] fp32
# Contrastive: every query × every document. Returns the full score
# matrix [Nq, Nb] — what most contrastive losses (InfoNCE, etc.) want.
scores = maxsim.score_contrastive(
queries, # [Nq, Lq, dim] fp16 / bf16 on CUDA; fp32 also on Metal
documents, # [total_d_tokens, dim] flat (all docs concatenated)
document_offsets, # [Nb + 1] int32 documents[offsets[i]:offsets[i+1]] = doc i
) # -> [Nq, Nb] fp32
# Packed: arbitrary (q, d) pair grids. You pick which pairs to score via
# pair_query_ids / pair_document_ids; the kernel computes only those.
scores = maxsim.score_pairs_packed(
queries, # [total_q_tokens, dim] fp32 / fp16 / bf16
query_offsets, # [num_queries + 1] int32 — CSR offsets into queries
documents, # [total_d_tokens, dim] same dtype as queries
document_offsets, # [num_documents + 1] int32 — CSR offsets into documents
pair_query_ids, # [num_pairs] int32 — which query to score
pair_document_ids, # [num_pairs] int32 — which document to score
) # -> [num_pairs] fp32
Argmax helpers
Each inference API also has a _with_argmax variant. It returns the same fp32
scores plus the document-token index that won the max for each query token.
These helpers are useful for debugging alignments, visualizing which document
tokens matched, and inspecting the compact state used by the training kernels.
scores, argmax = maxsim.score_candidates_padded_with_argmax(
queries, # [B, Lq, dim]
documents, # [B, K, Ld, dim]
query_lengths, # [B]
doc_lengths, # [B, K]
) # scores: [B, K] fp32; argmax: [B, K, Lq] int32
scores, argmax = maxsim.score_contrastive_with_argmax(
queries, # [Nq, Lq, dim]
documents, # [total_d_tokens, dim]
document_offsets, # [Nb + 1]
) # scores: [Nq, Nb] fp32; argmax: [Nq, Nb, Lq] int32
scores, argmax = maxsim.score_pairs_packed_with_argmax(
queries,
query_offsets,
documents,
document_offsets,
pair_query_ids,
pair_document_ids,
) # scores: [num_pairs] fp32; argmax: [num_pairs, max_q_len] int32
Training
Each inference API has a _train variant with the same input layout and output
shape, but connected to PyTorch autograd. The _train variants return fp32
scores and propagate gradients to queries and documents; length, offset, and
pair-id tensors are metadata and do not receive gradients. Backward uses fp32
accumulation.
# Padded training.
scores = maxsim.score_candidates_padded_train(
queries, # [B, Lq, dim] fp32 / fp16 / bf16
documents, # [B, K, Ld, dim] same dtype as queries
query_lengths, # [B] int32 / int64
doc_lengths, # [B, K] int32 / int64
) # -> [B, K] fp32
# Contrastive training.
scores = maxsim.score_contrastive_train(
queries, # [Nq, Lq, dim] fp16 / bf16 on CUDA; fp32 also on Metal
documents, # [total_d_tokens, dim] same dtype as queries
document_offsets, # [Nb + 1] int32 / int64
) # -> [Nq, Nb] fp32
# Packed training.
scores = maxsim.score_pairs_packed_train(
queries, # [total_q_tokens, dim] fp32 / fp16 / bf16
query_offsets, # [num_queries + 1] int32 / int64
documents, # [total_d_tokens, dim] same dtype as queries
document_offsets, # [num_documents + 1] int32 / int64
pair_query_ids, # [num_pairs] int32 / int64
pair_document_ids, # [num_pairs] int32 / int64
) # -> [num_pairs] fp32
Example scripts
The examples/ directory is the quickest way to run some code locally:
just example 01 # quickstart: padded inference + padded training
just example 02 # argmax / alignment visualization
just example 03 # tiny contrastive InfoNCE training loop
just example 04 # retained-memory showcase vs naive PyTorch
Benchmarks
All numbers below are medians over 30 iterations × 3 independent runs
(MAXSIM_BENCH_REPEATS=3), generated by just cuda-bench-matrix. The baseline
is a straightforward PyTorch implementation that uses torch.einsum to
materialise the full similarity tensor before max. Raw outputs are committed
under bench_results/v2/.
ColBERT training on H200
The headline workload uses a common ColBERT-style in-batch contrastive shape:
Nq=Nb=32, Lq=32, Ld=80, dim=128. ~10× faster training step, with 1/80th the
retained backward state.
| dtype | maxsim step | naive step | full step× | bwd alone× | peak mem× | retained state |
|---|---|---|---|---|---|---|
| fp16 | 0.261 ms | 2.697 ms | 10.33× | 12.86× | 1.27× | 1/80 |
| bf16 | 0.272 ms | 2.704 ms | 9.94× | 12.22× | 1.27× | 1/80 |
NVIDIA H200 (PTX-JIT). Run-to-run spread over 3 repeats: 0.33× (bf16), 1.20× (fp16). The fp16
step is noisier at this small grid on Hopper. Raw JSON artifact:
bench_results/v2/h200.json.
The backward speedup is the load-bearing one — training is backward-dominated, and that's where the kernel's argmax-only save pays off:
naive contrastive backward: [Nq, Nb, Lq, Ld] fp32 similarities
maxsim contrastive backward: [Nq, Nb, Lq] int32 argmax
maxsim retained state = 1/Ld of naive
| Ld | naive retained | maxsim retained | maxsim / naive |
|---|---|---|---|
| 80 | 10.5 MB | 0.13 MB | 1/80 |
| 128 | 16.8 MB | 0.13 MB | 1/128 |
| 256 | 33.6 MB | 0.13 MB | 1/256 |
| 512 | 67.1 MB | 0.13 MB | 1/512 |
| 1024 | 134.2 MB | 0.13 MB | 1/1024 |
Cross-Device Contrastive fp16
This table is rendered from the GPU JSON artifacts available under
bench_results/v2/.
| GPU | sm | maxsim step | naive step | speedup |
|---|---|---|---|---|
| H200 (PTX-JIT) | sm_90 | 0.261 ms | 2.697 ms | 10.33× |
| A100 SXM4 80GB | sm_80 | 0.378 ms | 4.516 ms | 11.94× |
| L4 | sm_89 | 0.360 ms | 3.340 ms | 9.28× |
| A10G | sm_86 | 0.433 ms | 4.086 ms | 9.43× |
Full H200 benchmark matrix
The full matrix includes every benchmark workload present in the rendered JSON artifact.
| Surface | Preset | Shape | dtype | maxsim | PyTorch | speedup | padded | bwd× | peak× | retained state |
|---|---|---|---|---|---|---|---|---|---|---|
| contrastive_train | Contrastive | Nq=32, Nb=32, Lq=32, Ld=80, D=128 |
fp16 | 0.261 ms | 2.697 ms | 10.33× | — | 12.86× | 1.27× | 1/80 |
| contrastive_train | Contrastive | Nq=32, Nb=32, Lq=32, Ld=80, D=128 |
bf16 | 0.272 ms | 2.704 ms | 9.94× | — | 12.22× | 1.27× | 1/80 |
| contrastive_train | LongDocs | Nq=32, Nb=32, Lq=32, Ld=512, D=128 |
fp16 | 0.270 ms | 2.703 ms | 10.01× | — | 25.69× | 2.22× | 1/512 |
| contrastive_train | LongDocs | Nq=32, Nb=32, Lq=32, Ld=512, D=128 |
bf16 | 0.263 ms | 2.713 ms | 10.32× | — | 32.21× | 2.22× | 1/512 |
| contrastive_train | BigBatch | Nq=64, Nb=64, Lq=32, Ld=128, D=128 |
fp16 | 0.320 ms | 4.906 ms | 15.34× | — | 43.90× | 2.47× | 1/128 |
| contrastive_train | BigBatch | Nq=64, Nb=64, Lq=32, Ld=128, D=128 |
bf16 | 0.321 ms | 4.911 ms | 15.30× | — | 44.26× | 2.47× | 1/128 |
| padded_infer | Rerank | B=32, K=50, Lq=32, Ld=180, D=128 |
fp16 | 0.327 ms | 0.381 ms | 1.17× | — | — | 2.32× | — |
| padded_infer | Rerank | B=32, K=50, Lq=32, Ld=180, D=128 |
bf16 | 0.338 ms | 0.382 ms | 1.13× | — | — | 2.32× | — |
| padded_infer | HeavyRerank | B=32, K=100, Lq=32, Ld=256, D=128 |
fp16 | 0.335 ms | 0.826 ms | 2.47× | — | — | 2.89× | — |
| padded_infer | HeavyRerank | B=32, K=100, Lq=32, Ld=256, D=128 |
bf16 | 0.374 ms | 0.830 ms | 2.22× | — | — | 2.89× | — |
| packed_infer | PackedRerank | B=32, K=50, Lq=32, Ld=180, D=128 |
fp16 | 5.527 ms | 0.380 ms | 0.07× | 0.328 ms | — | 2.32× | — |
| packed_infer | PackedRerank | B=32, K=50, Lq=32, Ld=180, D=128 |
bf16 | 5.533 ms | 0.382 ms | 0.07× | 0.338 ms | — | 2.32× | — |
| packed_infer | PackedHeavyRerank | B=32, K=100, Lq=32, Ld=256, D=128 |
fp16 | 14.640 ms | 0.825 ms | 0.06× | 0.333 ms | — | 2.89× | — |
| packed_infer | PackedHeavyRerank | B=32, K=100, Lq=32, Ld=256, D=128 |
bf16 | 14.654 ms | 0.830 ms | 0.06× | 0.372 ms | — | 2.89× | — |
Full A100 benchmark matrix
| Surface | Preset | Shape | dtype | maxsim | PyTorch | speedup | padded | bwd× | peak× | retained state |
|---|---|---|---|---|---|---|---|---|---|---|
| contrastive_train | Contrastive | Nq=32, Nb=32, Lq=32, Ld=80, D=128 |
fp16 | 0.378 ms | 4.516 ms | 11.94× | — | 18.73× | 1.78× | 1/80 |
| contrastive_train | Contrastive | Nq=32, Nb=32, Lq=32, Ld=80, D=128 |
bf16 | 0.378 ms | 4.512 ms | 11.95× | — | 18.66× | 1.78× | 1/80 |
| contrastive_train | LongDocs | Nq=32, Nb=32, Lq=32, Ld=512, D=128 |
fp16 | 0.422 ms | 4.517 ms | 10.69× | — | 47.36× | 3.24× | 1/512 |
| contrastive_train | LongDocs | Nq=32, Nb=32, Lq=32, Ld=512, D=128 |
bf16 | 0.424 ms | 4.510 ms | 10.64× | — | 48.17× | 3.24× | 1/512 |
| contrastive_train | BigBatch | Nq=64, Nb=64, Lq=32, Ld=128, D=128 |
fp16 | 0.505 ms | 8.286 ms | 16.41× | — | 48.09× | 4.15× | 1/128 |
| contrastive_train | BigBatch | Nq=64, Nb=64, Lq=32, Ld=128, D=128 |
bf16 | 0.515 ms | 8.260 ms | 16.04× | — | 47.61× | 4.15× | 1/128 |
| padded_infer | Rerank | B=32, K=50, Lq=32, Ld=180, D=128 |
fp16 | 0.543 ms | 0.961 ms | 1.77× | — | — | 3.01× | — |
| padded_infer | Rerank | B=32, K=50, Lq=32, Ld=180, D=128 |
bf16 | 0.540 ms | 0.964 ms | 1.79× | — | — | 3.01× | — |
| padded_infer | HeavyRerank | B=32, K=100, Lq=32, Ld=256, D=128 |
fp16 | 0.592 ms | 2.469 ms | 4.17× | — | — | 3.30× | — |
| padded_infer | HeavyRerank | B=32, K=100, Lq=32, Ld=256, D=128 |
bf16 | 0.576 ms | 2.475 ms | 4.30× | — | — | 3.30× | — |
| packed_infer | PackedRerank | B=32, K=50, Lq=32, Ld=180, D=128 |
fp16 | 8.976 ms | 0.961 ms | 0.11× | 0.542 ms | — | 3.01× | — |
| packed_infer | PackedRerank | B=32, K=50, Lq=32, Ld=180, D=128 |
bf16 | 8.970 ms | 0.964 ms | 0.11× | 0.540 ms | — | 3.01× | — |
| packed_infer | PackedHeavyRerank | B=32, K=100, Lq=32, Ld=256, D=128 |
fp16 | 24.709 ms | 2.468 ms | 0.10× | 0.589 ms | — | 3.30× | — |
| packed_infer | PackedHeavyRerank | B=32, K=100, Lq=32, Ld=256, D=128 |
bf16 | 24.718 ms | 2.473 ms | 0.10× | 0.573 ms | — | 3.30× | — |
Packed rows compare score_pairs_packed against both the equivalent padded
kernel layout and the PyTorch einsum baseline.
Apple Silicon (Metal) benchmark matrix
Maxsim ships a Metal backend. The same workloads as above, measured on the MPS
build against the same naive PyTorch baseline. Peak-memory figures aren't
captured on MPS, so the peak× column shows —.
| Surface | Preset | Shape | dtype | maxsim | PyTorch | speedup | padded | bwd× | peak× | retained state |
|---|---|---|---|---|---|---|---|---|---|---|
| contrastive_train | Contrastive | Nq=32, Nb=32, Lq=32, Ld=80, D=128 |
fp16 | 0.925 ms | 4.945 ms | 5.34× | — | — | — | 1/80 |
| contrastive_train | Contrastive | Nq=32, Nb=32, Lq=32, Ld=80, D=128 |
bf16 | 0.863 ms | 4.738 ms | 5.49× | — | — | — | 1/80 |
| contrastive_train | LongDocs | Nq=32, Nb=32, Lq=32, Ld=512, D=128 |
fp16 | 2.471 ms | 15.966 ms | 6.46× | — | 61.52× | — | 1/512 |
| contrastive_train | LongDocs | Nq=32, Nb=32, Lq=32, Ld=512, D=128 |
bf16 | 2.455 ms | 15.830 ms | 6.45× | — | 46.16× | — | 1/512 |
| contrastive_train | BigBatch | Nq=64, Nb=64, Lq=32, Ld=128, D=128 |
fp16 | 2.501 ms | 15.527 ms | 6.21× | — | 56.26× | — | 1/128 |
| contrastive_train | BigBatch | Nq=64, Nb=64, Lq=32, Ld=128, D=128 |
bf16 | 2.485 ms | 15.485 ms | 6.23× | — | 71.39× | — | 1/128 |
| padded_infer | Rerank | B=32, K=50, Lq=32, Ld=180, D=128 |
fp16 | 1.693 ms | 6.229 ms | 3.68× | — | — | — | — |
| padded_infer | Rerank | B=32, K=50, Lq=32, Ld=180, D=128 |
bf16 | 1.688 ms | 5.867 ms | 3.48× | — | — | — | — |
| padded_infer | HeavyRerank | B=32, K=100, Lq=32, Ld=256, D=128 |
fp16 | 3.255 ms | 16.490 ms | 5.07× | — | — | — | — |
| padded_infer | HeavyRerank | B=32, K=100, Lq=32, Ld=256, D=128 |
bf16 | 3.311 ms | 16.440 ms | 4.96× | — | — | — | — |
| packed_infer | PackedRerank | B=32, K=50, Lq=32, Ld=180, D=128 |
fp16 | 2.529 ms | 5.897 ms | 2.33× | 1.697 ms | — | — | — |
| packed_infer | PackedRerank | B=32, K=50, Lq=32, Ld=180, D=128 |
bf16 | 2.566 ms | 6.085 ms | 2.37× | 1.686 ms | — | — | — |
| packed_infer | PackedHeavyRerank | B=32, K=100, Lq=32, Ld=256, D=128 |
fp16 | 4.335 ms | 16.524 ms | 3.81× | 3.324 ms | — | — | — |
| packed_infer | PackedHeavyRerank | B=32, K=100, Lq=32, Ld=256, D=128 |
bf16 | 4.306 ms | 16.533 ms | 3.84× | 3.273 ms | — | — | — |
Reproducing
just cuda-bench-matrix a100-large 3 30 # flavor, repeats, iters
just cuda-bench-matrix l40sx1 3 30 # optional additional GPU
just cuda-bench-matrix a10g-small 3 30 # optional additional GPU
just bench-matrix-metal 3 30 # optional Apple Silicon / MPS run
just bench-render # regenerate README tables from JSON
The CUDA bench recipe submits an HF Jobs run; bench-matrix-metal runs the same
matrix against the Metal/MPS build. Both write one JSON artifact per device
class to bench_results/v2/. The artifacts capture the
kernel commit, device/torch/CUDA versions, timing + variance per workload, and a
structured shape per row. just bench-render rewrites the tables above from
those JSONs — same data, no copy-paste drift.
Supported
| Backend | Devices | Input dtypes | Backward accumulation |
|---|---|---|---|
| Metal | Apple Silicon (MPS, MSL 3.1+) | fp32 / fp16 / bf16 | fp32 |
| CUDA | sm_80, sm_86, sm_89, sm_90 (PTX JIT) | fp32 / fp16 / bf16 | fp32 |
The table is package-level support. A few API-specific constraints are listed below.
Limitations and edge cases
- CUDA contrastive APIs (
score_contrastive,score_contrastive_train, andscore_contrastive_with_argmax) currently use the WMMA path directly. They require fp16/bf16 inputs,Lq % 16 == 0, anddim % 16 == 0. - Padded inference (
score_candidates_padded) accepts fp32/fp16/bf16 and arbitraryLq,Ld, anddimon both backends. CUDA uses the WMMA fast path for fp16/bf16 whenLq % 16 == 0anddim % 16 == 0; otherwise it uses the scalar path. - Padded training on CUDA uses the argmax WMMA path and currently requires
fp16/bf16 inputs,
Lq >= 16,Lq % 16 == 0, anddim % 16 == 0. - Metal handles arbitrary
Lqfor padded and contrastive paths. Itssimdgroup_matrixpath fires whendim % 8 == 0; otherwise it uses scalar work. Very largeLq * dimshapes can hit Metal threadgroup-memory limits. - Packed CUDA paths are scalar today:
score_pairs_packeduses a scalar score-only kernel, and packed training uses scalar argmax/backward kernels rather than WMMA. Which means that there isn't really a performance win over native PyTorch. Use packed for arbitrary or sparse pair grids; for fixed-Kreranking, prefer the padded API. - Backward kernels accumulate gradients with fp32 atomics. The results are numerically stable for training, but not bit-exact deterministic across repeated backward passes because floating-point additions can arrive in a different order.
- NaN inputs do not follow PyTorch's exact
torch.maxNaN propagation semantics. Kernel comparisons usev > max, so NaN candidates are skipped. A NaN in a query token usually drives the whole score for that pair to-inf; a NaN in one document token may simply be ignored if another document token wins the max. Treat non-finite encoder outputs as upstream bugs and pre-screen withtorch.isfinite(...)if you need strict behavior. - Hopper (
sm_90) currently runs via PTX forward compatibility. Native WGMMA tuning is future work.
Related projects and acknowledgements
- Flash-MaxSim is a fused Triton MaxSim kernel for ColBERT/ColPali scoring.
Source
Source: https://github.com/erikkaum/maxsim.
Kernel repo: https://huggingface.co/kernels/erikkaum/maxsim.
License: Apache-2.0.
- Downloads last month
- 40
- OS
- macoslinux
- Arch
- x86_64aarch64






