| | from __future__ import annotations |
| | import warnings |
| | from dataclasses import dataclass |
| | from typing import NamedTuple, Optional |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from torch.nn import CrossEntropyLoss |
| | from torch.utils.checkpoint import checkpoint as torch_checkpoint |
| | from transformers import AutoConfig, AutoModelForMaskedLM, PreTrainedModel |
| | from transformers.modeling_outputs import MaskedLMOutput |
| | from transformers.utils import ModelOutput |
| |
|
| | from .configuration_recursive import RecursiveMLMConfig |
| |
|
| |
|
| | @dataclass |
| | class IterationMetrics(ModelOutput): |
| | """Metrics for a single iteration of recursive refinement.""" |
| | accuracy: Optional[float] = None |
| | entropy: Optional[float] = None |
| | softmax_ce: Optional[float] = None |
| | full_sequence_accuracy: Optional[float] = None |
| | min_sequence_confidence: Optional[float] = None |
| |
|
| |
|
| | @dataclass |
| | class RecursiveMaskedLMOutput(MaskedLMOutput): |
| | iteration_metrics: Optional[dict[int, IterationMetrics]] = None |
| | next_soft_embeds: Optional[torch.Tensor] = None |
| | all_logits: Optional[list[torch.Tensor]] = None |
| | |
| | flow_noise_embed: Optional[torch.Tensor] = None |
| | flow_t: Optional[torch.Tensor] = None |
| |
|
| |
|
| | class SelfDistillationOutput(NamedTuple): |
| | """Output from self-distillation forward pass.""" |
| | loss: torch.Tensor |
| | teacher_logits: torch.Tensor |
| | student_logits: torch.Tensor |
| | degradation_temperature: float |
| | teacher_entropy: float |
| | student_entropy: float |
| | agreement_rate: float |
| |
|
| |
|
| | class RecursiveMaskedLM(PreTrainedModel): |
| | """ |
| | Wraps any HF MLM with recursive soft-token refinement. |
| | |
| | At each step: |
| | 1. Normalize logits -> probs |
| | 2. Compute soft embeddings: probs @ embedding_weight + mask_embedding |
| | 3. Forward through MLM |
| | 4. Accumulate weighted loss |
| | """ |
| | config_class = RecursiveMLMConfig |
| | base_model_prefix = "mlm" |
| | supports_gradient_checkpointing = True |
| |
|
| | def __init__(self, config: RecursiveMLMConfig, base_model: Optional[PreTrainedModel] = None): |
| | super().__init__(config) |
| |
|
| | if base_model is not None: |
| | |
| | |
| | self.mlm = base_model |
| | elif config.base_model_config is not None: |
| | model_type = config.base_model_config.get("model_type", "") |
| | if model_type == "llada": |
| | from .configuration_llada import LLaDAConfig |
| | from .modeling_llada import LLaDAModelLM |
| | base_config = LLaDAConfig.from_dict(config.base_model_config) |
| | self.mlm = LLaDAModelLM(base_config) |
| | else: |
| | base_config = AutoConfig.for_model(**config.base_model_config) |
| | self.mlm = AutoModelForMaskedLM.from_config(base_config) |
| | |
| | self.post_init() |
| | else: |
| | raise ValueError("Need either base_model or config.base_model_config") |
| |
|
| | @classmethod |
| | def from_mlm_pretrained( |
| | cls, |
| | mlm_name_or_path: str, |
| | num_recursions: int = 8, |
| | normalization: str = "softmax", |
| | loss_weight: str = "linear", |
| | mask_token_id: Optional[int] = None, |
| | temperature: float = 1.0, |
| | gradient_steps: Optional[int] = None, |
| | |
| | schedule: str = "linear", |
| | causal_strength: float = 1.0, |
| | |
| | temperature_max: float = 0.0, |
| | entropy_target_max: float = 0.0, |
| | entropy_floor_max: float = 0.0, |
| | smear_sigma_max: float = 0.0, |
| | noise_std_max: float = 0.0, |
| | iteration_rope_dim_fraction: float = 0.0, |
| | use_recursion_checkpointing: bool = True, |
| | |
| | soft_embedding_method: str = "softmax", |
| | soft_embedding_ema_step: float = 1.0, |
| | |
| | flow_matching_enabled: bool = False, |
| | flow_matching_lambda: float = 0.5, |
| | flow_matching_t_distribution: str = "logit_normal", |
| | flow_matching_t_logit_mean: float = -0.4, |
| | flow_matching_t_logit_std: float = 1.0, |
| | flow_matching_t_min: float = 0.01, |
| | flow_matching_t_max: float = 0.99, |
| | flow_matching_mask_scale: bool = False, |
| | **model_kwargs, |
| | ) -> "RecursiveMaskedLM": |
| | """Load a pretrained MLM and wrap it for recursive refinement.""" |
| | base_model = AutoModelForMaskedLM.from_pretrained(mlm_name_or_path, **model_kwargs) |
| | return cls.from_base_model( |
| | base_model, |
| | num_recursions=num_recursions, |
| | normalization=normalization, |
| | loss_weight=loss_weight, |
| | mask_token_id=mask_token_id, |
| | temperature=temperature, |
| | gradient_steps=gradient_steps, |
| | schedule=schedule, |
| | causal_strength=causal_strength, |
| | temperature_max=temperature_max, |
| | entropy_target_max=entropy_target_max, |
| | entropy_floor_max=entropy_floor_max, |
| | smear_sigma_max=smear_sigma_max, |
| | noise_std_max=noise_std_max, |
| | iteration_rope_dim_fraction=iteration_rope_dim_fraction, |
| | use_recursion_checkpointing=use_recursion_checkpointing, |
| | soft_embedding_method=soft_embedding_method, |
| | soft_embedding_ema_step=soft_embedding_ema_step, |
| | flow_matching_enabled=flow_matching_enabled, |
| | flow_matching_lambda=flow_matching_lambda, |
| | flow_matching_t_distribution=flow_matching_t_distribution, |
| | flow_matching_t_logit_mean=flow_matching_t_logit_mean, |
| | flow_matching_t_logit_std=flow_matching_t_logit_std, |
| | flow_matching_t_min=flow_matching_t_min, |
| | flow_matching_t_max=flow_matching_t_max, |
| | flow_matching_mask_scale=flow_matching_mask_scale, |
| | ) |
| |
|
| | @classmethod |
| | def from_base_model( |
| | cls, |
| | base_model: PreTrainedModel, |
| | num_recursions: int = 8, |
| | normalization: str = "softmax", |
| | loss_weight: str = "linear", |
| | mask_token_id: Optional[int] = None, |
| | temperature: float = 1.0, |
| | gradient_steps: Optional[int] = None, |
| | |
| | schedule: str = "linear", |
| | causal_strength: float = 1.0, |
| | |
| | temperature_max: float = 0.0, |
| | entropy_target_max: float = 0.0, |
| | entropy_floor_max: float = 0.0, |
| | smear_sigma_max: float = 0.0, |
| | noise_std_max: float = 0.0, |
| | iteration_rope_dim_fraction: float = 0.0, |
| | use_recursion_checkpointing: bool = True, |
| | |
| | soft_embedding_method: str = "softmax", |
| | soft_embedding_ema_step: float = 1.0, |
| | |
| | flow_matching_enabled: bool = False, |
| | flow_matching_lambda: float = 0.5, |
| | flow_matching_t_distribution: str = "logit_normal", |
| | flow_matching_t_logit_mean: float = -0.4, |
| | flow_matching_t_logit_std: float = 1.0, |
| | flow_matching_t_min: float = 0.01, |
| | flow_matching_t_max: float = 0.99, |
| | flow_matching_mask_scale: bool = False, |
| | ) -> "RecursiveMaskedLM": |
| | """Wrap an existing model for recursive refinement. |
| | |
| | Use this for models not loadable via AutoModelForMaskedLM (e.g., LLaDA). |
| | |
| | Args: |
| | base_model: The base MLM model to wrap |
| | num_recursions: Number of recursive refinement steps |
| | normalization: Normalization method for logits (softmax, stable_softmax) |
| | loss_weight: Loss weighting scheme (last_1, last_2, linear, uniform) |
| | mask_token_id: Token ID for [MASK] |
| | temperature: Temperature for softmax normalization |
| | gradient_steps: Number of final steps to backprop through |
| | schedule: Convergence schedule type ("linear" or "causal") |
| | causal_strength: How much faster early positions converge (causal only) |
| | temperature_max: Max temperature boost for uncertain positions |
| | entropy_target_max: Target entropy at progress=0 (two-sided, recommended) |
| | entropy_floor_max: Min entropy floor (one-sided) |
| | smear_sigma_max: Max Gaussian sigma for position smearing |
| | noise_std_max: Max std of Gaussian noise on logits |
| | iteration_rope_dim_fraction: Fraction of dims for iteration RoPE |
| | use_recursion_checkpointing: Enable gradient checkpointing for iterations |
| | soft_embedding_method: How to convert logits to soft embeddings |
| | soft_embedding_ema_step: EMA step size (1.0 = no EMA, <1.0 = blend with previous) |
| | flow_matching_enabled: Enable CFM-inspired flow matching framework |
| | flow_matching_lambda: Weight of distillation KL loss relative to CE |
| | flow_matching_t_distribution: Time sampling distribution ("logit_normal" or "uniform") |
| | flow_matching_t_logit_mean: Mean of logit-normal distribution |
| | flow_matching_t_logit_std: Std of logit-normal distribution |
| | flow_matching_t_min: Minimum time value (clamp) |
| | flow_matching_t_max: Maximum time value (clamp) |
| | flow_matching_mask_scale: Scale mask_emb by (1-t) if True, binary if False |
| | """ |
| | config = RecursiveMLMConfig.from_base_model_config( |
| | base_model.config, |
| | num_recursions=num_recursions, |
| | normalization=normalization, |
| | loss_weight=loss_weight, |
| | mask_token_id=mask_token_id, |
| | temperature=temperature, |
| | gradient_steps=gradient_steps, |
| | schedule=schedule, |
| | causal_strength=causal_strength, |
| | temperature_max=temperature_max, |
| | entropy_target_max=entropy_target_max, |
| | entropy_floor_max=entropy_floor_max, |
| | smear_sigma_max=smear_sigma_max, |
| | noise_std_max=noise_std_max, |
| | iteration_rope_dim_fraction=iteration_rope_dim_fraction, |
| | use_recursion_checkpointing=use_recursion_checkpointing, |
| | soft_embedding_method=soft_embedding_method, |
| | soft_embedding_ema_step=soft_embedding_ema_step, |
| | flow_matching_enabled=flow_matching_enabled, |
| | flow_matching_lambda=flow_matching_lambda, |
| | flow_matching_t_distribution=flow_matching_t_distribution, |
| | flow_matching_t_logit_mean=flow_matching_t_logit_mean, |
| | flow_matching_t_logit_std=flow_matching_t_logit_std, |
| | flow_matching_t_min=flow_matching_t_min, |
| | flow_matching_t_max=flow_matching_t_max, |
| | flow_matching_mask_scale=flow_matching_mask_scale, |
| | ) |
| | return cls(config, base_model=base_model) |
| |
|
| | @property |
| | def embed_weight(self) -> torch.Tensor: |
| | return self.mlm.get_input_embeddings().weight |
| |
|
| | def get_input_embeddings(self): |
| | return self.mlm.get_input_embeddings() |
| |
|
| | def set_input_embeddings(self, value): |
| | self.mlm.set_input_embeddings(value) |
| |
|
| | def get_output_embeddings(self): |
| | return self.mlm.get_output_embeddings() |
| |
|
| | def set_output_embeddings(self, new_embeddings): |
| | self.mlm.set_output_embeddings(new_embeddings) |
| |
|
| | def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): |
| | """Enable gradient checkpointing with correct settings for recursion. |
| | |
| | Forces use_reentrant=False which is required for: |
| | - Nested checkpoint calls (base model + recursion checkpointing) |
| | - Models with frozen parameters |
| | - Complex gradient flows through soft embeddings |
| | """ |
| | if gradient_checkpointing_kwargs is None: |
| | gradient_checkpointing_kwargs = {} |
| | |
| | gradient_checkpointing_kwargs.setdefault("use_reentrant", False) |
| | self.mlm.gradient_checkpointing_enable(gradient_checkpointing_kwargs) |
| |
|
| | def gradient_checkpointing_disable(self): |
| | """Disable gradient checkpointing in the underlying MLM.""" |
| | self.mlm.gradient_checkpointing_disable() |
| |
|
| | def _single_iteration_checkpointable( |
| | self, |
| | soft_embeds: torch.Tensor, |
| | base_embeds: torch.Tensor, |
| | mask_pos: torch.Tensor, |
| | attention_mask: torch.Tensor, |
| | embed_weight: torch.Tensor, |
| | mask_emb: torch.Tensor, |
| | temperature: torch.Tensor, |
| | position_ids: Optional[torch.Tensor] = None, |
| | ) -> tuple[torch.Tensor, torch.Tensor]: |
| | """ |
| | Single differentiable iteration for checkpointing. |
| | |
| | This method performs one iteration of recursive refinement in a way that |
| | maintains gradient flow and is compatible with torch.utils.checkpoint. |
| | |
| | Args: |
| | soft_embeds: (B, L, H) - current soft embeddings |
| | base_embeds: (B, L, H) - original token embeddings |
| | mask_pos: (B, L) bool - which positions are masked |
| | attention_mask: (B, L) - attention mask for MLM |
| | embed_weight: (V, H) - embedding weight matrix |
| | mask_emb: (H,) - mask token embedding |
| | temperature: scalar tensor - softmax temperature |
| | |
| | Returns: |
| | logits: (B, L, V) - output logits from this iteration |
| | next_soft_embeds: (B, L, H) - soft embeddings for next iteration |
| | """ |
| | |
| | inputs_embeds = torch.where(mask_pos.unsqueeze(-1), soft_embeds, base_embeds) |
| |
|
| | |
| | outputs = self.mlm( |
| | inputs_embeds=inputs_embeds, |
| | attention_mask=attention_mask, |
| | position_ids=position_ids, |
| | return_dict=True, |
| | ) |
| | logits = outputs.logits |
| |
|
| | |
| | next_soft_embeds = base_embeds.clone() |
| | if mask_pos.any(): |
| | masked_logits = logits[mask_pos] |
| |
|
| | |
| | if self.config.soft_embedding_method == "none": |
| | |
| | weights = masked_logits |
| | elif self.config.soft_embedding_method == "l2_normalize": |
| | |
| | weights = F.normalize(masked_logits, p=2, dim=-1) |
| | else: |
| | |
| | weights = F.softmax(masked_logits / temperature, dim=-1) |
| |
|
| | soft_emb = weights @ embed_weight + mask_emb |
| |
|
| | |
| | ema_step = self.config.soft_embedding_ema_step |
| | if ema_step < 1.0: |
| | prev_soft_emb = soft_embeds[mask_pos] |
| | soft_emb = (1.0 - ema_step) * prev_soft_emb + ema_step * soft_emb |
| |
|
| | next_soft_embeds[mask_pos] = soft_emb |
| |
|
| | return logits, next_soft_embeds |
| |
|
| | def _stable_softmax(self, logits: torch.Tensor, T: float = 1.0, dim: int = -1, eps: float = 1e-12) -> torch.Tensor: |
| | """Numerically stable softmax with temperature T > 0.""" |
| | z = logits / max(T, eps) |
| | z = z - z.max(dim=dim, keepdim=True).values |
| | z = torch.exp(z) |
| | z_sum = z.sum(dim=dim, keepdim=True) |
| | return z / z_sum.clamp(min=eps) |
| |
|
| | def normalize(self, logits: torch.Tensor) -> torch.Tensor: |
| | """Normalize logits -> mixing weights. Shape: (B, L, V) -> (B, L, V)""" |
| | norm = self.config.normalization.lower() |
| | T = self.config.temperature |
| | V = logits.shape[-1] |
| |
|
| | if norm == "none": |
| | return logits |
| |
|
| | if norm == "softmax": |
| | return torch.softmax(logits / T, dim=-1) |
| |
|
| | if norm == "stable_softmax": |
| | return self._stable_softmax(logits, T=T, dim=-1) |
| |
|
| | raise ValueError(f"Unknown normalization: {norm}") |
| |
|
| | def step_weight(self, t: int, T: int) -> float: |
| | """Loss weight for step t of T.""" |
| | lw = self.config.loss_weight |
| | if lw == "linear": |
| | return (t + 1) / T |
| | if lw == "uniform": |
| | return 1.0 |
| | if lw == "last_1": |
| | return 1.0 if t == T - 1 else 0.0 |
| | if lw == "last_2": |
| | return 1.0 if T - t <= 2 else 0.0 |
| | raise ValueError(f"Unknown loss_weight: {lw}") |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | def _compute_convergence_progress( |
| | self, |
| | iteration: int, |
| | total_iterations: int, |
| | seq_length: int, |
| | mask_positions: torch.Tensor, |
| | schedule: str = "linear", |
| | causal_strength: float = 1.0, |
| | device: torch.device = None, |
| | dtype: torch.dtype = None, |
| | ) -> torch.Tensor: |
| | """ |
| | Compute per-position convergence progress based on schedule. |
| | |
| | Args: |
| | iteration: Current iteration (0-indexed) |
| | total_iterations: Total number of iterations |
| | seq_length: Full sequence length L |
| | mask_positions: Position indices of masked tokens (num_masked,) |
| | schedule: "linear" or "causal" |
| | causal_strength: How much faster early positions converge (for causal schedule) |
| | |
| | Returns: |
| | progress: (num_masked,) tensor with values in [0, 1] |
| | 0 = position should be maximally uncertain |
| | 1 = position is allowed to fully converge |
| | """ |
| | base_progress = iteration / max(total_iterations - 1, 1) |
| |
|
| | if schedule == "linear": |
| | return torch.full( |
| | (mask_positions.shape[0],), |
| | base_progress, |
| | device=device, |
| | dtype=dtype |
| | ) |
| |
|
| | elif schedule == "causal": |
| | position_factor = mask_positions.float() / max(seq_length - 1, 1) |
| | effective_progress = base_progress * (1.0 + causal_strength * (1.0 - position_factor)) |
| | return effective_progress.clamp(0.0, 1.0) |
| |
|
| | else: |
| | raise ValueError(f"Unknown schedule: {schedule}") |
| |
|
| | def _apply_temperature_effect( |
| | self, |
| | logits: torch.Tensor, |
| | progress: torch.Tensor, |
| | temperature_max: float, |
| | ) -> torch.Tensor: |
| | """ |
| | Apply per-position temperature scaling based on convergence progress. |
| | Low progress = high temperature (uncertain), high progress = temperature 1.0. |
| | """ |
| | if temperature_max <= 0: |
| | return logits |
| |
|
| | temperature = 1.0 + temperature_max * (1.0 - progress) |
| | temperature = temperature.unsqueeze(-1) |
| |
|
| | return logits / temperature |
| |
|
| | def _apply_entropy_floor_effect( |
| | self, |
| | probs: torch.Tensor, |
| | progress: torch.Tensor, |
| | entropy_floor_max: float, |
| | ) -> torch.Tensor: |
| | """ |
| | Ensure minimum entropy based on convergence progress. |
| | Low progress = high entropy floor, high progress = no floor. |
| | |
| | NOTE: This is a ONE-SIDED constraint (floor only). |
| | """ |
| | if entropy_floor_max <= 0: |
| | return probs |
| |
|
| | entropy_floor = entropy_floor_max * (1.0 - progress) |
| |
|
| | log_probs = torch.log(probs + 1e-10) |
| | current_entropy = -(probs * log_probs).sum(dim=-1) |
| |
|
| | below_floor = current_entropy < entropy_floor |
| |
|
| | if not below_floor.any(): |
| | return probs |
| |
|
| | logits = torch.log(probs + 1e-10) |
| |
|
| | target_ratio = entropy_floor / (current_entropy + 1e-10) |
| | temperature = torch.ones_like(current_entropy) |
| | temperature[below_floor] = target_ratio[below_floor].clamp(1.0, 10.0) |
| |
|
| | scaled_probs = torch.softmax(logits / temperature.unsqueeze(-1), dim=-1) |
| |
|
| | result = probs.clone() |
| | result[below_floor] = scaled_probs[below_floor] |
| | return result |
| |
|
| | def _find_temperature_for_target_entropy( |
| | self, |
| | logits: torch.Tensor, |
| | target_entropy: torch.Tensor, |
| | tol: float = 1e-3, |
| | max_iter: int = 32, |
| | T_low: float = 1e-6, |
| | T_high_init: float = 1.0, |
| | max_T: float = 100.0, |
| | ) -> torch.Tensor: |
| | """ |
| | Find per-position temperatures that achieve exactly the target entropy. |
| | Uses bisection search, adapted from ARChitects' implementation. |
| | |
| | Args: |
| | logits: Raw logits (num_positions, V) |
| | target_entropy: Target entropy per position (num_positions,) or scalar |
| | tol: Entropy tolerance for convergence |
| | max_iter: Maximum bisection iterations |
| | T_low: Minimum temperature (near-greedy) |
| | T_high_init: Initial upper bound for search |
| | max_T: Maximum allowed temperature |
| | |
| | Returns: |
| | temperatures: (num_positions,) temperatures that achieve target entropy |
| | """ |
| | N, V = logits.shape |
| | device, dtype = logits.device, logits.dtype |
| | H_max = torch.log(torch.tensor(V, device=device, dtype=dtype)) |
| |
|
| | if target_entropy.dim() == 0: |
| | target = target_entropy.expand(N).clone() |
| | else: |
| | target = target_entropy.clone() |
| | target = target.clamp(0.0, H_max) |
| |
|
| | def compute_entropy(logits_: torch.Tensor, temps: torch.Tensor) -> torch.Tensor: |
| | temps = temps.unsqueeze(-1).clamp(min=T_low) |
| | scaled = logits_ / temps |
| | scaled = scaled - scaled.max(dim=-1, keepdim=True).values |
| | probs = torch.softmax(scaled, dim=-1) |
| | log_probs = torch.log(probs + 1e-12) |
| | return -(probs * log_probs).sum(dim=-1) |
| |
|
| | lo = torch.full((N,), T_low, device=device, dtype=dtype) |
| | hi = torch.full((N,), T_high_init, device=device, dtype=dtype) |
| |
|
| | H_lo = compute_entropy(logits, lo) |
| |
|
| | done_low = target <= (H_lo + tol) |
| |
|
| | H_hi = compute_entropy(logits, hi) |
| | needs_expansion = (H_hi < target - tol) & ~done_low |
| |
|
| | for _ in range(100): |
| | if not needs_expansion.any(): |
| | break |
| | hi[needs_expansion] = (hi[needs_expansion] * 2.0).clamp(max=max_T) |
| | H_hi[needs_expansion] = compute_entropy( |
| | logits[needs_expansion], hi[needs_expansion] |
| | ) |
| | needs_expansion = (H_hi < target - tol) & ~done_low & (hi < max_T - 1e-6) |
| |
|
| | can_bisect = ~done_low & (H_hi >= target - tol) |
| |
|
| | for _ in range(max_iter): |
| | if not can_bisect.any(): |
| | break |
| |
|
| | mid = (lo + hi) / 2.0 |
| | H_mid = compute_entropy(logits, mid) |
| |
|
| | too_low = (H_mid < target) & can_bisect |
| | lo[too_low] = mid[too_low] |
| | hi[~too_low & can_bisect] = mid[~too_low & can_bisect] |
| |
|
| | converged = (hi - lo) <= tol * mid.clamp(min=1.0) |
| | can_bisect = can_bisect & ~converged |
| |
|
| | temps = torch.zeros(N, device=device, dtype=dtype) |
| | temps[done_low] = T_low |
| | temps[~done_low] = (lo[~done_low] + hi[~done_low]) / 2.0 |
| |
|
| | return temps |
| |
|
| | def _apply_target_entropy_effect( |
| | self, |
| | logits: torch.Tensor, |
| | progress: torch.Tensor, |
| | entropy_target_max: float, |
| | entropy_target_min: float = 0.0, |
| | ) -> torch.Tensor: |
| | """ |
| | Adjust temperature to achieve EXACTLY the target entropy per position. |
| | This is a TWO-SIDED constraint: both raises and lowers entropy as needed. |
| | |
| | Args: |
| | logits: Raw logits (num_masked, V) |
| | progress: Per-position convergence progress (num_masked,) |
| | entropy_target_max: Target entropy at progress=0 |
| | entropy_target_min: Target entropy at progress=1 (usually ~0) |
| | |
| | Returns: |
| | probs: Probabilities with entropy matching targets |
| | """ |
| | if entropy_target_max <= 0: |
| | return torch.softmax(logits, dim=-1) |
| |
|
| | target_entropy = entropy_target_max * (1.0 - progress) + entropy_target_min * progress |
| |
|
| | temps = self._find_temperature_for_target_entropy(logits, target_entropy) |
| |
|
| | temps = temps.unsqueeze(-1).clamp(min=1e-6) |
| | return torch.softmax(logits / temps, dim=-1) |
| |
|
| | def _apply_smear_effect( |
| | self, |
| | probs: torch.Tensor, |
| | mask_pos: torch.Tensor, |
| | progress_full: torch.Tensor, |
| | smear_sigma_max: float, |
| | ) -> torch.Tensor: |
| | """ |
| | Apply positional smearing with per-position sigma based on progress. |
| | Low progress = high smearing, high progress = no smearing. |
| | |
| | Note: This operates on full (B, L, V) tensor because smearing mixes across positions. |
| | """ |
| | if smear_sigma_max <= 0: |
| | return probs |
| |
|
| | B, L, V = probs.shape |
| |
|
| | sigma_per_pos = smear_sigma_max * (1.0 - progress_full) |
| |
|
| | avg_sigma = sigma_per_pos[mask_pos].mean().item() |
| |
|
| | if avg_sigma < 0.1: |
| | return probs |
| |
|
| | positions = torch.arange(L, device=probs.device, dtype=probs.dtype) |
| | diff = positions.unsqueeze(0) - positions.unsqueeze(1) |
| | kernel = torch.exp(-0.5 * (diff / avg_sigma) ** 2) |
| | kernel = kernel / kernel.sum(dim=1, keepdim=True) |
| |
|
| | smeared = torch.einsum('ij,bjv->biv', kernel, probs) |
| | smeared = smeared / smeared.sum(dim=-1, keepdim=True).clamp(min=1e-10) |
| |
|
| | blend = progress_full.unsqueeze(-1) |
| | result = blend * probs + (1 - blend) * smeared |
| |
|
| | output = probs.clone() |
| | output[mask_pos] = result[mask_pos] |
| | return output |
| |
|
| | def _apply_noise_effect( |
| | self, |
| | logits: torch.Tensor, |
| | progress: torch.Tensor, |
| | noise_std_max: float, |
| | ) -> torch.Tensor: |
| | """ |
| | Add Gaussian noise to logits based on convergence progress. |
| | Low progress = high noise, high progress = no noise. |
| | """ |
| | if noise_std_max <= 0: |
| | return logits |
| |
|
| | noise_std = noise_std_max * (1.0 - progress) |
| | noise_std = noise_std.unsqueeze(-1) |
| |
|
| | noise = torch.randn_like(logits) * noise_std |
| | return logits + noise |
| |
|
| | def _apply_iteration_rope( |
| | self, |
| | embeds: torch.Tensor, |
| | iteration: int, |
| | total_iterations: int, |
| | dim_fraction: float = 0.25, |
| | base: float = 10000.0, |
| | ) -> torch.Tensor: |
| | """ |
| | Apply rotary embedding based on iteration progress. |
| | Uses a subset of dimensions to avoid interfering with position RoPE. |
| | """ |
| | if dim_fraction <= 0: |
| | return embeds |
| |
|
| | H = embeds.shape[-1] |
| | rot_dim = int(H * dim_fraction) |
| | rot_dim = rot_dim - (rot_dim % 2) |
| |
|
| | if rot_dim < 2: |
| | return embeds |
| |
|
| | progress = iteration / max(total_iterations - 1, 1) |
| |
|
| | inv_freq = 1.0 / (base ** (torch.arange(0, rot_dim, 2, device=embeds.device, dtype=embeds.dtype) / rot_dim)) |
| | angles = progress * inv_freq * 3.14159 |
| | cos, sin = torch.cos(angles), torch.sin(angles) |
| |
|
| | if embeds.dim() == 2: |
| | cos, sin = cos.unsqueeze(0), sin.unsqueeze(0) |
| | elif embeds.dim() == 3: |
| | cos = cos.unsqueeze(0).unsqueeze(0) |
| | sin = sin.unsqueeze(0).unsqueeze(0) |
| |
|
| | embeds_out = embeds.clone() |
| | x1, x2 = embeds[..., -rot_dim::2], embeds[..., -rot_dim+1::2] |
| | embeds_out[..., -rot_dim::2] = x1 * cos - x2 * sin |
| | embeds_out[..., -rot_dim+1::2] = x1 * sin + x2 * cos |
| |
|
| | return embeds_out |
| |
|
| | |
| |
|
| | def _sample_flow_matching_t(self, num_tokens: int, device: torch.device) -> torch.Tensor: |
| | """Sample per-token time levels for flow matching. |
| | |
| | Returns: |
| | t: (num_tokens,) tensor of time levels in [t_min, t_max] |
| | """ |
| | dist = self.config.flow_matching_t_distribution |
| | if dist == "logit_normal": |
| | z = torch.randn(num_tokens, device=device) |
| | z = z * self.config.flow_matching_t_logit_std + self.config.flow_matching_t_logit_mean |
| | t = torch.sigmoid(z) |
| | elif dist == "uniform": |
| | t = torch.empty(num_tokens, device=device).uniform_(0, 1) |
| | else: |
| | raise ValueError(f"Unknown flow_matching_t_distribution: {dist}") |
| | return t.clamp(self.config.flow_matching_t_min, self.config.flow_matching_t_max) |
| |
|
| | def compute_flow_matching_distillation_loss( |
| | self, |
| | input_ids: torch.Tensor, |
| | teacher_logits: torch.Tensor, |
| | labels: torch.Tensor, |
| | flow_noise_embed: torch.Tensor, |
| | flow_t: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.Tensor] = None, |
| | ) -> SelfDistillationOutput: |
| | """ |
| | CFM flow matching distillation: teacher sees state at time t, student sees |
| | noisier state at time s < t on the same interpolation path. |
| | |
| | Both should predict the same endpoint (target token). The student must |
| | learn to refine from noisier inputs by matching the teacher's predictions. |
| | |
| | Args: |
| | input_ids: Input with [MASK] tokens at positions to predict |
| | teacher_logits: Logits from the forward pass (will be detached) |
| | labels: Target tokens at masked positions (-100 elsewhere) |
| | flow_noise_embed: (num_masked, H) noise embeddings from forward |
| | flow_t: (num_masked,) per-token time levels from forward |
| | attention_mask: Standard attention mask |
| | position_ids: Position IDs (if needed by base model) |
| | |
| | Returns: |
| | SelfDistillationOutput with loss, logits, time gap, and diagnostics |
| | """ |
| | mask_id = self.config.mask_token_id |
| | mask_pos = (input_ids == mask_id) |
| | device = input_ids.device |
| | num_masked = mask_pos.sum().item() |
| |
|
| | if num_masked == 0: |
| | zero = torch.tensor(0.0, device=device, requires_grad=True) |
| | dummy = torch.zeros(1, device=device) |
| | return SelfDistillationOutput(zero, dummy, dummy, 0.0, 0.0, 0.0, 1.0) |
| |
|
| | teacher_logits = teacher_logits.detach() |
| |
|
| | embed_weight = self.embed_weight |
| | mask_emb = embed_weight[mask_id] |
| | base_embeds = self.get_input_embeddings()(input_ids) |
| |
|
| | |
| | target_ids = labels[mask_pos] |
| | target_embed = embed_weight[target_ids] |
| |
|
| | |
| | s_per_token = flow_t * torch.rand(num_masked, device=device) |
| |
|
| | |
| | s_col = s_per_token.unsqueeze(-1).to(base_embeds.dtype) |
| | student_interp = (1 - s_col) * flow_noise_embed + s_col * target_embed |
| |
|
| | if self.config.flow_matching_mask_scale: |
| | student_masked_embeds = student_interp + (1 - s_col) * mask_emb |
| | else: |
| | student_masked_embeds = student_interp + mask_emb |
| |
|
| | |
| | student_embeds = base_embeds.detach().clone() |
| | student_embeds[mask_pos] = student_masked_embeds.detach() |
| |
|
| | student_inputs = torch.where( |
| | mask_pos.unsqueeze(-1), student_embeds, base_embeds.detach() |
| | ) |
| |
|
| | if attention_mask is None: |
| | attention_mask = torch.ones_like(input_ids, dtype=base_embeds.dtype) |
| |
|
| | student_out = self.mlm( |
| | inputs_embeds=student_inputs, |
| | attention_mask=attention_mask, |
| | position_ids=position_ids, |
| | return_dict=True, |
| | ) |
| | student_logits = student_out.logits |
| |
|
| | |
| | t_logits = teacher_logits[mask_pos] |
| | s_logits = student_logits[mask_pos] |
| |
|
| | teacher_probs = F.softmax(t_logits, dim=-1) |
| | student_log_probs = F.log_softmax(s_logits, dim=-1) |
| |
|
| | kl_loss = F.kl_div( |
| | student_log_probs, |
| | teacher_probs, |
| | reduction="batchmean", |
| | ) |
| |
|
| | |
| | with torch.no_grad(): |
| | teacher_log_probs = torch.log(teacher_probs + 1e-10) |
| | teacher_entropy = -(teacher_probs * teacher_log_probs).sum(dim=-1).mean().item() |
| |
|
| | student_probs = F.softmax(s_logits.detach(), dim=-1) |
| | student_log_probs_det = torch.log(student_probs + 1e-10) |
| | student_entropy = -(student_probs * student_log_probs_det).sum(dim=-1).mean().item() |
| |
|
| | agreement = (t_logits.argmax(dim=-1) == s_logits.detach().argmax(dim=-1)).float().mean().item() |
| |
|
| | mean_time_gap = (flow_t - s_per_token).mean().item() |
| |
|
| | return SelfDistillationOutput( |
| | loss=kl_loss, |
| | teacher_logits=teacher_logits, |
| | student_logits=student_logits, |
| | degradation_temperature=mean_time_gap, |
| | teacher_entropy=teacher_entropy, |
| | student_entropy=student_entropy, |
| | agreement_rate=agreement, |
| | ) |
| |
|
| | |
| |
|
| | def compute_self_distillation_loss( |
| | self, |
| | input_ids: torch.Tensor, |
| | teacher_logits: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.Tensor] = None, |
| | temperature_min: Optional[float] = None, |
| | temperature_max: Optional[float] = None, |
| | temperature_distribution: Optional[str] = None, |
| | ) -> SelfDistillationOutput: |
| | """ |
| | CFM-style self-distillation: model's predictions should be consistent |
| | across different levels of input degradation. |
| | |
| | Process: |
| | 1. Take teacher logits (from standard forward pass, DETACHED) |
| | 2. Degrade: per-token random temperature → softer soft embeddings |
| | 3. Student: forward pass from degraded embeddings → logits (has grad) |
| | 4. Loss: KL(teacher || student) on masked positions |
| | |
| | Each masked token gets its own independently sampled degradation |
| | temperature, creating varied difficulty across the sequence. |
| | |
| | Args: |
| | input_ids: Input with [MASK] tokens at positions to predict |
| | teacher_logits: Pre-computed teacher logits (will be detached). |
| | Typically outputs.all_logits[0] or outputs.logits from standard forward. |
| | attention_mask: Standard attention mask |
| | position_ids: Position IDs (if needed by base model) |
| | temperature_min: Min degradation temperature (default: config value) |
| | temperature_max: Max degradation temperature (default: config value) |
| | temperature_distribution: How to sample T (default: config value) |
| | |
| | Returns: |
| | SelfDistillationOutput with loss, logits, temperature, and diagnostics |
| | """ |
| | |
| | temperature_min = temperature_min if temperature_min is not None else self.config.self_distillation_temperature_min |
| | temperature_max = temperature_max if temperature_max is not None else self.config.self_distillation_temperature_max |
| | temperature_distribution = temperature_distribution if temperature_distribution is not None else self.config.self_distillation_temperature_distribution |
| |
|
| | mask_id = self.config.mask_token_id |
| | mask_pos = (input_ids == mask_id) |
| | device = input_ids.device |
| | num_masked = mask_pos.sum().item() |
| |
|
| | |
| | if num_masked == 0: |
| | zero = torch.tensor(0.0, device=device, requires_grad=True) |
| | dummy = torch.zeros(1, device=device) |
| | return SelfDistillationOutput(zero, dummy, dummy, 1.0, 0.0, 0.0, 1.0) |
| |
|
| | |
| | teacher_logits = teacher_logits.detach() |
| |
|
| | embed_weight = self.embed_weight |
| | mask_emb = embed_weight[mask_id] |
| | base_embeds = self.get_input_embeddings()(input_ids) |
| |
|
| | |
| | |
| | if temperature_distribution == "log_uniform": |
| | log_min = torch.tensor(temperature_min, device=device).log() |
| | log_max = torch.tensor(temperature_max, device=device).log() |
| | log_T = torch.empty(num_masked, device=device).uniform_(log_min.item(), log_max.item()) |
| | T_per_token = log_T.exp() |
| | elif temperature_distribution == "uniform": |
| | T_per_token = torch.empty(num_masked, device=device).uniform_( |
| | temperature_min, temperature_max |
| | ) |
| | else: |
| | raise ValueError(f"Unknown temperature distribution: {temperature_distribution}") |
| |
|
| | T_mean = T_per_token.mean().item() |
| |
|
| | |
| | |
| | masked_teacher_logits = teacher_logits[mask_pos] |
| | degraded_probs = F.softmax(masked_teacher_logits / T_per_token.unsqueeze(-1), dim=-1).to(embed_weight.dtype) |
| | degraded_soft = degraded_probs @ embed_weight + mask_emb |
| |
|
| | degraded_soft_embeds = base_embeds.clone() |
| | degraded_soft_embeds[mask_pos] = degraded_soft |
| | degraded_soft_embeds = degraded_soft_embeds.detach() |
| |
|
| | |
| | student_inputs = torch.where( |
| | mask_pos.unsqueeze(-1), degraded_soft_embeds, base_embeds.detach() |
| | ) |
| |
|
| | if attention_mask is None: |
| | attention_mask = torch.ones_like(input_ids, dtype=base_embeds.dtype) |
| |
|
| | student_out = self.mlm( |
| | inputs_embeds=student_inputs, |
| | attention_mask=attention_mask, |
| | position_ids=position_ids, |
| | return_dict=True, |
| | ) |
| | student_logits = student_out.logits |
| |
|
| | |
| | t_logits = teacher_logits[mask_pos] |
| | s_logits = student_logits[mask_pos] |
| |
|
| | teacher_probs = F.softmax(t_logits, dim=-1) |
| | student_log_probs = F.log_softmax(s_logits, dim=-1) |
| |
|
| | |
| | kl_loss = F.kl_div( |
| | student_log_probs, |
| | teacher_probs, |
| | reduction="batchmean", |
| | ) |
| |
|
| | |
| | with torch.no_grad(): |
| | teacher_log_probs = torch.log(teacher_probs + 1e-10) |
| | teacher_entropy = -(teacher_probs * teacher_log_probs).sum(dim=-1).mean().item() |
| |
|
| | student_probs = F.softmax(s_logits.detach(), dim=-1) |
| | student_log_probs_det = torch.log(student_probs + 1e-10) |
| | student_entropy = -(student_probs * student_log_probs_det).sum(dim=-1).mean().item() |
| |
|
| | agreement = (t_logits.argmax(dim=-1) == s_logits.detach().argmax(dim=-1)).float().mean().item() |
| |
|
| | return SelfDistillationOutput( |
| | loss=kl_loss, |
| | teacher_logits=teacher_logits, |
| | student_logits=student_logits, |
| | degradation_temperature=T_mean, |
| | teacher_entropy=teacher_entropy, |
| | student_entropy=student_entropy, |
| | agreement_rate=agreement, |
| | ) |
| |
|
| | |
| |
|
| | @torch.no_grad() |
| | def _compute_next_soft_embeds( |
| | self, |
| | logits: torch.Tensor, |
| | mask_pos: torch.Tensor, |
| | base_embeds: torch.Tensor, |
| | prev_soft_embeds: Optional[torch.Tensor] = None, |
| | iteration: int = 0, |
| | total_iterations: int = 1, |
| | |
| | schedule: Optional[str] = None, |
| | causal_strength: Optional[float] = None, |
| | |
| | temperature_max: Optional[float] = None, |
| | entropy_target_max: Optional[float] = None, |
| | entropy_floor_max: Optional[float] = None, |
| | smear_sigma_max: Optional[float] = None, |
| | noise_std_max: Optional[float] = None, |
| | iteration_rope_dim_fraction: Optional[float] = None, |
| | ) -> torch.Tensor: |
| | """ |
| | Compute soft embeddings from logits for the next iteration. |
| | |
| | This function implements a unified "convergence schedule" system that controls |
| | when each position is allowed to converge to a confident prediction. |
| | |
| | Schedule Types: |
| | "linear": All positions converge at the same rate (iteration-based only) |
| | "causal": Early positions converge first, late positions last |
| | |
| | Effects (mechanisms to enforce the schedule): |
| | temperature_max: High temperature = more uniform distribution (one-sided) |
| | entropy_target_max: Force EXACT entropy via bisection search (two-sided, recommended) |
| | entropy_floor_max: Force MINIMUM entropy (one-sided, only prevents too confident) |
| | smear_sigma_max: Spread probability across neighboring positions |
| | noise_std_max: Add Gaussian noise to logits |
| | |
| | All parameters default to their config values if not specified. |
| | |
| | Args: |
| | logits: Output logits from current iteration (B, L, V) |
| | mask_pos: Boolean mask indicating which positions are masked (B, L) |
| | base_embeds: Base token embeddings for non-masked positions (B, L, H) |
| | iteration: Current iteration index (0-indexed) |
| | total_iterations: Total number of iterations |
| | |
| | Returns: |
| | Soft embeddings for next iteration (B, L, H) |
| | """ |
| | |
| | schedule = schedule if schedule is not None else self.config.schedule |
| | causal_strength = causal_strength if causal_strength is not None else self.config.causal_strength |
| | temperature_max = temperature_max if temperature_max is not None else self.config.temperature_max |
| | entropy_target_max = entropy_target_max if entropy_target_max is not None else self.config.entropy_target_max |
| | entropy_floor_max = entropy_floor_max if entropy_floor_max is not None else self.config.entropy_floor_max |
| | smear_sigma_max = smear_sigma_max if smear_sigma_max is not None else self.config.smear_sigma_max |
| | noise_std_max = noise_std_max if noise_std_max is not None else self.config.noise_std_max |
| | iteration_rope_dim_fraction = iteration_rope_dim_fraction if iteration_rope_dim_fraction is not None else self.config.iteration_rope_dim_fraction |
| |
|
| | soft_embeds = base_embeds.clone() |
| |
|
| | if not mask_pos.any(): |
| | return soft_embeds.detach() |
| |
|
| | B, L, V = logits.shape |
| | device, dtype = logits.device, logits.dtype |
| |
|
| | |
| | has_effects = ( |
| | temperature_max > 0 or |
| | entropy_target_max > 0 or |
| | entropy_floor_max > 0 or |
| | smear_sigma_max > 0 or |
| | noise_std_max > 0 or |
| | iteration_rope_dim_fraction > 0 |
| | ) |
| |
|
| | if not has_effects: |
| | |
| | masked_logits = logits[mask_pos] |
| | embed_weight = self.embed_weight |
| |
|
| | |
| | if self.config.soft_embedding_method == "none": |
| | weights = masked_logits |
| | elif self.config.soft_embedding_method == "l2_normalize": |
| | weights = F.normalize(masked_logits, p=2, dim=-1) |
| | else: |
| | weights = self.normalize(masked_logits) |
| |
|
| | masked_soft = weights @ embed_weight |
| | mask_emb = embed_weight[self.config.mask_token_id] |
| | masked_soft = masked_soft + mask_emb |
| |
|
| | |
| | ema_step = self.config.soft_embedding_ema_step |
| | if ema_step < 1.0 and prev_soft_embeds is not None: |
| | prev_masked_soft = prev_soft_embeds[mask_pos] |
| | masked_soft = (1.0 - ema_step) * prev_masked_soft + ema_step * masked_soft |
| |
|
| | soft_embeds[mask_pos] = masked_soft |
| | return soft_embeds.detach() |
| |
|
| | |
| | batch_indices, position_indices = torch.where(mask_pos) |
| |
|
| | progress = self._compute_convergence_progress( |
| | iteration=iteration, |
| | total_iterations=total_iterations, |
| | seq_length=L, |
| | mask_positions=position_indices, |
| | schedule=schedule, |
| | causal_strength=causal_strength, |
| | device=device, |
| | dtype=dtype, |
| | ) |
| |
|
| | |
| | if smear_sigma_max > 0: |
| | all_positions = torch.arange(L, device=device, dtype=dtype) |
| | progress_full = self._compute_convergence_progress( |
| | iteration=iteration, |
| | total_iterations=total_iterations, |
| | seq_length=L, |
| | mask_positions=all_positions, |
| | schedule=schedule, |
| | causal_strength=causal_strength, |
| | device=device, |
| | dtype=dtype, |
| | ) |
| | progress_full = progress_full.unsqueeze(0).expand(B, -1) |
| |
|
| | |
| | full_probs = self.normalize(logits) |
| |
|
| | if smear_sigma_max > 0: |
| | full_probs = self._apply_smear_effect( |
| | full_probs, mask_pos, progress_full, smear_sigma_max |
| | ) |
| |
|
| | |
| | masked_logits = logits[mask_pos] |
| | masked_probs = full_probs[mask_pos] |
| |
|
| | |
| | if temperature_max > 0 and entropy_target_max <= 0: |
| | masked_logits = self._apply_temperature_effect( |
| | masked_logits, progress, temperature_max |
| | ) |
| | masked_probs = torch.softmax(masked_logits, dim=-1) |
| |
|
| | |
| | if noise_std_max > 0: |
| | masked_logits_noisy = self._apply_noise_effect( |
| | torch.log(masked_probs + 1e-10), progress, noise_std_max |
| | ) |
| | masked_probs = torch.softmax(masked_logits_noisy, dim=-1) |
| |
|
| | |
| | if entropy_target_max > 0: |
| | masked_probs = self._apply_target_entropy_effect( |
| | masked_logits, progress, entropy_target_max |
| | ) |
| | elif entropy_floor_max > 0: |
| | masked_probs = self._apply_entropy_floor_effect( |
| | masked_probs, progress, entropy_floor_max |
| | ) |
| |
|
| | |
| | embed_weight = self.embed_weight |
| |
|
| | |
| | if self.config.soft_embedding_method == "none": |
| | |
| | weights = masked_logits |
| | elif self.config.soft_embedding_method == "l2_normalize": |
| | |
| | weights = F.normalize(masked_logits, p=2, dim=-1) |
| | else: |
| | weights = masked_probs |
| |
|
| | masked_soft = weights @ embed_weight |
| | mask_emb = embed_weight[self.config.mask_token_id] |
| | masked_soft = masked_soft + mask_emb |
| |
|
| | |
| | if iteration_rope_dim_fraction > 0: |
| | masked_soft = self._apply_iteration_rope( |
| | masked_soft, iteration, total_iterations, iteration_rope_dim_fraction |
| | ) |
| |
|
| | |
| | ema_step = self.config.soft_embedding_ema_step |
| | if ema_step < 1.0 and prev_soft_embeds is not None: |
| | prev_masked_soft = prev_soft_embeds[mask_pos] |
| | masked_soft = (1.0 - ema_step) * prev_masked_soft + ema_step * masked_soft |
| |
|
| | |
| | soft_embeds[mask_pos] = masked_soft |
| |
|
| | return soft_embeds.detach() |
| |
|
| | @torch.no_grad() |
| | def _compute_iteration_metrics( |
| | self, logits: torch.Tensor, labels: torch.Tensor |
| | ) -> IterationMetrics: |
| | """ |
| | Compute token-level AND sequence-level metrics for a single iteration. |
| | Returns scalars only - no large tensor storage. |
| | |
| | Token-level metrics: |
| | - accuracy: fraction of correct token predictions |
| | - entropy: average entropy per token |
| | - softmax_ce: cross-entropy loss per token |
| | |
| | Sequence-level metrics: |
| | - full_sequence_accuracy: fraction of sequences where ALL tokens are correct |
| | - min_sequence_confidence: mean of minimum top-1 confidence per sequence |
| | """ |
| | B = logits.shape[0] |
| |
|
| | |
| | logits = logits.detach().cpu().float() |
| | target_labels = labels.detach().cpu().contiguous() |
| | mask = target_labels != -100 |
| |
|
| | if mask.sum() == 0: |
| | return IterationMetrics( |
| | accuracy=0.0, |
| | entropy=0.0, |
| | softmax_ce=0.0, |
| | full_sequence_accuracy=0.0, |
| | min_sequence_confidence=0.0, |
| | ) |
| |
|
| | logits = logits.contiguous() |
| | predictions = logits.argmax(dim=-1) |
| | correct = (predictions == target_labels) & mask |
| |
|
| | |
| |
|
| | |
| | accuracy = (correct.sum() / mask.sum()).item() |
| |
|
| | |
| | valid_logits = logits[mask] |
| | valid_labels = target_labels[mask] |
| |
|
| | |
| | log_probs = torch.nn.functional.log_softmax(valid_logits, dim=-1) |
| | probs = torch.exp(log_probs) |
| | entropy = -(probs * log_probs).sum(dim=-1).mean().item() |
| |
|
| | |
| | softmax_ce = torch.nn.functional.cross_entropy( |
| | valid_logits, valid_labels, reduction="mean" |
| | ).item() |
| |
|
| | |
| |
|
| | |
| | sequences_with_tokens = mask.any(dim=1) |
| | num_valid_sequences = sequences_with_tokens.sum().item() |
| |
|
| | if num_valid_sequences == 0: |
| | return IterationMetrics( |
| | accuracy=accuracy, |
| | entropy=entropy, |
| | softmax_ce=softmax_ce, |
| | full_sequence_accuracy=0.0, |
| | min_sequence_confidence=0.0, |
| | ) |
| |
|
| | |
| | num_correct_per_seq = correct.sum(dim=1) |
| | num_tokens_per_seq = mask.sum(dim=1) |
| | all_correct = (num_correct_per_seq == num_tokens_per_seq) & sequences_with_tokens |
| | full_seq_accuracy = (all_correct.sum() / num_valid_sequences).item() |
| |
|
| | |
| | probs_full = torch.softmax(logits, dim=-1) |
| | top1_confidence = probs_full.max(dim=-1).values |
| |
|
| | min_confidences = [] |
| | for i in range(B): |
| | if sequences_with_tokens[i]: |
| | seq_confidences = top1_confidence[i][mask[i]] |
| | min_confidences.append(seq_confidences.min().item()) |
| |
|
| | min_seq_conf = sum(min_confidences) / len(min_confidences) if min_confidences else 0.0 |
| |
|
| | return IterationMetrics( |
| | accuracy=accuracy, |
| | entropy=entropy, |
| | softmax_ce=softmax_ce, |
| | full_sequence_accuracy=full_seq_accuracy, |
| | min_sequence_confidence=min_seq_conf, |
| | ) |
| |
|
| | def _single_iteration( |
| | self, |
| | t: int, |
| | T: int, |
| | soft_embeds: torch.Tensor, |
| | base_embeds: torch.Tensor, |
| | mask_pos: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor], |
| | labels: Optional[torch.Tensor], |
| | compute_metrics: bool, |
| | position_ids: Optional[torch.Tensor] = None, |
| | **kwargs, |
| | ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[IterationMetrics]]: |
| | """ |
| | Execute a single iteration of recursive refinement. |
| | |
| | Args: |
| | t: Current iteration index (0 to T-1) |
| | T: Total number of iterations |
| | soft_embeds: Soft embeddings for mask positions |
| | base_embeds: Base token embeddings from input_ids |
| | mask_pos: Boolean mask of [MASK] positions (B, L) |
| | attention_mask: Attention mask for MLM |
| | labels: Target labels for loss computation |
| | compute_metrics: Whether to compute iteration metrics |
| | |
| | Returns: |
| | logits: Output logits from MLM (B, L, V) |
| | weighted_loss: Loss weighted by step_weight(t, T), or None if no labels |
| | metrics: IterationMetrics, or None if not requested |
| | """ |
| | |
| | inputs_embeds = torch.where(mask_pos.unsqueeze(-1), soft_embeds, base_embeds) |
| |
|
| | |
| | outputs = self.mlm( |
| | inputs_embeds=inputs_embeds, |
| | attention_mask=attention_mask, |
| | position_ids=position_ids, |
| | labels=labels, |
| | return_dict=True, |
| | **kwargs, |
| | ) |
| |
|
| | |
| | weighted_loss = outputs.loss |
| | if labels is not None: |
| | if weighted_loss is None: |
| | |
| | |
| | masked_logits = outputs.logits[mask_pos] |
| | masked_labels = labels[mask_pos] |
| | loss_fct = CrossEntropyLoss() |
| | weighted_loss = loss_fct(masked_logits, masked_labels) |
| | weighted_loss *= self.step_weight(t, T) |
| |
|
| | |
| | metrics = None |
| | if compute_metrics and labels is not None: |
| | metrics = self._compute_iteration_metrics(outputs.logits, labels) |
| |
|
| | return outputs.logits, weighted_loss, metrics |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | labels: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.Tensor] = None, |
| | num_recursions: Optional[int] = None, |
| | compute_iteration_metrics: bool = False, |
| | use_recursion_checkpointing: Optional[bool] = None, |
| | |
| | prev_soft_embeds: Optional[torch.Tensor] = None, |
| | run_set_iteration: Optional[int] = None, |
| | |
| | schedule: Optional[str] = None, |
| | causal_strength: Optional[float] = None, |
| | |
| | temperature_max: Optional[float] = None, |
| | entropy_target_max: Optional[float] = None, |
| | entropy_floor_max: Optional[float] = None, |
| | smear_sigma_max: Optional[float] = None, |
| | noise_std_max: Optional[float] = None, |
| | iteration_rope_dim_fraction: Optional[float] = None, |
| | **kwargs, |
| | ) -> RecursiveMaskedLMOutput: |
| | """ |
| | Forward with recursive refinement. |
| | |
| | Supports three modes: |
| | 1. Checkpointed mode (default): Run all T recursions with gradient checkpointing. |
| | Gradients flow through the entire chain; activations recomputed during backward. |
| | 2. Non-checkpointed mode (use_recursion_checkpointing=False): Store all activations. |
| | Faster backward but higher memory. |
| | 3. Single-iteration mode (DEPRECATED - run_set_iteration is not None): Run only one |
| | iteration. Use use_recursion_checkpointing=True instead. |
| | |
| | Loss Weighting (config.loss_weight): |
| | "last_1": Only final iteration loss (enables learning convergence behavior) |
| | "last_2": Last 2 iterations |
| | "linear": All iterations, linearly weighted (default) |
| | "uniform": All iterations, uniformly weighted |
| | |
| | Recursion Checkpointing: |
| | use_recursion_checkpointing: Enable gradient checkpointing for iterations. |
| | True = checkpoint each iteration, recompute during backward (default). |
| | False = store all activations (higher memory, faster backward). |
| | |
| | Convergence Schedule Parameters: |
| | All schedule/effect parameters default to their config values if not specified. |
| | Pass explicit values to override config for this forward pass. |
| | |
| | schedule: "linear" or "causal" - controls when positions can converge |
| | causal_strength: How much faster early positions converge (causal only) |
| | temperature_max: Max temperature boost for uncertain positions |
| | entropy_target_max: Target entropy at progress=0 (two-sided, recommended) |
| | entropy_floor_max: Min entropy floor (one-sided) |
| | smear_sigma_max: Max Gaussian sigma for position smearing |
| | noise_std_max: Max std of Gaussian noise on logits |
| | iteration_rope_dim_fraction: Fraction of dims for iteration RoPE |
| | """ |
| | B, L = input_ids.shape |
| | V = self.embed_weight.shape[0] |
| | mask_id = self.config.mask_token_id |
| |
|
| | if mask_id is None: |
| | raise ValueError("mask_token_id must be set") |
| |
|
| | |
| | use_recursion_checkpointing = ( |
| | use_recursion_checkpointing |
| | if use_recursion_checkpointing is not None |
| | else self.config.use_recursion_checkpointing |
| | ) |
| |
|
| | mask_pos = (input_ids == mask_id) |
| | base_embeds = self.get_input_embeddings()(input_ids) |
| | T = num_recursions or self.config.num_recursions |
| | weight_sum = sum(self.step_weight(i, T) for i in range(T)) |
| |
|
| | |
| | schedule_kwargs = dict( |
| | schedule=schedule, |
| | causal_strength=causal_strength, |
| | temperature_max=temperature_max, |
| | entropy_target_max=entropy_target_max, |
| | entropy_floor_max=entropy_floor_max, |
| | smear_sigma_max=smear_sigma_max, |
| | noise_std_max=noise_std_max, |
| | iteration_rope_dim_fraction=iteration_rope_dim_fraction, |
| | ) |
| |
|
| | |
| | if run_set_iteration is not None: |
| | warnings.warn( |
| | "run_set_iteration is deprecated. Use use_recursion_checkpointing=True instead, " |
| | "which provides proper gradient flow through all iterations.", |
| | DeprecationWarning, |
| | stacklevel=2, |
| | ) |
| | t = run_set_iteration |
| |
|
| | |
| | if t == 0: |
| | |
| | |
| | soft_embeds = base_embeds.clone() |
| | if mask_pos.any(): |
| | avg_embed = self.embed_weight.mean(dim=0) |
| | mask_emb = self.embed_weight[mask_id] |
| | soft_embeds[mask_pos] = avg_embed + mask_emb |
| | else: |
| | if prev_soft_embeds is None: |
| | raise ValueError(f"prev_soft_embeds must be provided for iteration {t}") |
| | soft_embeds = prev_soft_embeds |
| |
|
| | logits, weighted_loss, metrics = self._single_iteration( |
| | t, T, soft_embeds, base_embeds, mask_pos, |
| | attention_mask, labels, compute_iteration_metrics, |
| | position_ids=position_ids, **kwargs |
| | ) |
| |
|
| | |
| | loss = weighted_loss / weight_sum if weighted_loss is not None else None |
| |
|
| | |
| | next_soft_embeds = None |
| | if t < T - 1: |
| | next_soft_embeds = self._compute_next_soft_embeds( |
| | logits, mask_pos, base_embeds, |
| | iteration=t, |
| | total_iterations=T, |
| | **schedule_kwargs, |
| | ) |
| |
|
| | return RecursiveMaskedLMOutput( |
| | loss=loss, |
| | logits=logits, |
| | next_soft_embeds=next_soft_embeds, |
| | iteration_metrics={t: metrics} if metrics is not None else None, |
| | ) |
| |
|
| | |
| | embed_weight = self.embed_weight |
| | mask_emb = embed_weight[mask_id] |
| |
|
| | |
| | temperature = torch.tensor( |
| | self.config.temperature, |
| | device=input_ids.device, |
| | dtype=base_embeds.dtype, |
| | ) |
| |
|
| | |
| | if attention_mask is None: |
| | attention_mask = torch.ones(B, L, device=input_ids.device, dtype=base_embeds.dtype) |
| |
|
| | |
| | soft_embeds = base_embeds.clone() |
| | flow_noise_embed = None |
| | flow_t_per_token = None |
| |
|
| | if self.config.flow_matching_enabled and self.training and labels is not None and mask_pos.any(): |
| | |
| | num_masked = mask_pos.sum().item() |
| | V = embed_weight.shape[0] |
| | device = input_ids.device |
| |
|
| | |
| | flow_t_per_token = self._sample_flow_matching_t(num_masked, device) |
| |
|
| | |
| | z = torch.randn(num_masked, V, device=device, dtype=base_embeds.dtype) |
| | p_noise = F.softmax(z * self.config.flow_matching_noise_scale, dim=-1).to(base_embeds.dtype) |
| | flow_noise_embed = p_noise @ embed_weight |
| |
|
| | |
| | target_ids = labels[mask_pos] |
| | target_embed = embed_weight[target_ids] |
| |
|
| | |
| | t_col = flow_t_per_token.unsqueeze(-1).to(base_embeds.dtype) |
| | interp_embed = (1 - t_col) * flow_noise_embed + t_col * target_embed |
| |
|
| | |
| | if self.config.flow_matching_mask_scale: |
| | soft_embeds[mask_pos] = interp_embed + (1 - t_col) * mask_emb |
| | else: |
| | soft_embeds[mask_pos] = interp_embed + mask_emb |
| | elif mask_pos.any(): |
| | |
| | avg_embed = embed_weight.mean(dim=0) |
| | soft_embeds[mask_pos] = avg_embed + mask_emb |
| |
|
| | iteration_metrics = {} if compute_iteration_metrics and labels is not None else None |
| |
|
| | |
| | all_logits = [] |
| | for t in range(T): |
| | if self.training and use_recursion_checkpointing: |
| | |
| | |
| | logits, soft_embeds = torch_checkpoint( |
| | self._single_iteration_checkpointable, |
| | soft_embeds, |
| | base_embeds, |
| | mask_pos, |
| | attention_mask, |
| | embed_weight, |
| | mask_emb, |
| | temperature, |
| | position_ids, |
| | use_reentrant=False, |
| | ) |
| | else: |
| | |
| | logits, soft_embeds = self._single_iteration_checkpointable( |
| | soft_embeds, |
| | base_embeds, |
| | mask_pos, |
| | attention_mask, |
| | embed_weight, |
| | mask_emb, |
| | temperature, |
| | position_ids, |
| | ) |
| | all_logits.append(logits) |
| |
|
| | |
| | if iteration_metrics is not None and labels is not None: |
| | with torch.no_grad(): |
| | iteration_metrics[t] = self._compute_iteration_metrics(logits, labels) |
| |
|
| | |
| | |
| | return RecursiveMaskedLMOutput( |
| | loss=None, |
| | logits=logits, |
| | all_logits=all_logits if self.training else None, |
| | iteration_metrics=iteration_metrics or None, |
| | flow_noise_embed=flow_noise_embed, |
| | flow_t=flow_t_per_token, |
| | ) |
| |
|
| | @torch.no_grad() |
| | def _generate_flow_map( |
| | self, |
| | input_ids: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor], |
| | position_ids: Optional[torch.Tensor], |
| | num_steps: int, |
| | ) -> torch.Tensor: |
| | """Fill in mask positions using the CFM flow map update rule. |
| | |
| | Starts from a random point on the probability simplex and iteratively |
| | moves toward the model's predictions using the flow map step rule. |
| | |
| | Args: |
| | input_ids: Input with [MASK] tokens at positions to fill |
| | attention_mask: Attention mask |
| | position_ids: Position IDs |
| | num_steps: Number of flow map steps (finer = better, 1 step = greedy) |
| | |
| | Returns: |
| | Tensor with [MASK] positions filled with predicted tokens |
| | """ |
| | mask_pos = (input_ids == self.config.mask_token_id) |
| | num_masked = mask_pos.sum().item() |
| |
|
| | if num_masked == 0: |
| | return input_ids.clone() |
| |
|
| | device = input_ids.device |
| | V = self.embed_weight.shape[0] |
| | embed_weight = self.embed_weight |
| | mask_emb = embed_weight[self.config.mask_token_id] |
| | base_embeds = self.get_input_embeddings()(input_ids) |
| |
|
| | |
| | noise_scale = self.config.flow_matching_noise_scale |
| | p = F.softmax(torch.randn(num_masked, V, device=device, dtype=base_embeds.dtype) * noise_scale, dim=-1).to(base_embeds.dtype) |
| |
|
| | times = torch.linspace(0, 1, num_steps + 1, device=device) |
| |
|
| | for i in range(num_steps): |
| | t_now = times[i] |
| | t_next = times[i + 1] |
| | step_size = (t_next - t_now) / (1 - t_now) |
| |
|
| | |
| | if self.config.flow_matching_mask_scale: |
| | mask_signal = (1 - t_now) * mask_emb |
| | else: |
| | mask_signal = mask_emb |
| |
|
| | |
| | embed = p @ embed_weight + mask_signal |
| |
|
| | soft_embeds = base_embeds.clone() |
| | soft_embeds[mask_pos] = embed |
| | inputs_embeds = torch.where(mask_pos.unsqueeze(-1), soft_embeds, base_embeds) |
| |
|
| | outputs = self.mlm( |
| | inputs_embeds=inputs_embeds, |
| | attention_mask=attention_mask, |
| | position_ids=position_ids, |
| | return_dict=True, |
| | ) |
| | pi = F.softmax(outputs.logits[mask_pos], dim=-1).to(p.dtype) |
| |
|
| | |
| | p = p + step_size * (pi - p) |
| |
|
| | |
| | p = p.clamp(min=0) |
| | p = p / p.sum(dim=-1, keepdim=True) |
| |
|
| | result = input_ids.clone() |
| | result[mask_pos] = p.argmax(dim=-1) |
| | return result |
| |
|
| | @torch.no_grad() |
| | def generate( |
| | self, |
| | input_ids: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.Tensor] = None, |
| | num_recursions: Optional[int] = None, |
| | |
| | schedule: Optional[str] = None, |
| | causal_strength: Optional[float] = None, |
| | |
| | temperature_max: Optional[float] = None, |
| | entropy_target_max: Optional[float] = None, |
| | entropy_floor_max: Optional[float] = None, |
| | smear_sigma_max: Optional[float] = None, |
| | noise_std_max: Optional[float] = None, |
| | iteration_rope_dim_fraction: Optional[float] = None, |
| | ) -> torch.Tensor: |
| | """Fill in mask positions via iterative refinement. |
| | |
| | When flow_matching_enabled, uses the CFM flow map update rule. |
| | Otherwise, uses standard recursive soft-token refinement. |
| | |
| | Args: |
| | input_ids: Input token IDs with [MASK] tokens at positions to fill |
| | attention_mask: Attention mask |
| | num_recursions: Override number of recursions/steps (default: config value) |
| | schedule: "linear" or "causal" convergence schedule |
| | causal_strength: How much faster early positions converge (causal only) |
| | temperature_max: Max temperature boost for uncertain positions |
| | entropy_target_max: Target entropy at progress=0 (two-sided) |
| | entropy_floor_max: Min entropy floor (one-sided) |
| | smear_sigma_max: Max Gaussian sigma for position smearing |
| | noise_std_max: Max std of Gaussian noise on logits |
| | iteration_rope_dim_fraction: Fraction of dims for iteration RoPE |
| | |
| | Returns: |
| | Tensor with [MASK] positions filled with predicted tokens |
| | """ |
| | num_steps = num_recursions or self.config.num_recursions |
| |
|
| | if self.config.flow_matching_enabled: |
| | return self._generate_flow_map( |
| | input_ids, attention_mask, position_ids, num_steps |
| | ) |
| |
|
| | out = self.forward( |
| | input_ids, |
| | attention_mask, |
| | position_ids=position_ids, |
| | num_recursions=num_steps, |
| | schedule=schedule, |
| | causal_strength=causal_strength, |
| | temperature_max=temperature_max, |
| | entropy_target_max=entropy_target_max, |
| | entropy_floor_max=entropy_floor_max, |
| | smear_sigma_max=smear_sigma_max, |
| | noise_std_max=noise_std_max, |
| | iteration_rope_dim_fraction=iteration_rope_dim_fraction, |
| | ) |
| | result = input_ids.clone() |
| | mask_pos = (input_ids == self.config.mask_token_id) |
| | result[mask_pos] = out.logits.argmax(dim=-1)[mask_pos] |
| | return result |
| |
|