""" Text Diffusion Model for EN→DE Machine Translation Self-contained training script. Architecture: Masked Discrete Diffusion with DiT backbone Inspired by MDLM (kuleshov-group) + LLaDA conditional generation Dataset: WMT14 EN-DE Usage: pip install torch transformers datasets trackio sacrebleu sacremoses sentencepiece protobuf python train.py """ import os import math import typing import time import json import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset from dataclasses import dataclass from datasets import load_dataset from transformers import AutoTokenizer, get_cosine_schedule_with_warmup import trackio # ═══════════════════════════════════════════════════════════════ # MODEL ARCHITECTURE # ═══════════════════════════════════════════════════════════════ @dataclass class DiffusionTranslatorConfig: vocab_size: int = 32128 max_src_len: int = 128 max_tgt_len: int = 128 hidden_dim: int = 512 n_heads: int = 8 n_blocks: int = 8 dropout: float = 0.1 cond_dim: int = 128 mask_token_id: int = 32100 pad_token_id: int = 0 class Rotary(nn.Module): def __init__(self, dim, base=10_000): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer('inv_freq', inv_freq) self.seq_len_cached = None self.cos_cached = None self.sin_cached = None def forward(self, x, seq_dim=1): seq_len = x.shape[seq_dim] if seq_len != self.seq_len_cached: self.seq_len_cached = seq_len t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) freqs = torch.einsum("i,j->ij", t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1).to(x.device) self.cos_cached = emb.cos() self.sin_cached = emb.sin() return self.cos_cached, self.sin_cached def rotate_half(x): x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin): cos = cos[:q.shape[1], :] sin = sin[:q.shape[1], :] cos = cos.unsqueeze(0).unsqueeze(2) sin = sin.unsqueeze(0).unsqueeze(2) q = (q * cos) + (rotate_half(q) * sin) k = (k * cos) + (rotate_half(k) * sin) return q, k class TimestepEmbedder(nn.Module): def __init__(self, hidden_size, frequency_embedding_size=256): super().__init__() self.mlp = nn.Sequential( nn.Linear(frequency_embedding_size, hidden_size, bias=True), nn.SiLU(), nn.Linear(hidden_size, hidden_size, bias=True), ) self.frequency_embedding_size = frequency_embedding_size @staticmethod def timestep_embedding(t, dim, max_period=10000): half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(0, half, dtype=torch.float32, device=t.device) / half ) args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) return embedding def forward(self, t): t_freq = self.timestep_embedding(t, self.frequency_embedding_size) return self.mlp(t_freq) class LayerNorm(nn.Module): def __init__(self, dim): super().__init__() self.weight = nn.Parameter(torch.ones([dim])) self.dim = dim def forward(self, x): with torch.amp.autocast('cuda', enabled=False): x = F.layer_norm(x.float(), [self.dim]) return x * self.weight[None, None, :] class DiTBlock(nn.Module): """Diffusion Transformer block with adaptive layer norm (adaLN).""" def __init__(self, dim, n_heads, cond_dim, mlp_ratio=4, dropout=0.1): super().__init__() self.n_heads = n_heads self.head_dim = dim // n_heads self.norm1 = LayerNorm(dim) self.q_proj = nn.Linear(dim, dim, bias=False) self.k_proj = nn.Linear(dim, dim, bias=False) self.v_proj = nn.Linear(dim, dim, bias=False) self.attn_out = nn.Linear(dim, dim, bias=False) self.dropout1 = nn.Dropout(dropout) self.norm2 = LayerNorm(dim) self.mlp = nn.Sequential( nn.Linear(dim, mlp_ratio * dim, bias=True), nn.GELU(approximate='tanh'), nn.Linear(mlp_ratio * dim, dim, bias=True), ) self.dropout2 = nn.Dropout(dropout) self.adaLN_modulation = nn.Linear(cond_dim, 6 * dim, bias=True) nn.init.zeros_(self.adaLN_modulation.weight) nn.init.zeros_(self.adaLN_modulation.bias) def forward(self, x, rotary_cos_sin, c, attention_mask=None): batch_size, seq_len, dim = x.shape mod = self.adaLN_modulation(c)[:, None, :].chunk(6, dim=2) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod x_skip = x x_norm = self.norm1(x) * (1 + scale_msa) + shift_msa q = self.q_proj(x_norm).view(batch_size, seq_len, self.n_heads, self.head_dim) k = self.k_proj(x_norm).view(batch_size, seq_len, self.n_heads, self.head_dim) v = self.v_proj(x_norm).view(batch_size, seq_len, self.n_heads, self.head_dim) cos, sin = rotary_cos_sin q, k = apply_rotary_pos_emb(q, k, cos, sin) q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) # Bidirectional attention (no causal mask) attn_output = F.scaled_dot_product_attention( q, k, v, attn_mask=attention_mask, dropout_p=self.dropout1.p if self.training else 0.0 ) attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, dim) x = x_skip + gate_msa * self.dropout1(self.attn_out(attn_output)) x_skip = x x_norm = self.norm2(x) * (1 + scale_mlp) + shift_mlp x = x_skip + gate_mlp * self.dropout2(self.mlp(x_norm)) return x class DiffusionTranslator(nn.Module): """ Masked Discrete Diffusion model for EN→DE translation. Input: [source_tokens | target_tokens] where target tokens are partially masked Bidirectional transformer (DiT blocks with adaLN for timestep conditioning) """ def __init__(self, config: DiffusionTranslatorConfig): super().__init__() self.config = config self.vocab_embed = nn.Embedding(config.vocab_size, config.hidden_dim) self.sigma_map = TimestepEmbedder(config.cond_dim) self.rotary_emb = Rotary(config.hidden_dim // config.n_heads) self.segment_embed = nn.Embedding(2, config.hidden_dim) self.blocks = nn.ModuleList([ DiTBlock(config.hidden_dim, config.n_heads, config.cond_dim, dropout=config.dropout) for _ in range(config.n_blocks) ]) self.final_norm = LayerNorm(config.hidden_dim) self.final_adaLN = nn.Linear(config.cond_dim, 2 * config.hidden_dim, bias=True) nn.init.zeros_(self.final_adaLN.weight) nn.init.zeros_(self.final_adaLN.bias) self.output_proj = nn.Linear(config.hidden_dim, config.vocab_size, bias=False) self.output_proj.weight = self.vocab_embed.weight # Weight tying self._init_weights() def _init_weights(self): for module in self.modules(): if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) def forward(self, input_ids, segment_ids, timesteps): x = self.vocab_embed(input_ids) + self.segment_embed(segment_ids) c = F.silu(self.sigma_map(timesteps)) rotary_cos_sin = self.rotary_emb(x) for block in self.blocks: x = block(x, rotary_cos_sin, c) shift, scale = self.final_adaLN(c)[:, None, :].chunk(2, dim=2) x = self.final_norm(x) * (1 + scale) + shift logits = self.output_proj(x) return logits def count_parameters(self): return sum(p.numel() for p in self.parameters() if p.requires_grad) def compute_diffusion_loss(model, input_ids, segment_ids, target_ids, target_mask, config): """Compute masked diffusion training loss (LLaDA-style ELBO).""" batch_size = input_ids.shape[0] device = input_ids.device eps = 1e-5 t = torch.rand(batch_size, device=device) * (1 - eps) + eps mask_prob = t[:, None].expand_as(target_mask) random_mask = torch.rand_like(mask_prob) < mask_prob diffusion_mask = random_mask & target_mask noised_input = input_ids.clone() noised_input[diffusion_mask] = config.mask_token_id logits = model(noised_input, segment_ids, t) logits_flat = logits.view(-1, config.vocab_size) targets_flat = target_ids.view(-1) if diffusion_mask.sum() == 0: zero = torch.tensor(0.0, device=device, requires_grad=True) return zero, zero ce_loss = F.cross_entropy(logits_flat, targets_flat, reduction='none') masked_loss_2d = ce_loss.view(batch_size, -1) * diffusion_mask.float() per_example_counts = diffusion_mask.float().sum(dim=1).clamp(min=1.0) per_example_loss = masked_loss_2d.sum(dim=1) / per_example_counts weighted_loss = (per_example_loss / t).mean() unweighted_loss = per_example_loss.mean() return weighted_loss, unweighted_loss @torch.no_grad() def generate(model, src_ids, src_segment_ids, config, num_steps=50, device='cuda'): """Generate translation using iterative unmasking.""" model.eval() batch_size = src_ids.shape[0] tgt_len = config.max_tgt_len tgt_ids = torch.full((batch_size, tgt_len), config.mask_token_id, device=device) tgt_segment_ids = torch.ones(batch_size, tgt_len, dtype=torch.long, device=device) input_ids = torch.cat([src_ids, tgt_ids], dim=1) segment_ids = torch.cat([src_segment_ids, tgt_segment_ids], dim=1) src_len = src_ids.shape[1] for step in range(num_steps, 0, -1): t = torch.tensor([step / num_steps], device=device).expand(batch_size) s = torch.tensor([(step - 1) / num_steps], device=device).expand(batch_size) logits = model(input_ids, segment_ids, t) tgt_logits = logits[:, src_len:, :] predicted_tokens = tgt_logits.argmax(dim=-1) current_tgt = input_ids[:, src_len:] still_masked = (current_tgt == config.mask_token_id) if step > 1: remask_prob = s[0].item() / t[0].item() if t[0].item() > 0 else 0.0 remask = torch.rand_like(predicted_tokens.float()) < remask_prob new_tgt = current_tgt.clone() unmask_positions = still_masked & ~remask new_tgt[unmask_positions] = predicted_tokens[unmask_positions] else: new_tgt = current_tgt.clone() new_tgt[still_masked] = predicted_tokens[still_masked] input_ids = torch.cat([src_ids, new_tgt], dim=1) return input_ids[:, src_len:] # ═══════════════════════════════════════════════════════════════ # DATASET # ═══════════════════════════════════════════════════════════════ class WMT14EnDeDataset(Dataset): def __init__(self, data, tokenizer, max_src_len=128, max_tgt_len=128): self.data = data self.tokenizer = tokenizer self.max_src_len = max_src_len self.max_tgt_len = max_tgt_len def __len__(self): return len(self.data) def __getitem__(self, idx): item = self.data[idx] en_text = item['translation']['en'] de_text = item['translation']['de'] src_enc = self.tokenizer( "translate English to German: " + en_text, max_length=self.max_src_len, truncation=True, padding='max_length', return_tensors=None, ) tgt_enc = self.tokenizer( de_text, max_length=self.max_tgt_len, truncation=True, padding='max_length', return_tensors=None, ) src_ids = src_enc['input_ids'] tgt_ids = tgt_enc['input_ids'] segment_ids = [0] * len(src_ids) + [1] * len(tgt_ids) full_ids = src_ids + tgt_ids target_mask = [0] * len(src_ids) + tgt_enc['attention_mask'] return { 'input_ids': torch.tensor(full_ids, dtype=torch.long), 'segment_ids': torch.tensor(segment_ids, dtype=torch.long), 'target_ids': torch.tensor(full_ids, dtype=torch.long), 'target_mask': torch.tensor(target_mask, dtype=torch.bool), } # ═══════════════════════════════════════════════════════════════ # CONFIGURATION # ═══════════════════════════════════════════════════════════════ MODEL_CONFIG = dict( vocab_size=None, max_src_len=128, max_tgt_len=128, hidden_dim=512, n_heads=8, n_blocks=12, dropout=0.1, cond_dim=128, mask_token_id=None, pad_token_id=None, ) TRAIN_CONFIG = dict( learning_rate=3e-4, weight_decay=0.01, warmup_steps=4000, max_steps=200_000, batch_size=64, gradient_accumulation_steps=4, eval_every=5000, save_every=10000, log_every=100, max_grad_norm=1.0, num_gen_steps=50, fp16=True, seed=42, ) HUB_MODEL_ID = "vedkdev/text-diffusion-en-de" TOKENIZER_NAME = "Helsinki-NLP/opus-mt-en-de" # ═══════════════════════════════════════════════════════════════ # TRAINING # ═══════════════════════════════════════════════════════════════ def train(): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') torch.manual_seed(TRAIN_CONFIG['seed']) print(f"Device: {device}") print(f"Loading tokenizer: {TOKENIZER_NAME}") tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME) if tokenizer.mask_token is None: tokenizer.add_special_tokens({'mask_token': ''}) mask_token_id = tokenizer.mask_token_id pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 print(f"Vocab size: {len(tokenizer)}") print(f"Mask token ID: {mask_token_id}, Pad token ID: {pad_token_id}") MODEL_CONFIG['vocab_size'] = len(tokenizer) MODEL_CONFIG['mask_token_id'] = mask_token_id MODEL_CONFIG['pad_token_id'] = pad_token_id config = DiffusionTranslatorConfig(**MODEL_CONFIG) model = DiffusionTranslator(config).to(device) print(f"Model parameters: {model.count_parameters():,}") print("Loading WMT14 EN-DE dataset...") dataset = load_dataset("wmt/wmt14", "de-en", trust_remote_code=True) train_data = dataset['train'] val_data = dataset['validation'] print(f"Train: {len(train_data):,} | Val: {len(val_data):,}") train_dataset = WMT14EnDeDataset(train_data, tokenizer, config.max_src_len, config.max_tgt_len) val_dataset = WMT14EnDeDataset(val_data, tokenizer, config.max_src_len, config.max_tgt_len) train_loader = DataLoader(train_dataset, batch_size=TRAIN_CONFIG['batch_size'], shuffle=True, num_workers=4, pin_memory=True, drop_last=True) val_loader = DataLoader(val_dataset, batch_size=TRAIN_CONFIG['batch_size'], shuffle=False, num_workers=2, pin_memory=True) optimizer = torch.optim.AdamW(model.parameters(), lr=TRAIN_CONFIG['learning_rate'], weight_decay=TRAIN_CONFIG['weight_decay'], betas=(0.9, 0.98), eps=1e-8) scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=TRAIN_CONFIG['warmup_steps'], num_training_steps=TRAIN_CONFIG['max_steps']) scaler = torch.amp.GradScaler('cuda') if (TRAIN_CONFIG['fp16'] and device.type == 'cuda') else None trackio.init(project="text-diffusion-en-de", name="v1-wmt14-dit12-512d") global_step = 0 best_val_loss = float('inf') accum_loss = 0.0 accum_loss_uw = 0.0 accum_count = 0 eff_bs = TRAIN_CONFIG['batch_size'] * TRAIN_CONFIG['gradient_accumulation_steps'] print(f"\n=== Starting Training ===") print(f"Effective batch size: {eff_bs} | Max steps: {TRAIN_CONFIG['max_steps']:,}") print(f"Warmup: {TRAIN_CONFIG['warmup_steps']:,} | LR: {TRAIN_CONFIG['learning_rate']}") model.train() optimizer.zero_grad() data_iter = iter(train_loader) start_time = time.time() total_micro_steps = TRAIN_CONFIG['max_steps'] * TRAIN_CONFIG['gradient_accumulation_steps'] for step in range(1, total_micro_steps + 1): try: batch = next(data_iter) except StopIteration: data_iter = iter(train_loader) batch = next(data_iter) input_ids = batch['input_ids'].to(device) segment_ids = batch['segment_ids'].to(device) target_ids = batch['target_ids'].to(device) target_mask = batch['target_mask'].to(device) if scaler is not None: with torch.amp.autocast('cuda'): wl, uwl = compute_diffusion_loss(model, input_ids, segment_ids, target_ids, target_mask, config) loss = wl / TRAIN_CONFIG['gradient_accumulation_steps'] scaler.scale(loss).backward() else: wl, uwl = compute_diffusion_loss(model, input_ids, segment_ids, target_ids, target_mask, config) loss = wl / TRAIN_CONFIG['gradient_accumulation_steps'] loss.backward() accum_loss += wl.item() accum_loss_uw += uwl.item() accum_count += 1 if step % TRAIN_CONFIG['gradient_accumulation_steps'] == 0: if scaler is not None: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), TRAIN_CONFIG['max_grad_norm']) scaler.step(optimizer) scaler.update() else: torch.nn.utils.clip_grad_norm_(model.parameters(), TRAIN_CONFIG['max_grad_norm']) optimizer.step() scheduler.step() optimizer.zero_grad() global_step += 1 # Log if global_step % TRAIN_CONFIG['log_every'] == 0: avg_l = accum_loss / accum_count avg_uw = accum_loss_uw / accum_count elapsed = time.time() - start_time sps = global_step / elapsed lr = scheduler.get_last_lr()[0] print(f"step={global_step} | loss={avg_l:.4f} | ce_loss={avg_uw:.4f} | lr={lr:.2e} | steps/s={sps:.2f}") trackio.log({"train/loss_weighted": avg_l, "train/loss_ce": avg_uw, "train/learning_rate": lr, "train/steps_per_sec": sps}, step=global_step) accum_loss = 0.0 accum_loss_uw = 0.0 accum_count = 0 # Eval if global_step % TRAIN_CONFIG['eval_every'] == 0: vl, vuw = evaluate(model, val_loader, config, device, scaler is not None) print(f" [EVAL] step={global_step} | val_loss={vl:.4f} | val_ce={vuw:.4f}") trackio.log({"eval/loss_weighted": vl, "eval/loss_ce": vuw}, step=global_step) if global_step % (TRAIN_CONFIG['eval_every'] * 4) == 0: bleu = evaluate_bleu(model, tokenizer, config, device, num_samples=100, num_steps=TRAIN_CONFIG['num_gen_steps']) trackio.log({"eval/sacrebleu": bleu}, step=global_step) if vuw < best_val_loss: best_val_loss = vuw save_model(model, config, tokenizer, global_step, is_best=True) model.train() # Save + push if global_step % TRAIN_CONFIG['save_every'] == 0: save_model(model, config, tokenizer, global_step, push_to_hub=True) # Final save_model(model, config, tokenizer, global_step, push_to_hub=True) print("\n=== Final BLEU Evaluation ===") bleu = evaluate_bleu(model, tokenizer, config, device, num_samples=200, num_steps=TRAIN_CONFIG['num_gen_steps']) trackio.log({"eval/final_sacrebleu": bleu}, step=global_step) print(f"\n=== Training Complete === Final BLEU: {bleu:.2f}") def evaluate(model, val_loader, config, device, use_fp16=True): model.eval() total_l = total_uw = 0.0 count = 0 with torch.no_grad(): for i, batch in enumerate(val_loader): if i >= 50: break ids = batch['input_ids'].to(device) seg = batch['segment_ids'].to(device) tgt = batch['target_ids'].to(device) mask = batch['target_mask'].to(device) if use_fp16: with torch.amp.autocast('cuda'): wl, uwl = compute_diffusion_loss(model, ids, seg, tgt, mask, config) else: wl, uwl = compute_diffusion_loss(model, ids, seg, tgt, mask, config) total_l += wl.item() total_uw += uwl.item() count += 1 return total_l / max(count, 1), total_uw / max(count, 1) def evaluate_bleu(model, tokenizer, config, device, num_samples=100, num_steps=50): import sacrebleu model.eval() ds = load_dataset("wmt/wmt14", "de-en", split="test", trust_remote_code=True) refs, hyps = [], [] for i in range(min(num_samples, len(ds))): en = ds[i]['translation']['en'] de_ref = ds[i]['translation']['de'] enc = tokenizer("translate English to German: " + en, max_length=config.max_src_len, truncation=True, padding='max_length', return_tensors='pt') src_ids = enc['input_ids'].to(device) src_seg = torch.zeros_like(src_ids) with torch.no_grad(): if device.type == 'cuda': with torch.amp.autocast('cuda'): gen = generate(model, src_ids, src_seg, config, num_steps=num_steps, device=device) else: gen = generate(model, src_ids, src_seg, config, num_steps=num_steps, device=device) hyp = tokenizer.decode(gen[0], skip_special_tokens=True) refs.append(de_ref) hyps.append(hyp) if i < 5: print(f" EN: {en[:100]}") print(f" REF: {de_ref[:100]}") print(f" GEN: {hyp[:100]}") print() bleu = sacrebleu.corpus_bleu(hyps, [refs]) print(f"SacreBLEU: {bleu.score:.2f}") return bleu.score def save_model(model, config, tokenizer, step, is_best=False, push_to_hub=False): save_dir = "checkpoints/best" if is_best else f"checkpoints/step-{step}" os.makedirs(save_dir, exist_ok=True) torch.save(model.state_dict(), os.path.join(save_dir, "model.pt")) config_dict = {k: getattr(config, k) for k in [ 'vocab_size', 'max_src_len', 'max_tgt_len', 'hidden_dim', 'n_heads', 'n_blocks', 'dropout', 'cond_dim', 'mask_token_id', 'pad_token_id' ]} with open(os.path.join(save_dir, "config.json"), "w") as f: json.dump(config_dict, f, indent=2) tokenizer.save_pretrained(save_dir) if push_to_hub: push_model_to_hub(save_dir, step, config) print(f" Saved checkpoint to {save_dir}") def push_model_to_hub(save_dir, step, config): from huggingface_hub import HfApi, upload_folder api = HfApi() try: api.create_repo(HUB_MODEL_ID, exist_ok=True, private=False) except Exception as e: print(f" Warning creating repo: {e}") readme = f"""--- tags: - text-diffusion - machine-translation - en-de - masked-diffusion language: - en - de datasets: - wmt/wmt14 --- # Text Diffusion Model for EN→DE Translation A **masked discrete diffusion** model for English-to-German machine translation, trained from scratch on WMT14 EN-DE. ## Architecture - **Type**: Masked Discrete Diffusion (inspired by MDLM + LLaDA) - **Backbone**: DiT (Diffusion Transformer) with adaptive LayerNorm (adaLN) - **Parameters**: ~72M - **Blocks**: {config.n_blocks} DiT blocks, hidden_dim={config.hidden_dim}, heads={config.n_heads} - **Tokenizer**: {TOKENIZER_NAME} (~58K vocab) - **Max sequence**: {config.max_src_len} src + {config.max_tgt_len} tgt tokens ## Training - **Dataset**: WMT14 EN-DE (~4.5M pairs) - **Method**: Masked discrete diffusion with ELBO weighting (1/t) - **Optimizer**: AdamW, lr=3e-4, cosine with 4K warmup - **Effective batch size**: {TRAIN_CONFIG['batch_size'] * TRAIN_CONFIG['gradient_accumulation_steps']} - **Training steps**: {step:,} ## How It Works 1. Source (EN) + target (DE) tokens concatenated: `[source | target]` 2. Training: target tokens randomly masked with prob `t ~ U(0,1)`, predict masked tokens 3. Inference: start fully masked → iteratively unmask over {TRAIN_CONFIG['num_gen_steps']} steps ## References - [MDLM](https://arxiv.org/abs/2406.07524) | [LLaDA](https://arxiv.org/abs/2502.09992) | [DiNoiSer](https://arxiv.org/abs/2302.10025) """ with open(os.path.join(save_dir, "README.md"), "w") as f: f.write(readme) upload_folder(repo_id=HUB_MODEL_ID, folder_path=save_dir, commit_message=f"Checkpoint step {step}") print(f" Pushed to hub: https://huggingface.co/{HUB_MODEL_ID}") if __name__ == "__main__": train()