Pringled commited on
Commit
524bccc
·
verified ·
1 Parent(s): eb05299

Move training script to train.py

Browse files
Files changed (1) hide show
  1. train.py +303 -0
train.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Reproduction script for potion-code-16M.
2
+
3
+ Runs the full pipeline: distill → tokenlearn → contrastive fine-tuning.
4
+
5
+ Requirements:
6
+ pip install model2vec tokenlearn sentence-transformers datasets skeletoken einops
7
+
8
+ The three model checkpoints are saved to:
9
+ ./models/potion-code-16M-distilled
10
+ ./models/potion-code-16M-tokenlearn
11
+ ./models/potion-code-16M-contrastive ← final model
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import logging
17
+ import random
18
+
19
+ import numpy as np
20
+ import torch
21
+ from datasets import Dataset, concatenate_datasets, load_dataset
22
+ from huggingface_hub import snapshot_download
23
+ from model2vec import StaticModel
24
+ from model2vec.distill import distill_from_model
25
+ from model2vec.distill.inference import post_process_embeddings
26
+ from pathlib import Path
27
+ from sentence_transformers import (
28
+ SentenceTransformer,
29
+ SentenceTransformerTrainer,
30
+ SentenceTransformerTrainingArguments,
31
+ )
32
+ from sentence_transformers.losses import MultipleNegativesRankingLoss
33
+ from sentence_transformers.models import StaticEmbedding
34
+ from sentence_transformers.training_args import BatchSamplers
35
+ from skeletoken import TokenizerModel
36
+ from sklearn.decomposition import PCA
37
+ from tokenlearn.losses import Loss
38
+ from tokenlearn.model import StaticModelForFineTuning
39
+ from tokenlearn.utils import create_vocab
40
+ from transformers import AutoModel, AutoTokenizer
41
+
42
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
43
+ logger = logging.getLogger(__name__)
44
+
45
+ TEACHER_MODEL = "nomic-ai/CodeRankEmbed"
46
+ OUTPUT_DIR = Path("models")
47
+
48
+ # Distill
49
+ VOCAB_SIZE = 42_000 # extra tokens mined from CornStack → ~62.5k total → ~16M params
50
+ PCA_DIMS = 256
51
+ SIF_COEFFICIENT = 1e-4
52
+
53
+ # Tokenlearn
54
+ TOKENLEARN_DOCS_DATASET = "minishlab/tokenlearn-cornstack-docs-coderankembed"
55
+ TOKENLEARN_QUERIES_DATASET = "minishlab/tokenlearn-cornstack-queries-coderankembed"
56
+ TOKENLEARN_LANGUAGES = ["go", "java", "javascript", "php", "python", "ruby"]
57
+ TOKENLEARN_MAX_PER_LANGUAGE = 20_000 # 20k docs + 20k queries × 6 langs = 240k total
58
+ TOKENLEARN_LR = 1e-3
59
+ TOKENLEARN_MAX_EPOCHS = 20 # early stopping (patience=5) typically kicks in earlier
60
+ TOKENLEARN_BATCH_SIZE = 128
61
+
62
+ # Contrastive
63
+ CORNSTACK_DATASETS = {
64
+ "python": "nomic-ai/cornstack-python-v1",
65
+ "java": "nomic-ai/cornstack-java-v1",
66
+ "php": "nomic-ai/cornstack-php-v1",
67
+ "go": "nomic-ai/cornstack-go-v1",
68
+ "javascript": "nomic-ai/cornstack-javascript-v1",
69
+ "ruby": "nomic-ai/cornstack-ruby-v1",
70
+ }
71
+ CONTRASTIVE_MAX_PER_LANGUAGE = 20_000 # 20k × 6 langs = 120k pairs total
72
+ CONTRASTIVE_LR = 5e-3
73
+ CONTRASTIVE_EPOCHS = 3
74
+ CONTRASTIVE_BATCH_SIZE = 512
75
+ CONTRASTIVE_SEED = 42
76
+
77
+
78
+ def apply_post_sif(model: StaticModel, pca_dims: int, sif_coefficient: float) -> StaticModel:
79
+ """Apply post-SIF re-regularization to a static model."""
80
+ embeddings_np = model.embedding.astype(np.float32)
81
+ processed, weights = post_process_embeddings(embeddings_np, pca_dims=pca_dims, sif_coefficient=sif_coefficient)
82
+ logger.info("post_process_embeddings: %s → %s", embeddings_np.shape, processed.shape)
83
+ model.embedding = processed
84
+ model.weights = weights
85
+ return model
86
+
87
+
88
+ def run_distill(save_path: Path) -> None:
89
+ """Distill CodeRankEmbed into a static model with an extended code vocabulary."""
90
+ logger.info("Downloading %s ...", TEACHER_MODEL)
91
+ local_path = snapshot_download(TEACHER_MODEL)
92
+ model = AutoModel.from_pretrained(local_path, trust_remote_code=True)
93
+ tokenizer = AutoTokenizer.from_pretrained(local_path, trust_remote_code=True, use_fast=True)
94
+
95
+ # Load tokenlearn corpus texts for vocab mining (docs + queries, 20k/lang)
96
+ logger.info("Loading texts for vocabulary mining ...")
97
+ shards = []
98
+ for lang in TOKENLEARN_LANGUAGES:
99
+ docs = load_dataset(TOKENLEARN_DOCS_DATASET, name=lang, split=f"train[:{TOKENLEARN_MAX_PER_LANGUAGE}]")
100
+ queries = load_dataset(TOKENLEARN_QUERIES_DATASET, name=lang, split=f"train[:{TOKENLEARN_MAX_PER_LANGUAGE}]")
101
+ shards.extend([docs, queries])
102
+ corpus = concatenate_datasets(shards)
103
+ texts: list[str] = list(corpus["text"])
104
+ logger.info("Loaded %d texts for vocab mining.", len(texts))
105
+
106
+ logger.info("Mining vocabulary (target size=%d) ...", VOCAB_SIZE)
107
+ vocab = create_vocab(texts=texts, vocab_size=VOCAB_SIZE)
108
+ logger.info("Mined %d tokens.", len(vocab))
109
+
110
+ # Filter: keep only new single-token entries not already in CodeRankEmbed vocabulary.
111
+ tokenizer_model = TokenizerModel.from_transformers_tokenizer(tokenizer).prune_added_tokens()
112
+ preprocessor = tokenizer_model.preprocessor
113
+ seen = set(tokenizer_model.sorted_vocabulary)
114
+ filtered = []
115
+ for token in vocab:
116
+ preprocessed = preprocessor.preprocess(token)
117
+ if len(preprocessed) == 1 and preprocessed[0] not in seen:
118
+ seen.add(preprocessed[0])
119
+ filtered.append(preprocessed[0])
120
+ logger.info("Vocabulary after filtering: %d tokens added to CodeRankEmbed.", len(filtered))
121
+
122
+ # NomicBERT requires monkey-patched embedding accessors.
123
+ model.get_input_embeddings = lambda: model.embeddings.word_embeddings
124
+ model.set_input_embeddings = lambda v: setattr(model.embeddings, "word_embeddings", v)
125
+
126
+ logger.info("Distilling (pca_dims=%d, sif=%g) ...", PCA_DIMS, SIF_COEFFICIENT)
127
+ static_model = distill_from_model(
128
+ model=model,
129
+ tokenizer=tokenizer,
130
+ vocabulary=filtered,
131
+ pca_dims=PCA_DIMS,
132
+ sif_coefficient=SIF_COEFFICIENT,
133
+ pooling="mean",
134
+ quantize_to="float32",
135
+ )
136
+
137
+ save_path.mkdir(parents=True, exist_ok=True)
138
+ static_model.save_pretrained(str(save_path))
139
+ logger.info(
140
+ "Distilled model saved to %s (vocab=%d, dims=%d)",
141
+ save_path,
142
+ static_model.embedding.shape[0],
143
+ static_model.embedding.shape[1],
144
+ )
145
+
146
+
147
+ def run_tokenlearn(base_model_path: Path, save_path: Path) -> None:
148
+ """Fine-tune the distilled model on CornStack using cosine similarity loss."""
149
+ # Load 20k docs + 20k queries per language → 240k total
150
+ logger.info(
151
+ "Loading tokenlearn data (docs + queries, %d/lang × %d langs) ...",
152
+ TOKENLEARN_MAX_PER_LANGUAGE,
153
+ len(TOKENLEARN_LANGUAGES),
154
+ )
155
+ shards = []
156
+ for lang in TOKENLEARN_LANGUAGES:
157
+ docs = load_dataset(TOKENLEARN_DOCS_DATASET, name=lang, split=f"train[:{TOKENLEARN_MAX_PER_LANGUAGE}]")
158
+ queries = load_dataset(TOKENLEARN_QUERIES_DATASET, name=lang, split=f"train[:{TOKENLEARN_MAX_PER_LANGUAGE}]")
159
+ shards.extend([docs, queries])
160
+ dataset = concatenate_datasets(shards)
161
+ logger.info("Total samples: %d", len(dataset))
162
+
163
+ train_txt: list[str] = list(dataset["text"])
164
+ train_vec = np.array(dataset["embedding"], dtype=np.float32)
165
+ non_nan_mask = ~np.isnan(train_vec).any(axis=1)
166
+ train_txt = np.array(train_txt)[non_nan_mask].tolist()
167
+ train_vec = train_vec[non_nan_mask]
168
+ logger.info("Loaded %d samples, raw vector shape: %s", len(train_txt), train_vec.shape)
169
+
170
+ logger.info("Fitting PCA to %d dims ...", PCA_DIMS)
171
+ pca = PCA(n_components=PCA_DIMS)
172
+ train_vec = pca.fit_transform(train_vec)
173
+ logger.info("Explained variance: %.4f. Shape: %s", pca.explained_variance_ratio_.cumsum()[-1], train_vec.shape)
174
+
175
+ logger.info("Loading base model from %s ...", base_model_path)
176
+ base_model = StaticModel.from_pretrained(str(base_model_path), force_download=False)
177
+ if base_model.embedding.dtype != np.float32:
178
+ base_model.embedding = base_model.embedding.astype(np.float32)
179
+
180
+ trainable = StaticModelForFineTuning.from_static_model(
181
+ model=base_model,
182
+ out_dim=PCA_DIMS,
183
+ loss=Loss("cosine"),
184
+ )
185
+ logger.info(
186
+ "Training tokenlearn (lr=%g, max_epochs=%d, batch=%d) ...",
187
+ TOKENLEARN_LR,
188
+ TOKENLEARN_MAX_EPOCHS,
189
+ TOKENLEARN_BATCH_SIZE,
190
+ )
191
+ trainable.fit(
192
+ X=train_txt,
193
+ y=torch.from_numpy(train_vec.astype(np.float32)),
194
+ batch_size=TOKENLEARN_BATCH_SIZE,
195
+ learning_rate=TOKENLEARN_LR,
196
+ max_epochs=TOKENLEARN_MAX_EPOCHS,
197
+ early_stopping_patience=5,
198
+ use_wandb=False,
199
+ )
200
+ logger.info("Tokenlearn training complete.")
201
+
202
+ trained_model = trainable.to_static_model()
203
+ trained_model = apply_post_sif(trained_model, pca_dims=PCA_DIMS, sif_coefficient=SIF_COEFFICIENT)
204
+
205
+ save_path.mkdir(parents=True, exist_ok=True)
206
+ trained_model.save_pretrained(str(save_path))
207
+ logger.info("Tokenlearn model saved to %s", save_path)
208
+
209
+
210
+ def run_contrastive(base_model_path: Path, save_path: Path) -> None:
211
+ """Fine-tune the tokenlearn model using MultipleNegativesRankingLoss on CornStack pairs."""
212
+ random.seed(CONTRASTIVE_SEED)
213
+
214
+ logger.info(
215
+ "Streaming CornStack pairs (%d/lang × %d langs) ...", CONTRASTIVE_MAX_PER_LANGUAGE, len(CORNSTACK_DATASETS)
216
+ )
217
+ all_queries: list[str] = []
218
+ all_docs: list[str] = []
219
+ for lang, hf_name in CORNSTACK_DATASETS.items():
220
+ hf_ds = load_dataset(hf_name, split="train", streaming=True)
221
+ hf_ds = hf_ds.shuffle(seed=CONTRASTIVE_SEED, buffer_size=10_000)
222
+ kept = 0
223
+ seen_q: set[str] = set()
224
+ seen_d: set[str] = set()
225
+ for row in hf_ds:
226
+ q, d = row.get("query"), row.get("document")
227
+ if not isinstance(q, str) or not isinstance(d, str):
228
+ continue
229
+ if len(q) < 32 or len(d) < 32:
230
+ continue
231
+ if q in seen_q or d in seen_d:
232
+ continue
233
+ seen_q.add(q)
234
+ seen_d.add(d)
235
+ all_queries.append(q)
236
+ all_docs.append(d)
237
+ kept += 1
238
+ if kept >= CONTRASTIVE_MAX_PER_LANGUAGE:
239
+ break
240
+ logger.info(" %s: %d pairs", lang, kept)
241
+
242
+ logger.info("Total pairs: %d", len(all_queries))
243
+ train_dataset = Dataset.from_dict({"anchor": all_queries, "positive": all_docs})
244
+
245
+ static_embedding = StaticEmbedding.from_model2vec(str(base_model_path))
246
+ model = SentenceTransformer(modules=[static_embedding])
247
+ loss = MultipleNegativesRankingLoss(model)
248
+
249
+ training_args = SentenceTransformerTrainingArguments(
250
+ output_dir=str(save_path) + "-checkpoints",
251
+ num_train_epochs=CONTRASTIVE_EPOCHS,
252
+ per_device_train_batch_size=CONTRASTIVE_BATCH_SIZE,
253
+ learning_rate=CONTRASTIVE_LR,
254
+ warmup_steps=0.1,
255
+ fp16=False,
256
+ bf16=False,
257
+ batch_sampler=BatchSamplers.NO_DUPLICATES,
258
+ save_strategy="no",
259
+ logging_steps=100,
260
+ logging_first_step=True,
261
+ report_to=[],
262
+ )
263
+ logger.info(
264
+ "Training contrastive (lr=%g, epochs=%d, batch=%d) ...",
265
+ CONTRASTIVE_LR,
266
+ CONTRASTIVE_EPOCHS,
267
+ CONTRASTIVE_BATCH_SIZE,
268
+ )
269
+
270
+ trainer = SentenceTransformerTrainer(
271
+ model=model,
272
+ args=training_args,
273
+ train_dataset=train_dataset,
274
+ loss=loss,
275
+ )
276
+ trainer.train()
277
+ logger.info("Contrastive training complete.")
278
+
279
+ base_m2v = StaticModel.from_pretrained(str(base_model_path), force_download=False)
280
+ base_m2v.embedding = model[0].embedding.weight.detach().cpu().float().numpy()
281
+
282
+ final_model = apply_post_sif(base_m2v, pca_dims=PCA_DIMS, sif_coefficient=SIF_COEFFICIENT)
283
+
284
+ save_path.mkdir(parents=True, exist_ok=True)
285
+ final_model.save_pretrained(str(save_path))
286
+ logger.info("Final model saved to %s", save_path)
287
+
288
+
289
+ if __name__ == "__main__":
290
+ distilled_path = OUTPUT_DIR / "potion-code-16M-distilled"
291
+ tokenlearn_path = OUTPUT_DIR / "potion-code-16M-tokenlearn"
292
+ contrastive_path = OUTPUT_DIR / "potion-code-16M-contrastive"
293
+
294
+ logger.info("=== Step 1/3: Distill ===")
295
+ run_distill(save_path=distilled_path)
296
+
297
+ logger.info("=== Step 2/3: Tokenlearn ===")
298
+ run_tokenlearn(base_model_path=distilled_path, save_path=tokenlearn_path)
299
+
300
+ logger.info("=== Step 3/3: Contrastive ===")
301
+ run_contrastive(base_model_path=tokenlearn_path, save_path=contrastive_path)
302
+
303
+ logger.info("Done. Final model: %s", contrastive_path)