AbstractPhil commited on
Commit
4e43091
Β·
verified Β·
1 Parent(s): 408da7e

Create lexical_atlas.py

Browse files
experiments/exp_007_aleph_routed_attention/lexical_atlas.py ADDED
@@ -0,0 +1,512 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # lexical_atlas.py
2
+ """
3
+ Lexical Atlas β€” the full wordnet-lexical-topology vocabulary on the sphere
4
+ ===========================================================================
5
+
6
+ Extracts the ENTIRE AbstractPhil/wordnet-lexical-topology setup (~12.8M
7
+ n-grams across {nltk, hf, unicode} x {char, word} x {1..5}gram configs) into
8
+ spherical coordinates, correctly spaced β€” where "correct" is determined by
9
+ capacity mathematics, not hope.
10
+
11
+ THE CAPACITY LAW (computed exactly, 2026-06-09):
12
+ 12.8M uniformly spaced points on S^(D-1), median nearest-neighbor angle:
13
+ D=4 : 0.363 deg -> 0.06 logits of address contrast at tau=0.1
14
+ (neighbors indistinguishable through K=64; fp16
15
+ cannot resolve the cosines, fp32 marginal)
16
+ D=32: 39.1 deg | D=48: 47.6 deg -> 7-8 logits, comfortable
17
+ The CM-band result (band-valid D=32-112, sweet spot 32-56) independently
18
+ prescribes the same range. THEREFORE the atlas is TWO-TIER:
19
+
20
+ TIER 1 (base) : deterministic low-discrepancy placement at band-valid D
21
+ (default 48) β€” scrambled-Sobol -> Gaussian -> normalize.
22
+ Uniform by construction, reproducible by seed, unique
23
+ per n-gram. This is "spaced on the sphere correctly."
24
+ TIER 2 (view) : the LEARNED D_addr=4 address-space view extracted from a
25
+ trained AlephLM checkpoint β€” per n-gram: bytes -> pad ->
26
+ trigrams -> kappa rows (W_kappa o byte_emb) -> mean ->
27
+ normalize. This is where the model actually PLACED the
28
+ vocabulary; crowded by necessity (see law), meaningful
29
+ as geometry-of-content, not as unique identity.
30
+
31
+ Honesty on the learned view: mean composition is order-insensitive, so
32
+ anagrammatic n-grams (same trigram multiset) COLLIDE; collisions are counted
33
+ and reported per config. The deterministic tier never collides.
34
+
35
+ Per-config outputs:
36
+ atlas/{config}.parquet columns: ngram, rank, frequency, n_tri,
37
+ vec_base (D_base floats), vec_view (4 floats)
38
+ atlas/{config}.stats.json NN-angle distribution (sampled), statute of
39
+ both tiers (4k subsample), collision count
40
+
41
+ Usage:
42
+ python lexical_atlas.py --checkpoint aleph_lm_hybrid_corpus.pt \\
43
+ --configs char_eng_unigram char_eng_2gram char_eng_3gram \\
44
+ char_eng_4gram char_eng_5gram --d-base 48
45
+ # --configs all -> every config in the dataset (~12.8M rows total)
46
+
47
+ Depends: aleph_lm.py (+ its deps), pyarrow, huggingface_hub.
48
+ Author: AbstractPhil + Mirel Date: 2026-06-09 License: MIT
49
+ """
50
+
51
+ from __future__ import annotations
52
+
53
+ import json
54
+ import math
55
+ import os
56
+ from dataclasses import dataclass, field
57
+ from typing import Dict, List, Optional, Tuple
58
+
59
+ import numpy as np
60
+ import torch
61
+ import torch.nn.functional as F
62
+ from torch import Tensor
63
+
64
+ DATASET = "AbstractPhil/wordnet-lexical-topology"
65
+ PAD_BYTE = 0x00 # reserved pad symbol (documented, learned slot)
66
+
67
+ _GRAMS = ("unigram", "2gram", "3gram", "4gram", "5gram")
68
+ SOURCE_CONFIGS = ([f"nltk_{k}_eng_{n}" for k in ("char", "word") for n in _GRAMS]
69
+ + [f"hf_{k}_eng_{n}" for k in ("char", "word") for n in _GRAMS]
70
+ + [f"unicode_global_{n}" for n in _GRAMS])
71
+ LEGACY_CONFIGS = [f"{k}_eng_{n}" for k in ("char", "word") for n in _GRAMS]
72
+ # legacy unprefixed configs are pre-merged ANCESTORS (verified: char_eng_3gram
73
+ # is a superset of nltk_char_eng_3gram, freq corr 0.914) β€” excluded from 'all'
74
+ # to avoid double counting; available explicitly.
75
+ ALL_CONFIGS = SOURCE_CONFIGS
76
+
77
+
78
+ def source_of(config: str) -> str:
79
+ for s in ("nltk", "hf", "unicode"):
80
+ if config.startswith(s):
81
+ return s
82
+ return "legacy"
83
+
84
+
85
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
86
+ # Config
87
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
88
+
89
+ @dataclass
90
+ class AtlasConfig:
91
+ checkpoint: Optional[str] = None # AlephLM ckpt (None = base tier only)
92
+ configs: List[str] = field(default_factory=lambda: [
93
+ "char_eng_unigram", "char_eng_2gram", "char_eng_3gram",
94
+ "char_eng_4gram", "char_eng_5gram"])
95
+ d_base: int = 48 # band-valid (CM sweet spot 32-56)
96
+ base_seed: int = 1234 # determinism of Tier 1
97
+ out_dir: str = "atlas"
98
+ batch: int = 65536
99
+ max_tri: int = 16 # n-grams longer than 48 bytes truncated
100
+ stats_sample: int = 4000
101
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
102
+
103
+
104
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━���━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
105
+ # Tier 1 β€” deterministic band-valid base (correct spacing by construction)
106
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
107
+
108
+ class SobolSphere:
109
+ """Low-discrepancy points on S^(D-1): scrambled Sobol -> inverse-normal ->
110
+ normalize. Deterministic per (seed, global index): the same n-gram (by its
111
+ global rank position) always receives the same point. Never collides."""
112
+
113
+ def __init__(self, D: int, seed: int):
114
+ self.D, self.seed = D, seed
115
+ self.eng = torch.quasirandom.SobolEngine(D, scramble=True, seed=seed)
116
+ self._cursor = 0
117
+
118
+ def take(self, n: int) -> Tensor:
119
+ u = self.eng.draw(n).clamp(1e-6, 1 - 1e-6)
120
+ g = torch.erfinv(2 * u - 1) * math.sqrt(2.0) # inverse normal CDF
121
+ self._cursor += n
122
+ return F.normalize(g, dim=-1)
123
+
124
+
125
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
126
+ # Tier 2 β€” learned address-space view (the model's own placement)
127
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
128
+
129
+ class LearnedView:
130
+ """kappa-row composer from a trained AlephLM checkpoint."""
131
+
132
+ def __init__(self, checkpoint: str, device: str):
133
+ from aleph_lm import AlephLM, AlephLMConfig
134
+ d = torch.load(checkpoint, map_location=device, weights_only=False)
135
+ fields = AlephLMConfig.__dataclass_fields__
136
+ cfg = AlephLMConfig(**{k: v for k, v in d["config"].items() if k in fields})
137
+ bank = d.get("bank", None)
138
+ self.model = AlephLM(cfg, bank=bank).to(device)
139
+ self.model.load_state_dict(d["model_state_dict"])
140
+ self.model.eval()
141
+ self.cfg, self.device = cfg, device
142
+
143
+ @torch.no_grad()
144
+ def compose(self, tri: Tensor, n_tri: Tensor) -> Tensor:
145
+ """tri: (B, T, 3) padded trigram bytes; n_tri: (B,) valid counts.
146
+ Returns (B, D_addr) unit rows: normalized mean of per-trigram
147
+ kappa rows over the valid prefix. Order-insensitive (collisions
148
+ among anagrams; counted upstream)."""
149
+ m = self.model
150
+ tri = tri.to(self.device)
151
+ e = sum(emb(tri[..., i]) for i, emb in enumerate(m.byte_emb)) # (B,T,d)
152
+ rows = F.normalize(m.W_kappa(e), dim=-1) # (B,T,Da)
153
+ mask = (torch.arange(tri.shape[1], device=self.device)[None, :]
154
+ < n_tri.to(self.device)[:, None]).float().unsqueeze(-1)
155
+ mean = (rows * mask).sum(1) / mask.sum(1).clamp_min(1e-9)
156
+ return F.normalize(mean, dim=-1).cpu()
157
+
158
+
159
+ def ngrams_to_trigrams(ngrams: List[str], max_tri: int
160
+ ) -> Tuple[Tensor, Tensor, np.ndarray]:
161
+ """UTF-8 encode, pad to multiple of 3 with PAD_BYTE, frame as trigrams.
162
+ Returns (B, max_tri, 3) bytes, (B,) counts, and the trigram-multiset hash
163
+ per n-gram (for anagram-collision counting)."""
164
+ B = len(ngrams)
165
+ out = np.zeros((B, max_tri, 3), dtype=np.int64)
166
+ counts = np.zeros(B, dtype=np.int64)
167
+ mhash = np.zeros(B, dtype=np.uint64)
168
+ for i, s in enumerate(ngrams):
169
+ b = str(s).encode("utf-8", errors="ignore")[: 3 * max_tri]
170
+ if len(b) % 3:
171
+ b = b + bytes([PAD_BYTE]) * (3 - len(b) % 3)
172
+ t = np.frombuffer(b, dtype=np.uint8).reshape(-1, 3).astype(np.int64)
173
+ n = len(t)
174
+ out[i, :n] = t
175
+ counts[i] = max(n, 1)
176
+ ids = (t[:, 0] * 65536 + t[:, 1] * 256 + t[:, 2]).astype(np.uint64)
177
+ h = np.uint64(0)
178
+ for v in np.sort(ids): # order-free multiset hash
179
+ h = (h * np.uint64(1099511628211)) ^ (v + np.uint64(0x9E3779B9))
180
+ mhash[i] = h ^ np.uint64(n)
181
+ return torch.from_numpy(out), torch.from_numpy(counts), mhash
182
+
183
+
184
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
185
+ # Spacing battery
186
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
187
+
188
+ def spacing_stats(vecs: Tensor, sample: int, seed: int = 0) -> Dict:
189
+ """Sampled NN-angle distribution + statute on a subsample."""
190
+ g = torch.Generator().manual_seed(seed)
191
+ idx = torch.randperm(len(vecs), generator=g)[: min(sample, len(vecs))]
192
+ X = F.normalize(vecs[idx].float(), dim=-1)
193
+ cos = (X @ X.t()).clamp(-1, 1)
194
+ cos.fill_diagonal_(-1)
195
+ nn_deg = torch.acos(cos.max(dim=-1).values) * 180 / math.pi
196
+ st = statute(X)
197
+ return {"nn_deg_median": nn_deg.median().item(),
198
+ "nn_deg_p05": nn_deg.quantile(0.05).item(),
199
+ "nn_deg_p95": nn_deg.quantile(0.95).item(),
200
+ "statute": st}
201
+
202
+
203
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
204
+ # Extraction
205
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
206
+
207
+ def extract_config(name: str, cfg: AtlasConfig, sobol: SobolSphere,
208
+ view: Optional[LearnedView]) -> Dict:
209
+ import pyarrow as pa
210
+ import pyarrow.parquet as pq
211
+ from huggingface_hub import hf_hub_download
212
+
213
+ path = hf_hub_download(DATASET, f"data/{name}-00000-of-00001.parquet",
214
+ repo_type="dataset")
215
+ t = pq.read_table(path, columns=["ngram", "rank", "frequency"]) \
216
+ .to_pandas().sort_values("rank").reset_index(drop=True)
217
+ N = len(t)
218
+ print(f"[{name}] {N:,} n-grams")
219
+
220
+ base = sobol.take(N) # (N, D_base)
221
+ views, counts, hashes = [], [], []
222
+ if view is not None:
223
+ for s0 in range(0, N, cfg.batch):
224
+ chunk = t["ngram"].iloc[s0: s0 + cfg.batch].tolist()
225
+ tri, n_tri, mh = ngrams_to_trigrams(chunk, cfg.max_tri)
226
+ views.append(view.compose(tri, n_tri))
227
+ counts.append(n_tri)
228
+ hashes.append(mh)
229
+ vview = torch.cat(views)
230
+ n_tri = torch.cat(counts)
231
+ mh = np.concatenate(hashes)
232
+ n_coll = int(N - len(np.unique(mh)))
233
+ else:
234
+ vview, n_tri, n_coll = None, None, 0
235
+
236
+ os.makedirs(cfg.out_dir, exist_ok=True)
237
+ cols = {"ngram": pa.array(t["ngram"].astype(str)),
238
+ "rank": pa.array(t["rank"].astype("int64")),
239
+ "frequency": pa.array(t["frequency"].astype("float64")),
240
+ "vec_base": pa.array(base.numpy().tolist(),
241
+ type=pa.list_(pa.float32(), cfg.d_base))}
242
+ if vview is not None:
243
+ cols["n_tri"] = pa.array(n_tri.numpy().astype("int8"))
244
+ cols["vec_view"] = pa.array(vview.numpy().tolist(),
245
+ type=pa.list_(pa.float32(), vview.shape[1]))
246
+ out_path = os.path.join(cfg.out_dir, f"{name}.parquet")
247
+ pq.write_table(pa.table(cols), out_path)
248
+
249
+ stats = {"config": name, "n": N, "d_base": cfg.d_base,
250
+ "anagram_collisions_view": n_coll,
251
+ "base": spacing_stats(base, cfg.stats_sample)}
252
+ if vview is not None:
253
+ stats["view"] = spacing_stats(vview, cfg.stats_sample)
254
+ with open(os.path.join(cfg.out_dir, f"{name}.stats.json"), "w") as f:
255
+ json.dump(stats, f, indent=2, default=str)
256
+ print(f" base NN {stats['base']['nn_deg_median']:.2f} deg "
257
+ f"(statute {stats['base']['statute']['statute']})"
258
+ + (f" view NN {stats['view']['nn_deg_median']:.3f} deg "
259
+ f"(statute {stats['view']['statute']['statute']}, "
260
+ f"collisions {n_coll})" if vview is not None else "")
261
+ + f" -> {out_path}")
262
+ return stats
263
+
264
+
265
+ def build_atlas(cfg: AtlasConfig) -> List[Dict]:
266
+ names = ALL_CONFIGS if cfg.configs == ["all"] else cfg.configs
267
+ sobol = SobolSphere(cfg.d_base, cfg.base_seed) # ONE stream:
268
+ view = LearnedView(cfg.checkpoint, cfg.device) if cfg.checkpoint else None
269
+ # global index = unique placement across ALL configs (never reused)
270
+ all_stats = []
271
+ for name in names:
272
+ all_stats.append(extract_config(name, cfg, sobol, view))
273
+ total = sum(s["n"] for s in all_stats)
274
+ print(f"\n[atlas] {total:,} n-grams placed at D={cfg.d_base} "
275
+ f"(Tier 1, deterministic, collision-free)"
276
+ + (f" + learned D=4 view (Tier 2)" if view else ""))
277
+ return all_stats
278
+
279
+
280
+
281
+
282
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
283
+ # Canon β€” weighted dedupe across sources: ONE STRING, ONE POINT
284
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
285
+ # Cross-config duplicates of the same n-gram must not receive different
286
+ # Tier-1 placements. Canonization: per-config frequencies are normalized
287
+ # (sum to 1 within config β€” scale-free across sources), scaled by a
288
+ # per-source weight (HF elevated: frequency-weighted definitions with
289
+ # cardinality), summed per unique string, re-ranked, and placed once.
290
+
291
+ DEFAULT_SOURCE_WEIGHTS = {"hf": 5.0, "nltk": 1.0, "unicode": 1.0, "legacy": 0.0}
292
+
293
+
294
+ def canonize(cfg: AtlasConfig,
295
+ source_weights: Optional[Dict[str, float]] = None,
296
+ configs: Optional[List[str]] = None) -> Dict:
297
+ """Build the canonical deduplicated atlas directly from the dataset."""
298
+ import pandas as pd
299
+ import pyarrow as pa
300
+ import pyarrow.parquet as pq
301
+ from huggingface_hub import hf_hub_download
302
+
303
+ W = dict(DEFAULT_SOURCE_WEIGHTS)
304
+ if source_weights:
305
+ W.update(source_weights)
306
+ names = configs or SOURCE_CONFIGS
307
+
308
+ frames, prov = [], []
309
+ for name in names:
310
+ lam = W.get(source_of(name), 0.0)
311
+ if lam <= 0:
312
+ print(f"[canon] {name}: weight 0 β€” skipped")
313
+ continue
314
+ p = hf_hub_download(DATASET, f"data/{name}-00000-of-00001.parquet",
315
+ repo_type="dataset")
316
+ t = pq.read_table(p, columns=["ngram", "frequency"]).to_pandas()
317
+ t["ngram"] = t["ngram"].astype(str)
318
+ t["w"] = lam * t["frequency"] / max(t["frequency"].sum(), 1e-30)
319
+ t["src"] = source_of(name)
320
+ frames.append(t[["ngram", "w", "src"]])
321
+ print(f"[canon] {name}: {len(t):,} rows (lambda={lam})")
322
+ allrows = pd.concat(frames, ignore_index=True)
323
+ print(f"[canon] total rows {len(allrows):,}")
324
+
325
+ agg = allrows.groupby("ngram", sort=False).agg(
326
+ weight=("w", "sum"),
327
+ n_sources=("src", "nunique"),
328
+ sources=("src", lambda s: "+".join(sorted(set(s)))))
329
+ agg = agg.sort_values("weight", ascending=False).reset_index()
330
+ N = len(agg)
331
+ dup = len(allrows) - N
332
+ print(f"[canon] unique n-grams {N:,} (merged {dup:,} duplicate rows)")
333
+
334
+ # Tier 1: one fresh stream over the canonical ranking β€” one string, one point
335
+ sobol = SobolSphere(cfg.d_base, cfg.base_seed)
336
+ base = sobol.take(N)
337
+
338
+ # Tier 2: learned view regenerated per unique string (pure function)
339
+ view = LearnedView(cfg.checkpoint, cfg.device) if cfg.checkpoint else None
340
+ vview, n_tri_all, n_coll = None, None, 0
341
+ if view is not None:
342
+ views, counts, hashes = [], [], []
343
+ for s0 in range(0, N, cfg.batch):
344
+ chunk = agg["ngram"].iloc[s0: s0 + cfg.batch].tolist()
345
+ tri, n_tri, mh = ngrams_to_trigrams(chunk, cfg.max_tri)
346
+ views.append(view.compose(tri, n_tri))
347
+ counts.append(n_tri)
348
+ hashes.append(mh)
349
+ vview = torch.cat(views)
350
+ n_tri_all = torch.cat(counts)
351
+ mh = np.concatenate(hashes)
352
+ n_coll = int(N - len(np.unique(mh)))
353
+
354
+ os.makedirs(cfg.out_dir, exist_ok=True)
355
+ cols = {"ngram": pa.array(agg["ngram"]),
356
+ "weight": pa.array(agg["weight"].astype("float64")),
357
+ "rank": pa.array(np.arange(1, N + 1, dtype=np.int64)),
358
+ "n_sources": pa.array(agg["n_sources"].astype("int8")),
359
+ "sources": pa.array(agg["sources"]),
360
+ "vec_base": pa.array(base.numpy().tolist(),
361
+ type=pa.list_(pa.float32(), cfg.d_base))}
362
+ if vview is not None:
363
+ cols["n_tri"] = pa.array(n_tri_all.numpy().astype("int8"))
364
+ cols["vec_view"] = pa.array(vview.numpy().tolist(),
365
+ type=pa.list_(pa.float32(), vview.shape[1]))
366
+ out_path = os.path.join(cfg.out_dir, "canon.parquet")
367
+ pq.write_table(pa.table(cols), out_path)
368
+
369
+ stats = {"unique": N, "merged_duplicates": dup,
370
+ "source_weights": W, "configs": names,
371
+ "anagram_collisions_view": n_coll,
372
+ "base": spacing_stats(base, cfg.stats_sample)}
373
+ if vview is not None:
374
+ stats["view"] = spacing_stats(vview, cfg.stats_sample)
375
+ with open(os.path.join(cfg.out_dir, "canon.stats.json"), "w") as f:
376
+ json.dump(stats, f, indent=2, default=str)
377
+ print(f"[canon] -> {out_path} "
378
+ f"(base NN {stats['base']['nn_deg_median']:.2f} deg"
379
+ + (f", view collisions {n_coll}" if vview is not None else "") + ")")
380
+ return stats
381
+
382
+
383
+
384
+
385
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
386
+ # Stratified bank β€” round-robin across the granularity ladder
387
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
388
+ # Breadth-first sampling: rank-1 of every config, then rank-2, ... with
389
+ # weighted dedupe along the way, until `target` unique n-grams. Yields a
390
+ # compact multi-granularity candidate vocabulary stratified across
391
+ # {source} x {char, word} x {1..5}gram. Two outputs:
392
+ # bank_{target}.parquet the full multi-granularity bank
393
+ # bank_{target}_tri.pt the 3-byte-exact subset as an (M, 3) tensor β€”
394
+ # a DROP-IN AlephLM trigram bank (only exact
395
+ # 3-byte strings can match raw next-trigram
396
+ # targets; variable-length candidates await the
397
+ # span-prediction head β€” v2, noted in log)
398
+
399
+ def stratified_bank(cfg: AtlasConfig, target: int = 4096,
400
+ source_weights: Optional[Dict[str, float]] = None,
401
+ configs: Optional[List[str]] = None) -> Dict:
402
+ import pandas as pd
403
+ import pyarrow as pa
404
+ import pyarrow.parquet as pq
405
+ from huggingface_hub import hf_hub_download
406
+
407
+ W = dict(DEFAULT_SOURCE_WEIGHTS)
408
+ if source_weights:
409
+ W.update(source_weights)
410
+ names = [n for n in (configs or SOURCE_CONFIGS)
411
+ if W.get(source_of(n), 0.0) > 0]
412
+
413
+ tables = []
414
+ for name in names:
415
+ p = hf_hub_download(DATASET, f"data/{name}-00000-of-00001.parquet",
416
+ repo_type="dataset")
417
+ t = pq.read_table(p, columns=["ngram", "rank", "frequency"]).to_pandas()
418
+ t["ngram"] = t["ngram"].astype(str)
419
+ lam = W[source_of(name)]
420
+ t["w"] = lam * t["frequency"] / max(t["frequency"].sum(), 1e-30)
421
+ t["config"] = name
422
+ tables.append(t.sort_values("rank").reset_index(drop=True))
423
+
424
+ chosen: Dict[str, Dict] = {}
425
+ depth = 0
426
+ while len(chosen) < target:
427
+ progressed = False
428
+ for t in tables:
429
+ if depth >= len(t):
430
+ continue
431
+ progressed = True
432
+ row = t.iloc[depth]
433
+ rec = chosen.get(row.ngram)
434
+ if rec is None:
435
+ chosen[row.ngram] = {"weight": row.w, "configs": {row.config},
436
+ "first_depth": depth}
437
+ else:
438
+ rec["weight"] += row.w
439
+ rec["configs"].add(row.config)
440
+ if len(chosen) >= target:
441
+ break
442
+ depth += 1
443
+ if not progressed:
444
+ break
445
+ rows = [{"ngram": k, "weight": v["weight"],
446
+ "n_configs": len(v["configs"]),
447
+ "configs": "+".join(sorted(v["configs"])),
448
+ "first_depth": v["first_depth"],
449
+ "n_bytes": len(k.encode("utf-8", errors="ignore"))}
450
+ for k, v in chosen.items()]
451
+ bank = pd.DataFrame(rows).sort_values(
452
+ ["first_depth", "weight"], ascending=[True, False]).reset_index(drop=True)
453
+
454
+ os.makedirs(cfg.out_dir, exist_ok=True)
455
+ out_pq = os.path.join(cfg.out_dir, f"bank_{target}.parquet")
456
+ pq.write_table(pa.Table.from_pandas(bank, preserve_index=False), out_pq)
457
+
458
+ tri_rows = [list(k.encode("utf-8")) for k in bank["ngram"]
459
+ if len(k.encode("utf-8", errors="ignore")) == 3]
460
+ tri = torch.tensor(tri_rows, dtype=torch.long) if tri_rows else torch.empty(0, 3, dtype=torch.long)
461
+ out_pt = os.path.join(cfg.out_dir, f"bank_{target}_tri.pt")
462
+ torch.save({"bank": tri, "source": "stratified_atlas",
463
+ "target": target, "weights": W, "configs": names}, out_pt)
464
+
465
+ comp = bank.groupby(bank["configs"].str.split("+").str[0]).size().to_dict()
466
+ print(f"[bank] {len(bank):,} unique n-grams at round-robin depth {depth}"
467
+ f" (3-byte-exact: {len(tri):,} -> {out_pt})")
468
+ print(f"[bank] multi-config members: {(bank.n_configs > 1).sum():,}"
469
+ f" byte-length histogram: "
470
+ f"{bank.n_bytes.value_counts().sort_index().to_dict()}")
471
+ print(f"[bank] -> {out_pq}")
472
+ return {"n": len(bank), "depth": depth, "n_tri": len(tri),
473
+ "paths": [out_pq, out_pt]}
474
+
475
+
476
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
477
+ # Activation
478
+ # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
479
+
480
+ if __name__ == "__main__":
481
+ import argparse
482
+ ap = argparse.ArgumentParser(description="Full lexical-topology atlas")
483
+ ap.add_argument("--checkpoint", default=None)
484
+ ap.add_argument("--configs", nargs="+",
485
+ default=["all"])#"char_eng_unigram", "char_eng_2gram",
486
+ #"char_eng_3gram", "char_eng_4gram",
487
+ #"char_eng_5gram"])
488
+ ap.add_argument("--d-base", type=int, default=48)
489
+ ap.add_argument("--out", default="atlas")
490
+ ap.add_argument("--device",
491
+ default="cuda" if torch.cuda.is_available() else "cpu")
492
+ ap.add_argument("--canon", action="store_true",
493
+ help="weighted dedupe across sources: one string, one point")
494
+ ap.add_argument("--weights", default="hf=5,nltk=1,unicode=1,legacy=0",
495
+ help="per-source lambdas, e.g. hf=5,nltk=1,unicode=1")
496
+ ap.add_argument("--bank", type=int, default=0,
497
+ help="build a stratified bank of this many unique n-grams")
498
+ args, _unknown = ap.parse_known_args()
499
+ acfg = AtlasConfig(checkpoint=args.checkpoint, configs=args.configs,
500
+ d_base=args.d_base, out_dir=args.out, device=args.device)
501
+ sw = {k: float(v) for k, v in
502
+ (kv.split("=") for kv in args.weights.split(","))}
503
+ if args.bank:
504
+ stratified_bank(acfg, target=args.bank, source_weights=sw,
505
+ configs=None if args.configs in (["all"], ["sources"])
506
+ else args.configs)
507
+ elif args.canon:
508
+ canonize(acfg, source_weights=sw,
509
+ configs=None if args.configs == ["all"] else
510
+ (SOURCE_CONFIGS if args.configs == ["sources"] else args.configs))
511
+ else:
512
+ build_atlas(acfg)