| """ |
| Nano Reasoning Model (NRM) - Main Architecture |
| |
| ARCHITECTURE DESIGN PHILOSOPHY: |
| ================================ |
| This model maximizes reasoning ability per parameter through several key innovations: |
| |
| 1. SHARED LAYERS: The middle layers are shared (looped through multiple times). |
| This creates a form of "iterative refinement" - the model processes information |
| multiple passes, similar to how recurrent networks process sequences but applied |
| to depth instead. This is inspired by Universal Transformers and ALBERT. |
| |
| WHY IT HELPS REASONING: Reasoning often requires iterative refinement of |
| intermediate representations. Shared layers let the model "think more" without |
| more parameters. |
| |
| 2. THINKING TOKENS: Special <THINK> and </THINK> tokens create a "scratchpad" |
| where the model can show intermediate reasoning steps. The model is trained to |
| use <STEP> tokens for each logical step. |
| |
| WHY IT HELPS: Decomposing complex problems into steps is THE key capability |
| for reasoning. Even large models benefit from chain-of-thought prompting. |
| |
| 3. WEIGHT TYING: Input and output embeddings share the same weight matrix. |
| This halves the embedding parameter count and creates a natural link between |
| token understanding and token generation. |
| |
| WHY IT HELPS CPU: Fewer parameters = faster forward/backward passes. |
| |
| 4. LOW-RANK PROJECTIONS: All attention and MLP projections use LoRA-style |
| factored matrices, cutting parameter count by ~8x in linear layers. |
| |
| 5. GROUPED QUERY ATTENTION: KV heads are shared across query heads, |
| reducing KV projection parameters and memory. |
| |
| PARAMETER BUDGET (~10M): |
| Embedding: 2048 * 256 = 524K (shared with output head) |
| Per unique layer: ~200K |
| 4 unique + 2 shared (run 2x) = 6 effective layers |
| Total: ~2.1M (layers) + 524K (embed) ≈ 2.6M unique params |
| Effective computation: ~3.1M param equivalent |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Optional, Dict |
| from components import TransformerBlock, RMSNorm |
|
|
|
|
| class NanoReasoningModel(nn.Module): |
| def __init__(self, config: dict): |
| super().__init__() |
| self.config = config |
| |
| d_model = config['d_model'] |
| n_heads = config['n_heads'] |
| n_layers = config['n_layers'] |
| n_shared = config.get('n_shared_layers', 2) |
| d_ff = config['d_ff'] |
| vocab_size = config['vocab_size'] |
| max_seq_len = config['max_seq_len'] |
| dropout = config.get('dropout', 0.05) |
| rank = config.get('lora_rank', 16) |
| self.use_thinking = config.get('use_thinking_tokens', True) |
| self.n_thinking_steps = config.get('n_thinking_steps', 2) |
| n_kv_heads = config.get('n_kv_heads', n_heads // 2) |
| |
| |
| self.token_embedding = nn.Embedding(vocab_size, d_model) |
| self.embedding_dropout = nn.Dropout(dropout) |
| |
| |
| n_unique = n_layers - n_shared |
| self.entry_layers = nn.ModuleList([ |
| TransformerBlock(d_model, n_heads, d_ff, rank, dropout, max_seq_len, n_kv_heads) |
| for _ in range(n_unique // 2) |
| ]) |
| |
| |
| self.shared_layers = nn.ModuleList([ |
| TransformerBlock(d_model, n_heads, d_ff, rank, dropout, max_seq_len, n_kv_heads) |
| for _ in range(n_shared) |
| ]) |
| |
| |
| self.exit_layers = nn.ModuleList([ |
| TransformerBlock(d_model, n_heads, d_ff, rank, dropout, max_seq_len, n_kv_heads) |
| for _ in range(n_unique - n_unique // 2) |
| ]) |
| |
| |
| self.final_norm = RMSNorm(d_model) |
| |
| |
| self.output_head = nn.Linear(d_model, vocab_size, bias=False) |
| |
| if config.get('weight_tying', True): |
| self.output_head.weight = self.token_embedding.weight |
| |
| |
| if self.use_thinking: |
| self.think_gate = nn.Parameter(torch.tensor(0.5)) |
| |
| |
| self.apply(self._init_weights) |
| |
| |
| self._count_parameters() |
|
|
| def _init_weights(self, module: nn.Module): |
| """Initialize weights with scaled initialization for stability.""" |
| if isinstance(module, nn.Linear): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| if module.bias is not None: |
| torch.nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Embedding): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
| def _count_parameters(self): |
| """Count and report parameters.""" |
| total = sum(p.numel() for p in self.parameters()) |
| trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) |
| |
| |
| unique = sum(p.numel() for p in self.parameters()) |
| |
| self.total_params = total |
| self.trainable_params = trainable |
| print(f"\n{'='*50}") |
| print(f"NRM Model Configuration:") |
| print(f" d_model: {self.config['d_model']}") |
| print(f" n_heads: {self.config['n_heads']}") |
| print(f" n_layers: {self.config['n_layers']} " |
| f"({len(self.entry_layers)} entry + {len(self.shared_layers)} shared + {len(self.exit_layers)} exit)") |
| print(f" d_ff: {self.config['d_ff']}") |
| print(f" vocab_size: {self.config['vocab_size']}") |
| print(f" LoRA rank: {self.config.get('lora_rank', 16)}") |
| print(f" Thinking: {'enabled' if self.use_thinking else 'disabled'}") |
| print(f" Total parameters: {total:,}") |
| print(f" Trainable parameters: {trainable:,}") |
| print(f"{'='*50}\n") |
|
|
| def forward(self, input_ids: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| labels: Optional[torch.Tensor] = None, |
| n_think_loops: int = 1) -> Dict[str, torch.Tensor]: |
| """ |
| Forward pass with optional thinking loops. |
| |
| n_think_loops: How many times to loop through shared layers. |
| During reasoning, we increase this to give the model more "thinking time". |
| """ |
| B, T = input_ids.shape |
| |
| |
| x = self.token_embedding(input_ids) |
| x = self.embedding_dropout(x) |
| |
| |
| pad_mask = None |
| if attention_mask is not None: |
| pad_mask = (attention_mask == 0) |
| |
| |
| for layer in self.entry_layers: |
| x = layer(x, pad_mask) |
| |
| |
| actual_loops = max(1, n_think_loops) |
| if self.use_thinking and actual_loops > 1: |
| |
| x_original = x |
| for loop in range(actual_loops): |
| for layer in self.shared_layers: |
| x = layer(x, pad_mask) |
| if loop < actual_loops - 1: |
| |
| gate = torch.sigmoid(self.think_gate) |
| x = gate * x + (1 - gate) * x_original |
| else: |
| for layer in self.shared_layers: |
| x = layer(x, pad_mask) |
| |
| |
| for layer in self.exit_layers: |
| x = layer(x, pad_mask) |
| |
| |
| x = self.final_norm(x) |
| logits = self.output_head(x) |
| |
| result = {"logits": logits} |
| |
| if labels is not None: |
| |
| shift_logits = logits[:, :-1, :].contiguous() |
| shift_labels = labels[:, 1:].contiguous() |
| |
| loss = F.cross_entropy( |
| shift_logits.view(-1, shift_logits.size(-1)), |
| shift_labels.view(-1), |
| ignore_index=0, |
| label_smoothing=0.05 |
| ) |
| result["loss"] = loss |
| |
| return result |
|
|
| @torch.no_grad() |
| def generate(self, input_ids: torch.Tensor, max_new_tokens: int = 100, |
| temperature: float = 0.7, top_k: int = 50, top_p: float = 0.9, |
| n_think_loops: int = 1, eos_token_id: int = 2) -> torch.Tensor: |
| """ |
| Autoregressive generation with temperature, top-k, and top-p sampling. |
| |
| Uses nucleus (top-p) sampling for diverse but coherent generation. |
| """ |
| self.eval() |
| generated = input_ids.clone() |
| |
| for _ in range(max_new_tokens): |
| |
| context = generated[:, -self.config['max_seq_len']:] |
| |
| outputs = self.forward(context, n_think_loops=n_think_loops) |
| logits = outputs["logits"][:, -1, :] / max(temperature, 1e-5) |
| |
| |
| if top_k > 0: |
| top_k_val = min(top_k, logits.size(-1)) |
| indices_to_remove = logits < torch.topk(logits, top_k_val)[0][..., -1, None] |
| logits[indices_to_remove] = float('-inf') |
| |
| |
| if top_p < 1.0: |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
| cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| sorted_indices_to_remove = cumulative_probs > top_p |
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
| sorted_indices_to_remove[..., 0] = 0 |
| indices_to_remove = sorted_indices_to_remove.scatter( |
| 1, sorted_indices, sorted_indices_to_remove) |
| logits[indices_to_remove] = float('-inf') |
| |
| probs = F.softmax(logits, dim=-1) |
| next_token = torch.multinomial(probs, num_samples=1) |
| generated = torch.cat([generated, next_token], dim=1) |
| |
| if next_token.item() == eos_token_id: |
| break |
| |
| return generated |
|
|
| def save(self, path: str): |
| """Save model state dict and config.""" |
| import os, json |
| os.makedirs(path, exist_ok=True) |
| torch.save(self.state_dict(), os.path.join(path, "model.pt")) |
| with open(os.path.join(path, "config.json"), 'w') as f: |
| json.dump(self.config, f, indent=2) |
| print(f"Model saved to {path}") |
|
|
| @classmethod |
| def load(cls, path: str, device: str = 'cpu') -> 'NanoReasoningModel': |
| """Load model from saved state.""" |
| import os, json |
| with open(os.path.join(path, "config.json"), 'r') as f: |
| config = json.load(f) |
| model = cls(config) |
| model.load_state_dict(torch.load(os.path.join(path, "model.pt"), |
| map_location=device)) |
| return model |