| | """ |
| | BERT-Thetis Colab Training Script |
| | ---------------------------------- |
| | Pretrain BERT-Thetis on WikiText-103 with Masked Language Modeling. |
| | |
| | In a cell above this in colab run this install here; and then begin the training. |
| | |
| | try: |
| | !pip uninstall -qy geometricvocab |
| | except: |
| | pass |
| | |
| | !pip install -q git+https://github.com/AbstractEyes/lattice_vocabulary.git |
| | |
| | |
| | Designed for Google Colab with: |
| | - Easy setup and installation |
| | - HuggingFace Hub integration |
| | - Memory-efficient training |
| | - Progress tracking and logging |
| | - Automatic checkpointing |
| | |
| | Author: AbstractPhil + Claude Sonnet 4.5 |
| | License: MIT |
| | """ |
| |
|
| | import os |
| | import math |
| | import time |
| | from pathlib import Path |
| | from typing import Optional, Dict, Any |
| | from dataclasses import dataclass, field |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torch.utils.data import DataLoader, Dataset |
| | from torch.optim import AdamW |
| | from torch.optim.lr_scheduler import OneCycleLR |
| |
|
| | from datasets import load_dataset |
| | from transformers import AutoTokenizer |
| | from tqdm.auto import tqdm |
| |
|
| | |
| | from geovocab2.train.model.core.bert_thetis import ( |
| | ThetisConfig, |
| | ThetisForMaskedLM |
| | ) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | @dataclass |
| | class TrainingConfig: |
| | """Training configuration for Colab.""" |
| | |
| | |
| | model_name: str = "bert-thetis-tiny-wikitext103" |
| | crystal_dim: int = 256 |
| | num_layers: int = 4 |
| | num_attention_heads: int = 4 |
| | intermediate_size: int = 1024 |
| | vocab_size: int = 30522 |
| | beatrix_levels: int = 16 |
| | max_position_embeddings: int = 512 |
| | |
| | |
| | dataset_name: str = "wikitext" |
| | dataset_config: str = "wikitext-103-raw-v1" |
| | tokenizer_name: str = "bert-base-uncased" |
| | max_length: int = 128 |
| | mlm_probability: float = 0.15 |
| | |
| | |
| | num_epochs: int = 10 |
| | batch_size: int = 64 |
| | gradient_accumulation_steps: int = 2 |
| | learning_rate: float = 5e-4 |
| | weight_decay: float = 0.01 |
| | warmup_ratio: float = 0.1 |
| | max_grad_norm: float = 1.0 |
| | |
| | |
| | device: str = "cuda" if torch.cuda.is_available() else "cpu" |
| | num_workers: int = 2 |
| | pin_memory: bool = True |
| | mixed_precision: bool = True |
| | |
| | |
| | save_steps: int = 1000 |
| | eval_steps: int = 500 |
| | logging_steps: int = 100 |
| | save_total_limit: int = 3 |
| | |
| | |
| | push_to_hub: bool = True |
| | hub_model_id: str = "AbstractPhil/bert-thetis-tiny-wikitext103" |
| | hub_token: Optional[str] = None |
| | |
| | |
| | output_dir: str = "./thetis-outputs" |
| | cache_dir: str = "./cache" |
| | |
| | def __post_init__(self): |
| | """Setup paths and device.""" |
| | os.makedirs(self.output_dir, exist_ok=True) |
| | os.makedirs(self.cache_dir, exist_ok=True) |
| | |
| | |
| | if self.hub_token is None: |
| | self.hub_token = os.environ.get("HF_TOKEN") |
| | |
| | print(f"π’ BERT-Thetis Training Configuration") |
| | print(f" Device: {self.device}") |
| | print(f" Mixed Precision: {self.mixed_precision}") |
| | print(f" Model: {self.model_name}") |
| | print(f" Dataset: {self.dataset_name}/{self.dataset_config}") |
| | print(f" Output: {self.output_dir}") |
| | print(f" Push to Hub: {self.push_to_hub}") |
| | if self.push_to_hub: |
| | print(f" Hub Repo: {self.hub_model_id}") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class MaskedLMDataset(Dataset): |
| | """Dataset for Masked Language Modeling.""" |
| | |
| | def __init__( |
| | self, |
| | texts, |
| | tokenizer, |
| | max_length: int = 128, |
| | mlm_probability: float = 0.15 |
| | ): |
| | self.texts = texts |
| | self.tokenizer = tokenizer |
| | self.max_length = max_length |
| | self.mlm_probability = mlm_probability |
| | |
| | def __len__(self): |
| | return len(self.texts) |
| | |
| | def __getitem__(self, idx): |
| | text = self.texts[idx] |
| | |
| | |
| | encoding = self.tokenizer( |
| | text, |
| | max_length=self.max_length, |
| | padding="max_length", |
| | truncation=True, |
| | return_tensors="pt" |
| | ) |
| | |
| | input_ids = encoding["input_ids"].squeeze(0) |
| | attention_mask = encoding["attention_mask"].squeeze(0) |
| | |
| | |
| | labels = input_ids.clone() |
| | |
| | |
| | probability_matrix = torch.full(labels.shape, self.mlm_probability) |
| | |
| | |
| | special_tokens_mask = self.tokenizer.get_special_tokens_mask( |
| | labels.tolist(), already_has_special_tokens=True |
| | ) |
| | probability_matrix.masked_fill_( |
| | torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0 |
| | ) |
| | |
| | masked_indices = torch.bernoulli(probability_matrix).bool() |
| | labels[~masked_indices] = -100 |
| | |
| | |
| | indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices |
| | input_ids[indices_replaced] = self.tokenizer.mask_token_id |
| | |
| | |
| | indices_random = ( |
| | torch.bernoulli(torch.full(labels.shape, 0.5)).bool() |
| | & masked_indices |
| | & ~indices_replaced |
| | ) |
| | random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long) |
| | input_ids[indices_random] = random_words[indices_random] |
| | |
| | |
| | |
| | return { |
| | "input_ids": input_ids, |
| | "attention_mask": attention_mask, |
| | "labels": labels |
| | } |
| |
|
| |
|
| | def prepare_datasets(config: TrainingConfig): |
| | """Load and prepare WikiText-103 datasets.""" |
| | print(f"\nπ Loading {config.dataset_name}...") |
| | |
| | |
| | dataset = load_dataset( |
| | config.dataset_name, |
| | config.dataset_config, |
| | cache_dir=config.cache_dir |
| | ) |
| | |
| | |
| | tokenizer = AutoTokenizer.from_pretrained( |
| | config.tokenizer_name, |
| | cache_dir=config.cache_dir |
| | ) |
| | |
| | |
| | def is_valid(example): |
| | return len(example["text"].strip()) > 0 |
| | |
| | train_texts = [ex["text"] for ex in dataset["train"] if is_valid(ex)] |
| | val_texts = [ex["text"] for ex in dataset["validation"] if is_valid(ex)] |
| | |
| | print(f" Train samples: {len(train_texts):,}") |
| | print(f" Val samples: {len(val_texts):,}") |
| | |
| | |
| | train_dataset = MaskedLMDataset( |
| | train_texts, |
| | tokenizer, |
| | config.max_length, |
| | config.mlm_probability |
| | ) |
| | |
| | val_dataset = MaskedLMDataset( |
| | val_texts, |
| | tokenizer, |
| | config.max_length, |
| | config.mlm_probability |
| | ) |
| | |
| | return train_dataset, val_dataset, tokenizer |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class ThetisTrainer: |
| | """Trainer for BERT-Thetis with MLM.""" |
| | |
| | def __init__( |
| | self, |
| | model: ThetisForMaskedLM, |
| | train_dataset: Dataset, |
| | val_dataset: Dataset, |
| | config: TrainingConfig |
| | ): |
| | self.model = model |
| | self.train_dataset = train_dataset |
| | self.val_dataset = val_dataset |
| | self.config = config |
| | |
| | |
| | self.model.to(config.device) |
| | |
| | |
| | self.train_loader = DataLoader( |
| | train_dataset, |
| | batch_size=config.batch_size, |
| | shuffle=True, |
| | num_workers=config.num_workers, |
| | pin_memory=config.pin_memory |
| | ) |
| | |
| | self.val_loader = DataLoader( |
| | val_dataset, |
| | batch_size=config.batch_size * 2, |
| | shuffle=False, |
| | num_workers=config.num_workers, |
| | pin_memory=config.pin_memory |
| | ) |
| | |
| | |
| | no_decay = ["bias", "LayerNorm.weight"] |
| | optimizer_grouped_parameters = [ |
| | { |
| | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], |
| | "weight_decay": config.weight_decay, |
| | }, |
| | { |
| | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], |
| | "weight_decay": 0.0, |
| | }, |
| | ] |
| | |
| | self.optimizer = AdamW(optimizer_grouped_parameters, lr=config.learning_rate) |
| | |
| | |
| | total_steps = len(self.train_loader) * config.num_epochs // config.gradient_accumulation_steps |
| | warmup_steps = int(total_steps * config.warmup_ratio) |
| | |
| | self.scheduler = OneCycleLR( |
| | self.optimizer, |
| | max_lr=config.learning_rate, |
| | total_steps=total_steps, |
| | pct_start=config.warmup_ratio, |
| | anneal_strategy="cos" |
| | ) |
| | |
| | |
| | self.scaler = torch.amp.GradScaler('cuda') if config.mixed_precision and config.device == 'cuda' else None |
| | |
| | |
| | self.global_step = 0 |
| | self.epoch = 0 |
| | self.best_val_loss = float("inf") |
| | |
| | print(f"\nπ― Training Setup") |
| | print(f" Total steps: {total_steps:,}") |
| | print(f" Warmup steps: {warmup_steps:,}") |
| | print(f" Effective batch size: {config.batch_size * config.gradient_accumulation_steps}") |
| | |
| | def train_epoch(self): |
| | """Train for one epoch.""" |
| | self.model.train() |
| | total_loss = 0 |
| | |
| | progress_bar = tqdm(self.train_loader, desc=f"Epoch {self.epoch + 1}") |
| | |
| | for step, batch in enumerate(progress_bar): |
| | |
| | batch = {k: v.to(self.config.device) for k, v in batch.items()} |
| | |
| | |
| | with torch.amp.autocast('cuda', enabled=self.config.mixed_precision and self.config.device == 'cuda'): |
| | loss, _ = self.model( |
| | token_ids=batch["input_ids"], |
| | attention_mask=batch["attention_mask"], |
| | labels=batch["labels"] |
| | ) |
| | loss = loss / self.config.gradient_accumulation_steps |
| | |
| | |
| | if self.scaler is not None: |
| | self.scaler.scale(loss).backward() |
| | else: |
| | loss.backward() |
| | |
| | total_loss += loss.item() |
| | |
| | |
| | if (step + 1) % self.config.gradient_accumulation_steps == 0: |
| | if self.scaler is not None: |
| | self.scaler.unscale_(self.optimizer) |
| | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm) |
| | self.scaler.step(self.optimizer) |
| | self.scaler.update() |
| | else: |
| | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm) |
| | self.optimizer.step() |
| | |
| | self.scheduler.step() |
| | self.optimizer.zero_grad() |
| | self.global_step += 1 |
| | |
| | |
| | progress_bar.set_postfix({ |
| | "loss": f"{loss.item() * self.config.gradient_accumulation_steps:.4f}", |
| | "lr": f"{self.scheduler.get_last_lr()[0]:.2e}" |
| | }) |
| | |
| | |
| | if self.global_step % self.config.logging_steps == 0: |
| | avg_loss = total_loss / self.config.logging_steps |
| | print(f"\n Step {self.global_step}: loss={avg_loss:.4f}, lr={self.scheduler.get_last_lr()[0]:.2e}") |
| | total_loss = 0 |
| | |
| | |
| | if self.global_step % self.config.eval_steps == 0: |
| | val_loss = self.evaluate() |
| | print(f" Validation loss: {val_loss:.4f}") |
| | |
| | |
| | if val_loss < self.best_val_loss: |
| | self.best_val_loss = val_loss |
| | self.save_checkpoint("best") |
| | print(f" β New best model saved!") |
| | |
| | self.model.train() |
| | |
| | |
| | if self.global_step % self.config.save_steps == 0: |
| | self.save_checkpoint(f"step-{self.global_step}") |
| | |
| | @torch.no_grad() |
| | def evaluate(self): |
| | """Evaluate on validation set.""" |
| | self.model.eval() |
| | total_loss = 0 |
| | total_steps = 0 |
| | |
| | for batch in tqdm(self.val_loader, desc="Evaluating", leave=False): |
| | batch = {k: v.to(self.config.device) for k, v in batch.items()} |
| | |
| | with torch.amp.autocast('cuda', enabled=self.config.mixed_precision and self.config.device == 'cuda'): |
| | loss, _ = self.model( |
| | token_ids=batch["input_ids"], |
| | attention_mask=batch["attention_mask"], |
| | labels=batch["labels"] |
| | ) |
| | |
| | total_loss += loss.item() |
| | total_steps += 1 |
| | |
| | return total_loss / total_steps |
| | |
| | def train(self): |
| | """Full training loop.""" |
| | print(f"\nπ Starting Training") |
| | print("=" * 70) |
| | |
| | start_time = time.time() |
| | |
| | for epoch in range(self.config.num_epochs): |
| | self.epoch = epoch |
| | print(f"\nπ Epoch {epoch + 1}/{self.config.num_epochs}") |
| | |
| | self.train_epoch() |
| | |
| | |
| | val_loss = self.evaluate() |
| | print(f"\n Epoch {epoch + 1} validation loss: {val_loss:.4f}") |
| | |
| | |
| | self.save_checkpoint(f"epoch-{epoch + 1}") |
| | |
| | |
| | final_val_loss = self.evaluate() |
| | print(f"\nβ
Training Complete!") |
| | print(f" Final validation loss: {final_val_loss:.4f}") |
| | print(f" Best validation loss: {self.best_val_loss:.4f}") |
| | print(f" Total time: {(time.time() - start_time) / 3600:.2f} hours") |
| | |
| | |
| | self.save_checkpoint("final") |
| | |
| | |
| | if self.config.push_to_hub: |
| | self.push_to_hub() |
| | |
| | def save_checkpoint(self, name: str): |
| | """Save model checkpoint.""" |
| | output_dir = Path(self.config.output_dir) / name |
| | output_dir.mkdir(parents=True, exist_ok=True) |
| | |
| | |
| | torch.save(self.model.state_dict(), output_dir / "pytorch_model.bin") |
| | |
| | |
| | config_dict = { |
| | "crystal_dim": self.config.crystal_dim, |
| | "num_layers": self.config.num_layers, |
| | "num_attention_heads": self.config.num_attention_heads, |
| | "intermediate_size": self.config.intermediate_size, |
| | "vocab_size": self.config.vocab_size, |
| | "beatrix_levels": self.config.beatrix_levels, |
| | "max_position_embeddings": self.config.max_position_embeddings, |
| | } |
| | |
| | import json |
| | with open(output_dir / "config.json", "w") as f: |
| | json.dump(config_dict, f, indent=2) |
| | |
| | |
| | state = { |
| | "global_step": self.global_step, |
| | "epoch": self.epoch, |
| | "best_val_loss": self.best_val_loss, |
| | } |
| | torch.save(state, output_dir / "training_state.pt") |
| | |
| | def push_to_hub(self): |
| | """Push model to HuggingFace Hub.""" |
| | if not self.config.hub_token: |
| | print("β οΈ No HuggingFace token found. Skipping push to hub.") |
| | return |
| | |
| | print(f"\nπ€ Pushing to HuggingFace Hub: {self.config.hub_model_id}") |
| | |
| | try: |
| | from huggingface_hub import HfApi, create_repo |
| | |
| | api = HfApi(token=self.config.hub_token) |
| | |
| | |
| | try: |
| | create_repo( |
| | repo_id=self.config.hub_model_id, |
| | token=self.config.hub_token, |
| | exist_ok=True |
| | ) |
| | except Exception as e: |
| | print(f" Repo creation: {e}") |
| | |
| | |
| | best_dir = Path(self.config.output_dir) / "best" |
| | if best_dir.exists(): |
| | api.upload_folder( |
| | folder_path=str(best_dir), |
| | repo_id=self.config.hub_model_id, |
| | token=self.config.hub_token |
| | ) |
| | print(f" β Best model uploaded!") |
| | |
| | |
| | final_dir = Path(self.config.output_dir) / "final" |
| | if final_dir.exists(): |
| | api.upload_folder( |
| | folder_path=str(final_dir), |
| | repo_id=self.config.hub_model_id, |
| | path_in_repo="final", |
| | token=self.config.hub_token |
| | ) |
| | print(f" β Final model uploaded!") |
| | |
| | except Exception as e: |
| | print(f"β οΈ Failed to push to hub: {e}") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def main(): |
| | """Main training function.""" |
| | |
| | config = TrainingConfig() |
| | |
| | |
| | train_dataset, val_dataset, tokenizer = prepare_datasets(config) |
| | |
| | |
| | print(f"\nποΈ Creating BERT-Thetis model...") |
| | model_config = ThetisConfig( |
| | crystal_dim=config.crystal_dim, |
| | num_vertices=5, |
| | num_layers=config.num_layers, |
| | num_attention_heads=config.num_attention_heads, |
| | intermediate_size=config.intermediate_size, |
| | vocab_size=config.vocab_size, |
| | beatrix_levels=config.beatrix_levels, |
| | max_position_embeddings=config.max_position_embeddings, |
| | ) |
| | |
| | model = ThetisForMaskedLM(model_config) |
| | |
| | total_params = sum(p.numel() for p in model.parameters()) |
| | trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| | |
| | print(f" Total parameters: {total_params:,}") |
| | print(f" Trainable parameters: {trainable_params:,}") |
| | |
| | |
| | trainer = ThetisTrainer(model, train_dataset, val_dataset, config) |
| | |
| | |
| | trainer.train() |
| | |
| | print("\nπ All done! BERT-Thetis is ready to sail!") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |