File size: 11,152 Bytes
6e14a07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ee3fce
6e14a07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
"""
Nano Reasoning Model (NRM) - Main Architecture

ARCHITECTURE DESIGN PHILOSOPHY:
================================
This model maximizes reasoning ability per parameter through several key innovations:

1. SHARED LAYERS: The middle layers are shared (looped through multiple times).
   This creates a form of "iterative refinement" - the model processes information
   multiple passes, similar to how recurrent networks process sequences but applied
   to depth instead. This is inspired by Universal Transformers and ALBERT.
   
   WHY IT HELPS REASONING: Reasoning often requires iterative refinement of
   intermediate representations. Shared layers let the model "think more" without
   more parameters.

2. THINKING TOKENS: Special <THINK> and </THINK> tokens create a "scratchpad"
   where the model can show intermediate reasoning steps. The model is trained to
   use <STEP> tokens for each logical step.
   
   WHY IT HELPS: Decomposing complex problems into steps is THE key capability
   for reasoning. Even large models benefit from chain-of-thought prompting.

3. WEIGHT TYING: Input and output embeddings share the same weight matrix.
   This halves the embedding parameter count and creates a natural link between
   token understanding and token generation.
   
   WHY IT HELPS CPU: Fewer parameters = faster forward/backward passes.

4. LOW-RANK PROJECTIONS: All attention and MLP projections use LoRA-style
   factored matrices, cutting parameter count by ~8x in linear layers.

5. GROUPED QUERY ATTENTION: KV heads are shared across query heads,
   reducing KV projection parameters and memory.

PARAMETER BUDGET (~10M):
  Embedding: 2048 * 256 = 524K (shared with output head)
  Per unique layer: ~200K
  4 unique + 2 shared (run 2x) = 6 effective layers
  Total: ~2.1M (layers) + 524K (embed) ≈ 2.6M unique params
  Effective computation: ~3.1M param equivalent
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Dict
from components import TransformerBlock, RMSNorm


class NanoReasoningModel(nn.Module):
    def __init__(self, config: dict):
        super().__init__()
        self.config = config
        
        d_model = config['d_model']
        n_heads = config['n_heads']
        n_layers = config['n_layers']
        n_shared = config.get('n_shared_layers', 2)
        d_ff = config['d_ff']
        vocab_size = config['vocab_size']
        max_seq_len = config['max_seq_len']
        dropout = config.get('dropout', 0.05)
        rank = config.get('lora_rank', 16)
        self.use_thinking = config.get('use_thinking_tokens', True)
        self.n_thinking_steps = config.get('n_thinking_steps', 2)
        n_kv_heads = config.get('n_kv_heads', n_heads // 2)
        
        # Token embeddings (will be tied with output head)
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.embedding_dropout = nn.Dropout(dropout)
        
        # Entry layers (unique)
        n_unique = n_layers - n_shared
        self.entry_layers = nn.ModuleList([
            TransformerBlock(d_model, n_heads, d_ff, rank, dropout, max_seq_len, n_kv_heads)
            for _ in range(n_unique // 2)
        ])
        
        # Shared layers (looped)
        self.shared_layers = nn.ModuleList([
            TransformerBlock(d_model, n_heads, d_ff, rank, dropout, max_seq_len, n_kv_heads)
            for _ in range(n_shared)
        ])
        
        # Exit layers (unique)
        self.exit_layers = nn.ModuleList([
            TransformerBlock(d_model, n_heads, d_ff, rank, dropout, max_seq_len, n_kv_heads)
            for _ in range(n_unique - n_unique // 2)
        ])
        
        # Final norm
        self.final_norm = RMSNorm(d_model)
        
        # Output head (tied with embeddings)
        self.output_head = nn.Linear(d_model, vocab_size, bias=False)
        
        if config.get('weight_tying', True):
            self.output_head.weight = self.token_embedding.weight
        
        # Thinking step gate: learned scalar for blending thinking iterations
        if self.use_thinking:
            self.think_gate = nn.Parameter(torch.tensor(0.5))
        
        # Initialize weights
        self.apply(self._init_weights)
        
        # Count parameters
        self._count_parameters()

    def _init_weights(self, module: nn.Module):
        """Initialize weights with scaled initialization for stability."""
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def _count_parameters(self):
        """Count and report parameters."""
        total = sum(p.numel() for p in self.parameters())
        trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
        
        # Count unique parameters (shared layers counted once)
        unique = sum(p.numel() for p in self.parameters())
        
        self.total_params = total
        self.trainable_params = trainable
        print(f"\n{'='*50}")
        print(f"NRM Model Configuration:")
        print(f"  d_model: {self.config['d_model']}")
        print(f"  n_heads: {self.config['n_heads']}")
        print(f"  n_layers: {self.config['n_layers']} "
              f"({len(self.entry_layers)} entry + {len(self.shared_layers)} shared + {len(self.exit_layers)} exit)")
        print(f"  d_ff: {self.config['d_ff']}")
        print(f"  vocab_size: {self.config['vocab_size']}")
        print(f"  LoRA rank: {self.config.get('lora_rank', 16)}")
        print(f"  Thinking: {'enabled' if self.use_thinking else 'disabled'}")
        print(f"  Total parameters: {total:,}")
        print(f"  Trainable parameters: {trainable:,}")
        print(f"{'='*50}\n")

    def forward(self, input_ids: torch.Tensor, 
                attention_mask: Optional[torch.Tensor] = None,
                labels: Optional[torch.Tensor] = None,
                n_think_loops: int = 1) -> Dict[str, torch.Tensor]:
        """
        Forward pass with optional thinking loops.
        
        n_think_loops: How many times to loop through shared layers.
        During reasoning, we increase this to give the model more "thinking time".
        """
        B, T = input_ids.shape
        
        # Embeddings
        x = self.token_embedding(input_ids)
        x = self.embedding_dropout(x)
        
        # Padding mask
        pad_mask = None
        if attention_mask is not None:
            pad_mask = (attention_mask == 0)  # True where padded
        
        # Entry layers
        for layer in self.entry_layers:
            x = layer(x, pad_mask)
        
        # Shared layers with thinking loops
        actual_loops = max(1, n_think_loops)
        if self.use_thinking and actual_loops > 1:
            # Store the "pre-think" state
            x_original = x
            for loop in range(actual_loops):
                for layer in self.shared_layers:
                    x = layer(x, pad_mask)
                if loop < actual_loops - 1:
                    # Blend with original (residual thinking)
                    gate = torch.sigmoid(self.think_gate)
                    x = gate * x + (1 - gate) * x_original
        else:
            for layer in self.shared_layers:
                x = layer(x, pad_mask)
        
        # Exit layers
        for layer in self.exit_layers:
            x = layer(x, pad_mask)
        
        # Output
        x = self.final_norm(x)
        logits = self.output_head(x)
        
        result = {"logits": logits}
        
        if labels is not None:
            # Shift for autoregressive loss
            shift_logits = logits[:, :-1, :].contiguous()
            shift_labels = labels[:, 1:].contiguous()
            
            loss = F.cross_entropy(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1),
                ignore_index=0,  # PAD token
                label_smoothing=0.05  # Slight smoothing for better generalization
            )
            result["loss"] = loss
        
        return result

    @torch.no_grad()
    def generate(self, input_ids: torch.Tensor, max_new_tokens: int = 100,
                 temperature: float = 0.7, top_k: int = 50, top_p: float = 0.9,
                 n_think_loops: int = 1, eos_token_id: int = 2) -> torch.Tensor:
        """
        Autoregressive generation with temperature, top-k, and top-p sampling.
        
        Uses nucleus (top-p) sampling for diverse but coherent generation.
        """
        self.eval()
        generated = input_ids.clone()
        
        for _ in range(max_new_tokens):
            # Truncate to max_seq_len
            context = generated[:, -self.config['max_seq_len']:]
            
            outputs = self.forward(context, n_think_loops=n_think_loops)
            logits = outputs["logits"][:, -1, :] / max(temperature, 1e-5)
            
            # Top-k filtering
            if top_k > 0:
                top_k_val = min(top_k, logits.size(-1))
                indices_to_remove = logits < torch.topk(logits, top_k_val)[0][..., -1, None]
                logits[indices_to_remove] = float('-inf')
            
            # Top-p (nucleus) filtering
            if top_p < 1.0:
                sorted_logits, sorted_indices = torch.sort(logits, descending=True)
                cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                sorted_indices_to_remove[..., 0] = 0
                indices_to_remove = sorted_indices_to_remove.scatter(
                    1, sorted_indices, sorted_indices_to_remove)
                logits[indices_to_remove] = float('-inf')
            
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            generated = torch.cat([generated, next_token], dim=1)
            
            if next_token.item() == eos_token_id:
                break
        
        return generated

    def save(self, path: str):
        """Save model state dict and config."""
        import os, json
        os.makedirs(path, exist_ok=True)
        torch.save(self.state_dict(), os.path.join(path, "model.pt"))
        with open(os.path.join(path, "config.json"), 'w') as f:
            json.dump(self.config, f, indent=2)
        print(f"Model saved to {path}")

    @classmethod
    def load(cls, path: str, device: str = 'cpu') -> 'NanoReasoningModel':
        """Load model from saved state."""
        import os, json
        with open(os.path.join(path, "config.json"), 'r') as f:
            config = json.load(f)
        model = cls(config)
        model.load_state_dict(torch.load(os.path.join(path, "model.pt"), 
                                          map_location=device))
        return model