theapemachine commited on
Commit
c7bee4d
·
verified ·
1 Parent(s): 82900ee

Add cortex/adaptive_depth.py

Browse files
Files changed (1) hide show
  1. cortex/adaptive_depth.py +142 -0
cortex/adaptive_depth.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AdaptiveDepth: Dynamic layer skipping with learned gates.
3
+
4
+ Inspired by GateSkip (2025), Mixture of Depths (Raposo et al. 2024), and
5
+ Router-Tuning (2024).
6
+
7
+ Architecture:
8
+ - Each transformer layer gets a lightweight binary gate: g ∈ (0, 1)
9
+ - The gate decides per-token whether to execute the layer or skip it
10
+ - Skip = identity (hidden states pass through unchanged)
11
+ - Execute = normal layer forward + gated residual
12
+ - Gates are trained to minimize computation while maintaining quality
13
+ - A budget constraint ensures the model uses a target % of layers per token
14
+
15
+ Failure mode addressed:
16
+ - Fixed compute: All tokens get the same computation depth regardless of difficulty.
17
+ "The" doesn't need 32 layers of processing, but a complex reasoning step might need all of them.
18
+ - Wasted compute: Many layers are near-identity for "easy" tokens.
19
+ - Latency: Dynamic depth enables significant speedup on average.
20
+ - Overthinking: Too many layers can sometimes HURT performance (representation collapse).
21
+ Adaptive depth protects against this.
22
+
23
+ Injection point: POST_FFN
24
+ - Rationale: The gate wraps the entire layer's contribution to the residual stream.
25
+ It decides: "Was this layer's update useful for this token?"
26
+ """
27
+
28
+ import torch
29
+ import torch.nn as nn
30
+ import torch.nn.functional as F
31
+ from typing import Optional, Union, List
32
+ from cortex.core import CortexModule, InjectionPoint
33
+
34
+
35
+ class AdaptiveDepth(CortexModule):
36
+ """
37
+ Token-wise layer gating for dynamic computation depth.
38
+
39
+ Uses a sigmoid-linear gate: the gate output is in (0, 1) and directly
40
+ scales the layer's residual update. During inference, gates below a
41
+ threshold can be rounded to 0 for actual compute savings.
42
+
43
+ Training uses a straight-through estimator for the hard gate, plus a
44
+ budget regularization loss.
45
+
46
+ Args:
47
+ hidden_dim: Model hidden dimension
48
+ target_budget: Target fraction of layers to use per token (0-1)
49
+ gate_type: "sigmoid" (soft), "straight_through" (hard during forward, soft backward)
50
+ temperature: Temperature for gating (lower = more binary)
51
+ budget_loss_weight: Weight for the budget regularization loss
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ hidden_dim: int,
57
+ target_budget: float = 0.7,
58
+ gate_type: str = "sigmoid",
59
+ temperature: float = 1.0,
60
+ budget_loss_weight: float = 0.01,
61
+ target_layers: Union[List[int], str] = "all",
62
+ ):
63
+ super().__init__(InjectionPoint.POST_FFN, target_layers)
64
+
65
+ self.hidden_dim = hidden_dim
66
+ self.target_budget = target_budget
67
+ self.gate_type = gate_type
68
+ self.temperature = temperature
69
+ self.budget_loss_weight = budget_loss_weight
70
+
71
+ # Gate network: maps hidden state to a scalar gate per token
72
+ self.gate_net = nn.Sequential(
73
+ nn.Linear(hidden_dim, hidden_dim // 4),
74
+ nn.GELU(),
75
+ nn.Linear(hidden_dim // 4, 1),
76
+ )
77
+
78
+ # Initialize gate to be "open" (execute layer) by default
79
+ nn.init.constant_(self.gate_net[-1].bias, 2.0) # sigmoid(2) ≈ 0.88
80
+
81
+ # Buffers for monitoring
82
+ self.register_buffer("_pre_layer_hidden", None, persistent=False)
83
+ self.register_buffer("_gate_values", None, persistent=False)
84
+ self.register_buffer("_budget_loss", torch.tensor(0.0), persistent=False)
85
+
86
+ def store_input(self, hidden_states: torch.Tensor):
87
+ """Store the input to the layer (called via pre-hook)."""
88
+ self._pre_layer_hidden = hidden_states.detach()
89
+
90
+ def forward(
91
+ self,
92
+ hidden_states: torch.Tensor,
93
+ layer_idx: int,
94
+ **kwargs
95
+ ) -> torch.Tensor:
96
+ """
97
+ Gate the layer's residual contribution.
98
+
99
+ post_layer = pre_layer + gate * (post_layer - pre_layer)
100
+
101
+ When gate = 1: post_layer (use full layer output)
102
+ When gate = 0: pre_layer (skip layer entirely)
103
+ """
104
+ # Compute gate value per token
105
+ gate_logit = self.gate_net(hidden_states) / self.temperature # [B, T, 1]
106
+ gate = torch.sigmoid(gate_logit)
107
+
108
+ # Straight-through estimator for hard gating
109
+ if self.gate_type == "straight_through" and self.training:
110
+ hard_gate = (gate > 0.5).float()
111
+ gate = hard_gate - gate.detach() + gate # STE
112
+
113
+ self._gate_values = gate.detach()
114
+
115
+ # Gate the output: scale by gate, preserve gradients
116
+ gated_output = gate * hidden_states + (1 - gate) * hidden_states.detach()
117
+
118
+ # Budget regularization loss
119
+ avg_gate = gate.mean()
120
+ budget_loss = self.budget_loss_weight * (avg_gate - self.target_budget).pow(2)
121
+ self._budget_loss = budget_loss.detach()
122
+
123
+ return gated_output
124
+
125
+ def get_gate_stats(self) -> dict:
126
+ """Return statistics about gate usage."""
127
+ if self._gate_values is None:
128
+ return {"mean": 0.0, "std": 0.0, "skip_frac": 0.0}
129
+ g = self._gate_values
130
+ return {
131
+ "mean": g.mean().item(),
132
+ "std": g.std().item(),
133
+ "skip_frac": (g < 0.5).float().mean().item(),
134
+ }
135
+
136
+ def get_budget_loss(self) -> torch.Tensor:
137
+ """Return the budget regularization loss (add to main loss)."""
138
+ return self._budget_loss
139
+
140
+ def extra_repr(self):
141
+ return (f"hidden_dim={self.hidden_dim}, target_budget={self.target_budget}, "
142
+ f"gate_type={self.gate_type}, {super().extra_repr()}")