#@title Geometric Autoregressive LM - Full Training with HF Upload + TensorBoard generated valid shakespere """ Prototype LM for geometric simplex structures. Requires the geometricvocab's SimplexFactory for valid simplex representations, or the simplex behavior will not learn. try: !pip uninstall -qy geometricvocab except: pass !pip install -q git+https://github.com/AbstractEyes/lattice_vocabulary.git License: MIT """ import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset from torch.utils.tensorboard import SummaryWriter import math from itertools import combinations import time import os import json from tqdm.auto import tqdm from pathlib import Path device = 'cuda' if torch.cuda.is_available() else 'cpu' print(f"Device: {device}") from geovocab2.shapes.factory.simplex_factory import SimplexFactory from huggingface_hub import HfApi, create_repo, upload_folder import tiktoken # ============================================================================ # CONFIG # ============================================================================ HF_REPO = "AbstractPhil/ksimplex-llm-prototype" RUN_NAME = f"run_{int(time.time())}" CHECKPOINT_DIR = Path(f"./checkpoints/{RUN_NAME}") TENSORBOARD_DIR = Path(f"./runs/{RUN_NAME}") CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True) TENSORBOARD_DIR.mkdir(parents=True, exist_ok=True) # ============================================================================ # CAYLEY-MENGER VALIDATOR # ============================================================================ class CMValidator(nn.Module): def __init__(self, k): super().__init__() self._k = k self._nv = k + 1 pairs = list(combinations(range(self._nv), 2)) self._npairs = len(pairs) self.register_buffer('_pi', torch.tensor([p[0] for p in pairs], dtype=torch.long)) self.register_buffer('_pj', torch.tensor([p[1] for p in pairs], dtype=torch.long)) sign = (-1.0) ** (k + 1) fact = math.factorial(k) self._prefactor = sign / ((2.0 ** k) * (fact ** 2)) def forward(self, verts): gram = torch.einsum('...ve,...we->...vw', verts, verts) norms = torch.diagonal(gram, dim1=-2, dim2=-1) d2_mat = norms.unsqueeze(-1) + norms.unsqueeze(-2) - 2 * gram d2_mat = F.relu(d2_mat) d2_pairs = d2_mat[..., self._pi, self._pj] shape = d2_mat.shape[:-2] V = d2_mat.shape[-1] cm = torch.zeros(*shape, V+1, V+1, device=d2_mat.device, dtype=d2_mat.dtype) cm[..., 0, 1:] = 1.0 cm[..., 1:, 0] = 1.0 cm[..., 1:, 1:] = d2_mat vol2 = self._prefactor * torch.linalg.det(cm) return d2_pairs, vol2 # ============================================================================ # K-SIMPLEX CHANNEL ENCODER # ============================================================================ class KSimplexChannel(nn.Module): BASE_DEFORM = 0.05 def __init__(self, k, in_dim, edim, feat_dim): super().__init__() self._k = k self._nv = k + 1 self._edim = edim self._feat_dim = feat_dim self._cm = CMValidator(k) self._geo_dim = self._cm._npairs + 1 factory = SimplexFactory(k=k, embed_dim=edim, method="regular", scale=1.0) self.register_buffer('_template', factory.build_torch(dtype=torch.float32)) self._to_coords = nn.Linear(in_dim, self._nv * edim) self._to_feats = nn.Linear(in_dim, self._nv * feat_dim) self._geo_gate = nn.Sequential( nn.Linear(self._geo_dim, feat_dim), nn.Sigmoid(), ) self._out_dim = feat_dim + self._geo_dim @property def out_dim(self): return self._out_dim def forward(self, x): coords = self._to_coords(x).unflatten(-1, (self._nv, self._edim)) verts = self._template + self.BASE_DEFORM * coords vert_feats = self._to_feats(x).unflatten(-1, (self._nv, self._feat_dim)) d2, vol2 = self._cm(verts) geo = torch.cat([d2, vol2.unsqueeze(-1)], dim=-1) gate = self._geo_gate(geo) validity = torch.sigmoid(vol2 * 1e6).unsqueeze(-1) feat_agg = vert_feats.mean(dim=-2) * gate * validity out = torch.cat([feat_agg, geo], dim=-1) return out, vol2, d2.mean(dim=-1) # ============================================================================ # TOKEN TO K-SIMPLEX CHANNELS # ============================================================================ class TokenToKChannels(nn.Module): def __init__(self, embed_dim, depth, edim, feat_dim, hidden=256): super().__init__() self._depth = depth self._proj = nn.Sequential( nn.Linear(embed_dim, hidden), nn.LayerNorm(hidden), nn.GELU(), nn.Linear(hidden, hidden), nn.LayerNorm(hidden), nn.GELU(), ) self._k_encoders = nn.ModuleList([ KSimplexChannel(k=k+1, in_dim=hidden, edim=edim, feat_dim=feat_dim) for k in range(depth) ]) self._k_out_dims = [enc.out_dim for enc in self._k_encoders] self._max_out_dim = max(self._k_out_dims) def forward(self, x): h = self._proj(x) out_list, vol2_list, d2_list = [], [], [] for enc in self._k_encoders: out, vol2, d2_mean = enc(h) pad_size = self._max_out_dim - out.shape[-1] if pad_size > 0: out = F.pad(out, (0, pad_size)) out_list.append(out) vol2_list.append(vol2) d2_list.append(d2_mean) k_channels = torch.stack(out_list, dim=-2) vol2 = torch.stack(vol2_list, dim=-1) d2_mean = torch.stack(d2_list, dim=-1) return k_channels, vol2, d2_mean # ============================================================================ # K-CHANNEL CROSS-ATTENTION # ============================================================================ class KChannelCrossAttention(nn.Module): def __init__(self, depth, feat_dim, num_heads=4, dropout=0.1): super().__init__() self._depth = depth self._feat_dim = feat_dim self._num_heads = num_heads self._head_dim = feat_dim // num_heads self._norm_q = nn.LayerNorm(feat_dim) self._norm_kv = nn.LayerNorm(feat_dim) self._to_q = nn.Linear(feat_dim, feat_dim) self._to_k = nn.Linear(feat_dim, feat_dim) self._to_v = nn.Linear(feat_dim, feat_dim) self._out = nn.Linear(feat_dim, feat_dim) self._drop = nn.Dropout(dropout) self._scale = self._head_dim ** -0.5 def forward(self, x): B, T, K, F = x.shape x_flat = x.view(B * T, K, F) q = self._to_q(self._norm_q(x_flat)) k = self._to_k(self._norm_kv(x_flat)) v = self._to_v(self._norm_kv(x_flat)) q = q.view(-1, K, self._num_heads, self._head_dim).transpose(1, 2) k = k.view(-1, K, self._num_heads, self._head_dim).transpose(1, 2) v = v.view(-1, K, self._num_heads, self._head_dim).transpose(1, 2) attn = (q @ k.transpose(-2, -1)) * self._scale attn = attn.softmax(dim=-1) attn = self._drop(attn) out = (attn @ v).transpose(1, 2).reshape(B * T, K, F) out = self._out(out) out = self._drop(out) return x + out.view(B, T, K, F) # ============================================================================ # CAUSAL SEQUENCE ATTENTION # ============================================================================ class CausalSequenceAttention(nn.Module): def __init__(self, depth, feat_dim, num_heads=4, dropout=0.1, max_seq_len=2048): super().__init__() self._num_heads = num_heads total_dim = depth * feat_dim self._head_dim = total_dim // num_heads self._norm = nn.LayerNorm(total_dim) self._to_qkv = nn.Linear(total_dim, 3 * total_dim) self._out = nn.Linear(total_dim, total_dim) self._drop = nn.Dropout(dropout) self._scale = self._head_dim ** -0.5 self.register_buffer( '_causal_mask', torch.tril(torch.ones(max_seq_len, max_seq_len)).bool() ) def forward(self, x): B, T, K, F = x.shape x_flat = x.view(B, T, K * F) x_norm = self._norm(x_flat) qkv = self._to_qkv(x_norm).chunk(3, dim=-1) q, k, v = [t.view(B, T, self._num_heads, self._head_dim).transpose(1, 2) for t in qkv] attn = (q @ k.transpose(-2, -1)) * self._scale mask = self._causal_mask[:T, :T] attn = attn.masked_fill(~mask, float('-inf')) attn = attn.softmax(dim=-1) attn = self._drop(attn) out = (attn @ v).transpose(1, 2).reshape(B, T, K * F) out = self._out(out) out = self._drop(out) return x + out.view(B, T, K, F) # ============================================================================ # TRANSFORMER BLOCK # ============================================================================ class GeoBlock(nn.Module): def __init__(self, depth, feat_dim, num_heads, mlp_ratio=4.0, dropout=0.1, max_seq_len=2048): super().__init__() self._k_attn = KChannelCrossAttention(depth, feat_dim, num_heads, dropout) self._seq_attn = CausalSequenceAttention(depth, feat_dim, num_heads, dropout, max_seq_len) total_dim = depth * feat_dim self._norm = nn.LayerNorm(total_dim) self._mlp = nn.Sequential( nn.Linear(total_dim, int(total_dim * mlp_ratio)), nn.GELU(), nn.Dropout(dropout), nn.Linear(int(total_dim * mlp_ratio), total_dim), nn.Dropout(dropout), ) def forward(self, x): B, T, K, F = x.shape x = self._k_attn(x) x = self._seq_attn(x) x_flat = x.view(B, T, K * F) x_flat = x_flat + self._mlp(self._norm(x_flat)) x = x_flat.view(B, T, K, F) return x # ============================================================================ # GEOMETRIC LM # ============================================================================ class GeometricLM(nn.Module): def __init__( self, vocab_size, max_seq_len=512, embed_dim=256, depth=4, edim=16, feat_dim=64, hidden=256, num_heads=8, num_blocks=8, dropout=0.1, ): super().__init__() self._vocab_size = vocab_size self._max_seq_len = max_seq_len self._depth = depth self._feat_dim = feat_dim self._tok_embed = nn.Embedding(vocab_size, embed_dim) self._pos_embed = nn.Embedding(max_seq_len, embed_dim) self._tok_to_k = TokenToKChannels(embed_dim, depth, edim, feat_dim, hidden) self._max_out_dim = self._tok_to_k._max_out_dim self._proj = nn.Linear(self._max_out_dim, feat_dim) self._blocks = nn.ModuleList([ GeoBlock(depth, feat_dim, num_heads, dropout=dropout, max_seq_len=max_seq_len) for _ in range(num_blocks) ]) total_dim = depth * feat_dim self._norm = nn.LayerNorm(total_dim) self._lm_head = nn.Linear(total_dim, vocab_size, bias=False) self._config = { 'vocab_size': vocab_size, 'max_seq_len': max_seq_len, 'embed_dim': embed_dim, 'depth': depth, 'edim': edim, 'feat_dim': feat_dim, 'hidden': hidden, 'num_heads': num_heads, 'num_blocks': num_blocks, 'dropout': dropout, 'total_dim': total_dim, } def forward(self, tokens): B, T = tokens.shape pos = torch.arange(T, device=tokens.device) x = self._tok_embed(tokens) + self._pos_embed(pos) k_channels, vol2, d2_mean = self._tok_to_k(x) k_channels = self._proj(k_channels) for blk in self._blocks: k_channels = blk(k_channels) out = k_channels.flatten(-2) logits = self._lm_head(self._norm(out)) return logits, {'vol2': vol2, 'd2_mean': d2_mean} @torch.no_grad() def generate(self, prompt_tokens, max_new_tokens=100, temperature=1.0, top_k=50): self.eval() tokens = prompt_tokens.clone() for _ in range(max_new_tokens): ctx = tokens[:, -self._max_seq_len:] logits, _ = self(ctx) logits = logits[:, -1, :] / temperature if top_k > 0: v, _ = torch.topk(logits, top_k) logits[logits < v[:, [-1]]] = float('-inf') probs = F.softmax(logits, dim=-1) next_tok = torch.multinomial(probs, num_samples=1) tokens = torch.cat([tokens, next_tok], dim=1) return tokens # ============================================================================ # DATASET # ============================================================================ class TokenizedDataset(Dataset): def __init__(self, tokens, seq_len, stride=None): self._tokens = tokens self._seq_len = seq_len self._stride = stride if stride else seq_len // 2 # 50% overlap max def __len__(self): return max(0, (len(self._tokens) - self._seq_len - 1) // self._stride) def __getitem__(self, idx): start = idx * self._stride chunk = self._tokens[start:start + self._seq_len + 1] x = torch.tensor(chunk[:-1], dtype=torch.long) y = torch.tensor(chunk[1:], dtype=torch.long) return x, y # ============================================================================ # LOSS & METRICS # ============================================================================ def lm_loss(logits, targets, info, ce_weight=1.0, validity_weight=0.1): B, T, V = logits.shape ce = F.cross_entropy(logits.view(B * T, V), targets.view(B * T)) validity = F.relu(-info['vol2']).mean() total = ce_weight * ce + validity_weight * validity return total, ce, validity @torch.no_grad() def compute_metrics(info, depth): vol2 = info['vol2'] d2_mean = info['d2_mean'] m = {'valid_rate': (vol2 > 0).float().mean().item()} for k in range(depth): m[f'k{k+1}_valid'] = (vol2[..., k] > 0).float().mean().item() m[f'k{k+1}_vol2'] = vol2[..., k].mean().item() m[f'k{k+1}_d2'] = d2_mean[..., k].mean().item() return m # ============================================================================ # SANITY CHECK # ============================================================================ @torch.no_grad() def sanity_check(model, enc, device): """Verify no information leak.""" print("\n" + "=" * 60) print("SANITY CHECK") print("=" * 60) model.eval() # Test 1: Random input should give high CE random_tokens = torch.randint(0, 1000, (4, 256), device=device) logits, _ = model(random_tokens) random_targets = torch.randint(0, enc.n_vocab, (4, 256), device=device) ce = F.cross_entropy(logits.view(-1, enc.n_vocab), random_targets.view(-1)) expected_ce = math.log(enc.n_vocab) print(f"Test 1 - Random input:") print(f" CE: {ce.item():.2f} (expected ~{expected_ce:.2f})") print(f" PPL: {math.exp(min(ce.item(), 20)):.0f} (expected ~{enc.n_vocab})") test1_pass = ce.item() > 8.0 # Should be close to ln(50257) ≈ 10.8 print(f" Status: {'✓ PASS' if test1_pass else '✗ FAIL'}") # Test 2: Causal mask - early positions shouldn't depend on late tokens tokens1 = torch.zeros(1, 256, dtype=torch.long, device=device) tokens2 = torch.zeros(1, 256, dtype=torch.long, device=device) tokens2[0, 128:] = 999 # Change later tokens logits1, _ = model(tokens1) logits2, _ = model(tokens2) diff_early = (logits1[0, :128] - logits2[0, :128]).abs().max().item() diff_late = (logits1[0, 128:] - logits2[0, 128:]).abs().max().item() print(f"\nTest 2 - Causal mask:") print(f" Early positions diff: {diff_early:.6f} (should be ~0)") print(f" Late positions diff: {diff_late:.6f} (should be >0)") test2_pass = diff_early < 1e-5 and diff_late > 1e-3 print(f" Status: {'✓ PASS' if test2_pass else '✗ FAIL'}") # Test 3: Dataset sanity - x and y should be offset by 1 print(f"\nTest 3 - Dataset offset:") test_tokens = list(range(100)) ds = TokenizedDataset(test_tokens, seq_len=10) x, y = ds[0] offset_correct = all(x[i] + 1 == y[i] for i in range(len(x))) print(f" x: {x[:5].tolist()}...") print(f" y: {y[:5].tolist()}...") print(f" Offset correct: {'✓ PASS' if offset_correct else '✗ FAIL'}") print("=" * 60) all_pass = test1_pass and test2_pass and offset_correct if not all_pass: print("⚠️ WARNING: Some sanity checks failed!") else: print("✓ All sanity checks passed!") print("=" * 60 + "\n") model.train() return all_pass # ============================================================================ # GENERATION SAMPLING # ============================================================================ PROMPTS = [ "ROMEO: ", "JULIET: ", "To be or not to be", "The king ", "Once upon a time", "First Citizen:\n", "What light through yonder", "Friends, Romans, countrymen", "Now is the winter of", "All the world's a stage", ] @torch.no_grad() def generate_samples(model, enc, device, epoch, writer=None): """Generate samples from all prompts.""" model.eval() samples = [] print(f"\n{'='*60}") print(f"GENERATION SAMPLES - Epoch {epoch}") print(f"{'='*60}") for i, prompt in enumerate(PROMPTS): prompt_tokens = torch.tensor([enc.encode(prompt)], device=device) out_tokens = model.generate( prompt_tokens, max_new_tokens=100, temperature=0.8, top_k=50 ) generated = enc.decode(out_tokens[0].tolist()) samples.append({'prompt': prompt, 'generated': generated}) print(f"\n--- Prompt {i+1}: '{prompt.strip()}' ---") print(generated[:300]) if len(generated) > 300: print("...") print(f"{'='*60}\n") # Log to tensorboard if writer: sample_text = "\n\n".join([ f"**Prompt:** {s['prompt']}\n**Generated:**\n{s['generated'][:500]}" for s in samples ]) writer.add_text("samples/generated", sample_text, epoch) model.train() return samples # ============================================================================ # CHECKPOINTING & HF UPLOAD # ============================================================================ def save_checkpoint(model, optimizer, scheduler, epoch, config, metrics, checkpoint_dir): """Save checkpoint locally.""" checkpoint = { 'epoch': epoch, 'model_state_dict': model._orig_mod.state_dict() if hasattr(model, '_orig_mod') else model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'config': config, 'metrics': metrics, } path = checkpoint_dir / f"checkpoint_epoch_{epoch:03d}.pt" torch.save(checkpoint, path) # Also save latest torch.save(checkpoint, checkpoint_dir / "checkpoint_latest.pt") # Save config as JSON with open(checkpoint_dir / "config.json", 'w') as f: json.dump(config, f, indent=2) print(f"Saved checkpoint: {path}") return path def upload_to_hf(checkpoint_dir, repo_id, epoch): """Upload checkpoint directory to HuggingFace.""" try: api = HfApi() # Create repo if doesn't exist try: create_repo(repo_id, exist_ok=True, repo_type="model") except Exception as e: print(f"Repo creation note: {e}") # Upload folder api.upload_folder( folder_path=str(checkpoint_dir), repo_id=repo_id, commit_message=f"Epoch {epoch} checkpoint", ) print(f"Uploaded to HuggingFace: {repo_id}") return True except Exception as e: print(f"HuggingFace upload failed: {e}") return False # ============================================================================ # TRAIN # ============================================================================ def train(): import urllib.request # TensorBoard writer = SummaryWriter(log_dir=str(TENSORBOARD_DIR)) print(f"TensorBoard logs: {TENSORBOARD_DIR}") print(f"Checkpoints: {CHECKPOINT_DIR}") print(f"HuggingFace repo: {HF_REPO}") # Data data_path = './data/shakespeare.txt' if not os.path.exists(data_path): os.makedirs('./data', exist_ok=True) url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt' print("Downloading Shakespeare...") urllib.request.urlretrieve(url, data_path) with open(data_path, 'r') as f: text = f.read() print(f"Text length: {len(text):,} chars") # Tokenizer print("Loading tokenizer...") enc = tiktoken.get_encoding("gpt2") print("Tokenizing...") tokens = enc.encode(text) print(f"Token count: {len(tokens):,}") print(f"Vocab size: {enc.n_vocab:,}") print(f"Compression ratio: {len(text) / len(tokens):.2f}x") # Split seq_len = 256 split_idx = int(len(tokens) * 0.9) train_tokens = tokens[:split_idx] val_tokens = tokens[split_idx:] train_ds = TokenizedDataset(train_tokens, seq_len) val_ds = TokenizedDataset(val_tokens, seq_len) batch_size = 12 train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True, persistent_workers=True) val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True, persistent_workers=True) print(f"Train sequences: {len(train_ds):,} ({len(train_dl)} batches)") print(f"Val sequences: {len(val_ds):,} ({len(val_dl)} batches)") # Model config model_config = { 'vocab_size': enc.n_vocab, 'max_seq_len': seq_len, 'embed_dim': 384, 'depth': 4, 'edim': 16, 'feat_dim': 96, 'hidden': 384, 'num_heads': 8, 'num_blocks': 8, 'dropout': 0.1, } # Training config train_config = { 'batch_size': batch_size, 'seq_len': seq_len, 'lr': 3e-4, 'weight_decay': 0.1, 'num_epochs': 14, 'grad_clip': 1.0, 'ce_weight': 1.0, 'validity_weight': 0.1, } full_config = { 'model': model_config, 'training': train_config, 'data': { 'train_tokens': len(train_tokens), 'val_tokens': len(val_tokens), 'vocab_size': enc.n_vocab, }, 'run_name': RUN_NAME, } # Save config with open(CHECKPOINT_DIR / "config.json", 'w') as f: json.dump(full_config, f, indent=2) # Model print("\nBuilding model...") model = GeometricLM(**model_config).to(device) print(f"\nConfig:") for k, v in model._config.items(): print(f" {k}: {v}") params = sum(p.numel() for p in model.parameters()) print(f" params: {params:,}") full_config['model']['params'] = params # Sanity check BEFORE compile sanity_check(model, enc, device) print("\nCompiling...") #model = torch.compile(model, mode="reduce-overhead") # Optimizer opt = torch.optim.AdamW( model.parameters(), lr=train_config['lr'], weight_decay=train_config['weight_decay'] ) sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=train_config['num_epochs']) # Log model graph # writer.add_graph(model, torch.zeros(1, seq_len, dtype=torch.long, device=device)) best_val = float('inf') best_ppl = float('inf') global_step = 0 print("\nTraining...") print("=" * 120) epoch_pbar = tqdm(range(train_config['num_epochs']), desc="Epochs", position=0) for ep in epoch_pbar: epoch_start = time.time() # ==================== TRAIN ==================== model.train() ce_sum, val_sum, n = 0, 0, 0 train_pbar = tqdm(train_dl, desc=f"Train {ep+1}", leave=False, position=1) for batch_idx, (x, y) in enumerate(train_pbar): x, y = x.to(device), y.to(device) opt.zero_grad() logits, info = model(x) loss, ce, val = lm_loss( logits, y, info, ce_weight=train_config['ce_weight'], validity_weight=train_config['validity_weight'] ) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), train_config['grad_clip']) opt.step() ce_sum += ce.item() * x.size(0) val_sum += val.item() * x.size(0) n += x.size(0) # TensorBoard - batch level if global_step % 100 == 0: writer.add_scalar("train/ce_batch", ce.item(), global_step) writer.add_scalar("train/ppl_batch", math.exp(min(ce.item(), 10)), global_step) writer.add_scalar("train/validity_batch", val.item(), global_step) writer.add_scalar("train/lr", sched.get_last_lr()[0], global_step) global_step += 1 train_pbar.set_postfix({ 'CE': f'{ce.item():.3f}', 'PPL': f'{math.exp(min(ce.item(), 10)):.1f}' }) tr_ce = ce_sum / n tr_ppl = math.exp(min(tr_ce, 10)) tr_val = val_sum / n # ==================== VAL ==================== model.eval() ce_sum, n = 0, 0 metrics_agg = [] val_pbar = tqdm(val_dl, desc=f"Val {ep+1}", leave=False, position=1) with torch.no_grad(): for x, y in val_pbar: x, y = x.to(device), y.to(device) logits, info = model(x) _, ce, _ = lm_loss(logits, y, info) ce_sum += ce.item() * x.size(0) n += x.size(0) metrics_agg.append(compute_metrics(info, model._config['depth'])) val_pbar.set_postfix({ 'CE': f'{ce.item():.3f}', 'PPL': f'{math.exp(min(ce.item(), 10)):.1f}' }) va_ce = ce_sum / n va_ppl = math.exp(min(va_ce, 10)) sched.step() if va_ce < best_val: best_val = va_ce best_ppl = va_ppl # Aggregate metrics m = {k: sum(d[k] for d in metrics_agg) / len(metrics_agg) for k in metrics_agg[0]} epoch_time = time.time() - epoch_start # ==================== TENSORBOARD - EPOCH ==================== writer.add_scalar("epoch/train_ce", tr_ce, ep) writer.add_scalar("epoch/train_ppl", tr_ppl, ep) writer.add_scalar("epoch/val_ce", va_ce, ep) writer.add_scalar("epoch/val_ppl", va_ppl, ep) writer.add_scalar("epoch/best_ppl", best_ppl, ep) writer.add_scalar("epoch/validity_loss", tr_val, ep) writer.add_scalar("epoch/time", epoch_time, ep) for k in range(model._config['depth']): writer.add_scalar(f"geometry/k{k+1}_valid", m[f'k{k+1}_valid'], ep) writer.add_scalar(f"geometry/k{k+1}_vol2", m[f'k{k+1}_vol2'], ep) writer.add_scalar(f"geometry/k{k+1}_d2", m[f'k{k+1}_d2'], ep) writer.add_scalar("geometry/valid_rate", m['valid_rate'], ep) # ==================== LOGGING ==================== epoch_pbar.set_postfix({ 'TrPPL': f'{tr_ppl:.1f}', 'VaPPL': f'{va_ppl:.1f}', 'Best': f'{best_ppl:.1f}', 'Valid': f"{m['valid_rate']:.0%}" }) tqdm.write( f"\nEp {ep+1:3d} | TrCE {tr_ce:.4f} | VaCE {va_ce:.4f} | " f"TrPPL {tr_ppl:7.2f} | VaPPL {va_ppl:7.2f} | BestPPL {best_ppl:.2f} | " f"Time {epoch_time:.1f}s" ) tqdm.write( f" | k1 {m['k1_valid']:5.1%} vol²={m['k1_vol2']:.2e} | " f"k2 {m['k2_valid']:5.1%} vol²={m['k2_vol2']:.2e} | " f"k3 {m['k3_valid']:5.1%} vol²={m['k3_vol2']:.2e} | " f"k4 {m['k4_valid']:5.1%} vol²={m['k4_vol2']:.2e}" ) # ==================== GENERATE SAMPLES ==================== if ep % 25 == 0 or ep == train_config['num_epochs'] - 1: samples = generate_samples(model, enc, device, ep + 1, writer) # Save samples to file with open(CHECKPOINT_DIR / f"samples_epoch_{ep+1:03d}.json", 'w') as f: json.dump(samples, f, indent=2) # ==================== CHECKPOINT ==================== metrics = { 'epoch': ep + 1, 'train_ce': tr_ce, 'train_ppl': tr_ppl, 'val_ce': va_ce, 'val_ppl': va_ppl, 'best_ppl': best_ppl, 'geometry': m, } if ep % 2 == 0 or ep == train_config['num_epochs'] - 1: save_checkpoint(model, opt, sched, ep + 1, full_config, metrics, CHECKPOINT_DIR) # ==================== HF UPLOAD ==================== if train_config['num_epochs'] - 1 == ep: upload_to_hf(CHECKPOINT_DIR, HF_REPO, ep + 1) # ==================== FINAL ==================== writer.close() print("\n" + "=" * 120) print(f"Training complete!") print(f"Best val CE: {best_val:.4f}, PPL: {best_ppl:.2f}") print(f"Checkpoints: {CHECKPOINT_DIR}") print(f"TensorBoard: {TENSORBOARD_DIR}") print(f"HuggingFace: https://huggingface.co/{HF_REPO}") print("=" * 120) return model, enc if __name__ == "__main__": model, tokenizer = train()