Create lexical_atlas.py
Browse files
experiments/exp_007_aleph_routed_attention/lexical_atlas.py
ADDED
|
@@ -0,0 +1,512 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# lexical_atlas.py
|
| 2 |
+
"""
|
| 3 |
+
Lexical Atlas β the full wordnet-lexical-topology vocabulary on the sphere
|
| 4 |
+
===========================================================================
|
| 5 |
+
|
| 6 |
+
Extracts the ENTIRE AbstractPhil/wordnet-lexical-topology setup (~12.8M
|
| 7 |
+
n-grams across {nltk, hf, unicode} x {char, word} x {1..5}gram configs) into
|
| 8 |
+
spherical coordinates, correctly spaced β where "correct" is determined by
|
| 9 |
+
capacity mathematics, not hope.
|
| 10 |
+
|
| 11 |
+
THE CAPACITY LAW (computed exactly, 2026-06-09):
|
| 12 |
+
12.8M uniformly spaced points on S^(D-1), median nearest-neighbor angle:
|
| 13 |
+
D=4 : 0.363 deg -> 0.06 logits of address contrast at tau=0.1
|
| 14 |
+
(neighbors indistinguishable through K=64; fp16
|
| 15 |
+
cannot resolve the cosines, fp32 marginal)
|
| 16 |
+
D=32: 39.1 deg | D=48: 47.6 deg -> 7-8 logits, comfortable
|
| 17 |
+
The CM-band result (band-valid D=32-112, sweet spot 32-56) independently
|
| 18 |
+
prescribes the same range. THEREFORE the atlas is TWO-TIER:
|
| 19 |
+
|
| 20 |
+
TIER 1 (base) : deterministic low-discrepancy placement at band-valid D
|
| 21 |
+
(default 48) β scrambled-Sobol -> Gaussian -> normalize.
|
| 22 |
+
Uniform by construction, reproducible by seed, unique
|
| 23 |
+
per n-gram. This is "spaced on the sphere correctly."
|
| 24 |
+
TIER 2 (view) : the LEARNED D_addr=4 address-space view extracted from a
|
| 25 |
+
trained AlephLM checkpoint β per n-gram: bytes -> pad ->
|
| 26 |
+
trigrams -> kappa rows (W_kappa o byte_emb) -> mean ->
|
| 27 |
+
normalize. This is where the model actually PLACED the
|
| 28 |
+
vocabulary; crowded by necessity (see law), meaningful
|
| 29 |
+
as geometry-of-content, not as unique identity.
|
| 30 |
+
|
| 31 |
+
Honesty on the learned view: mean composition is order-insensitive, so
|
| 32 |
+
anagrammatic n-grams (same trigram multiset) COLLIDE; collisions are counted
|
| 33 |
+
and reported per config. The deterministic tier never collides.
|
| 34 |
+
|
| 35 |
+
Per-config outputs:
|
| 36 |
+
atlas/{config}.parquet columns: ngram, rank, frequency, n_tri,
|
| 37 |
+
vec_base (D_base floats), vec_view (4 floats)
|
| 38 |
+
atlas/{config}.stats.json NN-angle distribution (sampled), statute of
|
| 39 |
+
both tiers (4k subsample), collision count
|
| 40 |
+
|
| 41 |
+
Usage:
|
| 42 |
+
python lexical_atlas.py --checkpoint aleph_lm_hybrid_corpus.pt \\
|
| 43 |
+
--configs char_eng_unigram char_eng_2gram char_eng_3gram \\
|
| 44 |
+
char_eng_4gram char_eng_5gram --d-base 48
|
| 45 |
+
# --configs all -> every config in the dataset (~12.8M rows total)
|
| 46 |
+
|
| 47 |
+
Depends: aleph_lm.py (+ its deps), pyarrow, huggingface_hub.
|
| 48 |
+
Author: AbstractPhil + Mirel Date: 2026-06-09 License: MIT
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
from __future__ import annotations
|
| 52 |
+
|
| 53 |
+
import json
|
| 54 |
+
import math
|
| 55 |
+
import os
|
| 56 |
+
from dataclasses import dataclass, field
|
| 57 |
+
from typing import Dict, List, Optional, Tuple
|
| 58 |
+
|
| 59 |
+
import numpy as np
|
| 60 |
+
import torch
|
| 61 |
+
import torch.nn.functional as F
|
| 62 |
+
from torch import Tensor
|
| 63 |
+
|
| 64 |
+
DATASET = "AbstractPhil/wordnet-lexical-topology"
|
| 65 |
+
PAD_BYTE = 0x00 # reserved pad symbol (documented, learned slot)
|
| 66 |
+
|
| 67 |
+
_GRAMS = ("unigram", "2gram", "3gram", "4gram", "5gram")
|
| 68 |
+
SOURCE_CONFIGS = ([f"nltk_{k}_eng_{n}" for k in ("char", "word") for n in _GRAMS]
|
| 69 |
+
+ [f"hf_{k}_eng_{n}" for k in ("char", "word") for n in _GRAMS]
|
| 70 |
+
+ [f"unicode_global_{n}" for n in _GRAMS])
|
| 71 |
+
LEGACY_CONFIGS = [f"{k}_eng_{n}" for k in ("char", "word") for n in _GRAMS]
|
| 72 |
+
# legacy unprefixed configs are pre-merged ANCESTORS (verified: char_eng_3gram
|
| 73 |
+
# is a superset of nltk_char_eng_3gram, freq corr 0.914) β excluded from 'all'
|
| 74 |
+
# to avoid double counting; available explicitly.
|
| 75 |
+
ALL_CONFIGS = SOURCE_CONFIGS
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def source_of(config: str) -> str:
|
| 79 |
+
for s in ("nltk", "hf", "unicode"):
|
| 80 |
+
if config.startswith(s):
|
| 81 |
+
return s
|
| 82 |
+
return "legacy"
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 86 |
+
# Config
|
| 87 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 88 |
+
|
| 89 |
+
@dataclass
|
| 90 |
+
class AtlasConfig:
|
| 91 |
+
checkpoint: Optional[str] = None # AlephLM ckpt (None = base tier only)
|
| 92 |
+
configs: List[str] = field(default_factory=lambda: [
|
| 93 |
+
"char_eng_unigram", "char_eng_2gram", "char_eng_3gram",
|
| 94 |
+
"char_eng_4gram", "char_eng_5gram"])
|
| 95 |
+
d_base: int = 48 # band-valid (CM sweet spot 32-56)
|
| 96 |
+
base_seed: int = 1234 # determinism of Tier 1
|
| 97 |
+
out_dir: str = "atlas"
|
| 98 |
+
batch: int = 65536
|
| 99 |
+
max_tri: int = 16 # n-grams longer than 48 bytes truncated
|
| 100 |
+
stats_sample: int = 4000
|
| 101 |
+
device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
# ββββββββββββββββββββββββββββββοΏ½οΏ½οΏ½βββββββββββββββββββββββββββββββββββ
|
| 105 |
+
# Tier 1 β deterministic band-valid base (correct spacing by construction)
|
| 106 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 107 |
+
|
| 108 |
+
class SobolSphere:
|
| 109 |
+
"""Low-discrepancy points on S^(D-1): scrambled Sobol -> inverse-normal ->
|
| 110 |
+
normalize. Deterministic per (seed, global index): the same n-gram (by its
|
| 111 |
+
global rank position) always receives the same point. Never collides."""
|
| 112 |
+
|
| 113 |
+
def __init__(self, D: int, seed: int):
|
| 114 |
+
self.D, self.seed = D, seed
|
| 115 |
+
self.eng = torch.quasirandom.SobolEngine(D, scramble=True, seed=seed)
|
| 116 |
+
self._cursor = 0
|
| 117 |
+
|
| 118 |
+
def take(self, n: int) -> Tensor:
|
| 119 |
+
u = self.eng.draw(n).clamp(1e-6, 1 - 1e-6)
|
| 120 |
+
g = torch.erfinv(2 * u - 1) * math.sqrt(2.0) # inverse normal CDF
|
| 121 |
+
self._cursor += n
|
| 122 |
+
return F.normalize(g, dim=-1)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 126 |
+
# Tier 2 β learned address-space view (the model's own placement)
|
| 127 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 128 |
+
|
| 129 |
+
class LearnedView:
|
| 130 |
+
"""kappa-row composer from a trained AlephLM checkpoint."""
|
| 131 |
+
|
| 132 |
+
def __init__(self, checkpoint: str, device: str):
|
| 133 |
+
from aleph_lm import AlephLM, AlephLMConfig
|
| 134 |
+
d = torch.load(checkpoint, map_location=device, weights_only=False)
|
| 135 |
+
fields = AlephLMConfig.__dataclass_fields__
|
| 136 |
+
cfg = AlephLMConfig(**{k: v for k, v in d["config"].items() if k in fields})
|
| 137 |
+
bank = d.get("bank", None)
|
| 138 |
+
self.model = AlephLM(cfg, bank=bank).to(device)
|
| 139 |
+
self.model.load_state_dict(d["model_state_dict"])
|
| 140 |
+
self.model.eval()
|
| 141 |
+
self.cfg, self.device = cfg, device
|
| 142 |
+
|
| 143 |
+
@torch.no_grad()
|
| 144 |
+
def compose(self, tri: Tensor, n_tri: Tensor) -> Tensor:
|
| 145 |
+
"""tri: (B, T, 3) padded trigram bytes; n_tri: (B,) valid counts.
|
| 146 |
+
Returns (B, D_addr) unit rows: normalized mean of per-trigram
|
| 147 |
+
kappa rows over the valid prefix. Order-insensitive (collisions
|
| 148 |
+
among anagrams; counted upstream)."""
|
| 149 |
+
m = self.model
|
| 150 |
+
tri = tri.to(self.device)
|
| 151 |
+
e = sum(emb(tri[..., i]) for i, emb in enumerate(m.byte_emb)) # (B,T,d)
|
| 152 |
+
rows = F.normalize(m.W_kappa(e), dim=-1) # (B,T,Da)
|
| 153 |
+
mask = (torch.arange(tri.shape[1], device=self.device)[None, :]
|
| 154 |
+
< n_tri.to(self.device)[:, None]).float().unsqueeze(-1)
|
| 155 |
+
mean = (rows * mask).sum(1) / mask.sum(1).clamp_min(1e-9)
|
| 156 |
+
return F.normalize(mean, dim=-1).cpu()
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def ngrams_to_trigrams(ngrams: List[str], max_tri: int
|
| 160 |
+
) -> Tuple[Tensor, Tensor, np.ndarray]:
|
| 161 |
+
"""UTF-8 encode, pad to multiple of 3 with PAD_BYTE, frame as trigrams.
|
| 162 |
+
Returns (B, max_tri, 3) bytes, (B,) counts, and the trigram-multiset hash
|
| 163 |
+
per n-gram (for anagram-collision counting)."""
|
| 164 |
+
B = len(ngrams)
|
| 165 |
+
out = np.zeros((B, max_tri, 3), dtype=np.int64)
|
| 166 |
+
counts = np.zeros(B, dtype=np.int64)
|
| 167 |
+
mhash = np.zeros(B, dtype=np.uint64)
|
| 168 |
+
for i, s in enumerate(ngrams):
|
| 169 |
+
b = str(s).encode("utf-8", errors="ignore")[: 3 * max_tri]
|
| 170 |
+
if len(b) % 3:
|
| 171 |
+
b = b + bytes([PAD_BYTE]) * (3 - len(b) % 3)
|
| 172 |
+
t = np.frombuffer(b, dtype=np.uint8).reshape(-1, 3).astype(np.int64)
|
| 173 |
+
n = len(t)
|
| 174 |
+
out[i, :n] = t
|
| 175 |
+
counts[i] = max(n, 1)
|
| 176 |
+
ids = (t[:, 0] * 65536 + t[:, 1] * 256 + t[:, 2]).astype(np.uint64)
|
| 177 |
+
h = np.uint64(0)
|
| 178 |
+
for v in np.sort(ids): # order-free multiset hash
|
| 179 |
+
h = (h * np.uint64(1099511628211)) ^ (v + np.uint64(0x9E3779B9))
|
| 180 |
+
mhash[i] = h ^ np.uint64(n)
|
| 181 |
+
return torch.from_numpy(out), torch.from_numpy(counts), mhash
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 185 |
+
# Spacing battery
|
| 186 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 187 |
+
|
| 188 |
+
def spacing_stats(vecs: Tensor, sample: int, seed: int = 0) -> Dict:
|
| 189 |
+
"""Sampled NN-angle distribution + statute on a subsample."""
|
| 190 |
+
g = torch.Generator().manual_seed(seed)
|
| 191 |
+
idx = torch.randperm(len(vecs), generator=g)[: min(sample, len(vecs))]
|
| 192 |
+
X = F.normalize(vecs[idx].float(), dim=-1)
|
| 193 |
+
cos = (X @ X.t()).clamp(-1, 1)
|
| 194 |
+
cos.fill_diagonal_(-1)
|
| 195 |
+
nn_deg = torch.acos(cos.max(dim=-1).values) * 180 / math.pi
|
| 196 |
+
st = statute(X)
|
| 197 |
+
return {"nn_deg_median": nn_deg.median().item(),
|
| 198 |
+
"nn_deg_p05": nn_deg.quantile(0.05).item(),
|
| 199 |
+
"nn_deg_p95": nn_deg.quantile(0.95).item(),
|
| 200 |
+
"statute": st}
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 204 |
+
# Extraction
|
| 205 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 206 |
+
|
| 207 |
+
def extract_config(name: str, cfg: AtlasConfig, sobol: SobolSphere,
|
| 208 |
+
view: Optional[LearnedView]) -> Dict:
|
| 209 |
+
import pyarrow as pa
|
| 210 |
+
import pyarrow.parquet as pq
|
| 211 |
+
from huggingface_hub import hf_hub_download
|
| 212 |
+
|
| 213 |
+
path = hf_hub_download(DATASET, f"data/{name}-00000-of-00001.parquet",
|
| 214 |
+
repo_type="dataset")
|
| 215 |
+
t = pq.read_table(path, columns=["ngram", "rank", "frequency"]) \
|
| 216 |
+
.to_pandas().sort_values("rank").reset_index(drop=True)
|
| 217 |
+
N = len(t)
|
| 218 |
+
print(f"[{name}] {N:,} n-grams")
|
| 219 |
+
|
| 220 |
+
base = sobol.take(N) # (N, D_base)
|
| 221 |
+
views, counts, hashes = [], [], []
|
| 222 |
+
if view is not None:
|
| 223 |
+
for s0 in range(0, N, cfg.batch):
|
| 224 |
+
chunk = t["ngram"].iloc[s0: s0 + cfg.batch].tolist()
|
| 225 |
+
tri, n_tri, mh = ngrams_to_trigrams(chunk, cfg.max_tri)
|
| 226 |
+
views.append(view.compose(tri, n_tri))
|
| 227 |
+
counts.append(n_tri)
|
| 228 |
+
hashes.append(mh)
|
| 229 |
+
vview = torch.cat(views)
|
| 230 |
+
n_tri = torch.cat(counts)
|
| 231 |
+
mh = np.concatenate(hashes)
|
| 232 |
+
n_coll = int(N - len(np.unique(mh)))
|
| 233 |
+
else:
|
| 234 |
+
vview, n_tri, n_coll = None, None, 0
|
| 235 |
+
|
| 236 |
+
os.makedirs(cfg.out_dir, exist_ok=True)
|
| 237 |
+
cols = {"ngram": pa.array(t["ngram"].astype(str)),
|
| 238 |
+
"rank": pa.array(t["rank"].astype("int64")),
|
| 239 |
+
"frequency": pa.array(t["frequency"].astype("float64")),
|
| 240 |
+
"vec_base": pa.array(base.numpy().tolist(),
|
| 241 |
+
type=pa.list_(pa.float32(), cfg.d_base))}
|
| 242 |
+
if vview is not None:
|
| 243 |
+
cols["n_tri"] = pa.array(n_tri.numpy().astype("int8"))
|
| 244 |
+
cols["vec_view"] = pa.array(vview.numpy().tolist(),
|
| 245 |
+
type=pa.list_(pa.float32(), vview.shape[1]))
|
| 246 |
+
out_path = os.path.join(cfg.out_dir, f"{name}.parquet")
|
| 247 |
+
pq.write_table(pa.table(cols), out_path)
|
| 248 |
+
|
| 249 |
+
stats = {"config": name, "n": N, "d_base": cfg.d_base,
|
| 250 |
+
"anagram_collisions_view": n_coll,
|
| 251 |
+
"base": spacing_stats(base, cfg.stats_sample)}
|
| 252 |
+
if vview is not None:
|
| 253 |
+
stats["view"] = spacing_stats(vview, cfg.stats_sample)
|
| 254 |
+
with open(os.path.join(cfg.out_dir, f"{name}.stats.json"), "w") as f:
|
| 255 |
+
json.dump(stats, f, indent=2, default=str)
|
| 256 |
+
print(f" base NN {stats['base']['nn_deg_median']:.2f} deg "
|
| 257 |
+
f"(statute {stats['base']['statute']['statute']})"
|
| 258 |
+
+ (f" view NN {stats['view']['nn_deg_median']:.3f} deg "
|
| 259 |
+
f"(statute {stats['view']['statute']['statute']}, "
|
| 260 |
+
f"collisions {n_coll})" if vview is not None else "")
|
| 261 |
+
+ f" -> {out_path}")
|
| 262 |
+
return stats
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def build_atlas(cfg: AtlasConfig) -> List[Dict]:
|
| 266 |
+
names = ALL_CONFIGS if cfg.configs == ["all"] else cfg.configs
|
| 267 |
+
sobol = SobolSphere(cfg.d_base, cfg.base_seed) # ONE stream:
|
| 268 |
+
view = LearnedView(cfg.checkpoint, cfg.device) if cfg.checkpoint else None
|
| 269 |
+
# global index = unique placement across ALL configs (never reused)
|
| 270 |
+
all_stats = []
|
| 271 |
+
for name in names:
|
| 272 |
+
all_stats.append(extract_config(name, cfg, sobol, view))
|
| 273 |
+
total = sum(s["n"] for s in all_stats)
|
| 274 |
+
print(f"\n[atlas] {total:,} n-grams placed at D={cfg.d_base} "
|
| 275 |
+
f"(Tier 1, deterministic, collision-free)"
|
| 276 |
+
+ (f" + learned D=4 view (Tier 2)" if view else ""))
|
| 277 |
+
return all_stats
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 283 |
+
# Canon β weighted dedupe across sources: ONE STRING, ONE POINT
|
| 284 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 285 |
+
# Cross-config duplicates of the same n-gram must not receive different
|
| 286 |
+
# Tier-1 placements. Canonization: per-config frequencies are normalized
|
| 287 |
+
# (sum to 1 within config β scale-free across sources), scaled by a
|
| 288 |
+
# per-source weight (HF elevated: frequency-weighted definitions with
|
| 289 |
+
# cardinality), summed per unique string, re-ranked, and placed once.
|
| 290 |
+
|
| 291 |
+
DEFAULT_SOURCE_WEIGHTS = {"hf": 5.0, "nltk": 1.0, "unicode": 1.0, "legacy": 0.0}
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def canonize(cfg: AtlasConfig,
|
| 295 |
+
source_weights: Optional[Dict[str, float]] = None,
|
| 296 |
+
configs: Optional[List[str]] = None) -> Dict:
|
| 297 |
+
"""Build the canonical deduplicated atlas directly from the dataset."""
|
| 298 |
+
import pandas as pd
|
| 299 |
+
import pyarrow as pa
|
| 300 |
+
import pyarrow.parquet as pq
|
| 301 |
+
from huggingface_hub import hf_hub_download
|
| 302 |
+
|
| 303 |
+
W = dict(DEFAULT_SOURCE_WEIGHTS)
|
| 304 |
+
if source_weights:
|
| 305 |
+
W.update(source_weights)
|
| 306 |
+
names = configs or SOURCE_CONFIGS
|
| 307 |
+
|
| 308 |
+
frames, prov = [], []
|
| 309 |
+
for name in names:
|
| 310 |
+
lam = W.get(source_of(name), 0.0)
|
| 311 |
+
if lam <= 0:
|
| 312 |
+
print(f"[canon] {name}: weight 0 β skipped")
|
| 313 |
+
continue
|
| 314 |
+
p = hf_hub_download(DATASET, f"data/{name}-00000-of-00001.parquet",
|
| 315 |
+
repo_type="dataset")
|
| 316 |
+
t = pq.read_table(p, columns=["ngram", "frequency"]).to_pandas()
|
| 317 |
+
t["ngram"] = t["ngram"].astype(str)
|
| 318 |
+
t["w"] = lam * t["frequency"] / max(t["frequency"].sum(), 1e-30)
|
| 319 |
+
t["src"] = source_of(name)
|
| 320 |
+
frames.append(t[["ngram", "w", "src"]])
|
| 321 |
+
print(f"[canon] {name}: {len(t):,} rows (lambda={lam})")
|
| 322 |
+
allrows = pd.concat(frames, ignore_index=True)
|
| 323 |
+
print(f"[canon] total rows {len(allrows):,}")
|
| 324 |
+
|
| 325 |
+
agg = allrows.groupby("ngram", sort=False).agg(
|
| 326 |
+
weight=("w", "sum"),
|
| 327 |
+
n_sources=("src", "nunique"),
|
| 328 |
+
sources=("src", lambda s: "+".join(sorted(set(s)))))
|
| 329 |
+
agg = agg.sort_values("weight", ascending=False).reset_index()
|
| 330 |
+
N = len(agg)
|
| 331 |
+
dup = len(allrows) - N
|
| 332 |
+
print(f"[canon] unique n-grams {N:,} (merged {dup:,} duplicate rows)")
|
| 333 |
+
|
| 334 |
+
# Tier 1: one fresh stream over the canonical ranking β one string, one point
|
| 335 |
+
sobol = SobolSphere(cfg.d_base, cfg.base_seed)
|
| 336 |
+
base = sobol.take(N)
|
| 337 |
+
|
| 338 |
+
# Tier 2: learned view regenerated per unique string (pure function)
|
| 339 |
+
view = LearnedView(cfg.checkpoint, cfg.device) if cfg.checkpoint else None
|
| 340 |
+
vview, n_tri_all, n_coll = None, None, 0
|
| 341 |
+
if view is not None:
|
| 342 |
+
views, counts, hashes = [], [], []
|
| 343 |
+
for s0 in range(0, N, cfg.batch):
|
| 344 |
+
chunk = agg["ngram"].iloc[s0: s0 + cfg.batch].tolist()
|
| 345 |
+
tri, n_tri, mh = ngrams_to_trigrams(chunk, cfg.max_tri)
|
| 346 |
+
views.append(view.compose(tri, n_tri))
|
| 347 |
+
counts.append(n_tri)
|
| 348 |
+
hashes.append(mh)
|
| 349 |
+
vview = torch.cat(views)
|
| 350 |
+
n_tri_all = torch.cat(counts)
|
| 351 |
+
mh = np.concatenate(hashes)
|
| 352 |
+
n_coll = int(N - len(np.unique(mh)))
|
| 353 |
+
|
| 354 |
+
os.makedirs(cfg.out_dir, exist_ok=True)
|
| 355 |
+
cols = {"ngram": pa.array(agg["ngram"]),
|
| 356 |
+
"weight": pa.array(agg["weight"].astype("float64")),
|
| 357 |
+
"rank": pa.array(np.arange(1, N + 1, dtype=np.int64)),
|
| 358 |
+
"n_sources": pa.array(agg["n_sources"].astype("int8")),
|
| 359 |
+
"sources": pa.array(agg["sources"]),
|
| 360 |
+
"vec_base": pa.array(base.numpy().tolist(),
|
| 361 |
+
type=pa.list_(pa.float32(), cfg.d_base))}
|
| 362 |
+
if vview is not None:
|
| 363 |
+
cols["n_tri"] = pa.array(n_tri_all.numpy().astype("int8"))
|
| 364 |
+
cols["vec_view"] = pa.array(vview.numpy().tolist(),
|
| 365 |
+
type=pa.list_(pa.float32(), vview.shape[1]))
|
| 366 |
+
out_path = os.path.join(cfg.out_dir, "canon.parquet")
|
| 367 |
+
pq.write_table(pa.table(cols), out_path)
|
| 368 |
+
|
| 369 |
+
stats = {"unique": N, "merged_duplicates": dup,
|
| 370 |
+
"source_weights": W, "configs": names,
|
| 371 |
+
"anagram_collisions_view": n_coll,
|
| 372 |
+
"base": spacing_stats(base, cfg.stats_sample)}
|
| 373 |
+
if vview is not None:
|
| 374 |
+
stats["view"] = spacing_stats(vview, cfg.stats_sample)
|
| 375 |
+
with open(os.path.join(cfg.out_dir, "canon.stats.json"), "w") as f:
|
| 376 |
+
json.dump(stats, f, indent=2, default=str)
|
| 377 |
+
print(f"[canon] -> {out_path} "
|
| 378 |
+
f"(base NN {stats['base']['nn_deg_median']:.2f} deg"
|
| 379 |
+
+ (f", view collisions {n_coll}" if vview is not None else "") + ")")
|
| 380 |
+
return stats
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 386 |
+
# Stratified bank β round-robin across the granularity ladder
|
| 387 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 388 |
+
# Breadth-first sampling: rank-1 of every config, then rank-2, ... with
|
| 389 |
+
# weighted dedupe along the way, until `target` unique n-grams. Yields a
|
| 390 |
+
# compact multi-granularity candidate vocabulary stratified across
|
| 391 |
+
# {source} x {char, word} x {1..5}gram. Two outputs:
|
| 392 |
+
# bank_{target}.parquet the full multi-granularity bank
|
| 393 |
+
# bank_{target}_tri.pt the 3-byte-exact subset as an (M, 3) tensor β
|
| 394 |
+
# a DROP-IN AlephLM trigram bank (only exact
|
| 395 |
+
# 3-byte strings can match raw next-trigram
|
| 396 |
+
# targets; variable-length candidates await the
|
| 397 |
+
# span-prediction head β v2, noted in log)
|
| 398 |
+
|
| 399 |
+
def stratified_bank(cfg: AtlasConfig, target: int = 4096,
|
| 400 |
+
source_weights: Optional[Dict[str, float]] = None,
|
| 401 |
+
configs: Optional[List[str]] = None) -> Dict:
|
| 402 |
+
import pandas as pd
|
| 403 |
+
import pyarrow as pa
|
| 404 |
+
import pyarrow.parquet as pq
|
| 405 |
+
from huggingface_hub import hf_hub_download
|
| 406 |
+
|
| 407 |
+
W = dict(DEFAULT_SOURCE_WEIGHTS)
|
| 408 |
+
if source_weights:
|
| 409 |
+
W.update(source_weights)
|
| 410 |
+
names = [n for n in (configs or SOURCE_CONFIGS)
|
| 411 |
+
if W.get(source_of(n), 0.0) > 0]
|
| 412 |
+
|
| 413 |
+
tables = []
|
| 414 |
+
for name in names:
|
| 415 |
+
p = hf_hub_download(DATASET, f"data/{name}-00000-of-00001.parquet",
|
| 416 |
+
repo_type="dataset")
|
| 417 |
+
t = pq.read_table(p, columns=["ngram", "rank", "frequency"]).to_pandas()
|
| 418 |
+
t["ngram"] = t["ngram"].astype(str)
|
| 419 |
+
lam = W[source_of(name)]
|
| 420 |
+
t["w"] = lam * t["frequency"] / max(t["frequency"].sum(), 1e-30)
|
| 421 |
+
t["config"] = name
|
| 422 |
+
tables.append(t.sort_values("rank").reset_index(drop=True))
|
| 423 |
+
|
| 424 |
+
chosen: Dict[str, Dict] = {}
|
| 425 |
+
depth = 0
|
| 426 |
+
while len(chosen) < target:
|
| 427 |
+
progressed = False
|
| 428 |
+
for t in tables:
|
| 429 |
+
if depth >= len(t):
|
| 430 |
+
continue
|
| 431 |
+
progressed = True
|
| 432 |
+
row = t.iloc[depth]
|
| 433 |
+
rec = chosen.get(row.ngram)
|
| 434 |
+
if rec is None:
|
| 435 |
+
chosen[row.ngram] = {"weight": row.w, "configs": {row.config},
|
| 436 |
+
"first_depth": depth}
|
| 437 |
+
else:
|
| 438 |
+
rec["weight"] += row.w
|
| 439 |
+
rec["configs"].add(row.config)
|
| 440 |
+
if len(chosen) >= target:
|
| 441 |
+
break
|
| 442 |
+
depth += 1
|
| 443 |
+
if not progressed:
|
| 444 |
+
break
|
| 445 |
+
rows = [{"ngram": k, "weight": v["weight"],
|
| 446 |
+
"n_configs": len(v["configs"]),
|
| 447 |
+
"configs": "+".join(sorted(v["configs"])),
|
| 448 |
+
"first_depth": v["first_depth"],
|
| 449 |
+
"n_bytes": len(k.encode("utf-8", errors="ignore"))}
|
| 450 |
+
for k, v in chosen.items()]
|
| 451 |
+
bank = pd.DataFrame(rows).sort_values(
|
| 452 |
+
["first_depth", "weight"], ascending=[True, False]).reset_index(drop=True)
|
| 453 |
+
|
| 454 |
+
os.makedirs(cfg.out_dir, exist_ok=True)
|
| 455 |
+
out_pq = os.path.join(cfg.out_dir, f"bank_{target}.parquet")
|
| 456 |
+
pq.write_table(pa.Table.from_pandas(bank, preserve_index=False), out_pq)
|
| 457 |
+
|
| 458 |
+
tri_rows = [list(k.encode("utf-8")) for k in bank["ngram"]
|
| 459 |
+
if len(k.encode("utf-8", errors="ignore")) == 3]
|
| 460 |
+
tri = torch.tensor(tri_rows, dtype=torch.long) if tri_rows else torch.empty(0, 3, dtype=torch.long)
|
| 461 |
+
out_pt = os.path.join(cfg.out_dir, f"bank_{target}_tri.pt")
|
| 462 |
+
torch.save({"bank": tri, "source": "stratified_atlas",
|
| 463 |
+
"target": target, "weights": W, "configs": names}, out_pt)
|
| 464 |
+
|
| 465 |
+
comp = bank.groupby(bank["configs"].str.split("+").str[0]).size().to_dict()
|
| 466 |
+
print(f"[bank] {len(bank):,} unique n-grams at round-robin depth {depth}"
|
| 467 |
+
f" (3-byte-exact: {len(tri):,} -> {out_pt})")
|
| 468 |
+
print(f"[bank] multi-config members: {(bank.n_configs > 1).sum():,}"
|
| 469 |
+
f" byte-length histogram: "
|
| 470 |
+
f"{bank.n_bytes.value_counts().sort_index().to_dict()}")
|
| 471 |
+
print(f"[bank] -> {out_pq}")
|
| 472 |
+
return {"n": len(bank), "depth": depth, "n_tri": len(tri),
|
| 473 |
+
"paths": [out_pq, out_pt]}
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 477 |
+
# Activation
|
| 478 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 479 |
+
|
| 480 |
+
if __name__ == "__main__":
|
| 481 |
+
import argparse
|
| 482 |
+
ap = argparse.ArgumentParser(description="Full lexical-topology atlas")
|
| 483 |
+
ap.add_argument("--checkpoint", default=None)
|
| 484 |
+
ap.add_argument("--configs", nargs="+",
|
| 485 |
+
default=["all"])#"char_eng_unigram", "char_eng_2gram",
|
| 486 |
+
#"char_eng_3gram", "char_eng_4gram",
|
| 487 |
+
#"char_eng_5gram"])
|
| 488 |
+
ap.add_argument("--d-base", type=int, default=48)
|
| 489 |
+
ap.add_argument("--out", default="atlas")
|
| 490 |
+
ap.add_argument("--device",
|
| 491 |
+
default="cuda" if torch.cuda.is_available() else "cpu")
|
| 492 |
+
ap.add_argument("--canon", action="store_true",
|
| 493 |
+
help="weighted dedupe across sources: one string, one point")
|
| 494 |
+
ap.add_argument("--weights", default="hf=5,nltk=1,unicode=1,legacy=0",
|
| 495 |
+
help="per-source lambdas, e.g. hf=5,nltk=1,unicode=1")
|
| 496 |
+
ap.add_argument("--bank", type=int, default=0,
|
| 497 |
+
help="build a stratified bank of this many unique n-grams")
|
| 498 |
+
args, _unknown = ap.parse_known_args()
|
| 499 |
+
acfg = AtlasConfig(checkpoint=args.checkpoint, configs=args.configs,
|
| 500 |
+
d_base=args.d_base, out_dir=args.out, device=args.device)
|
| 501 |
+
sw = {k: float(v) for k, v in
|
| 502 |
+
(kv.split("=") for kv in args.weights.split(","))}
|
| 503 |
+
if args.bank:
|
| 504 |
+
stratified_bank(acfg, target=args.bank, source_weights=sw,
|
| 505 |
+
configs=None if args.configs in (["all"], ["sources"])
|
| 506 |
+
else args.configs)
|
| 507 |
+
elif args.canon:
|
| 508 |
+
canonize(acfg, source_weights=sw,
|
| 509 |
+
configs=None if args.configs == ["all"] else
|
| 510 |
+
(SOURCE_CONFIGS if args.configs == ["sources"] else args.configs))
|
| 511 |
+
else:
|
| 512 |
+
build_atlas(acfg)
|