| | import os |
| | import time |
| | import math |
| | from itertools import cycle |
| | from typing import Optional |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from bit_transformer import ( |
| | BitTransformerLM, |
| | text_to_bits, |
| | quantize_dynamic, |
| | prepare_qat_fx, |
| | convert_qat_fx, |
| | hil_safe_inference, |
| | collapse_submodel, |
| | diffusion_inference, |
| | TelemetrySynthesizer, |
| | save_distilled_model, |
| | ) |
| | from bit_transformer.training import train_loop as train |
| | from bit_transformer.optimization import configure_optimizer, adjust_learning_rate |
| | from bit_transformer.utils import save_model, load_model, set_dropout |
| | from bit_transformer.torch_utils import cpu_autocast |
| |
|
| |
|
| | def lines_to_tensor(lines, max_len): |
| | seqs = [] |
| | for text in lines: |
| | bits = text_to_bits(text)[:max_len] |
| | if len(bits) < max_len: |
| | bits.extend([0] * (max_len - len(bits))) |
| | seqs.append(bits) |
| | return torch.tensor(seqs, dtype=torch.long) |
| |
|
| |
|
| | def load_wikitext(dataset_size=128, max_len=64): |
| | try: |
| | from datasets import load_dataset |
| |
|
| | ds = load_dataset("wikitext", "wikitext-2-raw-v1") |
| | train_lines = [t for t in ds["train"]["text"] if t.strip()][:dataset_size] |
| | valid_split = max(1, dataset_size // 4) |
| | valid_lines = [t for t in ds["validation"]["text"] if t.strip()][:valid_split] |
| | train = lines_to_tensor(train_lines, max_len) |
| | valid = lines_to_tensor(valid_lines, max_len) |
| | return train, valid, train_lines |
| | except Exception as e: |
| | print("Dataset load failed, using random bits", e) |
| | train = torch.randint(0, 2, (dataset_size, max_len), dtype=torch.long) |
| | valid = torch.randint(0, 2, (max_len, max_len), dtype=torch.long) |
| | return train, valid, ["" for _ in range(len(train))] |
| |
|
| |
|
| | def _warmup( |
| | model: BitTransformerLM, |
| | data: torch.Tensor, |
| | steps: int = 5, |
| | freeze_old: bool = False, |
| | old_layers: int = 0, |
| | *, |
| | diffusion: bool = False, |
| | curriculum: bool = False, |
| | optimizer: Optional[torch.optim.Optimizer] = None, |
| | scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, |
| | ) -> None: |
| | """Run a short warm-up loop after expansion.""" |
| | model.train() |
| | set_dropout(model, 0.1) |
| | if freeze_old: |
| | for idx, layer in enumerate(model.layers): |
| | if idx < old_layers: |
| | for p in layer.parameters(): |
| | p.requires_grad_(False) |
| | if optimizer is None or scheduler is None: |
| | optimizer, scheduler = configure_optimizer(model, lr=1e-3, total_steps=steps) |
| | it = iter(data.split(8)) |
| | for idx in range(steps): |
| | try: |
| | batch = next(it) |
| | except StopIteration: |
| | it = iter(data.split(8)) |
| | batch = next(it) |
| | if diffusion: |
| | p = 0.5 * (1 - idx / max(1, steps - 1)) if curriculum else 0.5 |
| | noise = (torch.rand_like(batch.float()) < p).long() |
| | noisy = batch ^ noise |
| | logits, _ = model(noisy, causal=False) |
| | pred = logits.reshape(-1, 2) |
| | target = batch.reshape(-1) |
| | else: |
| | logits, _ = model(batch) |
| | pred = logits[:, :-1, :].reshape(-1, 2) |
| | target = batch[:, 1:].reshape(-1) |
| | loss = F.cross_entropy(pred, target) |
| | loss.backward() |
| | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| | optimizer.step() |
| | scheduler.step() |
| | optimizer.zero_grad() |
| | for p in model.parameters(): |
| | p.requires_grad_(True) |
| | model.eval() |
| | set_dropout(model, 0.0) |
| |
|
| |
|
| | def integration_schedule( |
| | steps: int = 10, |
| | max_len: int = 64, |
| | dataset_size: int = 128, |
| | *, |
| | weights_path: str = "weights/model.pt.gz", |
| | plateau_steps: int = 0, |
| | collapsed_path: str | None = None, |
| | epochs_per_step: int = 2, |
| | extra_steps: int = 3, |
| | collapse: bool = True, |
| | diffusion: bool = False, |
| | noise_schedule: str = "linear", |
| | diffusion_steps: int = 8, |
| | diffusion_curriculum: bool = False, |
| | use_checkpoint: bool = True, |
| | reversible: bool = True, |
| | improve_thresh: float = 0.01, |
| | qat: bool = False, |
| | ): |
| | start = time.time() |
| | train_bits, valid_bits, train_lines = load_wikitext(dataset_size, max_len) |
| | if os.path.exists(weights_path): |
| | try: |
| | model = load_model(weights_path) |
| | print(f"Loaded model from {weights_path}") |
| | except Exception as e: |
| | print("Failed to load weights, initializing new model", e) |
| | model = BitTransformerLM( |
| | d_model=32, |
| | nhead=4, |
| | num_layers=1, |
| | dim_feedforward=64, |
| | max_seq_len=max_len, |
| | use_act=True, |
| | act_threshold=0.7, |
| | reversible=reversible, |
| | chunk_size=max_len, |
| | use_autocast=True, |
| | use_checkpoint=use_checkpoint, |
| | ) |
| | else: |
| | model = BitTransformerLM( |
| | d_model=32, |
| | nhead=4, |
| | num_layers=1, |
| | dim_feedforward=64, |
| | max_seq_len=max_len, |
| | use_act=True, |
| | act_threshold=0.7, |
| | reversible=reversible, |
| | chunk_size=max_len, |
| | use_autocast=True, |
| | use_checkpoint=use_checkpoint, |
| | ) |
| | if qat: |
| | model = prepare_qat_fx(model) |
| | results = [] |
| | scale_cycle = cycle(["layers", "width", "context"]) |
| | base_lr = 1e-3 |
| | prev_val_loss: Optional[float] = None |
| | for step in range(steps): |
| | model.train() |
| | set_dropout(model, 0.1) |
| | opt, sched = configure_optimizer( |
| | model, lr=base_lr, total_steps=epochs_per_step |
| | ) |
| | train( |
| | model, |
| | train_bits, |
| | epochs=epochs_per_step, |
| | extra_steps=extra_steps, |
| | compress_prob=0.0 if diffusion else 1.0, |
| | log=True, |
| | diffusion=diffusion, |
| | diffusion_curriculum=diffusion_curriculum, |
| | optimizer=opt, |
| | scheduler=sched, |
| | ) |
| |
|
| | model.eval() |
| | set_dropout(model, 0.0) |
| | with torch.no_grad(): |
| | logits, telemetry = model(valid_bits, causal=not diffusion) |
| | if diffusion: |
| | pred = logits.reshape(-1, 2) |
| | target = valid_bits.reshape(-1) |
| | else: |
| | pred = logits[:, :-1, :].reshape(-1, 2) |
| | target = valid_bits[:, 1:].reshape(-1) |
| | val_loss = F.cross_entropy(pred, target).item() |
| | k = telemetry["negentropy_logits"].mean().item() |
| | c = telemetry["lz_complexity_logits"].mean().item() |
| | s = telemetry["symbiosis_score"].mean().item() |
| | print(f"Step {step} validation loss: {val_loss:.4f} K={k:.3f} C={c:.3f} S={s:.3f}") |
| | results.append((step, val_loss, k, c, s)) |
| |
|
| | if prev_val_loss is not None and prev_val_loss - val_loss < improve_thresh: |
| | strategy = next(scale_cycle) |
| | base_lr = adjust_learning_rate(opt, 1 / math.sqrt(2)) |
| | if strategy == "layers": |
| | old_layers = model.num_layers |
| | model = model.double_layers() |
| | warm_opt, warm_sched = configure_optimizer( |
| | model, lr=base_lr, total_steps=100 |
| | ) |
| | _warmup( |
| | model, |
| | train_bits, |
| | steps=100, |
| | freeze_old=True, |
| | old_layers=old_layers, |
| | diffusion=diffusion, |
| | curriculum=diffusion_curriculum, |
| | optimizer=warm_opt, |
| | scheduler=warm_sched, |
| | ) |
| | elif strategy == "width": |
| | model = model.double_width() |
| | warm_opt, warm_sched = configure_optimizer( |
| | model, lr=base_lr, total_steps=100 |
| | ) |
| | _warmup( |
| | model, |
| | train_bits, |
| | steps=100, |
| | diffusion=diffusion, |
| | curriculum=diffusion_curriculum, |
| | optimizer=warm_opt, |
| | scheduler=warm_sched, |
| | ) |
| | else: |
| | max_len *= 2 |
| | train_bits, valid_bits, train_lines = load_wikitext( |
| | dataset_size, max_len |
| | ) |
| | model = model.double_length() |
| | warm_opt, warm_sched = configure_optimizer( |
| | model, lr=base_lr, total_steps=100 |
| | ) |
| | _warmup( |
| | model, |
| | train_bits, |
| | steps=100, |
| | diffusion=diffusion, |
| | curriculum=diffusion_curriculum, |
| | optimizer=warm_opt, |
| | scheduler=warm_sched, |
| | ) |
| |
|
| | prev_val_loss = val_loss |
| | if time.time() - start > 8 * 60: |
| | print("Time limit reached") |
| | break |
| |
|
| | |
| | for p in range(plateau_steps): |
| | model.train() |
| | set_dropout(model, 0.1) |
| | train( |
| | model, |
| | train_bits, |
| | epochs=epochs_per_step, |
| | extra_steps=extra_steps, |
| | compress_prob=0.0 if diffusion else 1.0, |
| | log=True, |
| | diffusion=diffusion, |
| | diffusion_curriculum=diffusion_curriculum, |
| | ) |
| | model.eval() |
| | set_dropout(model, 0.0) |
| | with torch.no_grad(): |
| | logits, telemetry = model(valid_bits, causal=not diffusion) |
| | if diffusion: |
| | pred = logits.reshape(-1, 2) |
| | target = valid_bits.reshape(-1) |
| | else: |
| | pred = logits[:, :-1, :].reshape(-1, 2) |
| | target = valid_bits[:, 1:].reshape(-1) |
| | val_loss = F.cross_entropy(pred, target).item() |
| | k = telemetry["negentropy_logits"].mean().item() |
| | c = telemetry["lz_complexity_logits"].mean().item() |
| | s = telemetry["symbiosis_score"].mean().item() |
| | idx = steps + p |
| | print( |
| | f"Plateau {p} validation loss: {val_loss:.4f} K={k:.3f} C={c:.3f} S={s:.3f}" |
| | ) |
| | results.append((idx, val_loss, k, c, s)) |
| | if time.time() - start > 8 * 60: |
| | print("Time limit reached") |
| | break |
| |
|
| | |
| | model.eval() |
| | set_dropout(model, 0.0) |
| | with torch.no_grad(): |
| | logits, telemetry = model(valid_bits, causal=not diffusion) |
| | if diffusion: |
| | pred = logits.reshape(-1, 2) |
| | target = valid_bits.reshape(-1) |
| | else: |
| | pred = logits[:, :-1, :].reshape(-1, 2) |
| | target = valid_bits[:, 1:].reshape(-1) |
| | val_loss = F.cross_entropy(pred, target).item() |
| | k = telemetry["negentropy_logits"].mean().item() |
| | c = telemetry["lz_complexity_logits"].mean().item() |
| | s = telemetry["symbiosis_score"].mean().item() |
| |
|
| | print(f"Final validation loss: {val_loss:.4f} K={k:.3f} C={c:.3f} S={s:.3f}") |
| | results.append((steps + plateau_steps, val_loss, k, c, s)) |
| |
|
| | |
| | save_model(model, weights_path) |
| |
|
| | input_bits = valid_bits[:1] |
| | if qat: |
| | qmodel = convert_qat_fx(model) |
| | else: |
| | with cpu_autocast(): |
| | model(input_bits) |
| | qmodel = quantize_dynamic(model) |
| | qmodel.eval() |
| | try: |
| | hil_safe_inference( |
| | qmodel, |
| | input_bits, |
| | c_floor=0.3, |
| | s_floor=0.5, |
| | causal=not diffusion, |
| | strict=not diffusion, |
| | ) |
| | except RuntimeError as e: |
| | print("Safety gate triggered", e) |
| | collapsed = None |
| | if collapse: |
| | synth = TelemetrySynthesizer(n_clusters=8) |
| | reps = synth.cluster_sequences(model, train_bits[:64]) |
| | floors = {"negentropy": 0.3, "lz_complexity": 0.35, "symbiosis_score": 0.5} |
| | collapsed, metrics = collapse_submodel( |
| | reps, |
| | target_params=dict( |
| | d_model=16, |
| | nhead=4, |
| | num_layers=1, |
| | dim_feedforward=32, |
| | max_seq_len=max_len, |
| | ), |
| | floors=floors, |
| | ) |
| | collapsed.eval() |
| | with torch.no_grad(): |
| | logits, _ = collapsed(valid_bits) |
| | pred = logits[:, :-1, :].reshape(-1, 2) |
| | target = valid_bits[:, 1:].reshape(-1) |
| | c_loss = F.cross_entropy(pred, target).item() |
| | print("Collapsed model validation loss:", c_loss) |
| | if collapsed_path is not None: |
| | save_distilled_model( |
| | collapsed, |
| | collapsed_path, |
| | {**metrics, "val_loss": c_loss}, |
| | floors=floors, |
| | ) |
| | if diffusion: |
| | sample = diffusion_inference( |
| | model, length=max_len, steps=diffusion_steps, schedule=noise_schedule |
| | ) |
| | print("Diffusion sample:", sample[0].tolist()) |
| | return results, collapsed |
| |
|
| |
|
| | if __name__ == "__main__": |
| | integration_schedule() |
| |
|