vedkdev's picture
Add README
6bb4efd verified
---
tags:
- text-diffusion
- machine-translation
- en-de
- masked-diffusion
- from-scratch
language:
- en
- de
datasets:
- wmt/wmt14
license: apache-2.0
---
# Text Diffusion Model for EN→DE Translation
A **masked discrete diffusion** model for English-to-German machine translation, trained from scratch on WMT14 EN-DE.
## Architecture
| Component | Detail |
|---|---|
| **Type** | Masked Discrete Diffusion |
| **Backbone** | DiT (Diffusion Transformer) with adaLN |
| **Parameters** | ~72M |
| **Blocks** | 12 DiT blocks |
| **Hidden dim** | 512, 8 attention heads |
| **Attention** | Bidirectional (no causal mask) with RoPE |
| **Conditioning** | Timestep via sinusoidal embeddings + adaLN; Segment embeddings for src/tgt |
| **Weight tying** | Input embeddings tied to output projection |
| **Tokenizer** | [Helsinki-NLP/opus-mt-en-de](https://huggingface.co/Helsinki-NLP/opus-mt-en-de) (~58K vocab) |
| **Max sequence** | 128 src + 128 tgt tokens |
### Inspired by
- **[MDLM](https://arxiv.org/abs/2406.07524)** — DiT backbone architecture, masked diffusion objective
- **[LLaDA](https://arxiv.org/abs/2502.09992)** — Conditional generation via SFT (keep prompt unmasked, mask only target), 1/t ELBO weighting
- **[DiNoiSer](https://arxiv.org/abs/2302.10025)** — Noise manipulation for conditional seq2seq diffusion
## How It Works
### Training (Forward Diffusion)
1. Source (EN) and target (DE) tokens are concatenated: `[source | target]`
2. A random masking rate `t ~ Uniform(0, 1)` is sampled per example
3. Each target token is independently masked with probability `t`
4. The bidirectional DiT predicts all masked tokens simultaneously
5. Loss = cross-entropy on masked positions only, weighted by `1/t` (continuous-time ELBO)
### Inference (Reverse Diffusion)
1. Start with source tokens + fully masked target: `[source | MASK MASK ... MASK]`
2. Over 50 denoising steps, iteratively predict and unmask tokens
3. At each step `t → s`: predict all masked tokens, randomly re-mask a fraction `s/t`
4. Final step: all remaining masks are filled with predictions
## Training Details
| Setting | Value |
|---|---|
| **Dataset** | WMT14 EN-DE (~4.5M parallel sentence pairs) |
| **Optimizer** | AdamW (lr=3e-4, β₁=0.9, β₂=0.98, wd=0.01) |
| **Schedule** | Cosine with 4K linear warmup |
| **Effective batch size** | 256 (64 × 4 gradient accumulation) |
| **Max steps** | 200,000 |
| **Mixed precision** | FP16 |
| **Gradient clipping** | max_norm=1.0 |
| **Evaluation** | SacreBLEU on WMT14 test set every 20K steps |
## Quick Start
### Install dependencies
```bash
pip install torch transformers datasets trackio sacrebleu sacremoses sentencepiece protobuf
```
### Train
```bash
git clone https://huggingface.co/vedkdev/text-diffusion-en-de
cd text-diffusion-en-de
python train.py
```
The script will:
- Download WMT14 EN-DE automatically
- Train for 200K steps with logging via [Trackio](https://huggingface.co/docs/trackio)
- Evaluate SacreBLEU periodically
- Push checkpoints to this repo
### Adjusting for your hardware
Edit the `TRAIN_CONFIG` dict in `train.py`:
| GPU VRAM | Recommended `batch_size` | `gradient_accumulation_steps` |
|---|---|---|
| 24GB (A10G/3090/4090) | 64 | 4 |
| 16GB (T4/V100) | 32 | 8 |
| 12GB (3060) | 16 | 16 |
| 8GB (3070) | 8 | 32 |
### Inference (after training)
```python
import torch, json
from train import DiffusionTranslator, DiffusionTranslatorConfig, generate
from transformers import AutoTokenizer
# Load checkpoint
config = DiffusionTranslatorConfig(**json.load(open("checkpoints/best/config.json")))
model = DiffusionTranslator(config)
model.load_state_dict(torch.load("checkpoints/best/model.pt", map_location="cpu"))
model.eval()
tokenizer = AutoTokenizer.from_pretrained("checkpoints/best/")
# Translate
text = "The weather is nice today."
src = tokenizer(f"translate English to German: {text}",
max_length=128, truncation=True, padding="max_length",
return_tensors="pt")
gen_ids = generate(model, src["input_ids"], torch.zeros_like(src["input_ids"]),
config, num_steps=50, device="cpu")
print(tokenizer.decode(gen_ids[0], skip_special_tokens=True))
```
## Expected Results
Based on published literature for similar architectures on WMT14 EN→DE:
| Model | BLEU | Reference |
|---|---|---|
| Autoregressive Transformer | ~27 | Vaswani et al. |
| DiNoiSer (continuous diffusion) | 24.6 | Ye et al. 2023 |
| SeqDiffuSeq | 19.8 | Yuan et al. 2022 |
| E2D2 (discrete diffusion) | 24.8 | Kuleshov et al. 2024 |
| **This model (target)** | **15-20** | ~72M params, no KD |
> Note: Text diffusion models typically score 2-5 BLEU below autoregressive transformers of similar size. Knowledge distillation (KD) from an AR teacher can close the gap by ~1-2 BLEU.
## Citation
If you use this model, please cite the foundational papers:
```bibtex
@article{sahoo2024mdlm,
title={Simple and Effective Masked Diffusion Language Models},
author={Sahoo, Subham Sekhar and Arriola, Marianne and Schiff, Yair and Gokaslan, Aaron and Marroquin, Edgar and Kuleshov, Volodymyr},
journal={NeurIPS},
year={2024}
}
@article{nie2025llada,
title={Large Language Diffusion Models},
author={Nie, Shen and Zhu, Fengqi and You, Chao and Zhang, Xiaojun and Ou, Zhenguo and Zhu, Jun},
journal={arXiv preprint arXiv:2502.09992},
year={2025}
}
@article{ye2023dinoiser,
title={DiNoiSer: Diffused Conditional Sequence Learning by Manipulating Noises},
author={Ye, Jiasheng and Zheng, Zaixiang and Bao, Yu and Qian, Lihua and Gu, Quanquan},
journal={ACL},
year={2023}
}
```