| """ |
| İvme-Conversate — a stupidly small decoder-only language model. |
| |
| Philosophy: sub-100M params, trained from scratch on ultra-dense data, |
| built to punch above its weight on the Tiny-ML leaderboard. |
| |
| Architecture (v1, English-only): |
| - Decoder-only transformer |
| - RoPE positional encoding |
| - Grouped-Query Attention (GQA) |
| - SwiGLU feed-forward |
| - RMSNorm (pre-norm) |
| - Tied input/output embeddings |
| - No biases |
| - Flash Attention via HuggingFace Kernels (with SDPA fallback) |
| |
| Run `python model.py` to build the model and print the real parameter count. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import math |
| from dataclasses import dataclass |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| |
| |
| |
| @dataclass |
| class IvmeConfig: |
| |
| |
| |
| vocab_size: int = 16_384 |
| hidden_dim: int = 384 |
| n_layers: int = 10 |
| n_heads: int = 6 |
| n_kv_heads: int = 2 |
| ffn_dim: int = 1024 |
| max_seq_len: int = 1024 |
| rope_theta: float = 10_000.0 |
| norm_eps: float = 1e-5 |
| tie_embeddings: bool = True |
| |
| attn_backend: str = "sdpa" |
|
|
| @property |
| def head_dim(self) -> int: |
| assert self.hidden_dim % self.n_heads == 0, "hidden_dim must divide n_heads" |
| return self.hidden_dim // self.n_heads |
|
|
|
|
| |
| |
| |
| class RMSNorm(nn.Module): |
| def __init__(self, dim: int, eps: float = 1e-5): |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(dim)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| dtype = x.dtype |
| x = x.float() |
| x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
| return (x * self.weight.float()).to(dtype) |
|
|
|
|
| |
| |
| |
| def build_rope_cache(seq_len: int, head_dim: int, theta: float, device, dtype): |
| inv_freq = 1.0 / (theta ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim)) |
| t = torch.arange(seq_len, device=device).float() |
| freqs = torch.outer(t, inv_freq) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| return emb.cos().to(dtype), emb.sin().to(dtype) |
|
|
|
|
| def rotate_half(x: torch.Tensor) -> torch.Tensor: |
| x1, x2 = x.chunk(2, dim=-1) |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| def apply_rope(q, k, cos, sin): |
| cos = cos[None, None, :, :].to(q.dtype) |
| sin = sin[None, None, :, :].to(q.dtype) |
| q = (q * cos) + (rotate_half(q) * sin) |
| k = (k * cos) + (rotate_half(k) * sin) |
| return q, k |
|
|
|
|
| |
| |
| |
| class Attention(nn.Module): |
| def __init__(self, cfg: IvmeConfig): |
| super().__init__() |
| self.n_heads = cfg.n_heads |
| self.n_kv_heads = cfg.n_kv_heads |
| self.head_dim = cfg.head_dim |
| self.n_rep = cfg.n_heads // cfg.n_kv_heads |
| self.backend = cfg.attn_backend |
|
|
| self.q_proj = nn.Linear(cfg.hidden_dim, cfg.n_heads * cfg.head_dim, bias=False) |
| self.k_proj = nn.Linear(cfg.hidden_dim, cfg.n_kv_heads * cfg.head_dim, bias=False) |
| self.v_proj = nn.Linear(cfg.hidden_dim, cfg.n_kv_heads * cfg.head_dim, bias=False) |
| self.o_proj = nn.Linear(cfg.n_heads * cfg.head_dim, cfg.hidden_dim, bias=False) |
|
|
| self._flash = None |
|
|
| def _get_flash_kernel(self): |
| if self._flash is None: |
| from kernels import get_kernel |
| |
| |
| |
| mod = get_kernel("kernels-community/flash-attn2", version=1) |
| self._flash = mod.flash_attn_func |
| return self._flash |
|
|
| def forward(self, x, cos, sin): |
| B, S, _ = x.shape |
|
|
| q = self.q_proj(x).view(B, S, self.n_heads, self.head_dim) |
| k = self.k_proj(x).view(B, S, self.n_kv_heads, self.head_dim) |
| v = self.v_proj(x).view(B, S, self.n_kv_heads, self.head_dim) |
|
|
| if self.backend == "kernels": |
| q, k = self._rope_bshd(q, k, cos, sin) |
| q, k, v = (t.to(torch.bfloat16) for t in (q, k, v)) |
| flash_attn_func = self._get_flash_kernel() |
| out = flash_attn_func(q, k, v, causal=True) |
| out = out.reshape(B, S, -1) |
| else: |
| |
| q = q.transpose(1, 2) |
| k = k.transpose(1, 2) |
| v = v.transpose(1, 2) |
| q, k = apply_rope(q, k, cos, sin) |
| out = self._sdpa(q, k, v) |
| out = out.transpose(1, 2).reshape(B, S, -1) |
|
|
| return self.o_proj(out) |
|
|
| def _rope_bshd(self, q, k, cos, sin): |
| |
| q = q.transpose(1, 2) |
| k = k.transpose(1, 2) |
| q, k = apply_rope(q, k, cos, sin) |
| return q.transpose(1, 2).contiguous(), k.transpose(1, 2).contiguous() |
|
|
| def _sdpa(self, q, k, v): |
| |
| try: |
| return F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=True) |
| except TypeError: |
| k = k.repeat_interleave(self.n_rep, dim=1) |
| v = v.repeat_interleave(self.n_rep, dim=1) |
| return F.scaled_dot_product_attention(q, k, v, is_causal=True) |
|
|
|
|
| |
| |
| |
| class SwiGLU(nn.Module): |
| def __init__(self, cfg: IvmeConfig): |
| super().__init__() |
| self.gate_proj = nn.Linear(cfg.hidden_dim, cfg.ffn_dim, bias=False) |
| self.up_proj = nn.Linear(cfg.hidden_dim, cfg.ffn_dim, bias=False) |
| self.down_proj = nn.Linear(cfg.ffn_dim, cfg.hidden_dim, bias=False) |
|
|
| def forward(self, x): |
| return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) |
|
|
|
|
| |
| |
| |
| class Block(nn.Module): |
| def __init__(self, cfg: IvmeConfig): |
| super().__init__() |
| self.attn_norm = RMSNorm(cfg.hidden_dim, cfg.norm_eps) |
| self.attn = Attention(cfg) |
| self.ffn_norm = RMSNorm(cfg.hidden_dim, cfg.norm_eps) |
| self.ffn = SwiGLU(cfg) |
|
|
| def forward(self, x, cos, sin): |
| x = x + self.attn(self.attn_norm(x), cos, sin) |
| x = x + self.ffn(self.ffn_norm(x)) |
| return x |
|
|
|
|
| |
| |
| |
| class IvmeConversate(nn.Module): |
| def __init__(self, cfg: IvmeConfig): |
| super().__init__() |
| self.cfg = cfg |
| self.embed = nn.Embedding(cfg.vocab_size, cfg.hidden_dim) |
| self.blocks = nn.ModuleList(Block(cfg) for _ in range(cfg.n_layers)) |
| self.norm = RMSNorm(cfg.hidden_dim, cfg.norm_eps) |
| self.lm_head = nn.Linear(cfg.hidden_dim, cfg.vocab_size, bias=False) |
|
|
| if cfg.tie_embeddings: |
| self.lm_head.weight = self.embed.weight |
|
|
| |
| self._cos = None |
| self._sin = None |
|
|
| self.apply(self._init_weights) |
| |
| |
| for name, p in self.named_parameters(): |
| if name.endswith("o_proj.weight") or name.endswith("down_proj.weight"): |
| nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * cfg.n_layers)) |
|
|
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| elif isinstance(module, nn.Embedding): |
| nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
| def _rope(self, seq_len, device, dtype): |
| if self._cos is None or self._cos.size(0) < seq_len or self._cos.device != device: |
| self._cos, self._sin = build_rope_cache( |
| self.cfg.max_seq_len, self.cfg.head_dim, self.cfg.rope_theta, device, dtype |
| ) |
| return self._cos[:seq_len], self._sin[:seq_len] |
|
|
| def forward(self, idx, targets=None): |
| B, S = idx.shape |
| x = self.embed(idx) |
| cos, sin = self._rope(S, x.device, x.dtype) |
|
|
| for block in self.blocks: |
| x = block(x, cos, sin) |
| x = self.norm(x) |
| logits = self.lm_head(x) |
|
|
| loss = None |
| if targets is not None: |
| loss = F.cross_entropy( |
| logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-100 |
| ) |
| return logits, loss |
|
|
| |
| def num_params(self, non_embedding: bool = False) -> int: |
| n = sum(p.numel() for p in self.parameters()) |
| if non_embedding: |
| n -= self.embed.weight.numel() |
| if not self.cfg.tie_embeddings: |
| n -= self.lm_head.weight.numel() |
| return n |
|
|
| @torch.no_grad() |
| def generate(self, idx, max_new_tokens, temperature=0.8, top_k=40): |
| self.eval() |
| for _ in range(max_new_tokens): |
| idx_cond = idx[:, -self.cfg.max_seq_len:] |
| logits, _ = self(idx_cond) |
| logits = logits[:, -1, :] / max(temperature, 1e-5) |
| if top_k is not None: |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| logits[logits < v[:, [-1]]] = -float("inf") |
| probs = F.softmax(logits, dim=-1) |
| next_id = torch.multinomial(probs, num_samples=1) |
| idx = torch.cat((idx, next_id), dim=1) |
| return idx |
|
|
|
|
| |
| |
| |
| if __name__ == "__main__": |
| cfg = IvmeConfig() |
| model = IvmeConversate(cfg) |
|
|
| total = model.num_params() |
| non_emb = model.num_params(non_embedding=True) |
| emb = cfg.vocab_size * cfg.hidden_dim |
|
|
| print("=" * 52) |
| print(" İvme-Conversate") |
| print("=" * 52) |
| print(f" vocab_size : {cfg.vocab_size:,}") |
| print(f" hidden_dim : {cfg.hidden_dim}") |
| print(f" n_layers : {cfg.n_layers}") |
| print(f" n_heads / kv : {cfg.n_heads} / {cfg.n_kv_heads} (GQA)") |
| print(f" ffn_dim : {cfg.ffn_dim} (SwiGLU)") |
| print(f" max_seq_len : {cfg.max_seq_len}") |
| print(f" tied embeddings : {cfg.tie_embeddings}") |
| print("-" * 52) |
| print(f" embedding params: {emb:,} ({100*emb/total:.1f}% of total)") |
| print(f" transformer : {non_emb:,}") |
| print(f" TOTAL PARAMS : {total:,} (~{total/1e6:.1f}M)") |
| print("=" * 52) |
|
|
| |
| x = torch.randint(0, cfg.vocab_size, (2, 128)) |
| y = torch.randint(0, cfg.vocab_size, (2, 128)) |
| logits, loss = model(x, y) |
| loss.backward() |
| print(f" forward ok : logits {tuple(logits.shape)}") |
| print(f" initial loss : {loss.item():.3f} (random baseline ≈ {math.log(cfg.vocab_size):.3f})") |
| print(f" backward ok : grads populated") |
| print("=" * 52) |