#!/usr/bin/env python3 """Public-facing TMLM-Haiku interactive CLI. Pulls models from the CompactAI-O HuggingFace collection: https://huggingface.co/collections/CompactAI-O/tmlm-haiku-series """ from __future__ import annotations #!/usr/bin/env python3 from __future__ import annotations import hashlib import json import math import os import string import sys from dataclasses import dataclass from pathlib import Path from typing import Dict, Iterator, List, Optional, Sequence, Tuple import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.checkpoint import checkpoint HUGGINGFACE_MODELS = { "TMLM-Haiku-1": "CompactAI-O/TMLM-Haiku-1", "TMLM-Haiku-1.3": "CompactAI-O/TMLM-Haiku-1.3", "TMLM-Haiku-2": "CompactAI-O/TMLM-Haiku-2", "TMLM-Haiku-2.3": "CompactAI-O/TMLM-Haiku-2.3", "Glint-1": "CompactAI-O/Glint-1", } # --------------------------------------------------------------------------- # Config # --------------------------------------------------------------------------- @dataclass class ModelConfig: dim: int = 128 n_unique_layers: int = 8 n_logical_layers: int = 16 n_heads: int = 4 n_kv_heads: int = 2 ffn_dim: int = 224 dropout: float = 0.0 seq_len: int = 2048 sliding_window_size: int = 512 mtp_horizons: Tuple[int, ...] = (2, 3, 4) rope_fraction: float = 0.5 embed_scale: bool = True logit_soft_cap: float = -1.0 quantization: str = "nvfp4" @property def head_dim(self) -> int: return self.dim // self.n_heads model_config = ModelConfig() MODEL_SERIES = { "haiku": { "dim": 64, "n_unique_layers": 12, "n_logical_layers": 24, "n_heads": 4, "n_kv_heads": 2, "ffn_dim": 384, "dropout": 0.0, "seq_len": 2048, "sliding_window_size": 2048, "mtp_horizons": (), "rope_fraction": 0.5, "engram_dim": 8, "engram_heads": 2, "engram_table_size": 64, "engram_max_ngram": 2, "mhc_expansion": 2, "sleep_gate_cap": 0, "sleep_gate_heads": 4, "latent_think_layers": 0, "prelude_layers": 0, "coda_layers": 0, "recurrent_loops": 0, "recurrent_act_threshold": 0.9, "recurrent_lora_rank": 0, "recurrent_loop_embed_dim": 0, }, "sonnet": { "dim": 1024, "n_unique_layers": 20, "n_logical_layers": 40, "n_heads": 16, "n_kv_heads": 4, "ffn_dim": 4096, "dropout": 0.0, "seq_len": 2048, "mtp_horizons": (2,), "engram_dim": 32, "engram_heads": 8, "engram_table_size": 4096, "engram_max_ngram": 2, "mhc_expansion": 2, "sleep_gate_cap": 0, "sleep_gate_heads": 8, "latent_think_layers": 0, "prelude_layers": 0, "coda_layers": 0, "recurrent_loops": 0, "recurrent_act_threshold": 0.99, "recurrent_lora_rank": 0, "recurrent_loop_embed_dim": 0, }, "opus": { "dim": 1536, "n_unique_layers": 18, "n_logical_layers": 36, "n_heads": 16, "n_kv_heads": 4, "ffn_dim": 5888, "dropout": 0.0, "seq_len": 2048, "mtp_horizons": (2,), "engram_dim": 64, "engram_heads": 8, "engram_table_size": 8192, "engram_max_ngram": 2, "mhc_expansion": 4, "sleep_gate_cap": 0, "sleep_gate_heads": 8, "latent_think_layers": 0, "prelude_layers": 0, "coda_layers": 0, "recurrent_loops": 0, "recurrent_act_threshold": 0.99, "recurrent_lora_rank": 0, "recurrent_loop_embed_dim": 0, }, } # --------------------------------------------------------------------------- # Tokenizer # --------------------------------------------------------------------------- FORMAT_TOKENS = [ "<|user|>", "<|assistant|>", "<|system|>", "<|start_header_id|>", "<|end_header_id|>", "<|begin_of_thought|>", "<|end_of_thought|>", "<|begin_of_solution|>", "<|end_of_solution|>", ] class WordTokenizer: def __init__( self, extra_chars: str = "", format_tokens: Optional[List[str]] = None ) -> None: base = string.ascii_letters + string.digits + string.punctuation + " \n\t\r" fallback_chars = sorted(set(base + extra_chars)) self.core_special = ["", "", "", ""] self.format_tokens = ( list(format_tokens) if format_tokens else list(FORMAT_TOKENS) ) self.special = list(self.core_special) + list(self.format_tokens) self.id_to_token: List[str] = ( list(self.core_special) + self.format_tokens + fallback_chars ) self.token_to_id: Dict[str, int] = { t: i for i, t in enumerate(self.id_to_token) } self.special_multi_tokens = sorted( [t for t in self.special if len(t) > 1], key=len, reverse=True ) self.multi_char_tokens = self.special_multi_tokens self.dynamic_additions = 0 @property def pad_id(self) -> int: return self.token_to_id[""] @property def bos_id(self) -> int: return self.token_to_id[""] @property def eos_id(self) -> int: return self.token_to_id[""] @property def unk_id(self) -> int: return self.token_to_id[""] @property def vocab_size(self) -> int: return len(self.id_to_token) def maybe_add_char(self, ch: str) -> bool: if ch in self.token_to_id: return False self.token_to_id[ch] = len(self.id_to_token) self.id_to_token.append(ch) self.dynamic_additions += 1 return True def iter_lexical_tokens(self, text: str) -> Iterator[str]: i = 0 n = len(text) while i < n: matched_special = False for token in self.special_multi_tokens: if text.startswith(token, i): yield token i += len(token) matched_special = True break if matched_special: continue yield text[i] i += 1 def encode( self, text: str, add_bos: bool = False, add_eos: bool = False ) -> List[int]: out: List[int] = [] if add_bos: out.append(self.bos_id) unk = self.unk_id t2i = self.token_to_id for tok in self.iter_lexical_tokens(text): out.append(t2i.get(tok, unk)) if add_eos: out.append(self.eos_id) return out def decode(self, ids: Sequence[int], skip_special: bool = True) -> str: pieces: List[str] = [] for idx in ids: if int(idx) < 0 or int(idx) >= len(self.id_to_token): continue tok = self.id_to_token[int(idx)] if skip_special and tok in self.special: continue pieces.append(tok) return "".join(pieces) @classmethod def load(cls, path: Path) -> WordTokenizer: with path.open("r", encoding="utf-8") as f: data = json.load(f) format_tokens = data.get("format_tokens", FORMAT_TOKENS) tokenizer = cls(extra_chars="", format_tokens=format_tokens) tokenizer.id_to_token = data["id_to_token"] tokenizer.token_to_id = {t: i for i, t in enumerate(tokenizer.id_to_token)} tokenizer.special = list(tokenizer.core_special) + list(tokenizer.format_tokens) tokenizer.special_multi_tokens = sorted( [t for t in tokenizer.special if len(t) > 1], key=len, reverse=True ) tokenizer.multi_char_tokens = tokenizer.special_multi_tokens return tokenizer LetterTokenizer = WordTokenizer # --------------------------------------------------------------------------- # Model # --------------------------------------------------------------------------- class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6) -> None: super().__init__() self.weight = nn.Parameter(torch.ones(dim)) self.eps = eps def forward(self, x: torch.Tensor) -> torch.Tensor: if hasattr(torch.nn.functional, "rms_norm"): return torch.nn.functional.rms_norm( x, self.weight.shape, self.weight, self.eps ) x_fp = x.float() rms = torch.rsqrt(x_fp.pow(2).mean(dim=-1, keepdim=True) + self.eps) return (x_fp * rms).to(dtype=x.dtype) * self.weight class RotaryEmbedding(nn.Module): def __init__(self, dim: int, base: float = 10000.0) -> None: super().__init__() inv = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv, persistent=False) def cos_sin( self, seq_len: int, device: torch.device, dtype: torch.dtype ) -> Tuple[torch.Tensor, torch.Tensor]: t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) freqs = torch.outer(t, self.inv_freq) emb = torch.cat([freqs, freqs], dim=-1) cos = emb.cos()[None, None, :, :].to(dtype=dtype) sin = emb.sin()[None, None, :, :].to(dtype=dtype) return cos, sin def _rotate_half(x: torch.Tensor) -> torch.Tensor: x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) class CausalSelfAttention(nn.Module): def __init__( self, dim: int, n_heads: int, n_kv_heads: int, head_dim: int, dropout: float, sliding_window: int, rope_fraction: float, ) -> None: super().__init__() self.dim = dim self.n_heads = n_heads self.n_kv_heads = n_kv_heads self.head_dim = head_dim self.n_rep = n_heads // n_kv_heads self.dropout = dropout self.sliding_window = sliding_window self.wq = nn.Linear(dim, n_heads * head_dim, bias=False) self.wk = nn.Linear(dim, n_kv_heads * head_dim, bias=False) self.wv = nn.Linear(dim, n_kv_heads * head_dim, bias=False) self.wo = nn.Linear(n_heads * head_dim, dim, bias=False) self.rope_dim = max(2, int(head_dim * rope_fraction) // 2 * 2) self.rope = RotaryEmbedding(self.rope_dim) self.q_norm = RMSNorm(head_dim) self.k_norm = RMSNorm(head_dim) self.output_gate = nn.Parameter(torch.ones(n_heads)) def forward( self, x: torch.Tensor, is_global: bool, past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: B, T, _ = x.shape q = self.wq(x).view(B, T, self.n_heads, self.head_dim) k = self.wk(x).view(B, T, self.n_kv_heads, self.head_dim) v = self.wv(x).view(B, T, self.n_kv_heads, self.head_dim) q = self.q_norm(q) k = self.k_norm(k) q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) past_len = past_kv[0].shape[2] if past_kv is not None else 0 cos, sin = self.rope.cos_sin(T + past_len, x.device, q.dtype) cos_slice = cos[:, :, past_len : past_len + T, :] sin_slice = sin[:, :, past_len : past_len + T, :] q_rope = q[..., : self.rope_dim] q_pass = q[..., self.rope_dim :] k_rope = k[..., : self.rope_dim] k_pass = k[..., self.rope_dim :] q_rope = (q_rope * cos_slice) + (_rotate_half(q_rope) * sin_slice) k_rope = (k_rope * cos_slice) + (_rotate_half(k_rope) * sin_slice) q = torch.cat([q_rope, q_pass], dim=-1) k = torch.cat([k_rope, k_pass], dim=-1) if past_kv is not None: k = torch.cat([past_kv[0], k], dim=2) v = torch.cat([past_kv[1], v], dim=2) new_kv = (k, v) if use_cache else None S = k.shape[2] if self.n_rep > 1: k = ( k[:, :, None, :, :] .expand(B, self.n_kv_heads, self.n_rep, S, self.head_dim) .reshape(B, self.n_heads, S, self.head_dim) ) v = ( v[:, :, None, :, :] .expand(B, self.n_kv_heads, self.n_rep, S, self.head_dim) .reshape(B, self.n_heads, S, self.head_dim) ) drop_p = self.dropout if (self.training and torch.is_grad_enabled()) else 0.0 if is_global: if past_kv is None and T > 1: out = F.scaled_dot_product_attention( q, k, v, is_causal=True, dropout_p=drop_p ) else: out = F.scaled_dot_product_attention(q, k, v, dropout_p=drop_p) else: T_q = q.shape[2] q_pos = torch.arange(past_len, past_len + T_q, device=q.device).unsqueeze(1) k_pos = torch.arange(S, device=q.device).unsqueeze(0) mask = (q_pos >= k_pos) & ((q_pos - k_pos) < self.sliding_window) out = F.scaled_dot_product_attention( q, k, v, attn_mask=mask.unsqueeze(0).unsqueeze(0), dropout_p=drop_p ) gate = torch.sigmoid(self.output_gate).view(1, self.n_heads, 1, 1) out = out * gate out = out.transpose(1, 2).contiguous().view(B, T, self.n_heads * self.head_dim) out = self.wo(out) return out, new_kv class SwiGLU(nn.Module): def __init__(self, dim: int, hidden_dim: int, dropout: float) -> None: super().__init__() self.gate = nn.Linear(dim, hidden_dim, bias=False) self.up = nn.Linear(dim, hidden_dim, bias=False) self.down = nn.Linear(hidden_dim, dim, bias=False) self.drop = nn.Dropout(dropout) nn.init.normal_(self.gate.weight, std=dim**-0.5) nn.init.normal_(self.up.weight, std=dim**-0.5) nn.init.normal_(self.down.weight, std=hidden_dim**-0.5) def forward(self, x: torch.Tensor) -> torch.Tensor: h = F.silu(self.gate(x)) * self.up(x) out = self.down(h) if self.training and torch.is_grad_enabled(): out = self.drop(out) return out def loop_index_embedding(h: torch.Tensor, loop_t: int, loop_dim: int, theta: float = 10000.0) -> torch.Tensor: if loop_dim <= 0: return h loop_dim = min(loop_dim, h.shape[-1]) if loop_dim % 2 == 1: loop_dim -= 1 if loop_dim <= 0: return h inv_freq = 1.0 / (theta ** (torch.arange(0, loop_dim, 2, device=h.device, dtype=h.dtype) / loop_dim)) phase = torch.tensor(float(loop_t), device=h.device, dtype=h.dtype) * inv_freq loop_embed = torch.cat([phase.sin(), phase.cos()], dim=0).view(1, 1, loop_dim) out = h.clone() out[..., :loop_dim] = out[..., :loop_dim] + loop_embed return out class DepthLoRAAdapter(nn.Module): def __init__(self, dim: int, rank: int, max_loops: int) -> None: super().__init__() self.rank = max(0, rank) if self.rank <= 0: self.down = None self.B = None self.scale = None return self.down = nn.Linear(dim, self.rank, bias=False) self.B = nn.Parameter(torch.randn(self.rank, dim) * 0.02) self.scale = nn.Embedding(max(1, max_loops), self.rank) nn.init.zeros_(self.scale.weight) def forward(self, x: torch.Tensor, loop_t: int) -> torch.Tensor: if self.rank <= 0 or self.down is None or self.B is None or self.scale is None: return torch.zeros_like(x) t_idx = min(loop_t, self.scale.num_embeddings - 1) scale = self.scale(torch.tensor(t_idx, device=x.device)) return (self.down(x) * scale) @ self.B class StableRecurrentInjection(nn.Module): def __init__(self, dim: int) -> None: super().__init__() self.log_A = nn.Parameter(torch.full((dim,), -2.0)) self.log_dt = nn.Parameter(torch.full((dim,), -2.0)) self.input_gate = nn.Parameter(torch.zeros(dim)) def forward(self, h: torch.Tensor, e: torch.Tensor, transformer_out: torch.Tensor) -> torch.Tensor: A = torch.exp(-torch.exp((self.log_dt + self.log_A).clamp(-20, 20))).view(1, 1, -1) B = torch.sigmoid(self.input_gate).view(1, 1, -1) return A * h + B * e + transformer_out class AdaptiveHalting(nn.Module): def __init__(self, dim: int) -> None: super().__init__() self.halt = nn.Linear(dim, 1, bias=True) nn.init.zeros_(self.halt.weight) nn.init.constant_(self.halt.bias, -2.0) def forward(self, h: torch.Tensor) -> torch.Tensor: return torch.sigmoid(self.halt(h)).squeeze(-1) class EngramBlock(nn.Module): """DeepSeek Engram: conditional memory via O(1) hashed N-gram lookup. Stores common token-pair/triplet patterns in an embedding table and retrieves them with multi-head hashing. A context-aware gate (using the current hidden state as query) decides how much of the retrieved memory to inject into the residual stream. Reference: DeepSeek-AI, "Conditional Memory via Scalable Lookup" (2025). """ def __init__( self, dim: int, engram_dim: int, n_heads: int = 4, table_size: int = 8192, max_ngram: int = 3, ) -> None: super().__init__() self.dim = dim self.engram_dim = engram_dim self.n_heads = n_heads self.table_size = table_size self.max_ngram = max_ngram # One embedding table per (ngram_order, hash_head) self.embeddings = nn.ParameterDict() for n in range(2, max_ngram + 1): for k in range(n_heads): self.embeddings[f"{n}_{k}"] = nn.Parameter( torch.randn(table_size, engram_dim) * (engram_dim**-0.5) ) # Fixed hash parameters (non-learnable, deterministic) for n in range(2, max_ngram + 1): for k in range(n_heads): seed = int(hashlib.md5(f"engram_{n}_{k}".encode()).hexdigest()[:8], 16) rng = torch.Generator().manual_seed(seed) a = torch.randint(1, 2**31, (1,), generator=rng).item() b = torch.randint(0, 2**31, (1,), generator=rng).item() self.register_buffer( f"hash_a_{n}_{k}", torch.tensor(a), persistent=False ) self.register_buffer( f"hash_b_{n}_{k}", torch.tensor(b), persistent=False ) # Causal convolution over N-gram branch outputs (kernel=4, dilation=max_ngram) total_branch_dim = engram_dim * n_heads * (max_ngram - 1) self.branch_conv = nn.Conv1d( total_branch_dim, total_branch_dim, kernel_size=4, dilation=max_ngram, padding=0, groups=total_branch_dim, bias=True, ) nn.init.zeros_(self.branch_conv.weight) nn.init.zeros_(self.branch_conv.bias) # Context-aware gating: hidden state as query, memory as key/value self.gate_query = nn.Linear(dim, engram_dim, bias=False) self.gate_key = nn.Linear(total_branch_dim, engram_dim, bias=False) self.gate_value = nn.Linear(total_branch_dim, dim, bias=False) self.gate_scale = engram_dim**-0.5 def _hash_ngram(self, token_ids: torch.Tensor, n: int, k: int) -> torch.Tensor: """Hash n-gram token sequences into table indices. Args: token_ids: (B, T) token IDs n: n-gram order (2 = bigram, 3 = trigram) k: hash head index Returns: indices: (B, T) integer indices into embedding table """ a = getattr(self, f"hash_a_{n}_{k}") b = getattr(self, f"hash_b_{n}_{k}") B, T = token_ids.shape # Pad left with zeros so every position has a valid n-gram padded = F.pad(token_ids, (n - 1, 0), value=0) # (B, T+n-1) # Polynomial rolling hash combined = torch.zeros(B, T, dtype=torch.long, device=token_ids.device) for i in range(n): combined = combined * 31 + padded[:, i : i + T].long() indices = ((a * combined) ^ b) % self.table_size return indices def forward( self, hidden: torch.Tensor, token_ids: Optional[torch.Tensor] = None ) -> torch.Tensor: """Forward pass. Args: hidden: (B, T, dim) current hidden state token_ids: (B, T) input token IDs for n-gram hashing. If None, uses argmax of hidden projections as proxy. Returns: output: (B, T, dim) memory injection for residual stream """ B, T, _ = hidden.shape if token_ids is None: # Fallback: derive pseudo-token-ids from hidden state token_ids = hidden.mean(dim=-1).long() % self.table_size # Retrieve and concatenate across n-gram orders and hash heads branch_outputs = [] for n in range(2, self.max_ngram + 1): for k in range(self.n_heads): indices = self._hash_ngram(token_ids, n, k) # (B, T) table = self.embeddings[f"{n}_{k}"] # (table_size, engram_dim) retrieved = table[indices] # (B, T, engram_dim) branch_outputs.append(retrieved) # (B, T, engram_dim * n_heads * (max_ngram - 1)) memory = torch.cat(branch_outputs, dim=-1) # Causal convolution over sequence dimension # Pad left for causality (kernel_size - 1 = 3) conv_in = memory.transpose(1, 2) # (B, C, T) conv_in = F.pad( conv_in, ((self.branch_conv.kernel_size[0] - 1) * self.branch_conv.dilation[0], 0), ) conv_out = self.branch_conv(conv_in) # (B, C, T) memory = conv_out.transpose(1, 2) # (B, T, C) # Context-aware gating query = self.gate_query(hidden) # (B, T, engram_dim) key = self.gate_key(memory) # (B, T, engram_dim) gate = torch.sigmoid( (query * key).sum(dim=-1, keepdim=True) * self.gate_scale ) # (B, T, 1) value = self.gate_value(memory) # (B, T, dim) return gate * value class SleepGate(nn.Module): """Persistent memory + periodic consolidation gate.""" def __init__( self, dim: int, cap: int = 128, n_heads: int = 4, retention_enabled: bool = True, retention_hidden: int = 0, ) -> None: super().__init__() self.dim = dim self.cap = cap self.n_heads = n_heads self.head_dim = dim // n_heads self.scale = self.head_dim ** -0.5 self.retention_enabled = retention_enabled self.register_buffer("mem_emb", torch.zeros(cap, dim, dtype=torch.bfloat16)) self.register_buffer("mem_age", torch.zeros(cap, dtype=torch.long)) self.register_buffer("mem_beta", torch.ones(cap, dtype=torch.float32)) self.register_buffer("mem_count", torch.zeros((), dtype=torch.long)) self.register_buffer("mem_head", torch.zeros((), dtype=torch.long)) self.register_buffer("global_step", torch.zeros((), dtype=torch.long)) self.q_proj = nn.Linear(dim, dim, bias=False) self.k_proj = nn.Linear(dim, dim, bias=False) self.v_proj = nn.Linear(dim, dim, bias=False) self.o_proj = nn.Linear(dim, dim, bias=False) nn.init.zeros_(self.o_proj.weight) self.gate_scale = nn.Parameter(torch.zeros(())) if retention_enabled: if retention_hidden > 0: self.retention_gate: Optional[nn.Module] = nn.Sequential( nn.Linear(dim, retention_hidden, bias=False), nn.GELU(), nn.Linear(retention_hidden, 1, bias=True), ) nn.init.constant_(self.retention_gate[-1].bias, 2.2) else: self.retention_gate = nn.Linear(dim, 1, bias=True) nn.init.constant_(self.retention_gate.bias, 2.2) else: self.retention_gate = None self._last_beta: Optional[torch.Tensor] = None def write(self, hidden: torch.Tensor) -> None: B, T, _ = hidden.shape tail_full = hidden[:, max(0, T - 16):, :].float().mean(dim=1) if self.retention_gate is not None: beta_live = torch.sigmoid(self.retention_gate(tail_full).squeeze(-1)) self._last_beta = beta_live if self.training else None beta_store = beta_live.detach().float() else: self._last_beta = None beta_store = torch.ones(B, device=hidden.device, dtype=torch.float32) tail = tail_full.to(self.mem_emb.dtype).detach() with torch.no_grad(): head = int(self.mem_head.item()) count = int(self.mem_count.item()) step = int(self.global_step.item()) for b in range(B): self.mem_emb[head] = tail[b] self.mem_age[head] = step self.mem_beta[head] = beta_store[b] head = (head + 1) % self.cap if count < self.cap: count += 1 self.mem_head.fill_(head) self.mem_count.fill_(count) def read(self, x: torch.Tensor) -> torch.Tensor: count = int(self.mem_count.item()) if count == 0: return torch.zeros_like(x) B, T, D = x.shape mem = self.mem_emb[:count].clone().to(x.dtype) q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) k = self.k_proj(mem).view(count, self.n_heads, self.head_dim).transpose(0, 1) v = self.v_proj(mem).view(count, self.n_heads, self.head_dim).transpose(0, 1) attn = torch.einsum("bhtd,hmd->bhtm", q, k) * self.scale attn = F.softmax(attn, dim=-1) if self.retention_enabled: step = int(self.global_step.item()) ages = self.mem_age[:count].to(x.device) delta = (step - ages).clamp(min=0).to(x.dtype) betas = self.mem_beta[:count].to(x.dtype).clamp(min=1e-6, max=1.0) weights = betas.pow(delta) attn = attn * weights.view(1, 1, 1, count) attn = attn / attn.sum(dim=-1, keepdim=True).clamp_min(1e-9) out = torch.einsum("bhtm,hmd->bhtd", attn, v) out = out.transpose(1, 2).contiguous().view(B, T, D) out = self.o_proj(out) return torch.sigmoid(self.gate_scale) * out @torch.no_grad() def reset(self) -> None: self.mem_emb.zero_() self.mem_age.zero_() self.mem_beta.fill_(1.0) self.mem_count.zero_() self.mem_head.zero_() self.global_step.zero_() self._last_beta = None def _sinkhorn_knopp(logits: torch.Tensor, n_iters: int = 7) -> torch.Tensor: M = torch.exp(logits.clamp(-10, 10)) for _ in range(n_iters): M = M / M.sum(dim=-1, keepdim=True).clamp(min=1e-10) M = M / M.sum(dim=-2, keepdim=True).clamp(min=1e-10) return M class ManifoldHyperConnection(nn.Module): def __init__(self, dim: int, expansion: int = 2) -> None: super().__init__() self.dim = dim self.expansion = expansion n = expansion self.expand_fn = "duplicate" self.collapse_fn = "mean" self.bias_pre = nn.Parameter(torch.zeros(1, n)) self.bias_post = nn.Parameter(torch.zeros(1, n)) self.bias_res = nn.Parameter(torch.zeros(n, n)) self.theta_pre = nn.Linear(n * dim, n, bias=False) self.theta_post = nn.Linear(n * dim, n, bias=False) self.theta_res = nn.Linear(n * dim, n * n, bias=False) self.alpha_pre = nn.Parameter(torch.tensor(0.0)) self.alpha_post = nn.Parameter(torch.tensor(0.0)) self.alpha_res = nn.Parameter(torch.tensor(0.0)) def _compute_mappings( self, x_expanded: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: B, T, _ = x_expanded.shape n = self.expansion x_norm = F.rms_norm(x_expanded, [x_expanded.shape[-1]]) d_pre = torch.tanh(self.theta_pre(x_norm)) d_post = torch.tanh(self.theta_post(x_norm)) d_res = self.theta_res(x_norm) H_pre_raw = torch.sigmoid(self.alpha_pre * d_pre + self.bias_pre) H_post_raw = 2.0 * torch.sigmoid(self.alpha_post * d_post + self.bias_post) H_res_raw = (self.alpha_res * d_res + self.bias_res.reshape(1, 1, -1)).reshape( B, T, n, n ) H_res = _sinkhorn_knopp(H_res_raw) return H_pre_raw.unsqueeze(-2), H_post_raw.unsqueeze(-2), H_res def expand_stream(self, x: torch.Tensor) -> torch.Tensor: return x.repeat(1, 1, self.expansion) def collapse_stream(self, x_expanded: torch.Tensor) -> torch.Tensor: B, T, _ = x_expanded.shape n = self.expansion C = self.dim return x_expanded.view(B, T, n, C).mean(dim=-2) def pre_mix(self, x_expanded: torch.Tensor, H_pre: torch.Tensor) -> torch.Tensor: B, T, _ = x_expanded.shape n = self.expansion x_streams = x_expanded.view(B, T, n, self.dim) return (H_pre @ x_streams).squeeze(-2) def post_res_mix( self, layer_output: torch.Tensor, x_expanded: torch.Tensor, H_post: torch.Tensor, H_res: torch.Tensor, ) -> torch.Tensor: B, T, _ = x_expanded.shape n = self.expansion C = self.dim x_streams = x_expanded.view(B, T, n, C) mixed = torch.matmul(H_res, x_streams) post_out = torch.matmul(H_post.transpose(-2, -1), layer_output.unsqueeze(-2)) result = mixed + post_out return result.reshape(B, T, n * C) class TransformerBlock(nn.Module): def __init__( self, dim: int, n_heads: int, n_kv_heads: int, head_dim: int, ffn_dim: int, dropout: float, sliding_window: int, rope_fraction: float, engram_dim: int = 0, engram_heads: int = 4, engram_table_size: int = 8192, engram_max_ngram: int = 3, mhc_expansion: int = 1, ) -> None: super().__init__() self.norm1 = RMSNorm(dim) self.attn = CausalSelfAttention( dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, head_dim=head_dim, dropout=dropout, sliding_window=sliding_window, rope_fraction=rope_fraction, ) self.norm2 = RMSNorm(dim) self.ffn = SwiGLU(dim, ffn_dim, dropout) self.use_engram = engram_dim > 0 if self.use_engram: self.engram = EngramBlock( dim=dim, engram_dim=engram_dim, n_heads=engram_heads, table_size=engram_table_size, max_ngram=engram_max_ngram, ) self.engram_norm = RMSNorm(dim) self.use_mhc = mhc_expansion > 1 if self.use_mhc: self.mhc_attn = ManifoldHyperConnection(dim, expansion=mhc_expansion) self.mhc_ffn = ManifoldHyperConnection(dim, expansion=mhc_expansion) def forward( self, x: torch.Tensor, is_global: bool, past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = False, token_ids: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: if self.use_mhc: x_exp = self.mhc_attn.expand_stream(x) H_pre, H_post, H_res = self.mhc_attn._compute_mappings(x_exp) attn_in = self.mhc_attn.pre_mix(x_exp, H_pre) attn_out, new_kv = self.attn( self.norm1(attn_in), is_global, past_kv, use_cache ) x_exp = self.mhc_attn.post_res_mix(attn_out, x_exp, H_post, H_res) if self.use_engram: collapsed = self.mhc_attn.collapse_stream(x_exp) collapsed = collapsed + self.engram( self.engram_norm(collapsed), token_ids=token_ids ) x_exp = self.mhc_attn.expand_stream(collapsed) H_pre2, H_post2, H_res2 = self.mhc_ffn._compute_mappings(x_exp) ffn_in = self.mhc_ffn.pre_mix(x_exp, H_pre2) ffn_out = self.ffn(self.norm2(ffn_in)) x_exp = self.mhc_ffn.post_res_mix(ffn_out, x_exp, H_post2, H_res2) x = self.mhc_attn.collapse_stream(x_exp) else: attn_out, new_kv = self.attn(self.norm1(x), is_global, past_kv, use_cache) x = x + attn_out if self.use_engram: x = x + self.engram(self.engram_norm(x), token_ids=token_ids) x = x + self.ffn(self.norm2(x)) return x, new_kv class RecurrentDepthBlock(nn.Module): def __init__( self, dim: int, n_heads: int, n_kv_heads: int, head_dim: int, ffn_dim: int, dropout: float, sliding_window: int, rope_fraction: float, n_loops: int, act_threshold: float, lora_rank: int, loop_embed_dim: int, ) -> None: super().__init__() self.n_loops = max(1, n_loops) self.act_threshold = act_threshold self.loop_embed_dim = max(0, loop_embed_dim) self.norm = RMSNorm(dim) self.block = TransformerBlock( dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, head_dim=head_dim, ffn_dim=ffn_dim, dropout=dropout, sliding_window=sliding_window, rope_fraction=rope_fraction, engram_dim=0, mhc_expansion=1, ) self.injection = StableRecurrentInjection(dim) self.act = AdaptiveHalting(dim) self.lora = DepthLoRAAdapter(dim, lora_rank, self.n_loops) def forward( self, h: torch.Tensor, e: torch.Tensor, token_ids: Optional[torch.Tensor] = None, past_key_values: Optional[List[Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None, use_cache: bool = False, n_loops: Optional[int] = None, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]: loops = max(1, n_loops or self.n_loops) B, T, _ = h.shape halted = torch.zeros(B, T, device=h.device, dtype=torch.bool) cumulative_p = torch.zeros(B, T, device=h.device, dtype=h.dtype) output = torch.zeros_like(h) new_past: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None current = h final_halt = None for t in range(loops): h_loop = loop_index_embedding(current, t, self.loop_embed_dim) combined = self.norm(h_loop + e) past_kv = None if past_key_values is not None and t < len(past_key_values): past_kv = past_key_values[t] trans_out, layer_kv = self.block(combined, is_global=True, past_kv=past_kv, use_cache=use_cache, token_ids=token_ids) trans_out = trans_out + self.lora(trans_out, t) next_h = self.injection(current, e, trans_out) p = self.act(next_h) p = p * (~halted).to(p.dtype) final_halt = p should_halt = (~halted) & ((cumulative_p + p) >= self.act_threshold) update_weight = torch.where(should_halt, (1.0 - cumulative_p).clamp(min=0.0), p) output = output + next_h * update_weight.unsqueeze(-1) cumulative_p = cumulative_p + update_weight current = torch.where(halted.unsqueeze(-1), current, next_h) halted = halted | should_halt if new_past is not None: new_past.append(layer_kv) if not use_cache and bool(halted.all()): break remainder = (1.0 - cumulative_p).clamp(min=0.0) output = output + current * remainder.unsqueeze(-1) aux: Dict[str, torch.Tensor] = {} if final_halt is not None: aux["recurrent_halt_mean"] = final_halt.mean() return output, aux, new_past class TinyMemoryLM(nn.Module): def __init__( self, vocab_size: int, dim: int, n_unique_layers: int, n_logical_layers: int, n_heads: int, n_kv_heads: int, ffn_dim: int, dropout: float, mtp_horizons: Sequence[int], grad_checkpoint: bool, sliding_window: int = 512, rope_fraction: float = 0.5, embed_scale: bool = True, engram_dim: int = 0, engram_heads: int = 4, engram_table_size: int = 8192, engram_max_ngram: int = 3, mhc_expansion: int = 1, sleep_gate_cap: int = 0, sleep_gate_heads: int = 4, sleep_retention_enabled: bool = True, sleep_retention_hidden: int = 0, latent_think_layers: int = 0, prelude_layers: int = 0, coda_layers: int = 0, recurrent_loops: int = 0, recurrent_act_threshold: float = 0.99, recurrent_lora_rank: int = 0, recurrent_loop_embed_dim: int = 0, ) -> None: super().__init__() self.dim = dim self.n_unique_layers = n_unique_layers self.n_logical_layers = n_logical_layers self.grad_checkpoint = grad_checkpoint self.embed_scale_factor = math.sqrt(dim) if embed_scale else 1.0 head_dim = dim // n_heads self.embed_tokens = nn.Embedding(vocab_size, dim) self.head = nn.Linear(dim, vocab_size, bias=False) self.head.weight = self.embed_tokens.weight self.output_bias = nn.Parameter(torch.zeros(vocab_size)) self.use_recurrent_depth = recurrent_loops > 0 self.prelude_layers = max(0, prelude_layers) self.coda_layers = max(0, coda_layers) self.recurrent_loops = max(0, recurrent_loops) self.blocks: Optional[nn.ModuleList] = None self.prelude: Optional[nn.ModuleList] = None self.recurrent: Optional[RecurrentDepthBlock] = None self.coda: Optional[nn.ModuleList] = None def _make_blocks(n: int) -> nn.ModuleList: return nn.ModuleList([ TransformerBlock( dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, head_dim=head_dim, ffn_dim=ffn_dim, dropout=dropout, sliding_window=sliding_window, rope_fraction=rope_fraction, engram_dim=engram_dim, engram_heads=engram_heads, engram_table_size=engram_table_size, engram_max_ngram=engram_max_ngram, mhc_expansion=mhc_expansion, ) for _ in range(n) ]) if self.use_recurrent_depth: if self.prelude_layers > 0: self.prelude = _make_blocks(self.prelude_layers) self.recurrent = RecurrentDepthBlock( dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, head_dim=head_dim, ffn_dim=ffn_dim, dropout=dropout, sliding_window=sliding_window, rope_fraction=rope_fraction, n_loops=self.recurrent_loops, act_threshold=recurrent_act_threshold, lora_rank=recurrent_lora_rank, loop_embed_dim=recurrent_loop_embed_dim or max(2, dim // 8), ) if self.coda_layers > 0: self.coda = _make_blocks(self.coda_layers) else: self.blocks = _make_blocks(max(1, n_unique_layers)) self.norm = RMSNorm(dim) self.mtp_horizons = sorted({int(h) for h in mtp_horizons if int(h) > 1}) self.mtp_adapters = nn.ModuleDict( {str(h): nn.Linear(dim, dim, bias=False) for h in self.mtp_horizons} ) self.mtp_norms = nn.ModuleDict( {str(h): RMSNorm(dim) for h in self.mtp_horizons} ) res_scale = (2 * max(1, n_logical_layers)) ** -0.5 for group in (self.blocks, self.prelude, self.coda): if group is None: continue for block in group: block.attn.wo.weight.data.mul_(res_scale) block.ffn.down.weight.data.mul_(res_scale) if self.recurrent is not None: self.recurrent.block.attn.wo.weight.data.mul_(res_scale) self.recurrent.block.ffn.down.weight.data.mul_(res_scale) self.sleep_gate: Optional[SleepGate] = None if sleep_gate_cap > 0: self.sleep_gate = SleepGate( dim=dim, cap=sleep_gate_cap, n_heads=sleep_gate_heads, retention_enabled=sleep_retention_enabled, retention_hidden=sleep_retention_hidden, ) self.think_blocks: Optional[nn.ModuleList] = None self.think_norm: Optional[RMSNorm] = None if latent_think_layers > 0: self.think_blocks = nn.ModuleList([ TransformerBlock( dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, head_dim=head_dim, ffn_dim=ffn_dim, dropout=0.0, sliding_window=2048, rope_fraction=rope_fraction, engram_dim=0, mhc_expansion=1, ) for _ in range(latent_think_layers) ]) self.think_norm = RMSNorm(dim) def resize_token_embeddings(self, new_vocab_size: int) -> None: old_vocab_size = self.embed_tokens.num_embeddings if new_vocab_size == old_vocab_size: return device = self.embed_tokens.weight.device old_embed_weight = self.embed_tokens.weight.data.clone() self.embed_tokens = nn.Embedding(new_vocab_size, self.embed_tokens.embedding_dim).to(device) self.head = nn.Linear(self.embed_tokens.embedding_dim, new_vocab_size, bias=False).to(device) self.head.weight = self.embed_tokens.weight old_bias = self.output_bias.data.clone() self.output_bias = nn.Parameter(torch.zeros(new_vocab_size, device=device)) copy_size = min(old_vocab_size, new_vocab_size) self.output_bias.data[:copy_size] = old_bias[:copy_size] self.embed_tokens.weight.data[:copy_size] = old_embed_weight[:copy_size] def _build_logical_layers(self) -> List[Tuple[nn.Module, int]]: if self.blocks is None: return [] blocks_list = list(self.blocks) full_sequence = blocks_list + blocks_list return [(block, i) for i, block in enumerate(full_sequence[: self.n_logical_layers])] def forward( self, ids: torch.Tensor, use_cache: bool = False, past_key_values: Optional[List[Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None, return_hidden: bool = False, ) -> Tuple[torch.Tensor, Dict[int, torch.Tensor], Dict[str, torch.Tensor], Optional[torch.Tensor], Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]: B, T = ids.shape x = self.embed_tokens(ids) * self.embed_scale_factor new_past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None aux: Dict[str, torch.Tensor] = {} if self.use_recurrent_depth: offset = 0 if self.prelude is not None: for block in self.prelude: past_kv = past_key_values[offset] if past_key_values is not None and offset < len(past_key_values) else None x, layer_kv = block(x, is_global=True, past_kv=past_kv, use_cache=use_cache, token_ids=ids) if new_past_key_values is not None: new_past_key_values.append(layer_kv) offset += 1 encoded = x recurrent_past = past_key_values[offset: offset + self.recurrent_loops] if past_key_values is not None else None x, recurrent_aux, recurrent_kv = self.recurrent( x, encoded, token_ids=ids, past_key_values=recurrent_past, use_cache=use_cache, ) aux.update(recurrent_aux) if new_past_key_values is not None and recurrent_kv is not None: new_past_key_values.extend(recurrent_kv) offset += self.recurrent_loops if self.coda is not None: for block in self.coda: past_kv = past_key_values[offset] if past_key_values is not None and offset < len(past_key_values) else None x, layer_kv = block(x, is_global=True, past_kv=past_kv, use_cache=use_cache, token_ids=ids) if new_past_key_values is not None: new_past_key_values.append(layer_kv) offset += 1 else: logical_layers = self._build_logical_layers() last_logical_idx = len(logical_layers) - 1 for layer_idx, (block, logical_idx) in enumerate(logical_layers): is_global = logical_idx % 2 == 0 or layer_idx == last_logical_idx past_kv = past_key_values[layer_idx] if past_key_values is not None and layer_idx < len(past_key_values) else None if self.grad_checkpoint and self.training and not use_cache: x, layer_kv = checkpoint(block, x, is_global, past_kv, use_cache, ids, use_reentrant=True) else: x, layer_kv = block(x, is_global, past_kv, use_cache, ids) if new_past_key_values is not None: new_past_key_values.append(layer_kv) x = self.norm(x) if self.sleep_gate is not None: x = x + self.sleep_gate.read(x) if self.training: self.sleep_gate.write(x) if self.think_blocks is not None: for think_block in self.think_blocks: x, _ = think_block(x, is_global=True) x = self.think_norm(x) h_out = x if return_hidden else None logits = self.head(x) if self.embed_scale_factor != 1.0: logits = logits / self.embed_scale_factor logits = logits + self.output_bias mtp: Dict[int, torch.Tensor] = {} if self.mtp_horizons and self.training: for horizon in self.mtp_horizons: if horizon > 1 and horizon <= T - 1: shifted_h = x[:, :-horizon, :] adapted_h = self.mtp_adapters[str(horizon)](shifted_h) adapted_h = self.mtp_norms[str(horizon)](adapted_h) mtp_logits = self.head(adapted_h) if self.embed_scale_factor != 1.0: mtp_logits = mtp_logits / self.embed_scale_factor mtp_logits = mtp_logits + self.output_bias mtp[horizon] = mtp_logits return logits, mtp, aux, h_out, new_past_key_values # --------------------------------------------------------------------------- # Generation # --------------------------------------------------------------------------- def build_stop_token_ids(tokenizer: WordTokenizer) -> set: stop_tokens = {tokenizer.eos_id} for tok in ("<|user|>", "<|system|>", "<|assistant|>"): tid = tokenizer.token_to_id.get(tok) if tid is not None: stop_tokens.add(int(tid)) return stop_tokens def apply_no_repeat_ngram( logits: torch.Tensor, token_history: Sequence[int], ngram_size: int, ) -> torch.Tensor: if ngram_size <= 1 or len(token_history) < max(0, ngram_size - 1): return logits prefix = tuple(token_history[-(ngram_size - 1) :]) if ngram_size > 1 else tuple() banned: set = set() for i in range(len(token_history) - ngram_size + 1): if tuple(token_history[i : i + ngram_size - 1]) == prefix: banned.add(int(token_history[i + ngram_size - 1])) if not banned: return logits out = logits.clone() banned_ids = torch.tensor(sorted(banned), device=logits.device, dtype=torch.long) out[banned_ids] = float("-inf") return out def apply_loop_penalty( logits: torch.Tensor, tokenizer: WordTokenizer, generated_text: str, penalty: float = 5.0, ) -> torch.Tensor: """Detect repeated substring loops and penalise continuation tokens.""" if len(generated_text) < 16: return logits out = logits.clone() for span_len in [24, 16, 12, 8]: if len(generated_text) < span_len * 2: continue suffix = generated_text[-span_len:] prev = generated_text[:-span_len].rfind(suffix) if prev == -1: continue next_pos = prev + span_len if next_pos < len(generated_text): next_char = generated_text[next_pos] tid = tokenizer.token_to_id.get(next_char) if tid is not None: out[tid] -= penalty break return out def apply_min_p(logits: torch.Tensor, min_p: float) -> torch.Tensor: """Filter tokens below min_p fraction of the top token probability.""" if min_p <= 0.0: return logits probs = torch.softmax(logits, dim=-1) threshold = probs.max() * min_p out = logits.clone() out[probs < threshold] = float("-inf") return out def generate( model: TinyMemoryLM, tokenizer: WordTokenizer, prompt: str, max_new_tokens: int = 256, temperature: float = 0.8, top_k: int = 16, top_p: float = 0.95, repetition_penalty: float = 1.0, device: str = "cuda", sft_mode: bool = True, stream: bool = True, no_repeat_ngram_size: int = 0, context_window: int = 2048, logit_soft_cap: float = 15.0, min_p: float = 0.05, loop_penalty: float = 5.0, ) -> str: if sft_mode: full_prompt = f"<|user|>\n{prompt}\n<|assistant|>\n" else: full_prompt = prompt input_ids = tokenizer.encode(full_prompt, add_bos=True, add_eos=False) input_ids_t = torch.tensor([input_ids], dtype=torch.long, device=device) visible_tokens: List[str] = [] stop_token_ids = build_stop_token_ids(tokenizer) generated_text = "" generated_ids: List[int] = [] # Full history (prompt + generated) for ngram blocking — prevents echoing prompt full_ids_history: List[int] = list(input_ids) with torch.no_grad(): for _ in range(max_new_tokens): ctx_ids = ( input_ids_t[:, -context_window:] if context_window > 0 else input_ids_t ) logits, *_ = model(ctx_ids) next_logits = logits[0, -1, :].clone() # Logit soft-capping (Gemma-style) — prevents overconfident collapse if logit_soft_cap > 0: next_logits = logit_soft_cap * torch.tanh(next_logits / logit_soft_cap) raw_next_logits = next_logits.clone() # Repetition penalty on previously generated tokens if repetition_penalty != 1.0 and generated_ids: for tok_id in set(generated_ids): if next_logits[tok_id] > 0: next_logits[tok_id] /= repetition_penalty else: next_logits[tok_id] *= repetition_penalty # No-repeat n-gram blocking on generated tokens only if no_repeat_ngram_size > 0 and generated_ids: next_logits = apply_no_repeat_ngram(next_logits, generated_ids, no_repeat_ngram_size) # Substring loop detection next_logits = apply_loop_penalty(next_logits, tokenizer, generated_text, penalty=loop_penalty) # Temperature scaling if temperature != 1.0: next_logits = next_logits / max(temperature, 1e-6) # Min-p filtering — remove tokens below min_p * max_prob if min_p > 0: next_logits = apply_min_p(next_logits, min_p) # Top-k filtering if top_k > 0: v, _ = torch.topk(next_logits, min(top_k, next_logits.size(0))) next_logits[next_logits < v[-1]] = float("-inf") # Top-p (nucleus) filtering if 0.0 < top_p < 1.0: sorted_logits, sorted_indices = torch.sort(next_logits, descending=True) sorted_probs = torch.softmax(sorted_logits, dim=-1) cumulative_probs = torch.cumsum(sorted_probs, dim=-1) remove_mask = cumulative_probs > top_p remove_mask[0] = False indices_to_remove = sorted_indices[remove_mask] next_logits[indices_to_remove] = float("-inf") # Fallback if all tokens masked if not torch.isfinite(next_logits).any(): next_logits = raw_next_logits if temperature != 1.0: next_logits = next_logits / max(temperature, 1e-6) if temperature == 0: next_id = torch.argmax(next_logits).item() else: probs = torch.softmax(next_logits, dim=-1) next_id = torch.multinomial(probs, num_samples=1).item() if next_id in stop_token_ids: break token_str = ( tokenizer.id_to_token[next_id] if next_id < len(tokenizer.id_to_token) else "" ) generated_ids.append(next_id) full_ids_history.append(next_id) if token_str not in tokenizer.special: visible_tokens.append(token_str) generated_text += token_str if stream: print(token_str, end="", flush=True) input_ids_t = torch.cat( [input_ids_t, torch.tensor([[next_id]], device=device)], dim=1 ) if stream: print() return "".join(visible_tokens) def generate_stream( model: TinyMemoryLM, tokenizer: WordTokenizer, prompt: str, max_new_tokens: int = 256, temperature: float = 0.8, top_k: int = 16, top_p: float = 0.95, repetition_penalty: float = 1.0, device: str = "cpu", sft_mode: bool = True, no_repeat_ngram_size: int = 0, context_window: int = 2048, logit_soft_cap: float = 15.0, min_p: float = 0.05, loop_penalty: float = 5.0, ) -> "Iterator[str]": """Yield the accumulated response string after each new token (for Gradio streaming).""" if sft_mode: full_prompt = f"<|user|>\n{prompt}\n<|assistant|>\n" else: full_prompt = prompt input_ids = tokenizer.encode(full_prompt, add_bos=True, add_eos=False) input_ids_t = torch.tensor([input_ids], dtype=torch.long, device=device) stop_token_ids = build_stop_token_ids(tokenizer) generated_ids: list = [] generated_text = "" with torch.no_grad(): for _ in range(max_new_tokens): ctx_ids = input_ids_t[:, -context_window:] if context_window > 0 else input_ids_t logits, *_ = model(ctx_ids) next_logits = logits[0, -1, :].clone() if logit_soft_cap > 0: next_logits = logit_soft_cap * torch.tanh(next_logits / logit_soft_cap) raw_next_logits = next_logits.clone() if repetition_penalty != 1.0 and generated_ids: for tok_id in set(generated_ids): if next_logits[tok_id] > 0: next_logits[tok_id] /= repetition_penalty else: next_logits[tok_id] *= repetition_penalty if no_repeat_ngram_size > 0 and generated_ids: next_logits = apply_no_repeat_ngram(next_logits, generated_ids, no_repeat_ngram_size) next_logits = apply_loop_penalty(next_logits, tokenizer, generated_text, penalty=loop_penalty) if temperature != 1.0: next_logits = next_logits / max(temperature, 1e-6) if min_p > 0: next_logits = apply_min_p(next_logits, min_p) if top_k > 0: v, _ = torch.topk(next_logits, min(top_k, next_logits.size(0))) next_logits[next_logits < v[-1]] = float("-inf") if 0.0 < top_p < 1.0: sorted_logits, sorted_indices = torch.sort(next_logits, descending=True) sorted_probs = torch.softmax(sorted_logits, dim=-1) cumulative_probs = torch.cumsum(sorted_probs, dim=-1) remove_mask = cumulative_probs > top_p remove_mask[0] = False next_logits[sorted_indices[remove_mask]] = float("-inf") if not torch.isfinite(next_logits).any(): next_logits = raw_next_logits if temperature != 1.0: next_logits = next_logits / max(temperature, 1e-6) if temperature == 0: next_id = int(torch.argmax(next_logits).item()) else: probs = torch.softmax(next_logits, dim=-1) next_id = int(torch.multinomial(probs, num_samples=1).item()) if next_id in stop_token_ids: break token_str = tokenizer.id_to_token[next_id] if next_id < len(tokenizer.id_to_token) else "" generated_ids.append(next_id) if token_str not in tokenizer.special: generated_text += token_str yield generated_text input_ids_t = torch.cat( [input_ids_t, torch.tensor([[next_id]], device=device)], dim=1 ) # --------------------------------------------------------------------------- # Local model loading # --------------------------------------------------------------------------- def series_from_name(name: str) -> str | None: lower = (name or "").lower() if "haiku" in lower: return "Haiku" if "sonnet" in lower: return "Sonnet" if "opus" in lower: return "Opus" return None def series_config(series: str) -> dict[str, object]: return MODEL_SERIES.get(series.lower(), MODEL_SERIES["sonnet"]) def discover_models(runs_dir: Path) -> List[dict]: models = [] if not runs_dir.is_dir(): return models for child in sorted(runs_dir.iterdir()): if not child.is_dir(): continue tokenizer_path = child / "tokenizer.json" if not tokenizer_path.exists(): continue name = child.name series = None for ckpt_name in ("model.pt", "pretrain.pt"): ckpt_path = child / ckpt_name if ckpt_path.exists(): series = _fast_series_from_checkpoint(ckpt_path) break if series is None: series = series_from_name(name) or "Sonnet" found = False for ckpt_name in ("model.pt", "model_rep.pt", "pretrain.pt"): ckpt_path = child / ckpt_name if ckpt_path.exists(): models.append( { "name": name, "checkpoint": ckpt_name, "series": series, "model_path": ckpt_path, "tokenizer_path": tokenizer_path, } ) found = True if not found: step_ckpts = sorted( child.glob("checkpoint_step_*.pt"), key=lambda p: int(p.stem.rsplit("_", 1)[-1]), ) if step_ckpts: ckpt_path = step_ckpts[-1] models.append( { "name": name, "checkpoint": ckpt_path.name, "series": series, "model_path": ckpt_path, "tokenizer_path": tokenizer_path, } ) return models def _detect_engram(state_dict): for key in state_dict: if ".engram." in key: if ".embeddings." in key: return state_dict[key].shape[-1] return 0 def _detect_mhc(state_dict): for key, val in state_dict.items(): if ".mhc_attn.bias_pre" in key and val.dim() == 2: return val.shape[-1] # (1, expansion) return 1 def _detect_sleep_gate(state_dict) -> Tuple[int, int]: for key, val in state_dict.items(): if key == "sleep_gate.mem_emb" and val.dim() == 2: cap = val.shape[0] return cap, 4 return 0, 4 def _detect_latent_think(state_dict) -> int: indices = { int(k.split(".")[1]) for k in state_dict if k.startswith("think_blocks.") and k.split(".")[1].isdigit() } return max(indices) + 1 if indices else 0 def _detect_prelude_layers(state_dict) -> int: indices = { int(k.split(".")[1]) for k in state_dict if k.startswith("prelude.") and k.split(".")[1].isdigit() } return max(indices) + 1 if indices else 0 def _detect_coda_layers(state_dict) -> int: indices = { int(k.split(".")[1]) for k in state_dict if k.startswith("coda.") and k.split(".")[1].isdigit() } return max(indices) + 1 if indices else 0 def _detect_recurrent_loops(state_dict) -> int: if "recurrent.norm.weight" in state_dict or "recurrent.block.attn.wq.weight" in state_dict: if "recurrent.lora.scale.weight" in state_dict: return state_dict["recurrent.lora.scale.weight"].shape[0] return 1 return 0 def _detect_recurrent_lora_rank(state_dict) -> int: for key in ("recurrent.lora.B", "recurrent.lora.down.weight"): if key in state_dict: shape = state_dict[key].shape if len(shape) == 2: return int(shape[0]) return 0 def _infer_series_from_lora_rank(rank: int) -> str | None: if rank == 0: return None if rank <= 8: return "haiku" if rank <= 16: return "sonnet" return "opus" def _fast_series_from_checkpoint(ckpt_path: Path) -> str | None: try: cp = torch.load(ckpt_path, map_location="cpu", weights_only=False) sd = cp.get("model_state", cp.get("state_dict", {})) rank = 0 for key in ("recurrent.lora.B", "recurrent.lora.down.weight"): if key in sd: rank = int(sd[key].shape[0]) break if rank == 0: return None if rank <= 8: return "Haiku" if rank <= 16: return "Sonnet" return "Opus" except Exception: pass return None def _infer_arch_from_state_dict(state_dict, cfg): """Infer architecture hyper-parameters directly from checkpoint weights, falling back to *cfg* (series config) when a key is not found.""" overrides = {} has_prelude = any(k.startswith("prelude.") for k in state_dict) has_blocks = any(k.startswith("blocks.") for k in state_dict) has_recurrent = any(k.startswith("recurrent.") for k in state_dict) uses_recurrent_arch = has_prelude and has_recurrent and not has_blocks # dim from embed_tokens.weight [vocab, dim] if "embed_tokens.weight" in state_dict: overrides["dim"] = state_dict["embed_tokens.weight"].shape[1] if uses_recurrent_arch: if "prelude.0.ffn.gate.weight" in state_dict: overrides["ffn_dim"] = state_dict["prelude.0.ffn.gate.weight"].shape[0] overrides["n_unique_layers"] = 0 src = "prelude.0" else: if "blocks.0.ffn.gate.weight" in state_dict: overrides["ffn_dim"] = state_dict["blocks.0.ffn.gate.weight"].shape[0] block_ids = { int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.") and k.split(".")[1].isdigit() } if block_ids: overrides["n_unique_layers"] = max(block_ids) + 1 src = "blocks.0" dim = overrides.get("dim", int(cfg.get("dim", model_config.dim))) if f"{src}.attn.wq.weight" in state_dict: wq_rows = state_dict[f"{src}.attn.wq.weight"].shape[0] if f"{src}.attn.q_norm.weight" in state_dict: head_dim = state_dict[f"{src}.attn.q_norm.weight"].shape[0] overrides["n_heads"] = wq_rows // head_dim if f"{src}.attn.wk.weight" in state_dict: wk_rows = state_dict[f"{src}.attn.wk.weight"].shape[0] if f"{src}.attn.k_norm.weight" in state_dict: head_dim = state_dict[f"{src}.attn.k_norm.weight"].shape[0] overrides["n_kv_heads"] = wk_rows // head_dim # engram params for key, val in state_dict.items(): if ".engram.embeddings." in key and key.endswith("_0") and val.dim() == 2: overrides["engram_table_size"] = val.shape[0] overrides["engram_dim"] = val.shape[1] break engram_dim = overrides.get("engram_dim", int(cfg.get("engram_dim", 0))) engram_max_ngram = int(cfg.get("engram_max_ngram", 2)) if engram_dim > 0: for key, val in state_dict.items(): if ".engram.branch_conv.weight" in key and val.dim() == 3: total_branch_dim = val.shape[0] denom = engram_dim * (engram_max_ngram - 1) if denom > 0 and total_branch_dim % denom == 0: overrides["engram_heads"] = total_branch_dim // denom break merged = dict(cfg) merged.update(overrides) return merged def load_local_model(model_path: Path, tokenizer_path: Path, series: str) -> dict: tokenizer = WordTokenizer.load(tokenizer_path) ckpt = torch.load(str(model_path), map_location="cpu", weights_only=False) cfg = series_config(series) vocab_size = int(ckpt.get("vocab_size", tokenizer.vocab_size)) state_dict = ckpt.get("model_state") or ckpt.get("state_dict") or ckpt cfg = _infer_arch_from_state_dict(state_dict, cfg) engram_dim = int(cfg.get("engram_dim", 0)) if _detect_engram(state_dict) == 0: engram_dim = 0 mhc_expansion = _detect_mhc(state_dict) if mhc_expansion == 1: mhc_expansion = int(cfg.get("mhc_expansion", 1)) ckpt_sleep_cap, ckpt_sleep_heads = _detect_sleep_gate(state_dict) sleep_gate_cap = ckpt_sleep_cap if ckpt_sleep_cap > 0 else int(cfg.get("sleep_gate_cap", 0)) sleep_gate_heads = ckpt_sleep_heads if ckpt_sleep_cap > 0 else int(cfg.get("sleep_gate_heads", 4)) sleep_retention_enabled = bool(cfg.get("sleep_retention_enabled", True)) sleep_retention_hidden = int(cfg.get("sleep_retention_hidden", 0)) latent_think_layers = _detect_latent_think(state_dict) if latent_think_layers == 0: latent_think_layers = int(cfg.get("latent_think_layers", 0)) prelude_layers = _detect_prelude_layers(state_dict) coda_layers = _detect_coda_layers(state_dict) recurrent_loops = _detect_recurrent_loops(state_dict) ckpt_lora_rank = _detect_recurrent_lora_rank(state_dict) if ckpt_lora_rank > 0: inferred_series = _infer_series_from_lora_rank(ckpt_lora_rank) if inferred_series and inferred_series != series.lower(): series = inferred_series.capitalize() cfg = series_config(series) recurrent_lora_rank = ckpt_lora_rank else: recurrent_lora_rank = int(cfg.get("recurrent_lora_rank", 0)) recurrent_act_threshold = float(cfg.get("recurrent_act_threshold", 0.99)) recurrent_loop_embed_dim = int(cfg.get("recurrent_loop_embed_dim", 0)) n_unique = int(cfg.get("n_unique_layers", model_config.n_unique_layers)) model = TinyMemoryLM( vocab_size=vocab_size, dim=int(cfg.get("dim", model_config.dim)), n_unique_layers=n_unique, n_logical_layers=int(cfg.get("n_logical_layers", model_config.n_logical_layers)), n_heads=int(cfg.get("n_heads", model_config.n_heads)), n_kv_heads=int(cfg.get("n_kv_heads", model_config.n_kv_heads)), ffn_dim=int(cfg.get("ffn_dim", model_config.ffn_dim)), dropout=float(cfg.get("dropout", model_config.dropout)), mtp_horizons=tuple(int(v) for v in cfg.get("mtp_horizons", model_config.mtp_horizons)), grad_checkpoint=False, sliding_window=int(cfg.get("sliding_window_size", getattr(model_config, "sliding_window_size", 512))), rope_fraction=float(cfg.get("rope_fraction", getattr(model_config, "rope_fraction", 0.25))), embed_scale=bool(cfg.get("embed_scale", getattr(model_config, "embed_scale", True))), engram_dim=engram_dim, engram_heads=int(cfg.get("engram_heads", 4)), engram_table_size=int(cfg.get("engram_table_size", 8192)), engram_max_ngram=int(cfg.get("engram_max_ngram", 3)), mhc_expansion=mhc_expansion, sleep_gate_cap=sleep_gate_cap, sleep_gate_heads=sleep_gate_heads, sleep_retention_enabled=sleep_retention_enabled, sleep_retention_hidden=sleep_retention_hidden, latent_think_layers=latent_think_layers, prelude_layers=prelude_layers, coda_layers=coda_layers, recurrent_loops=recurrent_loops, recurrent_act_threshold=recurrent_act_threshold, recurrent_lora_rank=recurrent_lora_rank, recurrent_loop_embed_dim=recurrent_loop_embed_dim, ) model.load_state_dict(state_dict, strict=False) model.eval() if tokenizer.vocab_size > vocab_size: model.resize_token_embeddings(tokenizer.vocab_size) device = "cuda" if torch.cuda.is_available() else "cpu" model = model.to(device) return { "model": model, "tokenizer": tokenizer, "device": device, "series": series, "sft_mode": ckpt.get("sft_mode", None), "phase": ckpt.get("phase", None), } # --------------------------------------------------------------------------- # HuggingFace Model Download & Loading # --------------------------------------------------------------------------- def download_huggingface_model(hf_id: str, cache_dir: Path) -> dict: try: from huggingface_hub import snapshot_download except ImportError: print("huggingface_hub not installed. Install with: pip install huggingface_hub") sys.exit(1) print(f"Downloading {hf_id}...") try: local_dir = Path(snapshot_download(repo_id=hf_id, cache_dir=str(cache_dir))) except Exception as e: print(f"Failed to download {hf_id}: {e}") return None print(f"Using cached {hf_id} from {local_dir}") # Check common subdirectory names: "models/", "model/" if (local_dir / "models").exists(): model_dir = local_dir / "models" elif (local_dir / "model").exists(): model_dir = local_dir / "model" else: model_dir = local_dir model_path = model_dir / "model.pt" pretrain_path = model_dir / "pretrain.pt" tokenizer_path = model_dir / "tokenizer.json" ckpt_path = None for p in [model_path, pretrain_path]: if p.exists(): ckpt_path = p break if ckpt_path is None or not tokenizer_path.exists(): print(f"Missing model files in {model_dir}") print(f" model.pt exists: {model_path.exists()}") print(f" pretrain.pt exists: {pretrain_path.exists()}") print(f" tokenizer.json exists: {tokenizer_path.exists()}") return None return { "model_path": ckpt_path, "tokenizer_path": tokenizer_path, "model_name": ckpt_path.stem, } def load_huggingface_model(hf_id: str, cache_dir: Path) -> dict: files = download_huggingface_model(hf_id, cache_dir) if files is None: return None return load_local_model(files["model_path"], files["tokenizer_path"], "Haiku") # --------------------------------------------------------------------------- # Compare All Models # --------------------------------------------------------------------------- _hf_model_cache: Dict[str, dict] = {} def prefetch_huggingface_models() -> None: root = Path(__file__).resolve().parent cache_dir = root / "cache" / "huggingface" cache_dir.mkdir(parents=True, exist_ok=True) print("Downloading/preparing HuggingFace models...") for name, hf_id in HUGGINGFACE_MODELS.items(): print(f" {name}...") bundle = load_huggingface_model(hf_id, cache_dir) if bundle: _hf_model_cache[name] = bundle print(f"Prepared {len(_hf_model_cache)} HuggingFace models") def compare_all_models(prompt: str, cfg: dict) -> None: root = Path(__file__).resolve().parent runs_dir = root / "runs" all_models = discover_models(runs_dir) is_pretrain = not cfg.get("sft_mode", True) local_models = [ m for m in all_models if ("pretrain" in m["checkpoint"]) == is_pretrain ] if not local_models: print("No models found matching mode.") return results: List[dict] = [] for m in local_models: print(f"\n{'='*60}") print(f"Loading local {m['name']}/{m['checkpoint']}...") try: bundle = load_local_model(m["model_path"], m["tokenizer_path"], m["series"]) except Exception as e: print(f"Failed to load {m['name']}: {e}") continue model = bundle["model"] tokenizer = bundle["tokenizer"] device = bundle["device"] print(f"Generating on '{prompt}'...") output = generate( model=model, tokenizer=tokenizer, prompt=prompt, max_new_tokens=cfg["max_new_tokens"], temperature=cfg["temperature"], top_k=cfg["top_k"], top_p=cfg["top_p"], min_p=cfg["min_p"], no_repeat_ngram_size=cfg["no_repeat_ngram_size"], repetition_penalty=cfg["repetition_penalty"], logit_soft_cap=cfg["logit_soft_cap"], loop_penalty=cfg["loop_penalty"], device=str(device), sft_mode=cfg["sft_mode"], stream=True, context_window=cfg["context_window"], ) results.append({ "name": f"[LOCAL] {m['name']}/{m['checkpoint']}", "output": output, "device": device, }) for name, bundle in _hf_model_cache.items(): print(f"\n{'='*60}") print(f"Loading {name} (cached)...") model = bundle["model"] tokenizer = bundle["tokenizer"] device = bundle["device"] print(f"Generating on '{prompt}'...") output = generate( model=model, tokenizer=tokenizer, prompt=prompt, max_new_tokens=cfg["max_new_tokens"], temperature=cfg["temperature"], top_k=cfg["top_k"], top_p=cfg["top_p"], min_p=cfg["min_p"], no_repeat_ngram_size=cfg["no_repeat_ngram_size"], repetition_penalty=cfg["repetition_penalty"], logit_soft_cap=cfg["logit_soft_cap"], loop_penalty=cfg["loop_penalty"], device=str(device), sft_mode=cfg["sft_mode"], stream=True, context_window=cfg["context_window"], ) results.append({ "name": name, "output": output, "device": device, }) print(f"\n{'='*60}") print("=" * 60) print("SIDE-BY-SIDE COMPARISON") print("=" * 60) for r in results: print(f"\n--- {r['name']} ---") print(r["output"]) print(f"\n{'='*60}") # --------------------------------------------------------------------------- # Benchmark # --------------------------------------------------------------------------- BENCHMARKS = { "blimp": { "label": "BLiMP", "desc": "Grammaticality minimal pairs (67 paradigms). Accuracy = % grammatical < ungrammatical perplexity.", "hf_dataset": ("nyu-mll/blimp", None), "metric": "accuracy", }, "wikitext2": { "label": "WikiText-2", "desc": "LM perplexity on Wikipedia test split. Lower is better.", "hf_dataset": ("Salesforce/wikitext", "wikitext-2-raw-v1"), "metric": "perplexity", }, "arc_easy": { "label": "ARC-Easy", "desc": "Multiple-choice science QA (~2.4K). Perplexity-based answer selection.", "hf_dataset": ("allenai/ai2_arc", "ARC-Easy"), "metric": "accuracy", }, } def _score_text(model: TinyMemoryLM, tokenizer: WordTokenizer, text: str, device: str) -> float: ids = tokenizer.encode(text, add_bos=True, add_eos=False) if len(ids) < 2: return float("nan") ids_t = torch.tensor([ids], dtype=torch.long, device=device) with torch.no_grad(): logits, *_ = model(ids_t) log_probs = F.log_softmax(logits[0], dim=-1) targets = ids_t[0, 1:] nll = -log_probs[range(len(targets)), targets].mean().item() return nll def _score_completion(model: TinyMemoryLM, tokenizer: WordTokenizer, context: str, completion: str, device: str) -> float: full_ids = tokenizer.encode(context + completion, add_bos=True, add_eos=False) ctx_ids = tokenizer.encode(context, add_bos=True, add_eos=False) n_ctx = len(ctx_ids) n_ref = len(full_ids) - n_ctx if n_ref <= 0: return float("nan") ids_t = torch.tensor([full_ids], dtype=torch.long, device=device) with torch.no_grad(): logits, *_ = model(ids_t) log_probs = F.log_softmax(logits[0], dim=-1) targets = ids_t[0, 1:] ref_start = n_ctx - 1 ref_end = min(ref_start + n_ref, log_probs.shape[0]) if ref_start >= ref_end: return float("nan") nll = -log_probs[ref_start:ref_end][range(ref_end - ref_start), targets[ref_start:ref_end]].mean().item() return nll BLIMP_PARADIGMS = [ "adjunct_island", "anaphor_gender_agreement", "anaphor_number_agreement", "animate_subject_passive", "animate_subject_trans", "causative", "complex_NP_island", "coordinate_structure_constraint_complex_left_branch", "coordinate_structure_constraint_object_extraction", "determiner_noun_agreement_1", "determiner_noun_agreement_2", "determiner_noun_agreement_irregular_1", "determiner_noun_agreement_irregular_2", "determiner_noun_agreement_with_adj_2", "determiner_noun_agreement_with_adj_irregular_1", "determiner_noun_agreement_with_adj_irregular_2", "determiner_noun_agreement_with_adjective_1", "distractor_agreement_relational_noun", "distractor_agreement_relative_clause", "drop_argument", "ellipsis_n_bar_1", "ellipsis_n_bar_2", "existential_there_object_raising", "existential_there_quantifiers_1", "existential_there_quantifiers_2", "existential_there_subject_raising", "expletive_it_object_raising", "inchoative", "intransitive", "irregular_past_participle_adjectives", "irregular_past_participle_verbs", "irregular_plural_subject_verb_agreement_1", "irregular_plural_subject_verb_agreement_2", "left_branch_island_echo_question", "left_branch_island_simple_question", "matrix_question_npi_licensor_present", "npi_present_1", "npi_present_2", "only_npi_licensor_present", "only_npi_scope", "passive_1", "passive_2", "principle_A_c_command", "principle_A_case_1", "principle_A_case_2", "principle_A_domain_1", "principle_A_domain_2", "principle_A_domain_3", "principle_A_reconstruction", "regular_plural_subject_verb_agreement_1", "regular_plural_subject_verb_agreement_2", "sentential_negation_npi_licensor_present", "sentential_negation_npi_scope", "sentential_subject_island", "superlative_quantifiers_1", "superlative_quantifiers_2", "tough_vs_raising_1", "tough_vs_raising_2", "transitive", "wh_island", "wh_questions_object_gap", "wh_questions_subject_gap", "wh_questions_subject_gap_long_distance", "wh_vs_that_no_gap", "wh_vs_that_no_gap_long_distance", "wh_vs_that_with_gap", "wh_vs_that_with_gap_long_distance", ] def _run_blimp(model, tokenizer, device, n_samples: int = 200) -> Tuple[List[str], List[float]]: from datasets import load_dataset # type: ignore accuracies: List[float] = [] for paradigm in BLIMP_PARADIGMS: try: ds = load_dataset("nyu-mll/blimp", paradigm, split="train") except Exception as e: print(f" {paradigm}: skip ({e})") accuracies.append(float("nan")) continue items = list(ds)[:n_samples] correct = 0 for ex in items: good_nll = _score_text(model, tokenizer, ex["sentence_good"], device) bad_nll = _score_text(model, tokenizer, ex["sentence_bad"], device) if math.isnan(good_nll) or math.isnan(bad_nll): continue if good_nll < bad_nll: correct += 1 acc = correct / len(items) if items else float("nan") accuracies.append(acc) print(f" {paradigm:50s} acc={acc:.3f}") return BLIMP_PARADIGMS, accuracies def _run_wikitext2(model, tokenizer, device, chunk_chars: int = 512, max_chunks: int = 100) -> Tuple[List[str], List[float]]: from datasets import load_dataset # type: ignore ds = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="test") full_text = "\n".join(ex["text"] for ex in ds if ex["text"].strip()) chunks = [full_text[i:i + chunk_chars] for i in range(0, len(full_text), chunk_chars)] chunks = [c for c in chunks if len(c) > 20][:max_chunks] labels: List[str] = [] ppls: List[float] = [] for i, chunk in enumerate(chunks): nll = _score_text(model, tokenizer, chunk, device) ppl = math.exp(nll) if not math.isnan(nll) else float("nan") labels.append(f"chunk {i + 1}") ppls.append(ppl) if (i + 1) % 10 == 0: valid = [v for v in ppls if not math.isnan(v)] mean = sum(valid) / len(valid) if valid else float("nan") print(f" chunk {i + 1}/{len(chunks)} running mean ppl={mean:.2f}") return labels, ppls def _run_arc_easy(model, tokenizer, device, max_samples: int = 200) -> Tuple[List[str], List[float]]: from datasets import load_dataset # type: ignore ds = load_dataset("allenai/ai2_arc", "ARC-Easy", split="test") items = list(ds)[:max_samples] labels: List[str] = [] scores: List[float] = [] for i, ex in enumerate(items): question = ex["question"] choices = ex["choices"]["text"] choice_labels = ex["choices"]["label"] answer_key = ex["answerKey"] context = f"Question: {question}\nAnswer:" nlls = [_score_completion(model, tokenizer, context, f" {c}", device) for c in choices] if all(math.isnan(v) for v in nlls): scores.append(float("nan")) else: best_idx = min(range(len(nlls)), key=lambda j: nlls[j] if not math.isnan(nlls[j]) else float("inf")) predicted = choice_labels[best_idx] scores.append(1.0 if predicted == answer_key else 0.0) labels.append(f"Q{i + 1}") n_valid = sum(1 for s in scores if not math.isnan(s)) acc = sum(s for s in scores if not math.isnan(s)) / n_valid if n_valid else float("nan") print(f" {n_valid} questions evaluated, accuracy={acc:.3f}") return labels, scores def run_benchmark_mode() -> None: try: import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt except ImportError: print("matplotlib not installed. pip install matplotlib") return bench_keys = list(BENCHMARKS.keys()) print("\nBenchmarks:") for i, k in enumerate(bench_keys): b = BENCHMARKS[k] print(f" [{i + 1}] {b['label']} — {b['desc']}") print("Select benchmark [1]:", end=" ", flush=True) try: b_choice = input().strip() or "1" except (EOFError, KeyboardInterrupt): print() return if not (b_choice.isdigit() and 1 <= int(b_choice) <= len(bench_keys)): print("Invalid selection.") return bench_key = bench_keys[int(b_choice) - 1] bench = BENCHMARKS[bench_key] print(f"Benchmark: {bench['label']}") root = Path(__file__).resolve().parent runs_dir = root / "runs" all_models = discover_models(runs_dir) model_entries: List[dict] = [] for m in all_models: model_entries.append({"label": f"[LOCAL] {m['name']}/{m['checkpoint']}", "type": "local", "meta": m}) for hf_name, hf_id in HUGGINGFACE_MODELS.items(): model_entries.append({"label": f"[HF] {hf_name}", "type": "hf", "hf_id": hf_id, "hf_name": hf_name}) if not model_entries: print("No models found.") return print("\nAvailable models:") for i, e in enumerate(model_entries): print(f" [{i + 1}] {e['label']}") print(" [a] All models") print("Select models (comma-separated or 'a'):", end=" ", flush=True) try: raw = input().strip() except (EOFError, KeyboardInterrupt): print() return if raw.lower() == "a": selected = list(range(len(model_entries))) else: selected = [] for tok in raw.split(","): tok = tok.strip() if tok.isdigit() and 1 <= int(tok) <= len(model_entries): selected.append(int(tok) - 1) if not selected: print("No valid selection.") return all_results: List[dict] = [] shared_x_labels: Optional[List[str]] = None for idx in selected: entry = model_entries[idx] print(f"\n{'='*60}\nLoading {entry['label']}...") try: if entry["type"] == "local": m = entry["meta"] bundle = load_local_model(m["model_path"], m["tokenizer_path"], m["series"]) else: bundle = load_huggingface_model(entry["hf_id"], root / ".hf_cache") except Exception as e: print(f" Failed: {e}") continue model = bundle["model"] tokenizer = bundle["tokenizer"] device = str(bundle["device"]) model.eval() if bench_key == "blimp": x_labels, y_vals = _run_blimp(model, tokenizer, device) elif bench_key == "wikitext2": x_labels, y_vals = _run_wikitext2(model, tokenizer, device) else: x_labels, y_vals = _run_arc_easy(model, tokenizer, device) if shared_x_labels is None: shared_x_labels = x_labels valid = [v for v in y_vals if not math.isnan(v)] summary = sum(valid) / len(valid) if valid else float("nan") all_results.append({"label": entry["label"], "y": y_vals, "summary": summary}) if not all_results or shared_x_labels is None: print("No results to plot.") return metric = bench["metric"] paired = sorted(zip([r["summary"] for r in all_results], [r["label"] for r in all_results]), reverse=(metric != "perplexity")) summaries, model_labels = zip(*paired) if paired else ([], []) n = len(summaries) colors = [plt.cm.RdYlGn(i / max(n - 1, 1)) for i in range(n)] fig, ax = plt.subplots(figsize=(max(6, n * 1.4), 6)) bars = ax.bar(range(n), summaries, color=colors, edgecolor="black") for bar, val in zip(bars, summaries): ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.005, f"{val:.3f}", ha="center", va="bottom", fontsize=9, fontweight="bold") ylabel = "Mean Perplexity (↓ better)" if metric == "perplexity" else "Mean Accuracy (↑ better)" ax.set_ylabel(ylabel) ax.set_title(f"{bench['label']} Benchmark — Model Comparison") ax.set_xticks(range(n)) ax.set_xticklabels(model_labels, rotation=20, ha="right", fontsize=9) if metric == "accuracy": ax.set_ylim(0, 1.05) ax.grid(True, axis="y", alpha=0.3) plt.tight_layout() out_path = root / f"benchmark_{bench_key}.png" plt.savefig(str(out_path), dpi=150) print(f"\nChart saved to {out_path}") try: import subprocess subprocess.Popen(["xdg-open", str(out_path)]) except Exception: pass # --------------------------------------------------------------------------- # Interactive CLI # --------------------------------------------------------------------------- def _pick_series(detected: str) -> str: series_list = list(MODEL_SERIES.keys()) detected_lower = detected.lower() default_idx = next( (i + 1 for i, s in enumerate(series_list) if s == detected_lower), 1 ) # Skip selection if only one series available if len(series_list) == 1: return series_list[0].capitalize() print("Series:") for i, s in enumerate(series_list): marker = " (detected)" if s == detected_lower else "" print(f" [{i + 1}] {s.capitalize()}{marker}") while True: try: choice = input(f"Select series [{default_idx}]: ").strip() except (EOFError, KeyboardInterrupt): print() sys.exit(0) if not choice: choice = str(default_idx) if choice.isdigit() and 1 <= int(choice) <= len(series_list): return series_list[int(choice) - 1].capitalize() print(f"Enter a number 1-{len(series_list)}") def pick_model(runs_dir: Path) -> tuple[dict, str]: models = discover_models(runs_dir) if not models: print(f"No models found in {runs_dir}") print("Expected layout: runs//model.pt (or pretrain.pt) + tokenizer.json") sys.exit(1) if len(models) == 1: m = models[0] print(f"Loading {m['name']}/{m['checkpoint']} ({m['series']})...") bundle = load_local_model(m["model_path"], m["tokenizer_path"], m["series"]) return bundle, m["checkpoint"] print("Available models:") for i, m in enumerate(models): print(f" [{i + 1}] {m['name']}/{m['checkpoint']} ({m['series']})") while True: try: choice = input("Select model [1]: ").strip() except (EOFError, KeyboardInterrupt): print() sys.exit(0) if not choice: choice = "1" if choice.isdigit() and 1 <= int(choice) <= len(models): m = models[int(choice) - 1] print(f"Loading {m['name']}/{m['checkpoint']} ({m['series']})...") bundle = load_local_model(m["model_path"], m["tokenizer_path"], m["series"]) return bundle, m["checkpoint"] print(f"Enter a number 1-{len(models)}") # --------------------------------------------------------------------------- # Generation mode configs # --------------------------------------------------------------------------- MODES = { "chat-coherent": { "label": "Chat — Coherent", "desc": "structured, consistent, strong repetition control", "sft_mode": "chat", "temperature": 0.35, "top_k": 20, "top_p": 0.88, "min_p": 0.10, "no_repeat_ngram_size": 4, "repetition_penalty": 1.22, "logit_soft_cap": 20.0, "loop_penalty": 20.0, "max_new_tokens": 4096, "context_window": 2048, }, "chat-variants": { "label": "Chat — Variants", "desc": "creative, diverse, more surprising outputs", "sft_mode": "chat", "temperature": 0.65, "top_k": 60, "top_p": 0.92, "min_p": 0.05, "no_repeat_ngram_size": 4, "repetition_penalty": 1.12, "logit_soft_cap": 20.0, "loop_penalty": 14.0, "max_new_tokens": 4096, "context_window": 2048, }, "pretrain-coherent": { "label": "Pretrain — Coherent", "desc": "grounded continuation, low temperature, tight sampling", "sft_mode": False, "temperature": 0.3, "top_k": 20, "top_p": 0.85, "min_p": 0.10, "no_repeat_ngram_size": 4, "repetition_penalty": 1.2, "logit_soft_cap": 20.0, "loop_penalty": 20.0, "max_new_tokens": 4096, "context_window": 2048, }, "pretrain-variants": { "label": "Pretrain — Variants", "desc": "free-form continuation, higher temperature, more exploration", "sft_mode": False, "temperature": 0.7, "top_k": 60, "top_p": 0.93, "min_p": 0.04, "no_repeat_ngram_size": 4, "repetition_penalty": 1.12, "logit_soft_cap": 20.0, "loop_penalty": 12.0, "max_new_tokens": 4096, "context_window": 2048, }, } _MODE_LIST = list(MODES.keys()) def pick_mode(is_pretrain: bool) -> dict: """Prompt the user to choose a generation mode. Returns a config dict.""" # Filter to relevant modes based on checkpoint type candidates = [k for k in _MODE_LIST if ("pretrain" in k) == is_pretrain] print("\nGeneration mode:") for i, key in enumerate(candidates): cfg = MODES[key] print(f" [{i + 1}] {cfg['label']} — {cfg['desc']}") while True: try: choice = input("Select mode [1]: ").strip() except (EOFError, KeyboardInterrupt): print() sys.exit(0) if not choice: choice = "1" if choice.isdigit() and 1 <= int(choice) <= len(candidates): key = candidates[int(choice) - 1] cfg = MODES[key] print(f"Mode: {cfg['label']}") return cfg print(f"Enter a number 1-{len(candidates)}") def _run_loop(bundle: dict, cfg: dict) -> None: model = bundle["model"] tokenizer = bundle["tokenizer"] device = bundle["device"] sft = cfg["sft_mode"] prompt_label = "You" if sft else "Prompt" print(f"\nModel ready on {device}. Type your message, or /quit to exit.") print(f" temp={cfg['temperature']} top_k={cfg['top_k']} top_p={cfg['top_p']}") print(f" min_p={cfg['min_p']} ng={cfg['no_repeat_ngram_size']} rp={cfg['repetition_penalty']}") print(f" cap={cfg['logit_soft_cap']} loop_penalty={cfg['loop_penalty']}\n") while True: try: prompt = input(f"{prompt_label}: ").strip() except (EOFError, KeyboardInterrupt): print() break if not prompt: continue if prompt in ("/quit", "/exit", "/q"): break if prompt == "/help": print("Commands: /quit /exit /q /help /mode") if sft: print("Anything else is sent as a chat prompt.") else: print("Anything else is sent as a raw continuation prompt.") continue if prompt == "/mode": print(f"Current: {cfg['label']} — {cfg['desc']}") continue print("AI: ", end="", flush=True) generate( model=model, tokenizer=tokenizer, prompt=prompt, max_new_tokens=cfg["max_new_tokens"], temperature=cfg["temperature"], top_k=cfg["top_k"], top_p=cfg["top_p"], min_p=cfg["min_p"], no_repeat_ngram_size=cfg["no_repeat_ngram_size"], repetition_penalty=cfg["repetition_penalty"], logit_soft_cap=cfg["logit_soft_cap"], loop_penalty=cfg["loop_penalty"], device=str(device), sft_mode=cfg["sft_mode"], stream=True, context_window=cfg["context_window"], ) # --------------------------------------------------------------------------- # Dynamic collection discovery # --------------------------------------------------------------------------- _COLLECTION_SLUG = "CompactAI-O/tmlm-haiku-series" _AUTHOR = "CompactAI-O" _SEARCH = "TMLM-Haiku" _FALLBACK_COLLECTION = [ {"version": "TMLM-Haiku-2.3", "hf_id": "CompactAI-O/TMLM-Haiku-2.3"}, {"version": "TMLM-Haiku-2", "hf_id": "CompactAI-O/TMLM-Haiku-2"}, {"version": "TMLM-Haiku-1.3", "hf_id": "CompactAI-O/TMLM-Haiku-1.3"}, {"version": "TMLM-Haiku-1", "hf_id": "CompactAI-O/TMLM-Haiku-1"}, {"version": "Glint-1", "hf_id": "CompactAI-O/Glint-1"}, ] _EXTRA_REPOS = ["CompactAI-O/Glint-1"] def _probe_repo(hf_id: str) -> dict | None: """Return entry dict for one repo, or None if no usable checkpoints found.""" from huggingface_hub import list_repo_files try: files = set(list_repo_files(hf_id)) except Exception: return None # Detect which subdirectory holds the checkpoints subdir: str | None = None for candidate in ("models", "model"): if any(f.startswith(f"{candidate}/") for f in files): subdir = candidate break prefix = f"{subdir}/" if subdir else "" # Collect all .pt files in the checkpoint directory pt_files = sorted( f[len(prefix):] for f in files if f.startswith(prefix) and f.endswith(".pt") ) _LABELS = { "model.pt": ("Chat (SFT)", False), "model_rep.pt": ("Chat (anti-repetition)", False), "pretrain.pt": ("Pretrain (base)", True), } checkpoints = [] for fname in pt_files: label, is_pretrain = _LABELS.get(fname, (fname.removesuffix(".pt"), "pretrain" in fname)) checkpoints.append((label, fname, is_pretrain)) if not checkpoints: return None return { "version": hf_id.split("/")[-1], "hf_id": hf_id, "subdir": subdir, "checkpoints": checkpoints, "desc": "", } def fetch_collection() -> list[dict]: """Query HF for all CompactAI-O TMLM-Haiku models, newest first.""" from huggingface_hub import HfApi print("Checking HuggingFace collection for available models...") try: api = HfApi() infos = list( api.list_models( author=_AUTHOR, search=_SEARCH, sort="lastModified", ) ) infos.sort(key=lambda m: getattr(m, "lastModified", ""), reverse=True) except Exception as exc: print(f" Could not reach HuggingFace ({exc}); using fallback list.") infos = [type("M", (), {"id": e["hf_id"]})() for e in _FALLBACK_COLLECTION] entries = [] seen_ids: set = set() for info in infos: repo_id = info.id if _SEARCH.lower() not in repo_id.lower(): continue entry = _probe_repo(repo_id) if entry: entries.append(entry) seen_ids.add(repo_id) # Always include extra repos (e.g. Glint-1) not caught by TMLM-Haiku search for repo_id in _EXTRA_REPOS: if repo_id not in seen_ids: entry = _probe_repo(repo_id) if entry: entries.append(entry) seen_ids.add(repo_id) if not entries: print(" No models found; using fallback list.") for fb in _FALLBACK_COLLECTION: e = _probe_repo(fb["hf_id"]) if e: entries.append(e) return entries # --------------------------------------------------------------------------- # Download helper # --------------------------------------------------------------------------- def _download_version(entry: dict, cache_dir: Path) -> Path: """Download full repo snapshot; return the directory containing model files.""" try: from huggingface_hub import snapshot_download except ImportError: print("huggingface_hub not installed. Run: pip install huggingface_hub") sys.exit(1) hf_id = entry["hf_id"] print(f"Fetching {hf_id} ...") try: local_dir = Path(snapshot_download(repo_id=hf_id, cache_dir=str(cache_dir))) except Exception as exc: print(f"Download failed: {exc}") sys.exit(1) subdir = entry.get("subdir") model_dir = (local_dir / subdir) if subdir else local_dir if not model_dir.exists(): # Fallback to root model_dir = local_dir return model_dir # --------------------------------------------------------------------------- # Selection prompts # --------------------------------------------------------------------------- def _prompt_int(prompt: str, lo: int, hi: int, default: int = 1) -> int: while True: try: raw = input(f"{prompt} [{default}]: ").strip() except (EOFError, KeyboardInterrupt): print() sys.exit(0) if not raw: return default if raw.isdigit() and lo <= int(raw) <= hi: return int(raw) print(f" Enter a number {lo}–{hi}.") def pick_version(collection: list[dict]) -> dict: print("\nTMLM-Haiku series (CompactAI-O)\n") for i, entry in enumerate(collection): desc = f" — {entry['desc']}" if entry["desc"] else "" print(f" [{i + 1}] {entry['version']}{desc}") idx = _prompt_int("Select version", 1, len(collection)) return collection[idx - 1] def pick_checkpoint(entry: dict) -> tuple[str, bool]: """Return (filename, is_pretrain).""" ckpts = entry["checkpoints"] if len(ckpts) == 1: label, fname, is_pretrain = ckpts[0] print(f" Using: {label} ({fname})") return fname, is_pretrain print(f"\nCheckpoints for {entry['version']}:") for i, (label, fname, _) in enumerate(ckpts): print(f" [{i + 1}] {label} ({fname})") idx = _prompt_int("Select checkpoint", 1, len(ckpts)) label, fname, is_pretrain = ckpts[idx - 1] return fname, is_pretrain # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- # --------------------------------------------------------------------------- # Gradio Space # --------------------------------------------------------------------------- import gradio as gr _CACHE_DIR = Path(__file__).parent / ".hf_cache" _CACHE_DIR.mkdir(parents=True, exist_ok=True) _collection_cache: list = [] _model_cache: dict = {} def _get_collection() -> list: global _collection_cache if not _collection_cache: try: _collection_cache = fetch_collection() except Exception as e: print(f"Warning: fetch_collection failed ({e}); using fallback.") _collection_cache = [ _probe_repo(e["hf_id"]) or {"version": e["version"], "hf_id": e["hf_id"], "subdir": None, "checkpoints": [("Chat (SFT)", "model.pt", False)], "desc": ""} for e in _FALLBACK_COLLECTION ] return _collection_cache def _collection_versions() -> list[str]: return [e["version"] for e in _get_collection()] def _checkpoints_for(version: str) -> list[tuple[str, str, bool]]: for e in _get_collection(): if e["version"] == version: return e["checkpoints"] return [] def _ckpt_labels(version: str) -> list[str]: return [label for label, _, _ in _checkpoints_for(version)] def _ckpt_is_pretrain(version: str, label: str) -> bool: for lbl, _, is_pt in _checkpoints_for(version): if lbl == label: return is_pt return False def _ckpt_fname(version: str, label: str) -> str: for lbl, fname, _ in _checkpoints_for(version): if lbl == label: return fname return "model.pt" def _load_bundle(version: str, ckpt_label: str) -> dict: key = f"{version}/{ckpt_label}" if key not in _model_cache: fname = _ckpt_fname(version, ckpt_label) for entry in _get_collection(): if entry["version"] == version: model_dir = _download_version(entry, _CACHE_DIR) model_path = model_dir / fname tokenizer_path = model_dir / "tokenizer.json" _model_cache[key] = load_local_model(model_path, tokenizer_path, "Haiku") break return _model_cache[key] def _build_conversation_prompt(history: list[dict], new_message: str) -> str: """Flatten Gradio messages-format history + new turn into a raw prompt.""" parts = [] # history is [{role, content}, ...] pairs already in order i = 0 while i < len(history) - 1: u = history[i] a = history[i + 1] if u["role"] == "user" and a["role"] == "assistant": parts.append(f"<|user|>\n{u['content']}\n<|assistant|>\n{a['content']}") i += 2 parts.append(f"<|user|>\n{new_message}\n<|assistant|>\n") return "".join(parts) # ---- chat ---- def _on_version_change(version): labels = _ckpt_labels(version) return gr.update(choices=labels, value=labels[0] if labels else None) def _chat_submit(message, history): history = history or [] history.append({"role": "user", "content": message}) history.append({"role": "assistant", "content": ""}) return "", history def _chat_stream(history, version, ckpt_label, mode_key, use_custom, temperature, top_k, top_p, min_p, rep_penalty, ngram_size, soft_cap, loop_pen, max_tokens, ctx_win, raw_mode): if not history or history[-1]["role"] != "assistant": yield history return try: bundle = _load_bundle(version, ckpt_label) except Exception as e: history[-1]["content"] = f"[Error loading model: {e}]" yield history return prior_msgs = history[:-2] # exclude the current user+empty-assistant pair new_msg = history[-2]["content"] if use_custom: cfg = { "sft_mode": not raw_mode, "temperature": temperature, "top_k": top_k, "top_p": top_p, "min_p": min_p, "repetition_penalty": rep_penalty, "no_repeat_ngram_size": ngram_size, "logit_soft_cap": soft_cap, "loop_penalty": loop_pen, "max_new_tokens": max_tokens, "context_window": ctx_win, } else: cfg = dict(MODES[mode_key]) # Max new tokens slider always applies (independent of preset override) cfg["max_new_tokens"] = int(max_tokens) cfg["context_window"] = int(ctx_win) if prior_msgs: prompt = _build_conversation_prompt(prior_msgs, new_msg) sft = False else: prompt = new_msg sft = cfg["sft_mode"] for partial in generate_stream( model=bundle["model"], tokenizer=bundle["tokenizer"], prompt=prompt, device=str(bundle["device"]), sft_mode=sft, temperature=cfg["temperature"], top_k=cfg["top_k"], top_p=cfg["top_p"], min_p=cfg["min_p"], repetition_penalty=cfg["repetition_penalty"], no_repeat_ngram_size=cfg["no_repeat_ngram_size"], logit_soft_cap=cfg["logit_soft_cap"], loop_penalty=cfg["loop_penalty"], max_new_tokens=cfg["max_new_tokens"], context_window=cfg["context_window"], ): history[-1]["content"] = partial yield history # ---- compare ---- def _compare_fn(prompt, selected_versions, mode_key, use_custom, temperature, top_k, top_p, min_p, rep_penalty, ngram_size, soft_cap, loop_pen, max_tokens, ctx_win, raw_mode): if use_custom: cfg = { "sft_mode": not raw_mode, "temperature": temperature, "top_k": top_k, "top_p": top_p, "min_p": min_p, "repetition_penalty": rep_penalty, "no_repeat_ngram_size": ngram_size, "logit_soft_cap": soft_cap, "loop_penalty": loop_pen, "max_new_tokens": max_tokens, "context_window": ctx_win, } else: cfg = dict(MODES[mode_key]) # Max new tokens slider always applies (independent of preset override) cfg["max_new_tokens"] = int(max_tokens) cfg["context_window"] = int(ctx_win) # Iterate + emit oldest → newest (Haiku-1 first, Glint-1 last) so the order # matches the output-box layout in the UI. all_versions = _sort_oldest_to_newest(_collection_versions()) selected = set(selected_versions or []) state = {v: ("⏳ Queued…" if v in selected else "") for v in all_versions} def _emit(): return [state[v] for v in all_versions] yield _emit() for version in all_versions: if version not in selected: continue labels = _ckpt_labels(version) ckpt_label = labels[0] if labels else None if not ckpt_label: state[version] = "[No checkpoint found]" yield _emit() continue state[version] = "⏳ Loading…" yield _emit() try: bundle = _load_bundle(version, ckpt_label) except Exception as e: state[version] = f"[Load error: {e}]" yield _emit() continue state[version] = "" yield _emit() try: for partial in generate_stream( model=bundle["model"], tokenizer=bundle["tokenizer"], prompt=prompt, device=str(bundle["device"]), sft_mode=cfg["sft_mode"], temperature=cfg["temperature"], top_k=cfg["top_k"], top_p=cfg["top_p"], min_p=cfg["min_p"], repetition_penalty=cfg["repetition_penalty"], no_repeat_ngram_size=cfg["no_repeat_ngram_size"], logit_soft_cap=cfg["logit_soft_cap"], loop_penalty=cfg["loop_penalty"], max_new_tokens=cfg["max_new_tokens"], context_window=cfg["context_window"], ): state[version] = partial yield _emit() except Exception as e: state[version] = f"[Generation error: {e}]" yield _emit() # ---- benchmark ---- def _benchmark_fn(bench_key, selected_versions, max_samples, progress=gr.Progress(track_tqdm=True)): import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt if not selected_versions: return "No models selected.", None bench = BENCHMARKS[bench_key] all_results = [] log_lines = [f"Benchmark: {bench['label']}", ""] for version in progress.tqdm(selected_versions, desc="Benchmarking"): log_lines.append(f"--- {version} ---") labels = _ckpt_labels(version) ckpt_label = labels[0] if labels else None if not ckpt_label: log_lines.append(" (no checkpoint)") continue try: bundle = _load_bundle(version, ckpt_label) model, tokenizer, device = bundle["model"], bundle["tokenizer"], str(bundle["device"]) model.eval() if bench_key == "blimp": _, y = _run_blimp(model, tokenizer, device, n_samples=max_samples) elif bench_key == "wikitext2": _, y = _run_wikitext2(model, tokenizer, device, max_chunks=max_samples) else: _, y = _run_arc_easy(model, tokenizer, device, max_samples=max_samples) valid = [v for v in y if not math.isnan(v)] summary = sum(valid) / len(valid) if valid else float("nan") all_results.append({"label": version, "summary": summary}) log_lines.append(f" score: {summary:.4f}") except Exception as e: log_lines.append(f" error: {e}") if not all_results: return "\n".join(log_lines), None metric = bench["metric"] paired = sorted( zip([r["summary"] for r in all_results], [r["label"] for r in all_results]), reverse=(metric != "perplexity"), ) summaries, labels_ = zip(*paired) n = len(summaries) colors = [plt.cm.RdYlGn(i / max(n - 1, 1)) for i in range(n)] fig, ax = plt.subplots(figsize=(max(6, n * 1.6), 5)) bars = ax.bar(range(n), summaries, color=colors, edgecolor="black") for bar, val in zip(bars, summaries): ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.005, f"{val:.3f}", ha="center", va="bottom", fontsize=9, fontweight="bold") ylabel = "Mean Perplexity (↓ better)" if metric == "perplexity" else "Mean Accuracy (↑ better)" ax.set_ylabel(ylabel) ax.set_title(f"{bench['label']} — Model Comparison") ax.set_xticks(range(n)) ax.set_xticklabels(labels_, rotation=20, ha="right", fontsize=9) if metric == "accuracy": ax.set_ylim(0, 1.05) ax.grid(True, axis="y", alpha=0.3) plt.tight_layout() out_path = "/tmp/benchmark_result.png" plt.savefig(out_path, dpi=150) plt.close(fig) log_lines += ["", "Done."] return "\n".join(log_lines), out_path # ---- shared advanced params ---- def _advanced_block(): with gr.Accordion("Advanced parameters", open=False): use_custom = gr.Checkbox(label="Override preset with custom values below", value=False) raw_mode = gr.Checkbox(label="Raw / pretrain mode (no <|user|> wrapping)", value=False) with gr.Row(): temperature = gr.Slider(0.0, 2.0, value=0.5, step=0.01, label="Temperature") top_k = gr.Slider(0, 200, value=20, step=1, label="Top-k") with gr.Row(): top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="Top-p") min_p = gr.Slider(0.0, 0.5, value=0.05, step=0.01, label="Min-p") with gr.Row(): rep_penalty = gr.Slider(1.0, 2.0, value=1.15, step=0.01, label="Repetition penalty") ngram_size = gr.Slider(0, 8, value=4, step=1, label="No-repeat n-gram size") with gr.Row(): soft_cap = gr.Slider(0.0, 50.0, value=20.0, step=0.5, label="Logit soft cap") loop_pen = gr.Slider(0.0, 50.0, value=15.0, step=0.5, label="Loop penalty") with gr.Row(): max_tokens = gr.Slider(16, 4096, value=512, step=16, label="Max new tokens") ctx_win = gr.Slider(128, 4096, value=2048, step=128, label="Context window") return use_custom, temperature, top_k, top_p, min_p, rep_penalty, ngram_size, soft_cap, loop_pen, max_tokens, ctx_win, raw_mode # ---- build UI ---- def _sort_oldest_to_newest(versions: list[str]) -> list[str]: """Sort versions oldest→newest using HUGGINGFACE_MODELS key order.""" order = {name: i for i, name in enumerate(HUGGINGFACE_MODELS)} return sorted( versions, key=lambda v: (order.get(v, len(order)), versions.index(v)), ) _initial_versions = _sort_oldest_to_newest(_collection_versions()) _initial_version = _initial_versions[0] if _initial_versions else None _initial_ckpt_labels = _ckpt_labels(_initial_version) if _initial_version else [] _mode_keys = list(MODES.keys()) # Hugging Face style theme — yellow primary + warm slate neutrals. _HF_THEME = gr.themes.Default( primary_hue=gr.themes.Color( c50="#FFFBEA", c100="#FFF3C4", c200="#FCE588", c300="#FADB5F", c400="#F7C948", c500="#FFD21E", # HF brand yellow c600="#F0B429", c700="#CB6E17", c800="#B44D12", c900="#8D2B0B", c950="#5C1A04", ), secondary_hue="orange", neutral_hue="slate", font=[gr.themes.GoogleFont("IBM Plex Sans"), "ui-sans-serif", "system-ui", "sans-serif"], font_mono=[gr.themes.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace"], ).set( body_background_fill="#FFFFFF", body_background_fill_dark="#0B0F19", background_fill_primary="#FFFFFF", background_fill_primary_dark="#0B0F19", background_fill_secondary="#F5F6F8", background_fill_secondary_dark="#1B1F2A", block_background_fill="#FFFFFF", block_background_fill_dark="#11151F", block_border_color="#E5E7EB", block_border_color_dark="#22273A", block_label_background_fill="#FFFBEA", block_label_background_fill_dark="#1B1F2A", block_label_text_color="#5C1A04", block_label_text_color_dark="#FFD21E", block_title_text_color="#0B0F19", block_title_text_color_dark="#F5F6F8", button_primary_background_fill="#FFD21E", button_primary_background_fill_hover="#F0B429", button_primary_text_color="#0B0F19", button_primary_text_color_hover="#0B0F19", button_secondary_background_fill="#F5F6F8", button_secondary_background_fill_dark="#1B1F2A", button_secondary_text_color="#0B0F19", button_secondary_text_color_dark="#F5F6F8", border_color_accent="#FFD21E", border_color_primary="#E5E7EB", border_color_primary_dark="#22273A", color_accent_soft="#FFFBEA", color_accent_soft_dark="#1B1F2A", ) with gr.Blocks(title="CompactAI Models", theme=_HF_THEME) as demo: gr.Markdown( "# CompactAI — TinyMemoryLM\n" "Tiny recurrent-depth language models from [CompactAI-O](https://huggingface.co/CompactAI-O)." ) # ── Chat ────────────────────────────────────────────────────────────────── with gr.Tab("Chat"): with gr.Row(): with gr.Column(scale=1, min_width=240): chat_version = gr.Dropdown( choices=_initial_versions, value=_initial_version, label="Model version", ) chat_ckpt = gr.Dropdown( choices=_initial_ckpt_labels, value=_initial_ckpt_labels[0] if _initial_ckpt_labels else None, label="Checkpoint", ) chat_mode = gr.Radio( choices=_mode_keys, value="chat-coherent", label="Mode preset", info="Ignored when 'Override preset' is checked.", ) c_use_custom, c_temp, c_topk, c_topp, c_minp, c_rep, c_ng, c_cap, c_lp, c_maxt, c_ctx, c_raw = _advanced_block() with gr.Column(scale=3): chatbot = gr.Chatbot(label="Conversation", height=500) with gr.Row(): msg_box = gr.Textbox(placeholder="Type a message…", show_label=False, scale=5) send_btn = gr.Button("Send", variant="primary", scale=1) clear_btn = gr.Button("Clear") chat_version.change(_on_version_change, chat_version, chat_ckpt) _chat_adv = [chat_version, chat_ckpt, chat_mode, c_use_custom, c_temp, c_topk, c_topp, c_minp, c_rep, c_ng, c_cap, c_lp, c_maxt, c_ctx, c_raw] msg_box.submit(_chat_submit, [msg_box, chatbot], [msg_box, chatbot], queue=False).then( _chat_stream, [chatbot] + _chat_adv, chatbot ) send_btn.click(_chat_submit, [msg_box, chatbot], [msg_box, chatbot], queue=False).then( _chat_stream, [chatbot] + _chat_adv, chatbot ) clear_btn.click(lambda: [], None, chatbot, queue=False) # ── Compare ─────────────────────────────────────────────────────────────── with gr.Tab("Compare All Models"): gr.Markdown("Run the same prompt on every selected model. Outputs stream live one model at a time.") with gr.Row(): cmp_prompt = gr.Textbox(label="Prompt", placeholder="Enter a prompt…", lines=4, scale=3) with gr.Column(scale=1): cmp_models = gr.CheckboxGroup( choices=_initial_versions, value=_initial_versions, label="Models to run" ) cmp_mode = gr.Dropdown( choices=_mode_keys, value="chat-coherent", label="Mode preset" ) cmp_use_custom, cmp_temp, cmp_topk, cmp_topp, cmp_minp, cmp_rep, cmp_ng, cmp_cap, cmp_lp, cmp_maxt, cmp_ctx, cmp_raw = _advanced_block() cmp_run = gr.Button("▶ Run comparison", variant="primary") # 2-column grid of output boxes cmp_outputs = [] for row_start in range(0, len(_initial_versions), 2): with gr.Row(): for v in _initial_versions[row_start:row_start + 2]: cmp_outputs.append(gr.Textbox(label=v, lines=10, interactive=False)) cmp_run.click( _compare_fn, inputs=[cmp_prompt, cmp_models, cmp_mode, cmp_use_custom, cmp_temp, cmp_topk, cmp_topp, cmp_minp, cmp_rep, cmp_ng, cmp_cap, cmp_lp, cmp_maxt, cmp_ctx, cmp_raw], outputs=cmp_outputs, ) # ── Benchmark ───────────────────────────────────────────────────────────── with gr.Tab("Benchmark"): gr.Markdown( "Evaluate models on standard benchmarks.\n\n" "- **BLiMP** — grammaticality minimal pairs (accuracy)\n" "- **WikiText-2** — LM perplexity (lower = better)\n" "- **ARC-Easy** — multiple-choice science QA (accuracy)" ) with gr.Row(): bench_type = gr.Radio( choices=list(BENCHMARKS.keys()), value="arc_easy", label="Benchmark" ) bench_models = gr.CheckboxGroup( choices=_initial_versions, value=[_initial_versions[0]] if _initial_versions else [], label="Models", ) bench_samples = gr.Slider(10, 500, value=100, step=10, label="Max samples (fewer = faster)") bench_run = gr.Button("Run benchmark", variant="primary") with gr.Row(): bench_log = gr.Textbox(label="Progress log", lines=12, interactive=False) bench_plot = gr.Image(label="Results chart", type="filepath") bench_run.click( _benchmark_fn, inputs=[bench_type, bench_models, bench_samples], outputs=[bench_log, bench_plot], ) if __name__ == "__main__": demo.launch(theme=_HF_THEME)