File size: 5,617 Bytes
6bb4efd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 | ---
tags:
- text-diffusion
- machine-translation
- en-de
- masked-diffusion
- from-scratch
language:
- en
- de
datasets:
- wmt/wmt14
license: apache-2.0
---
# Text Diffusion Model for EN→DE Translation
A **masked discrete diffusion** model for English-to-German machine translation, trained from scratch on WMT14 EN-DE.
## Architecture
| Component | Detail |
|---|---|
| **Type** | Masked Discrete Diffusion |
| **Backbone** | DiT (Diffusion Transformer) with adaLN |
| **Parameters** | ~72M |
| **Blocks** | 12 DiT blocks |
| **Hidden dim** | 512, 8 attention heads |
| **Attention** | Bidirectional (no causal mask) with RoPE |
| **Conditioning** | Timestep via sinusoidal embeddings + adaLN; Segment embeddings for src/tgt |
| **Weight tying** | Input embeddings tied to output projection |
| **Tokenizer** | [Helsinki-NLP/opus-mt-en-de](https://huggingface.co/Helsinki-NLP/opus-mt-en-de) (~58K vocab) |
| **Max sequence** | 128 src + 128 tgt tokens |
### Inspired by
- **[MDLM](https://arxiv.org/abs/2406.07524)** — DiT backbone architecture, masked diffusion objective
- **[LLaDA](https://arxiv.org/abs/2502.09992)** — Conditional generation via SFT (keep prompt unmasked, mask only target), 1/t ELBO weighting
- **[DiNoiSer](https://arxiv.org/abs/2302.10025)** — Noise manipulation for conditional seq2seq diffusion
## How It Works
### Training (Forward Diffusion)
1. Source (EN) and target (DE) tokens are concatenated: `[source | target]`
2. A random masking rate `t ~ Uniform(0, 1)` is sampled per example
3. Each target token is independently masked with probability `t`
4. The bidirectional DiT predicts all masked tokens simultaneously
5. Loss = cross-entropy on masked positions only, weighted by `1/t` (continuous-time ELBO)
### Inference (Reverse Diffusion)
1. Start with source tokens + fully masked target: `[source | MASK MASK ... MASK]`
2. Over 50 denoising steps, iteratively predict and unmask tokens
3. At each step `t → s`: predict all masked tokens, randomly re-mask a fraction `s/t`
4. Final step: all remaining masks are filled with predictions
## Training Details
| Setting | Value |
|---|---|
| **Dataset** | WMT14 EN-DE (~4.5M parallel sentence pairs) |
| **Optimizer** | AdamW (lr=3e-4, β₁=0.9, β₂=0.98, wd=0.01) |
| **Schedule** | Cosine with 4K linear warmup |
| **Effective batch size** | 256 (64 × 4 gradient accumulation) |
| **Max steps** | 200,000 |
| **Mixed precision** | FP16 |
| **Gradient clipping** | max_norm=1.0 |
| **Evaluation** | SacreBLEU on WMT14 test set every 20K steps |
## Quick Start
### Install dependencies
```bash
pip install torch transformers datasets trackio sacrebleu sacremoses sentencepiece protobuf
```
### Train
```bash
git clone https://huggingface.co/vedkdev/text-diffusion-en-de
cd text-diffusion-en-de
python train.py
```
The script will:
- Download WMT14 EN-DE automatically
- Train for 200K steps with logging via [Trackio](https://huggingface.co/docs/trackio)
- Evaluate SacreBLEU periodically
- Push checkpoints to this repo
### Adjusting for your hardware
Edit the `TRAIN_CONFIG` dict in `train.py`:
| GPU VRAM | Recommended `batch_size` | `gradient_accumulation_steps` |
|---|---|---|
| 24GB (A10G/3090/4090) | 64 | 4 |
| 16GB (T4/V100) | 32 | 8 |
| 12GB (3060) | 16 | 16 |
| 8GB (3070) | 8 | 32 |
### Inference (after training)
```python
import torch, json
from train import DiffusionTranslator, DiffusionTranslatorConfig, generate
from transformers import AutoTokenizer
# Load checkpoint
config = DiffusionTranslatorConfig(**json.load(open("checkpoints/best/config.json")))
model = DiffusionTranslator(config)
model.load_state_dict(torch.load("checkpoints/best/model.pt", map_location="cpu"))
model.eval()
tokenizer = AutoTokenizer.from_pretrained("checkpoints/best/")
# Translate
text = "The weather is nice today."
src = tokenizer(f"translate English to German: {text}",
max_length=128, truncation=True, padding="max_length",
return_tensors="pt")
gen_ids = generate(model, src["input_ids"], torch.zeros_like(src["input_ids"]),
config, num_steps=50, device="cpu")
print(tokenizer.decode(gen_ids[0], skip_special_tokens=True))
```
## Expected Results
Based on published literature for similar architectures on WMT14 EN→DE:
| Model | BLEU | Reference |
|---|---|---|
| Autoregressive Transformer | ~27 | Vaswani et al. |
| DiNoiSer (continuous diffusion) | 24.6 | Ye et al. 2023 |
| SeqDiffuSeq | 19.8 | Yuan et al. 2022 |
| E2D2 (discrete diffusion) | 24.8 | Kuleshov et al. 2024 |
| **This model (target)** | **15-20** | ~72M params, no KD |
> Note: Text diffusion models typically score 2-5 BLEU below autoregressive transformers of similar size. Knowledge distillation (KD) from an AR teacher can close the gap by ~1-2 BLEU.
## Citation
If you use this model, please cite the foundational papers:
```bibtex
@article{sahoo2024mdlm,
title={Simple and Effective Masked Diffusion Language Models},
author={Sahoo, Subham Sekhar and Arriola, Marianne and Schiff, Yair and Gokaslan, Aaron and Marroquin, Edgar and Kuleshov, Volodymyr},
journal={NeurIPS},
year={2024}
}
@article{nie2025llada,
title={Large Language Diffusion Models},
author={Nie, Shen and Zhu, Fengqi and You, Chao and Zhang, Xiaojun and Ou, Zhenguo and Zhu, Jun},
journal={arXiv preprint arXiv:2502.09992},
year={2025}
}
@article{ye2023dinoiser,
title={DiNoiSer: Diffused Conditional Sequence Learning by Manipulating Noises},
author={Ye, Jiasheng and Zheng, Zaixiang and Bao, Yu and Qian, Lihua and Gu, Quanquan},
journal={ACL},
year={2023}
}
```
|