from __future__ import annotations import warnings from dataclasses import dataclass from typing import NamedTuple, Optional import torch import torch.nn.functional as F from torch.nn import CrossEntropyLoss from torch.utils.checkpoint import checkpoint as torch_checkpoint from transformers import AutoConfig, AutoModelForMaskedLM, PreTrainedModel from transformers.modeling_outputs import MaskedLMOutput from transformers.utils import ModelOutput from .configuration_recursive import RecursiveMLMConfig @dataclass class IterationMetrics(ModelOutput): """Metrics for a single iteration of recursive refinement.""" accuracy: Optional[float] = None entropy: Optional[float] = None softmax_ce: Optional[float] = None full_sequence_accuracy: Optional[float] = None min_sequence_confidence: Optional[float] = None @dataclass class RecursiveMaskedLMOutput(MaskedLMOutput): iteration_metrics: Optional[dict[int, IterationMetrics]] = None # Maps iteration index to metrics next_soft_embeds: Optional[torch.Tensor] = None # For caching between training steps all_logits: Optional[list[torch.Tensor]] = None # All T iterations' logits for trainer loss computation # Flow matching state (for distillation — compact H-dim, not V-dim) flow_noise_embed: Optional[torch.Tensor] = None # (num_masked, H) noise embedding flow_t: Optional[torch.Tensor] = None # (num_masked,) per-token time levels class SelfDistillationOutput(NamedTuple): """Output from self-distillation forward pass.""" loss: torch.Tensor # KL divergence loss (scalar, has grad) teacher_logits: torch.Tensor # For metrics/debugging (detached) student_logits: torch.Tensor # For metrics/debugging (has grad) degradation_temperature: float # Mean per-token temperature sampled teacher_entropy: float # Entropy of teacher distribution (for monitoring) student_entropy: float # Entropy of student distribution (for monitoring) agreement_rate: float # Fraction where teacher and student argmax agree class RecursiveMaskedLM(PreTrainedModel): """ Wraps any HF MLM with recursive soft-token refinement. At each step: 1. Normalize logits -> probs 2. Compute soft embeddings: probs @ embedding_weight + mask_embedding 3. Forward through MLM 4. Accumulate weighted loss """ config_class = RecursiveMLMConfig base_model_prefix = "mlm" supports_gradient_checkpointing = True def __init__(self, config: RecursiveMLMConfig, base_model: Optional[PreTrainedModel] = None): super().__init__(config) if base_model is not None: # Pre-trained model provided - assign directly WITHOUT calling post_init() # to avoid reinitializing the pre-trained weights via _init_weights() self.mlm = base_model elif config.base_model_config is not None: model_type = config.base_model_config.get("model_type", "") if model_type == "llada": from .configuration_llada import LLaDAConfig from .modeling_llada import LLaDAModelLM base_config = LLaDAConfig.from_dict(config.base_model_config) self.mlm = LLaDAModelLM(base_config) else: base_config = AutoConfig.for_model(**config.base_model_config) self.mlm = AutoModelForMaskedLM.from_config(base_config) # Only call post_init() for freshly created models (needs weight init) self.post_init() else: raise ValueError("Need either base_model or config.base_model_config") @classmethod def from_mlm_pretrained( cls, mlm_name_or_path: str, num_recursions: int = 8, normalization: str = "softmax", loss_weight: str = "linear", mask_token_id: Optional[int] = None, temperature: float = 1.0, gradient_steps: Optional[int] = None, # === Convergence schedule parameters === schedule: str = "linear", causal_strength: float = 1.0, # === Effect parameters === temperature_max: float = 0.0, entropy_target_max: float = 0.0, entropy_floor_max: float = 0.0, smear_sigma_max: float = 0.0, noise_std_max: float = 0.0, iteration_rope_dim_fraction: float = 0.0, use_recursion_checkpointing: bool = True, # === Soft embedding method === soft_embedding_method: str = "softmax", soft_embedding_ema_step: float = 1.0, # === Flow matching parameters === flow_matching_enabled: bool = False, flow_matching_lambda: float = 0.5, flow_matching_t_distribution: str = "logit_normal", flow_matching_t_logit_mean: float = -0.4, flow_matching_t_logit_std: float = 1.0, flow_matching_t_min: float = 0.01, flow_matching_t_max: float = 0.99, flow_matching_mask_scale: bool = False, **model_kwargs, ) -> "RecursiveMaskedLM": """Load a pretrained MLM and wrap it for recursive refinement.""" base_model = AutoModelForMaskedLM.from_pretrained(mlm_name_or_path, **model_kwargs) return cls.from_base_model( base_model, num_recursions=num_recursions, normalization=normalization, loss_weight=loss_weight, mask_token_id=mask_token_id, temperature=temperature, gradient_steps=gradient_steps, schedule=schedule, causal_strength=causal_strength, temperature_max=temperature_max, entropy_target_max=entropy_target_max, entropy_floor_max=entropy_floor_max, smear_sigma_max=smear_sigma_max, noise_std_max=noise_std_max, iteration_rope_dim_fraction=iteration_rope_dim_fraction, use_recursion_checkpointing=use_recursion_checkpointing, soft_embedding_method=soft_embedding_method, soft_embedding_ema_step=soft_embedding_ema_step, flow_matching_enabled=flow_matching_enabled, flow_matching_lambda=flow_matching_lambda, flow_matching_t_distribution=flow_matching_t_distribution, flow_matching_t_logit_mean=flow_matching_t_logit_mean, flow_matching_t_logit_std=flow_matching_t_logit_std, flow_matching_t_min=flow_matching_t_min, flow_matching_t_max=flow_matching_t_max, flow_matching_mask_scale=flow_matching_mask_scale, ) @classmethod def from_base_model( cls, base_model: PreTrainedModel, num_recursions: int = 8, normalization: str = "softmax", loss_weight: str = "linear", mask_token_id: Optional[int] = None, temperature: float = 1.0, gradient_steps: Optional[int] = None, # === Convergence schedule parameters === schedule: str = "linear", causal_strength: float = 1.0, # === Effect parameters === temperature_max: float = 0.0, entropy_target_max: float = 0.0, entropy_floor_max: float = 0.0, smear_sigma_max: float = 0.0, noise_std_max: float = 0.0, iteration_rope_dim_fraction: float = 0.0, use_recursion_checkpointing: bool = True, # === Soft embedding method === soft_embedding_method: str = "softmax", soft_embedding_ema_step: float = 1.0, # === Flow matching parameters === flow_matching_enabled: bool = False, flow_matching_lambda: float = 0.5, flow_matching_t_distribution: str = "logit_normal", flow_matching_t_logit_mean: float = -0.4, flow_matching_t_logit_std: float = 1.0, flow_matching_t_min: float = 0.01, flow_matching_t_max: float = 0.99, flow_matching_mask_scale: bool = False, ) -> "RecursiveMaskedLM": """Wrap an existing model for recursive refinement. Use this for models not loadable via AutoModelForMaskedLM (e.g., LLaDA). Args: base_model: The base MLM model to wrap num_recursions: Number of recursive refinement steps normalization: Normalization method for logits (softmax, stable_softmax) loss_weight: Loss weighting scheme (last_1, last_2, linear, uniform) mask_token_id: Token ID for [MASK] temperature: Temperature for softmax normalization gradient_steps: Number of final steps to backprop through schedule: Convergence schedule type ("linear" or "causal") causal_strength: How much faster early positions converge (causal only) temperature_max: Max temperature boost for uncertain positions entropy_target_max: Target entropy at progress=0 (two-sided, recommended) entropy_floor_max: Min entropy floor (one-sided) smear_sigma_max: Max Gaussian sigma for position smearing noise_std_max: Max std of Gaussian noise on logits iteration_rope_dim_fraction: Fraction of dims for iteration RoPE use_recursion_checkpointing: Enable gradient checkpointing for iterations soft_embedding_method: How to convert logits to soft embeddings soft_embedding_ema_step: EMA step size (1.0 = no EMA, <1.0 = blend with previous) flow_matching_enabled: Enable CFM-inspired flow matching framework flow_matching_lambda: Weight of distillation KL loss relative to CE flow_matching_t_distribution: Time sampling distribution ("logit_normal" or "uniform") flow_matching_t_logit_mean: Mean of logit-normal distribution flow_matching_t_logit_std: Std of logit-normal distribution flow_matching_t_min: Minimum time value (clamp) flow_matching_t_max: Maximum time value (clamp) flow_matching_mask_scale: Scale mask_emb by (1-t) if True, binary if False """ config = RecursiveMLMConfig.from_base_model_config( base_model.config, num_recursions=num_recursions, normalization=normalization, loss_weight=loss_weight, mask_token_id=mask_token_id, temperature=temperature, gradient_steps=gradient_steps, schedule=schedule, causal_strength=causal_strength, temperature_max=temperature_max, entropy_target_max=entropy_target_max, entropy_floor_max=entropy_floor_max, smear_sigma_max=smear_sigma_max, noise_std_max=noise_std_max, iteration_rope_dim_fraction=iteration_rope_dim_fraction, use_recursion_checkpointing=use_recursion_checkpointing, soft_embedding_method=soft_embedding_method, soft_embedding_ema_step=soft_embedding_ema_step, flow_matching_enabled=flow_matching_enabled, flow_matching_lambda=flow_matching_lambda, flow_matching_t_distribution=flow_matching_t_distribution, flow_matching_t_logit_mean=flow_matching_t_logit_mean, flow_matching_t_logit_std=flow_matching_t_logit_std, flow_matching_t_min=flow_matching_t_min, flow_matching_t_max=flow_matching_t_max, flow_matching_mask_scale=flow_matching_mask_scale, ) return cls(config, base_model=base_model) @property def embed_weight(self) -> torch.Tensor: return self.mlm.get_input_embeddings().weight def get_input_embeddings(self): return self.mlm.get_input_embeddings() def set_input_embeddings(self, value): self.mlm.set_input_embeddings(value) def get_output_embeddings(self): return self.mlm.get_output_embeddings() def set_output_embeddings(self, new_embeddings): self.mlm.set_output_embeddings(new_embeddings) def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): """Enable gradient checkpointing with correct settings for recursion. Forces use_reentrant=False which is required for: - Nested checkpoint calls (base model + recursion checkpointing) - Models with frozen parameters - Complex gradient flows through soft embeddings """ if gradient_checkpointing_kwargs is None: gradient_checkpointing_kwargs = {} # Force use_reentrant=False for nested checkpointing compatibility gradient_checkpointing_kwargs.setdefault("use_reentrant", False) self.mlm.gradient_checkpointing_enable(gradient_checkpointing_kwargs) def gradient_checkpointing_disable(self): """Disable gradient checkpointing in the underlying MLM.""" self.mlm.gradient_checkpointing_disable() def _single_iteration_checkpointable( self, soft_embeds: torch.Tensor, base_embeds: torch.Tensor, mask_pos: torch.Tensor, attention_mask: torch.Tensor, embed_weight: torch.Tensor, mask_emb: torch.Tensor, temperature: torch.Tensor, position_ids: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Single differentiable iteration for checkpointing. This method performs one iteration of recursive refinement in a way that maintains gradient flow and is compatible with torch.utils.checkpoint. Args: soft_embeds: (B, L, H) - current soft embeddings base_embeds: (B, L, H) - original token embeddings mask_pos: (B, L) bool - which positions are masked attention_mask: (B, L) - attention mask for MLM embed_weight: (V, H) - embedding weight matrix mask_emb: (H,) - mask token embedding temperature: scalar tensor - softmax temperature Returns: logits: (B, L, V) - output logits from this iteration next_soft_embeds: (B, L, H) - soft embeddings for next iteration """ # Blend: use soft_embeds at masked positions, base_embeds elsewhere inputs_embeds = torch.where(mask_pos.unsqueeze(-1), soft_embeds, base_embeds) # Forward through base MLM outputs = self.mlm( inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, return_dict=True, ) logits = outputs.logits # Compute soft embeddings for next iteration (DIFFERENTIABLE - no detach!) next_soft_embeds = base_embeds.clone() if mask_pos.any(): masked_logits = logits[mask_pos] # (num_masked, V) # Convert logits to mixing weights based on soft_embedding_method if self.config.soft_embedding_method == "none": # No normalization - use raw logits directly weights = masked_logits # Differentiable! elif self.config.soft_embedding_method == "l2_normalize": # L2 normalize logits - removes softmax bottleneck for smoother gradients weights = F.normalize(masked_logits, p=2, dim=-1) # Differentiable! else: # Default: softmax normalization weights = F.softmax(masked_logits / temperature, dim=-1) # Differentiable! soft_emb = weights @ embed_weight + mask_emb # Differentiable! # Apply EMA blending with previous soft embeddings if enabled ema_step = self.config.soft_embedding_ema_step if ema_step < 1.0: prev_soft_emb = soft_embeds[mask_pos] # Previous iteration's soft embeddings soft_emb = (1.0 - ema_step) * prev_soft_emb + ema_step * soft_emb next_soft_embeds[mask_pos] = soft_emb return logits, next_soft_embeds def _stable_softmax(self, logits: torch.Tensor, T: float = 1.0, dim: int = -1, eps: float = 1e-12) -> torch.Tensor: """Numerically stable softmax with temperature T > 0.""" z = logits / max(T, eps) z = z - z.max(dim=dim, keepdim=True).values # subtract max z = torch.exp(z) # safe since z <= 0 z_sum = z.sum(dim=dim, keepdim=True) return z / z_sum.clamp(min=eps) def normalize(self, logits: torch.Tensor) -> torch.Tensor: """Normalize logits -> mixing weights. Shape: (B, L, V) -> (B, L, V)""" norm = self.config.normalization.lower() T = self.config.temperature V = logits.shape[-1] if norm == "none": return logits if norm == "softmax": return torch.softmax(logits / T, dim=-1) if norm == "stable_softmax": return self._stable_softmax(logits, T=T, dim=-1) raise ValueError(f"Unknown normalization: {norm}") def step_weight(self, t: int, T: int) -> float: """Loss weight for step t of T.""" lw = self.config.loss_weight if lw == "linear": return (t + 1) / T if lw == "uniform": return 1.0 if lw == "last_1": return 1.0 if t == T - 1 else 0.0 if lw == "last_2": return 1.0 if T - t <= 2 else 0.0 raise ValueError(f"Unknown loss_weight: {lw}") # ==================== CONVERGENCE SCHEDULE SYSTEM ==================== # # The core idea: control WHEN each position is allowed to converge. # # Schedule types: # - "linear": All positions converge at the same rate # - "causal": Early positions converge first, late positions last # # Effects (mechanisms to enforce the schedule): # - temperature: Raise temperature for positions not yet allowed to converge # - entropy_floor: Force minimum entropy # - entropy_target: Force exact entropy via bisection (ARChitects-style) # - smear: Spread probability across neighboring positions # - noise: Add Gaussian noise to logits # # Each effect uses per-position "convergence progress" (0=uncertain, 1=can converge) def _compute_convergence_progress( self, iteration: int, total_iterations: int, seq_length: int, mask_positions: torch.Tensor, schedule: str = "linear", causal_strength: float = 1.0, device: torch.device = None, dtype: torch.dtype = None, ) -> torch.Tensor: """ Compute per-position convergence progress based on schedule. Args: iteration: Current iteration (0-indexed) total_iterations: Total number of iterations seq_length: Full sequence length L mask_positions: Position indices of masked tokens (num_masked,) schedule: "linear" or "causal" causal_strength: How much faster early positions converge (for causal schedule) Returns: progress: (num_masked,) tensor with values in [0, 1] 0 = position should be maximally uncertain 1 = position is allowed to fully converge """ base_progress = iteration / max(total_iterations - 1, 1) if schedule == "linear": return torch.full( (mask_positions.shape[0],), base_progress, device=device, dtype=dtype ) elif schedule == "causal": position_factor = mask_positions.float() / max(seq_length - 1, 1) effective_progress = base_progress * (1.0 + causal_strength * (1.0 - position_factor)) return effective_progress.clamp(0.0, 1.0) else: raise ValueError(f"Unknown schedule: {schedule}") def _apply_temperature_effect( self, logits: torch.Tensor, progress: torch.Tensor, temperature_max: float, ) -> torch.Tensor: """ Apply per-position temperature scaling based on convergence progress. Low progress = high temperature (uncertain), high progress = temperature 1.0. """ if temperature_max <= 0: return logits temperature = 1.0 + temperature_max * (1.0 - progress) temperature = temperature.unsqueeze(-1) return logits / temperature def _apply_entropy_floor_effect( self, probs: torch.Tensor, progress: torch.Tensor, entropy_floor_max: float, ) -> torch.Tensor: """ Ensure minimum entropy based on convergence progress. Low progress = high entropy floor, high progress = no floor. NOTE: This is a ONE-SIDED constraint (floor only). """ if entropy_floor_max <= 0: return probs entropy_floor = entropy_floor_max * (1.0 - progress) log_probs = torch.log(probs + 1e-10) current_entropy = -(probs * log_probs).sum(dim=-1) below_floor = current_entropy < entropy_floor if not below_floor.any(): return probs logits = torch.log(probs + 1e-10) target_ratio = entropy_floor / (current_entropy + 1e-10) temperature = torch.ones_like(current_entropy) temperature[below_floor] = target_ratio[below_floor].clamp(1.0, 10.0) scaled_probs = torch.softmax(logits / temperature.unsqueeze(-1), dim=-1) result = probs.clone() result[below_floor] = scaled_probs[below_floor] return result def _find_temperature_for_target_entropy( self, logits: torch.Tensor, target_entropy: torch.Tensor, tol: float = 1e-3, max_iter: int = 32, T_low: float = 1e-6, T_high_init: float = 1.0, max_T: float = 100.0, ) -> torch.Tensor: """ Find per-position temperatures that achieve exactly the target entropy. Uses bisection search, adapted from ARChitects' implementation. Args: logits: Raw logits (num_positions, V) target_entropy: Target entropy per position (num_positions,) or scalar tol: Entropy tolerance for convergence max_iter: Maximum bisection iterations T_low: Minimum temperature (near-greedy) T_high_init: Initial upper bound for search max_T: Maximum allowed temperature Returns: temperatures: (num_positions,) temperatures that achieve target entropy """ N, V = logits.shape device, dtype = logits.device, logits.dtype H_max = torch.log(torch.tensor(V, device=device, dtype=dtype)) if target_entropy.dim() == 0: target = target_entropy.expand(N).clone() else: target = target_entropy.clone() target = target.clamp(0.0, H_max) def compute_entropy(logits_: torch.Tensor, temps: torch.Tensor) -> torch.Tensor: temps = temps.unsqueeze(-1).clamp(min=T_low) scaled = logits_ / temps scaled = scaled - scaled.max(dim=-1, keepdim=True).values probs = torch.softmax(scaled, dim=-1) log_probs = torch.log(probs + 1e-12) return -(probs * log_probs).sum(dim=-1) lo = torch.full((N,), T_low, device=device, dtype=dtype) hi = torch.full((N,), T_high_init, device=device, dtype=dtype) H_lo = compute_entropy(logits, lo) done_low = target <= (H_lo + tol) H_hi = compute_entropy(logits, hi) needs_expansion = (H_hi < target - tol) & ~done_low for _ in range(100): if not needs_expansion.any(): break hi[needs_expansion] = (hi[needs_expansion] * 2.0).clamp(max=max_T) H_hi[needs_expansion] = compute_entropy( logits[needs_expansion], hi[needs_expansion] ) needs_expansion = (H_hi < target - tol) & ~done_low & (hi < max_T - 1e-6) can_bisect = ~done_low & (H_hi >= target - tol) for _ in range(max_iter): if not can_bisect.any(): break mid = (lo + hi) / 2.0 H_mid = compute_entropy(logits, mid) too_low = (H_mid < target) & can_bisect lo[too_low] = mid[too_low] hi[~too_low & can_bisect] = mid[~too_low & can_bisect] converged = (hi - lo) <= tol * mid.clamp(min=1.0) can_bisect = can_bisect & ~converged temps = torch.zeros(N, device=device, dtype=dtype) temps[done_low] = T_low temps[~done_low] = (lo[~done_low] + hi[~done_low]) / 2.0 return temps def _apply_target_entropy_effect( self, logits: torch.Tensor, progress: torch.Tensor, entropy_target_max: float, entropy_target_min: float = 0.0, ) -> torch.Tensor: """ Adjust temperature to achieve EXACTLY the target entropy per position. This is a TWO-SIDED constraint: both raises and lowers entropy as needed. Args: logits: Raw logits (num_masked, V) progress: Per-position convergence progress (num_masked,) entropy_target_max: Target entropy at progress=0 entropy_target_min: Target entropy at progress=1 (usually ~0) Returns: probs: Probabilities with entropy matching targets """ if entropy_target_max <= 0: return torch.softmax(logits, dim=-1) target_entropy = entropy_target_max * (1.0 - progress) + entropy_target_min * progress temps = self._find_temperature_for_target_entropy(logits, target_entropy) temps = temps.unsqueeze(-1).clamp(min=1e-6) return torch.softmax(logits / temps, dim=-1) def _apply_smear_effect( self, probs: torch.Tensor, mask_pos: torch.Tensor, progress_full: torch.Tensor, smear_sigma_max: float, ) -> torch.Tensor: """ Apply positional smearing with per-position sigma based on progress. Low progress = high smearing, high progress = no smearing. Note: This operates on full (B, L, V) tensor because smearing mixes across positions. """ if smear_sigma_max <= 0: return probs B, L, V = probs.shape sigma_per_pos = smear_sigma_max * (1.0 - progress_full) avg_sigma = sigma_per_pos[mask_pos].mean().item() if avg_sigma < 0.1: return probs positions = torch.arange(L, device=probs.device, dtype=probs.dtype) diff = positions.unsqueeze(0) - positions.unsqueeze(1) kernel = torch.exp(-0.5 * (diff / avg_sigma) ** 2) kernel = kernel / kernel.sum(dim=1, keepdim=True) smeared = torch.einsum('ij,bjv->biv', kernel, probs) smeared = smeared / smeared.sum(dim=-1, keepdim=True).clamp(min=1e-10) blend = progress_full.unsqueeze(-1) result = blend * probs + (1 - blend) * smeared output = probs.clone() output[mask_pos] = result[mask_pos] return output def _apply_noise_effect( self, logits: torch.Tensor, progress: torch.Tensor, noise_std_max: float, ) -> torch.Tensor: """ Add Gaussian noise to logits based on convergence progress. Low progress = high noise, high progress = no noise. """ if noise_std_max <= 0: return logits noise_std = noise_std_max * (1.0 - progress) noise_std = noise_std.unsqueeze(-1) noise = torch.randn_like(logits) * noise_std return logits + noise def _apply_iteration_rope( self, embeds: torch.Tensor, iteration: int, total_iterations: int, dim_fraction: float = 0.25, base: float = 10000.0, ) -> torch.Tensor: """ Apply rotary embedding based on iteration progress. Uses a subset of dimensions to avoid interfering with position RoPE. """ if dim_fraction <= 0: return embeds H = embeds.shape[-1] rot_dim = int(H * dim_fraction) rot_dim = rot_dim - (rot_dim % 2) if rot_dim < 2: return embeds progress = iteration / max(total_iterations - 1, 1) inv_freq = 1.0 / (base ** (torch.arange(0, rot_dim, 2, device=embeds.device, dtype=embeds.dtype) / rot_dim)) angles = progress * inv_freq * 3.14159 cos, sin = torch.cos(angles), torch.sin(angles) if embeds.dim() == 2: cos, sin = cos.unsqueeze(0), sin.unsqueeze(0) elif embeds.dim() == 3: cos = cos.unsqueeze(0).unsqueeze(0) sin = sin.unsqueeze(0).unsqueeze(0) embeds_out = embeds.clone() x1, x2 = embeds[..., -rot_dim::2], embeds[..., -rot_dim+1::2] embeds_out[..., -rot_dim::2] = x1 * cos - x2 * sin embeds_out[..., -rot_dim+1::2] = x1 * sin + x2 * cos return embeds_out # ==================== FLOW MATCHING ==================== def _sample_flow_matching_t(self, num_tokens: int, device: torch.device) -> torch.Tensor: """Sample per-token time levels for flow matching. Returns: t: (num_tokens,) tensor of time levels in [t_min, t_max] """ dist = self.config.flow_matching_t_distribution if dist == "logit_normal": z = torch.randn(num_tokens, device=device) z = z * self.config.flow_matching_t_logit_std + self.config.flow_matching_t_logit_mean t = torch.sigmoid(z) elif dist == "uniform": t = torch.empty(num_tokens, device=device).uniform_(0, 1) else: raise ValueError(f"Unknown flow_matching_t_distribution: {dist}") return t.clamp(self.config.flow_matching_t_min, self.config.flow_matching_t_max) def compute_flow_matching_distillation_loss( self, input_ids: torch.Tensor, teacher_logits: torch.Tensor, labels: torch.Tensor, flow_noise_embed: torch.Tensor, flow_t: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, ) -> SelfDistillationOutput: """ CFM flow matching distillation: teacher sees state at time t, student sees noisier state at time s < t on the same interpolation path. Both should predict the same endpoint (target token). The student must learn to refine from noisier inputs by matching the teacher's predictions. Args: input_ids: Input with [MASK] tokens at positions to predict teacher_logits: Logits from the forward pass (will be detached) labels: Target tokens at masked positions (-100 elsewhere) flow_noise_embed: (num_masked, H) noise embeddings from forward flow_t: (num_masked,) per-token time levels from forward attention_mask: Standard attention mask position_ids: Position IDs (if needed by base model) Returns: SelfDistillationOutput with loss, logits, time gap, and diagnostics """ mask_id = self.config.mask_token_id mask_pos = (input_ids == mask_id) # (B, L) device = input_ids.device num_masked = mask_pos.sum().item() if num_masked == 0: zero = torch.tensor(0.0, device=device, requires_grad=True) dummy = torch.zeros(1, device=device) return SelfDistillationOutput(zero, dummy, dummy, 0.0, 0.0, 0.0, 1.0) teacher_logits = teacher_logits.detach() embed_weight = self.embed_weight mask_emb = embed_weight[mask_id] # (H,) base_embeds = self.get_input_embeddings()(input_ids) # (B, L, H) # Target embeddings from labels target_ids = labels[mask_pos] # (num_masked,) target_embed = embed_weight[target_ids] # (num_masked, H) # Sample student time s ~ U(0, t) per token s_per_token = flow_t * torch.rand(num_masked, device=device) # (num_masked,) # Student state: same noise, earlier time (noisier) s_col = s_per_token.unsqueeze(-1).to(base_embeds.dtype) # (num_masked, 1) student_interp = (1 - s_col) * flow_noise_embed + s_col * target_embed if self.config.flow_matching_mask_scale: student_masked_embeds = student_interp + (1 - s_col) * mask_emb else: student_masked_embeds = student_interp + mask_emb # Build full student input (detached — gradient only flows through student's forward) student_embeds = base_embeds.detach().clone() student_embeds[mask_pos] = student_masked_embeds.detach() student_inputs = torch.where( mask_pos.unsqueeze(-1), student_embeds, base_embeds.detach() ) if attention_mask is None: attention_mask = torch.ones_like(input_ids, dtype=base_embeds.dtype) student_out = self.mlm( inputs_embeds=student_inputs, attention_mask=attention_mask, position_ids=position_ids, return_dict=True, ) student_logits = student_out.logits # (B, L, V) — has gradient # KL divergence loss on masked positions t_logits = teacher_logits[mask_pos] # (num_masked, V) s_logits = student_logits[mask_pos] # (num_masked, V) teacher_probs = F.softmax(t_logits, dim=-1) student_log_probs = F.log_softmax(s_logits, dim=-1) kl_loss = F.kl_div( student_log_probs, teacher_probs, reduction="batchmean", ) # Diagnostic metrics with torch.no_grad(): teacher_log_probs = torch.log(teacher_probs + 1e-10) teacher_entropy = -(teacher_probs * teacher_log_probs).sum(dim=-1).mean().item() student_probs = F.softmax(s_logits.detach(), dim=-1) student_log_probs_det = torch.log(student_probs + 1e-10) student_entropy = -(student_probs * student_log_probs_det).sum(dim=-1).mean().item() agreement = (t_logits.argmax(dim=-1) == s_logits.detach().argmax(dim=-1)).float().mean().item() mean_time_gap = (flow_t - s_per_token).mean().item() return SelfDistillationOutput( loss=kl_loss, teacher_logits=teacher_logits, student_logits=student_logits, degradation_temperature=mean_time_gap, teacher_entropy=teacher_entropy, student_entropy=student_entropy, agreement_rate=agreement, ) # ==================== SELF-DISTILLATION (legacy) ==================== def compute_self_distillation_loss( self, input_ids: torch.Tensor, teacher_logits: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, temperature_min: Optional[float] = None, temperature_max: Optional[float] = None, temperature_distribution: Optional[str] = None, ) -> SelfDistillationOutput: """ CFM-style self-distillation: model's predictions should be consistent across different levels of input degradation. Process: 1. Take teacher logits (from standard forward pass, DETACHED) 2. Degrade: per-token random temperature → softer soft embeddings 3. Student: forward pass from degraded embeddings → logits (has grad) 4. Loss: KL(teacher || student) on masked positions Each masked token gets its own independently sampled degradation temperature, creating varied difficulty across the sequence. Args: input_ids: Input with [MASK] tokens at positions to predict teacher_logits: Pre-computed teacher logits (will be detached). Typically outputs.all_logits[0] or outputs.logits from standard forward. attention_mask: Standard attention mask position_ids: Position IDs (if needed by base model) temperature_min: Min degradation temperature (default: config value) temperature_max: Max degradation temperature (default: config value) temperature_distribution: How to sample T (default: config value) Returns: SelfDistillationOutput with loss, logits, temperature, and diagnostics """ # Resolve defaults from config temperature_min = temperature_min if temperature_min is not None else self.config.self_distillation_temperature_min temperature_max = temperature_max if temperature_max is not None else self.config.self_distillation_temperature_max temperature_distribution = temperature_distribution if temperature_distribution is not None else self.config.self_distillation_temperature_distribution mask_id = self.config.mask_token_id mask_pos = (input_ids == mask_id) # (B, L) device = input_ids.device num_masked = mask_pos.sum().item() # Handle degenerate case: no masked positions if num_masked == 0: zero = torch.tensor(0.0, device=device, requires_grad=True) dummy = torch.zeros(1, device=device) return SelfDistillationOutput(zero, dummy, dummy, 1.0, 0.0, 0.0, 1.0) # Ensure teacher logits are detached teacher_logits = teacher_logits.detach() embed_weight = self.embed_weight mask_emb = embed_weight[mask_id] # (H,) base_embeds = self.get_input_embeddings()(input_ids) # (B, L, H) # ===== STEP 1: Sample per-token degradation temperatures ===== # Each masked position gets its own temperature independently if temperature_distribution == "log_uniform": log_min = torch.tensor(temperature_min, device=device).log() log_max = torch.tensor(temperature_max, device=device).log() log_T = torch.empty(num_masked, device=device).uniform_(log_min.item(), log_max.item()) T_per_token = log_T.exp() # (num_masked,) elif temperature_distribution == "uniform": T_per_token = torch.empty(num_masked, device=device).uniform_( temperature_min, temperature_max ) # (num_masked,) else: raise ValueError(f"Unknown temperature distribution: {temperature_distribution}") T_mean = T_per_token.mean().item() # ===== STEP 2: Create degraded soft embeddings ===== # Per-token temperature scaling: each position gets its own T masked_teacher_logits = teacher_logits[mask_pos] # (num_masked, V) degraded_probs = F.softmax(masked_teacher_logits / T_per_token.unsqueeze(-1), dim=-1).to(embed_weight.dtype) degraded_soft = degraded_probs @ embed_weight + mask_emb degraded_soft_embeds = base_embeds.clone() degraded_soft_embeds[mask_pos] = degraded_soft degraded_soft_embeds = degraded_soft_embeds.detach() # ===== STEP 3: Student forward from degraded input ===== student_inputs = torch.where( mask_pos.unsqueeze(-1), degraded_soft_embeds, base_embeds.detach() ) if attention_mask is None: attention_mask = torch.ones_like(input_ids, dtype=base_embeds.dtype) student_out = self.mlm( inputs_embeds=student_inputs, attention_mask=attention_mask, position_ids=position_ids, return_dict=True, ) student_logits = student_out.logits # (B, L, V) — has gradient! # ===== STEP 4: KL divergence loss on masked positions ===== t_logits = teacher_logits[mask_pos] # (num_masked, V) s_logits = student_logits[mask_pos] # (num_masked, V) teacher_probs = F.softmax(t_logits, dim=-1) student_log_probs = F.log_softmax(s_logits, dim=-1) # KL(teacher || student) = sum teacher * (log_teacher - log_student) kl_loss = F.kl_div( student_log_probs, teacher_probs, reduction="batchmean", ) # ===== STEP 5: Compute diagnostic metrics ===== with torch.no_grad(): teacher_log_probs = torch.log(teacher_probs + 1e-10) teacher_entropy = -(teacher_probs * teacher_log_probs).sum(dim=-1).mean().item() student_probs = F.softmax(s_logits.detach(), dim=-1) student_log_probs_det = torch.log(student_probs + 1e-10) student_entropy = -(student_probs * student_log_probs_det).sum(dim=-1).mean().item() agreement = (t_logits.argmax(dim=-1) == s_logits.detach().argmax(dim=-1)).float().mean().item() return SelfDistillationOutput( loss=kl_loss, teacher_logits=teacher_logits, student_logits=student_logits, degradation_temperature=T_mean, teacher_entropy=teacher_entropy, student_entropy=student_entropy, agreement_rate=agreement, ) # ==================== MAIN SOFT EMBEDDING COMPUTATION ==================== @torch.no_grad() def _compute_next_soft_embeds( self, logits: torch.Tensor, mask_pos: torch.Tensor, base_embeds: torch.Tensor, prev_soft_embeds: Optional[torch.Tensor] = None, iteration: int = 0, total_iterations: int = 1, # === Schedule parameters (default to config values) === schedule: Optional[str] = None, causal_strength: Optional[float] = None, # === Effect parameters (default to config values) === temperature_max: Optional[float] = None, entropy_target_max: Optional[float] = None, entropy_floor_max: Optional[float] = None, smear_sigma_max: Optional[float] = None, noise_std_max: Optional[float] = None, iteration_rope_dim_fraction: Optional[float] = None, ) -> torch.Tensor: """ Compute soft embeddings from logits for the next iteration. This function implements a unified "convergence schedule" system that controls when each position is allowed to converge to a confident prediction. Schedule Types: "linear": All positions converge at the same rate (iteration-based only) "causal": Early positions converge first, late positions last Effects (mechanisms to enforce the schedule): temperature_max: High temperature = more uniform distribution (one-sided) entropy_target_max: Force EXACT entropy via bisection search (two-sided, recommended) entropy_floor_max: Force MINIMUM entropy (one-sided, only prevents too confident) smear_sigma_max: Spread probability across neighboring positions noise_std_max: Add Gaussian noise to logits All parameters default to their config values if not specified. Args: logits: Output logits from current iteration (B, L, V) mask_pos: Boolean mask indicating which positions are masked (B, L) base_embeds: Base token embeddings for non-masked positions (B, L, H) iteration: Current iteration index (0-indexed) total_iterations: Total number of iterations Returns: Soft embeddings for next iteration (B, L, H) """ # Use config values as defaults schedule = schedule if schedule is not None else self.config.schedule causal_strength = causal_strength if causal_strength is not None else self.config.causal_strength temperature_max = temperature_max if temperature_max is not None else self.config.temperature_max entropy_target_max = entropy_target_max if entropy_target_max is not None else self.config.entropy_target_max entropy_floor_max = entropy_floor_max if entropy_floor_max is not None else self.config.entropy_floor_max smear_sigma_max = smear_sigma_max if smear_sigma_max is not None else self.config.smear_sigma_max noise_std_max = noise_std_max if noise_std_max is not None else self.config.noise_std_max iteration_rope_dim_fraction = iteration_rope_dim_fraction if iteration_rope_dim_fraction is not None else self.config.iteration_rope_dim_fraction soft_embeds = base_embeds.clone() if not mask_pos.any(): return soft_embeds.detach() B, L, V = logits.shape device, dtype = logits.device, logits.dtype # Check if any effects are enabled has_effects = ( temperature_max > 0 or entropy_target_max > 0 or entropy_floor_max > 0 or smear_sigma_max > 0 or noise_std_max > 0 or iteration_rope_dim_fraction > 0 ) if not has_effects: # Simple path: no convergence schedule effects masked_logits = logits[mask_pos] embed_weight = self.embed_weight # Convert logits to mixing weights based on soft_embedding_method if self.config.soft_embedding_method == "none": weights = masked_logits elif self.config.soft_embedding_method == "l2_normalize": weights = F.normalize(masked_logits, p=2, dim=-1) else: weights = self.normalize(masked_logits) masked_soft = weights @ embed_weight mask_emb = embed_weight[self.config.mask_token_id] masked_soft = masked_soft + mask_emb # Apply EMA blending with previous soft embeddings if enabled ema_step = self.config.soft_embedding_ema_step if ema_step < 1.0 and prev_soft_embeds is not None: prev_masked_soft = prev_soft_embeds[mask_pos] masked_soft = (1.0 - ema_step) * prev_masked_soft + ema_step * masked_soft soft_embeds[mask_pos] = masked_soft return soft_embeds.detach() # ========== STEP 1: Compute per-position convergence progress ========== batch_indices, position_indices = torch.where(mask_pos) progress = self._compute_convergence_progress( iteration=iteration, total_iterations=total_iterations, seq_length=L, mask_positions=position_indices, schedule=schedule, causal_strength=causal_strength, device=device, dtype=dtype, ) # Compute full (B, L) progress for smearing if needed if smear_sigma_max > 0: all_positions = torch.arange(L, device=device, dtype=dtype) progress_full = self._compute_convergence_progress( iteration=iteration, total_iterations=total_iterations, seq_length=L, mask_positions=all_positions, schedule=schedule, causal_strength=causal_strength, device=device, dtype=dtype, ) progress_full = progress_full.unsqueeze(0).expand(B, -1) # ========== STEP 2: Apply smearing (needs full tensor) ========== full_probs = self.normalize(logits) if smear_sigma_max > 0: full_probs = self._apply_smear_effect( full_probs, mask_pos, progress_full, smear_sigma_max ) # ========== STEP 3: Extract masked positions ========== masked_logits = logits[mask_pos] masked_probs = full_probs[mask_pos] # ========== STEP 4: Apply temperature effect (on logits) ========== if temperature_max > 0 and entropy_target_max <= 0: masked_logits = self._apply_temperature_effect( masked_logits, progress, temperature_max ) masked_probs = torch.softmax(masked_logits, dim=-1) # ========== STEP 5: Apply noise effect (on logits) ========== if noise_std_max > 0: masked_logits_noisy = self._apply_noise_effect( torch.log(masked_probs + 1e-10), progress, noise_std_max ) masked_probs = torch.softmax(masked_logits_noisy, dim=-1) # ========== STEP 6: Apply entropy control ========== if entropy_target_max > 0: masked_probs = self._apply_target_entropy_effect( masked_logits, progress, entropy_target_max ) elif entropy_floor_max > 0: masked_probs = self._apply_entropy_floor_effect( masked_probs, progress, entropy_floor_max ) # ========== STEP 7: Compute soft embeddings ========== embed_weight = self.embed_weight # Convert to mixing weights based on soft_embedding_method if self.config.soft_embedding_method == "none": # No normalization - use raw logits directly weights = masked_logits elif self.config.soft_embedding_method == "l2_normalize": # L2 normalize bypasses all the softmax-based effects above weights = F.normalize(masked_logits, p=2, dim=-1) else: weights = masked_probs masked_soft = weights @ embed_weight mask_emb = embed_weight[self.config.mask_token_id] masked_soft = masked_soft + mask_emb # ========== STEP 8: Apply iteration RoPE ========== if iteration_rope_dim_fraction > 0: masked_soft = self._apply_iteration_rope( masked_soft, iteration, total_iterations, iteration_rope_dim_fraction ) # ========== STEP 8.5: Apply EMA blending ========== ema_step = self.config.soft_embedding_ema_step if ema_step < 1.0 and prev_soft_embeds is not None: prev_masked_soft = prev_soft_embeds[mask_pos] masked_soft = (1.0 - ema_step) * prev_masked_soft + ema_step * masked_soft # ========== STEP 9: Place back and return ========== soft_embeds[mask_pos] = masked_soft return soft_embeds.detach() @torch.no_grad() def _compute_iteration_metrics( self, logits: torch.Tensor, labels: torch.Tensor ) -> IterationMetrics: """ Compute token-level AND sequence-level metrics for a single iteration. Returns scalars only - no large tensor storage. Token-level metrics: - accuracy: fraction of correct token predictions - entropy: average entropy per token - softmax_ce: cross-entropy loss per token Sequence-level metrics: - full_sequence_accuracy: fraction of sequences where ALL tokens are correct - min_sequence_confidence: mean of minimum top-1 confidence per sequence """ B = logits.shape[0] # Move to CPU to avoid GPU OOM - metrics are for monitoring only logits = logits.detach().cpu().float() # float32 is sufficient for metrics target_labels = labels.detach().cpu().contiguous() mask = target_labels != -100 if mask.sum() == 0: return IterationMetrics( accuracy=0.0, entropy=0.0, softmax_ce=0.0, full_sequence_accuracy=0.0, min_sequence_confidence=0.0, ) logits = logits.contiguous() predictions = logits.argmax(dim=-1) correct = (predictions == target_labels) & mask # ===== TOKEN-LEVEL METRICS ===== # Token accuracy accuracy = (correct.sum() / mask.sum()).item() # Extract valid tokens for entropy/CE valid_logits = logits[mask] valid_labels = target_labels[mask] # Entropy (using log_softmax for numerical stability) log_probs = torch.nn.functional.log_softmax(valid_logits, dim=-1) probs = torch.exp(log_probs) entropy = -(probs * log_probs).sum(dim=-1).mean().item() # Cross-entropy softmax_ce = torch.nn.functional.cross_entropy( valid_logits, valid_labels, reduction="mean" ).item() # ===== SEQUENCE-LEVEL METRICS ===== # Check which sequences have valid tokens sequences_with_tokens = mask.any(dim=1) # (B,) num_valid_sequences = sequences_with_tokens.sum().item() if num_valid_sequences == 0: return IterationMetrics( accuracy=accuracy, entropy=entropy, softmax_ce=softmax_ce, full_sequence_accuracy=0.0, min_sequence_confidence=0.0, ) # Full sequence accuracy: all tokens in sequence must be correct num_correct_per_seq = correct.sum(dim=1) # (B,) num_tokens_per_seq = mask.sum(dim=1) # (B,) all_correct = (num_correct_per_seq == num_tokens_per_seq) & sequences_with_tokens full_seq_accuracy = (all_correct.sum() / num_valid_sequences).item() # Min sequence confidence: minimum top-1 probability within each sequence probs_full = torch.softmax(logits, dim=-1) # (B, L, V) - already float32 top1_confidence = probs_full.max(dim=-1).values # (B, L) min_confidences = [] for i in range(B): if sequences_with_tokens[i]: seq_confidences = top1_confidence[i][mask[i]] # (num_tokens_in_seq,) min_confidences.append(seq_confidences.min().item()) min_seq_conf = sum(min_confidences) / len(min_confidences) if min_confidences else 0.0 return IterationMetrics( accuracy=accuracy, entropy=entropy, softmax_ce=softmax_ce, full_sequence_accuracy=full_seq_accuracy, min_sequence_confidence=min_seq_conf, ) def _single_iteration( self, t: int, T: int, soft_embeds: torch.Tensor, base_embeds: torch.Tensor, mask_pos: torch.Tensor, attention_mask: Optional[torch.Tensor], labels: Optional[torch.Tensor], compute_metrics: bool, position_ids: Optional[torch.Tensor] = None, **kwargs, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[IterationMetrics]]: """ Execute a single iteration of recursive refinement. Args: t: Current iteration index (0 to T-1) T: Total number of iterations soft_embeds: Soft embeddings for mask positions base_embeds: Base token embeddings from input_ids mask_pos: Boolean mask of [MASK] positions (B, L) attention_mask: Attention mask for MLM labels: Target labels for loss computation compute_metrics: Whether to compute iteration metrics Returns: logits: Output logits from MLM (B, L, V) weighted_loss: Loss weighted by step_weight(t, T), or None if no labels metrics: IterationMetrics, or None if not requested """ # Blend soft embeddings (at mask positions) with base embeddings (at non-mask positions) inputs_embeds = torch.where(mask_pos.unsqueeze(-1), soft_embeds, base_embeds) # Forward through base MLM outputs = self.mlm( inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, labels=labels, return_dict=True, **kwargs, ) # Compute weighted loss for this iteration weighted_loss = outputs.loss if labels is not None: if weighted_loss is None: # Base model doesn't compute loss (e.g., LLaDA) - compute it ourselves # Only compute loss on MASKED positions (MDLM training) masked_logits = outputs.logits[mask_pos] # (num_masked, V) masked_labels = labels[mask_pos] # (num_masked,) loss_fct = CrossEntropyLoss() # -100 index = padding token weighted_loss = loss_fct(masked_logits, masked_labels) weighted_loss *= self.step_weight(t, T) # Compute iteration metrics if requested metrics = None if compute_metrics and labels is not None: metrics = self._compute_iteration_metrics(outputs.logits, labels) return outputs.logits, weighted_loss, metrics def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, num_recursions: Optional[int] = None, compute_iteration_metrics: bool = False, use_recursion_checkpointing: Optional[bool] = None, # Parameters for single-iteration training mode (DEPRECATED) prev_soft_embeds: Optional[torch.Tensor] = None, run_set_iteration: Optional[int] = None, # === Convergence schedule parameters (None = use config defaults) === schedule: Optional[str] = None, causal_strength: Optional[float] = None, # === Effect parameters (None = use config defaults) === temperature_max: Optional[float] = None, entropy_target_max: Optional[float] = None, entropy_floor_max: Optional[float] = None, smear_sigma_max: Optional[float] = None, noise_std_max: Optional[float] = None, iteration_rope_dim_fraction: Optional[float] = None, **kwargs, ) -> RecursiveMaskedLMOutput: """ Forward with recursive refinement. Supports three modes: 1. Checkpointed mode (default): Run all T recursions with gradient checkpointing. Gradients flow through the entire chain; activations recomputed during backward. 2. Non-checkpointed mode (use_recursion_checkpointing=False): Store all activations. Faster backward but higher memory. 3. Single-iteration mode (DEPRECATED - run_set_iteration is not None): Run only one iteration. Use use_recursion_checkpointing=True instead. Loss Weighting (config.loss_weight): "last_1": Only final iteration loss (enables learning convergence behavior) "last_2": Last 2 iterations "linear": All iterations, linearly weighted (default) "uniform": All iterations, uniformly weighted Recursion Checkpointing: use_recursion_checkpointing: Enable gradient checkpointing for iterations. True = checkpoint each iteration, recompute during backward (default). False = store all activations (higher memory, faster backward). Convergence Schedule Parameters: All schedule/effect parameters default to their config values if not specified. Pass explicit values to override config for this forward pass. schedule: "linear" or "causal" - controls when positions can converge causal_strength: How much faster early positions converge (causal only) temperature_max: Max temperature boost for uncertain positions entropy_target_max: Target entropy at progress=0 (two-sided, recommended) entropy_floor_max: Min entropy floor (one-sided) smear_sigma_max: Max Gaussian sigma for position smearing noise_std_max: Max std of Gaussian noise on logits iteration_rope_dim_fraction: Fraction of dims for iteration RoPE """ B, L = input_ids.shape V = self.embed_weight.shape[0] mask_id = self.config.mask_token_id if mask_id is None: raise ValueError("mask_token_id must be set") # Resolve config default for recursion checkpointing use_recursion_checkpointing = ( use_recursion_checkpointing if use_recursion_checkpointing is not None else self.config.use_recursion_checkpointing ) mask_pos = (input_ids == mask_id) # (B, L) base_embeds = self.get_input_embeddings()(input_ids) # (B, L, H) T = num_recursions or self.config.num_recursions weight_sum = sum(self.step_weight(i, T) for i in range(T)) # Bundle schedule kwargs to pass to _compute_next_soft_embeds schedule_kwargs = dict( schedule=schedule, causal_strength=causal_strength, temperature_max=temperature_max, entropy_target_max=entropy_target_max, entropy_floor_max=entropy_floor_max, smear_sigma_max=smear_sigma_max, noise_std_max=noise_std_max, iteration_rope_dim_fraction=iteration_rope_dim_fraction, ) # ===== SINGLE ITERATION MODE (DEPRECATED) ===== if run_set_iteration is not None: warnings.warn( "run_set_iteration is deprecated. Use use_recursion_checkpointing=True instead, " "which provides proper gradient flow through all iterations.", DeprecationWarning, stacklevel=2, ) t = run_set_iteration # Get soft embeddings for this iteration if t == 0: # t=0: Uniform prior = average embedding (equivalent to softmax(zeros) @ embed_weight) # We compute this efficiently via embed_weight.mean() rather than creating large zero tensors soft_embeds = base_embeds.clone() if mask_pos.any(): avg_embed = self.embed_weight.mean(dim=0) # (H,) - mean over all V tokens mask_emb = self.embed_weight[mask_id] soft_embeds[mask_pos] = avg_embed + mask_emb else: if prev_soft_embeds is None: raise ValueError(f"prev_soft_embeds must be provided for iteration {t}") soft_embeds = prev_soft_embeds logits, weighted_loss, metrics = self._single_iteration( t, T, soft_embeds, base_embeds, mask_pos, attention_mask, labels, compute_iteration_metrics, position_ids=position_ids, **kwargs ) # Normalize loss by total weight sum loss = weighted_loss / weight_sum if weighted_loss is not None else None # Compute soft embeddings for next iteration (if not last) next_soft_embeds = None if t < T - 1: next_soft_embeds = self._compute_next_soft_embeds( logits, mask_pos, base_embeds, iteration=t, total_iterations=T, **schedule_kwargs, ) return RecursiveMaskedLMOutput( loss=loss, logits=logits, next_soft_embeds=next_soft_embeds, iteration_metrics={t: metrics} if metrics is not None else None, ) # ===== CHECKPOINTED MODE (gradient flow through all iterations) ===== embed_weight = self.embed_weight mask_emb = embed_weight[mask_id] # (H,) # Temperature must be a tensor for checkpointing (checkpoint requires tensor inputs) temperature = torch.tensor( self.config.temperature, device=input_ids.device, dtype=base_embeds.dtype, ) # Ensure attention_mask is a tensor (required for checkpointing) if attention_mask is None: attention_mask = torch.ones(B, L, device=input_ids.device, dtype=base_embeds.dtype) # Initialize soft embeddings for masked positions soft_embeds = base_embeds.clone() flow_noise_embed = None flow_t_per_token = None if self.config.flow_matching_enabled and self.training and labels is not None and mask_pos.any(): # Flow matching: interpolate between random noise and target on the simplex num_masked = mask_pos.sum().item() V = embed_weight.shape[0] device = input_ids.device # Sample per-token time levels (logit-normal by default) flow_t_per_token = self._sample_flow_matching_t(num_masked, device) # Random noise embedding: sample on simplex, project to H-dim z = torch.randn(num_masked, V, device=device, dtype=base_embeds.dtype) p_noise = F.softmax(z * self.config.flow_matching_noise_scale, dim=-1).to(base_embeds.dtype) flow_noise_embed = p_noise @ embed_weight # (num_masked, H) — compact # Target embedding from labels target_ids = labels[mask_pos] # original token IDs at masked positions target_embed = embed_weight[target_ids] # (num_masked, H) # Interpolate in embedding space t_col = flow_t_per_token.unsqueeze(-1).to(base_embeds.dtype) # (num_masked, 1) interp_embed = (1 - t_col) * flow_noise_embed + t_col * target_embed # Add mask signal (binary or scaled) if self.config.flow_matching_mask_scale: soft_embeds[mask_pos] = interp_embed + (1 - t_col) * mask_emb else: soft_embeds[mask_pos] = interp_embed + mask_emb elif mask_pos.any(): # Standard uniform prior (average embedding + mask signal) avg_embed = embed_weight.mean(dim=0) # (H,) soft_embeds[mask_pos] = avg_embed + mask_emb iteration_metrics = {} if compute_iteration_metrics and labels is not None else None # Main recursion loop with optional checkpointing all_logits = [] for t in range(T): if self.training and use_recursion_checkpointing: # Use checkpointing: activations recomputed during backward # This maintains gradient flow while saving memory logits, soft_embeds = torch_checkpoint( self._single_iteration_checkpointable, soft_embeds, base_embeds, mask_pos, attention_mask, embed_weight, mask_emb, temperature, position_ids, use_reentrant=False, # Critical for nested checkpointing! ) else: # No checkpointing: store all activations (inference or explicit disable) logits, soft_embeds = self._single_iteration_checkpointable( soft_embeds, base_embeds, mask_pos, attention_mask, embed_weight, mask_emb, temperature, position_ids, ) all_logits.append(logits) # Compute iteration metrics if requested (no grad needed) if iteration_metrics is not None and labels is not None: with torch.no_grad(): iteration_metrics[t] = self._compute_iteration_metrics(logits, labels) # Return all logits for trainer to compute loss with proper normalization # Trainer handles: timestep-based weighting, iteration weighting, batch/sequence/token normalization return RecursiveMaskedLMOutput( loss=None, # Let trainer compute loss logits=logits, # Final logits for inference/metrics all_logits=all_logits if self.training else None, # Only needed during training iteration_metrics=iteration_metrics or None, flow_noise_embed=flow_noise_embed, # For flow matching distillation flow_t=flow_t_per_token, # For flow matching distillation ) @torch.no_grad() def _generate_flow_map( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.Tensor], num_steps: int, ) -> torch.Tensor: """Fill in mask positions using the CFM flow map update rule. Starts from a random point on the probability simplex and iteratively moves toward the model's predictions using the flow map step rule. Args: input_ids: Input with [MASK] tokens at positions to fill attention_mask: Attention mask position_ids: Position IDs num_steps: Number of flow map steps (finer = better, 1 step = greedy) Returns: Tensor with [MASK] positions filled with predicted tokens """ mask_pos = (input_ids == self.config.mask_token_id) num_masked = mask_pos.sum().item() if num_masked == 0: return input_ids.clone() device = input_ids.device V = self.embed_weight.shape[0] embed_weight = self.embed_weight mask_emb = embed_weight[self.config.mask_token_id] base_embeds = self.get_input_embeddings()(input_ids) # Start from random simplex point noise_scale = self.config.flow_matching_noise_scale p = F.softmax(torch.randn(num_masked, V, device=device, dtype=base_embeds.dtype) * noise_scale, dim=-1).to(base_embeds.dtype) times = torch.linspace(0, 1, num_steps + 1, device=device) for i in range(num_steps): t_now = times[i] t_next = times[i + 1] step_size = (t_next - t_now) / (1 - t_now) # Mask signal (binary or scaled) if self.config.flow_matching_mask_scale: mask_signal = (1 - t_now) * mask_emb else: mask_signal = mask_emb # Project current state to embedding space embed = p @ embed_weight + mask_signal soft_embeds = base_embeds.clone() soft_embeds[mask_pos] = embed inputs_embeds = torch.where(mask_pos.unsqueeze(-1), soft_embeds, base_embeds) outputs = self.mlm( inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, return_dict=True, ) pi = F.softmax(outputs.logits[mask_pos], dim=-1).to(p.dtype) # Flow map update: move toward model's prediction p = p + step_size * (pi - p) # Fix floating point drift off the simplex p = p.clamp(min=0) p = p / p.sum(dim=-1, keepdim=True) result = input_ids.clone() result[mask_pos] = p.argmax(dim=-1) return result @torch.no_grad() def generate( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, num_recursions: Optional[int] = None, # === Convergence schedule parameters (None = use config defaults) === schedule: Optional[str] = None, causal_strength: Optional[float] = None, # === Effect parameters (None = use config defaults) === temperature_max: Optional[float] = None, entropy_target_max: Optional[float] = None, entropy_floor_max: Optional[float] = None, smear_sigma_max: Optional[float] = None, noise_std_max: Optional[float] = None, iteration_rope_dim_fraction: Optional[float] = None, ) -> torch.Tensor: """Fill in mask positions via iterative refinement. When flow_matching_enabled, uses the CFM flow map update rule. Otherwise, uses standard recursive soft-token refinement. Args: input_ids: Input token IDs with [MASK] tokens at positions to fill attention_mask: Attention mask num_recursions: Override number of recursions/steps (default: config value) schedule: "linear" or "causal" convergence schedule causal_strength: How much faster early positions converge (causal only) temperature_max: Max temperature boost for uncertain positions entropy_target_max: Target entropy at progress=0 (two-sided) entropy_floor_max: Min entropy floor (one-sided) smear_sigma_max: Max Gaussian sigma for position smearing noise_std_max: Max std of Gaussian noise on logits iteration_rope_dim_fraction: Fraction of dims for iteration RoPE Returns: Tensor with [MASK] positions filled with predicted tokens """ num_steps = num_recursions or self.config.num_recursions if self.config.flow_matching_enabled: return self._generate_flow_map( input_ids, attention_mask, position_ids, num_steps ) out = self.forward( input_ids, attention_mask, position_ids=position_ids, num_recursions=num_steps, schedule=schedule, causal_strength=causal_strength, temperature_max=temperature_max, entropy_target_max=entropy_target_max, entropy_floor_max=entropy_floor_max, smear_sigma_max=smear_sigma_max, noise_std_max=noise_std_max, iteration_rope_dim_fraction=iteration_rope_dim_fraction, ) result = input_ids.clone() mask_pos = (input_ids == self.config.mask_token_id) result[mask_pos] = out.logits.argmax(dim=-1)[mask_pos] return result