File size: 5,617 Bytes
6bb4efd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
---
tags:
- text-diffusion
- machine-translation
- en-de
- masked-diffusion
- from-scratch
language:
- en
- de
datasets:
- wmt/wmt14
license: apache-2.0
---

# Text Diffusion Model for EN→DE Translation

A **masked discrete diffusion** model for English-to-German machine translation, trained from scratch on WMT14 EN-DE.

## Architecture

| Component | Detail |
|---|---|
| **Type** | Masked Discrete Diffusion |
| **Backbone** | DiT (Diffusion Transformer) with adaLN |
| **Parameters** | ~72M |
| **Blocks** | 12 DiT blocks |
| **Hidden dim** | 512, 8 attention heads |
| **Attention** | Bidirectional (no causal mask) with RoPE |
| **Conditioning** | Timestep via sinusoidal embeddings + adaLN; Segment embeddings for src/tgt |
| **Weight tying** | Input embeddings tied to output projection |
| **Tokenizer** | [Helsinki-NLP/opus-mt-en-de](https://huggingface.co/Helsinki-NLP/opus-mt-en-de) (~58K vocab) |
| **Max sequence** | 128 src + 128 tgt tokens |

### Inspired by
- **[MDLM](https://arxiv.org/abs/2406.07524)** — DiT backbone architecture, masked diffusion objective
- **[LLaDA](https://arxiv.org/abs/2502.09992)** — Conditional generation via SFT (keep prompt unmasked, mask only target), 1/t ELBO weighting
- **[DiNoiSer](https://arxiv.org/abs/2302.10025)** — Noise manipulation for conditional seq2seq diffusion

## How It Works

### Training (Forward Diffusion)
1. Source (EN) and target (DE) tokens are concatenated: `[source | target]`
2. A random masking rate `t ~ Uniform(0, 1)` is sampled per example
3. Each target token is independently masked with probability `t`
4. The bidirectional DiT predicts all masked tokens simultaneously
5. Loss = cross-entropy on masked positions only, weighted by `1/t` (continuous-time ELBO)

### Inference (Reverse Diffusion)
1. Start with source tokens + fully masked target: `[source | MASK MASK ... MASK]`
2. Over 50 denoising steps, iteratively predict and unmask tokens
3. At each step `t → s`: predict all masked tokens, randomly re-mask a fraction `s/t`
4. Final step: all remaining masks are filled with predictions

## Training Details

| Setting | Value |
|---|---|
| **Dataset** | WMT14 EN-DE (~4.5M parallel sentence pairs) |
| **Optimizer** | AdamW (lr=3e-4, β₁=0.9, β₂=0.98, wd=0.01) |
| **Schedule** | Cosine with 4K linear warmup |
| **Effective batch size** | 256 (64 × 4 gradient accumulation) |
| **Max steps** | 200,000 |
| **Mixed precision** | FP16 |
| **Gradient clipping** | max_norm=1.0 |
| **Evaluation** | SacreBLEU on WMT14 test set every 20K steps |

## Quick Start

### Install dependencies

```bash
pip install torch transformers datasets trackio sacrebleu sacremoses sentencepiece protobuf
```

### Train

```bash
git clone https://huggingface.co/vedkdev/text-diffusion-en-de
cd text-diffusion-en-de
python train.py
```

The script will:
- Download WMT14 EN-DE automatically
- Train for 200K steps with logging via [Trackio](https://huggingface.co/docs/trackio)
- Evaluate SacreBLEU periodically
- Push checkpoints to this repo

### Adjusting for your hardware

Edit the `TRAIN_CONFIG` dict in `train.py`:

| GPU VRAM | Recommended `batch_size` | `gradient_accumulation_steps` |
|---|---|---|
| 24GB (A10G/3090/4090) | 64 | 4 |
| 16GB (T4/V100) | 32 | 8 |
| 12GB (3060) | 16 | 16 |
| 8GB (3070) | 8 | 32 |

### Inference (after training)

```python
import torch, json
from train import DiffusionTranslator, DiffusionTranslatorConfig, generate
from transformers import AutoTokenizer

# Load checkpoint
config = DiffusionTranslatorConfig(**json.load(open("checkpoints/best/config.json")))
model = DiffusionTranslator(config)
model.load_state_dict(torch.load("checkpoints/best/model.pt", map_location="cpu"))
model.eval()

tokenizer = AutoTokenizer.from_pretrained("checkpoints/best/")

# Translate
text = "The weather is nice today."
src = tokenizer(f"translate English to German: {text}",
                max_length=128, truncation=True, padding="max_length",
                return_tensors="pt")

gen_ids = generate(model, src["input_ids"], torch.zeros_like(src["input_ids"]),
                   config, num_steps=50, device="cpu")
print(tokenizer.decode(gen_ids[0], skip_special_tokens=True))
```

## Expected Results

Based on published literature for similar architectures on WMT14 EN→DE:

| Model | BLEU | Reference |
|---|---|---|
| Autoregressive Transformer | ~27 | Vaswani et al. |
| DiNoiSer (continuous diffusion) | 24.6 | Ye et al. 2023 |
| SeqDiffuSeq | 19.8 | Yuan et al. 2022 |
| E2D2 (discrete diffusion) | 24.8 | Kuleshov et al. 2024 |
| **This model (target)** | **15-20** | ~72M params, no KD |

> Note: Text diffusion models typically score 2-5 BLEU below autoregressive transformers of similar size. Knowledge distillation (KD) from an AR teacher can close the gap by ~1-2 BLEU.

## Citation

If you use this model, please cite the foundational papers:

```bibtex
@article{sahoo2024mdlm,
  title={Simple and Effective Masked Diffusion Language Models},
  author={Sahoo, Subham Sekhar and Arriola, Marianne and Schiff, Yair and Gokaslan, Aaron and Marroquin, Edgar and Kuleshov, Volodymyr},
  journal={NeurIPS},
  year={2024}
}

@article{nie2025llada,
  title={Large Language Diffusion Models},
  author={Nie, Shen and Zhu, Fengqi and You, Chao and Zhang, Xiaojun and Ou, Zhenguo and Zhu, Jun},
  journal={arXiv preprint arXiv:2502.09992},
  year={2025}
}

@article{ye2023dinoiser,
  title={DiNoiSer: Diffused Conditional Sequence Learning by Manipulating Noises},
  author={Ye, Jiasheng and Zheng, Zaixiang and Bao, Yu and Qian, Lihua and Gu, Quanquan},
  journal={ACL},
  year={2023}
}
```