SpTransformer

Transformer network for predicting tissue-specific splicing from pre-mRNA sequences.

Disclaimer

This is an UNOFFICIAL implementation of SpliceTransformer predicts tissue-specific splicing linked to human diseases by Ningyuan You, Chang Liu, Yuxin Gu, et al. and Ning Shen.

The OFFICIAL repository of SpliceTransformer (SpTransformer) is at ShenLab-Genomics/SpliceTransformer.

The MultiMolecule team has confirmed that the provided model and checkpoints are producing the same intermediate representations as the original implementation.

The team releasing SpTransformer did not write this model card for this model so this model card has been written by the MultiMolecule team.

Model Details

SpTransformer (SpliceTransformer) is a deep neural network that predicts tissue-specific splicing from primary pre-mRNA sequence. It combines two pretrained SpliceAI-style dilated-residual convolutional feature extractors with a trainable input-projection path; the concatenated features are processed by a Sinkhorn transformer attention block with axial positional embeddings. For each position the network predicts a 3-channel splice-site score (no-splice / acceptor / donor) and a per-position splice-site usage score across 15 human tissues. The model uses a fixed flanking context of 4,000 nucleotides on each side of every predicted position. SpTransformer is typically used to estimate the effect of genetic variants on tissue-specific splicing by scoring reference and alternate sequences and taking the difference. Please refer to the Training Details section for more information on the training process.

Model Specification

Num Layers Hidden Size Num Heads Intermediate Size Max Seq Len Num Parameters (M) FLOPs (G) MACs (G) Context
8 256 8 1024 8192 17.07 290.72 144.65 4000
  • Num Layers / Hidden Size / Num Heads / Intermediate Size / Max Seq Len describe the Sinkhorn transformer attention block.
  • The two SpliceAI-style feature extractors use hidden sizes 128 and 64; Num Parameters counts the full checkpoint.
  • Context is the fixed flanking context (in nucleotides) consumed on each side of every predicted position.
  • FLOPs and MACs are measured on a 100-nucleotide input.

Links

Usage

The model file depends on the multimolecule library. You can install it using pip:

pip install multimolecule

Direct Use

RNA Splicing Site Prediction

You can use this model directly to predict per-nucleotide tissue-specific splicing of a pre-mRNA sequence:

>>> from multimolecule import DnaTokenizer, SpTransformerModel

>>> tokenizer = DnaTokenizer.from_pretrained("multimolecule/sptransformer")
>>> model = SpTransformerModel.from_pretrained("multimolecule/sptransformer")
>>> output = model(tokenizer("AGCAGTCATTATGGCGAA", return_tensors="pt")["input_ids"])

>>> output.keys()
odict_keys(['last_hidden_state', 'logits'])

The logits tensor reproduces the original SpTransformer output: a 3-channel splice-site score (no-splice / acceptor / donor) and a per-tissue (15 tissues) splice-site usage score for each position.

Downstream Use

Token Prediction

You can fine-tune SpTransformer for per-nucleotide tissue-specific splicing regression with [SpTransformerForTokenPrediction][multimolecule.models.SpTransformerForTokenPrediction], which adds a shared token prediction head on top of the backbone.

Interpretability: Faithful Sparse-Attention Exposure

SpTransformer's attention block does not compute dense self-attention. Each layer ([SpTransformerSelfAttention][multimolecule.models.sptransformer.modeling_sptransformer.SpTransformerSelfAttention]) splits its heads into two groups with fundamentally different sparse-attention structures:

  • Windowed-local heads โ€” each window of bucket_size tokens attends only to itself plus the immediately preceding and following window (a look_backward=1, look_forward=1 look-around). Boundary positions are masked.
  • Sinkhorn sorted-bucket heads โ€” each query bucket attends to the concatenation of (a) one sorted / reordered key bucket selected by a parameter-free attention-sort net (differentiable_topk(R, k=1)) and (b) its own local bucket.

Because these two patterns operate on different key axes, there is no single dense (batch, heads, sequence, sequence) tensor that faithfully represents the computation. Materialising a zero-filled sequence x sequence grid would be a misleading interpretability artifact, so this model does not expose one.

