""" İvme-Conversate — a stupidly small decoder-only language model. Philosophy: sub-100M params, trained from scratch on ultra-dense data, built to punch above its weight on the Tiny-ML leaderboard. Architecture (v1, English-only): - Decoder-only transformer - RoPE positional encoding - Grouped-Query Attention (GQA) - SwiGLU feed-forward - RMSNorm (pre-norm) - Tied input/output embeddings - No biases - Flash Attention via HuggingFace Kernels (with SDPA fallback) Run `python model.py` to build the model and print the real parameter count. """ from __future__ import annotations import math from dataclasses import dataclass import torch import torch.nn as nn import torch.nn.functional as F # --------------------------------------------------------------------------- # # Config # --------------------------------------------------------------------------- # @dataclass class IvmeConfig: # These defaults land at ~22M params. Drop ffn_dim to 896 or n_layers to 9 # to hit ~20M exactly. The embedding table dominates at this scale, so the # vocab size is the single biggest lever on total params. vocab_size: int = 16_384 # English-only v1; BPE trained from scratch hidden_dim: int = 384 n_layers: int = 10 n_heads: int = 6 # head_dim = hidden_dim / n_heads = 64 n_kv_heads: int = 2 # GQA: each KV head is shared across 3 Q heads ffn_dim: int = 1024 # SwiGLU intermediate size max_seq_len: int = 1024 rope_theta: float = 10_000.0 norm_eps: float = 1e-5 tie_embeddings: bool = True # Attention backend: "kernels" (HF Kernel Hub flash-attn2) or "sdpa". attn_backend: str = "sdpa" @property def head_dim(self) -> int: assert self.hidden_dim % self.n_heads == 0, "hidden_dim must divide n_heads" return self.hidden_dim // self.n_heads # --------------------------------------------------------------------------- # # Normalization # --------------------------------------------------------------------------- # class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-5): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: # Compute in fp32 for stability, cast back to input dtype. dtype = x.dtype x = x.float() x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) return (x * self.weight.float()).to(dtype) # --------------------------------------------------------------------------- # # Rotary positional embeddings # --------------------------------------------------------------------------- # def build_rope_cache(seq_len: int, head_dim: int, theta: float, device, dtype): inv_freq = 1.0 / (theta ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim)) t = torch.arange(seq_len, device=device).float() freqs = torch.outer(t, inv_freq) # (seq_len, head_dim/2) emb = torch.cat((freqs, freqs), dim=-1) # (seq_len, head_dim) return emb.cos().to(dtype), emb.sin().to(dtype) def rotate_half(x: torch.Tensor) -> torch.Tensor: x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1) def apply_rope(q, k, cos, sin): cos = cos[None, None, :, :].to(q.dtype) sin = sin[None, None, :, :].to(q.dtype) q = (q * cos) + (rotate_half(q) * sin) k = (k * cos) + (rotate_half(k) * sin) return q, k # --------------------------------------------------------------------------- # # Attention (GQA) # --------------------------------------------------------------------------- # class Attention(nn.Module): def __init__(self, cfg: IvmeConfig): super().__init__() self.n_heads = cfg.n_heads self.n_kv_heads = cfg.n_kv_heads self.head_dim = cfg.head_dim self.n_rep = cfg.n_heads // cfg.n_kv_heads self.backend = cfg.attn_backend self.q_proj = nn.Linear(cfg.hidden_dim, cfg.n_heads * cfg.head_dim, bias=False) self.k_proj = nn.Linear(cfg.hidden_dim, cfg.n_kv_heads * cfg.head_dim, bias=False) self.v_proj = nn.Linear(cfg.hidden_dim, cfg.n_kv_heads * cfg.head_dim, bias=False) self.o_proj = nn.Linear(cfg.n_heads * cfg.head_dim, cfg.hidden_dim, bias=False) self._flash = None # lazily loaded HF kernel def _get_flash_kernel(self): if self._flash is None: from kernels import get_kernel # kernels >= 0.14 requires an explicit version. version=1 pins a # stable major API. flash_attn_func is the differentiable entry # point (raw .fwd has no backward, which would break training). mod = get_kernel("kernels-community/flash-attn2", version=1) self._flash = mod.flash_attn_func return self._flash def forward(self, x, cos, sin): B, S, _ = x.shape q = self.q_proj(x).view(B, S, self.n_heads, self.head_dim) k = self.k_proj(x).view(B, S, self.n_kv_heads, self.head_dim) v = self.v_proj(x).view(B, S, self.n_kv_heads, self.head_dim) if self.backend == "kernels": q, k = self._rope_bshd(q, k, cos, sin) q, k, v = (t.to(torch.bfloat16) for t in (q, k, v)) # <-- add this line flash_attn_func = self._get_flash_kernel() out = flash_attn_func(q, k, v, causal=True) out = out.reshape(B, S, -1) else: # SDPA path expects (B, H, S, D). q = q.transpose(1, 2) # (B, n_heads, S, D) k = k.transpose(1, 2) # (B, n_kv_heads, S, D) v = v.transpose(1, 2) q, k = apply_rope(q, k, cos, sin) out = self._sdpa(q, k, v) out = out.transpose(1, 2).reshape(B, S, -1) return self.o_proj(out) def _rope_bshd(self, q, k, cos, sin): # Apply RoPE while tensors are in (B, S, H, D) layout for the flash path. q = q.transpose(1, 2) k = k.transpose(1, 2) q, k = apply_rope(q, k, cos, sin) return q.transpose(1, 2).contiguous(), k.transpose(1, 2).contiguous() def _sdpa(self, q, k, v): # Prefer native GQA support (PyTorch >= 2.5); fall back to repeat_kv. try: return F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=True) except TypeError: k = k.repeat_interleave(self.n_rep, dim=1) v = v.repeat_interleave(self.n_rep, dim=1) return F.scaled_dot_product_attention(q, k, v, is_causal=True) # --------------------------------------------------------------------------- # # SwiGLU feed-forward # --------------------------------------------------------------------------- # class SwiGLU(nn.Module): def __init__(self, cfg: IvmeConfig): super().__init__() self.gate_proj = nn.Linear(cfg.hidden_dim, cfg.ffn_dim, bias=False) self.up_proj = nn.Linear(cfg.hidden_dim, cfg.ffn_dim, bias=False) self.down_proj = nn.Linear(cfg.ffn_dim, cfg.hidden_dim, bias=False) def forward(self, x): return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) # --------------------------------------------------------------------------- # # Transformer block (pre-norm) # --------------------------------------------------------------------------- # class Block(nn.Module): def __init__(self, cfg: IvmeConfig): super().__init__() self.attn_norm = RMSNorm(cfg.hidden_dim, cfg.norm_eps) self.attn = Attention(cfg) self.ffn_norm = RMSNorm(cfg.hidden_dim, cfg.norm_eps) self.ffn = SwiGLU(cfg) def forward(self, x, cos, sin): x = x + self.attn(self.attn_norm(x), cos, sin) x = x + self.ffn(self.ffn_norm(x)) return x # --------------------------------------------------------------------------- # # İvme-Conversate # --------------------------------------------------------------------------- # class IvmeConversate(nn.Module): def __init__(self, cfg: IvmeConfig): super().__init__() self.cfg = cfg self.embed = nn.Embedding(cfg.vocab_size, cfg.hidden_dim) self.blocks = nn.ModuleList(Block(cfg) for _ in range(cfg.n_layers)) self.norm = RMSNorm(cfg.hidden_dim, cfg.norm_eps) self.lm_head = nn.Linear(cfg.hidden_dim, cfg.vocab_size, bias=False) if cfg.tie_embeddings: self.lm_head.weight = self.embed.weight # RoPE cache is registered as a buffer-free attribute, rebuilt on device. self._cos = None self._sin = None self.apply(self._init_weights) # Scale residual projections by 1/sqrt(2*n_layers) — standard GPT-2 trick # that keeps activation variance stable through deep residual stacks. for name, p in self.named_parameters(): if name.endswith("o_proj.weight") or name.endswith("down_proj.weight"): nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * cfg.n_layers)) def _init_weights(self, module): if isinstance(module, nn.Linear): nn.init.normal_(module.weight, mean=0.0, std=0.02) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, mean=0.0, std=0.02) def _rope(self, seq_len, device, dtype): if self._cos is None or self._cos.size(0) < seq_len or self._cos.device != device: self._cos, self._sin = build_rope_cache( self.cfg.max_seq_len, self.cfg.head_dim, self.cfg.rope_theta, device, dtype ) return self._cos[:seq_len], self._sin[:seq_len] def forward(self, idx, targets=None): B, S = idx.shape x = self.embed(idx) cos, sin = self._rope(S, x.device, x.dtype) for block in self.blocks: x = block(x, cos, sin) x = self.norm(x) logits = self.lm_head(x) loss = None if targets is not None: loss = F.cross_entropy( logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-100 ) return logits, loss # -- utilities ---------------------------------------------------------- # def num_params(self, non_embedding: bool = False) -> int: n = sum(p.numel() for p in self.parameters()) if non_embedding: n -= self.embed.weight.numel() if not self.cfg.tie_embeddings: n -= self.lm_head.weight.numel() return n @torch.no_grad() def generate(self, idx, max_new_tokens, temperature=0.8, top_k=40): self.eval() for _ in range(max_new_tokens): idx_cond = idx[:, -self.cfg.max_seq_len:] logits, _ = self(idx_cond) logits = logits[:, -1, :] / max(temperature, 1e-5) if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float("inf") probs = F.softmax(logits, dim=-1) next_id = torch.multinomial(probs, num_samples=1) idx = torch.cat((idx, next_id), dim=1) return idx # --------------------------------------------------------------------------- # # Smoke test # --------------------------------------------------------------------------- # if __name__ == "__main__": cfg = IvmeConfig() model = IvmeConversate(cfg) total = model.num_params() non_emb = model.num_params(non_embedding=True) emb = cfg.vocab_size * cfg.hidden_dim print("=" * 52) print(" İvme-Conversate") print("=" * 52) print(f" vocab_size : {cfg.vocab_size:,}") print(f" hidden_dim : {cfg.hidden_dim}") print(f" n_layers : {cfg.n_layers}") print(f" n_heads / kv : {cfg.n_heads} / {cfg.n_kv_heads} (GQA)") print(f" ffn_dim : {cfg.ffn_dim} (SwiGLU)") print(f" max_seq_len : {cfg.max_seq_len}") print(f" tied embeddings : {cfg.tie_embeddings}") print("-" * 52) print(f" embedding params: {emb:,} ({100*emb/total:.1f}% of total)") print(f" transformer : {non_emb:,}") print(f" TOTAL PARAMS : {total:,} (~{total/1e6:.1f}M)") print("=" * 52) # Forward + backward sanity check. x = torch.randint(0, cfg.vocab_size, (2, 128)) y = torch.randint(0, cfg.vocab_size, (2, 128)) logits, loss = model(x, y) loss.backward() print(f" forward ok : logits {tuple(logits.shape)}") print(f" initial loss : {loss.item():.3f} (random baseline ≈ {math.log(cfg.vocab_size):.3f})") print(f" backward ok : grads populated") print("=" * 52)