vedkdev's picture
Add training script
6977714 verified
"""
Text Diffusion Model for EN→DE Machine Translation
Self-contained training script.
Architecture: Masked Discrete Diffusion with DiT backbone
Inspired by MDLM (kuleshov-group) + LLaDA conditional generation
Dataset: WMT14 EN-DE
Usage:
pip install torch transformers datasets trackio sacrebleu sacremoses sentencepiece protobuf
python train.py
"""
import os
import math
import typing
import time
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from dataclasses import dataclass
from datasets import load_dataset
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
import trackio
# ═══════════════════════════════════════════════════════════════
# MODEL ARCHITECTURE
# ═══════════════════════════════════════════════════════════════
@dataclass
class DiffusionTranslatorConfig:
vocab_size: int = 32128
max_src_len: int = 128
max_tgt_len: int = 128
hidden_dim: int = 512
n_heads: int = 8
n_blocks: int = 8
dropout: float = 0.1
cond_dim: int = 128
mask_token_id: int = 32100
pad_token_id: int = 0
class Rotary(nn.Module):
def __init__(self, dim, base=10_000):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
self.seq_len_cached = None
self.cos_cached = None
self.sin_cached = None
def forward(self, x, seq_dim=1):
seq_len = x.shape[seq_dim]
if seq_len != self.seq_len_cached:
self.seq_len_cached = seq_len
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.cos_cached = emb.cos()
self.sin_cached = emb.sin()
return self.cos_cached, self.sin_cached
def rotate_half(x):
x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin):
cos = cos[:q.shape[1], :]
sin = sin[:q.shape[1], :]
cos = cos.unsqueeze(0).unsqueeze(2)
sin = sin.unsqueeze(0).unsqueeze(2)
q = (q * cos) + (rotate_half(q) * sin)
k = (k * cos) + (rotate_half(k) * sin)
return q, k
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(0, half, dtype=torch.float32, device=t.device) / half
)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
return self.mlp(t_freq)
class LayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.weight = nn.Parameter(torch.ones([dim]))
self.dim = dim
def forward(self, x):
with torch.amp.autocast('cuda', enabled=False):
x = F.layer_norm(x.float(), [self.dim])
return x * self.weight[None, None, :]
class DiTBlock(nn.Module):
"""Diffusion Transformer block with adaptive layer norm (adaLN)."""
def __init__(self, dim, n_heads, cond_dim, mlp_ratio=4, dropout=0.1):
super().__init__()
self.n_heads = n_heads
self.head_dim = dim // n_heads
self.norm1 = LayerNorm(dim)
self.q_proj = nn.Linear(dim, dim, bias=False)
self.k_proj = nn.Linear(dim, dim, bias=False)
self.v_proj = nn.Linear(dim, dim, bias=False)
self.attn_out = nn.Linear(dim, dim, bias=False)
self.dropout1 = nn.Dropout(dropout)
self.norm2 = LayerNorm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_ratio * dim, bias=True),
nn.GELU(approximate='tanh'),
nn.Linear(mlp_ratio * dim, dim, bias=True),
)
self.dropout2 = nn.Dropout(dropout)
self.adaLN_modulation = nn.Linear(cond_dim, 6 * dim, bias=True)
nn.init.zeros_(self.adaLN_modulation.weight)
nn.init.zeros_(self.adaLN_modulation.bias)
def forward(self, x, rotary_cos_sin, c, attention_mask=None):
batch_size, seq_len, dim = x.shape
mod = self.adaLN_modulation(c)[:, None, :].chunk(6, dim=2)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod
x_skip = x
x_norm = self.norm1(x) * (1 + scale_msa) + shift_msa
q = self.q_proj(x_norm).view(batch_size, seq_len, self.n_heads, self.head_dim)
k = self.k_proj(x_norm).view(batch_size, seq_len, self.n_heads, self.head_dim)
v = self.v_proj(x_norm).view(batch_size, seq_len, self.n_heads, self.head_dim)
cos, sin = rotary_cos_sin
q, k = apply_rotary_pos_emb(q, k, cos, sin)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# Bidirectional attention (no causal mask)
attn_output = F.scaled_dot_product_attention(
q, k, v, attn_mask=attention_mask,
dropout_p=self.dropout1.p if self.training else 0.0
)
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, dim)
x = x_skip + gate_msa * self.dropout1(self.attn_out(attn_output))
x_skip = x
x_norm = self.norm2(x) * (1 + scale_mlp) + shift_mlp
x = x_skip + gate_mlp * self.dropout2(self.mlp(x_norm))
return x
class DiffusionTranslator(nn.Module):
"""
Masked Discrete Diffusion model for EN→DE translation.
Input: [source_tokens | target_tokens] where target tokens are partially masked
Bidirectional transformer (DiT blocks with adaLN for timestep conditioning)
"""
def __init__(self, config: DiffusionTranslatorConfig):
super().__init__()
self.config = config
self.vocab_embed = nn.Embedding(config.vocab_size, config.hidden_dim)
self.sigma_map = TimestepEmbedder(config.cond_dim)
self.rotary_emb = Rotary(config.hidden_dim // config.n_heads)
self.segment_embed = nn.Embedding(2, config.hidden_dim)
self.blocks = nn.ModuleList([
DiTBlock(config.hidden_dim, config.n_heads, config.cond_dim, dropout=config.dropout)
for _ in range(config.n_blocks)
])
self.final_norm = LayerNorm(config.hidden_dim)
self.final_adaLN = nn.Linear(config.cond_dim, 2 * config.hidden_dim, bias=True)
nn.init.zeros_(self.final_adaLN.weight)
nn.init.zeros_(self.final_adaLN.bias)
self.output_proj = nn.Linear(config.hidden_dim, config.vocab_size, bias=False)
self.output_proj.weight = self.vocab_embed.weight # Weight tying
self._init_weights()
def _init_weights(self):
for module in self.modules():
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, input_ids, segment_ids, timesteps):
x = self.vocab_embed(input_ids) + self.segment_embed(segment_ids)
c = F.silu(self.sigma_map(timesteps))
rotary_cos_sin = self.rotary_emb(x)
for block in self.blocks:
x = block(x, rotary_cos_sin, c)
shift, scale = self.final_adaLN(c)[:, None, :].chunk(2, dim=2)
x = self.final_norm(x) * (1 + scale) + shift
logits = self.output_proj(x)
return logits
def count_parameters(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)
def compute_diffusion_loss(model, input_ids, segment_ids, target_ids, target_mask, config):
"""Compute masked diffusion training loss (LLaDA-style ELBO)."""
batch_size = input_ids.shape[0]
device = input_ids.device
eps = 1e-5
t = torch.rand(batch_size, device=device) * (1 - eps) + eps
mask_prob = t[:, None].expand_as(target_mask)
random_mask = torch.rand_like(mask_prob) < mask_prob
diffusion_mask = random_mask & target_mask
noised_input = input_ids.clone()
noised_input[diffusion_mask] = config.mask_token_id
logits = model(noised_input, segment_ids, t)
logits_flat = logits.view(-1, config.vocab_size)
targets_flat = target_ids.view(-1)
if diffusion_mask.sum() == 0:
zero = torch.tensor(0.0, device=device, requires_grad=True)
return zero, zero
ce_loss = F.cross_entropy(logits_flat, targets_flat, reduction='none')
masked_loss_2d = ce_loss.view(batch_size, -1) * diffusion_mask.float()
per_example_counts = diffusion_mask.float().sum(dim=1).clamp(min=1.0)
per_example_loss = masked_loss_2d.sum(dim=1) / per_example_counts
weighted_loss = (per_example_loss / t).mean()
unweighted_loss = per_example_loss.mean()
return weighted_loss, unweighted_loss
@torch.no_grad()
def generate(model, src_ids, src_segment_ids, config, num_steps=50, device='cuda'):
"""Generate translation using iterative unmasking."""
model.eval()
batch_size = src_ids.shape[0]
tgt_len = config.max_tgt_len
tgt_ids = torch.full((batch_size, tgt_len), config.mask_token_id, device=device)
tgt_segment_ids = torch.ones(batch_size, tgt_len, dtype=torch.long, device=device)
input_ids = torch.cat([src_ids, tgt_ids], dim=1)
segment_ids = torch.cat([src_segment_ids, tgt_segment_ids], dim=1)
src_len = src_ids.shape[1]
for step in range(num_steps, 0, -1):
t = torch.tensor([step / num_steps], device=device).expand(batch_size)
s = torch.tensor([(step - 1) / num_steps], device=device).expand(batch_size)
logits = model(input_ids, segment_ids, t)
tgt_logits = logits[:, src_len:, :]
predicted_tokens = tgt_logits.argmax(dim=-1)
current_tgt = input_ids[:, src_len:]
still_masked = (current_tgt == config.mask_token_id)
if step > 1:
remask_prob = s[0].item() / t[0].item() if t[0].item() > 0 else 0.0
remask = torch.rand_like(predicted_tokens.float()) < remask_prob
new_tgt = current_tgt.clone()
unmask_positions = still_masked & ~remask
new_tgt[unmask_positions] = predicted_tokens[unmask_positions]
else:
new_tgt = current_tgt.clone()
new_tgt[still_masked] = predicted_tokens[still_masked]
input_ids = torch.cat([src_ids, new_tgt], dim=1)
return input_ids[:, src_len:]
# ═══════════════════════════════════════════════════════════════
# DATASET
# ═══════════════════════════════════════════════════════════════
class WMT14EnDeDataset(Dataset):
def __init__(self, data, tokenizer, max_src_len=128, max_tgt_len=128):
self.data = data
self.tokenizer = tokenizer
self.max_src_len = max_src_len
self.max_tgt_len = max_tgt_len
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
en_text = item['translation']['en']
de_text = item['translation']['de']
src_enc = self.tokenizer(
"translate English to German: " + en_text,
max_length=self.max_src_len, truncation=True,
padding='max_length', return_tensors=None,
)
tgt_enc = self.tokenizer(
de_text,
max_length=self.max_tgt_len, truncation=True,
padding='max_length', return_tensors=None,
)
src_ids = src_enc['input_ids']
tgt_ids = tgt_enc['input_ids']
segment_ids = [0] * len(src_ids) + [1] * len(tgt_ids)
full_ids = src_ids + tgt_ids
target_mask = [0] * len(src_ids) + tgt_enc['attention_mask']
return {
'input_ids': torch.tensor(full_ids, dtype=torch.long),
'segment_ids': torch.tensor(segment_ids, dtype=torch.long),
'target_ids': torch.tensor(full_ids, dtype=torch.long),
'target_mask': torch.tensor(target_mask, dtype=torch.bool),
}
# ═══════════════════════════════════════════════════════════════
# CONFIGURATION
# ═══════════════════════════════════════════════════════════════
MODEL_CONFIG = dict(
vocab_size=None,
max_src_len=128,
max_tgt_len=128,
hidden_dim=512,
n_heads=8,
n_blocks=12,
dropout=0.1,
cond_dim=128,
mask_token_id=None,
pad_token_id=None,
)
TRAIN_CONFIG = dict(
learning_rate=3e-4,
weight_decay=0.01,
warmup_steps=4000,
max_steps=200_000,
batch_size=64,
gradient_accumulation_steps=4,
eval_every=5000,
save_every=10000,
log_every=100,
max_grad_norm=1.0,
num_gen_steps=50,
fp16=True,
seed=42,
)
HUB_MODEL_ID = "vedkdev/text-diffusion-en-de"
TOKENIZER_NAME = "Helsinki-NLP/opus-mt-en-de"
# ═══════════════════════════════════════════════════════════════
# TRAINING
# ═══════════════════════════════════════════════════════════════
def train():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(TRAIN_CONFIG['seed'])
print(f"Device: {device}")
print(f"Loading tokenizer: {TOKENIZER_NAME}")
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
if tokenizer.mask_token is None:
tokenizer.add_special_tokens({'mask_token': '<mask>'})
mask_token_id = tokenizer.mask_token_id
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
print(f"Vocab size: {len(tokenizer)}")
print(f"Mask token ID: {mask_token_id}, Pad token ID: {pad_token_id}")
MODEL_CONFIG['vocab_size'] = len(tokenizer)
MODEL_CONFIG['mask_token_id'] = mask_token_id
MODEL_CONFIG['pad_token_id'] = pad_token_id
config = DiffusionTranslatorConfig(**MODEL_CONFIG)
model = DiffusionTranslator(config).to(device)
print(f"Model parameters: {model.count_parameters():,}")
print("Loading WMT14 EN-DE dataset...")
dataset = load_dataset("wmt/wmt14", "de-en", trust_remote_code=True)
train_data = dataset['train']
val_data = dataset['validation']
print(f"Train: {len(train_data):,} | Val: {len(val_data):,}")
train_dataset = WMT14EnDeDataset(train_data, tokenizer, config.max_src_len, config.max_tgt_len)
val_dataset = WMT14EnDeDataset(val_data, tokenizer, config.max_src_len, config.max_tgt_len)
train_loader = DataLoader(train_dataset, batch_size=TRAIN_CONFIG['batch_size'],
shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=TRAIN_CONFIG['batch_size'],
shuffle=False, num_workers=2, pin_memory=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=TRAIN_CONFIG['learning_rate'],
weight_decay=TRAIN_CONFIG['weight_decay'], betas=(0.9, 0.98), eps=1e-8)
scheduler = get_cosine_schedule_with_warmup(optimizer,
num_warmup_steps=TRAIN_CONFIG['warmup_steps'],
num_training_steps=TRAIN_CONFIG['max_steps'])
scaler = torch.amp.GradScaler('cuda') if (TRAIN_CONFIG['fp16'] and device.type == 'cuda') else None
trackio.init(project="text-diffusion-en-de", name="v1-wmt14-dit12-512d")
global_step = 0
best_val_loss = float('inf')
accum_loss = 0.0
accum_loss_uw = 0.0
accum_count = 0
eff_bs = TRAIN_CONFIG['batch_size'] * TRAIN_CONFIG['gradient_accumulation_steps']
print(f"\n=== Starting Training ===")
print(f"Effective batch size: {eff_bs} | Max steps: {TRAIN_CONFIG['max_steps']:,}")
print(f"Warmup: {TRAIN_CONFIG['warmup_steps']:,} | LR: {TRAIN_CONFIG['learning_rate']}")
model.train()
optimizer.zero_grad()
data_iter = iter(train_loader)
start_time = time.time()
total_micro_steps = TRAIN_CONFIG['max_steps'] * TRAIN_CONFIG['gradient_accumulation_steps']
for step in range(1, total_micro_steps + 1):
try:
batch = next(data_iter)
except StopIteration:
data_iter = iter(train_loader)
batch = next(data_iter)
input_ids = batch['input_ids'].to(device)
segment_ids = batch['segment_ids'].to(device)
target_ids = batch['target_ids'].to(device)
target_mask = batch['target_mask'].to(device)
if scaler is not None:
with torch.amp.autocast('cuda'):
wl, uwl = compute_diffusion_loss(model, input_ids, segment_ids, target_ids, target_mask, config)
loss = wl / TRAIN_CONFIG['gradient_accumulation_steps']
scaler.scale(loss).backward()
else:
wl, uwl = compute_diffusion_loss(model, input_ids, segment_ids, target_ids, target_mask, config)
loss = wl / TRAIN_CONFIG['gradient_accumulation_steps']
loss.backward()
accum_loss += wl.item()
accum_loss_uw += uwl.item()
accum_count += 1
if step % TRAIN_CONFIG['gradient_accumulation_steps'] == 0:
if scaler is not None:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), TRAIN_CONFIG['max_grad_norm'])
scaler.step(optimizer)
scaler.update()
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), TRAIN_CONFIG['max_grad_norm'])
optimizer.step()
scheduler.step()
optimizer.zero_grad()
global_step += 1
# Log
if global_step % TRAIN_CONFIG['log_every'] == 0:
avg_l = accum_loss / accum_count
avg_uw = accum_loss_uw / accum_count
elapsed = time.time() - start_time
sps = global_step / elapsed
lr = scheduler.get_last_lr()[0]
print(f"step={global_step} | loss={avg_l:.4f} | ce_loss={avg_uw:.4f} | lr={lr:.2e} | steps/s={sps:.2f}")
trackio.log({"train/loss_weighted": avg_l, "train/loss_ce": avg_uw,
"train/learning_rate": lr, "train/steps_per_sec": sps}, step=global_step)
accum_loss = 0.0
accum_loss_uw = 0.0
accum_count = 0
# Eval
if global_step % TRAIN_CONFIG['eval_every'] == 0:
vl, vuw = evaluate(model, val_loader, config, device, scaler is not None)
print(f" [EVAL] step={global_step} | val_loss={vl:.4f} | val_ce={vuw:.4f}")
trackio.log({"eval/loss_weighted": vl, "eval/loss_ce": vuw}, step=global_step)
if global_step % (TRAIN_CONFIG['eval_every'] * 4) == 0:
bleu = evaluate_bleu(model, tokenizer, config, device, num_samples=100,
num_steps=TRAIN_CONFIG['num_gen_steps'])
trackio.log({"eval/sacrebleu": bleu}, step=global_step)
if vuw < best_val_loss:
best_val_loss = vuw
save_model(model, config, tokenizer, global_step, is_best=True)
model.train()
# Save + push
if global_step % TRAIN_CONFIG['save_every'] == 0:
save_model(model, config, tokenizer, global_step, push_to_hub=True)
# Final
save_model(model, config, tokenizer, global_step, push_to_hub=True)
print("\n=== Final BLEU Evaluation ===")
bleu = evaluate_bleu(model, tokenizer, config, device, num_samples=200,
num_steps=TRAIN_CONFIG['num_gen_steps'])
trackio.log({"eval/final_sacrebleu": bleu}, step=global_step)
print(f"\n=== Training Complete === Final BLEU: {bleu:.2f}")
def evaluate(model, val_loader, config, device, use_fp16=True):
model.eval()
total_l = total_uw = 0.0
count = 0
with torch.no_grad():
for i, batch in enumerate(val_loader):
if i >= 50:
break
ids = batch['input_ids'].to(device)
seg = batch['segment_ids'].to(device)
tgt = batch['target_ids'].to(device)
mask = batch['target_mask'].to(device)
if use_fp16:
with torch.amp.autocast('cuda'):
wl, uwl = compute_diffusion_loss(model, ids, seg, tgt, mask, config)
else:
wl, uwl = compute_diffusion_loss(model, ids, seg, tgt, mask, config)
total_l += wl.item()
total_uw += uwl.item()
count += 1
return total_l / max(count, 1), total_uw / max(count, 1)
def evaluate_bleu(model, tokenizer, config, device, num_samples=100, num_steps=50):
import sacrebleu
model.eval()
ds = load_dataset("wmt/wmt14", "de-en", split="test", trust_remote_code=True)
refs, hyps = [], []
for i in range(min(num_samples, len(ds))):
en = ds[i]['translation']['en']
de_ref = ds[i]['translation']['de']
enc = tokenizer("translate English to German: " + en, max_length=config.max_src_len,
truncation=True, padding='max_length', return_tensors='pt')
src_ids = enc['input_ids'].to(device)
src_seg = torch.zeros_like(src_ids)
with torch.no_grad():
if device.type == 'cuda':
with torch.amp.autocast('cuda'):
gen = generate(model, src_ids, src_seg, config, num_steps=num_steps, device=device)
else:
gen = generate(model, src_ids, src_seg, config, num_steps=num_steps, device=device)
hyp = tokenizer.decode(gen[0], skip_special_tokens=True)
refs.append(de_ref)
hyps.append(hyp)
if i < 5:
print(f" EN: {en[:100]}")
print(f" REF: {de_ref[:100]}")
print(f" GEN: {hyp[:100]}")
print()
bleu = sacrebleu.corpus_bleu(hyps, [refs])
print(f"SacreBLEU: {bleu.score:.2f}")
return bleu.score
def save_model(model, config, tokenizer, step, is_best=False, push_to_hub=False):
save_dir = "checkpoints/best" if is_best else f"checkpoints/step-{step}"
os.makedirs(save_dir, exist_ok=True)
torch.save(model.state_dict(), os.path.join(save_dir, "model.pt"))
config_dict = {k: getattr(config, k) for k in [
'vocab_size', 'max_src_len', 'max_tgt_len', 'hidden_dim',
'n_heads', 'n_blocks', 'dropout', 'cond_dim', 'mask_token_id', 'pad_token_id'
]}
with open(os.path.join(save_dir, "config.json"), "w") as f:
json.dump(config_dict, f, indent=2)
tokenizer.save_pretrained(save_dir)
if push_to_hub:
push_model_to_hub(save_dir, step, config)
print(f" Saved checkpoint to {save_dir}")
def push_model_to_hub(save_dir, step, config):
from huggingface_hub import HfApi, upload_folder
api = HfApi()
try:
api.create_repo(HUB_MODEL_ID, exist_ok=True, private=False)
except Exception as e:
print(f" Warning creating repo: {e}")
readme = f"""---
tags:
- text-diffusion
- machine-translation
- en-de
- masked-diffusion
language:
- en
- de
datasets:
- wmt/wmt14
---
# Text Diffusion Model for EN→DE Translation
A **masked discrete diffusion** model for English-to-German machine translation,
trained from scratch on WMT14 EN-DE.
## Architecture
- **Type**: Masked Discrete Diffusion (inspired by MDLM + LLaDA)
- **Backbone**: DiT (Diffusion Transformer) with adaptive LayerNorm (adaLN)
- **Parameters**: ~72M
- **Blocks**: {config.n_blocks} DiT blocks, hidden_dim={config.hidden_dim}, heads={config.n_heads}
- **Tokenizer**: {TOKENIZER_NAME} (~58K vocab)
- **Max sequence**: {config.max_src_len} src + {config.max_tgt_len} tgt tokens
## Training
- **Dataset**: WMT14 EN-DE (~4.5M pairs)
- **Method**: Masked discrete diffusion with ELBO weighting (1/t)
- **Optimizer**: AdamW, lr=3e-4, cosine with 4K warmup
- **Effective batch size**: {TRAIN_CONFIG['batch_size'] * TRAIN_CONFIG['gradient_accumulation_steps']}
- **Training steps**: {step:,}
## How It Works
1. Source (EN) + target (DE) tokens concatenated: `[source | target]`
2. Training: target tokens randomly masked with prob `t ~ U(0,1)`, predict masked tokens
3. Inference: start fully masked → iteratively unmask over {TRAIN_CONFIG['num_gen_steps']} steps
## References
- [MDLM](https://arxiv.org/abs/2406.07524) | [LLaDA](https://arxiv.org/abs/2502.09992) | [DiNoiSer](https://arxiv.org/abs/2302.10025)
"""
with open(os.path.join(save_dir, "README.md"), "w") as f:
f.write(readme)
upload_folder(repo_id=HUB_MODEL_ID, folder_path=save_dir,
commit_message=f"Checkpoint step {step}")
print(f" Pushed to hub: https://huggingface.co/{HUB_MODEL_ID}")
if __name__ == "__main__":
train()