| """ |
| Configuration for 1B parameter LLaMA-style Transformer model. |
| Architecture: Decoder-only Transformer with RoPE, GQA, SwiGLU, RMSNorm. |
| """ |
|
|
| from dataclasses import dataclass |
|
|
|
|
| @dataclass |
| class ModelConfig: |
| vocab_size: int = 32000 |
| hidden_dim: int = 2048 |
| intermediate_dim: int = 5504 |
| num_layers: int = 22 |
| num_attention_heads: int = 32 |
| num_kv_heads: int = 8 |
| max_seq_len: int = 2048 |
| rope_theta: float = 10000.0 |
| rms_norm_eps: float = 1e-5 |
| dropout: float = 0.0 |
| tie_word_embeddings: bool = False |
|
|
| @property |
| def head_dim(self) -> int: |
| return self.hidden_dim // self.num_attention_heads |
|
|
| @property |
| def num_params_approx(self) -> int: |
| """Rough parameter count estimate.""" |
| embed = self.vocab_size * self.hidden_dim |
| attn_per_layer = ( |
| self.hidden_dim * self.head_dim * self.num_attention_heads + |
| self.hidden_dim * self.head_dim * self.num_kv_heads + |
| self.hidden_dim * self.head_dim * self.num_kv_heads + |
| self.head_dim * self.num_attention_heads * self.hidden_dim |
| ) |
| ffn_per_layer = 3 * self.hidden_dim * self.intermediate_dim |
| norm_per_layer = 2 * self.hidden_dim |
| total = ( |
| embed + |
| self.num_layers * (attn_per_layer + ffn_per_layer + norm_per_layer) + |
| self.hidden_dim + |
| (0 if self.tie_word_embeddings else self.vocab_size * self.hidden_dim) |
| ) |
| return total |
|
|
|
|
| @dataclass |
| class TrainConfig: |
| |
| checkpoint_dir: str = "/jfs/deepak-kumar/checkpoints" |
| data_cache_dir: str = "/jfs/deepak-kumar/data" |
| log_dir: str = "/home/jovyan/training/logs" |
|
|
| |
| total_tokens: int = 20_000_000_000 |
| batch_size_per_gpu: int = 8 |
| gradient_accumulation_steps: int = 8 |
| max_seq_len: int = 2048 |
| |
| |
| learning_rate: float = 3e-4 |
| min_lr: float = 3e-5 |
| warmup_steps: int = 1000 |
| weight_decay: float = 0.1 |
| beta1: float = 0.9 |
| beta2: float = 0.95 |
| grad_clip: float = 1.0 |
|
|
| |
| log_interval: int = 10 |
| save_interval: int = 1000 |
| eval_interval: int = 500 |
|
|
| |
| num_workers: int = 4 |
| seed: int = 42 |
| bf16: bool = True |
|
|