ereniko's picture
Upload model.py with huggingface_hub
b792941 verified
"""
İvme-Conversate — a stupidly small decoder-only language model.
Philosophy: sub-100M params, trained from scratch on ultra-dense data,
built to punch above its weight on the Tiny-ML leaderboard.
Architecture (v1, English-only):
- Decoder-only transformer
- RoPE positional encoding
- Grouped-Query Attention (GQA)
- SwiGLU feed-forward
- RMSNorm (pre-norm)
- Tied input/output embeddings
- No biases
- Flash Attention via HuggingFace Kernels (with SDPA fallback)
Run `python model.py` to build the model and print the real parameter count.
"""
from __future__ import annotations
import math
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
# --------------------------------------------------------------------------- #
# Config
# --------------------------------------------------------------------------- #
@dataclass
class IvmeConfig:
# These defaults land at ~22M params. Drop ffn_dim to 896 or n_layers to 9
# to hit ~20M exactly. The embedding table dominates at this scale, so the
# vocab size is the single biggest lever on total params.
vocab_size: int = 16_384 # English-only v1; BPE trained from scratch
hidden_dim: int = 384
n_layers: int = 10
n_heads: int = 6 # head_dim = hidden_dim / n_heads = 64
n_kv_heads: int = 2 # GQA: each KV head is shared across 3 Q heads
ffn_dim: int = 1024 # SwiGLU intermediate size
max_seq_len: int = 1024
rope_theta: float = 10_000.0
norm_eps: float = 1e-5
tie_embeddings: bool = True
# Attention backend: "kernels" (HF Kernel Hub flash-attn2) or "sdpa".
attn_backend: str = "sdpa"
@property
def head_dim(self) -> int:
assert self.hidden_dim % self.n_heads == 0, "hidden_dim must divide n_heads"
return self.hidden_dim // self.n_heads
# --------------------------------------------------------------------------- #
# Normalization
# --------------------------------------------------------------------------- #
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Compute in fp32 for stability, cast back to input dtype.
dtype = x.dtype
x = x.float()
x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
return (x * self.weight.float()).to(dtype)
# --------------------------------------------------------------------------- #
# Rotary positional embeddings
# --------------------------------------------------------------------------- #
def build_rope_cache(seq_len: int, head_dim: int, theta: float, device, dtype):
inv_freq = 1.0 / (theta ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim))
t = torch.arange(seq_len, device=device).float()
freqs = torch.outer(t, inv_freq) # (seq_len, head_dim/2)
emb = torch.cat((freqs, freqs), dim=-1) # (seq_len, head_dim)
return emb.cos().to(dtype), emb.sin().to(dtype)
def rotate_half(x: torch.Tensor) -> torch.Tensor:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rope(q, k, cos, sin):
cos = cos[None, None, :, :].to(q.dtype)
sin = sin[None, None, :, :].to(q.dtype)
q = (q * cos) + (rotate_half(q) * sin)
k = (k * cos) + (rotate_half(k) * sin)
return q, k
# --------------------------------------------------------------------------- #
# Attention (GQA)
# --------------------------------------------------------------------------- #
class Attention(nn.Module):
def __init__(self, cfg: IvmeConfig):
super().__init__()
self.n_heads = cfg.n_heads
self.n_kv_heads = cfg.n_kv_heads
self.head_dim = cfg.head_dim
self.n_rep = cfg.n_heads // cfg.n_kv_heads
self.backend = cfg.attn_backend
self.q_proj = nn.Linear(cfg.hidden_dim, cfg.n_heads * cfg.head_dim, bias=False)
self.k_proj = nn.Linear(cfg.hidden_dim, cfg.n_kv_heads * cfg.head_dim, bias=False)
self.v_proj = nn.Linear(cfg.hidden_dim, cfg.n_kv_heads * cfg.head_dim, bias=False)
self.o_proj = nn.Linear(cfg.n_heads * cfg.head_dim, cfg.hidden_dim, bias=False)
self._flash = None # lazily loaded HF kernel
def _get_flash_kernel(self):
if self._flash is None:
from kernels import get_kernel
# kernels >= 0.14 requires an explicit version. version=1 pins a
# stable major API. flash_attn_func is the differentiable entry
# point (raw .fwd has no backward, which would break training).
mod = get_kernel("kernels-community/flash-attn2", version=1)
self._flash = mod.flash_attn_func
return self._flash
def forward(self, x, cos, sin):
B, S, _ = x.shape
q = self.q_proj(x).view(B, S, self.n_heads, self.head_dim)
k = self.k_proj(x).view(B, S, self.n_kv_heads, self.head_dim)
v = self.v_proj(x).view(B, S, self.n_kv_heads, self.head_dim)
if self.backend == "kernels":
q, k = self._rope_bshd(q, k, cos, sin)
q, k, v = (t.to(torch.bfloat16) for t in (q, k, v)) # <-- add this line
flash_attn_func = self._get_flash_kernel()
out = flash_attn_func(q, k, v, causal=True)
out = out.reshape(B, S, -1)
else:
# SDPA path expects (B, H, S, D).
q = q.transpose(1, 2) # (B, n_heads, S, D)
k = k.transpose(1, 2) # (B, n_kv_heads, S, D)
v = v.transpose(1, 2)
q, k = apply_rope(q, k, cos, sin)
out = self._sdpa(q, k, v)
out = out.transpose(1, 2).reshape(B, S, -1)
return self.o_proj(out)
def _rope_bshd(self, q, k, cos, sin):
# Apply RoPE while tensors are in (B, S, H, D) layout for the flash path.
q = q.transpose(1, 2)
k = k.transpose(1, 2)
q, k = apply_rope(q, k, cos, sin)
return q.transpose(1, 2).contiguous(), k.transpose(1, 2).contiguous()
def _sdpa(self, q, k, v):
# Prefer native GQA support (PyTorch >= 2.5); fall back to repeat_kv.
try:
return F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=True)
except TypeError:
k = k.repeat_interleave(self.n_rep, dim=1)
v = v.repeat_interleave(self.n_rep, dim=1)
return F.scaled_dot_product_attention(q, k, v, is_causal=True)
# --------------------------------------------------------------------------- #
# SwiGLU feed-forward
# --------------------------------------------------------------------------- #
class SwiGLU(nn.Module):
def __init__(self, cfg: IvmeConfig):
super().__init__()
self.gate_proj = nn.Linear(cfg.hidden_dim, cfg.ffn_dim, bias=False)
self.up_proj = nn.Linear(cfg.hidden_dim, cfg.ffn_dim, bias=False)
self.down_proj = nn.Linear(cfg.ffn_dim, cfg.hidden_dim, bias=False)
def forward(self, x):
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
# --------------------------------------------------------------------------- #
# Transformer block (pre-norm)
# --------------------------------------------------------------------------- #
class Block(nn.Module):
def __init__(self, cfg: IvmeConfig):
super().__init__()
self.attn_norm = RMSNorm(cfg.hidden_dim, cfg.norm_eps)
self.attn = Attention(cfg)
self.ffn_norm = RMSNorm(cfg.hidden_dim, cfg.norm_eps)
self.ffn = SwiGLU(cfg)
def forward(self, x, cos, sin):
x = x + self.attn(self.attn_norm(x), cos, sin)
x = x + self.ffn(self.ffn_norm(x))
return x
# --------------------------------------------------------------------------- #
# İvme-Conversate
# --------------------------------------------------------------------------- #
class IvmeConversate(nn.Module):
def __init__(self, cfg: IvmeConfig):
super().__init__()
self.cfg = cfg
self.embed = nn.Embedding(cfg.vocab_size, cfg.hidden_dim)
self.blocks = nn.ModuleList(Block(cfg) for _ in range(cfg.n_layers))
self.norm = RMSNorm(cfg.hidden_dim, cfg.norm_eps)
self.lm_head = nn.Linear(cfg.hidden_dim, cfg.vocab_size, bias=False)
if cfg.tie_embeddings:
self.lm_head.weight = self.embed.weight
# RoPE cache is registered as a buffer-free attribute, rebuilt on device.
self._cos = None
self._sin = None
self.apply(self._init_weights)
# Scale residual projections by 1/sqrt(2*n_layers) — standard GPT-2 trick
# that keeps activation variance stable through deep residual stacks.
for name, p in self.named_parameters():
if name.endswith("o_proj.weight") or name.endswith("down_proj.weight"):
nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * cfg.n_layers))
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
def _rope(self, seq_len, device, dtype):
if self._cos is None or self._cos.size(0) < seq_len or self._cos.device != device:
self._cos, self._sin = build_rope_cache(
self.cfg.max_seq_len, self.cfg.head_dim, self.cfg.rope_theta, device, dtype
)
return self._cos[:seq_len], self._sin[:seq_len]
def forward(self, idx, targets=None):
B, S = idx.shape
x = self.embed(idx)
cos, sin = self._rope(S, x.device, x.dtype)
for block in self.blocks:
x = block(x, cos, sin)
x = self.norm(x)
logits = self.lm_head(x)
loss = None
if targets is not None:
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-100
)
return logits, loss
# -- utilities ---------------------------------------------------------- #
def num_params(self, non_embedding: bool = False) -> int:
n = sum(p.numel() for p in self.parameters())
if non_embedding:
n -= self.embed.weight.numel()
if not self.cfg.tie_embeddings:
n -= self.lm_head.weight.numel()
return n
@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=0.8, top_k=40):
self.eval()
for _ in range(max_new_tokens):
idx_cond = idx[:, -self.cfg.max_seq_len:]
logits, _ = self(idx_cond)
logits = logits[:, -1, :] / max(temperature, 1e-5)
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float("inf")
probs = F.softmax(logits, dim=-1)
next_id = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, next_id), dim=1)
return idx
# --------------------------------------------------------------------------- #
# Smoke test
# --------------------------------------------------------------------------- #
if __name__ == "__main__":
cfg = IvmeConfig()
model = IvmeConversate(cfg)
total = model.num_params()
non_emb = model.num_params(non_embedding=True)
emb = cfg.vocab_size * cfg.hidden_dim
print("=" * 52)
print(" İvme-Conversate")
print("=" * 52)
print(f" vocab_size : {cfg.vocab_size:,}")
print(f" hidden_dim : {cfg.hidden_dim}")
print(f" n_layers : {cfg.n_layers}")
print(f" n_heads / kv : {cfg.n_heads} / {cfg.n_kv_heads} (GQA)")
print(f" ffn_dim : {cfg.ffn_dim} (SwiGLU)")
print(f" max_seq_len : {cfg.max_seq_len}")
print(f" tied embeddings : {cfg.tie_embeddings}")
print("-" * 52)
print(f" embedding params: {emb:,} ({100*emb/total:.1f}% of total)")
print(f" transformer : {non_emb:,}")
print(f" TOTAL PARAMS : {total:,} (~{total/1e6:.1f}M)")
print("=" * 52)
# Forward + backward sanity check.
x = torch.randint(0, cfg.vocab_size, (2, 128))
y = torch.randint(0, cfg.vocab_size, (2, 128))
logits, loss = model(x, y)
loss.backward()
print(f" forward ok : logits {tuple(logits.shape)}")
print(f" initial loss : {loss.item():.3f} (random baseline ≈ {math.log(cfg.vocab_size):.3f})")
print(f" backward ok : grads populated")
print("=" * 52)