| | """Common training utilities for BitTransformer models.""" |
| |
|
| | from __future__ import annotations |
| |
|
| | from typing import Callable, Dict, List, Optional |
| | import contextlib |
| | import sys |
| | import warnings |
| | import math |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from torch.utils.data import DataLoader |
| |
|
| | from .compression import compress_bits, pack_bits, unpack_bits |
| | from .optimization import configure_optimizer |
| | from .model import BitTransformerLM |
| | from .utils import set_dropout |
| | from .torch_utils import cpu_autocast |
| |
|
| |
|
| | def cosine_ramp(step: int, start: float, end: float, total_steps: int) -> float: |
| | """Cosine ramp from ``start`` to ``end`` over ``total_steps``.""" |
| | if total_steps <= 0 or step >= total_steps: |
| | return end |
| | cos_inner = math.pi * step / total_steps |
| | return start + (end - start) * (1 - math.cos(cos_inner)) / 2 |
| |
|
| |
|
| | def train_loop( |
| | model: BitTransformerLM, |
| | data: torch.Tensor, |
| | *, |
| | epochs: int = 1, |
| | extra_steps: int = 0, |
| | compress_prob: float = 0.5, |
| | direct_prob: float = 0.0, |
| | batch_size: int = 8, |
| | num_workers: int = 0, |
| | accum_steps: int = 1, |
| | amp: bool = False, |
| | compile_model: bool = False, |
| | log: bool = False, |
| | forward_kwargs: Optional[Dict] = None, |
| | optimizer: Optional[torch.optim.Optimizer] = None, |
| | scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, |
| | diffusion: bool = False, |
| | noise_fn: Optional[Callable[[], float]] = None, |
| | diffusion_curriculum: bool = False, |
| | compress_warmup: int = 0, |
| | ) -> List[Dict[str, float]]: |
| | """Generic training loop supporting optional compression and diffusion. |
| | |
| | ``compress_prob`` controls the fraction of batches that are run through |
| | ``forward_compressed``. ``direct_prob`` instead feeds the model with the |
| | bit-packed result of ``compress_bits`` after converting back to a bit |
| | tensor. When enabled, metrics for direct-compressed batches are tracked |
| | separately. |
| | |
| | When ``diffusion`` is ``True`` the loop performs denoising training. Batches |
| | are noised by randomly flipping bits with a probability given by |
| | ``noise_fn`` (defaulting to a uniform draw in ``[0, 0.5]``). When |
| | ``diffusion_curriculum`` is ``True`` the noise probability decreases |
| | linearly from ``0.5`` to ``0.0`` over the training epochs. The model is |
| | then trained to recover the clean sequence using full-context attention |
| | (``causal=False``). |
| | |
| | Existing ``optimizer`` and ``scheduler`` instances may be supplied to allow |
| | integration with long-running training sessions, otherwise new ones are |
| | created automatically. |
| | """ |
| | if compile_model and sys.version_info < (3, 12) and torch.__version__ >= "2.1": |
| | model = torch.compile(model) |
| | elif compile_model: |
| | warnings.warn("torch.compile skipped: requires torch>=2.1 and Python<3.12") |
| |
|
| | model.train() |
| | set_dropout(model, 0.1) |
| |
|
| | device = next(model.parameters()).device |
| | loader = DataLoader( |
| | data, |
| | batch_size=batch_size, |
| | shuffle=True, |
| | num_workers=num_workers, |
| | persistent_workers=num_workers > 0, |
| | ) |
| | steps_per_epoch = max(1, len(loader)) |
| | total_updates = math.ceil(epochs * (steps_per_epoch + extra_steps) / accum_steps) |
| | if optimizer is None or scheduler is None: |
| | optimizer, scheduler = configure_optimizer( |
| | model, lr=1e-3, total_steps=total_updates |
| | ) |
| | metrics: List[Dict[str, float]] = [] |
| |
|
| | global_step = 0 |
| | for epoch in range(epochs): |
| | raw_losses: List[float] = [] |
| | raw_accs: List[float] = [] |
| | comp_losses: List[float] = [] |
| | comp_accs: List[float] = [] |
| | comp_ratios: List[float] = [] |
| | direct_losses: List[float] = [] |
| |
|
| | last_batch = None |
| | for step, batch in enumerate(loader): |
| | last_batch = batch |
| | batch = batch.to(device) |
| | cur_compress = ( |
| | cosine_ramp(global_step, 0.0, compress_prob, compress_warmup) |
| | if not diffusion |
| | else compress_prob |
| | ) |
| | if diffusion: |
| | if diffusion_curriculum: |
| | p = 0.5 * (1 - epoch / max(1, epochs - 1)) |
| | else: |
| | p = noise_fn() if noise_fn is not None else float(torch.rand(()) * 0.5) |
| | noise = (torch.rand_like(batch.float()) < p).long() |
| | noisy = batch ^ noise |
| | with ( |
| | torch.cuda.amp.autocast(dtype=torch.bfloat16) |
| | if amp and torch.cuda.is_available() |
| | else cpu_autocast() if amp else contextlib.nullcontext() |
| | ): |
| | logits, _ = model(noisy, causal=False) |
| | pred = logits.reshape(-1, 2) |
| | target = batch.reshape(-1) |
| | loss = F.cross_entropy(pred, target) / accum_steps |
| | acc = (pred.argmax(dim=-1) == target).float().mean().item() |
| | raw_losses.append(loss.item() * accum_steps) |
| | raw_accs.append(acc) |
| | loss.backward() |
| | if (step + 1) % accum_steps == 0: |
| | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| | optimizer.step() |
| | scheduler.step() |
| | optimizer.zero_grad() |
| | global_step += 1 |
| | continue |
| |
|
| | r = torch.rand(()) |
| | key = "raw" |
| | ratio = 1.0 |
| | target = batch[:, 1:].reshape(-1) |
| |
|
| | if r < direct_prob: |
| | packed = [pack_bits(row.to(torch.uint8)) for row in batch] |
| | unpacked = [unpack_bits(p, n_bits=batch.size(1)) for p in packed] |
| | max_len = min( |
| | max(u.numel() for u in unpacked), |
| | model.pos_enc.pe.size(0), |
| | ) |
| | padded = [F.pad(u[:max_len], (0, max_len - min(u.numel(), max_len))) for u in unpacked] |
| | dc_batch = torch.stack(padded).long() |
| | with ( |
| | torch.cuda.amp.autocast(dtype=torch.bfloat16) |
| | if amp and torch.cuda.is_available() |
| | else cpu_autocast() if amp else contextlib.nullcontext() |
| | ): |
| | logits, _ = model(dc_batch, **(forward_kwargs or {})) |
| | ratio = sum(p.numel() for p in packed) / batch.numel() |
| | target = dc_batch[:, 1:].reshape(-1) |
| | key = "direct" |
| | elif r < direct_prob + cur_compress: |
| | comp_batch = [compress_bits(row.to(torch.uint8)) for row in batch] |
| | with ( |
| | torch.cuda.amp.autocast(dtype=torch.bfloat16) |
| | if amp and torch.cuda.is_available() |
| | else cpu_autocast() if amp else contextlib.nullcontext() |
| | ): |
| | logits, _ = model.forward_compressed(comp_batch, **(forward_kwargs or {})) |
| | ratio = sum(c.numel() for c in comp_batch) / batch.numel() |
| | target = batch[:, 1:].reshape(-1) |
| | key = "compressed" |
| | else: |
| | with ( |
| | torch.cuda.amp.autocast(dtype=torch.bfloat16) |
| | if amp and torch.cuda.is_available() |
| | else cpu_autocast() if amp else contextlib.nullcontext() |
| | ): |
| | logits, _ = model(batch, **(forward_kwargs or {})) |
| |
|
| | pred = logits[:, :-1, :].reshape(-1, 2) |
| | loss = F.cross_entropy(pred, target) / accum_steps |
| | acc = (pred.argmax(dim=-1) == target).float().mean().item() |
| |
|
| | loss.backward() |
| | if (step + 1) % accum_steps == 0: |
| | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| | optimizer.step() |
| | scheduler.step() |
| | optimizer.zero_grad() |
| | global_step += 1 |
| |
|
| | if key == "compressed": |
| | comp_losses.append(loss.item() * accum_steps) |
| | comp_accs.append(acc) |
| | comp_ratios.append(ratio) |
| | elif key == "direct": |
| | direct_losses.append(loss.item() * accum_steps) |
| | comp_ratios.append(ratio) |
| | else: |
| | raw_losses.append(loss.item() * accum_steps) |
| | raw_accs.append(acc) |
| |
|
| | |
| | if extra_steps > 0 and last_batch is not None and not diffusion: |
| | for step in range(extra_steps): |
| | with ( |
| | torch.cuda.amp.autocast(dtype=torch.bfloat16) |
| | if amp and torch.cuda.is_available() |
| | else cpu_autocast() if amp else contextlib.nullcontext() |
| | ): |
| | logits, _ = model(last_batch, **(forward_kwargs or {})) |
| | pred = logits[:, :-1, :].reshape(-1, 2) |
| | target = last_batch[:, 1:].reshape(-1) |
| | loss = F.cross_entropy(pred, target) / accum_steps |
| | acc = (pred.argmax(dim=-1) == target).float().mean().item() |
| | loss.backward() |
| | if (step + 1) % accum_steps == 0: |
| | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| | optimizer.step() |
| | scheduler.step() |
| | optimizer.zero_grad() |
| | raw_losses.append(loss.item() * accum_steps) |
| | raw_accs.append(acc) |
| | global_step += 1 |
| |
|
| | m = { |
| | "raw_loss": float(sum(raw_losses) / len(raw_losses)) if raw_losses else 0.0, |
| | "raw_acc": float(sum(raw_accs) / len(raw_accs)) if raw_accs else 0.0, |
| | "compressed_loss": float(sum(comp_losses) / len(comp_losses)) if comp_losses else 0.0, |
| | "compressed_acc": float(sum(comp_accs) / len(comp_accs)) if comp_accs else 0.0, |
| | "direct_loss": float(sum(direct_losses) / len(direct_losses)) if direct_losses else 0.0, |
| | "compression_ratio": float(sum(comp_ratios) / len(comp_ratios)) if comp_ratios else 0.0, |
| | } |
| | metrics.append(m) |
| |
|
| | if log: |
| | print( |
| | f"Epoch {epoch} " |
| | f"raw_loss={m['raw_loss']:.4f} acc={m['raw_acc']:.3f} | " |
| | f"compressed_loss={m['compressed_loss']:.4f} acc={m['compressed_acc']:.3f} " |
| | f"direct_loss={m['direct_loss']:.4f} ratio={m['compression_ratio']:.2f}" |
| | ) |
| |
|
| | return metrics |
| |
|
| | __all__ = ["train_loop"] |
| |
|