| | """ |
| | MonoidForCausalLM โ Causal Monoid Language Model (HuggingFace Compatible) |
| | MonoidForCausalLM โ ๅนบๅ็พคๅ ๆ่ฏญ่จๆจกๅ (ๅ
ผๅฎน HuggingFace) |
| | |
| | Architecture / ๆถๆๆฆ่ฆ: |
| | Replace softmax attention with a monoid parallel-scan recurrence. |
| | ็จๅนบๅ็พคๅนถ่กๆซๆ้ๆจๆฟไปฃ softmax ๆณจๆๅใ |
| | |
| | Core idea / ๆ ธๅฟๆๆณ: |
| | Softmax attention computes o_t = ฮฃ_{iโคt} softmax(q_tยทk_i) v_i |
| | โ requires O(T) KV-cache per layer at inference. |
| | Softmax ๆณจๆๅ่ฎก็ฎ o_t = ฮฃ_{iโคt} softmax(q_tยทk_i) v_i |
| | โ ๆจ็ๆถๆฏๅฑ้่ฆ O(T) ็ KV ็ผๅญใ |
| | |
| | Monoid attention compresses the entire causal history into a |
| | fixed-size state matrix S_t โ โ^{dรd} per head: |
| | S_t = diag(ฮฑ_t) ยท S_{t-1} + k_t โ v_t (vector decay recurrence) |
| | o_t = q_t ยท S_t (state readout) |
| | where ฮฑ_t โ โ^d is a per-dimension vector decay gate. |
| | ๅนบๅ็พคๆณจๆๅๅฐๅฎๆดๅ ๆๅๅฒๅ็ผฉๅฐๆฏไธชๅคดไธไธชๅบๅฎๅคงๅฐ็็ถๆ็ฉ้ต S_t: |
| | S_t = diag(ฮฑ_t) ยท S_{t-1} + k_t โ v_t (ๅ้่กฐๅ้ๆจ) |
| | o_t = q_t ยท S_t (็ถๆ่ฏปๅบ) |
| | ๅ
ถไธญ ฮฑ_t โ โ^d ๆฏ้็ปดๅบฆ็ๅ้่กฐๅ้จใ |
| | |
| | This is a monoid because the binary operator: |
| | (ฮฑ, S) โ (ฮฒ, X) = (ฮฑยทฮฒ, diag(ฮฒ)ยทS + X) |
| | is associative โ enables parallel prefix scan for training, |
| | and O(1) sequential update for inference. |
| | ่ฟๆฏไธไธชๅนบๅ็พค๏ผๅ ไธบไบๅ
็ฎๅญ: |
| | (ฮฑ, S) โ (ฮฒ, X) = (ฮฑยทฮฒ, diag(ฮฒ)ยทS + X) |
| | ๆปก่ถณ็ปๅๅพ โ ่ฎญ็ปๆถๅฏ็จๅนถ่กๅ็ผๆซๆ๏ผๆจ็ๆถ O(1) ้ๆญฅ้ๆจใ |
| | |
| | Key properties / ๅ
ณ้ฎ็นๆง: |
| | โ Explicit causal modeling โ ฮฑ_t gate explicitly controls how fast |
| | past information decays, making causality a first-class citizen. |
| | ๆพๅผๅ ๆๅปบๆจก โ ฮฑ_t ่กฐๅ้จๆพๅผๆงๅถๅๅฒไฟกๆฏ็้ๅฟ้็๏ผ |
| | ๅ ๆๆงๆฏไธ็ญๅ
ฌๆฐ่้้ mask ๆฝๅ ็็บฆๆใ |
| | |
| | โ Monoid state compression โ the full causal prefix x_{1:t} is |
| | lossily compressed into a fixed-size (dรd) state matrix per head. |
| | No O(T) KV-cache needed; inference is O(1) per token per layer. |
| | ๅนบๅ็พค็ถๆๅ็ผฉ โ ๅฎๆดๅ ๆๅ็ผ x_{1:t} ่ขซๆๆๅ็ผฉๅฐๆฏไธชๅคด |
| | ๅบๅฎๅคงๅฐ็ (dรd) ็ถๆ็ฉ้ตไธญใๆ ้ O(T) KV ็ผๅญ๏ผ |
| | ๆจ็ๆถๆฏๅฑๆฏ token O(1)ใ |
| | |
| | โ Parallel training โ associativity of โ enables O(T) parallel |
| | prefix scan (vs O(Tยฒ) for softmax attention). |
| | ๅนถ่ก่ฎญ็ป โ โ ็็ปๅๅพไฝฟ O(T) ๅนถ่กๅ็ผๆซๆๆไธบๅฏ่ฝ |
| | (ๅฏนๆฏ softmax ๆณจๆๅ็ O(Tยฒ))ใ |
| | |
| | Reuses LlamaMLP + LlamaRMSNorm from HuggingFace Transformers. |
| | ๅค็จ HuggingFace Transformers ็ LlamaMLP + LlamaRMSNormใ |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | from typing import Optional, Union |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torch import Tensor |
| |
|
| | from transformers import PretrainedConfig, PreTrainedModel, GenerationMixin, AutoConfig, AutoModelForCausalLM |
| | from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
| | from transformers.models.llama.modeling_llama import LlamaMLP, LlamaRMSNorm |
| |
|
| | try: |
| | from monoid_scan_cuda import parallel_scan, parallel_scan_with_state |
| | except ImportError: |
| | |
| | |
| |
|
| | def parallel_scan(alpha: Tensor, kv: Tensor) -> Tensor: |
| | """Sequential prefix scan fallback: S_t[i,:] = ฮฑ_t[i]ยทS_{t-1}[i,:] + kv_t[i,:].""" |
| | B, H, T, d1, d2 = kv.shape |
| | states = torch.zeros(B, H, T, d1, d2, device=kv.device, dtype=kv.dtype) |
| | S = torch.zeros(B, H, d1, d2, device=kv.device, dtype=kv.dtype) |
| | for t in range(T): |
| | decay = alpha[:, :, t] |
| | while decay.dim() < S.dim(): |
| | decay = decay.unsqueeze(-1) |
| | S = S * decay + kv[:, :, t] |
| | states[:, :, t] = S |
| | return states |
| |
|
| | def parallel_scan_with_state(alpha: Tensor, kv: Tensor): |
| | """Sequential prefix scan that also returns the final (decay_acc, S) state.""" |
| | B, H, T, d1, d2 = kv.shape |
| | states = torch.zeros(B, H, T, d1, d2, device=kv.device, dtype=kv.dtype) |
| | S = torch.zeros(B, H, d1, d2, device=kv.device, dtype=kv.dtype) |
| | decay_acc = torch.ones(B, H, d1, device=alpha.device, dtype=alpha.dtype) |
| | for t in range(T): |
| | decay = alpha[:, :, t] |
| | while decay.dim() < S.dim(): |
| | decay = decay.unsqueeze(-1) |
| | S = S * decay + kv[:, :, t] |
| | states[:, :, t] = S |
| | decay_acc = decay_acc * alpha[:, :, t] |
| | return states, (decay_acc, S) |
| |
|
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class MonoidConfig(PretrainedConfig): |
| | """ |
| | Configuration for the Monoid causal language model. |
| | ๅนบๅ็พคๅ ๆ่ฏญ่จๆจกๅ็้
็ฝฎใ |
| | |
| | Mirrors LlamaConfig for the shared components (MLP, RMSNorm, embedding) |
| | so that weights can be directly transferred from Llama checkpoints. |
| | ไธ LlamaConfig ็ๅ
ฑไบซ็ปไปถ (MLP, RMSNorm, embedding) ไฟๆไธ่ด, |
| | ไปฅไพฟไป Llama ๆฃๆฅ็น็ดๆฅ่ฟ็งปๆ้ใ |
| | """ |
| | model_type = "monoid" |
| |
|
| | def __init__( |
| | self, |
| | vocab_size: int = 32000, |
| | hidden_size: int = 576, |
| | intermediate_size: int = 1536, |
| | num_hidden_layers: int = 30, |
| | num_attention_heads: int = 9, |
| | head_dim: int = 64, |
| | max_position_embeddings: int = 2048, |
| | rms_norm_eps: float = 1e-5, |
| | hidden_act: str = "silu", |
| | mlp_bias: bool = False, |
| | attention_bias: bool = False, |
| | tie_word_embeddings: bool = True, |
| | initializer_range: float = 0.041666666666666664, |
| | pad_token_id: int = None, |
| | bos_token_id: int = 1, |
| | eos_token_id: int = 2, |
| | **kwargs, |
| | ): |
| | super().__init__( |
| | pad_token_id=pad_token_id, |
| | bos_token_id=bos_token_id, |
| | eos_token_id=eos_token_id, |
| | tie_word_embeddings=tie_word_embeddings, |
| | **kwargs, |
| | ) |
| | self.vocab_size = vocab_size |
| | self.hidden_size = hidden_size |
| | self.intermediate_size = intermediate_size |
| | self.num_hidden_layers = num_hidden_layers |
| | self.num_attention_heads = num_attention_heads |
| | self.head_dim = head_dim |
| | self.max_position_embeddings = max_position_embeddings |
| | self.rms_norm_eps = rms_norm_eps |
| | self.hidden_act = hidden_act |
| | self.mlp_bias = mlp_bias |
| | self.attention_bias = attention_bias |
| | self.initializer_range = initializer_range |
| |
|
| |
|
| | |
| | |
| | |
| | |
| |
|
| | class MonoidCache: |
| | """ |
| | Per-layer monoid state cache for autoregressive inference. |
| | ่ชๅๅฝๆจ็็้ๅฑๅนบๅ็พค็ถๆ็ผๅญใ |
| | |
| | Unlike Transformer KV-Cache that stores all past keys & values (O(T) memory), |
| | each layer here stores exactly ONE state tuple: |
| | (decay_acc, S) where S โ โ^{B, H, d, d} |
| | This is the monoid "sum" of all past (ฮฑ_i, k_iโv_i) via โ. |
| | Memory is O(1) per layer regardless of sequence length. |
| | |
| | ไธๅไบ Transformer ็ KV-Cache (ๅญๅจๆๆ่ฟๅป็ key ๅ value, O(T) ๅ
ๅญ), |
| | ่ฟ้ๆฏๅฑไป
ๅญๅจไธไธช็ถๆๅ
็ป: |
| | (decay_acc, S) ๅ
ถไธญ S โ โ^{B, H, d, d} |
| | ่ฟๆฏๆๆ่ฟๅป็ (ฮฑ_i, k_iโv_i) ้่ฟ โ ็ดฏ็งฏ็ๅนบๅ็พค "ๅ"ใ |
| | ๆ ่ฎบๅบๅๅค้ฟ๏ผๆฏๅฑๅ
ๅญ O(1)ใ |
| | """ |
| |
|
| | def __init__(self): |
| | self.states: list[tuple[Tensor, Tensor] | None] = [] |
| | self.seen_tokens: int = 0 |
| |
|
| | def get_seq_length(self, layer_idx: int = 0) -> int: |
| | return self.seen_tokens |
| |
|
| | def update(self, layer_idx: int, state: tuple[Tensor, Tensor]): |
| | """Store the accumulated monoid state for a given layer. |
| | ๅญๅจๆๅฎๅฑ็็ดฏ็งฏๅนบๅ็พค็ถๆใ""" |
| | while len(self.states) <= layer_idx: |
| | self.states.append(None) |
| | self.states[layer_idx] = state |
| |
|
| | def get_state(self, layer_idx: int) -> tuple[Tensor, Tensor] | None: |
| | """Retrieve the accumulated monoid state for a given layer. |
| | ่ทๅๆๅฎๅฑ็็ดฏ็งฏๅนบๅ็พค็ถๆใ""" |
| | if layer_idx < len(self.states): |
| | return self.states[layer_idx] |
| | return None |
| |
|
| | def reorder_cache(self, beam_idx: torch.LongTensor): |
| | """Reorder cache for beam search. ไธบ beam search ้ๆ็ผๅญใ""" |
| | for i, state in enumerate(self.states): |
| | if state is not None: |
| | log_d, kv = state |
| | self.states[i] = (log_d[beam_idx], kv[beam_idx]) |
| |
|
| |
|
| | |
| | |
| | |
| | |
| |
|
| | def monoid_op( |
| | a: tuple[Tensor, Tensor], |
| | b: tuple[Tensor, Tensor], |
| | ) -> tuple[Tensor, Tensor]: |
| | """ |
| | The monoid binary operator โ on (vector decay, state matrix) pairs. |
| | ๅนบๅ็พคไบๅ
็ฎๅญ โ๏ผไฝ็จไบ (ๅ้่กฐๅ, ็ถๆ็ฉ้ต) ๅฏนใ |
| | |
| | Definition / ๅฎไน: |
| | (ฮฑ, S) โ (ฮฒ, X) = (ฮฑยทฮฒ, diag(ฮฒ)ยทS + X) |
| | where ฮฑ, ฮฒ โ (0,1)^d are per-dimension vector decay gates (sigmoid output). |
| | |
| | Why this is a monoid / ไธบไปไน่ฟๆฏๅนบๅ็พค: |
| | โข Associativity / ็ปๅๅพ: |
| | (a โ b) โ c = a โ (b โ c) โ |
| | This enables parallel prefix scan for training (reduce tree) |
| | and O(1) left-fold for inference (sequential append). |
| | ็ปๅๅพไฝฟ่ฎญ็ปๆถๅฏไปฅ็จๅนถ่กๅ็ผๆซๆ (ๅฝ็บฆๆ ), |
| | ๆจ็ๆถๅฏไปฅ O(1) ๅทฆๆๅ (้ๆญฅ่ฟฝๅ )ใ |
| | |
| | โข Identity / ๅไฝๅ
: |
| | e = (1, 0) โ e โ a = a โ e = a โ |
| | |
| | Causal semantics / ๅ ๆ่ฏญไน: |
| | S_t = ฮฑ_t ยท S_{t-1} + k_t โ v_t |
| | The decay ฮฑ_t โ (0,1) explicitly controls how much of the past |
| | the model retains. This is *explicit causal modeling* โ the model |
| | must learn to balance retention vs novelty at every timestep. |
| | ่กฐๅ ฮฑ_t โ (0,1) ๆพๅผๆงๅถๆจกๅไฟ็ๅคๅฐ่ฟๅปไฟกๆฏใ |
| | ่ฟๅฐฑๆฏ *ๆพๅผๅ ๆๅปบๆจก* โ ๆจกๅๅฟ
้กปๅจๆฏไธชๆถ้ดๆญฅๅญฆไน ๅฆไฝ |
| | ๅนณ่กกไฟ็ๆงไฟกๆฏไธๅธๆถๆฐไฟกๆฏใ |
| | """ |
| | decay_a, kv_a = a |
| | decay_b, kv_b = b |
| |
|
| | new_decay = decay_a * decay_b |
| | while decay_b.dim() < kv_a.dim(): |
| | decay_b = decay_b.unsqueeze(-1) |
| |
|
| | return new_decay, kv_a * decay_b + kv_b |
| |
|
| |
|
| | |
| | |
| | |
| | |
| |
|
| | class MonoidAttention(nn.Module): |
| | """ |
| | Monoid Causal Attention โ replaces softmax attention entirely. |
| | ๅนบๅ็พคๅ ๆๆณจๆๅ โ ๅฎๅ
จๆฟไปฃ softmax ๆณจๆๅใ |
| | |
| | Key differences from standard attention / ไธๆ ๅๆณจๆๅ็ๅ
ณ้ฎๅบๅซ: |
| | โ No RoPE / positional encoding โ position is implicitly encoded |
| | by the causal decay gate ฮฑ_t. The model learns *when* to forget |
| | rather than encoding *where* tokens are. |
| | ไธไฝฟ็จ RoPE / ไฝ็ฝฎ็ผ็ โ ไฝ็ฝฎไฟกๆฏ็ฑๅ ๆ่กฐๅ้จ ฮฑ_t ้ๅผ็ผ็ ใ |
| | ๆจกๅๅญฆไน *ไฝๆถ้ๅฟ* ่้็ผ็ token *ๅจๅช้*ใ |
| | |
| | โ No KV-Cache โ replaced by MonoidCache with O(1) state per layer. |
| | Each state S โ โ^{Hรdรd} is a compressed summary of ALL past tokens. |
| | ไธไฝฟ็จ KV ็ผๅญ โ ็ฑ O(1) ็ MonoidCache ็ถๆๆฟไปฃใ |
| | ๆฏไธช็ถๆ S โ โ^{Hรdรd} ๆฏๆๆ่ฟๅป token ็ๅ็ผฉๆ่ฆใ |
| | |
| | โ No attention mask โ causality is built into the recurrence itself. |
| | S_t only depends on S_{t-1} and the current token by construction. |
| | ไธไฝฟ็จๆณจๆๅๆฉ็ โ ๅ ๆๆงๅ
ๅปบไบ้ๆจ็ปๆๆฌ่บซใ |
| | S_t ไป
ไพ่ต S_{t-1} ๅๅฝๅ token๏ผ็ปๆไธไฟ่ฏๅ ๆๆงใ |
| | |
| | Computation / ่ฎก็ฎ: |
| | Training (parallel scan, O(T)): |
| | k_t = SiLU(k_proj(x_t)) # non-negative keys for PSD state |
| | S_t = ฮฑ_t ยท S_{t-1} + k_t โ v_t # monoid recurrence via prefix scan |
| | o_t = q_t ยท S_t # linear readout from state |
| | |
| | Inference (RNN mode, O(1) per token): |
| | Same recurrence, but applied one token at a time. |
| | |
| | ่ฎญ็ป (ๅนถ่กๆซๆ, O(T)): |
| | k_t = SiLU(k_proj(x_t)) # ้่ด key ไฟ่ฏ็ถๆ็ฉ้ตๅๆญฃๅฎ |
| | S_t = ฮฑ_t ยท S_{t-1} + k_t โ v_t # ้่ฟๅ็ผๆซๆๅฎ็ฐๅนบๅ็พค้ๆจ |
| | o_t = q_t ยท S_t # ไป็ถๆไธญ็บฟๆง่ฏปๅบ |
| | |
| | ๆจ็ (RNN ๆจกๅผ, ๆฏ token O(1)): |
| | ๅไธ้ๆจๅ
ฌๅผ, ไฝ้ token ้กบๅบๅบ็จใ |
| | """ |
| |
|
| | def __init__(self, config: MonoidConfig, layer_idx: int): |
| | super().__init__() |
| | self.layer_idx = layer_idx |
| | self.hidden_size = config.hidden_size |
| | self.num_heads = config.num_attention_heads |
| | self.head_dim = config.head_dim |
| | self.scaling = self.head_dim ** -0.5 |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) |
| | self.k_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) |
| | self.v_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) |
| | self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=config.attention_bias) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | self.gate_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | self.decay_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=True) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | self.q_norm = LlamaRMSNorm(self.head_dim, eps=config.rms_norm_eps) |
| | self.k_norm = LlamaRMSNorm(self.head_dim, eps=config.rms_norm_eps) |
| | self.o_norm = LlamaRMSNorm(self.head_dim, eps=config.rms_norm_eps) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | self.h0 = nn.Parameter(torch.zeros(1, self.num_heads, self.head_dim, self.head_dim)) |
| |
|
| | def forward( |
| | self, |
| | hidden_states: Tensor, |
| | attention_mask: Tensor | None = None, |
| | monoid_cache: MonoidCache | None = None, |
| | use_cache: bool = False, |
| | ) -> tuple[Tensor, tuple[Tensor, Tensor] | None]: |
| | """ |
| | Args: |
| | hidden_states: [B, T, hidden_size] |
| | attention_mask: [B, T] with 1=real token, 0=pad. |
| | For PAD positions: ฮฑ=1 (preserve state), kv=0 (no contribution). |
| | ๆฉ็ : 1=็ๅฎtoken, 0=ๅกซๅ
ใ |
| | ๅกซๅ
ไฝ็ฝฎ: ฮฑ=1 (ไฟๆ็ถๆไธๅ), kv=0 (ๆ ่ดก็ฎ)ใ |
| | monoid_cache: O(1) state cache for inference |
| | ๆจ็็จ O(1) ็ถๆ็ผๅญ |
| | use_cache: whether to use/update the cache |
| | ๆฏๅฆไฝฟ็จ/ๆดๆฐ็ผๅญ |
| | |
| | Returns: |
| | output: [B, T, hidden_size] |
| | final_state: (log_decay_acc, S) or None |
| | """ |
| | B, T, _ = hidden_states.shape |
| | H, d = self.num_heads, self.head_dim |
| |
|
| | |
| | |
| | q = self.q_proj(hidden_states).view(B, T, H, d).transpose(1, 2) |
| | k = self.k_proj(hidden_states).view(B, T, H, d).transpose(1, 2) |
| | v = self.v_proj(hidden_states).view(B, T, H, d).transpose(1, 2) |
| |
|
| | |
| | |
| | gate = torch.nn.functional.silu(self.gate_proj(hidden_states)) |
| |
|
| | |
| | |
| | q = self.q_norm(q) * self.scaling |
| | k = self.k_norm(k) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | k = torch.nn.functional.silu(k) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | raw = self.decay_proj(hidden_states) |
| | alpha = torch.sigmoid(raw) |
| | alpha = alpha.view(B, T, H, d).transpose(1, 2) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | if attention_mask is not None: |
| | |
| | mask = attention_mask[:, None, :, None].to(alpha.dtype) |
| | alpha = alpha * mask + (1 - mask) |
| | k = k * mask |
| | v = v * mask |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | if use_cache and T == 1: |
| | |
| | |
| | kv_t = torch.einsum('bhd, bhe -> bhde', k[:, :, 0], v[:, :, 0]) |
| | alpha_t = alpha[:, :, 0] |
| |
|
| | prev = monoid_cache.get_state(self.layer_idx) if monoid_cache else None |
| | if prev is None: |
| | |
| | |
| | decay_t = alpha_t |
| | while decay_t.dim() < self.h0.dim(): |
| | decay_t = decay_t.unsqueeze(-1) |
| | new_state = (alpha_t, self.h0.expand(B, -1, -1, -1) * decay_t + kv_t) |
| | else: |
| | |
| | |
| | new_state = monoid_op(prev, (alpha_t, kv_t)) |
| |
|
| | if monoid_cache is not None: |
| | monoid_cache.update(self.layer_idx, new_state) |
| |
|
| | |
| | |
| | o = torch.einsum('bhd, bhde -> bhe', q[:, :, 0], new_state[1]) |
| | o = self.o_norm(o) |
| | |
| | |
| | o = o.contiguous().view(B, 1, -1) |
| | return self.o_proj(gate * o), new_state |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | if use_cache: |
| | kv = torch.einsum('bhtd, bhte -> bhtde', k, v) |
| | states, (decay_acc, S_T) = parallel_scan_with_state(alpha, kv) |
| |
|
| | |
| | |
| | cum_alpha = torch.exp(torch.cumsum(torch.log(alpha + 1e-8), dim=2)) |
| | h0_decay = cum_alpha.unsqueeze(-1) |
| | states = states + h0_decay * self.h0.unsqueeze(2) |
| |
|
| | |
| | |
| | total_h0_decay = decay_acc.unsqueeze(-1) |
| | S_final = S_T + total_h0_decay * self.h0.squeeze(0) |
| | |
| | final_state = (decay_acc, S_final) |
| |
|
| | if monoid_cache is not None: |
| | monoid_cache.update(self.layer_idx, final_state) |
| |
|
| | |
| | |
| | o = torch.einsum('bhtd, bhtde -> bhte', q, states) |
| | o = self.o_norm(o) |
| | o = o.transpose(1, 2).contiguous().view(B, T, -1) |
| | return self.o_proj(gate * o), final_state |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | kv = torch.einsum('bhtd, bhte -> bhtde', k, v) |
| |
|
| | |
| | |
| | |
| | |
| | states = parallel_scan(alpha, kv) |
| |
|
| | |
| | |
| | cum_alpha = torch.exp(torch.cumsum(torch.log(alpha + 1e-8), dim=2)) |
| | h0_decay = cum_alpha.unsqueeze(-1) |
| | states = states + h0_decay * self.h0.unsqueeze(2) |
| |
|
| | |
| | |
| | o = torch.einsum('bhtd, bhtde -> bhte', q, states) |
| | o = self.o_norm(o) |
| |
|
| | o = o.transpose(1, 2).contiguous().view(B, T, -1) |
| | return self.o_proj(gate * o), None |
| |
|
| |
|
| | |
| | |
| | |
| | |
| |
|
| | class MonoidDecoderLayer(nn.Module): |
| | """ |
| | Pre-Norm Transformer block with Monoid attention. |
| | ไฝฟ็จๅนบๅ็พคๆณจๆๅ็ Pre-Norm Transformer ๅใ |
| | |
| | Data flow / ๆฐๆฎๆต: |
| | x โ RMSNorm โ MonoidAttn โ +residual โ RMSNorm โ LlamaMLP โ +residual โ out |
| | |
| | The MLP and RMSNorm are identical to Llama (weights transferred directly). |
| | Only MonoidAttention is the novel component. |
| | MLP ๅ RMSNorm ไธ Llama ๅฎๅ
จ็ธๅ (ๆ้็ดๆฅ่ฟ็งป)ใ |
| | ไป
MonoidAttention ๆฏๅ
จๆฐ็ปไปถใ |
| | """ |
| | gradient_checkpointing = False |
| |
|
| | def __init__(self, config: MonoidConfig, layer_idx: int): |
| | super().__init__() |
| | self.self_attn = MonoidAttention(config, layer_idx) |
| | self.mlp = LlamaMLP(config) |
| | self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| |
|
| | def forward( |
| | self, |
| | hidden_states: Tensor, |
| | attention_mask: Tensor | None = None, |
| | monoid_cache: MonoidCache | None = None, |
| | use_cache: bool = False, |
| | ) -> Tensor: |
| | |
| | |
| | residual = hidden_states |
| | hidden_states = self.input_layernorm(hidden_states) |
| | hidden_states, _ = self.self_attn(hidden_states, attention_mask=attention_mask, monoid_cache=monoid_cache, use_cache=use_cache) |
| | hidden_states = residual + hidden_states |
| |
|
| | |
| | |
| | residual = hidden_states |
| | hidden_states = self.post_attention_layernorm(hidden_states) |
| | hidden_states = self.mlp(hidden_states) |
| | hidden_states = residual + hidden_states |
| |
|
| | return hidden_states |
| |
|
| |
|
| | |
| | |
| | |
| | |
| |
|
| | class MonoidPreTrainedModel(PreTrainedModel): |
| | config_class = MonoidConfig |
| | base_model_prefix = "model" |
| | supports_gradient_checkpointing = True |
| | _no_split_modules = ["MonoidDecoderLayer"] |
| |
|
| | def _init_weights(self, module: nn.Module): |
| | std = self.config.initializer_range |
| | if isinstance(module, nn.Linear): |
| | module.weight.data.normal_(mean=0.0, std=std) |
| | if module.bias is not None: |
| | module.bias.data.zero_() |
| | elif isinstance(module, nn.Embedding): |
| | module.weight.data.normal_(mean=0.0, std=std) |
| | if module.padding_idx is not None: |
| | module.weight.data[module.padding_idx].zero_() |
| |
|
| | if isinstance(module, MonoidAttention): |
| | |
| | |
| | nn.init.constant_(module.decay_proj.bias, 3.0) |
| | |
| | |
| | |
| | nn.init.normal_(module.gate_proj.weight, mean=0.0, std=0.01) |
| | |
| | |
| | nn.init.ones_(module.o_norm.weight) |
| |
|
| | class MonoidModel(MonoidPreTrainedModel): |
| | """ |
| | Stack of MonoidDecoderLayers with token embedding and final norm. |
| | ๅนบๅ็พค่งฃ็ ๅฑๅ ๅ , ๅธฆ token ๅตๅ
ฅๅๆ็ปๅฝไธๅใ |
| | |
| | Forward: embed_tokens โ N ร MonoidDecoderLayer โ final_norm |
| | ๅๅ: embed_tokens โ N ร MonoidDecoderLayer โ final_norm |
| | """ |
| |
|
| | def __init__(self, config: MonoidConfig): |
| | super().__init__(config) |
| | self.padding_idx = config.pad_token_id |
| | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) |
| | self.layers = nn.ModuleList( |
| | [MonoidDecoderLayer(config, i) for i in range(config.num_hidden_layers)] |
| | ) |
| | self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| | self.gradient_checkpointing = False |
| | self.post_init() |
| |
|
| | def forward( |
| | self, |
| | input_ids: Tensor | None = None, |
| | attention_mask: Tensor | None = None, |
| | inputs_embeds: Tensor | None = None, |
| | monoid_cache: MonoidCache | None = None, |
| | use_cache: bool = False, |
| | ) -> BaseModelOutputWithPast: |
| | if inputs_embeds is None: |
| | inputs_embeds = self.embed_tokens(input_ids) |
| |
|
| | hidden_states = inputs_embeds |
| | for layer in self.layers: |
| | if self.gradient_checkpointing and self.training and not use_cache: |
| | hidden_states = self._gradient_checkpointing_func( |
| | layer.__call__, |
| | hidden_states, |
| | attention_mask, |
| | monoid_cache, |
| | use_cache, |
| | ) |
| | else: |
| | hidden_states = layer(hidden_states, attention_mask=attention_mask, monoid_cache=monoid_cache, use_cache=use_cache) |
| |
|
| | hidden_states = self.norm(hidden_states) |
| |
|
| | return BaseModelOutputWithPast( |
| | last_hidden_state=hidden_states, |
| | past_key_values=monoid_cache, |
| | ) |
| |
|
| |
|
| | |
| | |
| | |
| | |
| |
|
| | class MonoidForCausalLM(MonoidPreTrainedModel, GenerationMixin): |
| | """ |
| | Monoid-based causal language model with LM head. |
| | ๅบไบๅนบๅ็พค็ๅ ๆ่ฏญ่จๆจกๅ, ๅธฆ่ฏญ่จๆจกๅๅคดใ |
| | |
| | The architecture in one sentence: |
| | "Llama body + Monoid mind" โ reuse Llama's proven MLP/embeddings, |
| | replace attention with monoid state compression for O(1) inference. |
| | |
| | ไธๅฅ่ฏๆฆๆฌๆถๆ: |
| | "Llama ็่บซไฝ + ๅนบๅ็พค็ๆ็ปด" โ ๅค็จ Llama ๆ็็ MLP/ๅตๅ
ฅๅฑ, |
| | ็จๅนบๅ็พค็ถๆๅ็ผฉๆฟๆขๆณจๆๅ, ๅฎ็ฐ O(1) ๆจ็ใ |
| | """ |
| | _tied_weights_keys = ["lm_head.weight"] |
| |
|
| | |
| | |
| | |
| | |
| | _is_stateful = True |
| |
|
| | def __init__(self, config: MonoidConfig): |
| | super().__init__(config) |
| | self.model = MonoidModel(config) |
| | self.vocab_size = config.vocab_size |
| | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| | self.post_init() |
| |
|
| | def get_input_embeddings(self): |
| | return self.model.embed_tokens |
| |
|
| | def set_input_embeddings(self, value): |
| | self.model.embed_tokens = value |
| |
|
| | def get_output_embeddings(self): |
| | return self.lm_head |
| |
|
| | def set_output_embeddings(self, new_embeddings): |
| | self.lm_head = new_embeddings |
| |
|
| | def prepare_inputs_for_generation( |
| | self, |
| | input_ids: Tensor, |
| | past_key_values=None, |
| | attention_mask: Tensor | None = None, |
| | inputs_embeds: Tensor | None = None, |
| | **kwargs, |
| | ) -> dict: |
| | """ |
| | Called by GenerationMixin at each decoding step. |
| | GenerationMixin ๅจๆฏไธช่งฃ็ ๆญฅ่ฐ็จๆญคๆนๆณใ |
| | |
| | HuggingFace may pass a DynamicCache; we intercept and replace |
| | it with MonoidCache since we don't use standard KV-cache. |
| | HuggingFace ๅฏ่ฝไผ ๅ
ฅ DynamicCache; ๆไปฌๆฆๆชๅนถๆฟๆขไธบ |
| | MonoidCache, ๅ ไธบๆไปฌไธไฝฟ็จๆ ๅ KV ็ผๅญใ |
| | """ |
| | |
| | |
| | if past_key_values is not None and not isinstance(past_key_values, MonoidCache): |
| | past_key_values = None |
| |
|
| | if past_key_values is not None and past_key_values.seen_tokens > 0: |
| | |
| | |
| | input_ids = input_ids[:, -1:] |
| | |
| | |
| | attention_mask = None |
| |
|
| | model_inputs = { |
| | "input_ids": input_ids, |
| | "attention_mask": attention_mask, |
| | "monoid_cache": past_key_values, |
| | "use_cache": True, |
| | } |
| | return model_inputs |
| |
|
| | def forward( |
| | self, |
| | input_ids: Tensor | None = None, |
| | attention_mask: Tensor | None = None, |
| | |
| | position_ids: Tensor | None = None, |
| | |
| | past_key_values: MonoidCache | None = None, |
| | inputs_embeds: Tensor | None = None, |
| | labels: Tensor | None = None, |
| | use_cache: bool | None = None, |
| | monoid_cache: MonoidCache | None = None, |
| | output_attentions: bool | None = None, |
| | output_hidden_states: bool | None = None, |
| | logits_to_keep: int | Tensor = 0, |
| | **kwargs, |
| | ) -> CausalLMOutputWithPast: |
| | |
| | |
| | cache = monoid_cache or past_key_values |
| |
|
| | |
| | |
| | if cache is not None and not isinstance(cache, MonoidCache): |
| | cache = None |
| |
|
| | if use_cache and cache is None: |
| | cache = MonoidCache() |
| |
|
| | outputs = self.model( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | inputs_embeds=inputs_embeds, |
| | monoid_cache=cache, |
| | use_cache=bool(use_cache), |
| | ) |
| |
|
| | hidden_states = outputs.last_hidden_state |
| |
|
| | |
| | |
| | slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) and logits_to_keep > 0 else slice(None) |
| | logits = self.lm_head(hidden_states[:, slice_indices, :]) |
| |
|
| | |
| | |
| | loss = None |
| | if labels is not None: |
| | shift_logits = logits[..., :-1, :].contiguous() |
| | shift_labels = labels[..., 1:].contiguous() |
| | loss = nn.functional.cross_entropy( |
| | shift_logits.view(-1, self.vocab_size), |
| | shift_labels.view(-1), |
| | ignore_index=-100, |
| | ) |
| |
|
| | if cache is not None: |
| | cache.seen_tokens += (input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]) |
| |
|
| | return CausalLMOutputWithPast( |
| | loss=loss, |
| | logits=logits, |
| | past_key_values=cache, |
| | ) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | AutoConfig.register("monoid", MonoidConfig) |
| | AutoModelForCausalLM.register(MonoidConfig, MonoidForCausalLM) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | if __name__ == '__main__': |
| | device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu') |
| | print(f'Device: {device}') |
| |
|
| | config = MonoidConfig( |
| | vocab_size=49152, |
| | hidden_size=576, |
| | intermediate_size=1536, |
| | num_hidden_layers=30, |
| | num_attention_heads=9, |
| | head_dim=64, |
| | rms_norm_eps=1e-5, |
| | hidden_act="silu", |
| | tie_word_embeddings=True, |
| | ) |
| | model = MonoidForCausalLM(config).to(device) |
| | n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| | print(f'Parameters: {n_params:,}') |
| |
|
| | |
| | B, T = 2, 64 |
| | ids = torch.randint(0, config.vocab_size, (B, T), device=device) |
| | out = model(ids, labels=ids) |
| | print(f'Train โ logits: {out.logits.shape}, loss: {out.loss:.4f}') |
| |
|
| | |
| | prompt = torch.randint(0, config.vocab_size, (1, 8), device=device) |
| | cache = MonoidCache() |
| | |
| | prefill_out = model(prompt, use_cache=True, monoid_cache=cache) |
| | print(f'Prefill โ logits: {prefill_out.logits.shape}, cache seen: {cache.seen_tokens}') |
| | |
| | next_tok = prefill_out.logits[:, -1:].argmax(dim=-1) |
| | step_out = model(next_tok, use_cache=True, monoid_cache=cache) |
| | print(f'Decode โ logits: {step_out.logits.shape}, cache seen: {cache.seen_tokens}') |
| |
|
| | |
| | print('\nMonoid associativity check / ๅนบๅ็พค็ปๅๅพ้ช่ฏ:') |
| | a = (torch.randn(1, 1, 1), torch.randn(1, 1, 4, 4)) |
| | b = (torch.randn(1, 1, 1), torch.randn(1, 1, 4, 4)) |
| | c = (torch.randn(1, 1, 1), torch.randn(1, 1, 4, 4)) |
| | ab_c = monoid_op(monoid_op(a, b), c) |
| | a_bc = monoid_op(a, monoid_op(b, c)) |
| | err = (ab_c[1] - a_bc[1]).abs().max().item() |
| | print(f' |(aโb)โc - aโ(bโc)| = {err:.2e}') |
| |
|
| | print('\nDone.') |
| |
|