LLaDA-8B-Recursive-ARC / modeling_recursive.py
Fraser's picture
Upload folder using huggingface_hub
f98df9d verified
from __future__ import annotations
import warnings
from dataclasses import dataclass
from typing import NamedTuple, Optional
import torch
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from torch.utils.checkpoint import checkpoint as torch_checkpoint
from transformers import AutoConfig, AutoModelForMaskedLM, PreTrainedModel
from transformers.modeling_outputs import MaskedLMOutput
from transformers.utils import ModelOutput
from .configuration_recursive import RecursiveMLMConfig
@dataclass
class IterationMetrics(ModelOutput):
"""Metrics for a single iteration of recursive refinement."""
accuracy: Optional[float] = None
entropy: Optional[float] = None
softmax_ce: Optional[float] = None
full_sequence_accuracy: Optional[float] = None
min_sequence_confidence: Optional[float] = None
@dataclass
class RecursiveMaskedLMOutput(MaskedLMOutput):
iteration_metrics: Optional[dict[int, IterationMetrics]] = None # Maps iteration index to metrics
next_soft_embeds: Optional[torch.Tensor] = None # For caching between training steps
all_logits: Optional[list[torch.Tensor]] = None # All T iterations' logits for trainer loss computation
# Flow matching state (for distillation — compact H-dim, not V-dim)
flow_noise_embed: Optional[torch.Tensor] = None # (num_masked, H) noise embedding
flow_t: Optional[torch.Tensor] = None # (num_masked,) per-token time levels
class SelfDistillationOutput(NamedTuple):
"""Output from self-distillation forward pass."""
loss: torch.Tensor # KL divergence loss (scalar, has grad)
teacher_logits: torch.Tensor # For metrics/debugging (detached)
student_logits: torch.Tensor # For metrics/debugging (has grad)
degradation_temperature: float # Mean per-token temperature sampled
teacher_entropy: float # Entropy of teacher distribution (for monitoring)
student_entropy: float # Entropy of student distribution (for monitoring)
agreement_rate: float # Fraction where teacher and student argmax agree
class RecursiveMaskedLM(PreTrainedModel):
"""
Wraps any HF MLM with recursive soft-token refinement.
At each step:
1. Normalize logits -> probs
2. Compute soft embeddings: probs @ embedding_weight + mask_embedding
3. Forward through MLM
4. Accumulate weighted loss
"""
config_class = RecursiveMLMConfig
base_model_prefix = "mlm"
supports_gradient_checkpointing = True
def __init__(self, config: RecursiveMLMConfig, base_model: Optional[PreTrainedModel] = None):
super().__init__(config)
if base_model is not None:
# Pre-trained model provided - assign directly WITHOUT calling post_init()
# to avoid reinitializing the pre-trained weights via _init_weights()
self.mlm = base_model
elif config.base_model_config is not None:
model_type = config.base_model_config.get("model_type", "")
if model_type == "llada":
from .configuration_llada import LLaDAConfig
from .modeling_llada import LLaDAModelLM
base_config = LLaDAConfig.from_dict(config.base_model_config)
self.mlm = LLaDAModelLM(base_config)
else:
base_config = AutoConfig.for_model(**config.base_model_config)
self.mlm = AutoModelForMaskedLM.from_config(base_config)
# Only call post_init() for freshly created models (needs weight init)
self.post_init()
else:
raise ValueError("Need either base_model or config.base_model_config")
@classmethod
def from_mlm_pretrained(
cls,
mlm_name_or_path: str,
num_recursions: int = 8,
normalization: str = "softmax",
loss_weight: str = "linear",
mask_token_id: Optional[int] = None,
temperature: float = 1.0,
gradient_steps: Optional[int] = None,
# === Convergence schedule parameters ===
schedule: str = "linear",
causal_strength: float = 1.0,
# === Effect parameters ===
temperature_max: float = 0.0,
entropy_target_max: float = 0.0,
entropy_floor_max: float = 0.0,
smear_sigma_max: float = 0.0,
noise_std_max: float = 0.0,
iteration_rope_dim_fraction: float = 0.0,
use_recursion_checkpointing: bool = True,
# === Soft embedding method ===
soft_embedding_method: str = "softmax",
soft_embedding_ema_step: float = 1.0,
# === Flow matching parameters ===
flow_matching_enabled: bool = False,
flow_matching_lambda: float = 0.5,
flow_matching_t_distribution: str = "logit_normal",
flow_matching_t_logit_mean: float = -0.4,
flow_matching_t_logit_std: float = 1.0,
flow_matching_t_min: float = 0.01,
flow_matching_t_max: float = 0.99,
flow_matching_mask_scale: bool = False,
**model_kwargs,
) -> "RecursiveMaskedLM":
"""Load a pretrained MLM and wrap it for recursive refinement."""
base_model = AutoModelForMaskedLM.from_pretrained(mlm_name_or_path, **model_kwargs)
return cls.from_base_model(
base_model,
num_recursions=num_recursions,
normalization=normalization,
loss_weight=loss_weight,
mask_token_id=mask_token_id,
temperature=temperature,
gradient_steps=gradient_steps,
schedule=schedule,
causal_strength=causal_strength,
temperature_max=temperature_max,
entropy_target_max=entropy_target_max,
entropy_floor_max=entropy_floor_max,
smear_sigma_max=smear_sigma_max,
noise_std_max=noise_std_max,
iteration_rope_dim_fraction=iteration_rope_dim_fraction,
use_recursion_checkpointing=use_recursion_checkpointing,
soft_embedding_method=soft_embedding_method,
soft_embedding_ema_step=soft_embedding_ema_step,
flow_matching_enabled=flow_matching_enabled,
flow_matching_lambda=flow_matching_lambda,
flow_matching_t_distribution=flow_matching_t_distribution,
flow_matching_t_logit_mean=flow_matching_t_logit_mean,
flow_matching_t_logit_std=flow_matching_t_logit_std,
flow_matching_t_min=flow_matching_t_min,
flow_matching_t_max=flow_matching_t_max,
flow_matching_mask_scale=flow_matching_mask_scale,
)
@classmethod
def from_base_model(
cls,
base_model: PreTrainedModel,
num_recursions: int = 8,
normalization: str = "softmax",
loss_weight: str = "linear",
mask_token_id: Optional[int] = None,
temperature: float = 1.0,
gradient_steps: Optional[int] = None,
# === Convergence schedule parameters ===
schedule: str = "linear",
causal_strength: float = 1.0,
# === Effect parameters ===
temperature_max: float = 0.0,
entropy_target_max: float = 0.0,
entropy_floor_max: float = 0.0,
smear_sigma_max: float = 0.0,
noise_std_max: float = 0.0,
iteration_rope_dim_fraction: float = 0.0,
use_recursion_checkpointing: bool = True,
# === Soft embedding method ===
soft_embedding_method: str = "softmax",
soft_embedding_ema_step: float = 1.0,
# === Flow matching parameters ===
flow_matching_enabled: bool = False,
flow_matching_lambda: float = 0.5,
flow_matching_t_distribution: str = "logit_normal",
flow_matching_t_logit_mean: float = -0.4,
flow_matching_t_logit_std: float = 1.0,
flow_matching_t_min: float = 0.01,
flow_matching_t_max: float = 0.99,
flow_matching_mask_scale: bool = False,
) -> "RecursiveMaskedLM":
"""Wrap an existing model for recursive refinement.
Use this for models not loadable via AutoModelForMaskedLM (e.g., LLaDA).
Args:
base_model: The base MLM model to wrap
num_recursions: Number of recursive refinement steps
normalization: Normalization method for logits (softmax, stable_softmax)
loss_weight: Loss weighting scheme (last_1, last_2, linear, uniform)
mask_token_id: Token ID for [MASK]
temperature: Temperature for softmax normalization
gradient_steps: Number of final steps to backprop through
schedule: Convergence schedule type ("linear" or "causal")
causal_strength: How much faster early positions converge (causal only)
temperature_max: Max temperature boost for uncertain positions
entropy_target_max: Target entropy at progress=0 (two-sided, recommended)
entropy_floor_max: Min entropy floor (one-sided)
smear_sigma_max: Max Gaussian sigma for position smearing
noise_std_max: Max std of Gaussian noise on logits
iteration_rope_dim_fraction: Fraction of dims for iteration RoPE
use_recursion_checkpointing: Enable gradient checkpointing for iterations
soft_embedding_method: How to convert logits to soft embeddings
soft_embedding_ema_step: EMA step size (1.0 = no EMA, <1.0 = blend with previous)
flow_matching_enabled: Enable CFM-inspired flow matching framework
flow_matching_lambda: Weight of distillation KL loss relative to CE
flow_matching_t_distribution: Time sampling distribution ("logit_normal" or "uniform")
flow_matching_t_logit_mean: Mean of logit-normal distribution
flow_matching_t_logit_std: Std of logit-normal distribution
flow_matching_t_min: Minimum time value (clamp)
flow_matching_t_max: Maximum time value (clamp)
flow_matching_mask_scale: Scale mask_emb by (1-t) if True, binary if False
"""
config = RecursiveMLMConfig.from_base_model_config(
base_model.config,
num_recursions=num_recursions,
normalization=normalization,
loss_weight=loss_weight,
mask_token_id=mask_token_id,
temperature=temperature,
gradient_steps=gradient_steps,
schedule=schedule,
causal_strength=causal_strength,
temperature_max=temperature_max,
entropy_target_max=entropy_target_max,
entropy_floor_max=entropy_floor_max,
smear_sigma_max=smear_sigma_max,
noise_std_max=noise_std_max,
iteration_rope_dim_fraction=iteration_rope_dim_fraction,
use_recursion_checkpointing=use_recursion_checkpointing,
soft_embedding_method=soft_embedding_method,
soft_embedding_ema_step=soft_embedding_ema_step,
flow_matching_enabled=flow_matching_enabled,
flow_matching_lambda=flow_matching_lambda,
flow_matching_t_distribution=flow_matching_t_distribution,
flow_matching_t_logit_mean=flow_matching_t_logit_mean,
flow_matching_t_logit_std=flow_matching_t_logit_std,
flow_matching_t_min=flow_matching_t_min,
flow_matching_t_max=flow_matching_t_max,
flow_matching_mask_scale=flow_matching_mask_scale,
)
return cls(config, base_model=base_model)
@property
def embed_weight(self) -> torch.Tensor:
return self.mlm.get_input_embeddings().weight
def get_input_embeddings(self):
return self.mlm.get_input_embeddings()
def set_input_embeddings(self, value):
self.mlm.set_input_embeddings(value)
def get_output_embeddings(self):
return self.mlm.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
self.mlm.set_output_embeddings(new_embeddings)
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
"""Enable gradient checkpointing with correct settings for recursion.
Forces use_reentrant=False which is required for:
- Nested checkpoint calls (base model + recursion checkpointing)
- Models with frozen parameters
- Complex gradient flows through soft embeddings
"""
if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {}
# Force use_reentrant=False for nested checkpointing compatibility
gradient_checkpointing_kwargs.setdefault("use_reentrant", False)
self.mlm.gradient_checkpointing_enable(gradient_checkpointing_kwargs)
def gradient_checkpointing_disable(self):
"""Disable gradient checkpointing in the underlying MLM."""
self.mlm.gradient_checkpointing_disable()
def _single_iteration_checkpointable(
self,
soft_embeds: torch.Tensor,
base_embeds: torch.Tensor,
mask_pos: torch.Tensor,
attention_mask: torch.Tensor,
embed_weight: torch.Tensor,
mask_emb: torch.Tensor,
temperature: torch.Tensor,
position_ids: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Single differentiable iteration for checkpointing.
This method performs one iteration of recursive refinement in a way that
maintains gradient flow and is compatible with torch.utils.checkpoint.
Args:
soft_embeds: (B, L, H) - current soft embeddings
base_embeds: (B, L, H) - original token embeddings
mask_pos: (B, L) bool - which positions are masked
attention_mask: (B, L) - attention mask for MLM
embed_weight: (V, H) - embedding weight matrix
mask_emb: (H,) - mask token embedding
temperature: scalar tensor - softmax temperature
Returns:
logits: (B, L, V) - output logits from this iteration
next_soft_embeds: (B, L, H) - soft embeddings for next iteration
"""
# Blend: use soft_embeds at masked positions, base_embeds elsewhere
inputs_embeds = torch.where(mask_pos.unsqueeze(-1), soft_embeds, base_embeds)
# Forward through base MLM
outputs = self.mlm(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
return_dict=True,
)
logits = outputs.logits
# Compute soft embeddings for next iteration (DIFFERENTIABLE - no detach!)
next_soft_embeds = base_embeds.clone()
if mask_pos.any():
masked_logits = logits[mask_pos] # (num_masked, V)
# Convert logits to mixing weights based on soft_embedding_method
if self.config.soft_embedding_method == "none":
# No normalization - use raw logits directly
weights = masked_logits # Differentiable!
elif self.config.soft_embedding_method == "l2_normalize":
# L2 normalize logits - removes softmax bottleneck for smoother gradients
weights = F.normalize(masked_logits, p=2, dim=-1) # Differentiable!
else:
# Default: softmax normalization
weights = F.softmax(masked_logits / temperature, dim=-1) # Differentiable!
soft_emb = weights @ embed_weight + mask_emb # Differentiable!
# Apply EMA blending with previous soft embeddings if enabled
ema_step = self.config.soft_embedding_ema_step
if ema_step < 1.0:
prev_soft_emb = soft_embeds[mask_pos] # Previous iteration's soft embeddings
soft_emb = (1.0 - ema_step) * prev_soft_emb + ema_step * soft_emb
next_soft_embeds[mask_pos] = soft_emb
return logits, next_soft_embeds
def _stable_softmax(self, logits: torch.Tensor, T: float = 1.0, dim: int = -1, eps: float = 1e-12) -> torch.Tensor:
"""Numerically stable softmax with temperature T > 0."""
z = logits / max(T, eps)
z = z - z.max(dim=dim, keepdim=True).values # subtract max
z = torch.exp(z) # safe since z <= 0
z_sum = z.sum(dim=dim, keepdim=True)
return z / z_sum.clamp(min=eps)
def normalize(self, logits: torch.Tensor) -> torch.Tensor:
"""Normalize logits -> mixing weights. Shape: (B, L, V) -> (B, L, V)"""
norm = self.config.normalization.lower()
T = self.config.temperature
V = logits.shape[-1]
if norm == "none":
return logits
if norm == "softmax":
return torch.softmax(logits / T, dim=-1)
if norm == "stable_softmax":
return self._stable_softmax(logits, T=T, dim=-1)
raise ValueError(f"Unknown normalization: {norm}")
def step_weight(self, t: int, T: int) -> float:
"""Loss weight for step t of T."""
lw = self.config.loss_weight
if lw == "linear":
return (t + 1) / T
if lw == "uniform":
return 1.0
if lw == "last_1":
return 1.0 if t == T - 1 else 0.0
if lw == "last_2":
return 1.0 if T - t <= 2 else 0.0
raise ValueError(f"Unknown loss_weight: {lw}")
# ==================== CONVERGENCE SCHEDULE SYSTEM ====================
#
# The core idea: control WHEN each position is allowed to converge.
#
# Schedule types:
# - "linear": All positions converge at the same rate
# - "causal": Early positions converge first, late positions last
#
# Effects (mechanisms to enforce the schedule):
# - temperature: Raise temperature for positions not yet allowed to converge
# - entropy_floor: Force minimum entropy
# - entropy_target: Force exact entropy via bisection (ARChitects-style)
# - smear: Spread probability across neighboring positions
# - noise: Add Gaussian noise to logits
#
# Each effect uses per-position "convergence progress" (0=uncertain, 1=can converge)
def _compute_convergence_progress(
self,
iteration: int,
total_iterations: int,
seq_length: int,
mask_positions: torch.Tensor,
schedule: str = "linear",
causal_strength: float = 1.0,
device: torch.device = None,
dtype: torch.dtype = None,
) -> torch.Tensor:
"""
Compute per-position convergence progress based on schedule.
Args:
iteration: Current iteration (0-indexed)
total_iterations: Total number of iterations
seq_length: Full sequence length L
mask_positions: Position indices of masked tokens (num_masked,)
schedule: "linear" or "causal"
causal_strength: How much faster early positions converge (for causal schedule)
Returns:
progress: (num_masked,) tensor with values in [0, 1]
0 = position should be maximally uncertain
1 = position is allowed to fully converge
"""
base_progress = iteration / max(total_iterations - 1, 1)
if schedule == "linear":
return torch.full(
(mask_positions.shape[0],),
base_progress,
device=device,
dtype=dtype
)
elif schedule == "causal":
position_factor = mask_positions.float() / max(seq_length - 1, 1)
effective_progress = base_progress * (1.0 + causal_strength * (1.0 - position_factor))
return effective_progress.clamp(0.0, 1.0)
else:
raise ValueError(f"Unknown schedule: {schedule}")
def _apply_temperature_effect(
self,
logits: torch.Tensor,
progress: torch.Tensor,
temperature_max: float,
) -> torch.Tensor:
"""
Apply per-position temperature scaling based on convergence progress.
Low progress = high temperature (uncertain), high progress = temperature 1.0.
"""
if temperature_max <= 0:
return logits
temperature = 1.0 + temperature_max * (1.0 - progress)
temperature = temperature.unsqueeze(-1)
return logits / temperature
def _apply_entropy_floor_effect(
self,
probs: torch.Tensor,
progress: torch.Tensor,
entropy_floor_max: float,
) -> torch.Tensor:
"""
Ensure minimum entropy based on convergence progress.
Low progress = high entropy floor, high progress = no floor.
NOTE: This is a ONE-SIDED constraint (floor only).
"""
if entropy_floor_max <= 0:
return probs
entropy_floor = entropy_floor_max * (1.0 - progress)
log_probs = torch.log(probs + 1e-10)
current_entropy = -(probs * log_probs).sum(dim=-1)
below_floor = current_entropy < entropy_floor
if not below_floor.any():
return probs
logits = torch.log(probs + 1e-10)
target_ratio = entropy_floor / (current_entropy + 1e-10)
temperature = torch.ones_like(current_entropy)
temperature[below_floor] = target_ratio[below_floor].clamp(1.0, 10.0)
scaled_probs = torch.softmax(logits / temperature.unsqueeze(-1), dim=-1)
result = probs.clone()
result[below_floor] = scaled_probs[below_floor]
return result
def _find_temperature_for_target_entropy(
self,
logits: torch.Tensor,
target_entropy: torch.Tensor,
tol: float = 1e-3,
max_iter: int = 32,
T_low: float = 1e-6,
T_high_init: float = 1.0,
max_T: float = 100.0,
) -> torch.Tensor:
"""
Find per-position temperatures that achieve exactly the target entropy.
Uses bisection search, adapted from ARChitects' implementation.
Args:
logits: Raw logits (num_positions, V)
target_entropy: Target entropy per position (num_positions,) or scalar
tol: Entropy tolerance for convergence
max_iter: Maximum bisection iterations
T_low: Minimum temperature (near-greedy)
T_high_init: Initial upper bound for search
max_T: Maximum allowed temperature
Returns:
temperatures: (num_positions,) temperatures that achieve target entropy
"""
N, V = logits.shape
device, dtype = logits.device, logits.dtype
H_max = torch.log(torch.tensor(V, device=device, dtype=dtype))
if target_entropy.dim() == 0:
target = target_entropy.expand(N).clone()
else:
target = target_entropy.clone()
target = target.clamp(0.0, H_max)
def compute_entropy(logits_: torch.Tensor, temps: torch.Tensor) -> torch.Tensor:
temps = temps.unsqueeze(-1).clamp(min=T_low)
scaled = logits_ / temps
scaled = scaled - scaled.max(dim=-1, keepdim=True).values
probs = torch.softmax(scaled, dim=-1)
log_probs = torch.log(probs + 1e-12)
return -(probs * log_probs).sum(dim=-1)
lo = torch.full((N,), T_low, device=device, dtype=dtype)
hi = torch.full((N,), T_high_init, device=device, dtype=dtype)
H_lo = compute_entropy(logits, lo)
done_low = target <= (H_lo + tol)
H_hi = compute_entropy(logits, hi)
needs_expansion = (H_hi < target - tol) & ~done_low
for _ in range(100):
if not needs_expansion.any():
break
hi[needs_expansion] = (hi[needs_expansion] * 2.0).clamp(max=max_T)
H_hi[needs_expansion] = compute_entropy(
logits[needs_expansion], hi[needs_expansion]
)
needs_expansion = (H_hi < target - tol) & ~done_low & (hi < max_T - 1e-6)
can_bisect = ~done_low & (H_hi >= target - tol)
for _ in range(max_iter):
if not can_bisect.any():
break
mid = (lo + hi) / 2.0
H_mid = compute_entropy(logits, mid)
too_low = (H_mid < target) & can_bisect
lo[too_low] = mid[too_low]
hi[~too_low & can_bisect] = mid[~too_low & can_bisect]
converged = (hi - lo) <= tol * mid.clamp(min=1.0)
can_bisect = can_bisect & ~converged
temps = torch.zeros(N, device=device, dtype=dtype)
temps[done_low] = T_low
temps[~done_low] = (lo[~done_low] + hi[~done_low]) / 2.0
return temps
def _apply_target_entropy_effect(
self,
logits: torch.Tensor,
progress: torch.Tensor,
entropy_target_max: float,
entropy_target_min: float = 0.0,
) -> torch.Tensor:
"""
Adjust temperature to achieve EXACTLY the target entropy per position.
This is a TWO-SIDED constraint: both raises and lowers entropy as needed.
Args:
logits: Raw logits (num_masked, V)
progress: Per-position convergence progress (num_masked,)
entropy_target_max: Target entropy at progress=0
entropy_target_min: Target entropy at progress=1 (usually ~0)
Returns:
probs: Probabilities with entropy matching targets
"""
if entropy_target_max <= 0:
return torch.softmax(logits, dim=-1)
target_entropy = entropy_target_max * (1.0 - progress) + entropy_target_min * progress
temps = self._find_temperature_for_target_entropy(logits, target_entropy)
temps = temps.unsqueeze(-1).clamp(min=1e-6)
return torch.softmax(logits / temps, dim=-1)
def _apply_smear_effect(
self,
probs: torch.Tensor,
mask_pos: torch.Tensor,
progress_full: torch.Tensor,
smear_sigma_max: float,
) -> torch.Tensor:
"""
Apply positional smearing with per-position sigma based on progress.
Low progress = high smearing, high progress = no smearing.
Note: This operates on full (B, L, V) tensor because smearing mixes across positions.
"""
if smear_sigma_max <= 0:
return probs
B, L, V = probs.shape
sigma_per_pos = smear_sigma_max * (1.0 - progress_full)
avg_sigma = sigma_per_pos[mask_pos].mean().item()
if avg_sigma < 0.1:
return probs
positions = torch.arange(L, device=probs.device, dtype=probs.dtype)
diff = positions.unsqueeze(0) - positions.unsqueeze(1)
kernel = torch.exp(-0.5 * (diff / avg_sigma) ** 2)
kernel = kernel / kernel.sum(dim=1, keepdim=True)
smeared = torch.einsum('ij,bjv->biv', kernel, probs)
smeared = smeared / smeared.sum(dim=-1, keepdim=True).clamp(min=1e-10)
blend = progress_full.unsqueeze(-1)
result = blend * probs + (1 - blend) * smeared
output = probs.clone()
output[mask_pos] = result[mask_pos]
return output
def _apply_noise_effect(
self,
logits: torch.Tensor,
progress: torch.Tensor,
noise_std_max: float,
) -> torch.Tensor:
"""
Add Gaussian noise to logits based on convergence progress.
Low progress = high noise, high progress = no noise.
"""
if noise_std_max <= 0:
return logits
noise_std = noise_std_max * (1.0 - progress)
noise_std = noise_std.unsqueeze(-1)
noise = torch.randn_like(logits) * noise_std
return logits + noise
def _apply_iteration_rope(
self,
embeds: torch.Tensor,
iteration: int,
total_iterations: int,
dim_fraction: float = 0.25,
base: float = 10000.0,
) -> torch.Tensor:
"""
Apply rotary embedding based on iteration progress.
Uses a subset of dimensions to avoid interfering with position RoPE.
"""
if dim_fraction <= 0:
return embeds
H = embeds.shape[-1]
rot_dim = int(H * dim_fraction)
rot_dim = rot_dim - (rot_dim % 2)
if rot_dim < 2:
return embeds
progress = iteration / max(total_iterations - 1, 1)
inv_freq = 1.0 / (base ** (torch.arange(0, rot_dim, 2, device=embeds.device, dtype=embeds.dtype) / rot_dim))
angles = progress * inv_freq * 3.14159
cos, sin = torch.cos(angles), torch.sin(angles)
if embeds.dim() == 2:
cos, sin = cos.unsqueeze(0), sin.unsqueeze(0)
elif embeds.dim() == 3:
cos = cos.unsqueeze(0).unsqueeze(0)
sin = sin.unsqueeze(0).unsqueeze(0)
embeds_out = embeds.clone()
x1, x2 = embeds[..., -rot_dim::2], embeds[..., -rot_dim+1::2]
embeds_out[..., -rot_dim::2] = x1 * cos - x2 * sin
embeds_out[..., -rot_dim+1::2] = x1 * sin + x2 * cos
return embeds_out
# ==================== FLOW MATCHING ====================
def _sample_flow_matching_t(self, num_tokens: int, device: torch.device) -> torch.Tensor:
"""Sample per-token time levels for flow matching.
Returns:
t: (num_tokens,) tensor of time levels in [t_min, t_max]
"""
dist = self.config.flow_matching_t_distribution
if dist == "logit_normal":
z = torch.randn(num_tokens, device=device)
z = z * self.config.flow_matching_t_logit_std + self.config.flow_matching_t_logit_mean
t = torch.sigmoid(z)
elif dist == "uniform":
t = torch.empty(num_tokens, device=device).uniform_(0, 1)
else:
raise ValueError(f"Unknown flow_matching_t_distribution: {dist}")
return t.clamp(self.config.flow_matching_t_min, self.config.flow_matching_t_max)
def compute_flow_matching_distillation_loss(
self,
input_ids: torch.Tensor,
teacher_logits: torch.Tensor,
labels: torch.Tensor,
flow_noise_embed: torch.Tensor,
flow_t: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
) -> SelfDistillationOutput:
"""
CFM flow matching distillation: teacher sees state at time t, student sees
noisier state at time s < t on the same interpolation path.
Both should predict the same endpoint (target token). The student must
learn to refine from noisier inputs by matching the teacher's predictions.
Args:
input_ids: Input with [MASK] tokens at positions to predict
teacher_logits: Logits from the forward pass (will be detached)
labels: Target tokens at masked positions (-100 elsewhere)
flow_noise_embed: (num_masked, H) noise embeddings from forward
flow_t: (num_masked,) per-token time levels from forward
attention_mask: Standard attention mask
position_ids: Position IDs (if needed by base model)
Returns:
SelfDistillationOutput with loss, logits, time gap, and diagnostics
"""
mask_id = self.config.mask_token_id
mask_pos = (input_ids == mask_id) # (B, L)
device = input_ids.device
num_masked = mask_pos.sum().item()
if num_masked == 0:
zero = torch.tensor(0.0, device=device, requires_grad=True)
dummy = torch.zeros(1, device=device)
return SelfDistillationOutput(zero, dummy, dummy, 0.0, 0.0, 0.0, 1.0)
teacher_logits = teacher_logits.detach()
embed_weight = self.embed_weight
mask_emb = embed_weight[mask_id] # (H,)
base_embeds = self.get_input_embeddings()(input_ids) # (B, L, H)
# Target embeddings from labels
target_ids = labels[mask_pos] # (num_masked,)
target_embed = embed_weight[target_ids] # (num_masked, H)
# Sample student time s ~ U(0, t) per token
s_per_token = flow_t * torch.rand(num_masked, device=device) # (num_masked,)
# Student state: same noise, earlier time (noisier)
s_col = s_per_token.unsqueeze(-1).to(base_embeds.dtype) # (num_masked, 1)
student_interp = (1 - s_col) * flow_noise_embed + s_col * target_embed
if self.config.flow_matching_mask_scale:
student_masked_embeds = student_interp + (1 - s_col) * mask_emb
else:
student_masked_embeds = student_interp + mask_emb
# Build full student input (detached — gradient only flows through student's forward)
student_embeds = base_embeds.detach().clone()
student_embeds[mask_pos] = student_masked_embeds.detach()
student_inputs = torch.where(
mask_pos.unsqueeze(-1), student_embeds, base_embeds.detach()
)
if attention_mask is None:
attention_mask = torch.ones_like(input_ids, dtype=base_embeds.dtype)
student_out = self.mlm(
inputs_embeds=student_inputs,
attention_mask=attention_mask,
position_ids=position_ids,
return_dict=True,
)
student_logits = student_out.logits # (B, L, V) — has gradient
# KL divergence loss on masked positions
t_logits = teacher_logits[mask_pos] # (num_masked, V)
s_logits = student_logits[mask_pos] # (num_masked, V)
teacher_probs = F.softmax(t_logits, dim=-1)
student_log_probs = F.log_softmax(s_logits, dim=-1)
kl_loss = F.kl_div(
student_log_probs,
teacher_probs,
reduction="batchmean",
)
# Diagnostic metrics
with torch.no_grad():
teacher_log_probs = torch.log(teacher_probs + 1e-10)
teacher_entropy = -(teacher_probs * teacher_log_probs).sum(dim=-1).mean().item()
student_probs = F.softmax(s_logits.detach(), dim=-1)
student_log_probs_det = torch.log(student_probs + 1e-10)
student_entropy = -(student_probs * student_log_probs_det).sum(dim=-1).mean().item()
agreement = (t_logits.argmax(dim=-1) == s_logits.detach().argmax(dim=-1)).float().mean().item()
mean_time_gap = (flow_t - s_per_token).mean().item()
return SelfDistillationOutput(
loss=kl_loss,
teacher_logits=teacher_logits,
student_logits=student_logits,
degradation_temperature=mean_time_gap,
teacher_entropy=teacher_entropy,
student_entropy=student_entropy,
agreement_rate=agreement,
)
# ==================== SELF-DISTILLATION (legacy) ====================
def compute_self_distillation_loss(
self,
input_ids: torch.Tensor,
teacher_logits: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
temperature_min: Optional[float] = None,
temperature_max: Optional[float] = None,
temperature_distribution: Optional[str] = None,
) -> SelfDistillationOutput:
"""
CFM-style self-distillation: model's predictions should be consistent
across different levels of input degradation.
Process:
1. Take teacher logits (from standard forward pass, DETACHED)
2. Degrade: per-token random temperature → softer soft embeddings
3. Student: forward pass from degraded embeddings → logits (has grad)
4. Loss: KL(teacher || student) on masked positions
Each masked token gets its own independently sampled degradation
temperature, creating varied difficulty across the sequence.
Args:
input_ids: Input with [MASK] tokens at positions to predict
teacher_logits: Pre-computed teacher logits (will be detached).
Typically outputs.all_logits[0] or outputs.logits from standard forward.
attention_mask: Standard attention mask
position_ids: Position IDs (if needed by base model)
temperature_min: Min degradation temperature (default: config value)
temperature_max: Max degradation temperature (default: config value)
temperature_distribution: How to sample T (default: config value)
Returns:
SelfDistillationOutput with loss, logits, temperature, and diagnostics
"""
# Resolve defaults from config
temperature_min = temperature_min if temperature_min is not None else self.config.self_distillation_temperature_min
temperature_max = temperature_max if temperature_max is not None else self.config.self_distillation_temperature_max
temperature_distribution = temperature_distribution if temperature_distribution is not None else self.config.self_distillation_temperature_distribution
mask_id = self.config.mask_token_id
mask_pos = (input_ids == mask_id) # (B, L)
device = input_ids.device
num_masked = mask_pos.sum().item()
# Handle degenerate case: no masked positions
if num_masked == 0:
zero = torch.tensor(0.0, device=device, requires_grad=True)
dummy = torch.zeros(1, device=device)
return SelfDistillationOutput(zero, dummy, dummy, 1.0, 0.0, 0.0, 1.0)
# Ensure teacher logits are detached
teacher_logits = teacher_logits.detach()
embed_weight = self.embed_weight
mask_emb = embed_weight[mask_id] # (H,)
base_embeds = self.get_input_embeddings()(input_ids) # (B, L, H)
# ===== STEP 1: Sample per-token degradation temperatures =====
# Each masked position gets its own temperature independently
if temperature_distribution == "log_uniform":
log_min = torch.tensor(temperature_min, device=device).log()
log_max = torch.tensor(temperature_max, device=device).log()
log_T = torch.empty(num_masked, device=device).uniform_(log_min.item(), log_max.item())
T_per_token = log_T.exp() # (num_masked,)
elif temperature_distribution == "uniform":
T_per_token = torch.empty(num_masked, device=device).uniform_(
temperature_min, temperature_max
) # (num_masked,)
else:
raise ValueError(f"Unknown temperature distribution: {temperature_distribution}")
T_mean = T_per_token.mean().item()
# ===== STEP 2: Create degraded soft embeddings =====
# Per-token temperature scaling: each position gets its own T
masked_teacher_logits = teacher_logits[mask_pos] # (num_masked, V)
degraded_probs = F.softmax(masked_teacher_logits / T_per_token.unsqueeze(-1), dim=-1).to(embed_weight.dtype)
degraded_soft = degraded_probs @ embed_weight + mask_emb
degraded_soft_embeds = base_embeds.clone()
degraded_soft_embeds[mask_pos] = degraded_soft
degraded_soft_embeds = degraded_soft_embeds.detach()
# ===== STEP 3: Student forward from degraded input =====
student_inputs = torch.where(
mask_pos.unsqueeze(-1), degraded_soft_embeds, base_embeds.detach()
)
if attention_mask is None:
attention_mask = torch.ones_like(input_ids, dtype=base_embeds.dtype)
student_out = self.mlm(
inputs_embeds=student_inputs,
attention_mask=attention_mask,
position_ids=position_ids,
return_dict=True,
)
student_logits = student_out.logits # (B, L, V) — has gradient!
# ===== STEP 4: KL divergence loss on masked positions =====
t_logits = teacher_logits[mask_pos] # (num_masked, V)
s_logits = student_logits[mask_pos] # (num_masked, V)
teacher_probs = F.softmax(t_logits, dim=-1)
student_log_probs = F.log_softmax(s_logits, dim=-1)
# KL(teacher || student) = sum teacher * (log_teacher - log_student)
kl_loss = F.kl_div(
student_log_probs,
teacher_probs,
reduction="batchmean",
)
# ===== STEP 5: Compute diagnostic metrics =====
with torch.no_grad():
teacher_log_probs = torch.log(teacher_probs + 1e-10)
teacher_entropy = -(teacher_probs * teacher_log_probs).sum(dim=-1).mean().item()
student_probs = F.softmax(s_logits.detach(), dim=-1)
student_log_probs_det = torch.log(student_probs + 1e-10)
student_entropy = -(student_probs * student_log_probs_det).sum(dim=-1).mean().item()
agreement = (t_logits.argmax(dim=-1) == s_logits.detach().argmax(dim=-1)).float().mean().item()
return SelfDistillationOutput(
loss=kl_loss,
teacher_logits=teacher_logits,
student_logits=student_logits,
degradation_temperature=T_mean,
teacher_entropy=teacher_entropy,
student_entropy=student_entropy,
agreement_rate=agreement,
)
# ==================== MAIN SOFT EMBEDDING COMPUTATION ====================
@torch.no_grad()
def _compute_next_soft_embeds(
self,
logits: torch.Tensor,
mask_pos: torch.Tensor,
base_embeds: torch.Tensor,
prev_soft_embeds: Optional[torch.Tensor] = None,
iteration: int = 0,
total_iterations: int = 1,
# === Schedule parameters (default to config values) ===
schedule: Optional[str] = None,
causal_strength: Optional[float] = None,
# === Effect parameters (default to config values) ===
temperature_max: Optional[float] = None,
entropy_target_max: Optional[float] = None,
entropy_floor_max: Optional[float] = None,
smear_sigma_max: Optional[float] = None,
noise_std_max: Optional[float] = None,
iteration_rope_dim_fraction: Optional[float] = None,
) -> torch.Tensor:
"""
Compute soft embeddings from logits for the next iteration.
This function implements a unified "convergence schedule" system that controls
when each position is allowed to converge to a confident prediction.
Schedule Types:
"linear": All positions converge at the same rate (iteration-based only)
"causal": Early positions converge first, late positions last
Effects (mechanisms to enforce the schedule):
temperature_max: High temperature = more uniform distribution (one-sided)
entropy_target_max: Force EXACT entropy via bisection search (two-sided, recommended)
entropy_floor_max: Force MINIMUM entropy (one-sided, only prevents too confident)
smear_sigma_max: Spread probability across neighboring positions
noise_std_max: Add Gaussian noise to logits
All parameters default to their config values if not specified.
Args:
logits: Output logits from current iteration (B, L, V)
mask_pos: Boolean mask indicating which positions are masked (B, L)
base_embeds: Base token embeddings for non-masked positions (B, L, H)
iteration: Current iteration index (0-indexed)
total_iterations: Total number of iterations
Returns:
Soft embeddings for next iteration (B, L, H)
"""
# Use config values as defaults
schedule = schedule if schedule is not None else self.config.schedule
causal_strength = causal_strength if causal_strength is not None else self.config.causal_strength
temperature_max = temperature_max if temperature_max is not None else self.config.temperature_max
entropy_target_max = entropy_target_max if entropy_target_max is not None else self.config.entropy_target_max
entropy_floor_max = entropy_floor_max if entropy_floor_max is not None else self.config.entropy_floor_max
smear_sigma_max = smear_sigma_max if smear_sigma_max is not None else self.config.smear_sigma_max
noise_std_max = noise_std_max if noise_std_max is not None else self.config.noise_std_max
iteration_rope_dim_fraction = iteration_rope_dim_fraction if iteration_rope_dim_fraction is not None else self.config.iteration_rope_dim_fraction
soft_embeds = base_embeds.clone()
if not mask_pos.any():
return soft_embeds.detach()
B, L, V = logits.shape
device, dtype = logits.device, logits.dtype
# Check if any effects are enabled
has_effects = (
temperature_max > 0 or
entropy_target_max > 0 or
entropy_floor_max > 0 or
smear_sigma_max > 0 or
noise_std_max > 0 or
iteration_rope_dim_fraction > 0
)
if not has_effects:
# Simple path: no convergence schedule effects
masked_logits = logits[mask_pos]
embed_weight = self.embed_weight
# Convert logits to mixing weights based on soft_embedding_method
if self.config.soft_embedding_method == "none":
weights = masked_logits
elif self.config.soft_embedding_method == "l2_normalize":
weights = F.normalize(masked_logits, p=2, dim=-1)
else:
weights = self.normalize(masked_logits)
masked_soft = weights @ embed_weight
mask_emb = embed_weight[self.config.mask_token_id]
masked_soft = masked_soft + mask_emb
# Apply EMA blending with previous soft embeddings if enabled
ema_step = self.config.soft_embedding_ema_step
if ema_step < 1.0 and prev_soft_embeds is not None:
prev_masked_soft = prev_soft_embeds[mask_pos]
masked_soft = (1.0 - ema_step) * prev_masked_soft + ema_step * masked_soft
soft_embeds[mask_pos] = masked_soft
return soft_embeds.detach()
# ========== STEP 1: Compute per-position convergence progress ==========
batch_indices, position_indices = torch.where(mask_pos)
progress = self._compute_convergence_progress(
iteration=iteration,
total_iterations=total_iterations,
seq_length=L,
mask_positions=position_indices,
schedule=schedule,
causal_strength=causal_strength,
device=device,
dtype=dtype,
)
# Compute full (B, L) progress for smearing if needed
if smear_sigma_max > 0:
all_positions = torch.arange(L, device=device, dtype=dtype)
progress_full = self._compute_convergence_progress(
iteration=iteration,
total_iterations=total_iterations,
seq_length=L,
mask_positions=all_positions,
schedule=schedule,
causal_strength=causal_strength,
device=device,
dtype=dtype,
)
progress_full = progress_full.unsqueeze(0).expand(B, -1)
# ========== STEP 2: Apply smearing (needs full tensor) ==========
full_probs = self.normalize(logits)
if smear_sigma_max > 0:
full_probs = self._apply_smear_effect(
full_probs, mask_pos, progress_full, smear_sigma_max
)
# ========== STEP 3: Extract masked positions ==========
masked_logits = logits[mask_pos]
masked_probs = full_probs[mask_pos]
# ========== STEP 4: Apply temperature effect (on logits) ==========
if temperature_max > 0 and entropy_target_max <= 0:
masked_logits = self._apply_temperature_effect(
masked_logits, progress, temperature_max
)
masked_probs = torch.softmax(masked_logits, dim=-1)
# ========== STEP 5: Apply noise effect (on logits) ==========
if noise_std_max > 0:
masked_logits_noisy = self._apply_noise_effect(
torch.log(masked_probs + 1e-10), progress, noise_std_max
)
masked_probs = torch.softmax(masked_logits_noisy, dim=-1)
# ========== STEP 6: Apply entropy control ==========
if entropy_target_max > 0:
masked_probs = self._apply_target_entropy_effect(
masked_logits, progress, entropy_target_max
)
elif entropy_floor_max > 0:
masked_probs = self._apply_entropy_floor_effect(
masked_probs, progress, entropy_floor_max
)
# ========== STEP 7: Compute soft embeddings ==========
embed_weight = self.embed_weight
# Convert to mixing weights based on soft_embedding_method
if self.config.soft_embedding_method == "none":
# No normalization - use raw logits directly
weights = masked_logits
elif self.config.soft_embedding_method == "l2_normalize":
# L2 normalize bypasses all the softmax-based effects above
weights = F.normalize(masked_logits, p=2, dim=-1)
else:
weights = masked_probs
masked_soft = weights @ embed_weight
mask_emb = embed_weight[self.config.mask_token_id]
masked_soft = masked_soft + mask_emb
# ========== STEP 8: Apply iteration RoPE ==========
if iteration_rope_dim_fraction > 0:
masked_soft = self._apply_iteration_rope(
masked_soft, iteration, total_iterations, iteration_rope_dim_fraction
)
# ========== STEP 8.5: Apply EMA blending ==========
ema_step = self.config.soft_embedding_ema_step
if ema_step < 1.0 and prev_soft_embeds is not None:
prev_masked_soft = prev_soft_embeds[mask_pos]
masked_soft = (1.0 - ema_step) * prev_masked_soft + ema_step * masked_soft
# ========== STEP 9: Place back and return ==========
soft_embeds[mask_pos] = masked_soft
return soft_embeds.detach()
@torch.no_grad()
def _compute_iteration_metrics(
self, logits: torch.Tensor, labels: torch.Tensor
) -> IterationMetrics:
"""
Compute token-level AND sequence-level metrics for a single iteration.
Returns scalars only - no large tensor storage.
Token-level metrics:
- accuracy: fraction of correct token predictions
- entropy: average entropy per token
- softmax_ce: cross-entropy loss per token
Sequence-level metrics:
- full_sequence_accuracy: fraction of sequences where ALL tokens are correct
- min_sequence_confidence: mean of minimum top-1 confidence per sequence
"""
B = logits.shape[0]
# Move to CPU to avoid GPU OOM - metrics are for monitoring only
logits = logits.detach().cpu().float() # float32 is sufficient for metrics
target_labels = labels.detach().cpu().contiguous()
mask = target_labels != -100
if mask.sum() == 0:
return IterationMetrics(
accuracy=0.0,
entropy=0.0,
softmax_ce=0.0,
full_sequence_accuracy=0.0,
min_sequence_confidence=0.0,
)
logits = logits.contiguous()
predictions = logits.argmax(dim=-1)
correct = (predictions == target_labels) & mask
# ===== TOKEN-LEVEL METRICS =====
# Token accuracy
accuracy = (correct.sum() / mask.sum()).item()
# Extract valid tokens for entropy/CE
valid_logits = logits[mask]
valid_labels = target_labels[mask]
# Entropy (using log_softmax for numerical stability)
log_probs = torch.nn.functional.log_softmax(valid_logits, dim=-1)
probs = torch.exp(log_probs)
entropy = -(probs * log_probs).sum(dim=-1).mean().item()
# Cross-entropy
softmax_ce = torch.nn.functional.cross_entropy(
valid_logits, valid_labels, reduction="mean"
).item()
# ===== SEQUENCE-LEVEL METRICS =====
# Check which sequences have valid tokens
sequences_with_tokens = mask.any(dim=1) # (B,)
num_valid_sequences = sequences_with_tokens.sum().item()
if num_valid_sequences == 0:
return IterationMetrics(
accuracy=accuracy,
entropy=entropy,
softmax_ce=softmax_ce,
full_sequence_accuracy=0.0,
min_sequence_confidence=0.0,
)
# Full sequence accuracy: all tokens in sequence must be correct
num_correct_per_seq = correct.sum(dim=1) # (B,)
num_tokens_per_seq = mask.sum(dim=1) # (B,)
all_correct = (num_correct_per_seq == num_tokens_per_seq) & sequences_with_tokens
full_seq_accuracy = (all_correct.sum() / num_valid_sequences).item()
# Min sequence confidence: minimum top-1 probability within each sequence
probs_full = torch.softmax(logits, dim=-1) # (B, L, V) - already float32
top1_confidence = probs_full.max(dim=-1).values # (B, L)
min_confidences = []
for i in range(B):
if sequences_with_tokens[i]:
seq_confidences = top1_confidence[i][mask[i]] # (num_tokens_in_seq,)
min_confidences.append(seq_confidences.min().item())
min_seq_conf = sum(min_confidences) / len(min_confidences) if min_confidences else 0.0
return IterationMetrics(
accuracy=accuracy,
entropy=entropy,
softmax_ce=softmax_ce,
full_sequence_accuracy=full_seq_accuracy,
min_sequence_confidence=min_seq_conf,
)
def _single_iteration(
self,
t: int,
T: int,
soft_embeds: torch.Tensor,
base_embeds: torch.Tensor,
mask_pos: torch.Tensor,
attention_mask: Optional[torch.Tensor],
labels: Optional[torch.Tensor],
compute_metrics: bool,
position_ids: Optional[torch.Tensor] = None,
**kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[IterationMetrics]]:
"""
Execute a single iteration of recursive refinement.
Args:
t: Current iteration index (0 to T-1)
T: Total number of iterations
soft_embeds: Soft embeddings for mask positions
base_embeds: Base token embeddings from input_ids
mask_pos: Boolean mask of [MASK] positions (B, L)
attention_mask: Attention mask for MLM
labels: Target labels for loss computation
compute_metrics: Whether to compute iteration metrics
Returns:
logits: Output logits from MLM (B, L, V)
weighted_loss: Loss weighted by step_weight(t, T), or None if no labels
metrics: IterationMetrics, or None if not requested
"""
# Blend soft embeddings (at mask positions) with base embeddings (at non-mask positions)
inputs_embeds = torch.where(mask_pos.unsqueeze(-1), soft_embeds, base_embeds)
# Forward through base MLM
outputs = self.mlm(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
labels=labels,
return_dict=True,
**kwargs,
)
# Compute weighted loss for this iteration
weighted_loss = outputs.loss
if labels is not None:
if weighted_loss is None:
# Base model doesn't compute loss (e.g., LLaDA) - compute it ourselves
# Only compute loss on MASKED positions (MDLM training)
masked_logits = outputs.logits[mask_pos] # (num_masked, V)
masked_labels = labels[mask_pos] # (num_masked,)
loss_fct = CrossEntropyLoss() # -100 index = padding token
weighted_loss = loss_fct(masked_logits, masked_labels)
weighted_loss *= self.step_weight(t, T)
# Compute iteration metrics if requested
metrics = None
if compute_metrics and labels is not None:
metrics = self._compute_iteration_metrics(outputs.logits, labels)
return outputs.logits, weighted_loss, metrics
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
num_recursions: Optional[int] = None,
compute_iteration_metrics: bool = False,
use_recursion_checkpointing: Optional[bool] = None,
# Parameters for single-iteration training mode (DEPRECATED)
prev_soft_embeds: Optional[torch.Tensor] = None,
run_set_iteration: Optional[int] = None,
# === Convergence schedule parameters (None = use config defaults) ===
schedule: Optional[str] = None,
causal_strength: Optional[float] = None,
# === Effect parameters (None = use config defaults) ===
temperature_max: Optional[float] = None,
entropy_target_max: Optional[float] = None,
entropy_floor_max: Optional[float] = None,
smear_sigma_max: Optional[float] = None,
noise_std_max: Optional[float] = None,
iteration_rope_dim_fraction: Optional[float] = None,
**kwargs,
) -> RecursiveMaskedLMOutput:
"""
Forward with recursive refinement.
Supports three modes:
1. Checkpointed mode (default): Run all T recursions with gradient checkpointing.
Gradients flow through the entire chain; activations recomputed during backward.
2. Non-checkpointed mode (use_recursion_checkpointing=False): Store all activations.
Faster backward but higher memory.
3. Single-iteration mode (DEPRECATED - run_set_iteration is not None): Run only one
iteration. Use use_recursion_checkpointing=True instead.
Loss Weighting (config.loss_weight):
"last_1": Only final iteration loss (enables learning convergence behavior)
"last_2": Last 2 iterations
"linear": All iterations, linearly weighted (default)
"uniform": All iterations, uniformly weighted
Recursion Checkpointing:
use_recursion_checkpointing: Enable gradient checkpointing for iterations.
True = checkpoint each iteration, recompute during backward (default).
False = store all activations (higher memory, faster backward).
Convergence Schedule Parameters:
All schedule/effect parameters default to their config values if not specified.
Pass explicit values to override config for this forward pass.
schedule: "linear" or "causal" - controls when positions can converge
causal_strength: How much faster early positions converge (causal only)
temperature_max: Max temperature boost for uncertain positions
entropy_target_max: Target entropy at progress=0 (two-sided, recommended)
entropy_floor_max: Min entropy floor (one-sided)
smear_sigma_max: Max Gaussian sigma for position smearing
noise_std_max: Max std of Gaussian noise on logits
iteration_rope_dim_fraction: Fraction of dims for iteration RoPE
"""
B, L = input_ids.shape
V = self.embed_weight.shape[0]
mask_id = self.config.mask_token_id
if mask_id is None:
raise ValueError("mask_token_id must be set")
# Resolve config default for recursion checkpointing
use_recursion_checkpointing = (
use_recursion_checkpointing
if use_recursion_checkpointing is not None
else self.config.use_recursion_checkpointing
)
mask_pos = (input_ids == mask_id) # (B, L)
base_embeds = self.get_input_embeddings()(input_ids) # (B, L, H)
T = num_recursions or self.config.num_recursions
weight_sum = sum(self.step_weight(i, T) for i in range(T))
# Bundle schedule kwargs to pass to _compute_next_soft_embeds
schedule_kwargs = dict(
schedule=schedule,
causal_strength=causal_strength,
temperature_max=temperature_max,
entropy_target_max=entropy_target_max,
entropy_floor_max=entropy_floor_max,
smear_sigma_max=smear_sigma_max,
noise_std_max=noise_std_max,
iteration_rope_dim_fraction=iteration_rope_dim_fraction,
)
# ===== SINGLE ITERATION MODE (DEPRECATED) =====
if run_set_iteration is not None:
warnings.warn(
"run_set_iteration is deprecated. Use use_recursion_checkpointing=True instead, "
"which provides proper gradient flow through all iterations.",
DeprecationWarning,
stacklevel=2,
)
t = run_set_iteration
# Get soft embeddings for this iteration
if t == 0:
# t=0: Uniform prior = average embedding (equivalent to softmax(zeros) @ embed_weight)
# We compute this efficiently via embed_weight.mean() rather than creating large zero tensors
soft_embeds = base_embeds.clone()
if mask_pos.any():
avg_embed = self.embed_weight.mean(dim=0) # (H,) - mean over all V tokens
mask_emb = self.embed_weight[mask_id]
soft_embeds[mask_pos] = avg_embed + mask_emb
else:
if prev_soft_embeds is None:
raise ValueError(f"prev_soft_embeds must be provided for iteration {t}")
soft_embeds = prev_soft_embeds
logits, weighted_loss, metrics = self._single_iteration(
t, T, soft_embeds, base_embeds, mask_pos,
attention_mask, labels, compute_iteration_metrics,
position_ids=position_ids, **kwargs
)
# Normalize loss by total weight sum
loss = weighted_loss / weight_sum if weighted_loss is not None else None
# Compute soft embeddings for next iteration (if not last)
next_soft_embeds = None
if t < T - 1:
next_soft_embeds = self._compute_next_soft_embeds(
logits, mask_pos, base_embeds,
iteration=t,
total_iterations=T,
**schedule_kwargs,
)
return RecursiveMaskedLMOutput(
loss=loss,
logits=logits,
next_soft_embeds=next_soft_embeds,
iteration_metrics={t: metrics} if metrics is not None else None,
)
# ===== CHECKPOINTED MODE (gradient flow through all iterations) =====
embed_weight = self.embed_weight
mask_emb = embed_weight[mask_id] # (H,)
# Temperature must be a tensor for checkpointing (checkpoint requires tensor inputs)
temperature = torch.tensor(
self.config.temperature,
device=input_ids.device,
dtype=base_embeds.dtype,
)
# Ensure attention_mask is a tensor (required for checkpointing)
if attention_mask is None:
attention_mask = torch.ones(B, L, device=input_ids.device, dtype=base_embeds.dtype)
# Initialize soft embeddings for masked positions
soft_embeds = base_embeds.clone()
flow_noise_embed = None
flow_t_per_token = None
if self.config.flow_matching_enabled and self.training and labels is not None and mask_pos.any():
# Flow matching: interpolate between random noise and target on the simplex
num_masked = mask_pos.sum().item()
V = embed_weight.shape[0]
device = input_ids.device
# Sample per-token time levels (logit-normal by default)
flow_t_per_token = self._sample_flow_matching_t(num_masked, device)
# Random noise embedding: sample on simplex, project to H-dim
z = torch.randn(num_masked, V, device=device, dtype=base_embeds.dtype)
p_noise = F.softmax(z * self.config.flow_matching_noise_scale, dim=-1).to(base_embeds.dtype)
flow_noise_embed = p_noise @ embed_weight # (num_masked, H) — compact
# Target embedding from labels
target_ids = labels[mask_pos] # original token IDs at masked positions
target_embed = embed_weight[target_ids] # (num_masked, H)
# Interpolate in embedding space
t_col = flow_t_per_token.unsqueeze(-1).to(base_embeds.dtype) # (num_masked, 1)
interp_embed = (1 - t_col) * flow_noise_embed + t_col * target_embed
# Add mask signal (binary or scaled)
if self.config.flow_matching_mask_scale:
soft_embeds[mask_pos] = interp_embed + (1 - t_col) * mask_emb
else:
soft_embeds[mask_pos] = interp_embed + mask_emb
elif mask_pos.any():
# Standard uniform prior (average embedding + mask signal)
avg_embed = embed_weight.mean(dim=0) # (H,)
soft_embeds[mask_pos] = avg_embed + mask_emb
iteration_metrics = {} if compute_iteration_metrics and labels is not None else None
# Main recursion loop with optional checkpointing
all_logits = []
for t in range(T):
if self.training and use_recursion_checkpointing:
# Use checkpointing: activations recomputed during backward
# This maintains gradient flow while saving memory
logits, soft_embeds = torch_checkpoint(
self._single_iteration_checkpointable,
soft_embeds,
base_embeds,
mask_pos,
attention_mask,
embed_weight,
mask_emb,
temperature,
position_ids,
use_reentrant=False, # Critical for nested checkpointing!
)
else:
# No checkpointing: store all activations (inference or explicit disable)
logits, soft_embeds = self._single_iteration_checkpointable(
soft_embeds,
base_embeds,
mask_pos,
attention_mask,
embed_weight,
mask_emb,
temperature,
position_ids,
)
all_logits.append(logits)
# Compute iteration metrics if requested (no grad needed)
if iteration_metrics is not None and labels is not None:
with torch.no_grad():
iteration_metrics[t] = self._compute_iteration_metrics(logits, labels)
# Return all logits for trainer to compute loss with proper normalization
# Trainer handles: timestep-based weighting, iteration weighting, batch/sequence/token normalization
return RecursiveMaskedLMOutput(
loss=None, # Let trainer compute loss
logits=logits, # Final logits for inference/metrics
all_logits=all_logits if self.training else None, # Only needed during training
iteration_metrics=iteration_metrics or None,
flow_noise_embed=flow_noise_embed, # For flow matching distillation
flow_t=flow_t_per_token, # For flow matching distillation
)
@torch.no_grad()
def _generate_flow_map(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor],
position_ids: Optional[torch.Tensor],
num_steps: int,
) -> torch.Tensor:
"""Fill in mask positions using the CFM flow map update rule.
Starts from a random point on the probability simplex and iteratively
moves toward the model's predictions using the flow map step rule.
Args:
input_ids: Input with [MASK] tokens at positions to fill
attention_mask: Attention mask
position_ids: Position IDs
num_steps: Number of flow map steps (finer = better, 1 step = greedy)
Returns:
Tensor with [MASK] positions filled with predicted tokens
"""
mask_pos = (input_ids == self.config.mask_token_id)
num_masked = mask_pos.sum().item()
if num_masked == 0:
return input_ids.clone()
device = input_ids.device
V = self.embed_weight.shape[0]
embed_weight = self.embed_weight
mask_emb = embed_weight[self.config.mask_token_id]
base_embeds = self.get_input_embeddings()(input_ids)
# Start from random simplex point
noise_scale = self.config.flow_matching_noise_scale
p = F.softmax(torch.randn(num_masked, V, device=device, dtype=base_embeds.dtype) * noise_scale, dim=-1).to(base_embeds.dtype)
times = torch.linspace(0, 1, num_steps + 1, device=device)
for i in range(num_steps):
t_now = times[i]
t_next = times[i + 1]
step_size = (t_next - t_now) / (1 - t_now)
# Mask signal (binary or scaled)
if self.config.flow_matching_mask_scale:
mask_signal = (1 - t_now) * mask_emb
else:
mask_signal = mask_emb
# Project current state to embedding space
embed = p @ embed_weight + mask_signal
soft_embeds = base_embeds.clone()
soft_embeds[mask_pos] = embed
inputs_embeds = torch.where(mask_pos.unsqueeze(-1), soft_embeds, base_embeds)
outputs = self.mlm(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
return_dict=True,
)
pi = F.softmax(outputs.logits[mask_pos], dim=-1).to(p.dtype)
# Flow map update: move toward model's prediction
p = p + step_size * (pi - p)
# Fix floating point drift off the simplex
p = p.clamp(min=0)
p = p / p.sum(dim=-1, keepdim=True)
result = input_ids.clone()
result[mask_pos] = p.argmax(dim=-1)
return result
@torch.no_grad()
def generate(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
num_recursions: Optional[int] = None,
# === Convergence schedule parameters (None = use config defaults) ===
schedule: Optional[str] = None,
causal_strength: Optional[float] = None,
# === Effect parameters (None = use config defaults) ===
temperature_max: Optional[float] = None,
entropy_target_max: Optional[float] = None,
entropy_floor_max: Optional[float] = None,
smear_sigma_max: Optional[float] = None,
noise_std_max: Optional[float] = None,
iteration_rope_dim_fraction: Optional[float] = None,
) -> torch.Tensor:
"""Fill in mask positions via iterative refinement.
When flow_matching_enabled, uses the CFM flow map update rule.
Otherwise, uses standard recursive soft-token refinement.
Args:
input_ids: Input token IDs with [MASK] tokens at positions to fill
attention_mask: Attention mask
num_recursions: Override number of recursions/steps (default: config value)
schedule: "linear" or "causal" convergence schedule
causal_strength: How much faster early positions converge (causal only)
temperature_max: Max temperature boost for uncertain positions
entropy_target_max: Target entropy at progress=0 (two-sided)
entropy_floor_max: Min entropy floor (one-sided)
smear_sigma_max: Max Gaussian sigma for position smearing
noise_std_max: Max std of Gaussian noise on logits
iteration_rope_dim_fraction: Fraction of dims for iteration RoPE
Returns:
Tensor with [MASK] positions filled with predicted tokens
"""
num_steps = num_recursions or self.config.num_recursions
if self.config.flow_matching_enabled:
return self._generate_flow_map(
input_ids, attention_mask, position_ids, num_steps
)
out = self.forward(
input_ids,
attention_mask,
position_ids=position_ids,
num_recursions=num_steps,
schedule=schedule,
causal_strength=causal_strength,
temperature_max=temperature_max,
entropy_target_max=entropy_target_max,
entropy_floor_max=entropy_floor_max,
smear_sigma_max=smear_sigma_max,
noise_std_max=noise_std_max,
iteration_rope_dim_fraction=iteration_rope_dim_fraction,
)
result = input_ids.clone()
mask_pos = (input_ids == self.config.mask_token_id)
result[mask_pos] = out.logits.argmax(dim=-1)[mask_pos]
return result