File size: 3,389 Bytes
025878f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""Train a fresh 8K BPE on a FineWeb-edu sample.

This is the 50M-scale variant of the 1M project's 4K BPE. We bump the default
vocab to 8192 and the document count to 50000 (was 50000 in 1M, kept the same
because the 1M doc-count was already saturating BPE merge quality at 4K vocab
-- doubling vocab needs roughly the same training set, not 2x more).

We do NOT reuse any FANT tokenizer here -- the point of this experiment family
is a clean small recipe with no external dependencies.

Output: tokenizer.json in the working dir (or wherever specified).
"""
from __future__ import annotations

import argparse
import time
from pathlib import Path

from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.pre_tokenizers import ByteLevel as BLPre
from tokenizers.decoders import ByteLevel as BLDec
from tokenizers.processors import ByteLevel as BLPost
from tokenizers.trainers import BpeTrainer


SPECIAL_TOKENS = [
    "<|pad|>",      # 0
    "<|bos|>",      # 1
    "<|eos|>",      # 2
    "<|unk|>",      # 3
    "<|im_start|>", # 4 -- chat role open
    "<|im_end|>",   # 5 -- chat role close
]


def _iter_fineweb(n_docs: int):
    """Yield up to `n_docs` text strings from the FineWeb-edu streaming feed."""
    from datasets import load_dataset

    ds = load_dataset(
        "HuggingFaceFW/fineweb-edu",
        name="default",
        split="train",
        streaming=True,
    )
    n = 0
    for ex in ds:
        if n >= n_docs:
            return
        text = ex.get("text", "")
        if isinstance(text, str) and text.strip():
            n += 1
            yield text


def train_tokenizer(out_path: str = "tokenizer.json", vocab_size: int = 8192, n_docs: int = 50000) -> str:
    tok = Tokenizer(BPE(unk_token="<|unk|>"))
    tok.pre_tokenizer = BLPre(add_prefix_space=False)
    tok.decoder = BLDec()
    tok.post_processor = BLPost(trim_offsets=False)

    trainer = BpeTrainer(
        vocab_size=vocab_size,
        special_tokens=SPECIAL_TOKENS,
        initial_alphabet=BLPre.alphabet(),
        show_progress=False,
    )

    print(f"[tokenizer] streaming up to {n_docs} FineWeb-edu docs...")
    t0 = time.time()
    docs = list(_iter_fineweb(n_docs))
    print(f"[tokenizer] collected {len(docs)} docs in {time.time() - t0:.1f}s")

    print(f"[tokenizer] training BPE vocab_size={vocab_size}...")
    t0 = time.time()
    tok.train_from_iterator(docs, trainer=trainer)
    print(f"[tokenizer] trained in {time.time() - t0:.1f}s; vocab={tok.get_vocab_size()}")

    out_dir = Path(out_path).parent
    if str(out_dir) and not out_dir.exists():
        out_dir.mkdir(parents=True, exist_ok=True)
    tok.save(out_path)
    print(f"[tokenizer] saved to {out_path}")
    return out_path


def load_tokenizer(path: str = "tokenizer.json") -> Tokenizer:
    return Tokenizer.from_file(path)


# Convenience accessors used by data.py / train.py
def special_token_id(tok: Tokenizer, name: str) -> int:
    tid = tok.token_to_id(name)
    assert tid is not None, f"{name} not in tokenizer"
    return tid


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--out", default="tokenizer.json")
    ap.add_argument("--vocab", type=int, default=8192)
    ap.add_argument("--docs", type=int, default=50000)
    args = ap.parse_args()
    train_tokenizer(args.out, args.vocab, args.docs)


if __name__ == "__main__":
    main()