""" Nano Reasoning Model (NRM) - Main Architecture ARCHITECTURE DESIGN PHILOSOPHY: ================================ This model maximizes reasoning ability per parameter through several key innovations: 1. SHARED LAYERS: The middle layers are shared (looped through multiple times). This creates a form of "iterative refinement" - the model processes information multiple passes, similar to how recurrent networks process sequences but applied to depth instead. This is inspired by Universal Transformers and ALBERT. WHY IT HELPS REASONING: Reasoning often requires iterative refinement of intermediate representations. Shared layers let the model "think more" without more parameters. 2. THINKING TOKENS: Special and tokens create a "scratchpad" where the model can show intermediate reasoning steps. The model is trained to use tokens for each logical step. WHY IT HELPS: Decomposing complex problems into steps is THE key capability for reasoning. Even large models benefit from chain-of-thought prompting. 3. WEIGHT TYING: Input and output embeddings share the same weight matrix. This halves the embedding parameter count and creates a natural link between token understanding and token generation. WHY IT HELPS CPU: Fewer parameters = faster forward/backward passes. 4. LOW-RANK PROJECTIONS: All attention and MLP projections use LoRA-style factored matrices, cutting parameter count by ~8x in linear layers. 5. GROUPED QUERY ATTENTION: KV heads are shared across query heads, reducing KV projection parameters and memory. PARAMETER BUDGET (~10M): Embedding: 2048 * 256 = 524K (shared with output head) Per unique layer: ~200K 4 unique + 2 shared (run 2x) = 6 effective layers Total: ~2.1M (layers) + 524K (embed) ≈ 2.6M unique params Effective computation: ~3.1M param equivalent """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Dict from components import TransformerBlock, RMSNorm class NanoReasoningModel(nn.Module): def __init__(self, config: dict): super().__init__() self.config = config d_model = config['d_model'] n_heads = config['n_heads'] n_layers = config['n_layers'] n_shared = config.get('n_shared_layers', 2) d_ff = config['d_ff'] vocab_size = config['vocab_size'] max_seq_len = config['max_seq_len'] dropout = config.get('dropout', 0.05) rank = config.get('lora_rank', 16) self.use_thinking = config.get('use_thinking_tokens', True) self.n_thinking_steps = config.get('n_thinking_steps', 2) n_kv_heads = config.get('n_kv_heads', n_heads // 2) # Token embeddings (will be tied with output head) self.token_embedding = nn.Embedding(vocab_size, d_model) self.embedding_dropout = nn.Dropout(dropout) # Entry layers (unique) n_unique = n_layers - n_shared self.entry_layers = nn.ModuleList([ TransformerBlock(d_model, n_heads, d_ff, rank, dropout, max_seq_len, n_kv_heads) for _ in range(n_unique // 2) ]) # Shared layers (looped) self.shared_layers = nn.ModuleList([ TransformerBlock(d_model, n_heads, d_ff, rank, dropout, max_seq_len, n_kv_heads) for _ in range(n_shared) ]) # Exit layers (unique) self.exit_layers = nn.ModuleList([ TransformerBlock(d_model, n_heads, d_ff, rank, dropout, max_seq_len, n_kv_heads) for _ in range(n_unique - n_unique // 2) ]) # Final norm self.final_norm = RMSNorm(d_model) # Output head (tied with embeddings) self.output_head = nn.Linear(d_model, vocab_size, bias=False) if config.get('weight_tying', True): self.output_head.weight = self.token_embedding.weight # Thinking step gate: learned scalar for blending thinking iterations if self.use_thinking: self.think_gate = nn.Parameter(torch.tensor(0.5)) # Initialize weights self.apply(self._init_weights) # Count parameters self._count_parameters() def _init_weights(self, module: nn.Module): """Initialize weights with scaled initialization for stability.""" if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) def _count_parameters(self): """Count and report parameters.""" total = sum(p.numel() for p in self.parameters()) trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) # Count unique parameters (shared layers counted once) unique = sum(p.numel() for p in self.parameters()) self.total_params = total self.trainable_params = trainable print(f"\n{'='*50}") print(f"NRM Model Configuration:") print(f" d_model: {self.config['d_model']}") print(f" n_heads: {self.config['n_heads']}") print(f" n_layers: {self.config['n_layers']} " f"({len(self.entry_layers)} entry + {len(self.shared_layers)} shared + {len(self.exit_layers)} exit)") print(f" d_ff: {self.config['d_ff']}") print(f" vocab_size: {self.config['vocab_size']}") print(f" LoRA rank: {self.config.get('lora_rank', 16)}") print(f" Thinking: {'enabled' if self.use_thinking else 'disabled'}") print(f" Total parameters: {total:,}") print(f" Trainable parameters: {trainable:,}") print(f"{'='*50}\n") def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, n_think_loops: int = 1) -> Dict[str, torch.Tensor]: """ Forward pass with optional thinking loops. n_think_loops: How many times to loop through shared layers. During reasoning, we increase this to give the model more "thinking time". """ B, T = input_ids.shape # Embeddings x = self.token_embedding(input_ids) x = self.embedding_dropout(x) # Padding mask pad_mask = None if attention_mask is not None: pad_mask = (attention_mask == 0) # True where padded # Entry layers for layer in self.entry_layers: x = layer(x, pad_mask) # Shared layers with thinking loops actual_loops = max(1, n_think_loops) if self.use_thinking and actual_loops > 1: # Store the "pre-think" state x_original = x for loop in range(actual_loops): for layer in self.shared_layers: x = layer(x, pad_mask) if loop < actual_loops - 1: # Blend with original (residual thinking) gate = torch.sigmoid(self.think_gate) x = gate * x + (1 - gate) * x_original else: for layer in self.shared_layers: x = layer(x, pad_mask) # Exit layers for layer in self.exit_layers: x = layer(x, pad_mask) # Output x = self.final_norm(x) logits = self.output_head(x) result = {"logits": logits} if labels is not None: # Shift for autoregressive loss shift_logits = logits[:, :-1, :].contiguous() shift_labels = labels[:, 1:].contiguous() loss = F.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=0, # PAD token label_smoothing=0.05 # Slight smoothing for better generalization ) result["loss"] = loss return result @torch.no_grad() def generate(self, input_ids: torch.Tensor, max_new_tokens: int = 100, temperature: float = 0.7, top_k: int = 50, top_p: float = 0.9, n_think_loops: int = 1, eos_token_id: int = 2) -> torch.Tensor: """ Autoregressive generation with temperature, top-k, and top-p sampling. Uses nucleus (top-p) sampling for diverse but coherent generation. """ self.eval() generated = input_ids.clone() for _ in range(max_new_tokens): # Truncate to max_seq_len context = generated[:, -self.config['max_seq_len']:] outputs = self.forward(context, n_think_loops=n_think_loops) logits = outputs["logits"][:, -1, :] / max(temperature, 1e-5) # Top-k filtering if top_k > 0: top_k_val = min(top_k, logits.size(-1)) indices_to_remove = logits < torch.topk(logits, top_k_val)[0][..., -1, None] logits[indices_to_remove] = float('-inf') # Top-p (nucleus) filtering if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices_to_remove.scatter( 1, sorted_indices, sorted_indices_to_remove) logits[indices_to_remove] = float('-inf') probs = F.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) generated = torch.cat([generated, next_token], dim=1) if next_token.item() == eos_token_id: break return generated def save(self, path: str): """Save model state dict and config.""" import os, json os.makedirs(path, exist_ok=True) torch.save(self.state_dict(), os.path.join(path, "model.pt")) with open(os.path.join(path, "config.json"), 'w') as f: json.dump(self.config, f, indent=2) print(f"Model saved to {path}") @classmethod def load(cls, path: str, device: str = 'cpu') -> 'NanoReasoningModel': """Load model from saved state.""" import os, json with open(os.path.join(path, "config.json"), 'r') as f: config = json.load(f) model = cls(config) model.load_state_dict(torch.load(os.path.join(path, "model.pt"), map_location=device)) return model