bbkdevops's picture
download
raw
13.3 kB
"""
TinyMind Omega — Core Layers (High-Efficiency Edition)
นวัตกรรม 3 ชั้น:
1. GatedLinearAttention — O(n) kernel attention + KV Cache สำหรับ inference
2. SelectiveSSM — Parallel scan O(n log n) แทน O(n) sequential
3. KANFeedForward — Kolmogorov-Arnold splines, parameter-efficient
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from .config import OmegaConfig
# ─── RMSNorm (เร็วกว่า LayerNorm ~30%) ──────────────────────────────────────
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
# ─── RoPE ────────────────────────────────────────────────────────────────────
class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, max_seq: int = 4096, theta: float = 10_000.0):
super().__init__()
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._seq_len_cached = 0
self._cos_cached: torch.Tensor
self._sin_cached: torch.Tensor
self._build(max_seq)
def _build(self, seq_len: int):
self._seq_len_cached = seq_len
t = torch.arange(seq_len, device=self.inv_freq.device).float() # type: ignore[attr-defined]
freqs = torch.outer(t, self.inv_freq) # type: ignore[attr-defined]
emb = torch.cat([freqs, freqs], dim=-1)
self.register_buffer("_cos_cached", emb.cos(), persistent=False)
self.register_buffer("_sin_cached", emb.sin(), persistent=False)
def forward(self, seq_len: int) -> tuple[torch.Tensor, torch.Tensor]:
if seq_len > self._seq_len_cached:
self._build(seq_len * 2)
return self._cos_cached[:seq_len], self._sin_cached[:seq_len]
def rotate_half(x: torch.Tensor) -> torch.Tensor:
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
return torch.cat([-x2, x1], dim=-1)
def apply_rope(q: torch.Tensor, k: torch.Tensor,
cos: torch.Tensor, sin: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
cos = cos[None, :, None, :] # (1, T, 1, D)
sin = sin[None, :, None, :]
q = (q * cos) + (rotate_half(q) * sin)
k = (k * cos) + (rotate_half(k) * sin)
return q, k
# ─── 1. Gated Linear Attention + KV Cache ────────────────────────────────────
class GatedLinearAttention(nn.Module):
"""
O(n) attention: φ(Q)·[φ(K)ᵀV] / φ(Q)·[φ(K)ᵀ1]
kernel φ(x) = ELU(x)+1 (strictly positive → ใช้ associativity ได้)
KV Cache: เก็บ running sum แทน full sequence
"""
def __init__(self, cfg: OmegaConfig):
super().__init__()
self.H = cfg.n_heads
self.D = cfg.head_dim
inner = cfg.n_heads * cfg.head_dim
self.qkv = nn.Linear(cfg.dim, inner * 3, bias=False)
self.gate = nn.Linear(cfg.dim, inner, bias=False)
self.out = nn.Linear(inner, cfg.dim, bias=False)
self.norm = RMSNorm(inner)
self.rope = RotaryEmbedding(cfg.head_dim, cfg.max_seq_len, cfg.rope_theta)
@staticmethod
def phi(x: torch.Tensor) -> torch.Tensor:
return F.elu(x) + 1.0 # strictly positive
def forward(
self,
x: torch.Tensor,
kv_cache: dict | None = None, # inference cache
mask: torch.Tensor | None = None,
) -> tuple[torch.Tensor, dict | None]:
B, T, _ = x.shape
H, D = self.H, self.D
q, k, v = self.qkv(x).chunk(3, dim=-1)
q = rearrange(q, "b t (h d) -> b t h d", h=H)
k = rearrange(k, "b t (h d) -> b t h d", h=H)
v = rearrange(v, "b t (h d) -> b t h d", h=H)
g = torch.sigmoid(self.gate(x)) # (B, T, H*D)
cos, sin = self.rope(T + (kv_cache["offset"] if kv_cache else 0))
offset = kv_cache["offset"] if kv_cache else 0
q, k = apply_rope(q, k, cos[offset:offset+T], sin[offset:offset+T])
qk = self.phi(q) # (B,T,H,D)
kk = self.phi(k)
if kv_cache is not None:
# Inference: incremental update of running sums
S = kv_cache.get("S", torch.zeros(B, H, D, D, device=x.device, dtype=x.dtype))
z = kv_cache.get("z", torch.zeros(B, H, D, device=x.device, dtype=x.dtype))
new_S = S + torch.einsum("bthd,bthe->bthde", kk, v).sum(1) # += k^T v over time
new_z = z + kk.sum(1)
out_t = torch.einsum("bthd,bhde->bthe", qk, new_S)
denom = torch.einsum("bthd,bhd->bth", qk, new_z).clamp(min=1e-6).unsqueeze(-1)
out_t = out_t / denom
kv_cache = {"S": new_S, "z": new_z, "offset": offset + T}
else:
# Training: causal cumulative sum
kv_seq = torch.einsum("bthd,bthe->bthde", kk, v) # (B,T,H,D,D)
S_cum = kv_seq.cumsum(dim=1)
z_cum = kk.cumsum(dim=1)
out_t = torch.einsum("bthd,bthde->bthe", qk, S_cum)
denom = torch.einsum("bthd,bthd->bth", qk, z_cum).clamp(min=1e-6).unsqueeze(-1)
out_t = out_t / denom
out = rearrange(out_t, "b t h d -> b t (h d)") * g
out = self.norm(out)
return self.out(out), kv_cache
# ─── 2. Selective SSM — Parallel Scan ────────────────────────────────────────
def parallel_scan(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
"""
Work-efficient parallel prefix scan (Blelloch 1990):
compute h_t = A_t·h_{t-1} + B_t for all t in O(n log n) ops
A: (B, T, D, S) — กำหนด state decay
B: (B, T, D, S) — input contribution
return h: (B, T, D, S)
"""
if A.shape != B.shape:
raise ValueError(f"A and B must have the same shape, got {A.shape} and {B.shape}")
B_, T, D, S = A.shape
# Correct reference implementation. This keeps training numerically honest
# until a CUDA/CUTLASS scan kernel replaces it.
h_list = []
h_t = torch.zeros(B_, D, S, device=A.device, dtype=A.dtype)
for t in range(T):
h_t = A[:, t] * h_t + B[:, t]
h_list.append(h_t)
return torch.stack(h_list, dim=1) # (B, T, D, S)
class SelectiveSSM(nn.Module):
"""
Mamba-style SSM แต่:
- ใช้ parallel_scan แทน sequential loop ตอน train
- ใช้ incremental update ตอน inference (O(1) per step)
- VRAM-efficient: ไม่ต้องเก็บ full sequence hidden states
"""
def __init__(self, cfg: OmegaConfig):
super().__init__()
d = cfg.dim * cfg.ssm_expand
self.d_inner = d
self.d_state = cfg.ssm_d_state
self.d_conv = cfg.ssm_d_conv
self.in_proj = nn.Linear(cfg.dim, d * 2, bias=False)
self.conv1d = nn.Conv1d(d, d, cfg.ssm_d_conv,
padding=cfg.ssm_d_conv - 1,
groups=d, bias=True)
self.x_proj = nn.Linear(d, cfg.ssm_d_state * 2 + d, bias=False)
self.dt_proj = nn.Linear(d, d, bias=True)
nn.init.constant_(self.dt_proj.bias, math.log(math.expm1(1.0)))
A = torch.arange(1, cfg.ssm_d_state + 1, dtype=torch.float32).repeat(d, 1)
self.A_log = nn.Parameter(torch.log(A))
self.D_ = nn.Parameter(torch.ones(d))
self.out_proj = nn.Linear(d, cfg.dim, bias=False)
self.norm = RMSNorm(d)
def forward(
self,
x: torch.Tensor,
ssm_cache: dict | None = None,
mask: torch.Tensor | None = None,
) -> tuple[torch.Tensor, dict | None]:
B, T, _ = x.shape
xz = self.in_proj(x)
x_in, z = xz.chunk(2, dim=-1)
# Depthwise conv (causal)
xc = rearrange(x_in, "b t d -> b d t")
xc = self.conv1d(xc)[..., :T]
xc = rearrange(xc, "b d t -> b t d")
xc = F.silu(xc)
# SSM parameters (input-dependent = "selective")
bcd = self.x_proj(xc) # (B,T,2S+d)
d_s = self.d_state
B_s = bcd[..., :d_s] # (B,T,S)
C_s = bcd[..., d_s:2*d_s] # (B,T,S)
dt = F.softplus(self.dt_proj(bcd[..., 2*d_s:])) # (B,T,d)
A = -torch.exp(self.A_log.float()) # (d,S)
if ssm_cache is not None:
# Inference: single-step O(1)
h_prev = ssm_cache.get(
"h", torch.zeros(B, self.d_inner, self.d_state, device=x.device, dtype=x.dtype)
)
dA = torch.exp(dt[:, 0].unsqueeze(-1) * A.unsqueeze(0)) # (B,d,S)
dB = dt[:, 0].unsqueeze(-1) * B_s[:, 0].unsqueeze(1) # (B,d,S)
h = dA * h_prev + dB * xc[:, 0].unsqueeze(-1)
y = (h * C_s[:, 0].unsqueeze(1)).sum(-1) # (B,d)
y_out = y + self.D_ * xc[:, 0]
y_out = y_out.unsqueeze(1) # (B,1,d)
ssm_cache = {"h": h}
else:
# Training: parallel scan
dA = torch.exp(dt.unsqueeze(-1) * A.unsqueeze(0).unsqueeze(0)) # (B,T,d,S)
dB = dt.unsqueeze(-1) * B_s.unsqueeze(2) # (B,T,d,S)
x_exp = xc.unsqueeze(-1).expand_as(dB)
B_in = dB * x_exp
h_seq = parallel_scan(dA, B_in) # (B,T,d,S)
y_out = (h_seq * C_s.unsqueeze(2)).sum(-1) + self.D_ * xc # (B,T,d)
y_out = y_out * F.silu(z)
y_out = self.norm(y_out)
return self.out_proj(y_out), ssm_cache
# ─── 3. KAN FeedForward (Parameter-Efficient) ────────────────────────────────
class KANLinear(nn.Module):
"""
B-Spline KAN: แทน MLP neuron ด้วย learnable univariate functions
ประหยัด parameter ~40% สำหรับ expressiveness เดียวกัน
"""
def __init__(self, in_f: int, out_f: int, grid: int = 5, order: int = 3):
super().__init__()
self.in_f = in_f
self.out_f = out_f
self.grid = grid
self.order = order
n_basis = grid + order
# Spline coefficients (learnable)
self.coeff = nn.Parameter(torch.randn(out_f, in_f, n_basis) * 0.1)
# Residual linear
self.base = nn.Linear(in_f, out_f, bias=False)
nn.init.kaiming_uniform_(self.base.weight, a=math.sqrt(5))
pts = torch.linspace(-1, 1, grid + 1)
self.register_buffer("pts", pts, persistent=False)
def bspline_basis(self, x: torch.Tensor) -> torch.Tensor:
"""x: (N, in_f) -> (N, in_f, grid+order)."""
x = x.clamp(-1, 1).unsqueeze(-1)
n_basis = self.grid + self.order
centers = torch.linspace(-1, 1, n_basis, device=x.device, dtype=x.dtype)
width = 2.0 / max(n_basis - 1, 1)
return torch.exp(-((x - centers) / (width + 1e-6)).pow(2))
def forward(self, x: torch.Tensor) -> torch.Tensor:
shape = x.shape
x_flat = x.reshape(-1, self.in_f)
basis = self.bspline_basis(x_flat) # (N, in_f, K)
spline = torch.einsum("nig,oig->no", basis, self.coeff) # (N, out_f)
base = F.silu(self.base(x_flat)) # (N, out_f)
return (spline + base).reshape(*shape[:-1], self.out_f)
class KANFeedForward(nn.Module):
"""
Efficient KAN FFN:
- KAN เฉพาะ first projection (ส่วนที่ได้ประโยชน์มากสุด)
- Standard linear สำหรับ down projection (เร็ว)
- SwiGLU gating เพิ่ม expressiveness โดยไม่เพิ่ม depth
"""
def __init__(self, cfg: OmegaConfig):
super().__init__()
# SwiGLU hidden: ใช้ 2/3 ของ standard FFN เพื่อ param count เท่ากัน
hidden = int(cfg.dim * cfg.ffn_mult * 2 / 3)
hidden = (hidden + 63) // 64 * 64 # align to 64
self.kan_up = KANLinear(cfg.dim, hidden, grid=cfg.kan_grid, order=cfg.kan_order)
self.gate = nn.Linear(cfg.dim, hidden, bias=False)
self.down = nn.Linear(hidden, cfg.dim, bias=False)
self.norm = RMSNorm(hidden)
self.drop = nn.Dropout(cfg.dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# SwiGLU: KAN(x) * sigmoid(gate(x))
h = self.kan_up(x) * torch.sigmoid(self.gate(x))
h = self.norm(h)
h = self.drop(h)
return self.down(h)

Xet Storage Details

Size:
13.3 kB
·
Xet hash:
a655404dc1da5a2a2b3400f303b918011937b9548c42f6c7b703d6d69d5db215

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.