| from __future__ import annotations |
| from typing import Optional |
|
|
| from transformers import PretrainedConfig |
|
|
|
|
| class RecursiveMLMConfig(PretrainedConfig): |
| """ |
| Configuration for RecursiveMaskedLM. |
| |
| Stores the base MLM config plus recursive refinement parameters. |
| |
| Convergence Schedule System |
| --------------------------- |
| The convergence schedule controls WHEN each position is allowed to converge |
| to a confident prediction during iterative refinement. |
| |
| Schedule types: |
| - "linear": All positions converge at the same rate (iteration-based only) |
| - "causal": Early positions converge first, late positions last |
| |
| Effects (mechanisms to enforce the schedule): |
| - temperature_max: Raise temperature for positions not yet allowed to converge |
| - entropy_target_max: Force exact entropy via bisection search (two-sided, recommended) |
| - entropy_floor_max: Force minimum entropy (one-sided, only raises) |
| - smear_sigma_max: Spread probability across neighboring positions |
| - noise_std_max: Add Gaussian noise to logits |
| - iteration_rope_dim_fraction: Apply rotary embedding based on iteration progress |
| |
| Soft Embedding Methods |
| ---------------------- |
| Controls how logits are converted to soft embeddings for the next iteration: |
| - "softmax": Standard softmax normalization (default). Creates sparse, probabilistic |
| mixing but can cause gradient bottlenecks through the softmax Jacobian. |
| - "l2_normalize": L2 normalize logits before mixing with embeddings. Removes the |
| softmax bottleneck for smoother gradients through long recursion chains. |
| - "none": No normalization - use raw logits directly. Warning: this can cause |
| scale explosion without additional mechanisms like EMA accumulation. |
| |
| - soft_embedding_ema_step: Controls EMA blending with previous soft embeddings. |
| 1.0 (default) = full update (no EMA), 0.1 = slow update (90% previous + 10% new). |
| Formula: new = (1 - ema_step) * prev + ema_step * current |
| |
| Recursion Checkpointing |
| ----------------------- |
| Controls gradient flow through the entire recursion chain for memory-efficient training. |
| |
| Parameters: |
| - use_recursion_checkpointing: Enable gradient checkpointing for iterations |
| - loss_weight: Use "last_1" for final-iteration-only loss (learns convergence behavior) |
| |
| Flow Matching (CFM-inspired) |
| ---------------------------- |
| Replaces the old temperature-based self-distillation with a Continuous Flow Matching |
| framework. Training inputs are interpolated on the probability simplex between random |
| noise and the target one-hot, distillation gives the student a noisier (earlier-time) |
| version of the same interpolation path, and inference uses a flow map update rule. |
| |
| Parameters: |
| - flow_matching_enabled: Enable the flow matching framework |
| - flow_matching_lambda: Weight of distillation KL loss relative to CE loss |
| - flow_matching_t_distribution: How to sample time t ("logit_normal" or "uniform") |
| - flow_matching_t_logit_mean: Mean of logit-normal distribution (-0.4 biases toward noisy) |
| - flow_matching_t_logit_std: Std of logit-normal distribution |
| - flow_matching_t_min: Minimum time value (clamp) |
| - flow_matching_t_max: Maximum time value (clamp) |
| - flow_matching_mask_scale: If True, scale mask_emb by (1-t); if False, binary mask signal |
| |
| Time levels are sampled independently per masked token. At t=0 the input is pure noise, |
| at t=1 it is the clean target embedding. |
| |
| Self-Distillation (legacy, temperature-based) |
| ---------------------------------------------- |
| Kept for backward compatibility. Ignored when flow_matching_enabled=True. |
| |
| Parameters: |
| - self_distillation_enabled: Enable the self-distillation KL loss |
| - self_distillation_lambda: Weight of distillation loss relative to CE loss |
| - self_distillation_temperature_min: Minimum degradation temperature |
| - self_distillation_temperature_max: Maximum degradation temperature |
| - self_distillation_temperature_distribution: How to sample temperature |
| - self_distillation_teacher: Which logits to use as teacher ("first" or "last") |
| """ |
| model_type = "recursive-mlm" |
|
|
| def __init__( |
| self, |
| base_model_config: Optional[dict] = None, |
| num_recursions: int = 8, |
| normalization: str = "softmax", |
| loss_weight: str = "linear", |
| mask_token_id: Optional[int] = None, |
| temperature: float = 1.0, |
| gradient_steps: Optional[int] = None, |
| |
| schedule: str = "linear", |
| causal_strength: float = 1.0, |
| |
| temperature_max: float = 0.0, |
| entropy_target_max: float = 0.0, |
| entropy_floor_max: float = 0.0, |
| smear_sigma_max: float = 0.0, |
| noise_std_max: float = 0.0, |
| iteration_rope_dim_fraction: float = 0.0, |
| use_recursion_checkpointing: bool = True, |
| |
| soft_embedding_method: str = "softmax", |
| soft_embedding_ema_step: float = 1.0, |
| |
| flow_matching_enabled: bool = False, |
| flow_matching_lambda: float = 0.5, |
| flow_matching_t_distribution: str = "logit_normal", |
| flow_matching_t_logit_mean: float = -0.4, |
| flow_matching_t_logit_std: float = 1.0, |
| flow_matching_t_min: float = 0.01, |
| flow_matching_t_max: float = 0.99, |
| flow_matching_noise_scale: float = 2.0, |
| flow_matching_mask_scale: bool = False, |
| |
| self_distillation_enabled: bool = False, |
| self_distillation_lambda: float = 0.5, |
| self_distillation_temperature_min: float = 1.5, |
| self_distillation_temperature_max: float = 10.0, |
| self_distillation_temperature_distribution: str = "log_uniform", |
| self_distillation_teacher: str = "first", |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
| self.base_model_config = base_model_config |
| self.num_recursions = num_recursions |
| self.normalization = normalization |
| self.loss_weight = loss_weight |
| self.mask_token_id = mask_token_id |
| self.temperature = temperature |
| self.gradient_steps = gradient_steps |
| |
| self.schedule = schedule |
| self.causal_strength = causal_strength |
| |
| self.temperature_max = temperature_max |
| self.entropy_target_max = entropy_target_max |
| self.entropy_floor_max = entropy_floor_max |
| self.smear_sigma_max = smear_sigma_max |
| self.noise_std_max = noise_std_max |
| self.iteration_rope_dim_fraction = iteration_rope_dim_fraction |
| |
| self.use_recursion_checkpointing = use_recursion_checkpointing |
| |
| self.soft_embedding_method = soft_embedding_method |
| self.soft_embedding_ema_step = soft_embedding_ema_step |
| |
| self.flow_matching_enabled = flow_matching_enabled |
| self.flow_matching_lambda = flow_matching_lambda |
| self.flow_matching_t_distribution = flow_matching_t_distribution |
| self.flow_matching_t_logit_mean = flow_matching_t_logit_mean |
| self.flow_matching_t_logit_std = flow_matching_t_logit_std |
| self.flow_matching_t_min = flow_matching_t_min |
| self.flow_matching_t_max = flow_matching_t_max |
| self.flow_matching_noise_scale = flow_matching_noise_scale |
| self.flow_matching_mask_scale = flow_matching_mask_scale |
| |
| self.self_distillation_enabled = self_distillation_enabled |
| self.self_distillation_lambda = self_distillation_lambda |
| self.self_distillation_temperature_min = self_distillation_temperature_min |
| self.self_distillation_temperature_max = self_distillation_temperature_max |
| self.self_distillation_temperature_distribution = self_distillation_temperature_distribution |
| self.self_distillation_teacher = self_distillation_teacher |
|
|
| @classmethod |
| def from_base_model_config( |
| cls, |
| base_config: PretrainedConfig, |
| **kwargs, |
| ) -> "RecursiveMLMConfig": |
| """Create config from a base MLM's config.""" |
| return cls( |
| base_model_config=base_config.to_dict(), |
| **kwargs, |
| ) |
|
|