| """ |
| Text Diffusion Model for EN→DE Machine Translation |
| Self-contained training script. |
| Architecture: Masked Discrete Diffusion with DiT backbone |
| Inspired by MDLM (kuleshov-group) + LLaDA conditional generation |
| Dataset: WMT14 EN-DE |
| |
| Usage: |
| pip install torch transformers datasets trackio sacrebleu sacremoses sentencepiece protobuf |
| python train.py |
| """ |
|
|
| import os |
| import math |
| import typing |
| import time |
| import json |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader, Dataset |
| from dataclasses import dataclass |
| from datasets import load_dataset |
| from transformers import AutoTokenizer, get_cosine_schedule_with_warmup |
| import trackio |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class DiffusionTranslatorConfig: |
| vocab_size: int = 32128 |
| max_src_len: int = 128 |
| max_tgt_len: int = 128 |
| hidden_dim: int = 512 |
| n_heads: int = 8 |
| n_blocks: int = 8 |
| dropout: float = 0.1 |
| cond_dim: int = 128 |
| mask_token_id: int = 32100 |
| pad_token_id: int = 0 |
|
|
|
|
| class Rotary(nn.Module): |
| def __init__(self, dim, base=10_000): |
| super().__init__() |
| inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) |
| self.register_buffer('inv_freq', inv_freq) |
| self.seq_len_cached = None |
| self.cos_cached = None |
| self.sin_cached = None |
|
|
| def forward(self, x, seq_dim=1): |
| seq_len = x.shape[seq_dim] |
| if seq_len != self.seq_len_cached: |
| self.seq_len_cached = seq_len |
| t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) |
| freqs = torch.einsum("i,j->ij", t, self.inv_freq) |
| emb = torch.cat((freqs, freqs), dim=-1).to(x.device) |
| self.cos_cached = emb.cos() |
| self.sin_cached = emb.sin() |
| return self.cos_cached, self.sin_cached |
|
|
|
|
| def rotate_half(x): |
| x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| def apply_rotary_pos_emb(q, k, cos, sin): |
| cos = cos[:q.shape[1], :] |
| sin = sin[:q.shape[1], :] |
| cos = cos.unsqueeze(0).unsqueeze(2) |
| sin = sin.unsqueeze(0).unsqueeze(2) |
| q = (q * cos) + (rotate_half(q) * sin) |
| k = (k * cos) + (rotate_half(k) * sin) |
| return q, k |
|
|
|
|
| class TimestepEmbedder(nn.Module): |
| def __init__(self, hidden_size, frequency_embedding_size=256): |
| super().__init__() |
| self.mlp = nn.Sequential( |
| nn.Linear(frequency_embedding_size, hidden_size, bias=True), |
| nn.SiLU(), |
| nn.Linear(hidden_size, hidden_size, bias=True), |
| ) |
| self.frequency_embedding_size = frequency_embedding_size |
|
|
| @staticmethod |
| def timestep_embedding(t, dim, max_period=10000): |
| half = dim // 2 |
| freqs = torch.exp( |
| -math.log(max_period) * torch.arange(0, half, dtype=torch.float32, device=t.device) / half |
| ) |
| args = t[:, None].float() * freqs[None] |
| embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
| if dim % 2: |
| embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) |
| return embedding |
|
|
| def forward(self, t): |
| t_freq = self.timestep_embedding(t, self.frequency_embedding_size) |
| return self.mlp(t_freq) |
|
|
|
|
| class LayerNorm(nn.Module): |
| def __init__(self, dim): |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones([dim])) |
| self.dim = dim |
|
|
| def forward(self, x): |
| with torch.amp.autocast('cuda', enabled=False): |
| x = F.layer_norm(x.float(), [self.dim]) |
| return x * self.weight[None, None, :] |
|
|
|
|
| class DiTBlock(nn.Module): |
| """Diffusion Transformer block with adaptive layer norm (adaLN).""" |
|
|
| def __init__(self, dim, n_heads, cond_dim, mlp_ratio=4, dropout=0.1): |
| super().__init__() |
| self.n_heads = n_heads |
| self.head_dim = dim // n_heads |
|
|
| self.norm1 = LayerNorm(dim) |
| self.q_proj = nn.Linear(dim, dim, bias=False) |
| self.k_proj = nn.Linear(dim, dim, bias=False) |
| self.v_proj = nn.Linear(dim, dim, bias=False) |
| self.attn_out = nn.Linear(dim, dim, bias=False) |
| self.dropout1 = nn.Dropout(dropout) |
|
|
| self.norm2 = LayerNorm(dim) |
| self.mlp = nn.Sequential( |
| nn.Linear(dim, mlp_ratio * dim, bias=True), |
| nn.GELU(approximate='tanh'), |
| nn.Linear(mlp_ratio * dim, dim, bias=True), |
| ) |
| self.dropout2 = nn.Dropout(dropout) |
|
|
| self.adaLN_modulation = nn.Linear(cond_dim, 6 * dim, bias=True) |
| nn.init.zeros_(self.adaLN_modulation.weight) |
| nn.init.zeros_(self.adaLN_modulation.bias) |
|
|
| def forward(self, x, rotary_cos_sin, c, attention_mask=None): |
| batch_size, seq_len, dim = x.shape |
|
|
| mod = self.adaLN_modulation(c)[:, None, :].chunk(6, dim=2) |
| shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod |
|
|
| x_skip = x |
| x_norm = self.norm1(x) * (1 + scale_msa) + shift_msa |
|
|
| q = self.q_proj(x_norm).view(batch_size, seq_len, self.n_heads, self.head_dim) |
| k = self.k_proj(x_norm).view(batch_size, seq_len, self.n_heads, self.head_dim) |
| v = self.v_proj(x_norm).view(batch_size, seq_len, self.n_heads, self.head_dim) |
|
|
| cos, sin = rotary_cos_sin |
| q, k = apply_rotary_pos_emb(q, k, cos, sin) |
|
|
| q = q.transpose(1, 2) |
| k = k.transpose(1, 2) |
| v = v.transpose(1, 2) |
|
|
| |
| attn_output = F.scaled_dot_product_attention( |
| q, k, v, attn_mask=attention_mask, |
| dropout_p=self.dropout1.p if self.training else 0.0 |
| ) |
| attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, dim) |
|
|
| x = x_skip + gate_msa * self.dropout1(self.attn_out(attn_output)) |
|
|
| x_skip = x |
| x_norm = self.norm2(x) * (1 + scale_mlp) + shift_mlp |
| x = x_skip + gate_mlp * self.dropout2(self.mlp(x_norm)) |
|
|
| return x |
|
|
|
|
| class DiffusionTranslator(nn.Module): |
| """ |
| Masked Discrete Diffusion model for EN→DE translation. |
| Input: [source_tokens | target_tokens] where target tokens are partially masked |
| Bidirectional transformer (DiT blocks with adaLN for timestep conditioning) |
| """ |
|
|
| def __init__(self, config: DiffusionTranslatorConfig): |
| super().__init__() |
| self.config = config |
|
|
| self.vocab_embed = nn.Embedding(config.vocab_size, config.hidden_dim) |
| self.sigma_map = TimestepEmbedder(config.cond_dim) |
| self.rotary_emb = Rotary(config.hidden_dim // config.n_heads) |
| self.segment_embed = nn.Embedding(2, config.hidden_dim) |
|
|
| self.blocks = nn.ModuleList([ |
| DiTBlock(config.hidden_dim, config.n_heads, config.cond_dim, dropout=config.dropout) |
| for _ in range(config.n_blocks) |
| ]) |
|
|
| self.final_norm = LayerNorm(config.hidden_dim) |
| self.final_adaLN = nn.Linear(config.cond_dim, 2 * config.hidden_dim, bias=True) |
| nn.init.zeros_(self.final_adaLN.weight) |
| nn.init.zeros_(self.final_adaLN.bias) |
|
|
| self.output_proj = nn.Linear(config.hidden_dim, config.vocab_size, bias=False) |
| self.output_proj.weight = self.vocab_embed.weight |
|
|
| self._init_weights() |
|
|
| def _init_weights(self): |
| for module in self.modules(): |
| if isinstance(module, nn.Linear): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| if module.bias is not None: |
| torch.nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Embedding): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
| def forward(self, input_ids, segment_ids, timesteps): |
| x = self.vocab_embed(input_ids) + self.segment_embed(segment_ids) |
| c = F.silu(self.sigma_map(timesteps)) |
| rotary_cos_sin = self.rotary_emb(x) |
|
|
| for block in self.blocks: |
| x = block(x, rotary_cos_sin, c) |
|
|
| shift, scale = self.final_adaLN(c)[:, None, :].chunk(2, dim=2) |
| x = self.final_norm(x) * (1 + scale) + shift |
| logits = self.output_proj(x) |
| return logits |
|
|
| def count_parameters(self): |
| return sum(p.numel() for p in self.parameters() if p.requires_grad) |
|
|
|
|
| def compute_diffusion_loss(model, input_ids, segment_ids, target_ids, target_mask, config): |
| """Compute masked diffusion training loss (LLaDA-style ELBO).""" |
| batch_size = input_ids.shape[0] |
| device = input_ids.device |
|
|
| eps = 1e-5 |
| t = torch.rand(batch_size, device=device) * (1 - eps) + eps |
|
|
| mask_prob = t[:, None].expand_as(target_mask) |
| random_mask = torch.rand_like(mask_prob) < mask_prob |
| diffusion_mask = random_mask & target_mask |
|
|
| noised_input = input_ids.clone() |
| noised_input[diffusion_mask] = config.mask_token_id |
|
|
| logits = model(noised_input, segment_ids, t) |
|
|
| logits_flat = logits.view(-1, config.vocab_size) |
| targets_flat = target_ids.view(-1) |
|
|
| if diffusion_mask.sum() == 0: |
| zero = torch.tensor(0.0, device=device, requires_grad=True) |
| return zero, zero |
|
|
| ce_loss = F.cross_entropy(logits_flat, targets_flat, reduction='none') |
|
|
| masked_loss_2d = ce_loss.view(batch_size, -1) * diffusion_mask.float() |
| per_example_counts = diffusion_mask.float().sum(dim=1).clamp(min=1.0) |
| per_example_loss = masked_loss_2d.sum(dim=1) / per_example_counts |
|
|
| weighted_loss = (per_example_loss / t).mean() |
| unweighted_loss = per_example_loss.mean() |
|
|
| return weighted_loss, unweighted_loss |
|
|
|
|
| @torch.no_grad() |
| def generate(model, src_ids, src_segment_ids, config, num_steps=50, device='cuda'): |
| """Generate translation using iterative unmasking.""" |
| model.eval() |
| batch_size = src_ids.shape[0] |
| tgt_len = config.max_tgt_len |
|
|
| tgt_ids = torch.full((batch_size, tgt_len), config.mask_token_id, device=device) |
| tgt_segment_ids = torch.ones(batch_size, tgt_len, dtype=torch.long, device=device) |
|
|
| input_ids = torch.cat([src_ids, tgt_ids], dim=1) |
| segment_ids = torch.cat([src_segment_ids, tgt_segment_ids], dim=1) |
| src_len = src_ids.shape[1] |
|
|
| for step in range(num_steps, 0, -1): |
| t = torch.tensor([step / num_steps], device=device).expand(batch_size) |
| s = torch.tensor([(step - 1) / num_steps], device=device).expand(batch_size) |
|
|
| logits = model(input_ids, segment_ids, t) |
| tgt_logits = logits[:, src_len:, :] |
| predicted_tokens = tgt_logits.argmax(dim=-1) |
|
|
| current_tgt = input_ids[:, src_len:] |
| still_masked = (current_tgt == config.mask_token_id) |
|
|
| if step > 1: |
| remask_prob = s[0].item() / t[0].item() if t[0].item() > 0 else 0.0 |
| remask = torch.rand_like(predicted_tokens.float()) < remask_prob |
| new_tgt = current_tgt.clone() |
| unmask_positions = still_masked & ~remask |
| new_tgt[unmask_positions] = predicted_tokens[unmask_positions] |
| else: |
| new_tgt = current_tgt.clone() |
| new_tgt[still_masked] = predicted_tokens[still_masked] |
|
|
| input_ids = torch.cat([src_ids, new_tgt], dim=1) |
|
|
| return input_ids[:, src_len:] |
|
|
|
|
| |
| |
| |
|
|
| class WMT14EnDeDataset(Dataset): |
| def __init__(self, data, tokenizer, max_src_len=128, max_tgt_len=128): |
| self.data = data |
| self.tokenizer = tokenizer |
| self.max_src_len = max_src_len |
| self.max_tgt_len = max_tgt_len |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def __getitem__(self, idx): |
| item = self.data[idx] |
| en_text = item['translation']['en'] |
| de_text = item['translation']['de'] |
|
|
| src_enc = self.tokenizer( |
| "translate English to German: " + en_text, |
| max_length=self.max_src_len, truncation=True, |
| padding='max_length', return_tensors=None, |
| ) |
| tgt_enc = self.tokenizer( |
| de_text, |
| max_length=self.max_tgt_len, truncation=True, |
| padding='max_length', return_tensors=None, |
| ) |
|
|
| src_ids = src_enc['input_ids'] |
| tgt_ids = tgt_enc['input_ids'] |
| segment_ids = [0] * len(src_ids) + [1] * len(tgt_ids) |
| full_ids = src_ids + tgt_ids |
| target_mask = [0] * len(src_ids) + tgt_enc['attention_mask'] |
|
|
| return { |
| 'input_ids': torch.tensor(full_ids, dtype=torch.long), |
| 'segment_ids': torch.tensor(segment_ids, dtype=torch.long), |
| 'target_ids': torch.tensor(full_ids, dtype=torch.long), |
| 'target_mask': torch.tensor(target_mask, dtype=torch.bool), |
| } |
|
|
|
|
| |
| |
| |
|
|
| MODEL_CONFIG = dict( |
| vocab_size=None, |
| max_src_len=128, |
| max_tgt_len=128, |
| hidden_dim=512, |
| n_heads=8, |
| n_blocks=12, |
| dropout=0.1, |
| cond_dim=128, |
| mask_token_id=None, |
| pad_token_id=None, |
| ) |
|
|
| TRAIN_CONFIG = dict( |
| learning_rate=3e-4, |
| weight_decay=0.01, |
| warmup_steps=4000, |
| max_steps=200_000, |
| batch_size=64, |
| gradient_accumulation_steps=4, |
| eval_every=5000, |
| save_every=10000, |
| log_every=100, |
| max_grad_norm=1.0, |
| num_gen_steps=50, |
| fp16=True, |
| seed=42, |
| ) |
|
|
| HUB_MODEL_ID = "vedkdev/text-diffusion-en-de" |
| TOKENIZER_NAME = "Helsinki-NLP/opus-mt-en-de" |
|
|
|
|
| |
| |
| |
|
|
| def train(): |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| torch.manual_seed(TRAIN_CONFIG['seed']) |
|
|
| print(f"Device: {device}") |
| print(f"Loading tokenizer: {TOKENIZER_NAME}") |
|
|
| tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME) |
| if tokenizer.mask_token is None: |
| tokenizer.add_special_tokens({'mask_token': '<mask>'}) |
|
|
| mask_token_id = tokenizer.mask_token_id |
| pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 |
|
|
| print(f"Vocab size: {len(tokenizer)}") |
| print(f"Mask token ID: {mask_token_id}, Pad token ID: {pad_token_id}") |
|
|
| MODEL_CONFIG['vocab_size'] = len(tokenizer) |
| MODEL_CONFIG['mask_token_id'] = mask_token_id |
| MODEL_CONFIG['pad_token_id'] = pad_token_id |
| config = DiffusionTranslatorConfig(**MODEL_CONFIG) |
|
|
| model = DiffusionTranslator(config).to(device) |
| print(f"Model parameters: {model.count_parameters():,}") |
|
|
| print("Loading WMT14 EN-DE dataset...") |
| dataset = load_dataset("wmt/wmt14", "de-en", trust_remote_code=True) |
| train_data = dataset['train'] |
| val_data = dataset['validation'] |
| print(f"Train: {len(train_data):,} | Val: {len(val_data):,}") |
|
|
| train_dataset = WMT14EnDeDataset(train_data, tokenizer, config.max_src_len, config.max_tgt_len) |
| val_dataset = WMT14EnDeDataset(val_data, tokenizer, config.max_src_len, config.max_tgt_len) |
|
|
| train_loader = DataLoader(train_dataset, batch_size=TRAIN_CONFIG['batch_size'], |
| shuffle=True, num_workers=4, pin_memory=True, drop_last=True) |
| val_loader = DataLoader(val_dataset, batch_size=TRAIN_CONFIG['batch_size'], |
| shuffle=False, num_workers=2, pin_memory=True) |
|
|
| optimizer = torch.optim.AdamW(model.parameters(), lr=TRAIN_CONFIG['learning_rate'], |
| weight_decay=TRAIN_CONFIG['weight_decay'], betas=(0.9, 0.98), eps=1e-8) |
| scheduler = get_cosine_schedule_with_warmup(optimizer, |
| num_warmup_steps=TRAIN_CONFIG['warmup_steps'], |
| num_training_steps=TRAIN_CONFIG['max_steps']) |
|
|
| scaler = torch.amp.GradScaler('cuda') if (TRAIN_CONFIG['fp16'] and device.type == 'cuda') else None |
|
|
| trackio.init(project="text-diffusion-en-de", name="v1-wmt14-dit12-512d") |
|
|
| global_step = 0 |
| best_val_loss = float('inf') |
| accum_loss = 0.0 |
| accum_loss_uw = 0.0 |
| accum_count = 0 |
|
|
| eff_bs = TRAIN_CONFIG['batch_size'] * TRAIN_CONFIG['gradient_accumulation_steps'] |
| print(f"\n=== Starting Training ===") |
| print(f"Effective batch size: {eff_bs} | Max steps: {TRAIN_CONFIG['max_steps']:,}") |
| print(f"Warmup: {TRAIN_CONFIG['warmup_steps']:,} | LR: {TRAIN_CONFIG['learning_rate']}") |
|
|
| model.train() |
| optimizer.zero_grad() |
| data_iter = iter(train_loader) |
| start_time = time.time() |
|
|
| total_micro_steps = TRAIN_CONFIG['max_steps'] * TRAIN_CONFIG['gradient_accumulation_steps'] |
| for step in range(1, total_micro_steps + 1): |
| try: |
| batch = next(data_iter) |
| except StopIteration: |
| data_iter = iter(train_loader) |
| batch = next(data_iter) |
|
|
| input_ids = batch['input_ids'].to(device) |
| segment_ids = batch['segment_ids'].to(device) |
| target_ids = batch['target_ids'].to(device) |
| target_mask = batch['target_mask'].to(device) |
|
|
| if scaler is not None: |
| with torch.amp.autocast('cuda'): |
| wl, uwl = compute_diffusion_loss(model, input_ids, segment_ids, target_ids, target_mask, config) |
| loss = wl / TRAIN_CONFIG['gradient_accumulation_steps'] |
| scaler.scale(loss).backward() |
| else: |
| wl, uwl = compute_diffusion_loss(model, input_ids, segment_ids, target_ids, target_mask, config) |
| loss = wl / TRAIN_CONFIG['gradient_accumulation_steps'] |
| loss.backward() |
|
|
| accum_loss += wl.item() |
| accum_loss_uw += uwl.item() |
| accum_count += 1 |
|
|
| if step % TRAIN_CONFIG['gradient_accumulation_steps'] == 0: |
| if scaler is not None: |
| scaler.unscale_(optimizer) |
| torch.nn.utils.clip_grad_norm_(model.parameters(), TRAIN_CONFIG['max_grad_norm']) |
| scaler.step(optimizer) |
| scaler.update() |
| else: |
| torch.nn.utils.clip_grad_norm_(model.parameters(), TRAIN_CONFIG['max_grad_norm']) |
| optimizer.step() |
|
|
| scheduler.step() |
| optimizer.zero_grad() |
| global_step += 1 |
|
|
| |
| if global_step % TRAIN_CONFIG['log_every'] == 0: |
| avg_l = accum_loss / accum_count |
| avg_uw = accum_loss_uw / accum_count |
| elapsed = time.time() - start_time |
| sps = global_step / elapsed |
| lr = scheduler.get_last_lr()[0] |
| print(f"step={global_step} | loss={avg_l:.4f} | ce_loss={avg_uw:.4f} | lr={lr:.2e} | steps/s={sps:.2f}") |
| trackio.log({"train/loss_weighted": avg_l, "train/loss_ce": avg_uw, |
| "train/learning_rate": lr, "train/steps_per_sec": sps}, step=global_step) |
| accum_loss = 0.0 |
| accum_loss_uw = 0.0 |
| accum_count = 0 |
|
|
| |
| if global_step % TRAIN_CONFIG['eval_every'] == 0: |
| vl, vuw = evaluate(model, val_loader, config, device, scaler is not None) |
| print(f" [EVAL] step={global_step} | val_loss={vl:.4f} | val_ce={vuw:.4f}") |
| trackio.log({"eval/loss_weighted": vl, "eval/loss_ce": vuw}, step=global_step) |
|
|
| if global_step % (TRAIN_CONFIG['eval_every'] * 4) == 0: |
| bleu = evaluate_bleu(model, tokenizer, config, device, num_samples=100, |
| num_steps=TRAIN_CONFIG['num_gen_steps']) |
| trackio.log({"eval/sacrebleu": bleu}, step=global_step) |
|
|
| if vuw < best_val_loss: |
| best_val_loss = vuw |
| save_model(model, config, tokenizer, global_step, is_best=True) |
| model.train() |
|
|
| |
| if global_step % TRAIN_CONFIG['save_every'] == 0: |
| save_model(model, config, tokenizer, global_step, push_to_hub=True) |
|
|
| |
| save_model(model, config, tokenizer, global_step, push_to_hub=True) |
| print("\n=== Final BLEU Evaluation ===") |
| bleu = evaluate_bleu(model, tokenizer, config, device, num_samples=200, |
| num_steps=TRAIN_CONFIG['num_gen_steps']) |
| trackio.log({"eval/final_sacrebleu": bleu}, step=global_step) |
| print(f"\n=== Training Complete === Final BLEU: {bleu:.2f}") |
|
|
|
|
| def evaluate(model, val_loader, config, device, use_fp16=True): |
| model.eval() |
| total_l = total_uw = 0.0 |
| count = 0 |
| with torch.no_grad(): |
| for i, batch in enumerate(val_loader): |
| if i >= 50: |
| break |
| ids = batch['input_ids'].to(device) |
| seg = batch['segment_ids'].to(device) |
| tgt = batch['target_ids'].to(device) |
| mask = batch['target_mask'].to(device) |
| if use_fp16: |
| with torch.amp.autocast('cuda'): |
| wl, uwl = compute_diffusion_loss(model, ids, seg, tgt, mask, config) |
| else: |
| wl, uwl = compute_diffusion_loss(model, ids, seg, tgt, mask, config) |
| total_l += wl.item() |
| total_uw += uwl.item() |
| count += 1 |
| return total_l / max(count, 1), total_uw / max(count, 1) |
|
|
|
|
| def evaluate_bleu(model, tokenizer, config, device, num_samples=100, num_steps=50): |
| import sacrebleu |
| model.eval() |
| ds = load_dataset("wmt/wmt14", "de-en", split="test", trust_remote_code=True) |
| refs, hyps = [], [] |
| for i in range(min(num_samples, len(ds))): |
| en = ds[i]['translation']['en'] |
| de_ref = ds[i]['translation']['de'] |
| enc = tokenizer("translate English to German: " + en, max_length=config.max_src_len, |
| truncation=True, padding='max_length', return_tensors='pt') |
| src_ids = enc['input_ids'].to(device) |
| src_seg = torch.zeros_like(src_ids) |
| with torch.no_grad(): |
| if device.type == 'cuda': |
| with torch.amp.autocast('cuda'): |
| gen = generate(model, src_ids, src_seg, config, num_steps=num_steps, device=device) |
| else: |
| gen = generate(model, src_ids, src_seg, config, num_steps=num_steps, device=device) |
| hyp = tokenizer.decode(gen[0], skip_special_tokens=True) |
| refs.append(de_ref) |
| hyps.append(hyp) |
| if i < 5: |
| print(f" EN: {en[:100]}") |
| print(f" REF: {de_ref[:100]}") |
| print(f" GEN: {hyp[:100]}") |
| print() |
| bleu = sacrebleu.corpus_bleu(hyps, [refs]) |
| print(f"SacreBLEU: {bleu.score:.2f}") |
| return bleu.score |
|
|
|
|
| def save_model(model, config, tokenizer, step, is_best=False, push_to_hub=False): |
| save_dir = "checkpoints/best" if is_best else f"checkpoints/step-{step}" |
| os.makedirs(save_dir, exist_ok=True) |
| torch.save(model.state_dict(), os.path.join(save_dir, "model.pt")) |
| config_dict = {k: getattr(config, k) for k in [ |
| 'vocab_size', 'max_src_len', 'max_tgt_len', 'hidden_dim', |
| 'n_heads', 'n_blocks', 'dropout', 'cond_dim', 'mask_token_id', 'pad_token_id' |
| ]} |
| with open(os.path.join(save_dir, "config.json"), "w") as f: |
| json.dump(config_dict, f, indent=2) |
| tokenizer.save_pretrained(save_dir) |
| if push_to_hub: |
| push_model_to_hub(save_dir, step, config) |
| print(f" Saved checkpoint to {save_dir}") |
|
|
|
|
| def push_model_to_hub(save_dir, step, config): |
| from huggingface_hub import HfApi, upload_folder |
| api = HfApi() |
| try: |
| api.create_repo(HUB_MODEL_ID, exist_ok=True, private=False) |
| except Exception as e: |
| print(f" Warning creating repo: {e}") |
|
|
| readme = f"""--- |
| tags: |
| - text-diffusion |
| - machine-translation |
| - en-de |
| - masked-diffusion |
| language: |
| - en |
| - de |
| datasets: |
| - wmt/wmt14 |
| --- |
| |
| # Text Diffusion Model for EN→DE Translation |
| |
| A **masked discrete diffusion** model for English-to-German machine translation, |
| trained from scratch on WMT14 EN-DE. |
| |
| ## Architecture |
| - **Type**: Masked Discrete Diffusion (inspired by MDLM + LLaDA) |
| - **Backbone**: DiT (Diffusion Transformer) with adaptive LayerNorm (adaLN) |
| - **Parameters**: ~72M |
| - **Blocks**: {config.n_blocks} DiT blocks, hidden_dim={config.hidden_dim}, heads={config.n_heads} |
| - **Tokenizer**: {TOKENIZER_NAME} (~58K vocab) |
| - **Max sequence**: {config.max_src_len} src + {config.max_tgt_len} tgt tokens |
| |
| ## Training |
| - **Dataset**: WMT14 EN-DE (~4.5M pairs) |
| - **Method**: Masked discrete diffusion with ELBO weighting (1/t) |
| - **Optimizer**: AdamW, lr=3e-4, cosine with 4K warmup |
| - **Effective batch size**: {TRAIN_CONFIG['batch_size'] * TRAIN_CONFIG['gradient_accumulation_steps']} |
| - **Training steps**: {step:,} |
| |
| ## How It Works |
| 1. Source (EN) + target (DE) tokens concatenated: `[source | target]` |
| 2. Training: target tokens randomly masked with prob `t ~ U(0,1)`, predict masked tokens |
| 3. Inference: start fully masked → iteratively unmask over {TRAIN_CONFIG['num_gen_steps']} steps |
| |
| ## References |
| - [MDLM](https://arxiv.org/abs/2406.07524) | [LLaDA](https://arxiv.org/abs/2502.09992) | [DiNoiSer](https://arxiv.org/abs/2302.10025) |
| """ |
| with open(os.path.join(save_dir, "README.md"), "w") as f: |
| f.write(readme) |
|
|
| upload_folder(repo_id=HUB_MODEL_ID, folder_path=save_dir, |
| commit_message=f"Checkpoint step {step}") |
| print(f" Pushed to hub: https://huggingface.co/{HUB_MODEL_ID}") |
|
|
|
|
| if __name__ == "__main__": |
| train() |
|
|