Instead, attention recording is opt-in and faithful. Passing output_attentions=True (or setting config.output_attentions=True) returns, for every attention layer, a [SpTransformerAttentionMap][multimolecule.models.SpTransformerAttentionMap] holding the actual softmax weights used in the forward pass plus the indexing/permutation needed to map them back to absolute sequence positions:

  • local_attentions (B, num_local_heads, num_windows, W, (look_backward + 1 + look_forward) * W) โ€” the real per-window softmax weights; padded look-around columns carry weight 0.
  • local_key_positions (num_windows, (look_backward + 1 + look_forward) * W) โ€” absolute source position of every local key-axis column (-1 marks padded columns).
  • sinkhorn_attentions (B, num_sinkhorn_heads, num_buckets, W, 2 * W) โ€” the real per-bucket softmax weights over the [reordered-bucket | own-bucket] key axis.
  • sinkhorn_reorder (B, num_sinkhorn_heads, num_buckets, num_buckets) โ€” the exact bucket-permutation matrix; for query bucket u, the nonzero column v of row u says the reordered key bucket (columns 0:W of sinkhorn_attentions) is source bucket v (absolute positions v*W : v*W + W).
  • scalar metadata: bucket_size, look_backward, look_forward, num_local_heads, num_sinkhorn_heads, sequence_length.

W is bucket_size; local heads come first along the head axis, Sinkhorn heads second. These are structured block weights, not dense attention matrices โ€” re-deriving the per-type attention output by contracting these exact weights with the (block-gathered) values reproduces the layer output exactly. Recording is opt-in, so the default forward path and its numerics are byte-for-byte unchanged.

>>> import torch
>>> from multimolecule import SpTransformerConfig, SpTransformerModel
>>> config = SpTransformerConfig(bucket_size=4, max_seq_len=16, context=2, num_hidden_layers=2)
>>> model = SpTransformerModel(config).eval()
>>> output = model(torch.randint(config.vocab_size, (1, 16)), output_attentions=True)
>>> layer0 = output.attentions[0]
>>> layer0.local_attentions.shape
torch.Size([1, 2, 4, 4, 12])
>>> layer0.sinkhorn_attentions.shape
torch.Size([1, 6, 4, 4, 8])
>>> layer0.sinkhorn_reorder.shape
torch.Size([1, 6, 4, 4])

Training Details

SpTransformer was trained to predict tissue-specific splicing from primary pre-mRNA sequence.

Training Data

SpTransformer was trained on splicing measurements derived from RNA-seq data across 15 human tissues, using gene annotations from GENCODE, together with multi-species sequence data. The two convolutional feature extractors were pre-trained as SpliceAI-style splice-site predictors; MultiMolecule exposes them as trainable submodules for downstream fine-tuning. For each predicted nucleotide, a sequence window centered on that nucleotide was used, with the flanking context padded with N (unknown nucleotide) when near transcript ends.

Training Procedure

Pre-training

The model was trained to minimize a combination of cross-entropy loss over splice-site classification and a regression loss over per-tissue splice-site usage, comparing predictions against measurements derived from RNA-seq.

Citation

@article{You2024,
  author    = {You, Ningyuan and Liu, Chang and Gu, Yuxin and Wang, Rong and Jia, Hanying and Zhang, Tianyun and Jiang, Song and Shi, Jinsong and Chen, Ming and Guan, Min-Xin and Sun, Siqi and Pei, Shanshan and Liu, Zhihong and Shen, Ning},
  title     = {{SpliceTransformer predicts tissue-specific splicing linked to human diseases}},
  journal   = {Nature Communications},
  year      = {2024},
  volume    = {15},
  number    = {1},
  pages     = {9129},
  month     = {oct},
  doi       = {10.1038/s41467-024-53088-6},
  issn      = {2041-1723},
  url       = {https://doi.org/10.1038/s41467-024-53088-6}
}

The artifacts distributed in this repository are part of the MultiMolecule project. If you use MultiMolecule in your research, you must cite the MultiMolecule project as follows:

@software{chen_2024_12638419,
  author    = {Chen, Zhiyuan and Zhu, Sophia Y.},
  title     = {MultiMolecule},
  doi       = {10.5281/zenodo.12638419},
  publisher = {Zenodo},
  url       = {https://doi.org/10.5281/zenodo.12638419},
  year      = 2024,
  month     = may,
  day       = 4
}

Contact

Please use GitHub issues of MultiMolecule for any questions or comments on the model card.

Please contact the authors of the SpliceTransformer paper for questions or comments on the paper/model.

License

This model implementation is licensed under the GNU Affero General Public License.

For additional terms and clarifications, please refer to our License FAQ.

SPDX-License-Identifier: AGPL-3.0-or-later
Downloads last month
17
Safetensors
Model size
17.1M params
Tensor type
F32
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Spaces using multimolecule/sptransformer 2