| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from dataclasses import dataclass |
| from typing import Optional, Tuple, Dict |
|
|
|
|
| @dataclass |
| class HRMCosmicFishConfig: |
| vocab_size: int = 50304 |
| n_embd: int = 448 |
| block_size: int = 512 |
|
|
| n_input_layers: int = 6 |
| n_output_layers: int = 6 |
| n_head: int = 8 |
|
|
| hrm_H_layers: int = 4 |
| hrm_L_layers: int = 4 |
| hrm_H_cycles: int = 2 |
| hrm_L_cycles: int = 2 |
| hrm_max_steps: int = 16 |
| hrm_exploration_prob: float = 0.1 |
|
|
| dropout: float = 0.1 |
| bias: bool = False |
|
|
| use_rotary: bool = True |
| use_gqa: bool = True |
| use_swiglu: bool = True |
| n_kv_head: int = 4 |
|
|
| eps: float = 1e-5 |
|
|
| forward_dtype: str = "bfloat16" |
|
|
|
|
| def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): |
| freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) |
| t = torch.arange(end, device=freqs.device) |
| freqs = torch.outer(t, freqs).float() |
| freqs_cis = torch.polar(torch.ones_like(freqs), freqs) |
| return freqs_cis |
|
|
|
|
| def apply_rotary_emb(xq, xk, freqs_cis): |
| |
| xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) |
| xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) |
| freqs_cis = freqs_cis.unsqueeze(0).unsqueeze(0) |
| freqs_cis = freqs_cis[:, :, :xq_.shape[2], :] |
| xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) |
| xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) |
| return xq_out.type_as(xq), xk_out.type_as(xk) |
|
|
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, dim: int, eps: float = 1e-5): |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(dim)) |
|
|
| def forward(self, x): |
| input_dtype = x.dtype |
| x = x.to(torch.float32) |
| variance = x.pow(2).mean(-1, keepdim=True) |
| x = x * torch.rsqrt(variance + self.eps) |
| return (self.weight * x).to(input_dtype) |
|
|
|
|
| class GroupedQueryAttention(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| assert config.n_embd % config.n_head == 0 |
|
|
| self.n_head = config.n_head |
| self.n_kv_head = config.n_kv_head if config.use_gqa else config.n_head |
| self.head_dim = config.n_embd // config.n_head |
| self.n_embd = config.n_embd |
|
|
| self.q_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) |
| self.k_proj = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=config.bias) |
| self.v_proj = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=config.bias) |
| self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) |
|
|
| self.attn_dropout = nn.Dropout(config.dropout) |
| self.resid_dropout = nn.Dropout(config.dropout) |
|
|
| self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') |
|
|
| def forward(self, x, freqs_cis=None): |
| B, T, C = x.size() |
|
|
| q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2) |
| k = self.k_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2) |
| v = self.v_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2) |
|
|
| if freqs_cis is not None: |
| q, k = apply_rotary_emb(q, k, freqs_cis) |
|
|
| if self.n_kv_head != self.n_head: |
| k = k.repeat_interleave(self.n_head // self.n_kv_head, dim=1) |
| v = v.repeat_interleave(self.n_head // self.n_kv_head, dim=1) |
|
|
| if self.flash: |
| y = torch.nn.functional.scaled_dot_product_attention( |
| q, k, v, |
| attn_mask=None, |
| dropout_p=self.attn_dropout.p if self.training else 0.0, |
| is_causal=True |
| ) |
| else: |
| att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim)) |
| att = att.masked_fill(torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool(), float('-inf')) |
| att = F.softmax(att, dim=-1) |
| att = self.attn_dropout(att) |
| y = att @ v |
|
|
| y = y.transpose(1, 2).contiguous().view(B, T, C) |
| y = self.resid_dropout(self.c_proj(y)) |
| return y |
|
|
|
|
| class MLP(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| hidden_dim = 4 * config.n_embd |
|
|
| if config.use_swiglu: |
| self.gate = nn.Linear(config.n_embd, hidden_dim, bias=config.bias) |
| self.up = nn.Linear(config.n_embd, hidden_dim, bias=config.bias) |
| self.down = nn.Linear(hidden_dim, config.n_embd, bias=config.bias) |
| self.act = nn.SiLU() |
| else: |
| self.c_fc = nn.Linear(config.n_embd, hidden_dim, bias=config.bias) |
| self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=config.bias) |
| self.act = nn.GELU() |
|
|
| self.dropout = nn.Dropout(config.dropout) |
| self.use_swiglu = config.use_swiglu |
|
|
| def forward(self, x): |
| if self.use_swiglu: |
| return self.dropout(self.down(self.act(self.up(x)) * self.gate(x))) |
| else: |
| return self.dropout(self.c_proj(self.act(self.c_fc(x)))) |
|
|
|
|
| class TransformerBlock(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.ln_1 = RMSNorm(config.n_embd, eps=config.eps) |
| self.attn = GroupedQueryAttention(config) |
| self.ln_2 = RMSNorm(config.n_embd, eps=config.eps) |
| self.mlp = MLP(config) |
|
|
| def forward(self, x, freqs_cis=None): |
| x = x + self.attn(self.ln_1(x), freqs_cis) |
| x = x + self.mlp(self.ln_2(x)) |
| return x |
|
|
|
|
| class HRMReasoningBlock(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.ln_1 = RMSNorm(config.n_embd, eps=config.eps) |
| self.attn = GroupedQueryAttention(config) |
| self.ln_2 = RMSNorm(config.n_embd, eps=config.eps) |
| self.mlp = MLP(config) |
|
|
| def forward(self, x, freqs_cis=None): |
| |
| x = self.ln_1(x + self.attn(x, freqs_cis)) |
| x = self.ln_2(x + self.mlp(x)) |
| return x |
|
|
|
|
| class HRMReasoningLevel(nn.Module): |
| def __init__(self, config, n_layers): |
| super().__init__() |
| self.layers = nn.ModuleList([HRMReasoningBlock(config) for _ in range(n_layers)]) |
|
|
| def forward(self, hidden_states, input_injection, freqs_cis=None): |
| hidden_states = hidden_states + input_injection |
| for layer in self.layers: |
| hidden_states = layer(hidden_states, freqs_cis) |
| return hidden_states |
|
|
|
|
| class HRMCore(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
|
|
| self.H_level = HRMReasoningLevel(config, config.hrm_H_layers) |
| self.L_level = HRMReasoningLevel(config, config.hrm_L_layers) |
|
|
| self.H_init = nn.Parameter(torch.randn(config.n_embd) * 0.02) |
| self.L_init = nn.Parameter(torch.randn(config.n_embd) * 0.02) |
|
|
| self.q_head = nn.Linear(config.n_embd, 2, bias=True) |
|
|
| with torch.no_grad(): |
| self.q_head.weight.zero_() |
| self.q_head.bias.fill_(-5.0) |
|
|
| def forward(self, x, freqs_cis=None, training=False): |
| B, T, C = x.size() |
| device = x.device |
|
|
| z_H = self.H_init.expand(B, T, C) |
| z_L = self.L_init.expand(B, T, C) |
|
|
| steps_taken = torch.zeros(B, dtype=torch.long, device=device) |
| halted = torch.zeros(B, dtype=torch.bool, device=device) |
|
|
| q_logits_list = [] |
|
|
| for step in range(self.config.hrm_max_steps): |
| if halted.all(): |
| break |
|
|
| with torch.set_grad_enabled(step == self.config.hrm_max_steps - 1): |
| for _h in range(self.config.hrm_H_cycles): |
| for _l in range(self.config.hrm_L_cycles): |
| z_L = self.L_level(z_L, z_H + x, freqs_cis) |
| z_H = self.H_level(z_H, z_L, freqs_cis) |
|
|
| q_input = z_H.mean(dim=1) |
| q_logits = self.q_head(q_input.float()) |
| q_logits_list.append(q_logits) |
|
|
| if self.config.hrm_max_steps > 1: |
| q_halt = q_logits[:, 0] |
| q_continue = q_logits[:, 1] |
|
|
| if not training: |
| q_halt = q_halt + 0.35 |
|
|
| should_halt = q_halt > q_continue |
|
|
| if training and torch.rand(1).item() < self.config.hrm_exploration_prob: |
| min_steps = torch.randint(2, self.config.hrm_max_steps + 1, (1,)).item() |
| should_halt = should_halt & (steps_taken >= min_steps) |
|
|
| halted = halted | should_halt |
|
|
| steps_taken = torch.where(halted, steps_taken, steps_taken + 1) |
|
|
| if step == self.config.hrm_max_steps - 1: |
| halted = torch.ones_like(halted) |
|
|
| output_q_logits = q_logits_list[-1] if q_logits_list else None |
| return z_H, steps_taken, output_q_logits |
|
|
|
|
| class HRMCosmicFish(nn.Module): |
| """ |
| Architecture: Input Blocks → HRM Reasoning Core → Output Blocks → LM Head |
| """ |
|
|
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
|
|
| self.wte = nn.Embedding(config.vocab_size, config.n_embd) |
|
|
| if config.use_rotary: |
| self.freqs_cis = precompute_freqs_cis( |
| config.n_embd // config.n_head, |
| config.block_size |
| ) |
| else: |
| self.freqs_cis = None |
| self.wpe = nn.Embedding(config.block_size, config.n_embd) |
|
|
| self.drop = nn.Dropout(config.dropout) |
|
|
| self.input_blocks = nn.ModuleList([ |
| TransformerBlock(config) for _ in range(config.n_input_layers) |
| ]) |
|
|
| self.hrm_core = HRMCore(config) |
|
|
| self.output_blocks = nn.ModuleList([ |
| TransformerBlock(config) for _ in range(config.n_output_layers) |
| ]) |
|
|
| self.ln_f = RMSNorm(config.n_embd, eps=config.eps) |
| self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
|
|
| |
| self.wte.weight = self.lm_head.weight |
|
|
| self.apply(self._init_weights) |
|
|
| for pn, p in self.named_parameters(): |
| if pn.endswith('c_proj.weight') or pn.endswith('down.weight'): |
| total_layers = config.n_input_layers + config.n_output_layers + config.hrm_H_layers + config.hrm_L_layers |
| torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * total_layers)) |
|
|
| print(f"Model initialized with {self.get_num_params() / 1e6:.2f}M parameters") |
| print(f" Input blocks: {config.n_input_layers} layers") |
| print(f" HRM Core: H={config.hrm_H_layers} L={config.hrm_L_layers} (max {config.hrm_max_steps} steps)") |
| print(f" Output blocks: {config.n_output_layers} layers") |
|
|
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| if module.bias is not None: |
| torch.nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Embedding): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
| def get_num_params(self, non_embedding=True): |
| n_params = sum(p.numel() for p in self.parameters()) |
| if non_embedding and hasattr(self, 'wpe'): |
| n_params -= self.wpe.weight.numel() |
| return n_params |
|
|
| def forward(self, idx, targets=None): |
| device = idx.device |
| B, T = idx.size() |
| assert T <= self.config.block_size, f"Sequence length {T} exceeds block size {self.config.block_size}" |
|
|
| x = self.wte(idx) |
|
|
| if self.config.use_rotary: |
| freqs_cis = self.freqs_cis.to(device) if self.freqs_cis is not None else None |
| else: |
| pos = torch.arange(0, T, dtype=torch.long, device=device) |
| x = x + self.wpe(pos) |
| freqs_cis = None |
|
|
| x = self.drop(x) |
|
|
| for block in self.input_blocks: |
| x = block(x, freqs_cis) |
|
|
| x, steps_taken, q_logits = self.hrm_core(x, freqs_cis, training=self.training) |
|
|
| for block in self.output_blocks: |
| x = block(x, freqs_cis) |
|
|
| x = self.ln_f(x) |
| logits = self.lm_head(x) |
|
|
| loss = None |
| if targets is not None: |
| task_loss = F.cross_entropy( |
| logits.view(-1, logits.size(-1)), |
| targets.view(-1), |
| ignore_index=-1 |
| ) |
| step_penalty = 0.01 * steps_taken.float().mean() |
| loss = task_loss + step_penalty |
|
|
| return logits, loss, steps_taken, q_logits |
|
|
| @torch.no_grad() |
| def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): |
| for _ in range(max_new_tokens): |
| idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:] |
|
|
| logits, _, _, _ = self(idx_cond) |
| logits = logits[:, -1, :] / temperature |
|
|
| if top_k is not None: |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| logits[logits < v[:, [-1]]] = -float('Inf') |
|
|
| probs = F.softmax(logits, dim=-1) |
| idx_next = torch.multinomial(probs, num_samples=1) |
| idx = torch.cat((idx, idx_next), dim=1) |
|
|
| return idx |