| --- |
| tags: |
| - text-diffusion |
| - machine-translation |
| - en-de |
| - masked-diffusion |
| - from-scratch |
| language: |
| - en |
| - de |
| datasets: |
| - wmt/wmt14 |
| license: apache-2.0 |
| --- |
| |
| # Text Diffusion Model for EN→DE Translation |
|
|
| A **masked discrete diffusion** model for English-to-German machine translation, trained from scratch on WMT14 EN-DE. |
|
|
| ## Architecture |
|
|
| | Component | Detail | |
| |---|---| |
| | **Type** | Masked Discrete Diffusion | |
| | **Backbone** | DiT (Diffusion Transformer) with adaLN | |
| | **Parameters** | ~72M | |
| | **Blocks** | 12 DiT blocks | |
| | **Hidden dim** | 512, 8 attention heads | |
| | **Attention** | Bidirectional (no causal mask) with RoPE | |
| | **Conditioning** | Timestep via sinusoidal embeddings + adaLN; Segment embeddings for src/tgt | |
| | **Weight tying** | Input embeddings tied to output projection | |
| | **Tokenizer** | [Helsinki-NLP/opus-mt-en-de](https://huggingface.co/Helsinki-NLP/opus-mt-en-de) (~58K vocab) | |
| | **Max sequence** | 128 src + 128 tgt tokens | |
|
|
| ### Inspired by |
| - **[MDLM](https://arxiv.org/abs/2406.07524)** — DiT backbone architecture, masked diffusion objective |
| - **[LLaDA](https://arxiv.org/abs/2502.09992)** — Conditional generation via SFT (keep prompt unmasked, mask only target), 1/t ELBO weighting |
| - **[DiNoiSer](https://arxiv.org/abs/2302.10025)** — Noise manipulation for conditional seq2seq diffusion |
|
|
| ## How It Works |
|
|
| ### Training (Forward Diffusion) |
| 1. Source (EN) and target (DE) tokens are concatenated: `[source | target]` |
| 2. A random masking rate `t ~ Uniform(0, 1)` is sampled per example |
| 3. Each target token is independently masked with probability `t` |
| 4. The bidirectional DiT predicts all masked tokens simultaneously |
| 5. Loss = cross-entropy on masked positions only, weighted by `1/t` (continuous-time ELBO) |
|
|
| ### Inference (Reverse Diffusion) |
| 1. Start with source tokens + fully masked target: `[source | MASK MASK ... MASK]` |
| 2. Over 50 denoising steps, iteratively predict and unmask tokens |
| 3. At each step `t → s`: predict all masked tokens, randomly re-mask a fraction `s/t` |
| 4. Final step: all remaining masks are filled with predictions |
|
|
| ## Training Details |
|
|
| | Setting | Value | |
| |---|---| |
| | **Dataset** | WMT14 EN-DE (~4.5M parallel sentence pairs) | |
| | **Optimizer** | AdamW (lr=3e-4, β₁=0.9, β₂=0.98, wd=0.01) | |
| | **Schedule** | Cosine with 4K linear warmup | |
| | **Effective batch size** | 256 (64 × 4 gradient accumulation) | |
| | **Max steps** | 200,000 | |
| | **Mixed precision** | FP16 | |
| | **Gradient clipping** | max_norm=1.0 | |
| | **Evaluation** | SacreBLEU on WMT14 test set every 20K steps | |
| |
| ## Quick Start |
| |
| ### Install dependencies |
| |
| ```bash |
| pip install torch transformers datasets trackio sacrebleu sacremoses sentencepiece protobuf |
| ``` |
| |
| ### Train |
| |
| ```bash |
| git clone https://huggingface.co/vedkdev/text-diffusion-en-de |
| cd text-diffusion-en-de |
| python train.py |
| ``` |
| |
| The script will: |
| - Download WMT14 EN-DE automatically |
| - Train for 200K steps with logging via [Trackio](https://huggingface.co/docs/trackio) |
| - Evaluate SacreBLEU periodically |
| - Push checkpoints to this repo |
| |
| ### Adjusting for your hardware |
| |
| Edit the `TRAIN_CONFIG` dict in `train.py`: |
|
|
| | GPU VRAM | Recommended `batch_size` | `gradient_accumulation_steps` | |
| |---|---|---| |
| | 24GB (A10G/3090/4090) | 64 | 4 | |
| | 16GB (T4/V100) | 32 | 8 | |
| | 12GB (3060) | 16 | 16 | |
| | 8GB (3070) | 8 | 32 | |
|
|
| ### Inference (after training) |
|
|
| ```python |
| import torch, json |
| from train import DiffusionTranslator, DiffusionTranslatorConfig, generate |
| from transformers import AutoTokenizer |
| |
| # Load checkpoint |
| config = DiffusionTranslatorConfig(**json.load(open("checkpoints/best/config.json"))) |
| model = DiffusionTranslator(config) |
| model.load_state_dict(torch.load("checkpoints/best/model.pt", map_location="cpu")) |
| model.eval() |
| |
| tokenizer = AutoTokenizer.from_pretrained("checkpoints/best/") |
| |
| # Translate |
| text = "The weather is nice today." |
| src = tokenizer(f"translate English to German: {text}", |
| max_length=128, truncation=True, padding="max_length", |
| return_tensors="pt") |
| |
| gen_ids = generate(model, src["input_ids"], torch.zeros_like(src["input_ids"]), |
| config, num_steps=50, device="cpu") |
| print(tokenizer.decode(gen_ids[0], skip_special_tokens=True)) |
| ``` |
|
|
| ## Expected Results |
|
|
| Based on published literature for similar architectures on WMT14 EN→DE: |
|
|
| | Model | BLEU | Reference | |
| |---|---|---| |
| | Autoregressive Transformer | ~27 | Vaswani et al. | |
| | DiNoiSer (continuous diffusion) | 24.6 | Ye et al. 2023 | |
| | SeqDiffuSeq | 19.8 | Yuan et al. 2022 | |
| | E2D2 (discrete diffusion) | 24.8 | Kuleshov et al. 2024 | |
| | **This model (target)** | **15-20** | ~72M params, no KD | |
|
|
| > Note: Text diffusion models typically score 2-5 BLEU below autoregressive transformers of similar size. Knowledge distillation (KD) from an AR teacher can close the gap by ~1-2 BLEU. |
|
|
| ## Citation |
|
|
| If you use this model, please cite the foundational papers: |
|
|
| ```bibtex |
| @article{sahoo2024mdlm, |
| title={Simple and Effective Masked Diffusion Language Models}, |
| author={Sahoo, Subham Sekhar and Arriola, Marianne and Schiff, Yair and Gokaslan, Aaron and Marroquin, Edgar and Kuleshov, Volodymyr}, |
| journal={NeurIPS}, |
| year={2024} |
| } |
| |
| @article{nie2025llada, |
| title={Large Language Diffusion Models}, |
| author={Nie, Shen and Zhu, Fengqi and You, Chao and Zhang, Xiaojun and Ou, Zhenguo and Zhu, Jun}, |
| journal={arXiv preprint arXiv:2502.09992}, |
| year={2025} |
| } |
| |
| @article{ye2023dinoiser, |
| title={DiNoiSer: Diffused Conditional Sequence Learning by Manipulating Noises}, |
| author={Ye, Jiasheng and Zheng, Zaixiang and Bao, Yu and Qian, Lihua and Gu, Quanquan}, |
| journal={ACL}, |
| year={2023} |
| } |
| ``` |
|
|