Spaces:
Running
Running
MSG commited on
Commit Β·
59e2c8a
1
Parent(s): abab738
Merge pull request #4 from MSghais/experiment/small_model_building_testing
Browse filesThis view is limited to 50 files because it contains too many changes. Β See raw diff
- .cursor/plans/ensemble_refactor_benchmark_be904eec.plan.md +267 -0
- .cursor/plans/lm-eval_harness_integration_34af70c7.plan.md +339 -0
- .cursor/plans/model_verification_pipeline_ed9d35ab.plan.md +280 -0
- .env.example +25 -0
- .gitignore +3 -1
- README.md +3 -0
- libs/inference/pyproject.toml +1 -0
- libs/inference/src/inference/config.py +6 -0
- libs/inference/src/inference/transformers.py +12 -0
- models.yaml +21 -3
- pyproject.toml +21 -0
- research/README.md +48 -0
- research/USAGE.md +311 -0
- research/data/benchmark-kb.jsonl +10 -0
- research/data/benchmark-qa.jsonl +10 -0
- research/data/education-lesson-chat.jsonl +8 -0
- research/docs/overview.md +114 -0
- research/ensemble/README.md +113 -0
- research/ensemble/pyproject.toml +16 -0
- research/ensemble/scripts/smoke.sh +35 -0
- research/ensemble/src/ensemble/__init__.py +15 -0
- research/ensemble/src/ensemble/backends.py +418 -0
- research/ensemble/src/ensemble/bridge.py +28 -0
- research/ensemble/src/ensemble/checkpoint.py +149 -0
- research/ensemble/src/ensemble/config.py +163 -0
- research/ensemble/src/ensemble/energy.py +45 -0
- research/ensemble/src/ensemble/eval/__init__.py +1 -0
- research/ensemble/src/ensemble/eval/jepa_harness.py +266 -0
- research/ensemble/src/ensemble/eval/metrics.py +42 -0
- research/ensemble/src/ensemble/eval/world_harness.py +174 -0
- research/ensemble/src/ensemble/eval_harness.py +309 -0
- research/ensemble/src/ensemble/jepa.py +75 -0
- research/ensemble/src/ensemble/jepa_ensemble.py +232 -0
- research/ensemble/src/ensemble/llm_emb_jepa_ensemble_pluggable.py +507 -0
- research/ensemble/src/ensemble/memory.py +46 -0
- research/ensemble/src/ensemble/pretrain.py +198 -0
- research/ensemble/src/ensemble/world_ensemble.py +228 -0
- research/ensemble/src/ensemble/world_model.py +40 -0
- research/ensemble/src/ensemble/world_model_ensemble.py +499 -0
- research/eval_harness.py +6 -0
- research/evals/README.md +98 -0
- research/evals/USAGE.md +376 -0
- research/evals/configs/ensemble_jepa_lesson.yaml +24 -0
- research/evals/configs/eval_profiles.yaml +160 -0
- research/evals/configs/experiment_001.yaml +53 -0
- research/evals/configs/lm_eval_code.yaml +19 -0
- research/evals/configs/lm_eval_compare_study.yaml +42 -0
- research/evals/configs/lm_eval_instructions.yaml +18 -0
- research/evals/configs/lm_eval_minicpm5.yaml +26 -0
- research/evals/configs/lm_eval_reasoning.yaml +20 -0
.cursor/plans/ensemble_refactor_benchmark_be904eec.plan.md
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
name: Ensemble refactor benchmark
|
| 3 |
+
overview: Compare Option A vs Option B, then reorganize the three research scripts into a shared `research/ensemble/` package (Option B) and add a clear run + benchmark workflow from CPU smoke tests through real-model ablations.
|
| 4 |
+
todos:
|
| 5 |
+
- id: scaffold-package
|
| 6 |
+
content: "Create research/ensemble/ package: pyproject.toml, workspace member, ensemble dependency group"
|
| 7 |
+
status: completed
|
| 8 |
+
- id: extract-shared
|
| 9 |
+
content: Extract shared modules (backends, memory, jepa, bridge, eval/metrics) from the three existing scripts
|
| 10 |
+
status: completed
|
| 11 |
+
- id: migrate-jepa
|
| 12 |
+
content: Move Ensemble + demo into jepa_ensemble.py; migrate eval_harness to eval/jepa_harness.py
|
| 13 |
+
status: completed
|
| 14 |
+
- id: migrate-world
|
| 15 |
+
content: Move WorldEnsemble into world_ensemble.py; add world_model.py and energy.py
|
| 16 |
+
status: completed
|
| 17 |
+
- id: world-harness
|
| 18 |
+
content: Add eval/world_harness.py with energy vs random vs oracle selector comparison
|
| 19 |
+
status: completed
|
| 20 |
+
- id: benchmark-data
|
| 21 |
+
content: Add research/data/benchmark-qa.jsonl and benchmark-kb.jsonl from lesson topics
|
| 22 |
+
status: completed
|
| 23 |
+
- id: run-docs
|
| 24 |
+
content: Add research/ensemble/README.md, scripts/smoke.sh, and .env.example ensemble vars with Tier 1β3 commands
|
| 25 |
+
status: completed
|
| 26 |
+
- id: cleanup-shims
|
| 27 |
+
content: Remove or thin re-export old top-level research/*.py files; verify smoke + toy harness pass
|
| 28 |
+
status: completed
|
| 29 |
+
isProject: false
|
| 30 |
+
---
|
| 31 |
+
|
| 32 |
+
# Ensemble package (Option B) + run/benchmark plan
|
| 33 |
+
|
| 34 |
+
## Option A vs Option B
|
| 35 |
+
|
| 36 |
+
| | **Option A** (folder split only) | **Option B** (shared package) |
|
| 37 |
+
|---|---|---|
|
| 38 |
+
| **Layout** | `research/ensemble/jepa/` + `research/ensemble/world_model/` β move files, fix imports | `research/ensemble/` with shared `backends.py`, `jepa.py`, `memory.py`, etc. |
|
| 39 |
+
| **Effort** | ~1 hour | ~3β4 hours |
|
| 40 |
+
| **Duplication** | Keeps ~400 lines duplicated (two LLM backends, two JEPAs, two VectorStores) | Single implementation; bugfixes apply to both ensembles |
|
| 41 |
+
| **Imports** | Still fragile (`from jepa.ensemble import ...`) | Clean: `from ensemble.jepa_ensemble import Ensemble` |
|
| 42 |
+
| **Benchmarking** | Works for JEPA only ([`research/eval_harness.py`](research/eval_harness.py)); world model has no harness | Shared [`research/ensemble/eval/metrics.py`](research/ensemble/eval/metrics.py) + harnesses for both tracks |
|
| 43 |
+
| **Future agent hook** | Harder β two divergent codepaths | One package to import from `libs/agent` later if needed |
|
| 44 |
+
| **When to pick A** | You only need tidy folders for a demo and will not touch world-model eval | β |
|
| 45 |
+
| **When to pick B** | You will run ablations, add world-model benchmarks, or train bridge checkpoints | **Your choice** |
|
| 46 |
+
|
| 47 |
+
**Recommendation:** Go straight to **Option B under `research/ensemble/`** (experiments stay out of `libs/`). Option A is a reasonable **fallback** if time runs out: do the folder split and defer deduplication.
|
| 48 |
+
|
| 49 |
+
```mermaid
|
| 50 |
+
flowchart LR
|
| 51 |
+
subgraph shared [research/ensemble shared]
|
| 52 |
+
backends[backends.py]
|
| 53 |
+
memory[memory.py]
|
| 54 |
+
jepa_mod[jepa.py]
|
| 55 |
+
bridge[bridge.py]
|
| 56 |
+
metrics[eval/metrics.py]
|
| 57 |
+
end
|
| 58 |
+
|
| 59 |
+
subgraph jepa_track [JEPA track]
|
| 60 |
+
jepa_ens[jepa_ensemble.py]
|
| 61 |
+
jepa_eval[eval/jepa_harness.py]
|
| 62 |
+
end
|
| 63 |
+
|
| 64 |
+
subgraph world_track [World-model track]
|
| 65 |
+
world_mod[world_model.py]
|
| 66 |
+
energy[energy.py]
|
| 67 |
+
world_ens[world_ensemble.py]
|
| 68 |
+
world_eval[eval/world_harness.py]
|
| 69 |
+
end
|
| 70 |
+
|
| 71 |
+
backends --> jepa_ens
|
| 72 |
+
memory --> jepa_ens
|
| 73 |
+
jepa_mod --> jepa_ens
|
| 74 |
+
bridge --> jepa_ens
|
| 75 |
+
jepa_ens --> jepa_eval
|
| 76 |
+
metrics --> jepa_eval
|
| 77 |
+
|
| 78 |
+
backends --> world_ens
|
| 79 |
+
memory --> world_ens
|
| 80 |
+
jepa_mod --> world_ens
|
| 81 |
+
world_mod --> world_ens
|
| 82 |
+
energy --> world_ens
|
| 83 |
+
bridge --> world_ens
|
| 84 |
+
world_ens --> world_eval
|
| 85 |
+
metrics --> world_eval
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
---
|
| 89 |
+
|
| 90 |
+
## Target layout
|
| 91 |
+
|
| 92 |
+
```
|
| 93 |
+
research/
|
| 94 |
+
ensemble/
|
| 95 |
+
pyproject.toml # workspace member, optional deps
|
| 96 |
+
README.md # run + benchmark commands
|
| 97 |
+
ensemble/
|
| 98 |
+
__init__.py
|
| 99 |
+
backends.py # TinyBackend + HFBackend (from jepa file)
|
| 100 |
+
memory.py # Embedder, VectorStore, Router
|
| 101 |
+
jepa.py # JEPA module
|
| 102 |
+
bridge.py
|
| 103 |
+
jepa_ensemble.py # Ensemble class (from llm_emb_jepa_ensemble_pluggable.py)
|
| 104 |
+
world_model.py # WorldModel
|
| 105 |
+
energy.py # EnergyModel
|
| 106 |
+
world_ensemble.py # WorldEnsemble (from world_model_ensemble.py)
|
| 107 |
+
eval/
|
| 108 |
+
__init__.py
|
| 109 |
+
metrics.py # EM, F1, paired_bootstrap (from eval_harness)
|
| 110 |
+
jepa_harness.py # ablation ladder + best-of-N (from eval_harness)
|
| 111 |
+
world_harness.py # NEW: energy vs random vs oracle on shared drafts
|
| 112 |
+
scripts/
|
| 113 |
+
smoke.sh # toy CPU checks for both ensembles
|
| 114 |
+
data/
|
| 115 |
+
education-lesson-chat.jsonl # existing
|
| 116 |
+
benchmark-qa.jsonl # NEW: short QA for eval (derived from lesson topics)
|
| 117 |
+
benchmark-kb.jsonl # NEW: 1β2 sentence facts per topic for RAG
|
| 118 |
+
finetune.py # unchanged
|
| 119 |
+
```
|
| 120 |
+
|
| 121 |
+
Delete after migration (or leave thin re-export shims for one release):
|
| 122 |
+
- [`research/llm_emb_jepa_ensemble_pluggable.py`](research/llm_emb_jepa_ensemble_pluggable.py)
|
| 123 |
+
- [`research/eval_harness.py`](research/eval_harness.py)
|
| 124 |
+
- [`research/world_model_ensemble.py`](research/world_model_ensemble.py)
|
| 125 |
+
|
| 126 |
+
---
|
| 127 |
+
|
| 128 |
+
## Package wiring
|
| 129 |
+
|
| 130 |
+
1. Add [`research/ensemble/pyproject.toml`](research/ensemble/pyproject.toml) with `name = "ensemble"`, `torch` required, `transformers`/`peft`/`accelerate` optional (same pattern as root [`finetune`](pyproject.toml) group).
|
| 131 |
+
2. Extend root [`pyproject.toml`](pyproject.toml):
|
| 132 |
+
- `[tool.uv.workspace] members` β add `"research/ensemble"`
|
| 133 |
+
- `[dependency-groups] ensemble = [...]` mirroring `finetune`
|
| 134 |
+
- Optional root dep: `"ensemble"` for `uv run` convenience
|
| 135 |
+
3. Add ensemble env vars to [`.env.example`](.env.example) (model path, QA/KB paths, checkpoint path).
|
| 136 |
+
|
| 137 |
+
---
|
| 138 |
+
|
| 139 |
+
## How to try the models (3 tiers)
|
| 140 |
+
|
| 141 |
+
### Tier 1 β Smoke (CPU, no HF download, ~30s)
|
| 142 |
+
|
| 143 |
+
Validates imports and inference plumbing after refactor.
|
| 144 |
+
|
| 145 |
+
```bash
|
| 146 |
+
uv sync --group ensemble
|
| 147 |
+
uv run --package ensemble python -m ensemble.jepa_ensemble tiny # train 50 steps + answer
|
| 148 |
+
uv run --package ensemble python -m ensemble.world_ensemble tiny # train 60 steps + answer
|
| 149 |
+
bash research/ensemble/scripts/smoke.sh
|
| 150 |
+
```
|
| 151 |
+
|
| 152 |
+
Uses `TinyBackend` / `TinyLLM` β random weights, synthetic segments. Confirms modules load and forward/generate paths work.
|
| 153 |
+
|
| 154 |
+
### Tier 2 β Micro demo (real small model, ~2β5 min)
|
| 155 |
+
|
| 156 |
+
Quick βdoes it run on GPU/CPU with a real tokenizer?β
|
| 157 |
+
|
| 158 |
+
```bash
|
| 159 |
+
uv run --package ensemble python -m ensemble.jepa_ensemble Qwen/Qwen2.5-0.5B-Instruct
|
| 160 |
+
uv run --package ensemble python -m ensemble.world_ensemble Qwen/Qwen2.5-0.5B-Instruct
|
| 161 |
+
```
|
| 162 |
+
|
| 163 |
+
Or a local/finetuned path from [`models.yaml`](models.yaml) / `FINETUNE_OUT`:
|
| 164 |
+
|
| 165 |
+
```bash
|
| 166 |
+
uv run --package ensemble python -m ensemble.jepa_ensemble ./models/finetuned/minicpm5-1b-lora-merged
|
| 167 |
+
```
|
| 168 |
+
|
| 169 |
+
### Tier 3 β Benchmark (ablation + significance)
|
| 170 |
+
|
| 171 |
+
**JEPA ablation ladder** (existing logic, moved to `eval/jepa_harness.py`):
|
| 172 |
+
|
| 173 |
+
```bash
|
| 174 |
+
# Toy benchmark (no download)
|
| 175 |
+
uv run --package ensemble python -m ensemble.eval.jepa_harness \
|
| 176 |
+
--llm tiny --toy --limit 20 --n_drafts 8
|
| 177 |
+
|
| 178 |
+
# Real model + project-aligned QA
|
| 179 |
+
uv run --package ensemble python -m ensemble.eval.jepa_harness \
|
| 180 |
+
--llm Qwen/Qwen2.5-0.5B-Instruct \
|
| 181 |
+
--qa research/data/benchmark-qa.jsonl \
|
| 182 |
+
--kb research/data/benchmark-kb.jsonl \
|
| 183 |
+
--limit 50 --n_drafts 8
|
| 184 |
+
|
| 185 |
+
# With bridge-trained checkpoint (C5)
|
| 186 |
+
uv run --package ensemble python -m ensemble.eval.jepa_harness \
|
| 187 |
+
--llm ./models/finetuned/minicpm5-1b-lora-merged \
|
| 188 |
+
--qa research/data/benchmark-qa.jsonl \
|
| 189 |
+
--kb research/data/benchmark-kb.jsonl \
|
| 190 |
+
--ckpt ./checkpoints/ensemble_bridge.pt
|
| 191 |
+
```
|
| 192 |
+
|
| 193 |
+
**World-model benchmark** (new `eval/world_harness.py`, parallel API):
|
| 194 |
+
|
| 195 |
+
- Same QA/KB inputs
|
| 196 |
+
- Compare selectors on shared drafts: `first | random | energy | oracle`
|
| 197 |
+
- Report mean energy gap and `P(energy > random)` via shared bootstrap
|
| 198 |
+
|
| 199 |
+
**Optional continual-forgetting flag** (JEPA only): `--continual` as today.
|
| 200 |
+
|
| 201 |
+
---
|
| 202 |
+
|
| 203 |
+
## Benchmark data
|
| 204 |
+
|
| 205 |
+
Add small files derived from lesson topics in [`research/data/education-lesson-chat.jsonl`](research/data/education-lesson-chat.jsonl):
|
| 206 |
+
|
| 207 |
+
**`benchmark-qa.jsonl`** (~8β10 rows):
|
| 208 |
+
|
| 209 |
+
```json
|
| 210 |
+
{"question": "What is photosynthesis?", "answer": "Plants make food using sunlight, water, and carbon dioxide.", "domain": "science"}
|
| 211 |
+
```
|
| 212 |
+
|
| 213 |
+
**`benchmark-kb.jsonl`** (~8β10 rows):
|
| 214 |
+
|
| 215 |
+
```json
|
| 216 |
+
{"text": "Photosynthesis: plants use sunlight, water, and CO2 to make glucose and release oxygen."}
|
| 217 |
+
```
|
| 218 |
+
|
| 219 |
+
Keeps eval aligned with the hackathon domain (education) without needing a large external QA set.
|
| 220 |
+
|
| 221 |
+
---
|
| 222 |
+
|
| 223 |
+
## What βgoodβ benchmark output looks like
|
| 224 |
+
|
| 225 |
+
**JEPA harness** (from existing [`eval_harness.py`](research/eval_harness.py)):
|
| 226 |
+
|
| 227 |
+
```
|
| 228 |
+
config EM F1 lat(s)
|
| 229 |
+
C1_base 0.120 0.180 0.05
|
| 230 |
+
C2_rag 0.240 0.310 0.06
|
| 231 |
+
C3_rag_router 0.260 0.330 0.06
|
| 232 |
+
C4_full_jepa 0.320 0.400 0.45
|
| 233 |
+
|
| 234 |
+
best-of-N selector comparison:
|
| 235 |
+
first EM=0.28
|
| 236 |
+
random EM=0.31
|
| 237 |
+
jepa EM=0.38
|
| 238 |
+
oracle EM=0.52
|
| 239 |
+
P(jepa > random) = 0.97 JEPA critic WORKS
|
| 240 |
+
```
|
| 241 |
+
|
| 242 |
+
On **toy + untrained weights**, deltas will be noise β smoke tier only checks the pipeline runs. Meaningful numbers need Tier 2/3 with a real LLM and/or a trained bridge checkpoint.
|
| 243 |
+
|
| 244 |
+
**World harness** (new): same table shape but selectors `first | random | energy | oracle`.
|
| 245 |
+
|
| 246 |
+
---
|
| 247 |
+
|
| 248 |
+
## Implementation notes (minimal scope)
|
| 249 |
+
|
| 250 |
+
- Preserve existing public classes: `Ensemble`, `WorldEnsemble`, same `answer_ids` / `answer` APIs.
|
| 251 |
+
- `HFBackend` in JEPA track keeps LoRA adapter bank + router; world track keeps simpler `HFLLM` (no router) β unify only where behavior matches.
|
| 252 |
+
- Add `if __name__ == "__main__"` entrypoints on `jepa_ensemble.py` and `world_ensemble.py` (move demos from current files).
|
| 253 |
+
- No integration with Gradio/`libs/agent` in this pass β research-only per your choice.
|
| 254 |
+
- No new pytest suite unless requested; `smoke.sh` + toy harness are the acceptance check.
|
| 255 |
+
|
| 256 |
+
---
|
| 257 |
+
|
| 258 |
+
## If time is short (Option A fallback)
|
| 259 |
+
|
| 260 |
+
Skip deduplication; only:
|
| 261 |
+
|
| 262 |
+
```
|
| 263 |
+
research/ensemble/jepa/{ensemble.py, eval_harness.py}
|
| 264 |
+
research/ensemble/world_model/{ensemble.py}
|
| 265 |
+
```
|
| 266 |
+
|
| 267 |
+
Fix imports, add `README.md` with Tier 1β3 commands. Revisit shared modules post-hackathon.
|
.cursor/plans/lm-eval_harness_integration_34af70c7.plan.md
ADDED
|
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
name: lm-eval harness integration
|
| 3 |
+
overview: Add EleutherAI lm-evaluation-harness as the primary academic benchmark track for finetuned HF/LoRA models and ensemble checkpoints, with preset-aware CLI, fair-comparison configs, and structured results β building on existing slm-evals and ensemble harnesses without replacing them.
|
| 4 |
+
todos:
|
| 5 |
+
- id: deps-lm-eval
|
| 6 |
+
content: Add lm-eval[hf] optional dependency group to root pyproject.toml and slm-evals extras
|
| 7 |
+
status: completed
|
| 8 |
+
- id: preset-resolver
|
| 9 |
+
content: Implement preset_resolver.py mapping models.yaml presets to lm-eval model_args (base, LoRA, merged)
|
| 10 |
+
status: completed
|
| 11 |
+
- id: ensemble-lm-backend
|
| 12 |
+
content: Implement ensemble-lm custom LM class (generate_until + loglikelihood via underlying HF LLM)
|
| 13 |
+
status: completed
|
| 14 |
+
- id: slm-lm-eval-cli
|
| 15 |
+
content: Build run_lm_eval.py CLI with YAML config, simple_evaluate(), results + summary.md + --compare-to
|
| 16 |
+
status: completed
|
| 17 |
+
- id: lm-eval-configs
|
| 18 |
+
content: Add lm_eval_minicpm5.yaml, lm_eval_smoke.yaml, lm_eval_compare_study.yaml under research/evals/configs/
|
| 19 |
+
status: completed
|
| 20 |
+
- id: docs-workflows
|
| 21 |
+
content: Document finetune + ensemble lm-eval workflows in research/evals/USAGE.md and research/USAGE.md
|
| 22 |
+
status: completed
|
| 23 |
+
- id: finetune-hook
|
| 24 |
+
content: "Optional: add --lm-eval-after to finetune.py after CLI smoke tests pass"
|
| 25 |
+
status: completed
|
| 26 |
+
- id: phase2-verify
|
| 27 |
+
content: "Deferred: slm-compare, education_qa, harness JSON export from verification pipeline plan"
|
| 28 |
+
status: cancelled
|
| 29 |
+
isProject: false
|
| 30 |
+
---
|
| 31 |
+
|
| 32 |
+
# lm-evaluation-harness Integration for research/ Models
|
| 33 |
+
|
| 34 |
+
## Context: what you already have
|
| 35 |
+
|
| 36 |
+
Your repo runs **three eval tracks** today; none use standard academic benchmarks (GSM8K, ARC, HellaSwag):
|
| 37 |
+
|
| 38 |
+
| Track | Tool | Best for |
|
| 39 |
+
|-------|------|----------|
|
| 40 |
+
| Agentic | [`slm-benchmark`](research/evals/src/slm_evals/run_benchmark.py) | BFCL, Ο-bench, GAIA, SWE |
|
| 41 |
+
| Finetune training | [`finetune.py`](research/finetune.py) | eval_loss / perplexity only |
|
| 42 |
+
| Ensemble components | [`jepa_harness`](research/ensemble/src/ensemble/eval/jepa_harness.py) | RAG/router/JEPA ablation on custom QA |
|
| 43 |
+
|
| 44 |
+
The existing [Model Verification Pipeline plan](.cursor/plans/model_verification_pipeline_ed9d35ab.plan.md) defers lm-eval; per your choice, **lm-eval comes first**. `slm-compare` / `education_qa` stay Phase 2.
|
| 45 |
+
|
| 46 |
+
```mermaid
|
| 47 |
+
flowchart TB
|
| 48 |
+
subgraph models [Model artifacts]
|
| 49 |
+
basePreset[models.yaml preset]
|
| 50 |
+
loraDir[LoRA adapter dir]
|
| 51 |
+
mergedDir[Merged checkpoint]
|
| 52 |
+
ensCkpt[Ensemble manifest checkpoint]
|
| 53 |
+
end
|
| 54 |
+
|
| 55 |
+
subgraph newTrack [NEW: lm-eval track]
|
| 56 |
+
resolver[PresetResolver]
|
| 57 |
+
slmLmEval[slm-lm-eval CLI]
|
| 58 |
+
hfLm[lm_eval hf backend]
|
| 59 |
+
ensLm[ensemble-lm custom backend]
|
| 60 |
+
tasks[Standard tasks]
|
| 61 |
+
results[results/lm_eval/]
|
| 62 |
+
end
|
| 63 |
+
|
| 64 |
+
subgraph existing [Existing tracks unchanged]
|
| 65 |
+
agentic[slm-benchmark]
|
| 66 |
+
ablation[jepa_harness]
|
| 67 |
+
end
|
| 68 |
+
|
| 69 |
+
basePreset --> resolver
|
| 70 |
+
loraDir --> resolver
|
| 71 |
+
mergedDir --> resolver
|
| 72 |
+
resolver --> hfLm
|
| 73 |
+
ensCkpt --> ensLm
|
| 74 |
+
slmLmEval --> hfLm
|
| 75 |
+
slmLmEval --> ensLm
|
| 76 |
+
hfLm --> tasks
|
| 77 |
+
ensLm --> tasks
|
| 78 |
+
tasks --> results
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
---
|
| 82 |
+
|
| 83 |
+
## Benchmark matrix (what to run for each claim)
|
| 84 |
+
|
| 85 |
+
Match tasks to model size (~1B) and claim type:
|
| 86 |
+
|
| 87 |
+
| Model | lm-eval tasks (primary) | Existing complement |
|
| 88 |
+
|-------|-------------------------|---------------------|
|
| 89 |
+
| **Base preset** (`minicpm5-1b`, `gemma4-e2b-mobile`) | `arc_easy`, `arc_challenge`, `hellaswag`, `piqa`, `boolq`, `gsm8k` | `slm-benchmark --benchmarks bfcl,tau_bench` |
|
| 90 |
+
| **Finetuned LoRA** (`minicpm5-1b-lesson-lora`) | Same tasks; `pretrained=base,peft=adapter` | Same agentic suite |
|
| 91 |
+
| **Merged finetune** (`minicpm5-1b-lesson-merged`, `gemma-merged-local`) | Same tasks; `pretrained=merged_path` | Same |
|
| 92 |
+
| **Ensemble checkpoint** | Same generative + MC tasks via custom backend | `jepa_harness` for component ablation; `slm-benchmark --model-type ensemble` for agentic E2E |
|
| 93 |
+
|
| 94 |
+
**Smoke profile** (CPU/GPU quick check, ~5β15 min): `--tasks arc_easy,hellaswag --num_fewshot 0 --limit 50`
|
| 95 |
+
|
| 96 |
+
**Full profile** (reportable): `--tasks arc_easy,arc_challenge,hellaswag,piqa,boolq,gsm8k --num_fewshot 5` (gsm8k uses 8-shot per harness default)
|
| 97 |
+
|
| 98 |
+
**Fair comparison rules** (from your guide, enforced in YAML):
|
| 99 |
+
- Identical `tasks`, `num_fewshot`, `limit`, `seed`, `batch_size`
|
| 100 |
+
- Same base tokenizer (preset resolution guarantees this for LoRA)
|
| 101 |
+
- `temperature=0` / greedy decoding (lm-eval default for MC; gsm8k generative)
|
| 102 |
+
- Never compare `training_results.json` `result_score` to lm-eval accuracy
|
| 103 |
+
|
| 104 |
+
---
|
| 105 |
+
|
| 106 |
+
## Phase 1 β Dependencies and install surface
|
| 107 |
+
|
| 108 |
+
**Root [`pyproject.toml`](pyproject.toml)** β add optional group:
|
| 109 |
+
|
| 110 |
+
```toml
|
| 111 |
+
lm-eval = [
|
| 112 |
+
"lm-eval[hf]>=0.4.9",
|
| 113 |
+
]
|
| 114 |
+
```
|
| 115 |
+
|
| 116 |
+
Install: `uv sync --group evals --group lm-eval --group finetune`
|
| 117 |
+
|
| 118 |
+
Keep lm-eval **optional** so agentic-only workflows stay lightweight.
|
| 119 |
+
|
| 120 |
+
---
|
| 121 |
+
|
| 122 |
+
## Phase 2 β Preset β lm-eval model_args resolver
|
| 123 |
+
|
| 124 |
+
New module: `research/evals/src/slm_evals/lm_eval/preset_resolver.py`
|
| 125 |
+
|
| 126 |
+
Reuse [`inference.config.get_model_config`](libs/inference/src/inference/config.py) (same as [`finetune.py`](research/finetune.py)):
|
| 127 |
+
|
| 128 |
+
| Preset shape | lm-eval invocation |
|
| 129 |
+
|--------------|-------------------|
|
| 130 |
+
| `model_id` only (base) | `pretrained={model_id},trust_remote_code=True` |
|
| 131 |
+
| `model_id` + `adapter_path` (LoRA) | `pretrained={model_id},peft={adapter_path},trust_remote_code=True` |
|
| 132 |
+
| Local merged dir | `pretrained={model_id},trust_remote_code=True` |
|
| 133 |
+
| Ensemble (`jepa-ensemble-lesson` or path with `manifest.json`) | `--model ensemble-lm` (custom backend) |
|
| 134 |
+
|
| 135 |
+
Reject multimodal / llama_cpp presets with a clear error (same rule as finetune).
|
| 136 |
+
|
| 137 |
+
---
|
| 138 |
+
|
| 139 |
+
## Phase 3 β CLI wrapper `slm-lm-eval`
|
| 140 |
+
|
| 141 |
+
New entry: `research/evals/src/slm_evals/run_lm_eval.py`
|
| 142 |
+
Register in [`research/evals/pyproject.toml`](research/evals/pyproject.toml):
|
| 143 |
+
|
| 144 |
+
```toml
|
| 145 |
+
[project.scripts]
|
| 146 |
+
slm-lm-eval = "slm_evals.run_lm_eval:main"
|
| 147 |
+
```
|
| 148 |
+
|
| 149 |
+
**Flags:**
|
| 150 |
+
|
| 151 |
+
```bash
|
| 152 |
+
uv run --package slm-evals slm-lm-eval \
|
| 153 |
+
--config research/evals/configs/lm_eval_minicpm5.yaml \
|
| 154 |
+
--preset minicpm5-1b \
|
| 155 |
+
--experiment-name minicpm5-1b__lm-eval-baseline
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
| Flag | Purpose |
|
| 159 |
+
|------|---------|
|
| 160 |
+
| `--preset` | Resolve from [`models.yaml`](models.yaml) |
|
| 161 |
+
| `--model` | Override path/Hub id (merged dir or ensemble ckpt) |
|
| 162 |
+
| `--adapter` | Override LoRA path (alternative to preset) |
|
| 163 |
+
| `--config` | YAML: tasks, num_fewshot, limit, seed, device, batch_size |
|
| 164 |
+
| `--tasks` | CLI override of task list |
|
| 165 |
+
| `--compare-to` | Path to prior `results.json` β print delta table (lightweight pre-`slm-compare`) |
|
| 166 |
+
| `--experiment-name` / `--output-dir` | Write under `results/lm_eval/{name}/` |
|
| 167 |
+
|
| 168 |
+
Implementation: subprocess or programmatic call to `lm_eval.simple_evaluate()` (preferred over shelling to `lm_eval` CLI β easier to inject custom model registration).
|
| 169 |
+
|
| 170 |
+
Output artifacts per run:
|
| 171 |
+
- `results.json` (lm-eval native)
|
| 172 |
+
- `summary.md` (human-readable table: task β acc/score)
|
| 173 |
+
- `run_meta.json` (preset, base_model, adapter_path, tasks, seed, git hash optional)
|
| 174 |
+
|
| 175 |
+
---
|
| 176 |
+
|
| 177 |
+
## Phase 4 β Ensemble custom backend
|
| 178 |
+
|
| 179 |
+
New file: `research/evals/src/slm_evals/lm_eval/ensemble_lm.py`
|
| 180 |
+
|
| 181 |
+
Register with lm-eval:
|
| 182 |
+
|
| 183 |
+
```python
|
| 184 |
+
@register_model("ensemble-lm")
|
| 185 |
+
class EnsembleLM(LM):
|
| 186 |
+
...
|
| 187 |
+
```
|
| 188 |
+
|
| 189 |
+
Load via existing [`load_ensemble_model`](research/evals/src/slm_evals/utils/model_loader.py) / `ensemble.checkpoint.load_checkpoint`.
|
| 190 |
+
|
| 191 |
+
| lm-eval method | Implementation |
|
| 192 |
+
|----------------|----------------|
|
| 193 |
+
| `generate_until` | `ens.generate_text(prompt, max_new_tokens=..., temperature=0)` β full JEPA+RAG+router stack |
|
| 194 |
+
| `loglikelihood` | Delegate to underlying `ens.llm` HF model (default adapter index 0) for MC tasks; document that this evaluates **base LLM head**, not selector β pair with `generate_until` scores for full-stack generative tasks |
|
| 195 |
+
|
| 196 |
+
This split is intentional: MC benchmarks need token logprobs; the ensembleβs value on generative QA shows up in `generate_until` (gsm8k) and in [`jepa_harness`](research/ensemble/src/ensemble/eval/jepa_harness.py).
|
| 197 |
+
|
| 198 |
+
Import side-effect: ensure `ensemble_lm` is imported before `simple_evaluate()` so registration runs.
|
| 199 |
+
|
| 200 |
+
---
|
| 201 |
+
|
| 202 |
+
## Phase 5 β Experiment configs
|
| 203 |
+
|
| 204 |
+
Add under `research/evals/configs/`:
|
| 205 |
+
|
| 206 |
+
**`lm_eval_minicpm5.yaml`** β baseline template:
|
| 207 |
+
|
| 208 |
+
```yaml
|
| 209 |
+
tasks:
|
| 210 |
+
- arc_easy
|
| 211 |
+
- arc_challenge
|
| 212 |
+
- hellaswag
|
| 213 |
+
- piqa
|
| 214 |
+
- boolq
|
| 215 |
+
- gsm8k
|
| 216 |
+
num_fewshot: 5 # gsm8k harness may override to 8 internally
|
| 217 |
+
limit: null # null = full; 100 for dev
|
| 218 |
+
seed: 42
|
| 219 |
+
batch_size: auto
|
| 220 |
+
device: auto
|
| 221 |
+
dtype: bfloat16
|
| 222 |
+
trust_remote_code: true
|
| 223 |
+
```
|
| 224 |
+
|
| 225 |
+
**`lm_eval_smoke.yaml`** β `limit: 25`, tasks `[arc_easy, hellaswag]`
|
| 226 |
+
|
| 227 |
+
**`lm_eval_compare_study.yaml`** β documents baseline + candidate preset names and shared settings
|
| 228 |
+
|
| 229 |
+
---
|
| 230 |
+
|
| 231 |
+
## Phase 6 β End-to-end workflows (finetune + ensemble)
|
| 232 |
+
|
| 233 |
+
### Finetuned model verification
|
| 234 |
+
|
| 235 |
+
```bash
|
| 236 |
+
# 1. Baseline (same config, pinned seed)
|
| 237 |
+
uv run --package slm-evals slm-lm-eval \
|
| 238 |
+
--config research/evals/configs/lm_eval_minicpm5.yaml \
|
| 239 |
+
--preset minicpm5-1b \
|
| 240 |
+
--experiment-name minicpm5-1b__baseline
|
| 241 |
+
|
| 242 |
+
# 2. Train
|
| 243 |
+
uv run python research/finetune.py --preset minicpm5-1b --mode lora --epochs 3
|
| 244 |
+
|
| 245 |
+
# 3. Candidate β LoRA via preset (no merge required)
|
| 246 |
+
uv run --package slm-evals slm-lm-eval \
|
| 247 |
+
--config research/evals/configs/lm_eval_minicpm5.yaml \
|
| 248 |
+
--preset minicpm5-1b-lesson-lora \
|
| 249 |
+
--experiment-name minicpm5-1b-lora__v1 \
|
| 250 |
+
--compare-to results/lm_eval/minicpm5-1b__baseline/results.json
|
| 251 |
+
```
|
| 252 |
+
|
| 253 |
+
For Gemma: use `gemma-lora-local` / `gemma-merged-local` presets after notebook or `finetune.py` training.
|
| 254 |
+
|
| 255 |
+
### Ensemble verification
|
| 256 |
+
|
| 257 |
+
```bash
|
| 258 |
+
# Component ablation (domain QA) β existing
|
| 259 |
+
uv run --package ensemble python -m ensemble.eval.jepa_harness \
|
| 260 |
+
--llm openbmb/MiniCPM5-1B \
|
| 261 |
+
--qa research/data/benchmark-qa.jsonl \
|
| 262 |
+
--kb research/data/benchmark-kb.jsonl
|
| 263 |
+
|
| 264 |
+
# Academic benchmarks on saved ensemble
|
| 265 |
+
uv run --package slm-evals slm-lm-eval \
|
| 266 |
+
--config research/evals/configs/lm_eval_minicpm5.yaml \
|
| 267 |
+
--model ./models/ensemble/jepa-lesson-pretrain \
|
| 268 |
+
--experiment-name ensemble-jepa__lm-eval
|
| 269 |
+
```
|
| 270 |
+
|
| 271 |
+
### Optional post-finetune hook (minimal, Phase 1.5)
|
| 272 |
+
|
| 273 |
+
Add to [`finetune.py`](research/finetune.py) only after CLI stabilizes:
|
| 274 |
+
|
| 275 |
+
- `--lm-eval-after` + `--lm-eval-config` β subprocess `slm-lm-eval` on output checkpoint
|
| 276 |
+
- Append `lm_eval_summary` path to `training_results.json`
|
| 277 |
+
|
| 278 |
+
Defer `--eval-baseline` auto-compare until Phase 2 `slm-compare` exists; use `--compare-to` on lm-eval outputs in the meantime.
|
| 279 |
+
|
| 280 |
+
---
|
| 281 |
+
|
| 282 |
+
## Phase 7 β Documentation
|
| 283 |
+
|
| 284 |
+
Update:
|
| 285 |
+
- [`research/evals/USAGE.md`](research/evals/USAGE.md) β lm-eval section, PEFT notes, task profiles
|
| 286 |
+
- [`research/USAGE.md`](research/USAGE.md) β unified βverify finetune / ensembleβ checklist
|
| 287 |
+
- [`.env.example`](.env.example) β optional `LM_EVAL_TASKS`, `LM_EVAL_SEED` (low priority)
|
| 288 |
+
|
| 289 |
+
Include the verification checklist from your guide (seeds, fair comparison, no val-set leakage) mapped to concrete commands.
|
| 290 |
+
|
| 291 |
+
---
|
| 292 |
+
|
| 293 |
+
## File change summary
|
| 294 |
+
|
| 295 |
+
| File | Change |
|
| 296 |
+
|------|--------|
|
| 297 |
+
| [`pyproject.toml`](pyproject.toml) | `lm-eval` dependency group |
|
| 298 |
+
| [`research/evals/pyproject.toml`](research/evals/pyproject.toml) | Optional `lm-eval[hf]` extra; `slm-lm-eval` script |
|
| 299 |
+
| `research/evals/src/slm_evals/lm_eval/__init__.py` | Package init |
|
| 300 |
+
| `research/evals/src/slm_evals/lm_eval/preset_resolver.py` | Preset β model_args |
|
| 301 |
+
| `research/evals/src/slm_evals/lm_eval/ensemble_lm.py` | Custom `ensemble-lm` backend |
|
| 302 |
+
| `research/evals/src/slm_evals/run_lm_eval.py` | Main CLI |
|
| 303 |
+
| `research/evals/configs/lm_eval_*.yaml` | Baseline, smoke, compare templates |
|
| 304 |
+
| [`research/finetune.py`](research/finetune.py) | Optional `--lm-eval-after` (after CLI stable) |
|
| 305 |
+
| [`research/evals/USAGE.md`](research/evals/USAGE.md), [`research/USAGE.md`](research/USAGE.md) | Workflows |
|
| 306 |
+
|
| 307 |
+
---
|
| 308 |
+
|
| 309 |
+
## Testing plan
|
| 310 |
+
|
| 311 |
+
1. **Import smoke**: `uv run --package slm-evals python -c "import lm_eval; import slm_evals.lm_eval.ensemble_lm"`
|
| 312 |
+
2. **HF base**: `slm-lm-eval --config lm_eval_smoke.yaml --preset minicpm5-1b --limit 10`
|
| 313 |
+
3. **LoRA**: same with `--preset minicpm5-1b-lesson-lora` (or `--adapter ./models/finetuned/...`)
|
| 314 |
+
4. **Ensemble**: `slm-lm-eval --model ./models/ensemble/... --config lm_eval_smoke.yaml`
|
| 315 |
+
5. **Compare**: baseline + candidate runs produce `--compare-to` delta without task/seed mismatch warnings
|
| 316 |
+
6. **Regression**: existing `slm-benchmark` and `jepa_harness --toy` still pass unchanged
|
| 317 |
+
|
| 318 |
+
---
|
| 319 |
+
|
| 320 |
+
## Phase 2 (deferred β existing verification plan)
|
| 321 |
+
|
| 322 |
+
After lm-eval stabilizes, implement from [model_verification_pipeline plan](.cursor/plans/model_verification_pipeline_ed9d35ab.plan.md):
|
| 323 |
+
|
| 324 |
+
- `slm-compare` with paired bootstrap across **both** lm-eval and slm-benchmark JSON
|
| 325 |
+
- `education_qa` domain benchmark in slm-evals
|
| 326 |
+
- PEFT loading in `slm-benchmark` model_loader (reuse `preset_resolver`)
|
| 327 |
+
- Harness JSON export (`jepa_harness --output-dir`)
|
| 328 |
+
|
| 329 |
+
This gives you: **academic generalization** (lm-eval) + **agentic capability** (slm-evals) + **ensemble component proof** (jepa harness) with shared statistical comparison.
|
| 330 |
+
|
| 331 |
+
---
|
| 332 |
+
|
| 333 |
+
## Expected limitations (document, donβt hide)
|
| 334 |
+
|
| 335 |
+
- **1B models** will score low on gsm8k/mmlu β use for **relative** baseline vs finetune deltas, not SOTA claims
|
| 336 |
+
- **Ensemble loglikelihood** uses underlying LLM, not full selector stack β report both modes in docs
|
| 337 |
+
- **MiniCPM / Gemma** need `trust_remote_code=True`
|
| 338 |
+
- **First lm-eval run** downloads datasets; pin `HF_HOME` / cache for reproducibility
|
| 339 |
+
- **Multi-seed training** (3β5 seeds) is manual until Phase 2 compare aggregates runs
|
.cursor/plans/model_verification_pipeline_ed9d35ab.plan.md
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
name: Model Verification Pipeline
|
| 3 |
+
overview: Operationalize model verification for both fine-tuned SLMs and ensemble research by adding comparison tooling, preset-aware eval, a lesson-domain benchmark, reproducible experiment configs, and optional post-finetune eval β all grounded in your existing `slm-evals` and ensemble harness infrastructure.
|
| 4 |
+
todos:
|
| 5 |
+
- id: eval-infra
|
| 6 |
+
content: Add preset resolution, PEFT loading, eval_seed sample pinning, and education_qa benchmark to slm-evals
|
| 7 |
+
status: pending
|
| 8 |
+
- id: compare-cli
|
| 9 |
+
content: Implement slm-compare with paired bootstrap, delta tables, and subset mismatch warnings
|
| 10 |
+
status: pending
|
| 11 |
+
- id: finetune-hook
|
| 12 |
+
content: Add --eval-after / --eval-baseline flags to finetune.py with subprocess eval + compare
|
| 13 |
+
status: pending
|
| 14 |
+
- id: ensemble-json
|
| 15 |
+
content: Persist jepa/world harness results to JSON (--output-dir) using shared stats helpers
|
| 16 |
+
status: pending
|
| 17 |
+
- id: docs-configs
|
| 18 |
+
content: Add baseline_minicpm5.yaml template and verification checklist to research/ USAGE docs
|
| 19 |
+
status: pending
|
| 20 |
+
isProject: false
|
| 21 |
+
---
|
| 22 |
+
|
| 23 |
+
# Model Verification Pipeline for research/
|
| 24 |
+
|
| 25 |
+
## Current state
|
| 26 |
+
|
| 27 |
+
Your repo already has **three parallel eval tracks** with no unified comparison:
|
| 28 |
+
|
| 29 |
+
| Track | Tool | Metrics | Results | Stats |
|
| 30 |
+
|-------|------|---------|---------|-------|
|
| 31 |
+
| Fine-tuned SLM | [`slm-benchmark`](research/evals/src/slm_evals/run_benchmark.py) | BFCL, Ο-bench, GAIA, SWE pass rate | `results/<experiment>/results.json` | None |
|
| 32 |
+
| Finetune training | [`finetune.py`](research/finetune.py) | eval_loss, perplexity, `result_score` | `training_results.json` | None |
|
| 33 |
+
| Ensemble | [`jepa_harness`](research/ensemble/src/ensemble/eval/jepa_harness.py) / [`world_harness`](research/ensemble/src/ensemble/eval/world_harness.py) | EM, F1, ablation ladder | stdout only | [`paired_bootstrap`](research/ensemble/src/ensemble/eval/metrics.py) |
|
| 34 |
+
|
| 35 |
+
The manual before/after loop in [`research/evals/USAGE.md`](research/evals/USAGE.md) works but lacks significance testing, preset resolution, and a domain-aligned benchmark for lesson fine-tuning.
|
| 36 |
+
|
| 37 |
+
```mermaid
|
| 38 |
+
flowchart TB
|
| 39 |
+
subgraph today [Today]
|
| 40 |
+
baseline[slm-benchmark baseline]
|
| 41 |
+
finetune[finetune.py]
|
| 42 |
+
reeval[slm-benchmark new model]
|
| 43 |
+
manual[Manual diff results.csv]
|
| 44 |
+
baseline --> manual
|
| 45 |
+
reeval --> manual
|
| 46 |
+
finetune --> reeval
|
| 47 |
+
end
|
| 48 |
+
|
| 49 |
+
subgraph target [Target pipeline]
|
| 50 |
+
preset[models.yaml preset]
|
| 51 |
+
evalRun[slm-benchmark + eval_seed]
|
| 52 |
+
compare[slm-compare]
|
| 53 |
+
report[comparison_report.md]
|
| 54 |
+
preset --> evalRun
|
| 55 |
+
evalRun --> compare
|
| 56 |
+
compare --> report
|
| 57 |
+
end
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
---
|
| 61 |
+
|
| 62 |
+
## Verification strategy (mapped to your repo)
|
| 63 |
+
|
| 64 |
+
### 1. Fair comparison checklist (enforce via shared eval config)
|
| 65 |
+
|
| 66 |
+
Create a **single YAML experiment config** per comparison study (copy from [`experiment_001.yaml`](research/evals/configs/experiment_001.yaml)) and reuse it for baseline + candidate runs:
|
| 67 |
+
|
| 68 |
+
- Same `benchmarks`, `max_samples`, `benchmark_overrides`, `temperature: 0.0`, `max_new_tokens`
|
| 69 |
+
- Same `eval_seed` (new field) so sample subsets are identical across models
|
| 70 |
+
- Baseline = base preset from [`models.yaml`](models.yaml) (e.g. `minicpm5-1b`); candidate = LoRA/merged preset (e.g. `minicpm5-1b-lesson-lora`)
|
| 71 |
+
|
| 72 |
+
**Do not** compare `training_results.json` `result_score` against BFCL pass rate β they measure different things.
|
| 73 |
+
|
| 74 |
+
### 2. Benchmark selection for your use cases
|
| 75 |
+
|
| 76 |
+
| Claim | Benchmarks in this repo |
|
| 77 |
+
|-------|-------------------------|
|
| 78 |
+
| Agentic tool use (general) | `bfcl`, `tau_bench` (already in slm-evals) |
|
| 79 |
+
| End-to-end assistant | `gaia` (levels 1β2 for small models) |
|
| 80 |
+
| Code | `swe_bench` (keep `full_eval: false` unless Docker installed) |
|
| 81 |
+
| Lesson/education domain (finetune target) | **New** `education_qa` benchmark on [`research/data/benchmark-qa.jsonl`](research/data/benchmark-qa.jsonl) |
|
| 82 |
+
| Ensemble component value | Existing C1βC4 ablation ladder in JEPA harness |
|
| 83 |
+
| JEPA critic beats random | Existing selector comparison + `paired_bootstrap` |
|
| 84 |
+
|
| 85 |
+
Run **in-distribution** (`education_qa`) + **out-of-distribution** (`bfcl`, `tau_bench`) to show generalization, not just lesson memorization.
|
| 86 |
+
|
| 87 |
+
### 3. Statistical validation
|
| 88 |
+
|
| 89 |
+
Reuse existing `paired_bootstrap` from [`metrics.py`](research/ensemble/src/ensemble/eval/metrics.py) in a new shared module. For each benchmark:
|
| 90 |
+
|
| 91 |
+
- Align per-sample `passed` (or `score`) arrays by `samples[].id`
|
| 92 |
+
- Report: Ξscore, win rate, `P(candidate > baseline)` from bootstrap, 95% CI via bootstrap percentiles
|
| 93 |
+
- Flag `p > 0.95` as "significant improvement" (same threshold as JEPA harness)
|
| 94 |
+
|
| 95 |
+
For multi-seed training runs: run eval once per checkpoint seed, aggregate mean Β± std in comparison report (optional `--runs` glob in compare CLI).
|
| 96 |
+
|
| 97 |
+
---
|
| 98 |
+
|
| 99 |
+
## Implementation plan
|
| 100 |
+
|
| 101 |
+
### Phase A β Shared eval infrastructure (`research/evals/`)
|
| 102 |
+
|
| 103 |
+
**A1. Preset resolution in config loader**
|
| 104 |
+
|
| 105 |
+
Extend [`config_loader.py`](research/evals/src/slm_evals/utils/config_loader.py) and CLI in [`run_benchmark.py`](research/evals/src/slm_evals/run_benchmark.py):
|
| 106 |
+
|
| 107 |
+
- Add `--preset` flag and YAML field `preset:` (alternative to `model_path`)
|
| 108 |
+
- Resolve via existing [`inference.config.get_model_config`](libs/inference/src/inference/config.py): extract `model_id` + optional `adapter_path`
|
| 109 |
+
- Reject non-`transformers` / multimodal presets with clear error (same rule as finetune)
|
| 110 |
+
|
| 111 |
+
**A2. PEFT adapter loading**
|
| 112 |
+
|
| 113 |
+
Extend [`model_loader.py`](research/evals/src/slm_evals/utils/model_loader.py):
|
| 114 |
+
|
| 115 |
+
- If `adapter_path` is set: load base from `model_id`, attach LoRA via `peft.PeftModel.from_pretrained`
|
| 116 |
+
- Support merged checkpoints (adapter_path absent) unchanged
|
| 117 |
+
- Record `base_model`, `adapter_path`, `param_count` in results metadata
|
| 118 |
+
|
| 119 |
+
**A3. Reproducible sample subsets**
|
| 120 |
+
|
| 121 |
+
Extend [`BaseBenchmark`](research/evals/src/slm_evals/benchmarks/base.py):
|
| 122 |
+
|
| 123 |
+
- Accept `eval_seed` + `max_samples` from config
|
| 124 |
+
- After loading dataset: `rng = random.Random(eval_seed); indices = rng.sample(range(len(data)), min(max_samples, len(data)))`
|
| 125 |
+
- Persist `eval_seed`, `sample_ids` list in `results.json` so compare can verify identical subsets
|
| 126 |
+
|
| 127 |
+
**A4. Lesson-domain benchmark**
|
| 128 |
+
|
| 129 |
+
New file `research/evals/src/slm_evals/benchmarks/education_qa.py`:
|
| 130 |
+
|
| 131 |
+
- Load [`research/data/benchmark-qa.jsonl`](research/data/benchmark-qa.jsonl)
|
| 132 |
+
- Prompt: `"Answer briefly.\nQ: {question}\nA:"`
|
| 133 |
+
- Score: token-overlap F1 + normalized substring EM (reuse logic from ensemble `metrics.py` β extract to shared `slm_evals/utils/scoring.py` or import from ensemble if dependency is acceptable)
|
| 134 |
+
- Register as `education_qa` in `BENCHMARK_REGISTRY`
|
| 135 |
+
|
| 136 |
+
**A5. Comparison CLI β `slm-compare`**
|
| 137 |
+
|
| 138 |
+
New module `research/evals/src/slm_evals/compare_runs.py` + console script in [`pyproject.toml`](research/evals/pyproject.toml):
|
| 139 |
+
|
| 140 |
+
```bash
|
| 141 |
+
uv run --package slm-evals slm-compare \
|
| 142 |
+
--baseline results/minicpm5-1b__baseline/results.json \
|
| 143 |
+
--candidate results/minicpm5-1b-lora__v1/results.json \
|
| 144 |
+
--output results/comparisons/minicpm5-lora-vs-base.md
|
| 145 |
+
```
|
| 146 |
+
|
| 147 |
+
Outputs:
|
| 148 |
+
- Per-benchmark delta table (score, passed/total, latency)
|
| 149 |
+
- Paired bootstrap p-value per benchmark
|
| 150 |
+
- Per-sample win/loss/tie counts (joined on sample `id`)
|
| 151 |
+
- Warnings if `eval_seed`, `max_samples`, or benchmark sets differ
|
| 152 |
+
|
| 153 |
+
**A6. Experiment config templates**
|
| 154 |
+
|
| 155 |
+
Add two configs under `research/evals/configs/`:
|
| 156 |
+
|
| 157 |
+
- `baseline_minicpm5.yaml` β preset `minicpm5-1b`, benchmarks `[education_qa, bfcl, tau_bench]`, `max_samples: 100`, `eval_seed: 42`
|
| 158 |
+
- `compare_study.yaml` β documents baseline + candidate preset keys, shared eval settings, output naming convention
|
| 159 |
+
|
| 160 |
+
Update [`research/evals/USAGE.md`](research/evals/USAGE.md) and [`research/USAGE.md`](research/USAGE.md) with the verification checklist from your research guide.
|
| 161 |
+
|
| 162 |
+
---
|
| 163 |
+
|
| 164 |
+
### Phase B β Finetune integration (`research/finetune.py`)
|
| 165 |
+
|
| 166 |
+
**B1. Post-finetune eval hook**
|
| 167 |
+
|
| 168 |
+
Add optional flags to [`finetune.py`](research/finetune.py):
|
| 169 |
+
|
| 170 |
+
- `--eval-after` β run slm-benchmark after training completes
|
| 171 |
+
- `--eval-config PATH` β YAML with benchmark settings (defaults to `baseline_minicpm5.yaml` structure)
|
| 172 |
+
- `--eval-baseline PRESET` β also eval base preset for side-by-side comparison
|
| 173 |
+
|
| 174 |
+
On completion:
|
| 175 |
+
1. Write `training_results.json` (existing)
|
| 176 |
+
2. Run eval on output checkpoint
|
| 177 |
+
3. If `--eval-baseline` set, run baseline eval + invoke `slm-compare`
|
| 178 |
+
4. Append comparison summary path to `training_results.json` under `"post_eval": {...}`
|
| 179 |
+
|
| 180 |
+
Implementation: subprocess call to `uv run --package slm-evals slm-benchmark` (avoids circular imports).
|
| 181 |
+
|
| 182 |
+
---
|
| 183 |
+
|
| 184 |
+
### Phase C β Ensemble track unification (`research/ensemble/`)
|
| 185 |
+
|
| 186 |
+
**C1. Persist harness results to JSON**
|
| 187 |
+
|
| 188 |
+
Extend [`jepa_harness.py`](research/ensemble/src/ensemble/eval/jepa_harness.py) and [`world_harness.py`](research/ensemble/src/ensemble/eval/world_harness.py):
|
| 189 |
+
|
| 190 |
+
- Add `--output-dir` flag
|
| 191 |
+
- Write `results.json` matching slm-evals schema where possible:
|
| 192 |
+
- `benchmarks.C4_full_jepa.samples[]` with per-question EM/F1
|
| 193 |
+
- `benchmarks.C1_base` β¦ `C4_full_jepa` aggregate scores
|
| 194 |
+
- `significance[]` with paired bootstrap between ladder steps
|
| 195 |
+
|
| 196 |
+
**C2. Compare ensemble configs**
|
| 197 |
+
|
| 198 |
+
`slm-compare` accepts ensemble JSON (detect by presence of ablation config keys) and renders ablation table + bootstrap block.
|
| 199 |
+
|
| 200 |
+
**C3. Shared metrics module**
|
| 201 |
+
|
| 202 |
+
Move or re-export `em_score`, `f1_score`, `paired_bootstrap` to `research/evals/src/slm_evals/utils/stats.py` (ensemble imports from there, or duplicate minimally to avoid cross-package coupling β prefer a tiny shared module under `research/` if both packages need it).
|
| 203 |
+
|
| 204 |
+
---
|
| 205 |
+
|
| 206 |
+
### Phase D β End-to-end workflow documentation
|
| 207 |
+
|
| 208 |
+
Add a **Verification Checklist** section to [`research/README.md`](research/README.md):
|
| 209 |
+
|
| 210 |
+
```bash
|
| 211 |
+
# 1. Baseline (preset-aware, pinned seed)
|
| 212 |
+
uv run --package slm-evals slm-benchmark \
|
| 213 |
+
--config research/evals/configs/baseline_minicpm5.yaml \
|
| 214 |
+
--preset minicpm5-1b \
|
| 215 |
+
--experiment-name minicpm5-1b__baseline
|
| 216 |
+
|
| 217 |
+
# 2. Fine-tune
|
| 218 |
+
uv run python research/finetune.py --preset minicpm5-1b --mode lora --epochs 3
|
| 219 |
+
|
| 220 |
+
# 3. Candidate eval (same config, different preset/path)
|
| 221 |
+
uv run --package slm-evals slm-benchmark \
|
| 222 |
+
--config research/evals/configs/baseline_minicpm5.yaml \
|
| 223 |
+
--preset minicpm5-1b-lesson-lora \
|
| 224 |
+
--experiment-name minicpm5-1b-lora__v1
|
| 225 |
+
|
| 226 |
+
# 4. Statistical comparison
|
| 227 |
+
uv run --package slm-evals slm-compare \
|
| 228 |
+
--baseline results/minicpm5-1b__baseline/results.json \
|
| 229 |
+
--candidate results/minicpm5-1b-lora__v1/results.json
|
| 230 |
+
|
| 231 |
+
# 5. Ensemble ablation (domain QA)
|
| 232 |
+
uv run --package ensemble python -m ensemble.eval.jepa_harness \
|
| 233 |
+
--llm openbmb/MiniCPM5-1B \
|
| 234 |
+
--qa research/data/benchmark-qa.jsonl \
|
| 235 |
+
--kb research/data/benchmark-kb.jsonl \
|
| 236 |
+
--output-dir results/ensemble/jepa-ablation-v1
|
| 237 |
+
```
|
| 238 |
+
|
| 239 |
+
**Claiming "better than baseline" requires:**
|
| 240 |
+
- Ξscore > 0 on target benchmarks
|
| 241 |
+
- `P(candidate > baseline) > 0.95` on paired per-sample scores (or document inconclusive)
|
| 242 |
+
- No regression > 2pp on OOD benchmarks unless tradeoff is explicitly claimed
|
| 243 |
+
- Ablation shows which components contribute (ensemble) or which eval split improved (finetune)
|
| 244 |
+
|
| 245 |
+
---
|
| 246 |
+
|
| 247 |
+
## File change summary
|
| 248 |
+
|
| 249 |
+
| File | Change |
|
| 250 |
+
|------|--------|
|
| 251 |
+
| [`research/evals/src/slm_evals/utils/config_loader.py`](research/evals/src/slm_evals/utils/config_loader.py) | `preset`, `eval_seed`, adapter fields |
|
| 252 |
+
| [`research/evals/src/slm_evals/utils/model_loader.py`](research/evals/src/slm_evals/utils/model_loader.py) | PEFT adapter loading |
|
| 253 |
+
| [`research/evals/src/slm_evals/benchmarks/base.py`](research/evals/src/slm_evals/benchmarks/base.py) | Seeded sample selection |
|
| 254 |
+
| `research/evals/src/slm_evals/benchmarks/education_qa.py` | New lesson-domain benchmark |
|
| 255 |
+
| `research/evals/src/slm_evals/compare_runs.py` | New comparison CLI |
|
| 256 |
+
| `research/evals/src/slm_evals/utils/stats.py` | Bootstrap + scoring helpers |
|
| 257 |
+
| [`research/evals/pyproject.toml`](research/evals/pyproject.toml) | Register `slm-compare` script |
|
| 258 |
+
| [`research/finetune.py`](research/finetune.py) | `--eval-after`, `--eval-config`, `--eval-baseline` |
|
| 259 |
+
| [`research/ensemble/src/ensemble/eval/jepa_harness.py`](research/ensemble/src/ensemble/eval/jepa_harness.py) | `--output-dir`, JSON persistence |
|
| 260 |
+
| [`research/ensemble/src/ensemble/eval/world_harness.py`](research/ensemble/src/ensemble/eval/world_harness.py) | Same |
|
| 261 |
+
| [`research/evals/configs/baseline_minicpm5.yaml`](research/evals/configs/baseline_minicpm5.yaml) | New template |
|
| 262 |
+
| [`research/USAGE.md`](research/USAGE.md), [`research/evals/USAGE.md`](research/evals/USAGE.md) | Verification workflow docs |
|
| 263 |
+
|
| 264 |
+
---
|
| 265 |
+
|
| 266 |
+
## Testing plan
|
| 267 |
+
|
| 268 |
+
- **Unit**: `paired_bootstrap` on synthetic paired arrays; sample-id alignment in compare; preset resolution for `minicpm5-1b` and `minicpm5-1b-lesson-lora`
|
| 269 |
+
- **Smoke**: `slm-benchmark --preset minicpm5-1b --benchmarks education_qa --max-samples 5 --device cpu`
|
| 270 |
+
- **Integration**: baseline β finetune (1 epoch) β re-eval β `slm-compare` produces report with no subset mismatch warnings
|
| 271 |
+
- **Ensemble**: `jepa_harness --toy --output-dir /tmp/jepa-test` writes valid JSON readable by `slm-compare`
|
| 272 |
+
|
| 273 |
+
---
|
| 274 |
+
|
| 275 |
+
## Out of scope (defer)
|
| 276 |
+
|
| 277 |
+
- W&B / MLflow experiment tracking
|
| 278 |
+
- lm-evaluation-harness integration (GSM8K/MMLU β different from agentic suite; add later as separate benchmark group if needed)
|
| 279 |
+
- Running ensemble inference through BFCL (requires adapter wrapper, not just raw HF model)
|
| 280 |
+
- CI regression gates (can add after compare CLI stabilizes)
|
.env.example
CHANGED
|
@@ -28,3 +28,28 @@ ALLOW_MODEL_SWITCH=false
|
|
| 28 |
# Optional: local fine-tuned merged weights
|
| 29 |
# ACTIVE_MODEL=gemma-merged-local
|
| 30 |
# MODEL_ID=./gemma_merged_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
# Optional: local fine-tuned merged weights
|
| 29 |
# ACTIVE_MODEL=gemma-merged-local
|
| 30 |
# MODEL_ID=./gemma_merged_model
|
| 31 |
+
|
| 32 |
+
# --- Fine-tuning (research/finetune.py) ---
|
| 33 |
+
# FINETUNE_PRESET=minicpm5-1b
|
| 34 |
+
# FINETUNE_MODEL=openbmb/MiniCPM5-1B
|
| 35 |
+
# FINETUNE_DATASET=./research/data/education-lesson-chat.jsonl
|
| 36 |
+
# FINETUNE_DATASET=tatsu-lab/alpaca
|
| 37 |
+
# FINETUNE_DATASET_CONFIG=
|
| 38 |
+
# FINETUNE_DATASET_SPLIT=train
|
| 39 |
+
# FINETUNE_MAX_SAMPLES=500
|
| 40 |
+
# FINETUNE_OUT=./models/finetuned/minicpm5-1b-lora
|
| 41 |
+
# FINETUNE_FORMAT=chat
|
| 42 |
+
# After training, point Gradio at the adapter preset:
|
| 43 |
+
# ACTIVE_MODEL=minicpm5-1b-lesson-lora
|
| 44 |
+
|
| 45 |
+
# --- Ensemble research (research/ensemble/) ---
|
| 46 |
+
# Base LLM resolution (first match wins): ENSEMBLE_LLM, LLM_PATH, BASE, MODEL_ID, ACTIVE_MODEL
|
| 47 |
+
# LLM_PATH=./models/finetuned/minicpm5-1b-lora-merged
|
| 48 |
+
# ENSEMBLE_LLM=Qwen/Qwen2.5-0.5B-Instruct
|
| 49 |
+
# ENSEMBLE_PRESET=minicpm5-1b
|
| 50 |
+
# ENSEMBLE_OUT=./models/ensemble/minicpm5-1b-jepa-pretrain
|
| 51 |
+
# ENSEMBLE_QA=./research/data/benchmark-qa.jsonl
|
| 52 |
+
# ENSEMBLE_KB=./research/data/benchmark-kb.jsonl
|
| 53 |
+
# ENSEMBLE_CKPT=./models/ensemble/jepa-lesson-pretrain
|
| 54 |
+
|
| 55 |
+
BASE=openbmb/MiniCPM5-1B
|
.gitignore
CHANGED
|
@@ -10,4 +10,6 @@ models/
|
|
| 10 |
dist/
|
| 11 |
build/
|
| 12 |
|
| 13 |
-
outputs/traces
|
|
|
|
|
|
|
|
|
| 10 |
dist/
|
| 11 |
build/
|
| 12 |
|
| 13 |
+
outputs/traces
|
| 14 |
+
|
| 15 |
+
/results
|
README.md
CHANGED
|
@@ -46,8 +46,11 @@ apps/gradio-space/ # Gradio tabs (Lesson slides + Chat debug)
|
|
| 46 |
libs/agent/ # Skill agent runner, tools, trace recorder
|
| 47 |
libs/inference/ # Transformers + llama.cpp backends
|
| 48 |
skills/ # SKILL.md task definitions
|
|
|
|
| 49 |
```
|
| 50 |
|
|
|
|
|
|
|
| 51 |
## Environment variables
|
| 52 |
|
| 53 |
| Variable | Default | Description |
|
|
|
|
| 46 |
libs/agent/ # Skill agent runner, tools, trace recorder
|
| 47 |
libs/inference/ # Transformers + llama.cpp backends
|
| 48 |
skills/ # SKILL.md task definitions
|
| 49 |
+
research/ # Fine-tune, ensemble experiments, agentic evals (optional)
|
| 50 |
```
|
| 51 |
|
| 52 |
+
Optional research tooling (not required for the Space): see [research/USAGE.md](research/USAGE.md).
|
| 53 |
+
|
| 54 |
## Environment variables
|
| 55 |
|
| 56 |
| Variable | Default | Description |
|
libs/inference/pyproject.toml
CHANGED
|
@@ -14,6 +14,7 @@ dependencies = [
|
|
| 14 |
"pyyaml>=6.0.2",
|
| 15 |
"torch>=2.5.0",
|
| 16 |
"torchvision>=0.20.0",
|
|
|
|
| 17 |
"transformers>=5.7.0",
|
| 18 |
]
|
| 19 |
|
|
|
|
| 14 |
"pyyaml>=6.0.2",
|
| 15 |
"torch>=2.5.0",
|
| 16 |
"torchvision>=0.20.0",
|
| 17 |
+
"peft>=0.14.0",
|
| 18 |
"transformers>=5.7.0",
|
| 19 |
]
|
| 20 |
|
libs/inference/src/inference/config.py
CHANGED
|
@@ -23,6 +23,7 @@ class ModelConfig:
|
|
| 23 |
model_file: str | None = None
|
| 24 |
model_path: str | None = None
|
| 25 |
model_id: str | None = None
|
|
|
|
| 26 |
trust_remote_code: bool = False
|
| 27 |
multimodal: bool = False
|
| 28 |
n_ctx: int = 4096
|
|
@@ -37,6 +38,7 @@ class ModelConfig:
|
|
| 37 |
self.model_file,
|
| 38 |
self.model_path,
|
| 39 |
self.model_id,
|
|
|
|
| 40 |
self.trust_remote_code,
|
| 41 |
self.multimodal,
|
| 42 |
self.n_ctx,
|
|
@@ -60,6 +62,9 @@ class ModelConfig:
|
|
| 60 |
if self.model_id and self.model_id.startswith(("./", "../")):
|
| 61 |
updates["model_id"] = str((base_dir / self.model_id).resolve())
|
| 62 |
|
|
|
|
|
|
|
|
|
|
| 63 |
return replace(self, **updates) if updates else self
|
| 64 |
|
| 65 |
|
|
@@ -149,6 +154,7 @@ def _parse_model_entry(key: str, raw: dict[str, Any]) -> ModelConfig:
|
|
| 149 |
model_file=raw.get("model_file"),
|
| 150 |
model_path=raw.get("model_path"),
|
| 151 |
model_id=raw.get("model_id"),
|
|
|
|
| 152 |
trust_remote_code=bool(raw.get("trust_remote_code", False)),
|
| 153 |
multimodal=bool(raw.get("multimodal", False)),
|
| 154 |
n_ctx=int(raw.get("n_ctx", 4096)),
|
|
|
|
| 23 |
model_file: str | None = None
|
| 24 |
model_path: str | None = None
|
| 25 |
model_id: str | None = None
|
| 26 |
+
adapter_path: str | None = None
|
| 27 |
trust_remote_code: bool = False
|
| 28 |
multimodal: bool = False
|
| 29 |
n_ctx: int = 4096
|
|
|
|
| 38 |
self.model_file,
|
| 39 |
self.model_path,
|
| 40 |
self.model_id,
|
| 41 |
+
self.adapter_path,
|
| 42 |
self.trust_remote_code,
|
| 43 |
self.multimodal,
|
| 44 |
self.n_ctx,
|
|
|
|
| 62 |
if self.model_id and self.model_id.startswith(("./", "../")):
|
| 63 |
updates["model_id"] = str((base_dir / self.model_id).resolve())
|
| 64 |
|
| 65 |
+
if self.adapter_path and self.adapter_path.startswith(("./", "../")):
|
| 66 |
+
updates["adapter_path"] = str((base_dir / self.adapter_path).resolve())
|
| 67 |
+
|
| 68 |
return replace(self, **updates) if updates else self
|
| 69 |
|
| 70 |
|
|
|
|
| 154 |
model_file=raw.get("model_file"),
|
| 155 |
model_path=raw.get("model_path"),
|
| 156 |
model_id=raw.get("model_id"),
|
| 157 |
+
adapter_path=raw.get("adapter_path"),
|
| 158 |
trust_remote_code=bool(raw.get("trust_remote_code", False)),
|
| 159 |
multimodal=bool(raw.get("multimodal", False)),
|
| 160 |
n_ctx=int(raw.get("n_ctx", 4096)),
|
libs/inference/src/inference/transformers.py
CHANGED
|
@@ -76,6 +76,18 @@ class TransformersBackend:
|
|
| 76 |
**model_kwargs,
|
| 77 |
)
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
if device == "cpu":
|
| 80 |
self._model.to(device)
|
| 81 |
|
|
|
|
| 76 |
**model_kwargs,
|
| 77 |
)
|
| 78 |
|
| 79 |
+
if self._config.adapter_path:
|
| 80 |
+
from pathlib import Path
|
| 81 |
+
|
| 82 |
+
from peft import PeftModel
|
| 83 |
+
|
| 84 |
+
adapter = Path(self._config.adapter_path)
|
| 85 |
+
if not adapter.is_dir():
|
| 86 |
+
raise FileNotFoundError(
|
| 87 |
+
f"LoRA adapter not found for preset {self._config.key!r}: {adapter}"
|
| 88 |
+
)
|
| 89 |
+
self._model = PeftModel.from_pretrained(self._model, str(adapter))
|
| 90 |
+
|
| 91 |
if device == "cpu":
|
| 92 |
self._model.to(device)
|
| 93 |
|
models.yaml
CHANGED
|
@@ -2,9 +2,8 @@
|
|
| 2 |
# Select active preset with ACTIVE_MODEL; override any field via .env (see .env.example).
|
| 3 |
|
| 4 |
defaults:
|
| 5 |
-
active_model: minicpm-v-4.6
|
| 6 |
-
|
| 7 |
-
# active_model: minicpm5-1b
|
| 8 |
|
| 9 |
# Dev: set ALLOW_MODEL_SWITCH=true in .env to expose a dropdown in Gradio.
|
| 10 |
# Space: keep false so visitors use one pinned model.
|
|
@@ -55,3 +54,22 @@ models:
|
|
| 55 |
label: Fine-tuned LoRA adapter (local path)
|
| 56 |
backend: transformers
|
| 57 |
model_id: ./gemma_finetuned_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
# Select active preset with ACTIVE_MODEL; override any field via .env (see .env.example).
|
| 3 |
|
| 4 |
defaults:
|
| 5 |
+
# active_model: minicpm-v-4.6
|
| 6 |
+
active_model: minicpm5-1b
|
|
|
|
| 7 |
|
| 8 |
# Dev: set ALLOW_MODEL_SWITCH=true in .env to expose a dropdown in Gradio.
|
| 9 |
# Space: keep false so visitors use one pinned model.
|
|
|
|
| 54 |
label: Fine-tuned LoRA adapter (local path)
|
| 55 |
backend: transformers
|
| 56 |
model_id: ./gemma_finetuned_model
|
| 57 |
+
|
| 58 |
+
minicpm5-1b-lesson-lora:
|
| 59 |
+
label: MiniCPM5 1B lesson LoRA (local)
|
| 60 |
+
backend: transformers
|
| 61 |
+
model_id: openbmb/MiniCPM5-1B
|
| 62 |
+
adapter_path: ./models/finetuned/minicpm5-1b-lora
|
| 63 |
+
trust_remote_code: true
|
| 64 |
+
|
| 65 |
+
minicpm5-1b-lesson-merged:
|
| 66 |
+
label: MiniCPM5 1B lesson merged (local)
|
| 67 |
+
backend: transformers
|
| 68 |
+
model_id: ./models/finetuned/minicpm5-1b-lora-merged
|
| 69 |
+
trust_remote_code: true
|
| 70 |
+
|
| 71 |
+
jepa-ensemble-lesson:
|
| 72 |
+
label: JEPA ensemble (LLM + emb + JEPA) lesson pretrain
|
| 73 |
+
backend: transformers
|
| 74 |
+
model_id: ./models/ensemble/jepa-lesson-pretrain
|
| 75 |
+
trust_remote_code: true
|
pyproject.toml
CHANGED
|
@@ -6,6 +6,7 @@ readme = "README.md"
|
|
| 6 |
requires-python = ">=3.12"
|
| 7 |
dependencies = [
|
| 8 |
"agent",
|
|
|
|
| 9 |
"gradio-space",
|
| 10 |
"inference",
|
| 11 |
]
|
|
@@ -15,14 +16,34 @@ dev = [
|
|
| 15 |
"ruff>=0.9.0",
|
| 16 |
"pytest>=8.0.0",
|
| 17 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
[tool.uv.workspace]
|
| 20 |
members = [
|
| 21 |
"apps/*",
|
| 22 |
"libs/*",
|
|
|
|
|
|
|
| 23 |
]
|
| 24 |
|
| 25 |
[tool.uv.sources]
|
| 26 |
agent = { workspace = true }
|
|
|
|
| 27 |
gradio-space = { workspace = true }
|
| 28 |
inference = { workspace = true }
|
|
|
|
|
|
| 6 |
requires-python = ">=3.12"
|
| 7 |
dependencies = [
|
| 8 |
"agent",
|
| 9 |
+
"ensemble",
|
| 10 |
"gradio-space",
|
| 11 |
"inference",
|
| 12 |
]
|
|
|
|
| 16 |
"ruff>=0.9.0",
|
| 17 |
"pytest>=8.0.0",
|
| 18 |
]
|
| 19 |
+
finetune = [
|
| 20 |
+
"bitsandbytes>=0.46.1",
|
| 21 |
+
"datasets>=3.0.0",
|
| 22 |
+
"peft>=0.14.0",
|
| 23 |
+
]
|
| 24 |
+
ensemble = [
|
| 25 |
+
"accelerate>=1.2.0",
|
| 26 |
+
"peft>=0.14.0",
|
| 27 |
+
"transformers>=5.7.0",
|
| 28 |
+
]
|
| 29 |
+
evals = [
|
| 30 |
+
"slm-evals",
|
| 31 |
+
]
|
| 32 |
+
lm-eval = [
|
| 33 |
+
"slm-evals[lm-eval]",
|
| 34 |
+
]
|
| 35 |
|
| 36 |
[tool.uv.workspace]
|
| 37 |
members = [
|
| 38 |
"apps/*",
|
| 39 |
"libs/*",
|
| 40 |
+
"research/ensemble",
|
| 41 |
+
"research/evals",
|
| 42 |
]
|
| 43 |
|
| 44 |
[tool.uv.sources]
|
| 45 |
agent = { workspace = true }
|
| 46 |
+
ensemble = { workspace = true }
|
| 47 |
gradio-space = { workspace = true }
|
| 48 |
inference = { workspace = true }
|
| 49 |
+
slm-evals = { workspace = true }
|
research/README.md
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Research
|
| 2 |
+
|
| 3 |
+
Experimental code for **fine-tuning**, **ensemble architectures**, and **agentic benchmarks**. Nothing here is wired into the Gradio Lesson Agent by default β use it to train models, probe JEPA/world-model ideas, and score checkpoints against public benchmarks.
|
| 4 |
+
|
| 5 |
+
| Path | Purpose |
|
| 6 |
+
| ---- | ------- |
|
| 7 |
+
| [`finetune.py`](finetune.py) | LoRA / QLoRA / full fine-tune on chat or instruction data |
|
| 8 |
+
| [`ensemble/`](ensemble/) | JEPA + world-model ensemble experiments (uv package `ensemble`) |
|
| 9 |
+
| [`evals/`](evals/) | SLM agentic benchmark suite β BFCL, Ο-bench, GAIA, SWE-bench (uv package `slm-evals`) |
|
| 10 |
+
| [`data/`](data/) | Shared JSONL datasets for finetune and ensemble harnesses |
|
| 11 |
+
|
| 12 |
+
## Quick links
|
| 13 |
+
|
| 14 |
+
- **[USAGE.md](USAGE.md)** β install groups, commands, and typical workflows
|
| 15 |
+
- **[docs/overview.md](docs/overview.md)** β how the pieces fit together
|
| 16 |
+
- **[ensemble/README.md](ensemble/README.md)** β ensemble smoke tests and harnesses
|
| 17 |
+
- **[evals/USAGE.md](evals/USAGE.md)** β benchmark CLI, configs, and results
|
| 18 |
+
- **[evals/docs/benchmarks.md](evals/docs/benchmarks.md)** β what each benchmark measures
|
| 19 |
+
|
| 20 |
+
## Install (from repo root)
|
| 21 |
+
|
| 22 |
+
```bash
|
| 23 |
+
# Everything you need for research scripts
|
| 24 |
+
uv sync --group finetune --group ensemble --group evals
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
Individual groups:
|
| 28 |
+
|
| 29 |
+
| Group | Command | Enables |
|
| 30 |
+
| ----- | ------- | ------- |
|
| 31 |
+
| `finetune` | `uv sync --group finetune` | `research/finetune.py` (LoRA, QLoRA, merge) |
|
| 32 |
+
| `ensemble` | `uv sync --group ensemble` | `research/ensemble/` package |
|
| 33 |
+
| `evals` | `uv sync --group evals` | `research/evals/` package (`slm-benchmark`) |
|
| 34 |
+
|
| 35 |
+
## Typical workflow
|
| 36 |
+
|
| 37 |
+
```text
|
| 38 |
+
research/data/education-lesson-chat.jsonl
|
| 39 |
+
β
|
| 40 |
+
βΌ
|
| 41 |
+
research/finetune.py βββΊ models/finetuned/<preset>-lora/
|
| 42 |
+
β
|
| 43 |
+
ββββΊ research/evals/ (BFCL, Ο-bench, GAIA, SWE-bench)
|
| 44 |
+
β
|
| 45 |
+
ββββΊ research/ensemble/ (JEPA / world-model ablations)
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
See [USAGE.md](USAGE.md) for copy-paste commands.
|
research/USAGE.md
ADDED
|
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Research usage
|
| 2 |
+
|
| 3 |
+
How to run fine-tuning, ensemble experiments, and agentic benchmarks under `research/`. All commands assume the **repo root** as the working directory unless noted.
|
| 4 |
+
|
| 5 |
+
The Lesson Agent app lives in `apps/gradio-space/` β see root [USAGE.md](../USAGE.md). Research code is optional and isolated here.
|
| 6 |
+
|
| 7 |
+
## Prerequisites
|
| 8 |
+
|
| 9 |
+
- [uv](https://docs.astral.sh/uv/) and Python 3.12
|
| 10 |
+
- GPU recommended for real-model runs (CPU works for smoke tests and `tiny` backends)
|
| 11 |
+
- Hugging Face Hub access for model downloads and some benchmark datasets
|
| 12 |
+
|
| 13 |
+
## Install dependency groups
|
| 14 |
+
|
| 15 |
+
```bash
|
| 16 |
+
# All research tooling
|
| 17 |
+
uv sync --group finetune --group ensemble --group evals --group lm-eval
|
| 18 |
+
|
| 19 |
+
# Or one at a time
|
| 20 |
+
uv sync --group finetune
|
| 21 |
+
uv sync --group ensemble
|
| 22 |
+
uv sync --group evals
|
| 23 |
+
uv sync --group lm-eval
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
| Group | Package / script | What it adds |
|
| 27 |
+
| ----- | ---------------- | ------------ |
|
| 28 |
+
| `finetune` | `research/finetune.py` | `peft`, `datasets`, `bitsandbytes` (QLoRA) |
|
| 29 |
+
| `ensemble` | `ensemble` workspace member | JEPA / world-model ensemble + harnesses |
|
| 30 |
+
| `evals` | `slm-evals` workspace member | `slm-benchmark` CLI |
|
| 31 |
+
| `lm-eval` | `slm-evals[lm-eval]` | `slm-lm-eval` CLI (GSM8K, ARC, HellaSwag, β¦) |
|
| 32 |
+
|
| 33 |
+
---
|
| 34 |
+
|
| 35 |
+
## 1. Fine-tuning (`research/finetune.py`)
|
| 36 |
+
|
| 37 |
+
Single script for **full**, **LoRA**, and **QLoRA** training. Defaults to the lesson-agent chat dataset at `research/data/education-lesson-chat.jsonl` and writes checkpoints under `models/finetuned/`.
|
| 38 |
+
|
| 39 |
+
### Model resolution (first match wins)
|
| 40 |
+
|
| 41 |
+
1. `--model <hf-id-or-path>`
|
| 42 |
+
2. `--preset <key>` from root `models.yaml`
|
| 43 |
+
3. Env: `FINETUNE_MODEL`, `MODEL_ID`, or `BASE`
|
| 44 |
+
4. `ACTIVE_MODEL` preset from `.env`
|
| 45 |
+
|
| 46 |
+
### Quick start
|
| 47 |
+
|
| 48 |
+
```bash
|
| 49 |
+
# LoRA on default lesson chat data, 1 epoch
|
| 50 |
+
uv run python research/finetune.py --preset minicpm5-1b --mode lora --epochs 1
|
| 51 |
+
|
| 52 |
+
# Smoke run (50 steps)
|
| 53 |
+
uv run python research/finetune.py --mode lora --max_steps 50
|
| 54 |
+
|
| 55 |
+
# QLoRA on a Hub instruction dataset
|
| 56 |
+
uv run python research/finetune.py \
|
| 57 |
+
--model Qwen/Qwen2.5-0.5B-Instruct \
|
| 58 |
+
--dataset tatsu-lab/alpaca --format alpaca \
|
| 59 |
+
--mode qlora --epochs 1
|
| 60 |
+
|
| 61 |
+
# Merge LoRA adapter into standalone weights
|
| 62 |
+
uv run python research/finetune.py \
|
| 63 |
+
--merge ./models/finetuned/minicpm5-1b-lora \
|
| 64 |
+
--out ./models/finetuned/minicpm5-1b-merged
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
### Dataset formats (`--format`)
|
| 68 |
+
|
| 69 |
+
| Format | Expected columns |
|
| 70 |
+
| ------ | ---------------- |
|
| 71 |
+
| `chat` | `messages`: `[{"role": "...", "content": "..."}]` |
|
| 72 |
+
| `alpaca` | `instruction`, optional `input`, `output` |
|
| 73 |
+
| `prompt` | `prompt` / `completion` (or `response`) |
|
| 74 |
+
| `text` | `text`, or a plain `.txt` file |
|
| 75 |
+
|
| 76 |
+
Local files: `.json`, `.jsonl`, `.csv`, `.txt`. Hub ids: any `datasets` repo id.
|
| 77 |
+
|
| 78 |
+
### Outputs
|
| 79 |
+
|
| 80 |
+
Training writes to `<out>/` (default `./models/finetuned/<preset>-<mode>/`):
|
| 81 |
+
|
| 82 |
+
- Adapter or full weights
|
| 83 |
+
- `training_results.json` β train/eval loss, perplexity, `result_score` (0β100)
|
| 84 |
+
|
| 85 |
+
### Env vars
|
| 86 |
+
|
| 87 |
+
| Variable | Description |
|
| 88 |
+
| -------- | ----------- |
|
| 89 |
+
| `FINETUNE_PRESET` | Preset key from `models.yaml` |
|
| 90 |
+
| `FINETUNE_DATASET` | Override dataset path or Hub id |
|
| 91 |
+
| `FINETUNE_DATASET_CONFIG` | Hub config name |
|
| 92 |
+
| `FINETUNE_DATASET_SPLIT` | Hub split (e.g. `train[:500]`) |
|
| 93 |
+
| `ACTIVE_MODEL` | Fallback preset when `--preset` omitted |
|
| 94 |
+
|
| 95 |
+
---
|
| 96 |
+
|
| 97 |
+
## 2. Ensemble experiments (`research/ensemble/`)
|
| 98 |
+
|
| 99 |
+
JEPA and world-model ensemble prototypes: small LLM + embedding memory + latent predictors + energy-based draft selection. **Not connected to the Gradio app.**
|
| 100 |
+
|
| 101 |
+
Install: `uv sync --group ensemble`
|
| 102 |
+
|
| 103 |
+
### Tier 1 β CPU smoke (no Hub download)
|
| 104 |
+
|
| 105 |
+
```bash
|
| 106 |
+
uv run --package ensemble python -m ensemble.jepa_ensemble tiny
|
| 107 |
+
uv run --package ensemble python -m ensemble.world_ensemble tiny
|
| 108 |
+
bash research/ensemble/scripts/smoke.sh
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
### Tier 2 β Real small model
|
| 112 |
+
|
| 113 |
+
```bash
|
| 114 |
+
uv run --package ensemble python -m ensemble.jepa_ensemble Qwen/Qwen2.5-0.5B-Instruct
|
| 115 |
+
uv run --package ensemble python -m ensemble.world_ensemble Qwen/Qwen2.5-0.5B-Instruct
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
### Pretrain + save (LLM + emb + JEPA)
|
| 119 |
+
|
| 120 |
+
```bash
|
| 121 |
+
# Default LLM: ENSEMBLE_LLM β LLM_PATH β BASE β MODEL_ID β ACTIVE_MODEL (models.yaml)
|
| 122 |
+
uv run --package ensemble ensemble-pretrain --steps 200
|
| 123 |
+
|
| 124 |
+
# Or override
|
| 125 |
+
uv run --package ensemble ensemble-pretrain \
|
| 126 |
+
--llm Qwen/Qwen2.5-0.5B-Instruct \
|
| 127 |
+
--steps 200
|
| 128 |
+
|
| 129 |
+
# Benchmark saved ensemble with slm-evals (compare to base HF model)
|
| 130 |
+
uv run --package slm-evals slm-benchmark \
|
| 131 |
+
--model ./models/ensemble/jepa-lesson-pretrain \
|
| 132 |
+
--model-type ensemble \
|
| 133 |
+
--benchmarks bfcl tau_bench --max-samples 20
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
Checkpoint files: `manifest.json`, `aux.pt`, `llm/` (PEFT adapters), optional `store.pt`.
|
| 137 |
+
|
| 138 |
+
### Tier 3 β Benchmark harnesses
|
| 139 |
+
|
| 140 |
+
Uses `research/data/benchmark-qa.jsonl` (questions) and `benchmark-kb.jsonl` (retrieval snippets).
|
| 141 |
+
|
| 142 |
+
```bash
|
| 143 |
+
# JEPA track β toy
|
| 144 |
+
uv run --package ensemble python -m ensemble.eval.jepa_harness \
|
| 145 |
+
--llm tiny --toy --limit 20 --n_drafts 8
|
| 146 |
+
|
| 147 |
+
# JEPA track β education QA
|
| 148 |
+
uv run --package ensemble python -m ensemble.eval.jepa_harness \
|
| 149 |
+
--llm Qwen/Qwen2.5-0.5B-Instruct \
|
| 150 |
+
--qa research/data/benchmark-qa.jsonl \
|
| 151 |
+
--kb research/data/benchmark-kb.jsonl \
|
| 152 |
+
--limit 50 --n_drafts 8
|
| 153 |
+
|
| 154 |
+
# World-model track
|
| 155 |
+
uv run --package ensemble python -m ensemble.eval.world_harness \
|
| 156 |
+
--llm tiny --toy --limit 20 --n_drafts 8
|
| 157 |
+
```
|
| 158 |
+
|
| 159 |
+
More detail: [ensemble/README.md](ensemble/README.md), [docs/overview.md](docs/overview.md).
|
| 160 |
+
|
| 161 |
+
### Legacy shims
|
| 162 |
+
|
| 163 |
+
Top-level files re-export the package for old scripts:
|
| 164 |
+
|
| 165 |
+
- `research/llm_emb_jepa_ensemble_pluggable.py` β `ensemble.jepa_ensemble`
|
| 166 |
+
- `research/world_model_ensemble.py` β `ensemble.world_ensemble`
|
| 167 |
+
- `research/eval_harness.py` β `ensemble.eval.jepa_harness`
|
| 168 |
+
|
| 169 |
+
Prefer `uv run --package ensemble python -m ensemble.<module>`.
|
| 170 |
+
|
| 171 |
+
---
|
| 172 |
+
|
| 173 |
+
## 3. Agentic benchmarks (`research/evals/`)
|
| 174 |
+
|
| 175 |
+
Evaluate local HuggingFace checkpoints on BFCL, Ο-bench, GAIA, and SWE-bench Verified.
|
| 176 |
+
|
| 177 |
+
Install: `uv sync --group evals`
|
| 178 |
+
|
| 179 |
+
```bash
|
| 180 |
+
# Smoke test (20 samples, two benchmarks)
|
| 181 |
+
uv run --package slm-evals slm-benchmark \
|
| 182 |
+
--model openbmb/MiniCPM5-1B \
|
| 183 |
+
--benchmarks bfcl tau_bench \
|
| 184 |
+
--max-samples 20
|
| 185 |
+
|
| 186 |
+
# Full config-driven run
|
| 187 |
+
uv run --package slm-evals slm-benchmark \
|
| 188 |
+
--config research/evals/configs/experiment_001.yaml
|
| 189 |
+
```
|
| 190 |
+
|
| 191 |
+
Full reference: [evals/USAGE.md](evals/USAGE.md).
|
| 192 |
+
|
| 193 |
+
---
|
| 194 |
+
|
| 195 |
+
## 4. Academic benchmarks (`slm-lm-eval`)
|
| 196 |
+
|
| 197 |
+
Standard lm-evaluation-harness tasks (ARC, HellaSwag, GSM8K, β¦) for base presets, LoRA adapters, merged checkpoints, and ensemble manifests.
|
| 198 |
+
|
| 199 |
+
Install: `uv sync --group lm-eval`
|
| 200 |
+
|
| 201 |
+
Profile guide: [evals/docs/eval_profiles.md](evals/docs/eval_profiles.md)
|
| 202 |
+
|
| 203 |
+
```bash
|
| 204 |
+
# List claim-matched profiles (reasoning, code, understanding, β¦)
|
| 205 |
+
uv run --package slm-evals slm-lm-eval --list-profiles
|
| 206 |
+
|
| 207 |
+
# Run by profile name
|
| 208 |
+
uv run --package slm-evals slm-lm-eval \
|
| 209 |
+
--profile reasoning \
|
| 210 |
+
--preset minicpm5-1b \
|
| 211 |
+
--experiment-name minicpm5-1b__reasoning-baseline
|
| 212 |
+
|
| 213 |
+
# Smoke (25 samples, arc_easy + hellaswag)
|
| 214 |
+
uv run --package slm-evals slm-lm-eval \
|
| 215 |
+
--profile smoke \
|
| 216 |
+
--preset minicpm5-1b \
|
| 217 |
+
--experiment-name minicpm5-1b__smoke
|
| 218 |
+
|
| 219 |
+
# Full profile
|
| 220 |
+
uv run --package slm-evals slm-lm-eval \
|
| 221 |
+
--config research/evals/configs/lm_eval_minicpm5.yaml \
|
| 222 |
+
--preset minicpm5-1b-lesson-lora \
|
| 223 |
+
--experiment-name minicpm5-1b-lora__v1 \
|
| 224 |
+
--compare-to results/lm_eval/minicpm5-1b__baseline/results.json
|
| 225 |
+
|
| 226 |
+
# Ensemble checkpoint
|
| 227 |
+
uv run --package slm-evals slm-lm-eval \
|
| 228 |
+
--config research/evals/configs/lm_eval_smoke.yaml \
|
| 229 |
+
--model ./models/ensemble/jepa-lesson-pretrain \
|
| 230 |
+
--experiment-name ensemble-jepa__lm-eval
|
| 231 |
+
```
|
| 232 |
+
|
| 233 |
+
Post-training hook:
|
| 234 |
+
|
| 235 |
+
```bash
|
| 236 |
+
uv run python research/finetune.py \
|
| 237 |
+
--preset minicpm5-1b --mode lora --max_steps 50 \
|
| 238 |
+
--lm-eval-after \
|
| 239 |
+
--lm-eval-baseline minicpm5-1b
|
| 240 |
+
```
|
| 241 |
+
|
| 242 |
+
Full reference: [evals/USAGE.md](evals/USAGE.md#lm-evaluation-harness-slm-lm-eval).
|
| 243 |
+
|
| 244 |
+
---
|
| 245 |
+
|
| 246 |
+
## Shared data (`research/data/`)
|
| 247 |
+
|
| 248 |
+
| File | Used by | Format |
|
| 249 |
+
| ---- | ------- | ------ |
|
| 250 |
+
| `education-lesson-chat.jsonl` | `finetune.py` default | Chat messages for lesson agent |
|
| 251 |
+
| `benchmark-qa.jsonl` | Ensemble harnesses | `question`, `answer`, `domain` |
|
| 252 |
+
| `benchmark-kb.jsonl` | Ensemble harnesses | Retrieval snippets for memory routing |
|
| 253 |
+
|
| 254 |
+
---
|
| 255 |
+
|
| 256 |
+
## Suggested end-to-end pipeline
|
| 257 |
+
|
| 258 |
+
1. **Baseline lm-eval** β academic benchmarks on the base preset (pinned seed):
|
| 259 |
+
```bash
|
| 260 |
+
uv run --package slm-evals slm-lm-eval \
|
| 261 |
+
--config research/evals/configs/lm_eval_compare_study.yaml \
|
| 262 |
+
--preset minicpm5-1b \
|
| 263 |
+
--experiment-name minicpm5-1b__baseline
|
| 264 |
+
```
|
| 265 |
+
|
| 266 |
+
2. **Baseline agentic eval** (optional):
|
| 267 |
+
```bash
|
| 268 |
+
uv run --package slm-evals slm-benchmark \
|
| 269 |
+
--model openbmb/MiniCPM5-1B --benchmarks bfcl --max-samples 50
|
| 270 |
+
```
|
| 271 |
+
|
| 272 |
+
3. **Fine-tune** on lesson data:
|
| 273 |
+
```bash
|
| 274 |
+
uv run python research/finetune.py --preset minicpm5-1b --mode lora --epochs 1
|
| 275 |
+
```
|
| 276 |
+
|
| 277 |
+
4. **Re-eval candidate** with the same lm-eval config:
|
| 278 |
+
```bash
|
| 279 |
+
uv run --package slm-evals slm-lm-eval \
|
| 280 |
+
--config research/evals/configs/lm_eval_compare_study.yaml \
|
| 281 |
+
--preset minicpm5-1b-lesson-lora \
|
| 282 |
+
--experiment-name minicpm5-1b-lora__v1 \
|
| 283 |
+
--compare-to results/lm_eval/minicpm5-1b__baseline/results.json
|
| 284 |
+
```
|
| 285 |
+
|
| 286 |
+
5. **Optional** β probe ensemble ideas on the same QA/KB files:
|
| 287 |
+
```bash
|
| 288 |
+
bash research/ensemble/scripts/smoke.sh
|
| 289 |
+
```
|
| 290 |
+
|
| 291 |
+
### Verification checklist
|
| 292 |
+
|
| 293 |
+
- Use the **same** lm-eval YAML (`tasks`, `num_fewshot`, `limit`, `seed`) for baseline and candidate runs.
|
| 294 |
+
- Compare lm-eval `results.json` files with `--compare-to`; do not compare `training_results.json` `result_score` to lm-eval accuracy.
|
| 295 |
+
- For LoRA checkpoints, prefer `--preset minicpm5-1b-lesson-lora` (base + adapter) over passing the adapter dir alone to `--model`.
|
| 296 |
+
- Report mean Β± std only after multiple training seeds; single-seed deltas are indicative, not conclusive.
|
| 297 |
+
- Ensemble `loglikelihood` tasks score the underlying LLM head; generative tasks (`gsm8k`) use the full JEPA+RAG stack.
|
| 298 |
+
|
| 299 |
+
---
|
| 300 |
+
|
| 301 |
+
## Troubleshooting
|
| 302 |
+
|
| 303 |
+
| Symptom | Fix |
|
| 304 |
+
| ------- | --- |
|
| 305 |
+
| `No module named 'ensemble'` | `uv sync --group ensemble` |
|
| 306 |
+
| `slm-benchmark: command not found` | `uv sync --group evals` |
|
| 307 |
+
| `slm-lm-eval: command not found` | `uv sync --group lm-eval` |
|
| 308 |
+
| CUDA OOM during finetune | Use `--mode qlora` or reduce batch size in script args |
|
| 309 |
+
| BFCL / GAIA download slow | Set `max_samples` low first; cache HF datasets under `~/.cache/huggingface` |
|
| 310 |
+
| SWE-bench Docker errors | Keep `full_eval: false` in YAML unless `swebench` + Docker are installed |
|
| 311 |
+
| Ο-bench API costs | Keep `use_llm_user: false` (rule-based user simulator) |
|
research/data/benchmark-kb.jsonl
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"text": "Photosynthesis: plants use sunlight, water, and CO2 to make glucose and release oxygen."}
|
| 2 |
+
{"text": "Fractions: the top number counts parts and the bottom number shows total equal parts."}
|
| 3 |
+
{"text": "Water cycle: evaporation turns liquid water into vapor; condensation forms clouds."}
|
| 4 |
+
{"text": "Precipitation returns water to land and bodies of water as rain, snow, sleet, or hail."}
|
| 5 |
+
{"text": "AI agents observe state, decide actions, and update plans using tools and memory."}
|
| 6 |
+
{"text": "American Revolution: colonists protested taxation without representation and trade restrictions."}
|
| 7 |
+
{"text": "Boston Tea Party: colonists dumped tea to protest the Tea Act in 1773."}
|
| 8 |
+
{"text": "Enlightenment ideas about rights and consent spread through pamphlets and speeches."}
|
| 9 |
+
{"text": "One-half means one of two equal parts; one-fourth means one of four equal parts."}
|
| 10 |
+
{"text": "Runoff carries water from land back toward rivers, lakes, and oceans."}
|
research/data/benchmark-qa.jsonl
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"question": "What is photosynthesis?", "answer": "Plants make food using sunlight, water, and carbon dioxide.", "domain": "science"}
|
| 2 |
+
{"question": "What do plants produce during photosynthesis?", "answer": "Glucose and oxygen.", "domain": "science"}
|
| 3 |
+
{"question": "What is a fraction?", "answer": "A fraction shows equal parts of a whole.", "domain": "math"}
|
| 4 |
+
{"question": "What does one-half mean?", "answer": "One of two equal parts.", "domain": "math"}
|
| 5 |
+
{"question": "What is evaporation in the water cycle?", "answer": "Heat turns liquid water into water vapor.", "domain": "science"}
|
| 6 |
+
{"question": "What is precipitation?", "answer": "Water falls as rain, snow, sleet, or hail.", "domain": "science"}
|
| 7 |
+
{"question": "What is an AI agent?", "answer": "Software that takes goals and acts to reach them.", "domain": "technology"}
|
| 8 |
+
{"question": "What is a colonial grievance before the American Revolution?", "answer": "Taxation without representation.", "domain": "history"}
|
| 9 |
+
{"question": "What happened at the Boston Tea Party?", "answer": "Colonists protested the Tea Act.", "domain": "history"}
|
| 10 |
+
{"question": "Why does photosynthesis matter?", "answer": "It feeds nearly all life on Earth and keeps the atmosphere balanced.", "domain": "science"}
|
research/data/education-lesson-chat.jsonl
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"messages": [{"role": "system", "content": "You are a lesson-planning assistant for teachers.\nFollow the skill workflow below and output ONLY valid JSON (no markdown fences).\n\nSkill workflow:\n1. Ask for topic, audience grade, and slide count (3β8 content slides).\n2. Produce a JSON outline with `title` and `slides` (each slide has `title`, `bullets`, `speaker_note`).\n3. Call `create_pptx` with the validated outline.\n4. Return a download link and markdown preview for the teacher.\n\nJSON schema:\n{\n \"title\": \"string β presentation title\",\n \"slides\": [\n {\n \"title\": \"string β slide heading\",\n \"bullets\": [\"string\", \"...\"],\n \"speaker_note\": \"string β one sentence for the teacher\"\n }\n ]\n}\n\nRules:\n- Use exactly the requested number of content slides (title slide is added separately by the tool).\n- Bullets should be short, age-appropriate, and factual.\n- speaker_note is optional but helpful for each slide."}, {"role": "user", "content": "Topic: Photosynthesis\nGrade level: 6\nNumber of content slides: 4\nReturn JSON only."}, {"role": "assistant", "content": "{\"title\": \"Photosynthesis for 6th Grade\", \"slides\": [{\"title\": \"What Is Photosynthesis?\", \"bullets\": [\"Plants make their own food using sunlight\", \"The process happens mainly in leaves\"], \"speaker_note\": \"Ask students what they already know about how plants eat.\"}, {\"title\": \"Ingredients Plants Need\", \"bullets\": [\"Sunlight provides energy\", \"Water comes from roots and carbon dioxide from air\"], \"speaker_note\": \"Draw a simple leaf diagram while you explain inputs.\"}, {\"title\": \"What Plants Produce\", \"bullets\": [\"Glucose is the sugar plants use for energy\", \"Oxygen is released into the air\"], \"speaker_note\": \"Connect oxygen release to why forests matter for breathing.\"}, {\"title\": \"Why It Matters\", \"bullets\": [\"Photosynthesis feeds nearly all life on Earth\", \"It helps keep our atmosphere balanced\"], \"speaker_note\": \"End with a quick pair-share: one reason photosynthesis matters.\"}]}"}]}
|
| 2 |
+
{"messages": [{"role": "system", "content": "You are a lesson-planning assistant for teachers.\nFollow the skill workflow below and output ONLY valid JSON (no markdown fences).\n\nSkill workflow:\n1. Ask for topic, audience grade, and slide count (3β8 content slides).\n2. Produce a JSON outline with `title` and `slides` (each slide has `title`, `bullets`, `speaker_note`).\n3. Call `create_pptx` with the validated outline.\n4. Return a download link and markdown preview for the teacher.\n\nJSON schema:\n{\n \"title\": \"string β presentation title\",\n \"slides\": [\n {\n \"title\": \"string β slide heading\",\n \"bullets\": [\"string\", \"...\"],\n \"speaker_note\": \"string β one sentence for the teacher\"\n }\n ]\n}\n\nRules:\n- Use exactly the requested number of content slides (title slide is added separately by the tool).\n- Bullets should be short, age-appropriate, and factual.\n- speaker_note is optional but helpful for each slide."}, {"role": "user", "content": "Topic: Fractions\nGrade level: 4\nNumber of content slides: 3\nReturn JSON only."}, {"role": "assistant", "content": "{\"title\": \"Understanding Fractions\", \"slides\": [{\"title\": \"Parts of a Whole\", \"bullets\": [\"A fraction shows equal parts of a whole\", \"The top number counts parts; the bottom shows total equal parts\"], \"speaker_note\": \"Use a paper pizza cut into equal slices as a visual.\"}, {\"title\": \"Reading Common Fractions\", \"bullets\": [\"One-half means one of two equal parts\", \"One-fourth means one of four equal parts\"], \"speaker_note\": \"Have students name fractions they see in the room.\"}, {\"title\": \"Comparing Fractions\", \"bullets\": [\"Larger denominators can mean smaller pieces\", \"Use drawings to compare two simple fractions\"], \"speaker_note\": \"Give two fraction cards and ask which shows more pizza.\"}]}"}]}
|
| 3 |
+
{"messages": [{"role": "system", "content": "You are a lesson-planning assistant for teachers.\nFollow the skill workflow below and output ONLY valid JSON (no markdown fences).\n\nSkill workflow:\n1. Ask for topic, audience grade, and slide count (3β8 content slides).\n2. Produce a JSON outline with `title` and `slides` (each slide has `title`, `bullets`, `speaker_note`).\n3. Call `create_pptx` with the validated outline.\n4. Return a download link and markdown preview for the teacher.\n\nJSON schema:\n{\n \"title\": \"string β presentation title\",\n \"slides\": [\n {\n \"title\": \"string β slide heading\",\n \"bullets\": [\"string\", \"...\"],\n \"speaker_note\": \"string β one sentence for the teacher\"\n }\n ]\n}\n\nRules:\n- Use exactly the requested number of content slides (title slide is added separately by the tool).\n- Bullets should be short, age-appropriate, and factual.\n- speaker_note is optional but helpful for each slide."}, {"role": "user", "content": "Topic: The water cycle\nGrade level: 5\nNumber of content slides: 5\nReturn JSON only."}, {"role": "assistant", "content": "{\"title\": \"The Water Cycle\", \"slides\": [{\"title\": \"Water Moves Constantly\", \"bullets\": [\"Water travels between Earth, air, and living things\", \"The same water has been recycled for millions of years\"], \"speaker_note\": \"Show a simple cycle diagram before diving into each stage.\"}, {\"title\": \"Evaporation\", \"bullets\": [\"Heat turns liquid water into water vapor\", \"Oceans, lakes, and puddles lose water to the air\"], \"speaker_note\": \"Ask what happens to a wet sidewalk on a sunny day.\"}, {\"title\": \"Condensation\", \"bullets\": [\"Water vapor cools and forms tiny droplets\", \"Clouds form when droplets gather in the sky\"], \"speaker_note\": \"Hold a cold cup and point out the outside moisture.\"}, {\"title\": \"Precipitation\", \"bullets\": [\"Droplets combine and fall as rain, snow, sleet, or hail\", \"Precipitation returns water to land and water bodies\"], \"speaker_note\": \"Let students name types of precipitation they have seen.\"}, {\"title\": \"Collection and Runoff\", \"bullets\": [\"Water collects in rivers, lakes, and groundwater\", \"Runoff carries water back toward oceans\"], \"speaker_note\": \"Trace a raindrop's path on your diagram to close the loop.\"}]}"}]}
|
| 4 |
+
{"messages": [{"role": "system", "content": "You are a lesson-planning assistant for teachers.\nFollow the skill workflow below and output ONLY valid JSON (no markdown fences).\n\nSkill workflow:\n1. Ask for topic, audience grade, and slide count (3β8 content slides).\n2. Produce a JSON outline with `title` and `slides` (each slide has `title`, `bullets`, `speaker_note`).\n3. Call `create_pptx` with the validated outline.\n4. Return a download link and markdown preview for the teacher.\n\nJSON schema:\n{\n \"title\": \"string β presentation title\",\n \"slides\": [\n {\n \"title\": \"string β slide heading\",\n \"bullets\": [\"string\", \"...\"],\n \"speaker_note\": \"string β one sentence for the teacher\"\n }\n ]\n}\n\nRules:\n- Use exactly the requested number of content slides (title slide is added separately by the tool).\n- Bullets should be short, age-appropriate, and factual.\n- speaker_note is optional but helpful for each slide."}, {"role": "user", "content": "Topic: AI agents\nGrade level: 8\nNumber of content slides: 4\nReturn JSON only."}, {"role": "assistant", "content": "{\"title\": \"Introduction to AI Agents\", \"slides\": [{\"title\": \"What Is an AI Agent?\", \"bullets\": [\"An AI agent is software that takes goals and acts to reach them\", \"Agents can use tools, memory, and feedback loops\"], \"speaker_note\": \"Contrast a chatbot that only replies with an agent that plans steps.\"}, {\"title\": \"How Agents Work\", \"bullets\": [\"They observe input or environment state\", \"They decide actions and update their plan\"], \"speaker_note\": \"Walk through a simple example like booking a meeting.\"}, {\"title\": \"Everyday Examples\", \"bullets\": [\"Navigation apps reroute based on traffic\", \"Email filters can sort and draft responses\"], \"speaker_note\": \"Invite students to suggest other agent-like tools they use.\"}, {\"title\": \"Benefits and Risks\", \"bullets\": [\"Agents can save time on repetitive tasks\", \"They need clear rules and human oversight\"], \"speaker_note\": \"Close with a short discussion on when automation helps or hurts.\"}]}"}]}
|
| 5 |
+
{"messages": [{"role": "system", "content": "You are a lesson-planning assistant for teachers.\nFollow the skill workflow below and output ONLY valid JSON (no markdown fences).\n\nSkill workflow:\n1. Ask for topic, audience grade, and slide count (3β8 content slides).\n2. Produce a JSON outline with `title` and `slides` (each slide has `title`, `bullets`, `speaker_note`).\n3. Call `create_pptx` with the validated outline.\n4. Return a download link and markdown preview for the teacher.\n\nJSON schema:\n{\n \"title\": \"string β presentation title\",\n \"slides\": [\n {\n \"title\": \"string β slide heading\",\n \"bullets\": [\"string\", \"...\"],\n \"speaker_note\": \"string β one sentence for the teacher\"\n }\n ]\n}\n\nRules:\n- Use exactly the requested number of content slides (title slide is added separately by the tool).\n- Bullets should be short, age-appropriate, and factual.\n- speaker_note is optional but helpful for each slide."}, {"role": "user", "content": "Topic: The American Revolution\nGrade level: 8\nNumber of content slides: 5\nReturn JSON only."}, {"role": "assistant", "content": "{\"title\": \"Causes of the American Revolution\", \"slides\": [{\"title\": \"Colonial Grievances\", \"bullets\": [\"Colonists protested taxation without representation\", \"Trade restrictions limited colonial economies\"], \"speaker_note\": \"Connect taxes on everyday goods to student purchases today.\"}, {\"title\": \"Key Events\", \"bullets\": [\"The Boston Tea Party protested the Tea Act\", \"The Intolerable Acts tightened British control\"], \"speaker_note\": \"Place events on a timeline as you introduce each one.\"}, {\"title\": \"Ideas That Spread\", \"bullets\": [\"Enlightenment ideas emphasized rights and consent\", \"Pamphlets and speeches spread revolutionary arguments\"], \"speaker_note\": \"Read a short excerpt from a primary source if time allows.\"}, {\"title\": \"Leaders and Groups\", \"bullets\": [\"Committees and congresses organized resistance\", \"Leaders like Adams and Washington gained prominence\"], \"speaker_note\": \"Highlight diverse roles beyond famous military figures.\"}, {\"title\": \"Road to Independence\", \"bullets\": [\"The Declaration explained the break from Britain\", \"War followed when compromise failed\"], \"speaker_note\": \"End by asking what principle students would defend most strongly.\"}]}"}]}
|
| 6 |
+
{"messages": [{"role": "system", "content": "You are a lesson-planning assistant for teachers.\nFollow the skill workflow below and output ONLY valid JSON (no markdown fences).\n\nSkill workflow:\n1. Ask for topic, audience grade, and slide count (3β8 content slides).\n2. Produce a JSON outline with `title` and `slides` (each slide has `title`, `bullets`, `speaker_note`).\n3. Call `create_pptx` with the validated outline.\n4. Return a download link and markdown preview for the teacher.\n\nJSON schema:\n{\n \"title\": \"string β presentation title\",\n \"slides\": [\n {\n \"title\": \"string β slide heading\",\n \"bullets\": [\"string\", \"...\"],\n \"speaker_note\": \"string β one sentence for the teacher\"\n }\n ]\n}\n\nRules:\n- Use exactly the requested number of content slides (title slide is added separately by the tool).\n- Bullets should be short, age-appropriate, and factual.\n- speaker_note is optional but helpful for each slide."}, {"role": "user", "content": "Topic: Kindness and friendship\nGrade level: 2\nNumber of content slides: 3\nReturn JSON only."}, {"role": "assistant", "content": "{\"title\": \"Being a Good Friend\", \"slides\": [{\"title\": \"What Friends Do\", \"bullets\": [\"Friends listen and share\", \"Friends help when someone feels sad\"], \"speaker_note\": \"Ask students to share one kind thing a friend did for them.\"}, {\"title\": \"Using Kind Words\", \"bullets\": [\"Please and thank you show respect\", \"Compliments can make others feel happy\"], \"speaker_note\": \"Practice turning an unkind sentence into a kind one together.\"}, {\"title\": \"Solving Small Problems\", \"bullets\": [\"Take turns when you disagree\", \"Ask a grown-up if you need help\"], \"speaker_note\": \"Role-play a simple sharing conflict and a peaceful solution.\"}]}"}]}
|
| 7 |
+
{"messages": [{"role": "system", "content": "You are a lesson-planning assistant for teachers.\nFollow the skill workflow below and output ONLY valid JSON (no markdown fences).\n\nSkill workflow:\n1. Ask for topic, audience grade, and slide count (3β8 content slides).\n2. Produce a JSON outline with `title` and `slides` (each slide has `title`, `bullets`, `speaker_note`).\n3. Call `create_pptx` with the validated outline.\n4. Return a download link and markdown preview for the teacher.\n\nJSON schema:\n{\n \"title\": \"string β presentation title\",\n \"slides\": [\n {\n \"title\": \"string β slide heading\",\n \"bullets\": [\"string\", \"...\"],\n \"speaker_note\": \"string β one sentence for the teacher\"\n }\n ]\n}\n\nRules:\n- Use exactly the requested number of content slides (title slide is added separately by the tool).\n- Bullets should be short, age-appropriate, and factual.\n- speaker_note is optional but helpful for each slide."}, {"role": "user", "content": "Topic: Linear equations\nGrade level: 9\nNumber of content slides: 4\nReturn JSON only."}, {"role": "assistant", "content": "{\"title\": \"Introduction to Linear Equations\", \"slides\": [{\"title\": \"What Makes an Equation Linear?\", \"bullets\": [\"A linear equation graphs as a straight line\", \"The highest power of the variable is one\"], \"speaker_note\": \"Show y = 2x + 1 on a coordinate plane.\"}, {\"title\": \"Slope-Intercept Form\", \"bullets\": [\"y = mx + b uses slope m and y-intercept b\", \"m tells how steep the line is\"], \"speaker_note\": \"Have students identify m and b from one example.\"}, {\"title\": \"Graphing from an Equation\", \"bullets\": [\"Plot the y-intercept first\", \"Use the slope to find another point\"], \"speaker_note\": \"Work one example slowly on the board.\"}, {\"title\": \"Real-World Connections\", \"bullets\": [\"Phone plans with a flat fee plus per-minute cost are linear\", \"Distance at constant speed changes linearly over time\"], \"speaker_note\": \"Ask students to invent their own linear situation.\"}]}"}]}
|
| 8 |
+
{"messages": [{"role": "system", "content": "You are a lesson-planning assistant for teachers.\nFollow the skill workflow below and output ONLY valid JSON (no markdown fences).\n\nSkill workflow:\n1. Ask for topic, audience grade, and slide count (3β8 content slides).\n2. Produce a JSON outline with `title` and `slides` (each slide has `title`, `bullets`, `speaker_note`).\n3. Call `create_pptx` with the validated outline.\n4. Return a download link and markdown preview for the teacher.\n\nJSON schema:\n{\n \"title\": \"string β presentation title\",\n \"slides\": [\n {\n \"title\": \"string β slide heading\",\n \"bullets\": [\"string\", \"...\"],\n \"speaker_note\": \"string β one sentence for the teacher\"\n }\n ]\n}\n\nRules:\n- Use exactly the requested number of content slides (title slide is added separately by the tool).\n- Bullets should be short, age-appropriate, and factual.\n- speaker_note is optional but helpful for each slide."}, {"role": "user", "content": "Topic: The solar system\nGrade level: 3\nNumber of content slides: 4\nReturn JSON only."}, {"role": "assistant", "content": "{\"title\": \"Our Solar System\", \"slides\": [{\"title\": \"The Sun at the Center\", \"bullets\": [\"The Sun is a star that gives Earth light and heat\", \"Planets orbit, or travel around, the Sun\"], \"speaker_note\": \"Use a lamp and balls to model orbits if you have them.\"}, {\"title\": \"Inner Planets\", \"bullets\": [\"Mercury, Venus, Earth, and Mars are rocky planets\", \"Earth is the only one known to have life\"], \"speaker_note\": \"Compare sizes with classroom objects.\"}, {\"title\": \"Outer Planets\", \"bullets\": [\"Jupiter and Saturn are gas giants\", \"Uranus and Neptune are icy giants farther from the Sun\"], \"speaker_note\": \"Mention that Jupiter is the largest planet.\"}, {\"title\": \"Other Objects\", \"bullets\": [\"Asteroids and comets also orbit the Sun\", \"The Moon orbits Earth, not the Sun directly\"], \"speaker_note\": \"Clarify that moons are not planets.\"}]}"}]}
|
research/docs/overview.md
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Research overview
|
| 2 |
+
|
| 3 |
+
How `research/` relates to the main hackathon repo and what each component does.
|
| 4 |
+
|
| 5 |
+
## Position in the repo
|
| 6 |
+
|
| 7 |
+
```text
|
| 8 |
+
small-model-hackathon/
|
| 9 |
+
βββ apps/gradio-space/ β shipped Lesson Agent UI
|
| 10 |
+
βββ libs/agent/ β skill loop, tools, traces
|
| 11 |
+
βββ libs/inference/ β transformers + llama.cpp backends
|
| 12 |
+
βββ models.yaml β model presets (shared with finetune)
|
| 13 |
+
βββ research/ β experiments (this tree)
|
| 14 |
+
βββ finetune.py
|
| 15 |
+
βββ data/
|
| 16 |
+
βββ ensemble/ β uv workspace package
|
| 17 |
+
βββ evals/ β uv workspace package
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
Research code is a **uv workspace sibling** of `apps/*` and `libs/*`. Root `pyproject.toml` declares optional dependency groups (`finetune`, `ensemble`, `evals`) so the Docker Space image does not need to install torch-heavy extras unless you opt in locally.
|
| 21 |
+
|
| 22 |
+
## Three tracks
|
| 23 |
+
|
| 24 |
+
### Fine-tuning
|
| 25 |
+
|
| 26 |
+
`research/finetune.py` adapts a small HF causal LM on instruction or chat data. It reuses root `models.yaml` presets and the shared inference config loader, so the same `minicpm5-1b` preset used in the Gradio app can be fine-tuned without duplicating model metadata.
|
| 27 |
+
|
| 28 |
+
Outputs land in `models/finetuned/` β you can register a new preset in `models.yaml` pointing at merged weights for the **Well-Tuned** hackathon badge.
|
| 29 |
+
|
| 30 |
+
### Ensemble (JEPA / world model)
|
| 31 |
+
|
| 32 |
+
`research/ensemble/` explores a modular stack inspired by LeCun-style architectures:
|
| 33 |
+
|
| 34 |
+
```text
|
| 35 |
+
Input βββΊ Embedder + VectorStore (retrieval memory)
|
| 36 |
+
β
|
| 37 |
+
βΌ
|
| 38 |
+
JEPA encoder βββΊ latent state
|
| 39 |
+
β
|
| 40 |
+
ββββΊ World model (multi-step latent rollout)
|
| 41 |
+
β
|
| 42 |
+
ββββΊ Energy model (scores LLM draft continuations)
|
| 43 |
+
β
|
| 44 |
+
βΌ
|
| 45 |
+
Small LLM generates N drafts β pick lowest energy
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
Two entry ensembles:
|
| 49 |
+
|
| 50 |
+
| Module | File | Critic |
|
| 51 |
+
| ------ | ---- | ------ |
|
| 52 |
+
| JEPA track | `ensemble.jepa_ensemble` | JEPA latent prediction |
|
| 53 |
+
| World track | `ensemble.world_ensemble` | Energy model over world-model rollouts |
|
| 54 |
+
|
| 55 |
+
`TinyBackend` runs on CPU with random weights for smoke tests. `HFBackend` loads real Hub models via `transformers` + optional `peft` LoRA banks.
|
| 56 |
+
|
| 57 |
+
Eval harnesses (`ensemble.eval.jepa_harness`, `ensemble.eval.world_harness`) measure draft-selection accuracy on `research/data/benchmark-qa.jsonl` with optional KB retrieval from `benchmark-kb.jsonl`.
|
| 58 |
+
|
| 59 |
+
### Agentic evals
|
| 60 |
+
|
| 61 |
+
`research/evals/` (`slm-evals` package) scores **whole models** on public agent benchmarks β function calling, multi-turn tool use, GAIA tasks, and SWE-bench patches. This complements ensemble harnesses: evals test end-to-end model behavior; ensemble harnesses test internal selection mechanisms on a small custom QA set.
|
| 62 |
+
|
| 63 |
+
## Data flow
|
| 64 |
+
|
| 65 |
+
```mermaid
|
| 66 |
+
flowchart LR
|
| 67 |
+
subgraph data [research/data]
|
| 68 |
+
lesson[education-lesson-chat.jsonl]
|
| 69 |
+
qa[benchmark-qa.jsonl]
|
| 70 |
+
kb[benchmark-kb.jsonl]
|
| 71 |
+
end
|
| 72 |
+
|
| 73 |
+
subgraph train [finetune.py]
|
| 74 |
+
ckpt[models/finetuned/]
|
| 75 |
+
end
|
| 76 |
+
|
| 77 |
+
subgraph evals [slm-evals]
|
| 78 |
+
bfcl[BFCL]
|
| 79 |
+
tau[tau-bench]
|
| 80 |
+
gaia[GAIA]
|
| 81 |
+
swe[SWE-bench]
|
| 82 |
+
end
|
| 83 |
+
|
| 84 |
+
subgraph ens [ensemble]
|
| 85 |
+
jepa[JEPA harness]
|
| 86 |
+
world[World harness]
|
| 87 |
+
end
|
| 88 |
+
|
| 89 |
+
lesson --> train
|
| 90 |
+
train --> ckpt
|
| 91 |
+
ckpt --> evals
|
| 92 |
+
qa --> jepa
|
| 93 |
+
kb --> jepa
|
| 94 |
+
qa --> world
|
| 95 |
+
kb --> world
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
## When to use which tool
|
| 99 |
+
|
| 100 |
+
| Goal | Tool |
|
| 101 |
+
| ---- | ---- |
|
| 102 |
+
| Improve lesson slide quality on your data | `finetune.py` + optional eval before/after |
|
| 103 |
+
| Compare base vs LoRA on public agent tasks | `slm-benchmark` |
|
| 104 |
+
| Prototype latent draft selection | `ensemble` smoke β harness |
|
| 105 |
+
| Ship in Gradio Space | `apps/gradio-space` only β wire new weights via `models.yaml` |
|
| 106 |
+
|
| 107 |
+
## Workspace packages
|
| 108 |
+
|
| 109 |
+
Both subpackages are listed in root `[tool.uv.workspace] members`:
|
| 110 |
+
|
| 111 |
+
- `research/ensemble` β import name `ensemble`
|
| 112 |
+
- `research/evals` β import name `slm_evals`, CLI `slm-benchmark`
|
| 113 |
+
|
| 114 |
+
Run with `uv run --package <name>` from the repo root so uv resolves workspace paths and shared lockfile versions.
|
research/ensemble/README.md
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ensemble research package
|
| 2 |
+
|
| 3 |
+
JEPA and world-model ensemble experiments. Stays under `research/` β not wired into the Gradio agent.
|
| 4 |
+
|
| 5 |
+
See also: [../USAGE.md](../USAGE.md) Β· [../docs/overview.md](../docs/overview.md)
|
| 6 |
+
|
| 7 |
+
## Install
|
| 8 |
+
|
| 9 |
+
```bash
|
| 10 |
+
uv sync --group ensemble
|
| 11 |
+
```
|
| 12 |
+
|
| 13 |
+
## Tier 1 β Smoke (CPU, no HF download)
|
| 14 |
+
|
| 15 |
+
```bash
|
| 16 |
+
uv run --package ensemble python -m ensemble.jepa_ensemble tiny
|
| 17 |
+
uv run --package ensemble python -m ensemble.world_ensemble tiny
|
| 18 |
+
bash research/ensemble/scripts/smoke.sh
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
## Tier 2 β Micro demo (real small model)
|
| 22 |
+
|
| 23 |
+
```bash
|
| 24 |
+
uv run --package ensemble python -m ensemble.jepa_ensemble Qwen/Qwen2.5-0.5B-Instruct
|
| 25 |
+
uv run --package ensemble python -m ensemble.world_ensemble Qwen/Qwen2.5-0.5B-Instruct
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
## Pretrain + save (LLM + emb + JEPA)
|
| 29 |
+
|
| 30 |
+
Joint training writes a full checkpoint to `models/ensemble/<name>/`:
|
| 31 |
+
|
| 32 |
+
```bash
|
| 33 |
+
# CPU smoke (tiny backend, no HF download)
|
| 34 |
+
uv run --package ensemble ensemble-pretrain \
|
| 35 |
+
--llm tiny --steps 50 --no-kb \
|
| 36 |
+
--out models/ensemble/jepa-smoke
|
| 37 |
+
|
| 38 |
+
# Uses ACTIVE_MODEL / BASE / LLM_PATH from .env + models.yaml by default
|
| 39 |
+
uv run --package ensemble ensemble-pretrain \
|
| 40 |
+
--data research/data/education-lesson-chat.jsonl \
|
| 41 |
+
--kb research/data/benchmark-kb.jsonl \
|
| 42 |
+
--steps 200
|
| 43 |
+
|
| 44 |
+
# Override base LLM explicitly
|
| 45 |
+
uv run --package ensemble ensemble-pretrain \
|
| 46 |
+
--llm Qwen/Qwen2.5-0.5B-Instruct --steps 200
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
Checkpoint layout: `manifest.json`, `aux.pt` (emb/jepa/bridge/router), `llm/` (PEFT adapters).
|
| 50 |
+
|
| 51 |
+
Benchmark the saved ensemble with **slm-evals** (auto-detects `manifest.json`):
|
| 52 |
+
|
| 53 |
+
```bash
|
| 54 |
+
uv run --package slm-evals slm-benchmark \
|
| 55 |
+
--model ./models/ensemble/jepa-lesson-pretrain \
|
| 56 |
+
--model-type ensemble \
|
| 57 |
+
--benchmarks bfcl tau_bench --max-samples 20
|
| 58 |
+
|
| 59 |
+
# Or use the template config
|
| 60 |
+
uv run --package slm-evals slm-benchmark \
|
| 61 |
+
--config research/evals/configs/ensemble_jepa_lesson.yaml
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
Compare against a base HF model by running the same config with `model_type: hf` and `model_path: openbmb/MiniCPM5-1B`.
|
| 65 |
+
|
| 66 |
+
## Tier 3 β Benchmark
|
| 67 |
+
|
| 68 |
+
### JEPA ablation ladder
|
| 69 |
+
|
| 70 |
+
```bash
|
| 71 |
+
# Toy (no download)
|
| 72 |
+
uv run --package ensemble python -m ensemble.eval.jepa_harness \
|
| 73 |
+
--llm tiny --toy --limit 20 --n_drafts 8
|
| 74 |
+
|
| 75 |
+
# Education QA set
|
| 76 |
+
uv run --package ensemble python -m ensemble.eval.jepa_harness \
|
| 77 |
+
--llm Qwen/Qwen2.5-0.5B-Instruct \
|
| 78 |
+
--qa research/data/benchmark-qa.jsonl \
|
| 79 |
+
--kb research/data/benchmark-kb.jsonl \
|
| 80 |
+
--limit 50 --n_drafts 8
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
### World-model energy selector
|
| 84 |
+
|
| 85 |
+
```bash
|
| 86 |
+
uv run --package ensemble python -m ensemble.eval.world_harness \
|
| 87 |
+
--llm tiny --toy --limit 20 --n_drafts 8
|
| 88 |
+
|
| 89 |
+
uv run --package ensemble python -m ensemble.eval.world_harness \
|
| 90 |
+
--llm Qwen/Qwen2.5-0.5B-Instruct \
|
| 91 |
+
--qa research/data/benchmark-qa.jsonl \
|
| 92 |
+
--kb research/data/benchmark-kb.jsonl \
|
| 93 |
+
--limit 50 --n_drafts 8
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
## Layout
|
| 97 |
+
|
| 98 |
+
```
|
| 99 |
+
research/ensemble/
|
| 100 |
+
src/ensemble/
|
| 101 |
+
backends.py # TinyBackend, HFBackend, TinyLLM, HFLLM
|
| 102 |
+
memory.py # Embedder, VectorStore, Router
|
| 103 |
+
jepa.py # JEPA latent predictor
|
| 104 |
+
bridge.py # LLM hidden -> latent alignment
|
| 105 |
+
world_model.py # Latent dynamics + rollout
|
| 106 |
+
energy.py # Energy-based critic
|
| 107 |
+
jepa_ensemble.py # Ensemble (JEPA track)
|
| 108 |
+
world_ensemble.py # WorldEnsemble
|
| 109 |
+
eval/
|
| 110 |
+
metrics.py
|
| 111 |
+
jepa_harness.py
|
| 112 |
+
world_harness.py
|
| 113 |
+
```
|
research/ensemble/pyproject.toml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "ensemble"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "JEPA and world-model ensemble research package"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.12"
|
| 7 |
+
dependencies = [
|
| 8 |
+
"torch>=2.5.0",
|
| 9 |
+
]
|
| 10 |
+
|
| 11 |
+
[project.scripts]
|
| 12 |
+
ensemble-pretrain = "ensemble.pretrain:main"
|
| 13 |
+
|
| 14 |
+
[build-system]
|
| 15 |
+
requires = ["uv_build>=0.8.13,<0.9.0"]
|
| 16 |
+
build-backend = "uv_build"
|
research/ensemble/scripts/smoke.sh
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
ROOT="$(cd "$(dirname "$0")/../../.." && pwd)"
|
| 5 |
+
cd "$ROOT"
|
| 6 |
+
|
| 7 |
+
echo "== JEPA ensemble demo (tiny) =="
|
| 8 |
+
uv run --package ensemble python -m ensemble.jepa_ensemble tiny
|
| 9 |
+
|
| 10 |
+
echo ""
|
| 11 |
+
echo "== World ensemble demo (tiny) =="
|
| 12 |
+
uv run --package ensemble python -m ensemble.world_ensemble tiny
|
| 13 |
+
|
| 14 |
+
echo ""
|
| 15 |
+
echo "== JEPA harness (toy) =="
|
| 16 |
+
uv run --package ensemble python -m ensemble.eval.jepa_harness \
|
| 17 |
+
--llm tiny --toy --limit 10 --n_drafts 4
|
| 18 |
+
|
| 19 |
+
echo "== Pretrain smoke + checkpoint roundtrip =="
|
| 20 |
+
uv run --package ensemble ensemble-pretrain \
|
| 21 |
+
--llm tiny --steps 20 --no-kb \
|
| 22 |
+
--out models/ensemble/jepa-smoke
|
| 23 |
+
uv run --package ensemble python -c "
|
| 24 |
+
from ensemble.checkpoint import load_checkpoint
|
| 25 |
+
ens = load_checkpoint('models/ensemble/jepa-smoke')
|
| 26 |
+
print('loaded ensemble, adapters:', ens.adapter_names)
|
| 27 |
+
"
|
| 28 |
+
|
| 29 |
+
echo ""
|
| 30 |
+
echo "== World harness (toy) =="
|
| 31 |
+
uv run --package ensemble python -m ensemble.eval.world_harness \
|
| 32 |
+
--llm tiny --toy --limit 10 --n_drafts 4
|
| 33 |
+
|
| 34 |
+
echo ""
|
| 35 |
+
echo "All smoke checks passed."
|
research/ensemble/src/ensemble/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Research ensemble package: JEPA and world-model tracks."""
|
| 2 |
+
|
| 3 |
+
__all__ = ["Ensemble", "WorldEnsemble"]
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def __getattr__(name: str):
|
| 7 |
+
if name == "Ensemble":
|
| 8 |
+
from ensemble.jepa_ensemble import Ensemble
|
| 9 |
+
|
| 10 |
+
return Ensemble
|
| 11 |
+
if name == "WorldEnsemble":
|
| 12 |
+
from ensemble.world_ensemble import WorldEnsemble
|
| 13 |
+
|
| 14 |
+
return WorldEnsemble
|
| 15 |
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
research/ensemble/src/ensemble/backends.py
ADDED
|
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""LLM backends: toy fallbacks and HuggingFace + LoRA loaders."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class LLMBackend(nn.Module):
|
| 11 |
+
"""Contract for JEPA ensemble backends."""
|
| 12 |
+
|
| 13 |
+
vocab_size: int
|
| 14 |
+
hidden_size: int
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class HFBackend(LLMBackend):
|
| 18 |
+
"""HuggingFace causal LM with PEFT LoRA adapter bank."""
|
| 19 |
+
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
model_path: str,
|
| 23 |
+
*,
|
| 24 |
+
load_in_4bit: bool = False,
|
| 25 |
+
lora_r: int = 16,
|
| 26 |
+
lora_alpha: int = 32,
|
| 27 |
+
target_modules=("q_proj", "v_proj"),
|
| 28 |
+
device: str | None = None,
|
| 29 |
+
torch_dtype=None,
|
| 30 |
+
):
|
| 31 |
+
super().__init__()
|
| 32 |
+
from peft import LoraConfig, get_peft_model
|
| 33 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 34 |
+
|
| 35 |
+
self.device_ = torch.device(
|
| 36 |
+
device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
kwargs = {}
|
| 40 |
+
if load_in_4bit:
|
| 41 |
+
from transformers import BitsAndBytesConfig
|
| 42 |
+
|
| 43 |
+
kwargs["quantization_config"] = BitsAndBytesConfig(
|
| 44 |
+
load_in_4bit=True,
|
| 45 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 46 |
+
bnb_4bit_quant_type="nf4",
|
| 47 |
+
)
|
| 48 |
+
if torch_dtype is not None:
|
| 49 |
+
kwargs["torch_dtype"] = torch_dtype
|
| 50 |
+
|
| 51 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 52 |
+
if self.tokenizer.pad_token is None:
|
| 53 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 54 |
+
base = AutoModelForCausalLM.from_pretrained(model_path, **kwargs)
|
| 55 |
+
if not load_in_4bit:
|
| 56 |
+
base.to(self.device_)
|
| 57 |
+
|
| 58 |
+
for p in base.parameters():
|
| 59 |
+
p.requires_grad_(False)
|
| 60 |
+
|
| 61 |
+
self._lora_cfg = LoraConfig(
|
| 62 |
+
r=lora_r,
|
| 63 |
+
lora_alpha=lora_alpha,
|
| 64 |
+
lora_dropout=0.05,
|
| 65 |
+
target_modules=list(target_modules),
|
| 66 |
+
task_type="CAUSAL_LM",
|
| 67 |
+
)
|
| 68 |
+
self.model = get_peft_model(base, self._lora_cfg, adapter_name="general")
|
| 69 |
+
self._adapters = {"general"}
|
| 70 |
+
|
| 71 |
+
self.vocab_size = self.model.config.vocab_size
|
| 72 |
+
self.hidden_size = self.model.config.hidden_size
|
| 73 |
+
|
| 74 |
+
def add_adapter(self, name: str):
|
| 75 |
+
if name not in self._adapters:
|
| 76 |
+
self.model.add_adapter(name, self._lora_cfg)
|
| 77 |
+
self._adapters.add(name)
|
| 78 |
+
|
| 79 |
+
def set_adapter(self, name: str):
|
| 80 |
+
self.model.set_adapter(name)
|
| 81 |
+
|
| 82 |
+
def trainable_parameters(self):
|
| 83 |
+
return (p for p in self.model.parameters() if p.requires_grad)
|
| 84 |
+
|
| 85 |
+
def forward(self, ids):
|
| 86 |
+
out = self.model(
|
| 87 |
+
input_ids=ids.to(self.device_), output_hidden_states=True
|
| 88 |
+
)
|
| 89 |
+
return out.logits, out.hidden_states[-1]
|
| 90 |
+
|
| 91 |
+
@torch.no_grad()
|
| 92 |
+
def generate(self, ids, n_new=64, temperature=0.8):
|
| 93 |
+
gen_kwargs: dict = dict(
|
| 94 |
+
input_ids=ids.to(self.device_),
|
| 95 |
+
max_new_tokens=n_new,
|
| 96 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
| 97 |
+
)
|
| 98 |
+
if temperature <= 0:
|
| 99 |
+
gen_kwargs["do_sample"] = False
|
| 100 |
+
else:
|
| 101 |
+
gen_kwargs.update(do_sample=True, temperature=temperature)
|
| 102 |
+
out = self.model.generate(**gen_kwargs)
|
| 103 |
+
return out
|
| 104 |
+
|
| 105 |
+
def encode_text(self, text: str):
|
| 106 |
+
return self.tokenizer(text, return_tensors="pt").input_ids.to(self.device_)
|
| 107 |
+
|
| 108 |
+
def decode(self, ids):
|
| 109 |
+
return self.tokenizer.decode(ids[0], skip_special_tokens=True)
|
| 110 |
+
|
| 111 |
+
@property
|
| 112 |
+
def device(self):
|
| 113 |
+
return self.device_
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class TinyBackend(LLMBackend):
|
| 117 |
+
"""Toy transformer with LoRA adapters (no transformers dependency)."""
|
| 118 |
+
|
| 119 |
+
VOCAB, D_MODEL, N_LAYERS, N_HEADS, SEQ_LEN, LORA_R = 1000, 128, 2, 4, 32, 8
|
| 120 |
+
|
| 121 |
+
class _LoRALinear(nn.Module):
|
| 122 |
+
def __init__(self, d_in, d_out, r):
|
| 123 |
+
super().__init__()
|
| 124 |
+
self.base = nn.Linear(d_in, d_out)
|
| 125 |
+
self.base.weight.requires_grad_(False)
|
| 126 |
+
self.base.bias.requires_grad_(False)
|
| 127 |
+
self.adapters, self.active, self.r = nn.ModuleDict(), None, r
|
| 128 |
+
|
| 129 |
+
def add_adapter(self, name):
|
| 130 |
+
A = nn.Linear(self.base.in_features, self.r, bias=False)
|
| 131 |
+
B = nn.Linear(self.r, self.base.out_features, bias=False)
|
| 132 |
+
nn.init.zeros_(B.weight)
|
| 133 |
+
self.adapters[name] = nn.Sequential(A, B)
|
| 134 |
+
|
| 135 |
+
def forward(self, x):
|
| 136 |
+
y = self.base(x)
|
| 137 |
+
if self.active and self.active in self.adapters:
|
| 138 |
+
y = y + self.adapters[self.active](x)
|
| 139 |
+
return y
|
| 140 |
+
|
| 141 |
+
class _Block(nn.Module):
|
| 142 |
+
def __init__(self, D, H, R):
|
| 143 |
+
super().__init__()
|
| 144 |
+
L = TinyBackend._LoRALinear
|
| 145 |
+
self.ln1 = nn.LayerNorm(D)
|
| 146 |
+
self.attn = nn.MultiheadAttention(D, H, batch_first=True)
|
| 147 |
+
self.ln2 = nn.LayerNorm(D)
|
| 148 |
+
self.up, self.down = L(D, 4 * D, R), L(4 * D, D, R)
|
| 149 |
+
|
| 150 |
+
def forward(self, x, mask):
|
| 151 |
+
h = self.ln1(x)
|
| 152 |
+
a, _ = self.attn(h, h, h, attn_mask=mask, need_weights=False)
|
| 153 |
+
x = x + a
|
| 154 |
+
return x + self.down(F.gelu(self.up(self.ln2(x))))
|
| 155 |
+
|
| 156 |
+
def __init__(self):
|
| 157 |
+
super().__init__()
|
| 158 |
+
D, V = self.D_MODEL, self.VOCAB
|
| 159 |
+
self.tok = nn.Embedding(V, D)
|
| 160 |
+
self.pos = nn.Embedding(self.SEQ_LEN * 4, D)
|
| 161 |
+
self.blocks = nn.ModuleList(
|
| 162 |
+
[self._Block(D, self.N_HEADS, self.LORA_R) for _ in range(self.N_LAYERS)]
|
| 163 |
+
)
|
| 164 |
+
self.ln_f, self.head = nn.LayerNorm(D), nn.Linear(D, V, bias=False)
|
| 165 |
+
self.vocab_size, self.hidden_size = V, D
|
| 166 |
+
self.add_adapter("general")
|
| 167 |
+
self.set_adapter("general")
|
| 168 |
+
|
| 169 |
+
def add_adapter(self, name):
|
| 170 |
+
for b in self.blocks:
|
| 171 |
+
b.up.add_adapter(name)
|
| 172 |
+
b.down.add_adapter(name)
|
| 173 |
+
|
| 174 |
+
def set_adapter(self, name):
|
| 175 |
+
for b in self.blocks:
|
| 176 |
+
b.up.active = name
|
| 177 |
+
b.down.active = name
|
| 178 |
+
|
| 179 |
+
def trainable_parameters(self):
|
| 180 |
+
return (p for p in self.parameters() if p.requires_grad)
|
| 181 |
+
|
| 182 |
+
def forward(self, ids):
|
| 183 |
+
B, T = ids.shape
|
| 184 |
+
x = self.tok(ids) + self.pos(torch.arange(T, device=ids.device))
|
| 185 |
+
mask = torch.triu(
|
| 186 |
+
torch.full((T, T), float("-inf"), device=ids.device), 1
|
| 187 |
+
)
|
| 188 |
+
for b in self.blocks:
|
| 189 |
+
x = b(x, mask)
|
| 190 |
+
h = self.ln_f(x)
|
| 191 |
+
return self.head(h), h
|
| 192 |
+
|
| 193 |
+
@torch.no_grad()
|
| 194 |
+
def generate(self, ids, n_new=16, temperature=1.0):
|
| 195 |
+
for _ in range(n_new):
|
| 196 |
+
logits, _ = self(ids[:, -self.SEQ_LEN :])
|
| 197 |
+
if temperature <= 0:
|
| 198 |
+
nxt = logits[:, -1].argmax(dim=-1, keepdim=True)
|
| 199 |
+
else:
|
| 200 |
+
nxt = torch.multinomial(
|
| 201 |
+
F.softmax(logits[:, -1] / temperature, -1), 1
|
| 202 |
+
)
|
| 203 |
+
ids = torch.cat([ids, nxt], dim=1)
|
| 204 |
+
return ids
|
| 205 |
+
|
| 206 |
+
def encode_text(self, text: str):
|
| 207 |
+
vals = [ord(c) % self.vocab_size for c in text[: self.SEQ_LEN]]
|
| 208 |
+
if not vals:
|
| 209 |
+
vals = [0]
|
| 210 |
+
return torch.tensor([vals], dtype=torch.long)
|
| 211 |
+
|
| 212 |
+
def decode(self, ids):
|
| 213 |
+
return " ".join(str(int(t)) for t in ids[0].tolist())
|
| 214 |
+
|
| 215 |
+
@property
|
| 216 |
+
def device(self):
|
| 217 |
+
return next(self.parameters()).device
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def make_backend(llm: str, **kw) -> LLMBackend:
|
| 221 |
+
"""'tiny' -> toy model; anything else -> HF hub id or local path."""
|
| 222 |
+
return TinyBackend() if llm == "tiny" else HFBackend(llm, **kw)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def load_hf_backend_from_checkpoint(
|
| 226 |
+
base_llm: str,
|
| 227 |
+
adapter_dir: str | None,
|
| 228 |
+
*,
|
| 229 |
+
adapter_names: tuple[str, ...] = ("general",),
|
| 230 |
+
device: str | None = None,
|
| 231 |
+
load_in_4bit: bool = False,
|
| 232 |
+
lora_r: int = 16,
|
| 233 |
+
lora_alpha: int = 32,
|
| 234 |
+
) -> HFBackend:
|
| 235 |
+
"""Load a frozen base LM + saved PEFT adapters (ensemble checkpoint llm/)."""
|
| 236 |
+
from pathlib import Path
|
| 237 |
+
|
| 238 |
+
from peft import LoraConfig, PeftModel, get_peft_model
|
| 239 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 240 |
+
|
| 241 |
+
def _discover_adapter_dirs(root: Path) -> dict[str, Path]:
|
| 242 |
+
if (root / "adapter_config.json").is_file():
|
| 243 |
+
return {"general": root}
|
| 244 |
+
discovered: dict[str, Path] = {}
|
| 245 |
+
for child in sorted(root.iterdir()):
|
| 246 |
+
if child.is_dir() and (child / "adapter_config.json").is_file():
|
| 247 |
+
discovered[child.name] = child
|
| 248 |
+
return discovered
|
| 249 |
+
|
| 250 |
+
resolved_device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 251 |
+
tokenizer = AutoTokenizer.from_pretrained(adapter_dir or base_llm)
|
| 252 |
+
if tokenizer.pad_token is None:
|
| 253 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 254 |
+
|
| 255 |
+
kwargs: dict = {}
|
| 256 |
+
if load_in_4bit:
|
| 257 |
+
from transformers import BitsAndBytesConfig
|
| 258 |
+
|
| 259 |
+
kwargs["quantization_config"] = BitsAndBytesConfig(
|
| 260 |
+
load_in_4bit=True,
|
| 261 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 262 |
+
bnb_4bit_quant_type="nf4",
|
| 263 |
+
)
|
| 264 |
+
elif resolved_device != "cpu":
|
| 265 |
+
kwargs["torch_dtype"] = torch.bfloat16
|
| 266 |
+
|
| 267 |
+
base = AutoModelForCausalLM.from_pretrained(base_llm, **kwargs)
|
| 268 |
+
if not load_in_4bit and resolved_device != "cpu":
|
| 269 |
+
base.to(resolved_device)
|
| 270 |
+
for p in base.parameters():
|
| 271 |
+
p.requires_grad_(False)
|
| 272 |
+
|
| 273 |
+
if adapter_dir:
|
| 274 |
+
adapter_dirs = _discover_adapter_dirs(Path(adapter_dir))
|
| 275 |
+
if not adapter_dirs:
|
| 276 |
+
raise ValueError(
|
| 277 |
+
f"No PEFT adapters found under {adapter_dir} "
|
| 278 |
+
"(expected adapter_config.json or <name>/adapter_config.json)"
|
| 279 |
+
)
|
| 280 |
+
preferred = [name for name in adapter_names if name in adapter_dirs]
|
| 281 |
+
load_order = preferred + [
|
| 282 |
+
name for name in adapter_dirs if name not in preferred
|
| 283 |
+
]
|
| 284 |
+
first_name = load_order[0]
|
| 285 |
+
model = PeftModel.from_pretrained(
|
| 286 |
+
base,
|
| 287 |
+
str(adapter_dirs[first_name]),
|
| 288 |
+
adapter_name=first_name,
|
| 289 |
+
is_trainable=False,
|
| 290 |
+
)
|
| 291 |
+
for name in load_order[1:]:
|
| 292 |
+
model.load_adapter(str(adapter_dirs[name]), adapter_name=name)
|
| 293 |
+
adapters = set(load_order)
|
| 294 |
+
else:
|
| 295 |
+
lora_cfg = LoraConfig(
|
| 296 |
+
r=lora_r,
|
| 297 |
+
lora_alpha=lora_alpha,
|
| 298 |
+
lora_dropout=0.05,
|
| 299 |
+
target_modules=["q_proj", "v_proj"],
|
| 300 |
+
task_type="CAUSAL_LM",
|
| 301 |
+
)
|
| 302 |
+
model = get_peft_model(base, lora_cfg, adapter_name="general")
|
| 303 |
+
adapters = {"general"}
|
| 304 |
+
|
| 305 |
+
backend = HFBackend.__new__(HFBackend)
|
| 306 |
+
nn.Module.__init__(backend)
|
| 307 |
+
backend.device_ = torch.device(resolved_device)
|
| 308 |
+
backend.tokenizer = tokenizer
|
| 309 |
+
backend.model = model
|
| 310 |
+
backend._lora_cfg = None
|
| 311 |
+
backend._adapters = adapters
|
| 312 |
+
backend.vocab_size = model.config.vocab_size
|
| 313 |
+
backend.hidden_size = model.config.hidden_size
|
| 314 |
+
if adapter_names:
|
| 315 |
+
backend.set_adapter(adapter_names[0])
|
| 316 |
+
return backend
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
class TinyLLM(nn.Module):
|
| 320 |
+
"""Simpler toy LLM for the world-model track (no adapter bank)."""
|
| 321 |
+
|
| 322 |
+
VOCAB, D, L, H, T = 1000, 128, 2, 4, 32
|
| 323 |
+
|
| 324 |
+
def __init__(self):
|
| 325 |
+
super().__init__()
|
| 326 |
+
self.tok = nn.Embedding(self.VOCAB, self.D)
|
| 327 |
+
self.pos = nn.Embedding(self.T * 4, self.D)
|
| 328 |
+
layer = nn.TransformerEncoderLayer(
|
| 329 |
+
self.D, self.H, 4 * self.D, batch_first=True, norm_first=True
|
| 330 |
+
)
|
| 331 |
+
self.blocks = nn.TransformerEncoder(layer, self.L)
|
| 332 |
+
self.head = nn.Linear(self.D, self.VOCAB, bias=False)
|
| 333 |
+
self.vocab_size, self.hidden_size = self.VOCAB, self.D
|
| 334 |
+
|
| 335 |
+
def forward(self, ids):
|
| 336 |
+
Tn = ids.size(1)
|
| 337 |
+
x = self.tok(ids) + self.pos(torch.arange(Tn, device=ids.device))
|
| 338 |
+
mask = torch.triu(
|
| 339 |
+
torch.full((Tn, Tn), float("-inf"), device=ids.device), 1
|
| 340 |
+
)
|
| 341 |
+
h = self.blocks(x, mask=mask)
|
| 342 |
+
return self.head(h), h
|
| 343 |
+
|
| 344 |
+
@torch.no_grad()
|
| 345 |
+
def generate(self, ids, n_new=16, temperature=1.0):
|
| 346 |
+
for _ in range(n_new):
|
| 347 |
+
logits, _ = self(ids[:, -self.T :])
|
| 348 |
+
nxt = torch.multinomial(
|
| 349 |
+
F.softmax(logits[:, -1] / temperature, -1), 1
|
| 350 |
+
)
|
| 351 |
+
ids = torch.cat([ids, nxt], 1)
|
| 352 |
+
return ids
|
| 353 |
+
|
| 354 |
+
def trainable_parameters(self):
|
| 355 |
+
return self.parameters()
|
| 356 |
+
|
| 357 |
+
@property
|
| 358 |
+
def device(self):
|
| 359 |
+
return next(self.parameters()).device
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
class HFLLM(nn.Module):
|
| 363 |
+
"""Small HF model with single LoRA stack (world-model track)."""
|
| 364 |
+
|
| 365 |
+
def __init__(self, path, lora_r=16):
|
| 366 |
+
super().__init__()
|
| 367 |
+
from peft import LoraConfig, get_peft_model
|
| 368 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 369 |
+
|
| 370 |
+
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
| 371 |
+
if self.tokenizer.pad_token is None:
|
| 372 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 373 |
+
base = AutoModelForCausalLM.from_pretrained(
|
| 374 |
+
path,
|
| 375 |
+
torch_dtype=torch.bfloat16
|
| 376 |
+
if torch.cuda.is_available()
|
| 377 |
+
else torch.float32,
|
| 378 |
+
device_map="auto" if torch.cuda.is_available() else None,
|
| 379 |
+
)
|
| 380 |
+
for p in base.parameters():
|
| 381 |
+
p.requires_grad_(False)
|
| 382 |
+
cfg = LoraConfig(
|
| 383 |
+
r=lora_r,
|
| 384 |
+
lora_alpha=2 * lora_r,
|
| 385 |
+
lora_dropout=0.05,
|
| 386 |
+
target_modules=["q_proj", "v_proj"],
|
| 387 |
+
task_type="CAUSAL_LM",
|
| 388 |
+
)
|
| 389 |
+
self.model = get_peft_model(base, cfg)
|
| 390 |
+
self.vocab_size = self.model.config.vocab_size
|
| 391 |
+
self.hidden_size = self.model.config.hidden_size
|
| 392 |
+
|
| 393 |
+
def forward(self, ids):
|
| 394 |
+
out = self.model(
|
| 395 |
+
input_ids=ids.to(self.device), output_hidden_states=True
|
| 396 |
+
)
|
| 397 |
+
return out.logits, out.hidden_states[-1]
|
| 398 |
+
|
| 399 |
+
@torch.no_grad()
|
| 400 |
+
def generate(self, ids, n_new=32, temperature=0.8):
|
| 401 |
+
return self.model.generate(
|
| 402 |
+
input_ids=ids.to(self.device),
|
| 403 |
+
max_new_tokens=n_new,
|
| 404 |
+
do_sample=True,
|
| 405 |
+
temperature=temperature,
|
| 406 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
def trainable_parameters(self):
|
| 410 |
+
return (p for p in self.model.parameters() if p.requires_grad)
|
| 411 |
+
|
| 412 |
+
@property
|
| 413 |
+
def device(self):
|
| 414 |
+
return next(self.model.parameters()).device
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
def load_llm(spec: str):
|
| 418 |
+
return TinyLLM() if spec == "tiny" else HFLLM(spec)
|
research/ensemble/src/ensemble/bridge.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Bridge: align LLM hidden states with JEPA latent space."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Bridge(nn.Module):
|
| 11 |
+
def __init__(self, d_llm_hidden: int, d_latent: int):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.proj = nn.Sequential(
|
| 14 |
+
nn.Linear(d_llm_hidden, d_latent),
|
| 15 |
+
nn.GELU(),
|
| 16 |
+
nn.Linear(d_latent, d_latent),
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
def forward(self, llm_hidden):
|
| 20 |
+
return self.proj(llm_hidden.float().mean(dim=1))
|
| 21 |
+
|
| 22 |
+
def info_nce(self, z1, z2, tau=0.07):
|
| 23 |
+
z1, z2 = F.normalize(z1, dim=-1), F.normalize(z2, dim=-1)
|
| 24 |
+
logits = z1 @ z2.t() / tau
|
| 25 |
+
labels = torch.arange(z1.size(0), device=z1.device)
|
| 26 |
+
return 0.5 * (
|
| 27 |
+
F.cross_entropy(logits, labels) + F.cross_entropy(logits.t(), labels)
|
| 28 |
+
)
|
research/ensemble/src/ensemble/checkpoint.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Save and load JEPA ensemble checkpoints under models/ensemble/."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from ensemble.backends import TinyBackend, load_hf_backend_from_checkpoint
|
| 12 |
+
from ensemble.jepa_ensemble import Ensemble
|
| 13 |
+
|
| 14 |
+
MANIFEST_FILE = "manifest.json"
|
| 15 |
+
AUX_FILE = "aux.pt"
|
| 16 |
+
STORE_FILE = "store.pt"
|
| 17 |
+
LLM_DIR = "llm"
|
| 18 |
+
TINY_LLM_FILE = "tiny_llm.pt"
|
| 19 |
+
|
| 20 |
+
CHECKPOINT_VERSION = 1
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _aux_state_dict(ens: Ensemble) -> dict[str, torch.Tensor]:
|
| 24 |
+
return {
|
| 25 |
+
"emb": ens.emb.state_dict(),
|
| 26 |
+
"jepa": ens.jepa.state_dict(),
|
| 27 |
+
"bridge": ens.bridge.state_dict(),
|
| 28 |
+
"router": ens.router.state_dict(),
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _store_payload(ens: Ensemble) -> dict[str, Any]:
|
| 33 |
+
return {
|
| 34 |
+
"keys": [k for k in ens.store.keys],
|
| 35 |
+
"values": [v for v in ens.store.values],
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def save_checkpoint(
|
| 40 |
+
ens: Ensemble,
|
| 41 |
+
out_dir: str | Path,
|
| 42 |
+
*,
|
| 43 |
+
base_llm: str,
|
| 44 |
+
training_meta: dict[str, Any] | None = None,
|
| 45 |
+
) -> Path:
|
| 46 |
+
"""Persist ensemble (LLM adapters + emb + JEPA + bridge + router + store)."""
|
| 47 |
+
root = Path(out_dir).resolve()
|
| 48 |
+
root.mkdir(parents=True, exist_ok=True)
|
| 49 |
+
|
| 50 |
+
backend = "tiny" if isinstance(ens.llm, TinyBackend) else "hf"
|
| 51 |
+
manifest: dict[str, Any] = {
|
| 52 |
+
"version": CHECKPOINT_VERSION,
|
| 53 |
+
"track": "jepa",
|
| 54 |
+
"backend": backend,
|
| 55 |
+
"base_llm": base_llm,
|
| 56 |
+
"adapter_names": list(ens.adapter_names),
|
| 57 |
+
"d_emb": ens.emb.d_emb,
|
| 58 |
+
"d_jepa": ens.jepa.d_latent,
|
| 59 |
+
"training": training_meta or {},
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
torch.save(_aux_state_dict(ens), root / AUX_FILE)
|
| 63 |
+
store = _store_payload(ens)
|
| 64 |
+
if store["keys"]:
|
| 65 |
+
torch.save(store, root / STORE_FILE)
|
| 66 |
+
|
| 67 |
+
if backend == "hf":
|
| 68 |
+
llm_path = root / LLM_DIR
|
| 69 |
+
llm_path.mkdir(exist_ok=True)
|
| 70 |
+
ens.llm.model.save_pretrained(llm_path)
|
| 71 |
+
ens.llm.tokenizer.save_pretrained(llm_path)
|
| 72 |
+
else:
|
| 73 |
+
torch.save(ens.llm.state_dict(), root / TINY_LLM_FILE)
|
| 74 |
+
|
| 75 |
+
with open(root / MANIFEST_FILE, "w") as f:
|
| 76 |
+
json.dump(manifest, f, indent=2)
|
| 77 |
+
|
| 78 |
+
return root
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def is_ensemble_checkpoint(path: str | Path) -> bool:
|
| 82 |
+
return (Path(path) / MANIFEST_FILE).is_file()
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def load_checkpoint(
|
| 86 |
+
ckpt_dir: str | Path,
|
| 87 |
+
*,
|
| 88 |
+
device: str | None = None,
|
| 89 |
+
load_in_4bit: bool = False,
|
| 90 |
+
) -> Ensemble:
|
| 91 |
+
"""Restore a saved JEPA ensemble from models/ensemble/<name>/."""
|
| 92 |
+
root = Path(ckpt_dir).resolve()
|
| 93 |
+
manifest_path = root / MANIFEST_FILE
|
| 94 |
+
if not manifest_path.is_file():
|
| 95 |
+
raise FileNotFoundError(
|
| 96 |
+
f"Not an ensemble checkpoint (missing {MANIFEST_FILE}): {root}"
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
with open(manifest_path) as f:
|
| 100 |
+
manifest = json.load(f)
|
| 101 |
+
|
| 102 |
+
base_llm = manifest["base_llm"]
|
| 103 |
+
backend = manifest.get("backend", "hf")
|
| 104 |
+
adapter_names = tuple(manifest.get("adapter_names", ["general"]))
|
| 105 |
+
d_emb = manifest.get("d_emb", 64)
|
| 106 |
+
d_jepa = manifest.get("d_jepa", 64)
|
| 107 |
+
|
| 108 |
+
if backend == "tiny":
|
| 109 |
+
ens = Ensemble(
|
| 110 |
+
llm="tiny",
|
| 111 |
+
adapter_names=adapter_names,
|
| 112 |
+
d_emb=d_emb,
|
| 113 |
+
d_jepa=d_jepa,
|
| 114 |
+
)
|
| 115 |
+
tiny_state = torch.load(
|
| 116 |
+
root / TINY_LLM_FILE, map_location="cpu", weights_only=True
|
| 117 |
+
)
|
| 118 |
+
ens.llm.load_state_dict(tiny_state)
|
| 119 |
+
else:
|
| 120 |
+
llm_dir = root / LLM_DIR
|
| 121 |
+
llm_backend = load_hf_backend_from_checkpoint(
|
| 122 |
+
base_llm,
|
| 123 |
+
str(llm_dir) if llm_dir.is_dir() else None,
|
| 124 |
+
adapter_names=adapter_names,
|
| 125 |
+
device=device,
|
| 126 |
+
load_in_4bit=load_in_4bit,
|
| 127 |
+
)
|
| 128 |
+
ens = Ensemble(
|
| 129 |
+
llm=base_llm,
|
| 130 |
+
adapter_names=adapter_names,
|
| 131 |
+
d_emb=d_emb,
|
| 132 |
+
d_jepa=d_jepa,
|
| 133 |
+
llm_backend=llm_backend,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
aux = torch.load(root / AUX_FILE, map_location="cpu", weights_only=True)
|
| 137 |
+
ens.emb.load_state_dict(aux["emb"])
|
| 138 |
+
ens.jepa.load_state_dict(aux["jepa"])
|
| 139 |
+
ens.bridge.load_state_dict(aux["bridge"])
|
| 140 |
+
ens.router.load_state_dict(aux["router"])
|
| 141 |
+
|
| 142 |
+
store_path = root / STORE_FILE
|
| 143 |
+
if store_path.is_file():
|
| 144 |
+
store = torch.load(store_path, map_location="cpu", weights_only=True)
|
| 145 |
+
ens.store.keys = list(store["keys"])
|
| 146 |
+
ens.store.values = list(store["values"])
|
| 147 |
+
|
| 148 |
+
ens.eval()
|
| 149 |
+
return ens
|
research/ensemble/src/ensemble/config.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Resolve base LLM for ensemble from .env and models.yaml (same order as finetune)."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
_REPO_ROOT = Path(__file__).resolve().parents[4]
|
| 10 |
+
_FALLBACK_PRESET = "minicpm5-1b"
|
| 11 |
+
|
| 12 |
+
_ENV_LLM_KEYS = (
|
| 13 |
+
"ENSEMBLE_LLM",
|
| 14 |
+
"LLM_PATH",
|
| 15 |
+
"BASE",
|
| 16 |
+
"FINETUNE_MODEL",
|
| 17 |
+
"MODEL_ID",
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def repo_root() -> Path:
|
| 22 |
+
return _REPO_ROOT
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def load_dotenv() -> None:
|
| 26 |
+
"""Load KEY=VALUE pairs from repo .env without overriding existing env vars."""
|
| 27 |
+
path = _REPO_ROOT / ".env"
|
| 28 |
+
if not path.is_file():
|
| 29 |
+
return
|
| 30 |
+
for line in path.read_text().splitlines():
|
| 31 |
+
line = line.strip()
|
| 32 |
+
if not line or line.startswith("#") or "=" not in line:
|
| 33 |
+
continue
|
| 34 |
+
key, _, value = line.partition("=")
|
| 35 |
+
key = key.strip()
|
| 36 |
+
value = value.strip().strip('"').strip("'")
|
| 37 |
+
if key:
|
| 38 |
+
os.environ.setdefault(key, value)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _ensure_inference_on_path() -> None:
|
| 42 |
+
libs = _REPO_ROOT / "libs" / "inference" / "src"
|
| 43 |
+
if str(libs) not in sys.path:
|
| 44 |
+
sys.path.insert(0, str(libs))
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _is_ensemble_llm_preset(model) -> bool:
|
| 48 |
+
return model.backend == "transformers" and not model.multimodal and bool(
|
| 49 |
+
model.model_id
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _llm_from_local_path(raw: str) -> str | None:
|
| 54 |
+
path = Path(raw)
|
| 55 |
+
if not path.is_absolute():
|
| 56 |
+
path = (_REPO_ROOT / path).resolve()
|
| 57 |
+
if path.suffix == ".gguf":
|
| 58 |
+
return None
|
| 59 |
+
if path.is_dir() and (path / "config.json").is_file():
|
| 60 |
+
return str(path)
|
| 61 |
+
if path.is_file():
|
| 62 |
+
return None
|
| 63 |
+
return None
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _llm_from_env_paths() -> str | None:
|
| 67 |
+
for key in ("LLM_PATH", "MODEL_PATH"):
|
| 68 |
+
raw = os.environ.get(key)
|
| 69 |
+
if raw:
|
| 70 |
+
resolved = _llm_from_local_path(raw)
|
| 71 |
+
if resolved:
|
| 72 |
+
return resolved
|
| 73 |
+
return None
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def resolve_llm(
|
| 77 |
+
*,
|
| 78 |
+
llm_arg: str | None = None,
|
| 79 |
+
preset_arg: str | None = None,
|
| 80 |
+
) -> tuple[str, str | None]:
|
| 81 |
+
"""
|
| 82 |
+
Return (hub_id_or_local_path, preset_key) for ensemble HF backends.
|
| 83 |
+
|
| 84 |
+
Priority when llm_arg is None or ``auto``:
|
| 85 |
+
1. ENSEMBLE_LLM, LLM_PATH (local HF dir), BASE, FINETUNE_MODEL, MODEL_ID
|
| 86 |
+
2. MODEL_PATH if it points at a HuggingFace model directory (not .gguf)
|
| 87 |
+
3. ENSEMBLE_PRESET, FINETUNE_PRESET, or ACTIVE_MODEL from models.yaml
|
| 88 |
+
4. First fine-tunable transformers preset (default minicpm5-1b)
|
| 89 |
+
"""
|
| 90 |
+
if llm_arg and llm_arg not in ("auto",):
|
| 91 |
+
return llm_arg, preset_arg
|
| 92 |
+
|
| 93 |
+
for env_name in _ENV_LLM_KEYS:
|
| 94 |
+
raw = os.environ.get(env_name)
|
| 95 |
+
if raw:
|
| 96 |
+
local = _llm_from_local_path(raw)
|
| 97 |
+
return local or raw, preset_arg
|
| 98 |
+
|
| 99 |
+
local = _llm_from_env_paths()
|
| 100 |
+
if local:
|
| 101 |
+
return local, preset_arg
|
| 102 |
+
|
| 103 |
+
_ensure_inference_on_path()
|
| 104 |
+
from inference.config import get_app_config, get_model_config
|
| 105 |
+
|
| 106 |
+
app_config = get_app_config(reload=True)
|
| 107 |
+
preset_key = (
|
| 108 |
+
preset_arg
|
| 109 |
+
or os.environ.get("ENSEMBLE_PRESET")
|
| 110 |
+
or os.environ.get("FINETUNE_PRESET")
|
| 111 |
+
or os.environ.get("ACTIVE_MODEL")
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
if preset_key and preset_key in app_config.models:
|
| 115 |
+
model = get_model_config(preset_key)
|
| 116 |
+
if not _is_ensemble_llm_preset(model):
|
| 117 |
+
preset_key = None
|
| 118 |
+
|
| 119 |
+
if preset_key is None:
|
| 120 |
+
for candidate in (_FALLBACK_PRESET, *app_config.models):
|
| 121 |
+
if candidate not in app_config.models:
|
| 122 |
+
continue
|
| 123 |
+
model = get_model_config(candidate)
|
| 124 |
+
if _is_ensemble_llm_preset(model):
|
| 125 |
+
preset_key = candidate
|
| 126 |
+
break
|
| 127 |
+
|
| 128 |
+
if not preset_key:
|
| 129 |
+
raise SystemExit(
|
| 130 |
+
"No transformers LLM found for ensemble. Pass --llm, set LLM_PATH/BASE/"
|
| 131 |
+
"MODEL_ID in .env, or ACTIVE_MODEL in models.yaml."
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
model = get_model_config(preset_key)
|
| 135 |
+
if not _is_ensemble_llm_preset(model):
|
| 136 |
+
raise SystemExit(
|
| 137 |
+
f"Preset {preset_key!r} cannot back an ensemble "
|
| 138 |
+
f"(backend={model.backend}, multimodal={model.multimodal})."
|
| 139 |
+
)
|
| 140 |
+
return model.model_id, preset_key
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def default_ensemble_out(preset_key: str | None) -> str:
|
| 144 |
+
label = preset_key or "custom"
|
| 145 |
+
return str((_REPO_ROOT / "models" / "ensemble" / f"{label}-jepa-pretrain").resolve())
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def resolve_llm_cli(
|
| 149 |
+
llm: str | None,
|
| 150 |
+
*,
|
| 151 |
+
toy: bool = False,
|
| 152 |
+
preset: str | None = None,
|
| 153 |
+
) -> str:
|
| 154 |
+
"""CLI helper: explicit tiny, else .env / models.yaml unless --toy without --llm."""
|
| 155 |
+
if llm == "tiny":
|
| 156 |
+
return "tiny"
|
| 157 |
+
if llm is None or llm == "auto":
|
| 158 |
+
if toy:
|
| 159 |
+
return "tiny"
|
| 160 |
+
load_dotenv()
|
| 161 |
+
resolved, _ = resolve_llm(preset_arg=preset)
|
| 162 |
+
return resolved
|
| 163 |
+
return llm
|
research/ensemble/src/ensemble/energy.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Energy model: score candidate latents against world state."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class EnergyModel(nn.Module):
|
| 11 |
+
def __init__(self, d_latent: int):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.net = nn.Sequential(
|
| 14 |
+
nn.Linear(2 * d_latent, 2 * d_latent),
|
| 15 |
+
nn.GELU(),
|
| 16 |
+
nn.Linear(2 * d_latent, d_latent),
|
| 17 |
+
nn.GELU(),
|
| 18 |
+
nn.Linear(d_latent, 1),
|
| 19 |
+
)
|
| 20 |
+
self.d_latent = d_latent
|
| 21 |
+
|
| 22 |
+
def energy(self, s, z):
|
| 23 |
+
return self.net(torch.cat([s, z], -1)).squeeze(-1)
|
| 24 |
+
|
| 25 |
+
def contrastive_loss(self, s, z_pos, z_negs=None, tau=0.5):
|
| 26 |
+
B = s.size(0)
|
| 27 |
+
s_rep = s.unsqueeze(1).expand(B, B, self.d_latent).reshape(
|
| 28 |
+
B * B, self.d_latent
|
| 29 |
+
)
|
| 30 |
+
z_rep = z_pos.unsqueeze(0).expand(B, B, self.d_latent).reshape(
|
| 31 |
+
B * B, self.d_latent
|
| 32 |
+
)
|
| 33 |
+
E = self.energy(s_rep, z_rep).view(B, B)
|
| 34 |
+
if z_negs is not None:
|
| 35 |
+
En = self.energy(
|
| 36 |
+
s.repeat_interleave(z_negs.size(1), 0),
|
| 37 |
+
z_negs.reshape(-1, self.d_latent),
|
| 38 |
+
).view(B, -1)
|
| 39 |
+
E = torch.cat([E, En], dim=1)
|
| 40 |
+
labels = torch.arange(B, device=s.device)
|
| 41 |
+
return F.cross_entropy(-E / tau, labels)
|
| 42 |
+
|
| 43 |
+
@torch.no_grad()
|
| 44 |
+
def rank(self, s, candidates):
|
| 45 |
+
return self.energy(s.expand(candidates.size(0), -1), candidates)
|
research/ensemble/src/ensemble/eval/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Evaluation harnesses for JEPA and world-model ensembles."""
|
research/ensemble/src/ensemble/eval/jepa_harness.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Ablation ladder + JEPA best-of-N benchmark for the ensemble."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import json
|
| 7 |
+
import random
|
| 8 |
+
import time
|
| 9 |
+
from collections import defaultdict
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
|
| 14 |
+
from ensemble.eval.metrics import em_score, f1_score, paired_bootstrap
|
| 15 |
+
from ensemble.backends import TinyBackend
|
| 16 |
+
from ensemble.checkpoint import load_checkpoint
|
| 17 |
+
from ensemble.config import load_dotenv, resolve_llm_cli
|
| 18 |
+
from ensemble.jepa_ensemble import Ensemble
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@torch.no_grad()
|
| 22 |
+
def generate_plain(ens, q_ids, n_new):
|
| 23 |
+
ens.llm.set_adapter(ens.adapter_names[0])
|
| 24 |
+
t0 = time.time()
|
| 25 |
+
out = ens.llm.generate(q_ids.to(ens.llm.device), n_new=n_new, temperature=0.7)
|
| 26 |
+
return out[:, q_ids.size(1) :], time.time() - t0
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@torch.no_grad()
|
| 30 |
+
def generate_config(
|
| 31 |
+
ens, q_ids, n_new, *, use_rag, use_router, use_jepa, n_drafts=1, tau=0.0
|
| 32 |
+
):
|
| 33 |
+
q_emb = ens.emb(q_ids.cpu())
|
| 34 |
+
|
| 35 |
+
if use_router:
|
| 36 |
+
a_idx = ens.router(q_emb).item()
|
| 37 |
+
ens.llm.set_adapter(ens.adapter_names[a_idx])
|
| 38 |
+
else:
|
| 39 |
+
ens.llm.set_adapter(ens.adapter_names[0])
|
| 40 |
+
|
| 41 |
+
ctx = q_ids.cpu()
|
| 42 |
+
if use_rag:
|
| 43 |
+
mems = ens.store.search(q_emb, k=1)
|
| 44 |
+
if mems:
|
| 45 |
+
ctx = torch.cat([mems[0], ctx], dim=1)
|
| 46 |
+
|
| 47 |
+
t0 = time.time()
|
| 48 |
+
if not use_jepa:
|
| 49 |
+
out = ens.llm.generate(
|
| 50 |
+
ctx.to(ens.llm.device), n_new=n_new, temperature=0.7
|
| 51 |
+
)
|
| 52 |
+
return out[:, ctx.size(1) :], time.time() - t0, None
|
| 53 |
+
|
| 54 |
+
z_exp = ens.jepa.predict_next_latent(ctx)
|
| 55 |
+
drafts, scores = [], []
|
| 56 |
+
for _ in range(n_drafts):
|
| 57 |
+
out = ens.llm.generate(
|
| 58 |
+
ctx.to(ens.llm.device), n_new=n_new, temperature=0.9
|
| 59 |
+
)
|
| 60 |
+
new = out[:, ctx.size(1) :].cpu()
|
| 61 |
+
drafts.append(new)
|
| 62 |
+
scores.append(
|
| 63 |
+
F.cosine_similarity(z_exp, ens.jepa.encode(new)).item()
|
| 64 |
+
)
|
| 65 |
+
best = max(range(n_drafts), key=lambda i: scores[i])
|
| 66 |
+
return drafts[best], time.time() - t0, (drafts, scores)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def selector_comparison(drafts_scores_gold, decode_fn, rng):
|
| 70 |
+
res = defaultdict(list)
|
| 71 |
+
for drafts, scores, gold in drafts_scores_gold:
|
| 72 |
+
texts = [decode_fn(d) for d in drafts]
|
| 73 |
+
ems = [em_score(t, gold) for t in texts]
|
| 74 |
+
res["first"].append(ems[0])
|
| 75 |
+
res["random"].append(ems[rng.randrange(len(ems))])
|
| 76 |
+
res["jepa"].append(ems[max(range(len(ems)), key=lambda i: scores[i])])
|
| 77 |
+
res["oracle"].append(max(ems))
|
| 78 |
+
return {k: sum(v) / len(v) for k, v in res.items()}, res
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def load_jsonl(path):
|
| 82 |
+
with open(path) as f:
|
| 83 |
+
return [json.loads(line) for line in f if line.strip()]
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def make_toy_data(ens, n_qa=20, vocab=None):
|
| 87 |
+
vocab = vocab or ens.llm.vocab_size
|
| 88 |
+
qa, kb = [], []
|
| 89 |
+
for _ in range(n_qa):
|
| 90 |
+
key = torch.randint(0, vocab, (1, 6))
|
| 91 |
+
ans = torch.randint(0, vocab, (1, 4))
|
| 92 |
+
kb.append(torch.cat([key, ans], dim=1))
|
| 93 |
+
qa.append({"q_ids": key, "answer_ids": ans})
|
| 94 |
+
return qa, kb
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def run(args):
|
| 98 |
+
torch.manual_seed(args.seed)
|
| 99 |
+
rng = random.Random(args.seed)
|
| 100 |
+
|
| 101 |
+
if args.ckpt:
|
| 102 |
+
ens = load_checkpoint(args.ckpt)
|
| 103 |
+
print(f"loaded ensemble checkpoint: {args.ckpt}")
|
| 104 |
+
is_text = not isinstance(ens.llm, TinyBackend)
|
| 105 |
+
else:
|
| 106 |
+
load_dotenv()
|
| 107 |
+
args.llm = resolve_llm_cli(
|
| 108 |
+
args.llm, toy=args.toy, preset=getattr(args, "preset", None)
|
| 109 |
+
)
|
| 110 |
+
print(f"Resolved LLM: {args.llm}")
|
| 111 |
+
ens = Ensemble(llm=args.llm)
|
| 112 |
+
is_text = args.llm != "tiny"
|
| 113 |
+
|
| 114 |
+
if args.toy or not is_text:
|
| 115 |
+
qa, kb = make_toy_data(ens)
|
| 116 |
+
for mem in kb:
|
| 117 |
+
ens.memorize_ids(mem)
|
| 118 |
+
|
| 119 |
+
def to_ids(item):
|
| 120 |
+
return item["q_ids"]
|
| 121 |
+
|
| 122 |
+
def gold_text(item):
|
| 123 |
+
return " ".join(map(str, item["answer_ids"][0].tolist()))
|
| 124 |
+
|
| 125 |
+
def decode(ids):
|
| 126 |
+
return " ".join(map(str, ids[0].tolist()))
|
| 127 |
+
else:
|
| 128 |
+
qa = load_jsonl(args.qa)
|
| 129 |
+
if args.kb:
|
| 130 |
+
for row in load_jsonl(args.kb):
|
| 131 |
+
ens.memorize_text(row["text"])
|
| 132 |
+
|
| 133 |
+
def to_ids(item):
|
| 134 |
+
return ens.llm.encode_text(
|
| 135 |
+
f"Answer briefly.\nQ: {item['question']}\nA:"
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
def gold_text(item):
|
| 139 |
+
return item["answer"]
|
| 140 |
+
|
| 141 |
+
def decode(ids):
|
| 142 |
+
return ens.llm.decode(ids)
|
| 143 |
+
|
| 144 |
+
qa = qa[: args.limit]
|
| 145 |
+
print(
|
| 146 |
+
f"eval set: {len(qa)} questions | store: {len(ens.store.keys)} memories\n"
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
configs = {
|
| 150 |
+
"C1_base": dict(use_rag=False, use_router=False, use_jepa=False),
|
| 151 |
+
"C2_rag": dict(use_rag=True, use_router=False, use_jepa=False),
|
| 152 |
+
"C3_rag_router": dict(use_rag=True, use_router=True, use_jepa=False),
|
| 153 |
+
"C4_full_jepa": dict(
|
| 154 |
+
use_rag=True,
|
| 155 |
+
use_router=True,
|
| 156 |
+
use_jepa=True,
|
| 157 |
+
n_drafts=args.n_drafts,
|
| 158 |
+
),
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
per_q = {}
|
| 162 |
+
summary = {}
|
| 163 |
+
jepa_material = []
|
| 164 |
+
|
| 165 |
+
for name, cfg in configs.items():
|
| 166 |
+
ems, f1s, lats = [], [], []
|
| 167 |
+
for item in qa:
|
| 168 |
+
ids = to_ids(item)
|
| 169 |
+
if name == "C1_base":
|
| 170 |
+
out, dt = generate_plain(ens, ids, args.n_new)
|
| 171 |
+
extra = None
|
| 172 |
+
else:
|
| 173 |
+
out, dt, extra = generate_config(ens, ids, args.n_new, **cfg)
|
| 174 |
+
pred, gold = decode(out), gold_text(item)
|
| 175 |
+
ems.append(em_score(pred, gold))
|
| 176 |
+
f1s.append(f1_score(pred, gold))
|
| 177 |
+
lats.append(dt)
|
| 178 |
+
if name == "C4_full_jepa" and extra is not None:
|
| 179 |
+
jepa_material.append((extra[0], extra[1], gold))
|
| 180 |
+
per_q[name] = ems
|
| 181 |
+
summary[name] = (
|
| 182 |
+
sum(ems) / len(ems),
|
| 183 |
+
sum(f1s) / len(f1s),
|
| 184 |
+
sum(lats) / len(lats),
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
print(f"{'config':<16}{'EM':>8}{'F1':>8}{'lat(s)':>9}")
|
| 188 |
+
for k, (em, f1, lat) in summary.items():
|
| 189 |
+
print(f"{k:<16}{em:>8.3f}{f1:>8.3f}{lat:>9.3f}")
|
| 190 |
+
|
| 191 |
+
print("\ncomponent contributions (paired bootstrap, P(B>A)):")
|
| 192 |
+
ladder = list(configs.keys())
|
| 193 |
+
for a, b in zip(ladder, ladder[1:]):
|
| 194 |
+
d = summary[b][0] - summary[a][0]
|
| 195 |
+
p = paired_bootstrap(per_q[a], per_q[b])
|
| 196 |
+
print(f" {b} - {a}: ΞEM={d:+.3f} P(better)={p:.2f}")
|
| 197 |
+
|
| 198 |
+
if jepa_material:
|
| 199 |
+
sel, sel_per_q = selector_comparison(jepa_material, decode, rng)
|
| 200 |
+
print(
|
| 201 |
+
f"\nbest-of-N selector comparison (same drafts, N={args.n_drafts}):"
|
| 202 |
+
)
|
| 203 |
+
for k in ("first", "random", "jepa", "oracle"):
|
| 204 |
+
print(f" {k:<8}EM={sel[k]:.3f}")
|
| 205 |
+
p = paired_bootstrap(sel_per_q["random"], sel_per_q["jepa"])
|
| 206 |
+
verdict = (
|
| 207 |
+
"JEPA critic WORKS"
|
| 208 |
+
if p > 0.95
|
| 209 |
+
else "inconclusive β critic ~ random"
|
| 210 |
+
)
|
| 211 |
+
print(f" P(jepa > random) = {p:.2f} {verdict}")
|
| 212 |
+
print(f" headroom to oracle: {sel['oracle'] - sel['jepa']:.3f}")
|
| 213 |
+
|
| 214 |
+
if args.continual:
|
| 215 |
+
print(
|
| 216 |
+
"\ncontinual test: accuracy on task-A questions "
|
| 217 |
+
"before vs after adding adapters B and C"
|
| 218 |
+
)
|
| 219 |
+
ems_before = per_q["C3_rag_router"]
|
| 220 |
+
ens.new_task_adapter("task_B")
|
| 221 |
+
ens.new_task_adapter("task_C")
|
| 222 |
+
ems_after = []
|
| 223 |
+
for item in qa:
|
| 224 |
+
out, _, _ = generate_config(
|
| 225 |
+
ens,
|
| 226 |
+
to_ids(item),
|
| 227 |
+
args.n_new,
|
| 228 |
+
use_rag=True,
|
| 229 |
+
use_router=True,
|
| 230 |
+
use_jepa=False,
|
| 231 |
+
)
|
| 232 |
+
ems_after.append(em_score(decode(out), gold_text(item)))
|
| 233 |
+
bt = sum(ems_after) / len(ems_after) - sum(ems_before) / len(
|
| 234 |
+
ems_before
|
| 235 |
+
)
|
| 236 |
+
print(f" backward transfer (β0 is ideal): {bt:+.3f}")
|
| 237 |
+
|
| 238 |
+
return summary
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def parse_args():
|
| 242 |
+
p = argparse.ArgumentParser()
|
| 243 |
+
p.add_argument(
|
| 244 |
+
"--llm",
|
| 245 |
+
default=None,
|
| 246 |
+
help="HF id / path, 'tiny', or omit for LLM_PATH / ACTIVE_MODEL from .env",
|
| 247 |
+
)
|
| 248 |
+
p.add_argument("--preset", default=None, help="models.yaml preset override")
|
| 249 |
+
p.add_argument("--qa", default=None, help="jsonl with question/answer")
|
| 250 |
+
p.add_argument("--kb", default=None, help="jsonl with text -> vector store")
|
| 251 |
+
p.add_argument(
|
| 252 |
+
"--ckpt",
|
| 253 |
+
default=None,
|
| 254 |
+
help="saved ensemble directory (models/ensemble/... with manifest.json)",
|
| 255 |
+
)
|
| 256 |
+
p.add_argument("--toy", action="store_true", help="synthetic data smoke test")
|
| 257 |
+
p.add_argument("--limit", type=int, default=100)
|
| 258 |
+
p.add_argument("--n_new", type=int, default=24)
|
| 259 |
+
p.add_argument("--n_drafts", type=int, default=8)
|
| 260 |
+
p.add_argument("--continual", action="store_true")
|
| 261 |
+
p.add_argument("--seed", type=int, default=0)
|
| 262 |
+
return p.parse_args()
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
if __name__ == "__main__":
|
| 266 |
+
run(parse_args())
|
research/ensemble/src/ensemble/eval/metrics.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""QA metrics and paired bootstrap significance."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import random
|
| 6 |
+
import re
|
| 7 |
+
import string
|
| 8 |
+
from collections import Counter
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def normalize(s: str) -> str:
|
| 12 |
+
s = s.lower()
|
| 13 |
+
s = "".join(c for c in s if c not in string.punctuation)
|
| 14 |
+
s = re.sub(r"\b(a|an|the)\b", " ", s)
|
| 15 |
+
return " ".join(s.split())
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def em_score(pred: str, gold: str) -> float:
|
| 19 |
+
return float(normalize(gold) in normalize(pred))
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def f1_score(pred: str, gold: str) -> float:
|
| 23 |
+
p, g = normalize(pred).split(), normalize(gold).split()
|
| 24 |
+
if not p or not g:
|
| 25 |
+
return float(p == g)
|
| 26 |
+
common = Counter(p) & Counter(g)
|
| 27 |
+
overlap = sum(common.values())
|
| 28 |
+
if overlap == 0:
|
| 29 |
+
return 0.0
|
| 30 |
+
prec, rec = overlap / len(p), overlap / len(g)
|
| 31 |
+
return 2 * prec * rec / (prec + rec)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def paired_bootstrap(scores_a, scores_b, iters=2000, seed=0):
|
| 35 |
+
rng = random.Random(seed)
|
| 36 |
+
n, wins = len(scores_a), 0
|
| 37 |
+
for _ in range(iters):
|
| 38 |
+
idx = [rng.randrange(n) for _ in range(n)]
|
| 39 |
+
da = sum(scores_a[i] for i in idx) / n
|
| 40 |
+
db = sum(scores_b[i] for i in idx) / n
|
| 41 |
+
wins += db > da
|
| 42 |
+
return wins / iters
|
research/ensemble/src/ensemble/eval/world_harness.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Energy-based draft selector benchmark for the world-model ensemble."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import json
|
| 7 |
+
import random
|
| 8 |
+
import time
|
| 9 |
+
from collections import defaultdict
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
from ensemble.eval.metrics import em_score, f1_score, paired_bootstrap
|
| 14 |
+
from ensemble.world_ensemble import WorldEnsemble
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@torch.no_grad()
|
| 18 |
+
def generate_drafts(ens, q_ids, n_new, n_drafts, use_rag=True):
|
| 19 |
+
q_emb = ens.emb(q_ids.cpu())
|
| 20 |
+
mems = ens.store.search(q_emb, k=1) if use_rag else []
|
| 21 |
+
segments = (mems + [q_ids.cpu()]) if mems else [q_ids.cpu()]
|
| 22 |
+
ctx = torch.cat(segments, dim=1)
|
| 23 |
+
|
| 24 |
+
s = ens.world_state(segments)
|
| 25 |
+
ens.world.rollout(s, horizon=3)
|
| 26 |
+
|
| 27 |
+
drafts, energies = [], []
|
| 28 |
+
t0 = time.time()
|
| 29 |
+
for _ in range(n_drafts):
|
| 30 |
+
out = ens.llm.generate(
|
| 31 |
+
ctx.to(ens.llm.device), n_new=n_new, temperature=0.9
|
| 32 |
+
)
|
| 33 |
+
new = out[:, ctx.size(1) :].cpu()
|
| 34 |
+
drafts.append(new)
|
| 35 |
+
z = ens.jepa.encode(new)
|
| 36 |
+
energies.append(ens.energy.rank(s, z).item())
|
| 37 |
+
return drafts, energies, time.time() - t0
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def selector_comparison(drafts_energy_gold, decode_fn, rng):
|
| 41 |
+
res = defaultdict(list)
|
| 42 |
+
for drafts, energies, gold in drafts_energy_gold:
|
| 43 |
+
texts = [decode_fn(d) for d in drafts]
|
| 44 |
+
ems = [em_score(t, gold) for t in texts]
|
| 45 |
+
res["first"].append(ems[0])
|
| 46 |
+
res["random"].append(ems[rng.randrange(len(ems))])
|
| 47 |
+
res["energy"].append(
|
| 48 |
+
ems[min(range(len(ems)), key=lambda i: energies[i])]
|
| 49 |
+
)
|
| 50 |
+
res["oracle"].append(max(ems))
|
| 51 |
+
return {k: sum(v) / len(v) for k, v in res.items()}, res
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def load_jsonl(path):
|
| 55 |
+
with open(path) as f:
|
| 56 |
+
return [json.loads(line) for line in f if line.strip()]
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def make_toy_data(ens, n_qa=20, vocab=None):
|
| 60 |
+
vocab = vocab or ens.llm.vocab_size
|
| 61 |
+
qa, kb = [], []
|
| 62 |
+
for _ in range(n_qa):
|
| 63 |
+
key = torch.randint(0, vocab, (1, 6))
|
| 64 |
+
ans = torch.randint(0, vocab, (1, 4))
|
| 65 |
+
kb.append(torch.cat([key, ans], dim=1))
|
| 66 |
+
qa.append({"q_ids": key, "answer_ids": ans})
|
| 67 |
+
return qa, kb
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def run(args):
|
| 71 |
+
from ensemble.config import load_dotenv, resolve_llm_cli
|
| 72 |
+
|
| 73 |
+
torch.manual_seed(args.seed)
|
| 74 |
+
rng = random.Random(args.seed)
|
| 75 |
+
|
| 76 |
+
load_dotenv()
|
| 77 |
+
args.llm = resolve_llm_cli(
|
| 78 |
+
args.llm, toy=args.toy, preset=getattr(args, "preset", None)
|
| 79 |
+
)
|
| 80 |
+
print(f"Resolved LLM: {args.llm}")
|
| 81 |
+
ens = WorldEnsemble(args.llm)
|
| 82 |
+
if args.ckpt:
|
| 83 |
+
state = torch.load(args.ckpt, map_location="cpu")
|
| 84 |
+
ens.load_state_dict(state, strict=False)
|
| 85 |
+
print(f"loaded world ensemble checkpoint: {args.ckpt}")
|
| 86 |
+
|
| 87 |
+
is_text = args.llm != "tiny"
|
| 88 |
+
|
| 89 |
+
if args.toy or not is_text:
|
| 90 |
+
qa, kb = make_toy_data(ens)
|
| 91 |
+
for mem in kb:
|
| 92 |
+
ens.memorize(mem)
|
| 93 |
+
|
| 94 |
+
def to_ids(item):
|
| 95 |
+
return item["q_ids"]
|
| 96 |
+
|
| 97 |
+
def gold_text(item):
|
| 98 |
+
return " ".join(map(str, item["answer_ids"][0].tolist()))
|
| 99 |
+
|
| 100 |
+
def decode(ids):
|
| 101 |
+
return " ".join(map(str, ids[0].tolist()))
|
| 102 |
+
else:
|
| 103 |
+
qa = load_jsonl(args.qa)
|
| 104 |
+
if args.kb:
|
| 105 |
+
for row in load_jsonl(args.kb):
|
| 106 |
+
ids = ens.llm.tokenizer(
|
| 107 |
+
row["text"], return_tensors="pt"
|
| 108 |
+
).input_ids
|
| 109 |
+
ens.memorize(ids)
|
| 110 |
+
|
| 111 |
+
def to_ids(item):
|
| 112 |
+
return ens.llm.tokenizer(
|
| 113 |
+
f"Answer briefly.\nQ: {item['question']}\nA:",
|
| 114 |
+
return_tensors="pt",
|
| 115 |
+
).input_ids
|
| 116 |
+
|
| 117 |
+
def gold_text(item):
|
| 118 |
+
return item["answer"]
|
| 119 |
+
|
| 120 |
+
def decode(ids):
|
| 121 |
+
return ens.llm.tokenizer.decode(ids[0], skip_special_tokens=True)
|
| 122 |
+
|
| 123 |
+
qa = qa[: args.limit]
|
| 124 |
+
print(
|
| 125 |
+
f"eval set: {len(qa)} questions | store: {len(ens.store.keys)} memories\n"
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
material = []
|
| 129 |
+
lats = []
|
| 130 |
+
for item in qa:
|
| 131 |
+
drafts, energies, dt = generate_drafts(
|
| 132 |
+
ens, to_ids(item), args.n_new, args.n_drafts
|
| 133 |
+
)
|
| 134 |
+
material.append((drafts, energies, gold_text(item)))
|
| 135 |
+
lats.append(dt)
|
| 136 |
+
|
| 137 |
+
sel, sel_per_q = selector_comparison(material, decode, rng)
|
| 138 |
+
print(f"best-of-N selector comparison (same drafts, N={args.n_drafts}):")
|
| 139 |
+
for k in ("first", "random", "energy", "oracle"):
|
| 140 |
+
print(f" {k:<8}EM={sel[k]:.3f}")
|
| 141 |
+
p = paired_bootstrap(sel_per_q["random"], sel_per_q["energy"])
|
| 142 |
+
verdict = (
|
| 143 |
+
"Energy critic WORKS"
|
| 144 |
+
if p > 0.95
|
| 145 |
+
else "inconclusive β critic ~ random"
|
| 146 |
+
)
|
| 147 |
+
print(f" P(energy > random) = {p:.2f} {verdict}")
|
| 148 |
+
print(f" headroom to oracle: {sel['oracle'] - sel['energy']:.3f}")
|
| 149 |
+
print(f" mean latency: {sum(lats) / len(lats):.3f}s")
|
| 150 |
+
|
| 151 |
+
return sel
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def parse_args():
|
| 155 |
+
p = argparse.ArgumentParser()
|
| 156 |
+
p.add_argument(
|
| 157 |
+
"--llm",
|
| 158 |
+
default=None,
|
| 159 |
+
help="HF id / path, 'tiny', or omit for LLM_PATH / ACTIVE_MODEL from .env",
|
| 160 |
+
)
|
| 161 |
+
p.add_argument("--preset", default=None, help="models.yaml preset override")
|
| 162 |
+
p.add_argument("--qa", default=None, help="jsonl with question/answer")
|
| 163 |
+
p.add_argument("--kb", default=None, help="jsonl with text -> vector store")
|
| 164 |
+
p.add_argument("--ckpt", default=None, help="trained world ensemble .pt")
|
| 165 |
+
p.add_argument("--toy", action="store_true", help="synthetic data smoke test")
|
| 166 |
+
p.add_argument("--limit", type=int, default=100)
|
| 167 |
+
p.add_argument("--n_new", type=int, default=24)
|
| 168 |
+
p.add_argument("--n_drafts", type=int, default=8)
|
| 169 |
+
p.add_argument("--seed", type=int, default=0)
|
| 170 |
+
return p.parse_args()
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
if __name__ == "__main__":
|
| 174 |
+
run(parse_args())
|
research/ensemble/src/ensemble/eval_harness.py
ADDED
|
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
eval_harness.py β Ablation ladder + JEPA best-of-N test for the ensemble
|
| 3 |
+
========================================================================
|
| 4 |
+
Companion to `llm_emb_jepa_ensemble_pluggable.py` (must be importable,
|
| 5 |
+
i.e. in the same directory).
|
| 6 |
+
|
| 7 |
+
What it runs
|
| 8 |
+
------------
|
| 9 |
+
1. ABLATION LADDER on a QA set:
|
| 10 |
+
C1 base LLM alone
|
| 11 |
+
C2 C1 + RAG (embedding retrieval)
|
| 12 |
+
C3 C2 + router/adapters
|
| 13 |
+
C4 C3 + JEPA best-of-N critic
|
| 14 |
+
(C5 = C4 with a bridge-trained checkpoint β just pass --ckpt)
|
| 15 |
+
|
| 16 |
+
2. BEST-OF-N SELECTOR comparison (the decisive JEPA experiment):
|
| 17 |
+
first-sample | random-pick | JEPA-score pick | oracle pick
|
| 18 |
+
All on the SAME N drafts per question, so differences are pure selection.
|
| 19 |
+
|
| 20 |
+
3. CONTINUAL FORGETTING test (optional, --continual):
|
| 21 |
+
accuracy on task A before vs after training adapters for B and C.
|
| 22 |
+
|
| 23 |
+
4. PAIRED BOOTSTRAP significance between any two configs.
|
| 24 |
+
|
| 25 |
+
Usage
|
| 26 |
+
-----
|
| 27 |
+
# Smoke test, no GPU/deps beyond torch (toy backend, synthetic QA):
|
| 28 |
+
python eval_harness.py --llm tiny --toy
|
| 29 |
+
|
| 30 |
+
# Real model + your QA file (jsonl: {"question": ..., "answer": ..., "context": optional}):
|
| 31 |
+
python eval_harness.py --llm Qwen/Qwen2.5-0.5B-Instruct \
|
| 32 |
+
--qa ./domain_qa.jsonl --kb ./knowledge.jsonl --n_drafts 8
|
| 33 |
+
|
| 34 |
+
# With a bridge-trained ensemble checkpoint (C5):
|
| 35 |
+
python eval_harness.py --llm /models/llama-3.2-1b --qa ./qa.jsonl \
|
| 36 |
+
--kb ./kb.jsonl --ckpt ./ensemble_bridge.pt
|
| 37 |
+
|
| 38 |
+
QA file: {"question": str, "answer": str, "domain": optional str}
|
| 39 |
+
KB file: {"text": str} (each line becomes one memory in the vector store)
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
import argparse
|
| 43 |
+
import json
|
| 44 |
+
import random
|
| 45 |
+
import re
|
| 46 |
+
import string
|
| 47 |
+
import time
|
| 48 |
+
from collections import Counter, defaultdict
|
| 49 |
+
|
| 50 |
+
import torch
|
| 51 |
+
|
| 52 |
+
from llm_emb_jepa_ensemble_pluggable import Ensemble # same directory
|
| 53 |
+
|
| 54 |
+
# ----------------------------------------------------------------------------
|
| 55 |
+
# Metrics: normalized exact match + token F1 (SQuAD-style)
|
| 56 |
+
# ----------------------------------------------------------------------------
|
| 57 |
+
def normalize(s: str) -> str:
|
| 58 |
+
s = s.lower()
|
| 59 |
+
s = "".join(c for c in s if c not in string.punctuation)
|
| 60 |
+
s = re.sub(r"\b(a|an|the)\b", " ", s)
|
| 61 |
+
return " ".join(s.split())
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def em_score(pred: str, gold: str) -> float:
|
| 65 |
+
return float(normalize(gold) in normalize(pred)) # containment EM
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def f1_score(pred: str, gold: str) -> float:
|
| 69 |
+
p, g = normalize(pred).split(), normalize(gold).split()
|
| 70 |
+
if not p or not g:
|
| 71 |
+
return float(p == g)
|
| 72 |
+
common = Counter(p) & Counter(g)
|
| 73 |
+
overlap = sum(common.values())
|
| 74 |
+
if overlap == 0:
|
| 75 |
+
return 0.0
|
| 76 |
+
prec, rec = overlap / len(p), overlap / len(g)
|
| 77 |
+
return 2 * prec * rec / (prec + rec)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
# ----------------------------------------------------------------------------
|
| 81 |
+
# Paired bootstrap: P(config B beats config A)
|
| 82 |
+
# ----------------------------------------------------------------------------
|
| 83 |
+
def paired_bootstrap(scores_a, scores_b, iters=2000, seed=0):
|
| 84 |
+
rng = random.Random(seed)
|
| 85 |
+
n, wins = len(scores_a), 0
|
| 86 |
+
for _ in range(iters):
|
| 87 |
+
idx = [rng.randrange(n) for _ in range(n)]
|
| 88 |
+
da = sum(scores_a[i] for i in idx) / n
|
| 89 |
+
db = sum(scores_b[i] for i in idx) / n
|
| 90 |
+
wins += db > da
|
| 91 |
+
return wins / iters
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
# ----------------------------------------------------------------------------
|
| 95 |
+
# Config runners β each returns per-question dicts
|
| 96 |
+
# ----------------------------------------------------------------------------
|
| 97 |
+
@torch.no_grad()
|
| 98 |
+
def generate_plain(ens, q_ids, n_new):
|
| 99 |
+
"""C1: base adapter, no retrieval, single sample."""
|
| 100 |
+
ens.llm.set_adapter(ens.adapter_names[0])
|
| 101 |
+
t0 = time.time()
|
| 102 |
+
out = ens.llm.generate(q_ids.to(ens.llm.device), n_new=n_new, temperature=0.7)
|
| 103 |
+
return out[:, q_ids.size(1):], time.time() - t0
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
@torch.no_grad()
|
| 107 |
+
def generate_config(ens, q_ids, n_new, *, use_rag, use_router, use_jepa,
|
| 108 |
+
n_drafts=1, tau=0.0):
|
| 109 |
+
"""Unified runner for C2/C3/C4."""
|
| 110 |
+
q_emb = ens.emb(q_ids.cpu())
|
| 111 |
+
|
| 112 |
+
if use_router:
|
| 113 |
+
a_idx = ens.router(q_emb).item()
|
| 114 |
+
ens.llm.set_adapter(ens.adapter_names[a_idx])
|
| 115 |
+
else:
|
| 116 |
+
ens.llm.set_adapter(ens.adapter_names[0])
|
| 117 |
+
|
| 118 |
+
ctx = q_ids.cpu()
|
| 119 |
+
if use_rag:
|
| 120 |
+
mems = ens.store.search(q_emb, k=1)
|
| 121 |
+
if mems:
|
| 122 |
+
ctx = torch.cat([mems[0], ctx], dim=1)
|
| 123 |
+
|
| 124 |
+
t0 = time.time()
|
| 125 |
+
if not use_jepa:
|
| 126 |
+
out = ens.llm.generate(ctx.to(ens.llm.device), n_new=n_new, temperature=0.7)
|
| 127 |
+
return out[:, ctx.size(1):], time.time() - t0, None
|
| 128 |
+
|
| 129 |
+
# JEPA best-of-N: sample drafts, keep the one closest to predicted latent
|
| 130 |
+
z_exp = ens.jepa.predict_next_latent(ctx)
|
| 131 |
+
drafts, scores = [], []
|
| 132 |
+
for _ in range(n_drafts):
|
| 133 |
+
out = ens.llm.generate(ctx.to(ens.llm.device), n_new=n_new, temperature=0.9)
|
| 134 |
+
new = out[:, ctx.size(1):].cpu()
|
| 135 |
+
drafts.append(new)
|
| 136 |
+
scores.append(torch.nn.functional.cosine_similarity(
|
| 137 |
+
z_exp, ens.jepa.encode(new)).item())
|
| 138 |
+
best = max(range(n_drafts), key=lambda i: scores[i])
|
| 139 |
+
return drafts[best], time.time() - t0, (drafts, scores)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
# ----------------------------------------------------------------------------
|
| 143 |
+
# Best-of-N selector comparison on shared drafts
|
| 144 |
+
# ----------------------------------------------------------------------------
|
| 145 |
+
def selector_comparison(drafts_scores_gold, decode_fn, rng):
|
| 146 |
+
"""drafts_scores_gold: list of (drafts, jepa_scores, gold_answer).
|
| 147 |
+
Returns EM for: first | random | jepa | oracle β all on the SAME drafts."""
|
| 148 |
+
res = defaultdict(list)
|
| 149 |
+
for drafts, scores, gold in drafts_scores_gold:
|
| 150 |
+
texts = [decode_fn(d) for d in drafts]
|
| 151 |
+
ems = [em_score(t, gold) for t in texts]
|
| 152 |
+
res["first"].append(ems[0])
|
| 153 |
+
res["random"].append(ems[rng.randrange(len(ems))])
|
| 154 |
+
res["jepa"].append(ems[max(range(len(ems)), key=lambda i: scores[i])])
|
| 155 |
+
res["oracle"].append(max(ems)) # upper bound of selection
|
| 156 |
+
return {k: sum(v) / len(v) for k, v in res.items()}, res
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
# ----------------------------------------------------------------------------
|
| 160 |
+
# Data loading
|
| 161 |
+
# ----------------------------------------------------------------------------
|
| 162 |
+
def load_jsonl(path):
|
| 163 |
+
with open(path) as f:
|
| 164 |
+
return [json.loads(l) for l in f if l.strip()]
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def make_toy_data(ens, n_qa=20, vocab=None):
|
| 168 |
+
"""Synthetic QA for the tiny backend: 'answer' token sequence is planted
|
| 169 |
+
in the KB so RAG can genuinely help even with random weights."""
|
| 170 |
+
vocab = vocab or ens.llm.vocab_size
|
| 171 |
+
qa, kb = [], []
|
| 172 |
+
for i in range(n_qa):
|
| 173 |
+
key = torch.randint(0, vocab, (1, 6))
|
| 174 |
+
ans = torch.randint(0, vocab, (1, 4))
|
| 175 |
+
kb.append(torch.cat([key, ans], dim=1)) # memory = key+answer
|
| 176 |
+
qa.append({"q_ids": key, "answer_ids": ans})
|
| 177 |
+
return qa, kb
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
# ----------------------------------------------------------------------------
|
| 181 |
+
# Main evaluation
|
| 182 |
+
# ----------------------------------------------------------------------------
|
| 183 |
+
def run(args):
|
| 184 |
+
torch.manual_seed(args.seed)
|
| 185 |
+
rng = random.Random(args.seed)
|
| 186 |
+
|
| 187 |
+
ens = Ensemble(llm=args.llm)
|
| 188 |
+
if args.ckpt:
|
| 189 |
+
state = torch.load(args.ckpt, map_location="cpu")
|
| 190 |
+
ens.load_state_dict(state, strict=False)
|
| 191 |
+
print(f"loaded ensemble checkpoint: {args.ckpt}")
|
| 192 |
+
|
| 193 |
+
is_text = args.llm != "tiny"
|
| 194 |
+
|
| 195 |
+
# ---- load data and fill the vector store -------------------------------
|
| 196 |
+
if args.toy or not is_text:
|
| 197 |
+
qa, kb = make_toy_data(ens)
|
| 198 |
+
for mem in kb:
|
| 199 |
+
ens.memorize_ids(mem)
|
| 200 |
+
def to_ids(item): return item["q_ids"]
|
| 201 |
+
def gold_of(item): return item["answer_ids"]
|
| 202 |
+
def decode(ids): return " ".join(map(str, ids[0].tolist()))
|
| 203 |
+
def gold_text(item): return decode(item["answer_ids"])
|
| 204 |
+
else:
|
| 205 |
+
qa = load_jsonl(args.qa)
|
| 206 |
+
if args.kb:
|
| 207 |
+
for row in load_jsonl(args.kb):
|
| 208 |
+
ens.memorize_text(row["text"])
|
| 209 |
+
def to_ids(item): return ens.llm.encode_text(
|
| 210 |
+
f"Answer briefly.\nQ: {item['question']}\nA:")
|
| 211 |
+
def gold_text(item): return item["answer"]
|
| 212 |
+
def decode(ids): return ens.llm.decode(ids)
|
| 213 |
+
|
| 214 |
+
qa = qa[: args.limit]
|
| 215 |
+
print(f"eval set: {len(qa)} questions | store: {len(ens.store.keys)} memories\n")
|
| 216 |
+
|
| 217 |
+
# ---- ablation ladder ----------------------------------------------------
|
| 218 |
+
configs = {
|
| 219 |
+
"C1_base": dict(use_rag=False, use_router=False, use_jepa=False),
|
| 220 |
+
"C2_rag": dict(use_rag=True, use_router=False, use_jepa=False),
|
| 221 |
+
"C3_rag_router": dict(use_rag=True, use_router=True, use_jepa=False),
|
| 222 |
+
"C4_full_jepa": dict(use_rag=True, use_router=True, use_jepa=True,
|
| 223 |
+
n_drafts=args.n_drafts),
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
per_q = {} # config -> list of EM scores (for bootstrap)
|
| 227 |
+
summary = {}
|
| 228 |
+
jepa_material = [] # (drafts, scores, gold) for selector comparison
|
| 229 |
+
|
| 230 |
+
for name, cfg in configs.items():
|
| 231 |
+
ems, f1s, lats = [], [], []
|
| 232 |
+
for item in qa:
|
| 233 |
+
ids = to_ids(item)
|
| 234 |
+
if name == "C1_base":
|
| 235 |
+
out, dt = generate_plain(ens, ids, args.n_new)
|
| 236 |
+
extra = None
|
| 237 |
+
else:
|
| 238 |
+
out, dt, extra = generate_config(ens, ids, args.n_new, **cfg)
|
| 239 |
+
pred, gold = decode(out), gold_text(item)
|
| 240 |
+
ems.append(em_score(pred, gold))
|
| 241 |
+
f1s.append(f1_score(pred, gold))
|
| 242 |
+
lats.append(dt)
|
| 243 |
+
if name == "C4_full_jepa" and extra is not None:
|
| 244 |
+
jepa_material.append((extra[0], extra[1], gold))
|
| 245 |
+
per_q[name] = ems
|
| 246 |
+
summary[name] = (sum(ems) / len(ems), sum(f1s) / len(f1s),
|
| 247 |
+
sum(lats) / len(lats))
|
| 248 |
+
|
| 249 |
+
print(f"{'config':<16}{'EM':>8}{'F1':>8}{'lat(s)':>9}")
|
| 250 |
+
for k, (em, f1, lat) in summary.items():
|
| 251 |
+
print(f"{k:<16}{em:>8.3f}{f1:>8.3f}{lat:>9.3f}")
|
| 252 |
+
|
| 253 |
+
# deltas + significance
|
| 254 |
+
print("\ncomponent contributions (paired bootstrap, P(B>A)):")
|
| 255 |
+
ladder = list(configs.keys())
|
| 256 |
+
for a, b in zip(ladder, ladder[1:]):
|
| 257 |
+
d = summary[b][0] - summary[a][0]
|
| 258 |
+
p = paired_bootstrap(per_q[a], per_q[b])
|
| 259 |
+
print(f" {b} - {a}: ΞEM={d:+.3f} P(better)={p:.2f}")
|
| 260 |
+
|
| 261 |
+
# ---- decisive JEPA selector experiment ----------------------------------
|
| 262 |
+
if jepa_material:
|
| 263 |
+
sel, sel_per_q = selector_comparison(jepa_material, decode, rng)
|
| 264 |
+
print("\nbest-of-N selector comparison (same drafts, N="
|
| 265 |
+
f"{args.n_drafts}):")
|
| 266 |
+
for k in ("first", "random", "jepa", "oracle"):
|
| 267 |
+
print(f" {k:<8}EM={sel[k]:.3f}")
|
| 268 |
+
p = paired_bootstrap(sel_per_q["random"], sel_per_q["jepa"])
|
| 269 |
+
print(f" P(jepa > random) = {p:.2f} "
|
| 270 |
+
f"{'JEPA critic WORKS' if p > 0.95 else 'inconclusive β critic ~ random'}")
|
| 271 |
+
gap = sel["oracle"] - sel["jepa"]
|
| 272 |
+
print(f" headroom to oracle: {gap:.3f}")
|
| 273 |
+
|
| 274 |
+
# ---- continual forgetting (optional) ------------------------------------
|
| 275 |
+
if args.continual:
|
| 276 |
+
print("\ncontinual test: accuracy on task-A questions "
|
| 277 |
+
"before vs after adding adapters B and C")
|
| 278 |
+
ems_before = per_q["C3_rag_router"]
|
| 279 |
+
ens.new_task_adapter("task_B")
|
| 280 |
+
ens.new_task_adapter("task_C")
|
| 281 |
+
ems_after = []
|
| 282 |
+
for item in qa:
|
| 283 |
+
out, _, _ = generate_config(ens, to_ids(item), args.n_new,
|
| 284 |
+
use_rag=True, use_router=True,
|
| 285 |
+
use_jepa=False)
|
| 286 |
+
ems_after.append(em_score(decode(out), gold_text(item)))
|
| 287 |
+
bt = sum(ems_after) / len(ems_after) - sum(ems_before) / len(ems_before)
|
| 288 |
+
print(f" backward transfer (β0 is ideal): {bt:+.3f}")
|
| 289 |
+
|
| 290 |
+
return summary
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def parse_args():
|
| 294 |
+
p = argparse.ArgumentParser()
|
| 295 |
+
p.add_argument("--llm", default="tiny", help="'tiny' | HF id | local path")
|
| 296 |
+
p.add_argument("--qa", default=None, help="jsonl with question/answer")
|
| 297 |
+
p.add_argument("--kb", default=None, help="jsonl with text -> vector store")
|
| 298 |
+
p.add_argument("--ckpt", default=None, help="bridge-trained ensemble .pt (C5)")
|
| 299 |
+
p.add_argument("--toy", action="store_true", help="synthetic data smoke test")
|
| 300 |
+
p.add_argument("--limit", type=int, default=100)
|
| 301 |
+
p.add_argument("--n_new", type=int, default=24)
|
| 302 |
+
p.add_argument("--n_drafts", type=int, default=8)
|
| 303 |
+
p.add_argument("--continual", action="store_true")
|
| 304 |
+
p.add_argument("--seed", type=int, default=0)
|
| 305 |
+
return p.parse_args()
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
if __name__ == "__main__":
|
| 309 |
+
run(parse_args())
|
research/ensemble/src/ensemble/jepa.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""JEPA latent predictor with EMA target encoder."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import copy
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class _SegEncoder(nn.Module):
|
| 13 |
+
def __init__(self, vocab_size, d):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.tok = nn.Embedding(vocab_size, d)
|
| 16 |
+
self.enc = nn.GRU(d, d, batch_first=True)
|
| 17 |
+
self.out = nn.Linear(d, d)
|
| 18 |
+
|
| 19 |
+
def forward(self, ids):
|
| 20 |
+
h, _ = self.enc(self.tok(ids))
|
| 21 |
+
return self.out(h.mean(dim=1))
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class JEPA(nn.Module):
|
| 25 |
+
def __init__(self, vocab_size: int, d_latent: int = 64, ema_m: float = 0.996):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.ctx_enc = _SegEncoder(vocab_size, d_latent)
|
| 28 |
+
self.tgt_enc = copy.deepcopy(self.ctx_enc)
|
| 29 |
+
for p in self.tgt_enc.parameters():
|
| 30 |
+
p.requires_grad_(False)
|
| 31 |
+
self.predictor = nn.Sequential(
|
| 32 |
+
nn.Linear(d_latent, 2 * d_latent),
|
| 33 |
+
nn.GELU(),
|
| 34 |
+
nn.Linear(2 * d_latent, d_latent),
|
| 35 |
+
)
|
| 36 |
+
self.m = ema_m
|
| 37 |
+
self.d_latent = d_latent
|
| 38 |
+
|
| 39 |
+
@property
|
| 40 |
+
def enc(self):
|
| 41 |
+
"""Alias used by world-model track."""
|
| 42 |
+
return self.ctx_enc
|
| 43 |
+
|
| 44 |
+
@property
|
| 45 |
+
def tgt(self):
|
| 46 |
+
return self.tgt_enc
|
| 47 |
+
|
| 48 |
+
@property
|
| 49 |
+
def pred(self):
|
| 50 |
+
return self.predictor
|
| 51 |
+
|
| 52 |
+
@torch.no_grad()
|
| 53 |
+
def ema_update(self):
|
| 54 |
+
for p_t, p_c in zip(self.tgt_enc.parameters(), self.ctx_enc.parameters()):
|
| 55 |
+
p_t.mul_(self.m).add_(p_c.detach(), alpha=1 - self.m)
|
| 56 |
+
|
| 57 |
+
def ema(self):
|
| 58 |
+
"""Alias used by world-model track."""
|
| 59 |
+
self.ema_update()
|
| 60 |
+
|
| 61 |
+
def loss(self, seg_ctx, seg_tgt):
|
| 62 |
+
z_hat = self.predictor(self.ctx_enc(seg_ctx))
|
| 63 |
+
with torch.no_grad():
|
| 64 |
+
z_tgt = self.tgt_enc(seg_tgt)
|
| 65 |
+
pred = F.mse_loss(z_hat, z_tgt)
|
| 66 |
+
var_reg = F.relu(1.0 - z_hat.std(dim=0)).mean()
|
| 67 |
+
return pred + 0.5 * var_reg
|
| 68 |
+
|
| 69 |
+
@torch.no_grad()
|
| 70 |
+
def predict_next_latent(self, seg_ctx):
|
| 71 |
+
return self.predictor(self.ctx_enc(seg_ctx))
|
| 72 |
+
|
| 73 |
+
@torch.no_grad()
|
| 74 |
+
def encode(self, seg):
|
| 75 |
+
return self.tgt_enc(seg)
|
research/ensemble/src/ensemble/jepa_ensemble.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""JEPA ensemble: route -> retrieve -> generate -> JEPA-verify."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
from ensemble.backends import HFBackend, make_backend
|
| 10 |
+
from ensemble.bridge import Bridge
|
| 11 |
+
from ensemble.jepa import JEPA
|
| 12 |
+
from ensemble.memory import Embedder, Router, VectorStore
|
| 13 |
+
|
| 14 |
+
torch.manual_seed(0)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Ensemble(nn.Module):
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
llm: str = "tiny",
|
| 21 |
+
adapter_names=("general",),
|
| 22 |
+
d_emb: int = 64,
|
| 23 |
+
d_jepa: int = 64,
|
| 24 |
+
llm_backend: HFBackend | None = None,
|
| 25 |
+
**backend_kw,
|
| 26 |
+
):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.llm = llm_backend if llm_backend is not None else make_backend(llm, **backend_kw)
|
| 29 |
+
V, H = self.llm.vocab_size, self.llm.hidden_size
|
| 30 |
+
|
| 31 |
+
self.emb = Embedder(V, d_emb)
|
| 32 |
+
self.jepa = JEPA(V, d_jepa)
|
| 33 |
+
self.bridge = Bridge(H, d_jepa)
|
| 34 |
+
self.store = VectorStore()
|
| 35 |
+
|
| 36 |
+
self.adapter_names = list(adapter_names)
|
| 37 |
+
for n in self.adapter_names:
|
| 38 |
+
self.llm.add_adapter(n)
|
| 39 |
+
self.llm.set_adapter(self.adapter_names[0])
|
| 40 |
+
self.router = Router(d_emb, len(self.adapter_names))
|
| 41 |
+
|
| 42 |
+
@torch.no_grad()
|
| 43 |
+
def answer_ids(
|
| 44 |
+
self,
|
| 45 |
+
query_ids,
|
| 46 |
+
n_new=32,
|
| 47 |
+
tau_consistency=0.0,
|
| 48 |
+
max_retries=2,
|
| 49 |
+
temperature: float = 0.7,
|
| 50 |
+
):
|
| 51 |
+
q_emb = self.emb(query_ids.cpu())
|
| 52 |
+
a_idx = self.router(q_emb).item()
|
| 53 |
+
self.llm.set_adapter(self.adapter_names[a_idx])
|
| 54 |
+
|
| 55 |
+
mems = self.store.search(q_emb, k=1)
|
| 56 |
+
ctx = (
|
| 57 |
+
torch.cat([mems[0], query_ids.cpu()], dim=1)
|
| 58 |
+
if mems
|
| 59 |
+
else query_ids.cpu()
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
z_expected = self.jepa.predict_next_latent(ctx)
|
| 63 |
+
|
| 64 |
+
best = None
|
| 65 |
+
for attempt in range(max_retries + 1):
|
| 66 |
+
temp = temperature if attempt == 0 else max(temperature, 0.8 + 0.3 * attempt)
|
| 67 |
+
draft = self.llm.generate(
|
| 68 |
+
ctx.to(self.llm.device),
|
| 69 |
+
n_new=n_new,
|
| 70 |
+
temperature=temp,
|
| 71 |
+
)
|
| 72 |
+
new_part = draft[:, ctx.size(1) :].cpu()
|
| 73 |
+
score = F.cosine_similarity(
|
| 74 |
+
z_expected, self.jepa.encode(new_part)
|
| 75 |
+
).item()
|
| 76 |
+
if best is None or score > best[1]:
|
| 77 |
+
best = (draft, score, attempt)
|
| 78 |
+
if score >= tau_consistency:
|
| 79 |
+
break
|
| 80 |
+
draft, score, attempt = best
|
| 81 |
+
return draft, score, self.adapter_names[a_idx], attempt
|
| 82 |
+
|
| 83 |
+
def answer_text(self, prompt: str, **kw):
|
| 84 |
+
ids = self.llm.encode_text(prompt)
|
| 85 |
+
out, score, adapter, retries = self.answer_ids(ids, **kw)
|
| 86 |
+
return self.llm.decode(out), score, adapter, retries
|
| 87 |
+
|
| 88 |
+
def generate_text(
|
| 89 |
+
self,
|
| 90 |
+
prompt: str,
|
| 91 |
+
*,
|
| 92 |
+
max_new_tokens: int = 512,
|
| 93 |
+
temperature: float = 0.0,
|
| 94 |
+
) -> str:
|
| 95 |
+
"""Greedy or sampled generation through the full ensemble stack."""
|
| 96 |
+
ids = self.llm.encode_text(prompt)
|
| 97 |
+
out, _, _, _ = self.answer_ids(
|
| 98 |
+
ids,
|
| 99 |
+
n_new=max_new_tokens,
|
| 100 |
+
tau_consistency=-1.0,
|
| 101 |
+
max_retries=0 if temperature <= 0 else 1,
|
| 102 |
+
temperature=temperature,
|
| 103 |
+
)
|
| 104 |
+
return self.llm.decode(out)
|
| 105 |
+
|
| 106 |
+
def memorize_ids(self, ids):
|
| 107 |
+
self.store.add(self.emb(ids.cpu()), ids.cpu())
|
| 108 |
+
|
| 109 |
+
def memorize_text(self, text: str):
|
| 110 |
+
self.memorize_ids(self.llm.encode_text(text))
|
| 111 |
+
|
| 112 |
+
def new_task_adapter(self, name: str):
|
| 113 |
+
self.adapter_names.append(name)
|
| 114 |
+
self.llm.add_adapter(name)
|
| 115 |
+
old = self.router
|
| 116 |
+
self.router = Router(self.emb.d_emb, len(self.adapter_names))
|
| 117 |
+
with torch.no_grad():
|
| 118 |
+
self.router.fc.weight[: old.fc.out_features] = old.fc.weight
|
| 119 |
+
self.router.fc.bias[: old.fc.out_features] = old.fc.bias
|
| 120 |
+
|
| 121 |
+
def train_step(self, seg_a, seg_b, opt, w_bridge=0.1):
|
| 122 |
+
logits, hidden = self.llm(seg_a.to(self.llm.device))
|
| 123 |
+
lm_loss = F.cross_entropy(
|
| 124 |
+
logits[:, :-1].reshape(-1, self.llm.vocab_size).float(),
|
| 125 |
+
seg_a[:, 1:].reshape(-1).to(logits.device),
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
jepa_loss = self.jepa.loss(seg_a.cpu(), seg_b.cpu())
|
| 129 |
+
|
| 130 |
+
z_llm = self.bridge(
|
| 131 |
+
hidden.cpu() if hidden.device.type != "cpu" else hidden
|
| 132 |
+
)
|
| 133 |
+
z_jepa = self.jepa.ctx_enc(seg_a.cpu()).detach()
|
| 134 |
+
bridge_loss = self.bridge.info_nce(z_llm, z_jepa.to(z_llm.device))
|
| 135 |
+
|
| 136 |
+
loss = lm_loss.cpu() + jepa_loss + w_bridge * bridge_loss
|
| 137 |
+
opt.zero_grad()
|
| 138 |
+
loss.backward()
|
| 139 |
+
opt.step()
|
| 140 |
+
self.jepa.ema_update()
|
| 141 |
+
return {
|
| 142 |
+
"lm": lm_loss.item(),
|
| 143 |
+
"jepa": jepa_loss.item(),
|
| 144 |
+
"bridge": bridge_loss.item(),
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
def make_optimizer(self, lr_lora=2e-4, lr_aux=1e-3):
|
| 148 |
+
return torch.optim.AdamW(
|
| 149 |
+
[
|
| 150 |
+
{"params": list(self.llm.trainable_parameters()), "lr": lr_lora},
|
| 151 |
+
{
|
| 152 |
+
"params": list(self.jepa.ctx_enc.parameters())
|
| 153 |
+
+ list(self.jepa.predictor.parameters()),
|
| 154 |
+
"lr": lr_aux,
|
| 155 |
+
},
|
| 156 |
+
{
|
| 157 |
+
"params": list(self.bridge.parameters())
|
| 158 |
+
+ list(self.emb.parameters())
|
| 159 |
+
+ list(self.router.parameters()),
|
| 160 |
+
"lr": lr_aux,
|
| 161 |
+
},
|
| 162 |
+
]
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def segment_pairs_from_texts(backend: HFBackend, texts, seg_len=64):
|
| 167 |
+
a_list, b_list = [], []
|
| 168 |
+
for t in texts:
|
| 169 |
+
ids = backend.tokenizer(t, return_tensors="pt").input_ids[0]
|
| 170 |
+
for i in range(0, len(ids) - 2 * seg_len, seg_len):
|
| 171 |
+
a_list.append(ids[i : i + seg_len])
|
| 172 |
+
b_list.append(ids[i + seg_len : i + 2 * seg_len])
|
| 173 |
+
if not a_list:
|
| 174 |
+
raise ValueError("texts too short for the chosen seg_len")
|
| 175 |
+
return torch.stack(a_list), torch.stack(b_list)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def demo_tiny(steps=50):
|
| 179 |
+
ens = Ensemble(llm="tiny")
|
| 180 |
+
opt = ens.make_optimizer()
|
| 181 |
+
for s in range(steps):
|
| 182 |
+
seg_a = torch.randint(0, ens.llm.vocab_size, (8, 32))
|
| 183 |
+
seg_b = torch.randint(0, ens.llm.vocab_size, (8, 32))
|
| 184 |
+
logs = ens.train_step(seg_a, seg_b, opt)
|
| 185 |
+
if s % 10 == 0:
|
| 186 |
+
print(
|
| 187 |
+
f"step {s:3d} | "
|
| 188 |
+
+ " | ".join(f"{k} {v:.3f}" for k, v in logs.items())
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
for _ in range(5):
|
| 192 |
+
ens.memorize_ids(torch.randint(0, ens.llm.vocab_size, (1, 32)))
|
| 193 |
+
ens.new_task_adapter("medical")
|
| 194 |
+
|
| 195 |
+
q = torch.randint(0, ens.llm.vocab_size, (1, 8))
|
| 196 |
+
out, score, adapter, retries = ens.answer_ids(q, tau_consistency=-1.0)
|
| 197 |
+
print(f"\nadapter={adapter} jepa_consistency={score:.3f} retries={retries}")
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def demo_hf(model_path="Qwen/Qwen2.5-0.5B-Instruct"):
|
| 201 |
+
ens = Ensemble(llm=model_path, load_in_4bit=False)
|
| 202 |
+
opt = ens.make_optimizer()
|
| 203 |
+
|
| 204 |
+
texts = ["Replace this with your real corpus. " * 50]
|
| 205 |
+
seg_a, seg_b = segment_pairs_from_texts(ens.llm, texts, seg_len=32)
|
| 206 |
+
for s in range(10):
|
| 207 |
+
logs = ens.train_step(seg_a[:4], seg_b[:4], opt)
|
| 208 |
+
print(f"step {s} | " + " | ".join(f"{k} {v:.3f}" for k, v in logs.items()))
|
| 209 |
+
|
| 210 |
+
ens.memorize_text("The project codename is AURORA and it ships in Q3.")
|
| 211 |
+
ens.new_task_adapter("project_aurora")
|
| 212 |
+
|
| 213 |
+
text, score, adapter, retries = ens.answer_text(
|
| 214 |
+
"What is the project codename?", n_new=24, tau_consistency=-1.0
|
| 215 |
+
)
|
| 216 |
+
print(f"\n[{adapter} | jepa={score:.3f} | retries={retries}]\n{text}")
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
if __name__ == "__main__":
|
| 220 |
+
import sys
|
| 221 |
+
|
| 222 |
+
from ensemble.config import load_dotenv, resolve_llm
|
| 223 |
+
|
| 224 |
+
load_dotenv()
|
| 225 |
+
arg = sys.argv[1] if len(sys.argv) > 1 else None
|
| 226 |
+
if arg is None or arg == "auto":
|
| 227 |
+
arg, preset = resolve_llm()
|
| 228 |
+
print(f"Resolved LLM: {arg} (preset {preset})")
|
| 229 |
+
if arg == "tiny":
|
| 230 |
+
demo_tiny()
|
| 231 |
+
else:
|
| 232 |
+
demo_hf(arg)
|
research/ensemble/src/ensemble/llm_emb_jepa_ensemble_pluggable.py
ADDED
|
@@ -0,0 +1,507 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LLM + Embedding + JEPA Ensemble β pluggable base-model edition
|
| 3 |
+
==============================================================
|
| 4 |
+
Now the LLM is a swappable BACKEND. Three ways to load it:
|
| 5 |
+
|
| 6 |
+
# 1. HuggingFace Hub id
|
| 7 |
+
ens = Ensemble(llm="Qwen/Qwen2.5-0.5B-Instruct")
|
| 8 |
+
|
| 9 |
+
# 2. Local path (e.g. downloaded Llama / converted checkpoint)
|
| 10 |
+
ens = Ensemble(llm="/models/llama-3.2-1b")
|
| 11 |
+
|
| 12 |
+
# 3. Toy fallback (no transformers needed, runs on CPU in seconds)
|
| 13 |
+
ens = Ensemble(llm="tiny")
|
| 14 |
+
|
| 15 |
+
Requirements for real models:
|
| 16 |
+
pip install torch transformers peft accelerate
|
| 17 |
+
(optional 4-bit: pip install bitsandbytes -> load_in_4bit=True)
|
| 18 |
+
|
| 19 |
+
Everything else (Embedder, JEPA, Bridge, VectorStore, Router, the
|
| 20 |
+
JEPA-critic inference loop, continual-learning hooks) only touches
|
| 21 |
+
token ids / hidden states / latents, so it works with ANY backend.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
from __future__ import annotations
|
| 25 |
+
import copy
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn as nn
|
| 28 |
+
import torch.nn.functional as F
|
| 29 |
+
|
| 30 |
+
torch.manual_seed(0)
|
| 31 |
+
|
| 32 |
+
# ----------------------------------------------------------------------------
|
| 33 |
+
# 0. Backend interface β everything the ensemble needs from "an LLM"
|
| 34 |
+
# ----------------------------------------------------------------------------
|
| 35 |
+
class LLMBackend(nn.Module):
|
| 36 |
+
"""Contract:
|
| 37 |
+
vocab_size : int
|
| 38 |
+
hidden_size: int
|
| 39 |
+
device : torch.device
|
| 40 |
+
forward(ids) -> (logits [B,T,V], hidden [B,T,H])
|
| 41 |
+
generate(ids, n_new) -> ids [B, T+n_new]
|
| 42 |
+
add_adapter(name) / set_adapter(name)
|
| 43 |
+
trainable_parameters() -> iterable of params to optimize
|
| 44 |
+
encode_text(str) / decode(ids) (real backends only)
|
| 45 |
+
"""
|
| 46 |
+
vocab_size: int
|
| 47 |
+
hidden_size: int
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# ----------------------------------------------------------------------------
|
| 51 |
+
# 0a. HuggingFace backend (local path OR hub id) with PEFT LoRA adapters
|
| 52 |
+
# ----------------------------------------------------------------------------
|
| 53 |
+
class HFBackend(LLMBackend):
|
| 54 |
+
def __init__(self, model_path: str, *, load_in_4bit: bool = False,
|
| 55 |
+
lora_r: int = 16, lora_alpha: int = 32,
|
| 56 |
+
target_modules=("q_proj", "v_proj"),
|
| 57 |
+
device: str | None = None, torch_dtype=None):
|
| 58 |
+
super().__init__()
|
| 59 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 60 |
+
from peft import LoraConfig, get_peft_model
|
| 61 |
+
|
| 62 |
+
self.device_ = torch.device(
|
| 63 |
+
device or ("cuda" if torch.cuda.is_available() else "cpu"))
|
| 64 |
+
|
| 65 |
+
kwargs = {}
|
| 66 |
+
if load_in_4bit:
|
| 67 |
+
from transformers import BitsAndBytesConfig
|
| 68 |
+
kwargs["quantization_config"] = BitsAndBytesConfig(
|
| 69 |
+
load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16,
|
| 70 |
+
bnb_4bit_quant_type="nf4")
|
| 71 |
+
if torch_dtype is not None:
|
| 72 |
+
kwargs["torch_dtype"] = torch_dtype
|
| 73 |
+
|
| 74 |
+
# `model_path` may be "Qwen/Qwen2.5-0.5B-Instruct", "meta-llama/...",
|
| 75 |
+
# or a local directory like "/models/llama-3.2-1b".
|
| 76 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 77 |
+
if self.tokenizer.pad_token is None:
|
| 78 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 79 |
+
base = AutoModelForCausalLM.from_pretrained(model_path, **kwargs)
|
| 80 |
+
if not load_in_4bit:
|
| 81 |
+
base.to(self.device_)
|
| 82 |
+
|
| 83 |
+
# Freeze the base; all learning happens in LoRA adapters.
|
| 84 |
+
for p in base.parameters():
|
| 85 |
+
p.requires_grad_(False)
|
| 86 |
+
|
| 87 |
+
self._lora_cfg = LoraConfig(
|
| 88 |
+
r=lora_r, lora_alpha=lora_alpha, lora_dropout=0.05,
|
| 89 |
+
target_modules=list(target_modules), task_type="CAUSAL_LM")
|
| 90 |
+
self.model = get_peft_model(base, self._lora_cfg, adapter_name="general")
|
| 91 |
+
self._adapters = {"general"}
|
| 92 |
+
|
| 93 |
+
self.vocab_size = self.model.config.vocab_size
|
| 94 |
+
self.hidden_size = self.model.config.hidden_size
|
| 95 |
+
|
| 96 |
+
# ---- adapters -----------------------------------------------------------
|
| 97 |
+
def add_adapter(self, name: str):
|
| 98 |
+
if name not in self._adapters:
|
| 99 |
+
self.model.add_adapter(name, self._lora_cfg)
|
| 100 |
+
self._adapters.add(name)
|
| 101 |
+
|
| 102 |
+
def set_adapter(self, name: str):
|
| 103 |
+
self.model.set_adapter(name)
|
| 104 |
+
|
| 105 |
+
def trainable_parameters(self):
|
| 106 |
+
return (p for p in self.model.parameters() if p.requires_grad)
|
| 107 |
+
|
| 108 |
+
# ---- core ops -----------------------------------------------------------
|
| 109 |
+
def forward(self, ids):
|
| 110 |
+
out = self.model(input_ids=ids.to(self.device_),
|
| 111 |
+
output_hidden_states=True)
|
| 112 |
+
return out.logits, out.hidden_states[-1] # last layer hidden
|
| 113 |
+
|
| 114 |
+
@torch.no_grad()
|
| 115 |
+
def generate(self, ids, n_new=64, temperature=0.8):
|
| 116 |
+
out = self.model.generate(
|
| 117 |
+
input_ids=ids.to(self.device_),
|
| 118 |
+
max_new_tokens=n_new, do_sample=True, temperature=temperature,
|
| 119 |
+
pad_token_id=self.tokenizer.pad_token_id)
|
| 120 |
+
return out
|
| 121 |
+
|
| 122 |
+
# ---- text helpers -------------------------------------------------------
|
| 123 |
+
def encode_text(self, text: str):
|
| 124 |
+
return self.tokenizer(text, return_tensors="pt").input_ids.to(self.device_)
|
| 125 |
+
|
| 126 |
+
def decode(self, ids):
|
| 127 |
+
return self.tokenizer.decode(ids[0], skip_special_tokens=True)
|
| 128 |
+
|
| 129 |
+
@property
|
| 130 |
+
def device(self):
|
| 131 |
+
return self.device_
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
# ----------------------------------------------------------------------------
|
| 135 |
+
# 0b. Tiny fallback backend (no transformers; same toy model as before)
|
| 136 |
+
# ----------------------------------------------------------------------------
|
| 137 |
+
class TinyBackend(LLMBackend):
|
| 138 |
+
VOCAB, D_MODEL, N_LAYERS, N_HEADS, SEQ_LEN, LORA_R = 1000, 128, 2, 4, 32, 8
|
| 139 |
+
|
| 140 |
+
class _LoRALinear(nn.Module):
|
| 141 |
+
def __init__(self, d_in, d_out, r):
|
| 142 |
+
super().__init__()
|
| 143 |
+
self.base = nn.Linear(d_in, d_out)
|
| 144 |
+
self.base.weight.requires_grad_(False)
|
| 145 |
+
self.base.bias.requires_grad_(False)
|
| 146 |
+
self.adapters, self.active, self.r = nn.ModuleDict(), None, r
|
| 147 |
+
|
| 148 |
+
def add_adapter(self, name):
|
| 149 |
+
A = nn.Linear(self.base.in_features, self.r, bias=False)
|
| 150 |
+
B = nn.Linear(self.r, self.base.out_features, bias=False)
|
| 151 |
+
nn.init.zeros_(B.weight)
|
| 152 |
+
self.adapters[name] = nn.Sequential(A, B)
|
| 153 |
+
|
| 154 |
+
def forward(self, x):
|
| 155 |
+
y = self.base(x)
|
| 156 |
+
if self.active and self.active in self.adapters:
|
| 157 |
+
y = y + self.adapters[self.active](x)
|
| 158 |
+
return y
|
| 159 |
+
|
| 160 |
+
class _Block(nn.Module):
|
| 161 |
+
def __init__(self, D, H, R):
|
| 162 |
+
super().__init__()
|
| 163 |
+
L = TinyBackend._LoRALinear
|
| 164 |
+
self.ln1 = nn.LayerNorm(D)
|
| 165 |
+
self.attn = nn.MultiheadAttention(D, H, batch_first=True)
|
| 166 |
+
self.ln2 = nn.LayerNorm(D)
|
| 167 |
+
self.up, self.down = L(D, 4 * D, R), L(4 * D, D, R)
|
| 168 |
+
|
| 169 |
+
def forward(self, x, mask):
|
| 170 |
+
h = self.ln1(x)
|
| 171 |
+
a, _ = self.attn(h, h, h, attn_mask=mask, need_weights=False)
|
| 172 |
+
x = x + a
|
| 173 |
+
return x + self.down(F.gelu(self.up(self.ln2(x))))
|
| 174 |
+
|
| 175 |
+
def __init__(self):
|
| 176 |
+
super().__init__()
|
| 177 |
+
D, V = self.D_MODEL, self.VOCAB
|
| 178 |
+
self.tok = nn.Embedding(V, D)
|
| 179 |
+
self.pos = nn.Embedding(self.SEQ_LEN * 4, D)
|
| 180 |
+
self.blocks = nn.ModuleList(
|
| 181 |
+
[self._Block(D, self.N_HEADS, self.LORA_R) for _ in range(self.N_LAYERS)])
|
| 182 |
+
self.ln_f, self.head = nn.LayerNorm(D), nn.Linear(D, V, bias=False)
|
| 183 |
+
self.vocab_size, self.hidden_size = V, D
|
| 184 |
+
self.add_adapter("general")
|
| 185 |
+
self.set_adapter("general")
|
| 186 |
+
|
| 187 |
+
def add_adapter(self, name):
|
| 188 |
+
for b in self.blocks:
|
| 189 |
+
b.up.add_adapter(name); b.down.add_adapter(name)
|
| 190 |
+
|
| 191 |
+
def set_adapter(self, name):
|
| 192 |
+
for b in self.blocks:
|
| 193 |
+
b.up.active = name; b.down.active = name
|
| 194 |
+
|
| 195 |
+
def trainable_parameters(self):
|
| 196 |
+
return (p for p in self.parameters() if p.requires_grad)
|
| 197 |
+
|
| 198 |
+
def forward(self, ids):
|
| 199 |
+
B, T = ids.shape
|
| 200 |
+
x = self.tok(ids) + self.pos(torch.arange(T, device=ids.device))
|
| 201 |
+
mask = torch.triu(torch.full((T, T), float("-inf"), device=ids.device), 1)
|
| 202 |
+
for b in self.blocks:
|
| 203 |
+
x = b(x, mask)
|
| 204 |
+
h = self.ln_f(x)
|
| 205 |
+
return self.head(h), h
|
| 206 |
+
|
| 207 |
+
@torch.no_grad()
|
| 208 |
+
def generate(self, ids, n_new=16, temperature=1.0):
|
| 209 |
+
for _ in range(n_new):
|
| 210 |
+
logits, _ = self(ids[:, -self.SEQ_LEN:])
|
| 211 |
+
nxt = torch.multinomial(F.softmax(logits[:, -1] / temperature, -1), 1)
|
| 212 |
+
ids = torch.cat([ids, nxt], dim=1)
|
| 213 |
+
return ids
|
| 214 |
+
|
| 215 |
+
@property
|
| 216 |
+
def device(self):
|
| 217 |
+
return next(self.parameters()).device
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def make_backend(llm: str, **kw) -> LLMBackend:
|
| 221 |
+
"""'tiny' -> toy model; anything else -> HF hub id or local path."""
|
| 222 |
+
return TinyBackend() if llm == "tiny" else HFBackend(llm, **kw)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
# ----------------------------------------------------------------------------
|
| 226 |
+
# 1. Embedder β vocab-agnostic (sized from the backend's tokenizer)
|
| 227 |
+
# Swap for a real model: pass embed_fn=lambda txt: sbert.encode(...)
|
| 228 |
+
# ----------------------------------------------------------------------------
|
| 229 |
+
class Embedder(nn.Module):
|
| 230 |
+
def __init__(self, vocab_size: int, d_emb: int = 64):
|
| 231 |
+
super().__init__()
|
| 232 |
+
self.tok = nn.Embedding(vocab_size, d_emb)
|
| 233 |
+
self.enc = nn.GRU(d_emb, d_emb, batch_first=True, bidirectional=True)
|
| 234 |
+
self.proj = nn.Linear(2 * d_emb, d_emb)
|
| 235 |
+
self.d_emb = d_emb
|
| 236 |
+
|
| 237 |
+
def forward(self, ids):
|
| 238 |
+
h, _ = self.enc(self.tok(ids))
|
| 239 |
+
return F.normalize(self.proj(h.mean(dim=1)), dim=-1)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
# ----------------------------------------------------------------------------
|
| 243 |
+
# 2. JEPA β vocab-agnostic latent predictor with EMA target encoder
|
| 244 |
+
# ----------------------------------------------------------------------------
|
| 245 |
+
class _JEPAEncoder(nn.Module):
|
| 246 |
+
def __init__(self, vocab_size, d):
|
| 247 |
+
super().__init__()
|
| 248 |
+
self.tok = nn.Embedding(vocab_size, d)
|
| 249 |
+
self.enc = nn.GRU(d, d, batch_first=True)
|
| 250 |
+
self.out = nn.Linear(d, d)
|
| 251 |
+
|
| 252 |
+
def forward(self, ids):
|
| 253 |
+
h, _ = self.enc(self.tok(ids))
|
| 254 |
+
return self.out(h.mean(dim=1))
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
class JEPA(nn.Module):
|
| 258 |
+
def __init__(self, vocab_size: int, d_jepa: int = 64, ema_m: float = 0.996):
|
| 259 |
+
super().__init__()
|
| 260 |
+
self.ctx_enc = _JEPAEncoder(vocab_size, d_jepa)
|
| 261 |
+
self.tgt_enc = copy.deepcopy(self.ctx_enc)
|
| 262 |
+
for p in self.tgt_enc.parameters():
|
| 263 |
+
p.requires_grad_(False)
|
| 264 |
+
self.predictor = nn.Sequential(
|
| 265 |
+
nn.Linear(d_jepa, 2 * d_jepa), nn.GELU(), nn.Linear(2 * d_jepa, d_jepa))
|
| 266 |
+
self.m, self.d_jepa = ema_m, d_jepa
|
| 267 |
+
|
| 268 |
+
@torch.no_grad()
|
| 269 |
+
def ema_update(self):
|
| 270 |
+
for p_t, p_c in zip(self.tgt_enc.parameters(), self.ctx_enc.parameters()):
|
| 271 |
+
p_t.mul_(self.m).add_(p_c.detach(), alpha=1 - self.m)
|
| 272 |
+
|
| 273 |
+
def loss(self, seg_ctx, seg_tgt):
|
| 274 |
+
z_hat = self.predictor(self.ctx_enc(seg_ctx))
|
| 275 |
+
with torch.no_grad():
|
| 276 |
+
z_tgt = self.tgt_enc(seg_tgt)
|
| 277 |
+
pred = F.mse_loss(z_hat, z_tgt)
|
| 278 |
+
var_reg = F.relu(1.0 - z_hat.std(dim=0)).mean() # anti-collapse
|
| 279 |
+
return pred + 0.5 * var_reg
|
| 280 |
+
|
| 281 |
+
@torch.no_grad()
|
| 282 |
+
def predict_next_latent(self, seg_ctx):
|
| 283 |
+
return self.predictor(self.ctx_enc(seg_ctx))
|
| 284 |
+
|
| 285 |
+
@torch.no_grad()
|
| 286 |
+
def encode(self, seg):
|
| 287 |
+
return self.tgt_enc(seg)
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
# ----------------------------------------------------------------------------
|
| 291 |
+
# 3. Bridge β sized from backend.hidden_size at construction
|
| 292 |
+
# ----------------------------------------------------------------------------
|
| 293 |
+
class Bridge(nn.Module):
|
| 294 |
+
def __init__(self, d_llm_hidden: int, d_jepa: int):
|
| 295 |
+
super().__init__()
|
| 296 |
+
self.proj = nn.Sequential(
|
| 297 |
+
nn.Linear(d_llm_hidden, d_jepa), nn.GELU(), nn.Linear(d_jepa, d_jepa))
|
| 298 |
+
|
| 299 |
+
def forward(self, llm_hidden): # [B,T,H] -> [B,d_jepa]
|
| 300 |
+
return self.proj(llm_hidden.float().mean(dim=1))
|
| 301 |
+
|
| 302 |
+
def info_nce(self, z1, z2, tau=0.07):
|
| 303 |
+
z1, z2 = F.normalize(z1, dim=-1), F.normalize(z2, dim=-1)
|
| 304 |
+
logits = z1 @ z2.t() / tau
|
| 305 |
+
labels = torch.arange(z1.size(0), device=z1.device)
|
| 306 |
+
return 0.5 * (F.cross_entropy(logits, labels) +
|
| 307 |
+
F.cross_entropy(logits.t(), labels))
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
# ----------------------------------------------------------------------------
|
| 311 |
+
# 4. Memory + Router
|
| 312 |
+
# ----------------------------------------------------------------------------
|
| 313 |
+
class VectorStore:
|
| 314 |
+
def __init__(self):
|
| 315 |
+
self.keys, self.values = [], []
|
| 316 |
+
|
| 317 |
+
def add(self, emb, payload):
|
| 318 |
+
self.keys.append(emb.squeeze(0).detach().cpu())
|
| 319 |
+
self.values.append(payload)
|
| 320 |
+
|
| 321 |
+
def search(self, q, k=2):
|
| 322 |
+
if not self.keys:
|
| 323 |
+
return []
|
| 324 |
+
K = torch.stack(self.keys)
|
| 325 |
+
sims = (q.detach().cpu() @ K.t()).squeeze(0)
|
| 326 |
+
top = sims.topk(min(k, len(self.keys))).indices
|
| 327 |
+
return [self.values[i] for i in top]
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
class Router(nn.Module):
|
| 331 |
+
def __init__(self, d_emb, n_adapters):
|
| 332 |
+
super().__init__()
|
| 333 |
+
self.fc = nn.Linear(d_emb, n_adapters)
|
| 334 |
+
|
| 335 |
+
def forward(self, emb):
|
| 336 |
+
return self.fc(emb).argmax(dim=-1)
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
# ----------------------------------------------------------------------------
|
| 340 |
+
# 5. Ensemble β backend-agnostic
|
| 341 |
+
# ----------------------------------------------------------------------------
|
| 342 |
+
class Ensemble(nn.Module):
|
| 343 |
+
def __init__(self, llm: str = "tiny", adapter_names=("general",),
|
| 344 |
+
d_emb: int = 64, d_jepa: int = 64, **backend_kw):
|
| 345 |
+
super().__init__()
|
| 346 |
+
self.llm = make_backend(llm, **backend_kw)
|
| 347 |
+
V, H = self.llm.vocab_size, self.llm.hidden_size
|
| 348 |
+
|
| 349 |
+
self.emb = Embedder(V, d_emb)
|
| 350 |
+
self.jepa = JEPA(V, d_jepa)
|
| 351 |
+
self.bridge = Bridge(H, d_jepa)
|
| 352 |
+
self.store = VectorStore()
|
| 353 |
+
|
| 354 |
+
self.adapter_names = list(adapter_names)
|
| 355 |
+
for n in self.adapter_names:
|
| 356 |
+
self.llm.add_adapter(n)
|
| 357 |
+
self.llm.set_adapter(self.adapter_names[0])
|
| 358 |
+
self.router = Router(d_emb, len(self.adapter_names))
|
| 359 |
+
|
| 360 |
+
# -------- inference: route -> retrieve -> generate -> JEPA-verify -------
|
| 361 |
+
@torch.no_grad()
|
| 362 |
+
def answer_ids(self, query_ids, n_new=32, tau_consistency=0.0, max_retries=2):
|
| 363 |
+
q_emb = self.emb(query_ids.cpu())
|
| 364 |
+
a_idx = self.router(q_emb).item()
|
| 365 |
+
self.llm.set_adapter(self.adapter_names[a_idx])
|
| 366 |
+
|
| 367 |
+
mems = self.store.search(q_emb, k=1)
|
| 368 |
+
ctx = (torch.cat([mems[0], query_ids.cpu()], dim=1)
|
| 369 |
+
if mems else query_ids.cpu())
|
| 370 |
+
|
| 371 |
+
z_expected = self.jepa.predict_next_latent(ctx)
|
| 372 |
+
|
| 373 |
+
best = None
|
| 374 |
+
for attempt in range(max_retries + 1):
|
| 375 |
+
draft = self.llm.generate(ctx.to(self.llm.device), n_new=n_new,
|
| 376 |
+
temperature=0.8 + 0.3 * attempt)
|
| 377 |
+
new_part = draft[:, ctx.size(1):].cpu()
|
| 378 |
+
score = F.cosine_similarity(
|
| 379 |
+
z_expected, self.jepa.encode(new_part)).item()
|
| 380 |
+
if best is None or score > best[1]:
|
| 381 |
+
best = (draft, score, attempt)
|
| 382 |
+
if score >= tau_consistency:
|
| 383 |
+
break
|
| 384 |
+
draft, score, attempt = best
|
| 385 |
+
return draft, score, self.adapter_names[a_idx], attempt
|
| 386 |
+
|
| 387 |
+
def answer_text(self, prompt: str, **kw):
|
| 388 |
+
"""Convenience wrapper for HF backends (uses the real tokenizer)."""
|
| 389 |
+
ids = self.llm.encode_text(prompt)
|
| 390 |
+
out, score, adapter, retries = self.answer_ids(ids, **kw)
|
| 391 |
+
return self.llm.decode(out), score, adapter, retries
|
| 392 |
+
|
| 393 |
+
# -------- continual learning hooks ---------------------------------------
|
| 394 |
+
def memorize_ids(self, ids):
|
| 395 |
+
self.store.add(self.emb(ids.cpu()), ids.cpu())
|
| 396 |
+
|
| 397 |
+
def memorize_text(self, text: str):
|
| 398 |
+
self.memorize_ids(self.llm.encode_text(text))
|
| 399 |
+
|
| 400 |
+
def new_task_adapter(self, name: str):
|
| 401 |
+
self.adapter_names.append(name)
|
| 402 |
+
self.llm.add_adapter(name)
|
| 403 |
+
old = self.router
|
| 404 |
+
self.router = Router(self.emb.d_emb, len(self.adapter_names))
|
| 405 |
+
with torch.no_grad():
|
| 406 |
+
self.router.fc.weight[: old.fc.out_features] = old.fc.weight
|
| 407 |
+
self.router.fc.bias[: old.fc.out_features] = old.fc.bias
|
| 408 |
+
|
| 409 |
+
# -------- one joint training step (LM + JEPA + Bridge) -------------------
|
| 410 |
+
def train_step(self, seg_a, seg_b, opt, w_bridge=0.1):
|
| 411 |
+
"""seg_a, seg_b: consecutive token-id segments [B, T] (same tokenizer
|
| 412 |
+
as the backend!). For HF backends build them with backend.tokenizer."""
|
| 413 |
+
logits, hidden = self.llm(seg_a.to(self.llm.device))
|
| 414 |
+
lm_loss = F.cross_entropy(
|
| 415 |
+
logits[:, :-1].reshape(-1, self.llm.vocab_size).float(),
|
| 416 |
+
seg_a[:, 1:].reshape(-1).to(logits.device))
|
| 417 |
+
|
| 418 |
+
jepa_loss = self.jepa.loss(seg_a.cpu(), seg_b.cpu())
|
| 419 |
+
|
| 420 |
+
z_llm = self.bridge(hidden.cpu() if hidden.device.type != "cpu" else hidden)
|
| 421 |
+
z_jepa = self.jepa.ctx_enc(seg_a.cpu()).detach()
|
| 422 |
+
bridge_loss = self.bridge.info_nce(z_llm, z_jepa.to(z_llm.device))
|
| 423 |
+
|
| 424 |
+
loss = lm_loss.cpu() + jepa_loss + w_bridge * bridge_loss
|
| 425 |
+
opt.zero_grad(); loss.backward(); opt.step()
|
| 426 |
+
self.jepa.ema_update()
|
| 427 |
+
return {"lm": lm_loss.item(), "jepa": jepa_loss.item(),
|
| 428 |
+
"bridge": bridge_loss.item()}
|
| 429 |
+
|
| 430 |
+
def make_optimizer(self, lr_lora=2e-4, lr_aux=1e-3):
|
| 431 |
+
return torch.optim.AdamW([
|
| 432 |
+
{"params": list(self.llm.trainable_parameters()), "lr": lr_lora},
|
| 433 |
+
{"params": list(self.jepa.ctx_enc.parameters())
|
| 434 |
+
+ list(self.jepa.predictor.parameters()), "lr": lr_aux},
|
| 435 |
+
{"params": list(self.bridge.parameters())
|
| 436 |
+
+ list(self.emb.parameters())
|
| 437 |
+
+ list(self.router.parameters()), "lr": lr_aux},
|
| 438 |
+
])
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
# ----------------------------------------------------------------------------
|
| 442 |
+
# 6. Helpers: turn raw text into (seg_a, seg_b) pairs with the HF tokenizer
|
| 443 |
+
# ----------------------------------------------------------------------------
|
| 444 |
+
def segment_pairs_from_texts(backend: HFBackend, texts, seg_len=64):
|
| 445 |
+
"""Yields consecutive-segment id pairs for the JEPA + LM losses."""
|
| 446 |
+
a_list, b_list = [], []
|
| 447 |
+
for t in texts:
|
| 448 |
+
ids = backend.tokenizer(t, return_tensors="pt").input_ids[0]
|
| 449 |
+
for i in range(0, len(ids) - 2 * seg_len, seg_len):
|
| 450 |
+
a_list.append(ids[i:i + seg_len])
|
| 451 |
+
b_list.append(ids[i + seg_len:i + 2 * seg_len])
|
| 452 |
+
if not a_list:
|
| 453 |
+
raise ValueError("texts too short for the chosen seg_len")
|
| 454 |
+
return torch.stack(a_list), torch.stack(b_list)
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
# ----------------------------------------------------------------------------
|
| 458 |
+
# 7. Demos
|
| 459 |
+
# ----------------------------------------------------------------------------
|
| 460 |
+
def demo_tiny(steps=50):
|
| 461 |
+
"""No-dependency smoke test."""
|
| 462 |
+
ens = Ensemble(llm="tiny")
|
| 463 |
+
opt = ens.make_optimizer()
|
| 464 |
+
for s in range(steps):
|
| 465 |
+
seg_a = torch.randint(0, ens.llm.vocab_size, (8, 32))
|
| 466 |
+
seg_b = torch.randint(0, ens.llm.vocab_size, (8, 32))
|
| 467 |
+
logs = ens.train_step(seg_a, seg_b, opt)
|
| 468 |
+
if s % 10 == 0:
|
| 469 |
+
print(f"step {s:3d} | " + " | ".join(f"{k} {v:.3f}" for k, v in logs.items()))
|
| 470 |
+
|
| 471 |
+
for _ in range(5):
|
| 472 |
+
ens.memorize_ids(torch.randint(0, ens.llm.vocab_size, (1, 32)))
|
| 473 |
+
ens.new_task_adapter("medical")
|
| 474 |
+
|
| 475 |
+
q = torch.randint(0, ens.llm.vocab_size, (1, 8))
|
| 476 |
+
out, score, adapter, retries = ens.answer_ids(q, tau_consistency=-1.0)
|
| 477 |
+
print(f"\nadapter={adapter} jepa_consistency={score:.3f} retries={retries}")
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
def demo_hf(model_path="Qwen/Qwen2.5-0.5B-Instruct"):
|
| 481 |
+
"""Real model from hub id OR local path, e.g. '/models/llama-3.2-1b'.
|
| 482 |
+
For gated Llama repos: huggingface-cli login first."""
|
| 483 |
+
ens = Ensemble(llm=model_path, load_in_4bit=False) # 4bit needs bitsandbytes
|
| 484 |
+
opt = ens.make_optimizer()
|
| 485 |
+
|
| 486 |
+
texts = ["Replace this with your real corpus. " * 50]
|
| 487 |
+
seg_a, seg_b = segment_pairs_from_texts(ens.llm, texts, seg_len=32)
|
| 488 |
+
for s in range(10): # tiny demo run
|
| 489 |
+
logs = ens.train_step(seg_a[:4], seg_b[:4], opt)
|
| 490 |
+
print(f"step {s} | " + " | ".join(f"{k} {v:.3f}" for k, v in logs.items()))
|
| 491 |
+
|
| 492 |
+
ens.memorize_text("The project codename is AURORA and it ships in Q3.")
|
| 493 |
+
ens.new_task_adapter("project_aurora")
|
| 494 |
+
|
| 495 |
+
text, score, adapter, retries = ens.answer_text(
|
| 496 |
+
"What is the project codename?", n_new=24, tau_consistency=-1.0)
|
| 497 |
+
print(f"\n[{adapter} | jepa={score:.3f} | retries={retries}]\n{text}")
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
if __name__ == "__main__":
|
| 501 |
+
import sys
|
| 502 |
+
arg = sys.argv[1] if len(sys.argv) > 1 else "tiny"
|
| 503 |
+
if arg == "tiny":
|
| 504 |
+
demo_tiny()
|
| 505 |
+
else:
|
| 506 |
+
demo_hf(arg) # python ensemble.py /models/llama-3.2-1b
|
| 507 |
+
# python ensemble.py Qwen/Qwen2.5-0.5B-Instruct
|
research/ensemble/src/ensemble/memory.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Retrieval memory: embedder, vector store, and adapter router."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Embedder(nn.Module):
|
| 11 |
+
def __init__(self, vocab_size: int, d_emb: int = 64):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.tok = nn.Embedding(vocab_size, d_emb)
|
| 14 |
+
self.enc = nn.GRU(d_emb, d_emb, batch_first=True, bidirectional=True)
|
| 15 |
+
self.proj = nn.Linear(2 * d_emb, d_emb)
|
| 16 |
+
self.d_emb = d_emb
|
| 17 |
+
|
| 18 |
+
def forward(self, ids):
|
| 19 |
+
h, _ = self.enc(self.tok(ids))
|
| 20 |
+
return F.normalize(self.proj(h.mean(dim=1)), dim=-1)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class VectorStore:
|
| 24 |
+
def __init__(self):
|
| 25 |
+
self.keys, self.values = [], []
|
| 26 |
+
|
| 27 |
+
def add(self, emb, payload):
|
| 28 |
+
self.keys.append(emb.squeeze(0).detach().cpu())
|
| 29 |
+
self.values.append(payload)
|
| 30 |
+
|
| 31 |
+
def search(self, q, k=2):
|
| 32 |
+
if not self.keys:
|
| 33 |
+
return []
|
| 34 |
+
K = torch.stack(self.keys)
|
| 35 |
+
sims = (q.detach().cpu() @ K.t()).squeeze(0)
|
| 36 |
+
top = sims.topk(min(k, len(self.keys))).indices
|
| 37 |
+
return [self.values[i] for i in top]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class Router(nn.Module):
|
| 41 |
+
def __init__(self, d_emb, n_adapters):
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.fc = nn.Linear(d_emb, n_adapters)
|
| 44 |
+
|
| 45 |
+
def forward(self, emb):
|
| 46 |
+
return self.fc(emb).argmax(dim=-1)
|
research/ensemble/src/ensemble/pretrain.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Joint pretrain: LLM (LoRA) + embedder + JEPA + bridge, saved to models/ensemble/."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
import random
|
| 9 |
+
import time
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
from ensemble.checkpoint import save_checkpoint
|
| 15 |
+
from ensemble.config import default_ensemble_out, load_dotenv, resolve_llm
|
| 16 |
+
from ensemble.jepa_ensemble import Ensemble, segment_pairs_from_texts
|
| 17 |
+
|
| 18 |
+
_REPO_ROOT = Path(__file__).resolve().parents[4]
|
| 19 |
+
_DEFAULT_DATA = _REPO_ROOT / "research/data/education-lesson-chat.jsonl"
|
| 20 |
+
_DEFAULT_KB = _REPO_ROOT / "research/data/benchmark-kb.jsonl"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _load_jsonl(path: Path) -> list[dict]:
|
| 24 |
+
rows = []
|
| 25 |
+
with open(path) as f:
|
| 26 |
+
for line in f:
|
| 27 |
+
line = line.strip()
|
| 28 |
+
if line:
|
| 29 |
+
rows.append(json.loads(line))
|
| 30 |
+
return rows
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _chat_to_text(row: dict) -> str:
|
| 34 |
+
messages = row.get("messages", [])
|
| 35 |
+
parts = [f"{m.get('role', 'user')}: {m.get('content', '')}" for m in messages]
|
| 36 |
+
return "\n".join(parts)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _collect_texts(data_path: Path, max_samples: int | None) -> list[str]:
|
| 40 |
+
rows = _load_jsonl(data_path)
|
| 41 |
+
if max_samples is not None:
|
| 42 |
+
rows = rows[:max_samples]
|
| 43 |
+
return [_chat_to_text(r) for r in rows if _chat_to_text(r).strip()]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _seed_memory(ens: Ensemble, kb_path: Path | None) -> int:
|
| 47 |
+
if kb_path is None or not kb_path.is_file():
|
| 48 |
+
return 0
|
| 49 |
+
count = 0
|
| 50 |
+
for row in _load_jsonl(kb_path):
|
| 51 |
+
text = row.get("text", "").strip()
|
| 52 |
+
if text:
|
| 53 |
+
ens.memorize_text(text)
|
| 54 |
+
count += 1
|
| 55 |
+
return count
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def pretrain(args) -> Path:
|
| 59 |
+
torch.manual_seed(args.seed)
|
| 60 |
+
random.seed(args.seed)
|
| 61 |
+
|
| 62 |
+
data_path = Path(args.data).resolve()
|
| 63 |
+
out_dir = Path(args.out).resolve()
|
| 64 |
+
kb_path = Path(args.kb).resolve() if args.kb else None
|
| 65 |
+
|
| 66 |
+
print(f"Loading ensemble backend: {args.llm}")
|
| 67 |
+
ens = Ensemble(llm=args.llm, load_in_4bit=args.load_in_4bit)
|
| 68 |
+
opt = ens.make_optimizer(lr_lora=args.lr_lora, lr_aux=args.lr_aux)
|
| 69 |
+
|
| 70 |
+
texts = _collect_texts(data_path, args.max_samples)
|
| 71 |
+
if not texts and args.llm != "tiny":
|
| 72 |
+
raise SystemExit(f"No training texts found in {data_path}")
|
| 73 |
+
|
| 74 |
+
mem_count = _seed_memory(ens, kb_path)
|
| 75 |
+
print(f"Training texts: {len(texts)} | memory snippets: {mem_count}")
|
| 76 |
+
|
| 77 |
+
if args.llm == "tiny":
|
| 78 |
+
n_pairs = max(args.steps * args.batch_size, args.batch_size)
|
| 79 |
+
v = ens.llm.vocab_size
|
| 80 |
+
seg_a = torch.randint(0, v, (n_pairs, args.seg_len))
|
| 81 |
+
seg_b = torch.randint(0, v, (n_pairs, args.seg_len))
|
| 82 |
+
else:
|
| 83 |
+
seg_a, seg_b = segment_pairs_from_texts(
|
| 84 |
+
ens.llm, texts, seg_len=args.seg_len
|
| 85 |
+
)
|
| 86 |
+
n_pairs = seg_a.size(0)
|
| 87 |
+
batch = min(args.batch_size, n_pairs)
|
| 88 |
+
print(f"Segment pairs: {n_pairs} | batch={batch} | steps={args.steps}")
|
| 89 |
+
|
| 90 |
+
t0 = time.time()
|
| 91 |
+
for step in range(args.steps):
|
| 92 |
+
idx = torch.randint(0, n_pairs, (batch,))
|
| 93 |
+
logs = ens.train_step(seg_a[idx], seg_b[idx], opt, w_bridge=args.w_bridge)
|
| 94 |
+
if step % max(1, args.log_every) == 0 or step == args.steps - 1:
|
| 95 |
+
parts = " | ".join(f"{k} {v:.4f}" for k, v in logs.items())
|
| 96 |
+
print(f"step {step:4d}/{args.steps} | {parts}")
|
| 97 |
+
|
| 98 |
+
elapsed = time.time() - t0
|
| 99 |
+
meta = {
|
| 100 |
+
"steps": args.steps,
|
| 101 |
+
"batch_size": batch,
|
| 102 |
+
"seg_len": args.seg_len,
|
| 103 |
+
"data": str(data_path),
|
| 104 |
+
"kb": str(kb_path) if kb_path else None,
|
| 105 |
+
"memory_count": mem_count,
|
| 106 |
+
"text_count": len(texts),
|
| 107 |
+
"elapsed_s": round(elapsed, 1),
|
| 108 |
+
"lr_lora": args.lr_lora,
|
| 109 |
+
"lr_aux": args.lr_aux,
|
| 110 |
+
"w_bridge": args.w_bridge,
|
| 111 |
+
"seed": args.seed,
|
| 112 |
+
"preset": getattr(args, "preset", None),
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
saved = save_checkpoint(
|
| 116 |
+
ens,
|
| 117 |
+
out_dir,
|
| 118 |
+
base_llm=args.llm,
|
| 119 |
+
training_meta=meta,
|
| 120 |
+
)
|
| 121 |
+
print(f"\nSaved ensemble checkpoint β {saved}")
|
| 122 |
+
print("Benchmark with slm-evals:")
|
| 123 |
+
print(
|
| 124 |
+
f" uv run --package slm-evals slm-benchmark "
|
| 125 |
+
f"--model {saved} --model-type ensemble "
|
| 126 |
+
f"--benchmarks bfcl --max-samples 5"
|
| 127 |
+
)
|
| 128 |
+
return saved
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def parse_args():
|
| 132 |
+
p = argparse.ArgumentParser(
|
| 133 |
+
description="Pretrain JEPA ensemble (LLM+emb+JEPA) and save to models/ensemble/"
|
| 134 |
+
)
|
| 135 |
+
p.add_argument(
|
| 136 |
+
"--llm",
|
| 137 |
+
default=None,
|
| 138 |
+
help=(
|
| 139 |
+
"HF hub id / local path, 'tiny' for CPU smoke, or omit to use "
|
| 140 |
+
"LLM_PATH / BASE / MODEL_ID / ACTIVE_MODEL from .env + models.yaml"
|
| 141 |
+
),
|
| 142 |
+
)
|
| 143 |
+
p.add_argument(
|
| 144 |
+
"--preset",
|
| 145 |
+
default=None,
|
| 146 |
+
help="models.yaml preset key (default: ENSEMBLE_PRESET or ACTIVE_MODEL)",
|
| 147 |
+
)
|
| 148 |
+
p.add_argument(
|
| 149 |
+
"--data",
|
| 150 |
+
default=str(_DEFAULT_DATA),
|
| 151 |
+
help="Chat JSONL (messages[]) for segment-pair training",
|
| 152 |
+
)
|
| 153 |
+
p.add_argument(
|
| 154 |
+
"--kb",
|
| 155 |
+
default=str(_DEFAULT_KB),
|
| 156 |
+
help="Optional KB JSONL (text field) loaded into vector store",
|
| 157 |
+
)
|
| 158 |
+
p.add_argument(
|
| 159 |
+
"--out",
|
| 160 |
+
default=None,
|
| 161 |
+
help="Output dir (default: ENSEMBLE_OUT or models/ensemble/<preset>-jepa-pretrain)",
|
| 162 |
+
)
|
| 163 |
+
p.add_argument("--steps", type=int, default=100)
|
| 164 |
+
p.add_argument("--batch-size", type=int, default=4)
|
| 165 |
+
p.add_argument("--seg-len", type=int, default=32)
|
| 166 |
+
p.add_argument("--max-samples", type=int, default=None)
|
| 167 |
+
p.add_argument("--lr-lora", type=float, default=2e-4)
|
| 168 |
+
p.add_argument("--lr-aux", type=float, default=1e-3)
|
| 169 |
+
p.add_argument("--w-bridge", type=float, default=0.1)
|
| 170 |
+
p.add_argument("--log-every", type=int, default=10)
|
| 171 |
+
p.add_argument("--seed", type=int, default=0)
|
| 172 |
+
p.add_argument("--load-in-4bit", action="store_true")
|
| 173 |
+
p.add_argument("--no-kb", action="store_true", help="Skip loading KB into memory")
|
| 174 |
+
return p.parse_args()
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def main():
|
| 178 |
+
load_dotenv()
|
| 179 |
+
args = parse_args()
|
| 180 |
+
if args.no_kb:
|
| 181 |
+
args.kb = None
|
| 182 |
+
|
| 183 |
+
preset_key = args.preset
|
| 184 |
+
if args.llm is None or args.llm == "auto":
|
| 185 |
+
args.llm, preset_key = resolve_llm(preset_arg=args.preset)
|
| 186 |
+
elif args.llm != "tiny" and not args.preset:
|
| 187 |
+
_, preset_key = resolve_llm(llm_arg=args.llm)
|
| 188 |
+
|
| 189 |
+
if not args.out:
|
| 190 |
+
args.out = os.environ.get("ENSEMBLE_OUT") or default_ensemble_out(preset_key)
|
| 191 |
+
|
| 192 |
+
args.preset = preset_key
|
| 193 |
+
print(f"Resolved LLM: {args.llm}" + (f" (preset {preset_key})" if preset_key else ""))
|
| 194 |
+
pretrain(args)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
if __name__ == "__main__":
|
| 198 |
+
main()
|
research/ensemble/src/ensemble/world_ensemble.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""World-model ensemble: plan -> generate -> energy-rank."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import sys
|
| 6 |
+
import time
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
|
| 12 |
+
from ensemble.backends import HFLLM, load_llm
|
| 13 |
+
from ensemble.bridge import Bridge
|
| 14 |
+
from ensemble.energy import EnergyModel
|
| 15 |
+
from ensemble.jepa import JEPA
|
| 16 |
+
from ensemble.memory import Embedder, VectorStore
|
| 17 |
+
from ensemble.world_model import WorldModel
|
| 18 |
+
|
| 19 |
+
torch.manual_seed(0)
|
| 20 |
+
|
| 21 |
+
D_LAT = 96
|
| 22 |
+
D_EMB = 64
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class WorldEnsemble(nn.Module):
|
| 26 |
+
def __init__(self, llm_spec="tiny"):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.llm = load_llm(llm_spec)
|
| 29 |
+
V, H = self.llm.vocab_size, self.llm.hidden_size
|
| 30 |
+
self.emb = Embedder(V, D_EMB)
|
| 31 |
+
self.jepa = JEPA(V, D_LAT)
|
| 32 |
+
self.world = WorldModel(D_LAT)
|
| 33 |
+
self.energy = EnergyModel(D_LAT)
|
| 34 |
+
self.bridge = Bridge(H, D_LAT)
|
| 35 |
+
self.store = VectorStore()
|
| 36 |
+
|
| 37 |
+
@torch.no_grad()
|
| 38 |
+
def world_state(self, segments):
|
| 39 |
+
s = self.world.init_state(1, "cpu")
|
| 40 |
+
for seg in segments:
|
| 41 |
+
z = self.jepa.encode(seg.cpu())
|
| 42 |
+
s, _ = self.world.step(s, z)
|
| 43 |
+
return s
|
| 44 |
+
|
| 45 |
+
@torch.no_grad()
|
| 46 |
+
def answer(self, query_ids, n_new=24, n_drafts=6, horizon=3):
|
| 47 |
+
q_emb = self.emb(query_ids.cpu())
|
| 48 |
+
mems = self.store.search(q_emb, k=1)
|
| 49 |
+
segments = (mems + [query_ids.cpu()]) if mems else [query_ids.cpu()]
|
| 50 |
+
ctx = torch.cat(segments, dim=1)
|
| 51 |
+
|
| 52 |
+
s = self.world_state(segments)
|
| 53 |
+
plan, _ = self.world.rollout(s, horizon)
|
| 54 |
+
|
| 55 |
+
drafts, lat = [], []
|
| 56 |
+
for _ in range(n_drafts):
|
| 57 |
+
out = self.llm.generate(
|
| 58 |
+
ctx.to(self.llm.device), n_new=n_new, temperature=0.9
|
| 59 |
+
)
|
| 60 |
+
new = out[:, ctx.size(1) :].cpu()
|
| 61 |
+
drafts.append(new)
|
| 62 |
+
lat.append(self.jepa.encode(new))
|
| 63 |
+
Z = torch.cat(lat, 0)
|
| 64 |
+
E = self.energy.rank(s, Z)
|
| 65 |
+
best = E.argmin().item()
|
| 66 |
+
return {
|
| 67 |
+
"output": drafts[best],
|
| 68 |
+
"energy": E[best].item(),
|
| 69 |
+
"all_energies": E.tolist(),
|
| 70 |
+
"plan_alignment": F.cosine_similarity(
|
| 71 |
+
plan[:, 0], Z[best : best + 1]
|
| 72 |
+
).item(),
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
def memorize(self, ids):
|
| 76 |
+
self.store.add(self.emb(ids.cpu()), ids.cpu())
|
| 77 |
+
|
| 78 |
+
def train_step(
|
| 79 |
+
self,
|
| 80 |
+
seg_seq,
|
| 81 |
+
opt,
|
| 82 |
+
w=None,
|
| 83 |
+
hard_negs=True,
|
| 84 |
+
):
|
| 85 |
+
if w is None:
|
| 86 |
+
w = dict(lm=1.0, jepa=1.0, world=1.0, ebm=1.0, bridge=0.1)
|
| 87 |
+
|
| 88 |
+
B, T, L = seg_seq.shape
|
| 89 |
+
dev = self.llm.device
|
| 90 |
+
|
| 91 |
+
flat = seg_seq[:, 0].to(dev)
|
| 92 |
+
logits, hidden = self.llm(flat)
|
| 93 |
+
lm = F.cross_entropy(
|
| 94 |
+
logits[:, :-1].reshape(-1, self.llm.vocab_size).float(),
|
| 95 |
+
flat[:, 1:].reshape(-1),
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
jepa = self.jepa.loss(seg_seq[:, 0], seg_seq[:, 1])
|
| 99 |
+
|
| 100 |
+
z_seq = torch.stack(
|
| 101 |
+
[self.jepa.enc(seg_seq[:, t]) for t in range(T)], 1
|
| 102 |
+
)
|
| 103 |
+
world = self.world.sequence_loss(z_seq)
|
| 104 |
+
|
| 105 |
+
s = self.world.init_state(B, z_seq.device)
|
| 106 |
+
s, _ = self.world.step(s, z_seq[:, 0].detach())
|
| 107 |
+
z_pos = z_seq[:, 1].detach()
|
| 108 |
+
z_negs = None
|
| 109 |
+
if hard_negs:
|
| 110 |
+
with torch.no_grad():
|
| 111 |
+
gen = self.llm.generate(seg_seq[:, 0].to(dev), n_new=L)
|
| 112 |
+
gen_new = gen[:, seg_seq.size(2) :].cpu()
|
| 113 |
+
z_negs = self.jepa.encode(gen_new).unsqueeze(1)
|
| 114 |
+
ebm = self.energy.contrastive_loss(s, z_pos, z_negs)
|
| 115 |
+
|
| 116 |
+
bridge = self.bridge.info_nce(
|
| 117 |
+
self.bridge(
|
| 118 |
+
hidden.cpu() if hidden.device.type != "cpu" else hidden
|
| 119 |
+
),
|
| 120 |
+
self.jepa.enc(seg_seq[:, 0]).detach(),
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
loss = (
|
| 124 |
+
w["lm"] * lm.cpu()
|
| 125 |
+
+ w["jepa"] * jepa
|
| 126 |
+
+ w["world"] * world
|
| 127 |
+
+ w["ebm"] * ebm
|
| 128 |
+
+ w["bridge"] * bridge
|
| 129 |
+
)
|
| 130 |
+
opt.zero_grad()
|
| 131 |
+
loss.backward()
|
| 132 |
+
opt.step()
|
| 133 |
+
self.jepa.ema()
|
| 134 |
+
return dict(
|
| 135 |
+
lm=lm.item(),
|
| 136 |
+
jepa=jepa.item(),
|
| 137 |
+
world=world.item(),
|
| 138 |
+
ebm=ebm.item(),
|
| 139 |
+
bridge=bridge.item(),
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
def make_optimizer(self, lr_lora=2e-4, lr_aux=1e-3):
|
| 143 |
+
return torch.optim.AdamW(
|
| 144 |
+
[
|
| 145 |
+
{"params": list(self.llm.trainable_parameters()), "lr": lr_lora},
|
| 146 |
+
{
|
| 147 |
+
"params": list(self.jepa.enc.parameters())
|
| 148 |
+
+ list(self.jepa.pred.parameters()),
|
| 149 |
+
"lr": lr_aux,
|
| 150 |
+
},
|
| 151 |
+
{"params": list(self.world.parameters()), "lr": lr_aux},
|
| 152 |
+
{"params": list(self.energy.parameters()), "lr": lr_aux},
|
| 153 |
+
{
|
| 154 |
+
"params": list(self.bridge.parameters())
|
| 155 |
+
+ list(self.emb.parameters()),
|
| 156 |
+
"lr": lr_aux,
|
| 157 |
+
},
|
| 158 |
+
]
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def toy_segment_sequences(B=8, T=4, L=24, vocab=1000):
|
| 163 |
+
return torch.randint(0, vocab, (B, T, L))
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def hf_segment_sequences(llm: HFLLM, texts, T=4, L=64):
|
| 167 |
+
seqs = []
|
| 168 |
+
for t in texts:
|
| 169 |
+
ids = llm.tokenizer(t, return_tensors="pt").input_ids[0]
|
| 170 |
+
n = (len(ids) // (T * L)) * T * L
|
| 171 |
+
if n:
|
| 172 |
+
seqs.append(ids[:n].view(-1, T, L))
|
| 173 |
+
if not seqs:
|
| 174 |
+
raise ValueError("corpus too short for T*L window")
|
| 175 |
+
return torch.cat(seqs, 0)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def demo(spec="tiny", steps=60):
|
| 179 |
+
ens = WorldEnsemble(spec)
|
| 180 |
+
opt = ens.make_optimizer()
|
| 181 |
+
|
| 182 |
+
if spec == "tiny":
|
| 183 |
+
get_batch = lambda: toy_segment_sequences(vocab=ens.llm.vocab_size)
|
| 184 |
+
else:
|
| 185 |
+
corpus = ["Replace with your real documents. " * 200]
|
| 186 |
+
data = hf_segment_sequences(ens.llm, corpus, T=4, L=32)
|
| 187 |
+
get_batch = lambda: data[torch.randperm(len(data))[:4]]
|
| 188 |
+
steps = min(steps, 10)
|
| 189 |
+
|
| 190 |
+
t0 = time.time()
|
| 191 |
+
for s in range(steps):
|
| 192 |
+
logs = ens.train_step(
|
| 193 |
+
get_batch(), opt, hard_negs=(s > steps // 2)
|
| 194 |
+
)
|
| 195 |
+
if s % 10 == 0:
|
| 196 |
+
print(
|
| 197 |
+
f"step {s:3d} | "
|
| 198 |
+
+ " | ".join(f"{k} {v:.3f}" for k, v in logs.items())
|
| 199 |
+
)
|
| 200 |
+
print(f"trained {steps} steps in {time.time() - t0:.1f}s")
|
| 201 |
+
|
| 202 |
+
for _ in range(4):
|
| 203 |
+
if spec == "tiny":
|
| 204 |
+
ens.memorize(torch.randint(0, ens.llm.vocab_size, (1, 24)))
|
| 205 |
+
q = (
|
| 206 |
+
torch.randint(0, ens.llm.vocab_size, (1, 12))
|
| 207 |
+
if spec == "tiny"
|
| 208 |
+
else ens.llm.tokenizer(
|
| 209 |
+
"What is this document about?", return_tensors="pt"
|
| 210 |
+
).input_ids
|
| 211 |
+
)
|
| 212 |
+
res = ens.answer(q, n_drafts=6, horizon=3)
|
| 213 |
+
print(
|
| 214 |
+
f"\nselected draft energy={res['energy']:.3f} "
|
| 215 |
+
f"(all: {[f'{e:.2f}' for e in res['all_energies']]})"
|
| 216 |
+
)
|
| 217 |
+
print(f"planβoutput alignment: {res['plan_alignment']:.3f}")
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
if __name__ == "__main__":
|
| 221 |
+
from ensemble.config import load_dotenv, resolve_llm
|
| 222 |
+
|
| 223 |
+
load_dotenv()
|
| 224 |
+
spec = sys.argv[1] if len(sys.argv) > 1 else None
|
| 225 |
+
if spec is None or spec == "auto":
|
| 226 |
+
spec, preset = resolve_llm()
|
| 227 |
+
print(f"Resolved LLM: {spec} (preset {preset})")
|
| 228 |
+
demo(spec or "tiny")
|
research/ensemble/src/ensemble/world_model.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Latent world model: multi-step rollout in JEPA space."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class WorldModel(nn.Module):
|
| 11 |
+
def __init__(self, d_latent: int):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.cell = nn.GRUCell(d_latent, d_latent)
|
| 14 |
+
self.head = nn.Linear(d_latent, d_latent)
|
| 15 |
+
self.s0 = nn.Parameter(torch.zeros(d_latent))
|
| 16 |
+
self.d_latent = d_latent
|
| 17 |
+
|
| 18 |
+
def init_state(self, B, device):
|
| 19 |
+
return self.s0.unsqueeze(0).expand(B, -1).contiguous().to(device)
|
| 20 |
+
|
| 21 |
+
def step(self, s, z):
|
| 22 |
+
s = self.cell(z, s)
|
| 23 |
+
return s, self.head(s)
|
| 24 |
+
|
| 25 |
+
def rollout(self, s, horizon):
|
| 26 |
+
preds = []
|
| 27 |
+
for _ in range(horizon):
|
| 28 |
+
z_hat = self.head(s)
|
| 29 |
+
preds.append(z_hat)
|
| 30 |
+
s = self.cell(z_hat, s)
|
| 31 |
+
return torch.stack(preds, 1), s
|
| 32 |
+
|
| 33 |
+
def sequence_loss(self, z_seq):
|
| 34 |
+
B, T, _ = z_seq.shape
|
| 35 |
+
s = self.init_state(B, z_seq.device)
|
| 36 |
+
loss = 0.0
|
| 37 |
+
for t in range(T - 1):
|
| 38 |
+
s, z_hat = self.step(s, z_seq[:, t])
|
| 39 |
+
loss = loss + F.mse_loss(z_hat, z_seq[:, t + 1])
|
| 40 |
+
return loss / (T - 1)
|
research/ensemble/src/ensemble/world_model_ensemble.py
ADDED
|
@@ -0,0 +1,499 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
World-Model Ensemble: EMB + EBM + JEPA + World Model + small LLM (from path)
|
| 3 |
+
=============================================================================
|
| 4 |
+
A LeCun-style modular agent built around a small language model.
|
| 5 |
+
|
| 6 |
+
ARCHITECTURE
|
| 7 |
+
------------
|
| 8 |
+
ββββββββββββββββββββββββββββββ
|
| 9 |
+
input tokens βββΊ EMB βββ€ VectorStore (retrieval/CL) ββββΊ context
|
| 10 |
+
β ββββββββββββββββββββββββββββββ β
|
| 11 |
+
β β
|
| 12 |
+
βΌ βΌ
|
| 13 |
+
JEPA encoder βββΊ latent state s_t βββΊ WORLD MODEL βββΊ Ε_{t+1..t+H}
|
| 14 |
+
β (GRU dynamics, multi-step rollout) β
|
| 15 |
+
β β
|
| 16 |
+
β ββββββββββββββββββββββββββββββββββββββ β
|
| 17 |
+
ββββββββββββΊ β ENERGY MODEL E(s_ctx, z_candidate)β βββ
|
| 18 |
+
β low energy = compatible/plausible β
|
| 19 |
+
ββββββββββββββββββ¬ββββββββββββββββββββ
|
| 20 |
+
β scores drafts / plans
|
| 21 |
+
βΌ
|
| 22 |
+
LLM (small, loaded from path, LoRA bank) βββΊ N drafts βββΊ pick argmin E
|
| 23 |
+
|
| 24 |
+
ROLES
|
| 25 |
+
-----
|
| 26 |
+
EMB perception for retrieval + routing (non-parametric memory)
|
| 27 |
+
JEPA learns the latent space: predict z(next segment) from z(context)
|
| 28 |
+
(EMA target encoder + variance reg, no token reconstruction)
|
| 29 |
+
WORLD MODEL deterministic latent dynamics s_{t+1} = f(s_t, z_t):
|
| 30 |
+
rolls the conversation/document state forward H steps in
|
| 31 |
+
LATENT space β cheap lookahead without decoding tokens
|
| 32 |
+
ENERGY E(s, z) β R, trained so true continuations have LOW energy and
|
| 33 |
+
negatives (shuffled / model-generated) have HIGH energy.
|
| 34 |
+
At inference it is the critic: rank LLM drafts, reject bad plans.
|
| 35 |
+
LLM the only token-level generator. Loaded from a local path or HF id;
|
| 36 |
+
frozen base + LoRA adapters (continual learning by isolation).
|
| 37 |
+
|
| 38 |
+
WHY EBM *and* JEPA? JEPA gives a point prediction αΊ of the future latent;
|
| 39 |
+
the EBM gives a *compatibility landscape* E(s, z) β it can say "both A and B
|
| 40 |
+
are plausible" where a point predictor must average them. JEPA trains the
|
| 41 |
+
representation; the EBM scores hypotheses in it. World model chains JEPA
|
| 42 |
+
one-step predictions into multi-step rollouts that the EBM can evaluate.
|
| 43 |
+
|
| 44 |
+
USAGE
|
| 45 |
+
-----
|
| 46 |
+
pip install torch # toy mode
|
| 47 |
+
pip install transformers peft accelerate # real LLM mode
|
| 48 |
+
|
| 49 |
+
python world_model_ensemble.py tiny # smoke test
|
| 50 |
+
python world_model_ensemble.py /models/llama-3.2-1b # local weights
|
| 51 |
+
python world_model_ensemble.py Qwen/Qwen2.5-0.5B-Instruct
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
from __future__ import annotations
|
| 55 |
+
import copy
|
| 56 |
+
import math
|
| 57 |
+
import sys
|
| 58 |
+
import time
|
| 59 |
+
|
| 60 |
+
import torch
|
| 61 |
+
import torch.nn as nn
|
| 62 |
+
import torch.nn.functional as F
|
| 63 |
+
|
| 64 |
+
torch.manual_seed(0)
|
| 65 |
+
|
| 66 |
+
D_LAT = 96 # shared latent dimension (JEPA / world / energy)
|
| 67 |
+
D_EMB = 64 # retrieval embedding dim
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# ============================================================================
|
| 71 |
+
# 1. LLM backend β load small model from path / hub, or toy fallback
|
| 72 |
+
# (same contract as before: forward -> (logits, hidden), generate, adapters)
|
| 73 |
+
# ============================================================================
|
| 74 |
+
class TinyLLM(nn.Module):
|
| 75 |
+
VOCAB, D, L, H, T = 1000, 128, 2, 4, 32
|
| 76 |
+
|
| 77 |
+
def __init__(self):
|
| 78 |
+
super().__init__()
|
| 79 |
+
self.tok = nn.Embedding(self.VOCAB, self.D)
|
| 80 |
+
self.pos = nn.Embedding(self.T * 4, self.D)
|
| 81 |
+
layer = nn.TransformerEncoderLayer(self.D, self.H, 4 * self.D,
|
| 82 |
+
batch_first=True, norm_first=True)
|
| 83 |
+
self.blocks = nn.TransformerEncoder(layer, self.L)
|
| 84 |
+
self.head = nn.Linear(self.D, self.VOCAB, bias=False)
|
| 85 |
+
self.vocab_size, self.hidden_size = self.VOCAB, self.D
|
| 86 |
+
|
| 87 |
+
def forward(self, ids):
|
| 88 |
+
Tn = ids.size(1)
|
| 89 |
+
x = self.tok(ids) + self.pos(torch.arange(Tn, device=ids.device))
|
| 90 |
+
mask = torch.triu(torch.full((Tn, Tn), float("-inf"),
|
| 91 |
+
device=ids.device), 1)
|
| 92 |
+
h = self.blocks(x, mask=mask)
|
| 93 |
+
return self.head(h), h
|
| 94 |
+
|
| 95 |
+
@torch.no_grad()
|
| 96 |
+
def generate(self, ids, n_new=16, temperature=1.0):
|
| 97 |
+
for _ in range(n_new):
|
| 98 |
+
logits, _ = self(ids[:, -self.T:])
|
| 99 |
+
nxt = torch.multinomial(
|
| 100 |
+
F.softmax(logits[:, -1] / temperature, -1), 1)
|
| 101 |
+
ids = torch.cat([ids, nxt], 1)
|
| 102 |
+
return ids
|
| 103 |
+
|
| 104 |
+
def trainable_parameters(self):
|
| 105 |
+
return self.parameters()
|
| 106 |
+
|
| 107 |
+
@property
|
| 108 |
+
def device(self):
|
| 109 |
+
return next(self.parameters()).device
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class HFLLM(nn.Module):
|
| 113 |
+
"""Small model from a local path or HF id, frozen base + LoRA."""
|
| 114 |
+
def __init__(self, path, lora_r=16):
|
| 115 |
+
super().__init__()
|
| 116 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 117 |
+
from peft import LoraConfig, get_peft_model
|
| 118 |
+
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
| 119 |
+
if self.tokenizer.pad_token is None:
|
| 120 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 121 |
+
base = AutoModelForCausalLM.from_pretrained(
|
| 122 |
+
path, torch_dtype=torch.bfloat16
|
| 123 |
+
if torch.cuda.is_available() else torch.float32,
|
| 124 |
+
device_map="auto" if torch.cuda.is_available() else None)
|
| 125 |
+
for p in base.parameters():
|
| 126 |
+
p.requires_grad_(False)
|
| 127 |
+
cfg = LoraConfig(r=lora_r, lora_alpha=2 * lora_r, lora_dropout=0.05,
|
| 128 |
+
target_modules=["q_proj", "v_proj"],
|
| 129 |
+
task_type="CAUSAL_LM")
|
| 130 |
+
self.model = get_peft_model(base, cfg)
|
| 131 |
+
self.vocab_size = self.model.config.vocab_size
|
| 132 |
+
self.hidden_size = self.model.config.hidden_size
|
| 133 |
+
|
| 134 |
+
def forward(self, ids):
|
| 135 |
+
out = self.model(input_ids=ids.to(self.device),
|
| 136 |
+
output_hidden_states=True)
|
| 137 |
+
return out.logits, out.hidden_states[-1]
|
| 138 |
+
|
| 139 |
+
@torch.no_grad()
|
| 140 |
+
def generate(self, ids, n_new=32, temperature=0.8):
|
| 141 |
+
return self.model.generate(
|
| 142 |
+
input_ids=ids.to(self.device), max_new_tokens=n_new,
|
| 143 |
+
do_sample=True, temperature=temperature,
|
| 144 |
+
pad_token_id=self.tokenizer.pad_token_id)
|
| 145 |
+
|
| 146 |
+
def trainable_parameters(self):
|
| 147 |
+
return (p for p in self.model.parameters() if p.requires_grad)
|
| 148 |
+
|
| 149 |
+
@property
|
| 150 |
+
def device(self):
|
| 151 |
+
return next(self.model.parameters()).device
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def load_llm(spec: str):
|
| 155 |
+
return TinyLLM() if spec == "tiny" else HFLLM(spec)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
# ============================================================================
|
| 159 |
+
# 2. Embedder (retrieval) β vocab-agnostic
|
| 160 |
+
# ============================================================================
|
| 161 |
+
class Embedder(nn.Module):
|
| 162 |
+
def __init__(self, vocab):
|
| 163 |
+
super().__init__()
|
| 164 |
+
self.tok = nn.Embedding(vocab, D_EMB)
|
| 165 |
+
self.gru = nn.GRU(D_EMB, D_EMB, batch_first=True, bidirectional=True)
|
| 166 |
+
self.out = nn.Linear(2 * D_EMB, D_EMB)
|
| 167 |
+
|
| 168 |
+
def forward(self, ids):
|
| 169 |
+
h, _ = self.gru(self.tok(ids))
|
| 170 |
+
return F.normalize(self.out(h.mean(1)), dim=-1)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class VectorStore:
|
| 174 |
+
def __init__(self):
|
| 175 |
+
self.K, self.V = [], []
|
| 176 |
+
|
| 177 |
+
def add(self, k, v):
|
| 178 |
+
self.K.append(k.squeeze(0).detach().cpu()); self.V.append(v)
|
| 179 |
+
|
| 180 |
+
def search(self, q, k=1):
|
| 181 |
+
if not self.K:
|
| 182 |
+
return []
|
| 183 |
+
sims = (q.detach().cpu() @ torch.stack(self.K).t()).squeeze(0)
|
| 184 |
+
return [self.V[i] for i in sims.topk(min(k, len(self.K))).indices]
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
# ============================================================================
|
| 188 |
+
# 3. JEPA β owns the latent space (EMA target encoder, variance-regularized)
|
| 189 |
+
# ============================================================================
|
| 190 |
+
class SegEncoder(nn.Module):
|
| 191 |
+
def __init__(self, vocab):
|
| 192 |
+
super().__init__()
|
| 193 |
+
self.tok = nn.Embedding(vocab, D_LAT)
|
| 194 |
+
self.gru = nn.GRU(D_LAT, D_LAT, batch_first=True)
|
| 195 |
+
self.out = nn.Linear(D_LAT, D_LAT)
|
| 196 |
+
|
| 197 |
+
def forward(self, ids):
|
| 198 |
+
h, _ = self.gru(self.tok(ids))
|
| 199 |
+
return self.out(h.mean(1)) # [B, D_LAT]
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
class JEPA(nn.Module):
|
| 203 |
+
def __init__(self, vocab, m=0.996):
|
| 204 |
+
super().__init__()
|
| 205 |
+
self.enc = SegEncoder(vocab) # context/online enc
|
| 206 |
+
self.tgt = copy.deepcopy(self.enc) # EMA target
|
| 207 |
+
for p in self.tgt.parameters():
|
| 208 |
+
p.requires_grad_(False)
|
| 209 |
+
self.pred = nn.Sequential(nn.Linear(D_LAT, 2 * D_LAT), nn.GELU(),
|
| 210 |
+
nn.Linear(2 * D_LAT, D_LAT))
|
| 211 |
+
self.m = m
|
| 212 |
+
|
| 213 |
+
@torch.no_grad()
|
| 214 |
+
def ema(self):
|
| 215 |
+
for pt, pc in zip(self.tgt.parameters(), self.enc.parameters()):
|
| 216 |
+
pt.mul_(self.m).add_(pc.detach(), alpha=1 - self.m)
|
| 217 |
+
|
| 218 |
+
def loss(self, seg_a, seg_b):
|
| 219 |
+
z_hat = self.pred(self.enc(seg_a))
|
| 220 |
+
with torch.no_grad():
|
| 221 |
+
z_tgt = self.tgt(seg_b)
|
| 222 |
+
var = F.relu(1.0 - z_hat.std(0)).mean() # anti-collapse
|
| 223 |
+
return F.mse_loss(z_hat, z_tgt) + 0.5 * var
|
| 224 |
+
|
| 225 |
+
@torch.no_grad()
|
| 226 |
+
def encode(self, seg): # target space
|
| 227 |
+
return self.tgt(seg)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
# ============================================================================
|
| 231 |
+
# 4. WORLD MODEL β latent dynamics s_{t+1} = f(s_t, z_t), multi-step rollout
|
| 232 |
+
# Trained on SEQUENCES of segments: predict each next latent from state.
|
| 233 |
+
# ============================================================================
|
| 234 |
+
class WorldModel(nn.Module):
|
| 235 |
+
def __init__(self):
|
| 236 |
+
super().__init__()
|
| 237 |
+
self.cell = nn.GRUCell(D_LAT, D_LAT) # state update
|
| 238 |
+
self.head = nn.Linear(D_LAT, D_LAT) # state -> αΊ_{t+1}
|
| 239 |
+
self.s0 = nn.Parameter(torch.zeros(D_LAT))
|
| 240 |
+
|
| 241 |
+
def init_state(self, B, device):
|
| 242 |
+
return self.s0.unsqueeze(0).expand(B, -1).contiguous().to(device)
|
| 243 |
+
|
| 244 |
+
def step(self, s, z):
|
| 245 |
+
"""Consume observed latent z_t, return (new state, prediction αΊ_{t+1})."""
|
| 246 |
+
s = self.cell(z, s)
|
| 247 |
+
return s, self.head(s)
|
| 248 |
+
|
| 249 |
+
def rollout(self, s, horizon):
|
| 250 |
+
"""Imagine H future latents feeding its own predictions back in."""
|
| 251 |
+
preds = []
|
| 252 |
+
for _ in range(horizon):
|
| 253 |
+
z_hat = self.head(s)
|
| 254 |
+
preds.append(z_hat)
|
| 255 |
+
s = self.cell(z_hat, s)
|
| 256 |
+
return torch.stack(preds, 1), s # [B, H, D_LAT]
|
| 257 |
+
|
| 258 |
+
def sequence_loss(self, z_seq):
|
| 259 |
+
"""z_seq: [B, T, D_LAT] observed segment latents (teacher forcing)."""
|
| 260 |
+
B, T, _ = z_seq.shape
|
| 261 |
+
s = self.init_state(B, z_seq.device)
|
| 262 |
+
loss = 0.0
|
| 263 |
+
for t in range(T - 1):
|
| 264 |
+
s, z_hat = self.step(s, z_seq[:, t])
|
| 265 |
+
loss = loss + F.mse_loss(z_hat, z_seq[:, t + 1])
|
| 266 |
+
return loss / (T - 1)
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
# ============================================================================
|
| 270 |
+
# 5. ENERGY MODEL β E(state, candidate latent) β R, low = plausible
|
| 271 |
+
# Trained with InfoNCE-style contrastive: positives = true next latent,
|
| 272 |
+
# negatives = (a) other batch items, (b) LLM-generated drafts (optional).
|
| 273 |
+
# ============================================================================
|
| 274 |
+
class EnergyModel(nn.Module):
|
| 275 |
+
def __init__(self):
|
| 276 |
+
super().__init__()
|
| 277 |
+
self.net = nn.Sequential(
|
| 278 |
+
nn.Linear(2 * D_LAT, 2 * D_LAT), nn.GELU(),
|
| 279 |
+
nn.Linear(2 * D_LAT, D_LAT), nn.GELU(),
|
| 280 |
+
nn.Linear(D_LAT, 1))
|
| 281 |
+
|
| 282 |
+
def energy(self, s, z):
|
| 283 |
+
"""s: [B, D_LAT] context state; z: [B, D_LAT] candidate. -> [B]"""
|
| 284 |
+
return self.net(torch.cat([s, z], -1)).squeeze(-1)
|
| 285 |
+
|
| 286 |
+
def contrastive_loss(self, s, z_pos, z_negs=None, tau=0.5):
|
| 287 |
+
"""Softmax over energies: true continuation must be the argmin.
|
| 288 |
+
In-batch negatives: every other item's z_pos is a negative for s_i."""
|
| 289 |
+
B = s.size(0)
|
| 290 |
+
# pairwise energies: E(s_i, z_j) for all i, j
|
| 291 |
+
s_rep = s.unsqueeze(1).expand(B, B, D_LAT).reshape(B * B, D_LAT)
|
| 292 |
+
z_rep = z_pos.unsqueeze(0).expand(B, B, D_LAT).reshape(B * B, D_LAT)
|
| 293 |
+
E = self.energy(s_rep, z_rep).view(B, B) # [B, B]
|
| 294 |
+
if z_negs is not None: # extra hard negatives
|
| 295 |
+
En = self.energy(
|
| 296 |
+
s.repeat_interleave(z_negs.size(1), 0),
|
| 297 |
+
z_negs.reshape(-1, D_LAT)).view(B, -1)
|
| 298 |
+
E = torch.cat([E, En], dim=1)
|
| 299 |
+
labels = torch.arange(B, device=s.device)
|
| 300 |
+
return F.cross_entropy(-E / tau, labels) # low E β high logit
|
| 301 |
+
|
| 302 |
+
@torch.no_grad()
|
| 303 |
+
def rank(self, s, candidates):
|
| 304 |
+
"""candidates: [N, D_LAT]; returns energies [N] (lower = better)."""
|
| 305 |
+
return self.energy(s.expand(candidates.size(0), -1), candidates)
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
# ============================================================================
|
| 309 |
+
# 6. Bridge β LLM hidden states -> shared latent space (alignment)
|
| 310 |
+
# ============================================================================
|
| 311 |
+
class Bridge(nn.Module):
|
| 312 |
+
def __init__(self, d_hidden):
|
| 313 |
+
super().__init__()
|
| 314 |
+
self.proj = nn.Sequential(nn.Linear(d_hidden, D_LAT), nn.GELU(),
|
| 315 |
+
nn.Linear(D_LAT, D_LAT))
|
| 316 |
+
|
| 317 |
+
def forward(self, h): # [B,T,H] -> [B,D_LAT]
|
| 318 |
+
return self.proj(h.float().mean(1))
|
| 319 |
+
|
| 320 |
+
def info_nce(self, a, b, tau=0.07):
|
| 321 |
+
a, b = F.normalize(a, -1), F.normalize(b, -1)
|
| 322 |
+
logits = a @ b.t() / tau
|
| 323 |
+
y = torch.arange(a.size(0), device=a.device)
|
| 324 |
+
return 0.5 * (F.cross_entropy(logits, y) +
|
| 325 |
+
F.cross_entropy(logits.t(), y))
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
# ============================================================================
|
| 329 |
+
# 7. THE ENSEMBLE β wiring + inference (plan -> generate -> energy-rank)
|
| 330 |
+
# ============================================================================
|
| 331 |
+
class WorldEnsemble(nn.Module):
|
| 332 |
+
def __init__(self, llm_spec="tiny"):
|
| 333 |
+
super().__init__()
|
| 334 |
+
self.llm = load_llm(llm_spec)
|
| 335 |
+
V, H = self.llm.vocab_size, self.llm.hidden_size
|
| 336 |
+
self.emb = Embedder(V)
|
| 337 |
+
self.jepa = JEPA(V)
|
| 338 |
+
self.world = WorldModel()
|
| 339 |
+
self.energy = EnergyModel()
|
| 340 |
+
self.bridge = Bridge(H)
|
| 341 |
+
self.store = VectorStore()
|
| 342 |
+
|
| 343 |
+
# ------------------------- inference ---------------------------------
|
| 344 |
+
@torch.no_grad()
|
| 345 |
+
def world_state(self, segments):
|
| 346 |
+
"""Fold a list of [1,T] segment tensors into a latent state."""
|
| 347 |
+
s = self.world.init_state(1, "cpu")
|
| 348 |
+
for seg in segments:
|
| 349 |
+
z = self.jepa.encode(seg.cpu())
|
| 350 |
+
s, _ = self.world.step(s, z)
|
| 351 |
+
return s
|
| 352 |
+
|
| 353 |
+
@torch.no_grad()
|
| 354 |
+
def answer(self, query_ids, n_new=24, n_drafts=6, horizon=3):
|
| 355 |
+
"""retrieve -> build world state -> imagine -> generate N -> argmin E."""
|
| 356 |
+
q_emb = self.emb(query_ids.cpu())
|
| 357 |
+
mems = self.store.search(q_emb, k=1)
|
| 358 |
+
segments = (mems + [query_ids.cpu()]) if mems else [query_ids.cpu()]
|
| 359 |
+
ctx = torch.cat(segments, dim=1)
|
| 360 |
+
|
| 361 |
+
s = self.world_state(segments) # latent context state
|
| 362 |
+
plan, _ = self.world.rollout(s, horizon) # imagined future
|
| 363 |
+
# (plan is available for planning losses / steering; logged here)
|
| 364 |
+
|
| 365 |
+
drafts, lat = [], []
|
| 366 |
+
for _ in range(n_drafts):
|
| 367 |
+
out = self.llm.generate(ctx.to(self.llm.device), n_new=n_new,
|
| 368 |
+
temperature=0.9)
|
| 369 |
+
new = out[:, ctx.size(1):].cpu()
|
| 370 |
+
drafts.append(new)
|
| 371 |
+
lat.append(self.jepa.encode(new))
|
| 372 |
+
Z = torch.cat(lat, 0) # [N, D_LAT]
|
| 373 |
+
E = self.energy.rank(s, Z) # lower = better
|
| 374 |
+
best = E.argmin().item()
|
| 375 |
+
return {"output": drafts[best], "energy": E[best].item(),
|
| 376 |
+
"all_energies": E.tolist(),
|
| 377 |
+
"plan_alignment": F.cosine_similarity(
|
| 378 |
+
plan[:, 0], Z[best:best + 1]).item()}
|
| 379 |
+
|
| 380 |
+
def memorize(self, ids):
|
| 381 |
+
self.store.add(self.emb(ids.cpu()), ids.cpu())
|
| 382 |
+
|
| 383 |
+
# ------------------------- training ----------------------------------
|
| 384 |
+
def train_step(self, seg_seq, opt, w=dict(lm=1.0, jepa=1.0, world=1.0,
|
| 385 |
+
ebm=1.0, bridge=0.1),
|
| 386 |
+
hard_negs=True):
|
| 387 |
+
"""seg_seq: [B, T_seg, L] β B documents, each split into T_seg
|
| 388 |
+
consecutive segments of length L (same tokenizer as the LLM)."""
|
| 389 |
+
B, T, L = seg_seq.shape
|
| 390 |
+
dev = self.llm.device
|
| 391 |
+
|
| 392 |
+
# (1) LM loss on the first segment (or all, batched, if budget allows)
|
| 393 |
+
flat = seg_seq[:, 0].to(dev)
|
| 394 |
+
logits, hidden = self.llm(flat)
|
| 395 |
+
lm = F.cross_entropy(
|
| 396 |
+
logits[:, :-1].reshape(-1, self.llm.vocab_size).float(),
|
| 397 |
+
flat[:, 1:].reshape(-1))
|
| 398 |
+
|
| 399 |
+
# (2) JEPA: adjacent segment pairs
|
| 400 |
+
jepa = self.jepa.loss(seg_seq[:, 0], seg_seq[:, 1])
|
| 401 |
+
|
| 402 |
+
# (3) World model: sequence of latents (online encoder, grads flow)
|
| 403 |
+
z_seq = torch.stack([self.jepa.enc(seg_seq[:, t])
|
| 404 |
+
for t in range(T)], 1) # [B, T, D_LAT]
|
| 405 |
+
world = self.world.sequence_loss(z_seq)
|
| 406 |
+
|
| 407 |
+
# (4) Energy: state after t=0 must give low E to true z_1,
|
| 408 |
+
# high E to in-batch + (optionally) LLM-generated negatives
|
| 409 |
+
s = self.world.init_state(B, z_seq.device)
|
| 410 |
+
s, _ = self.world.step(s, z_seq[:, 0].detach())
|
| 411 |
+
z_pos = z_seq[:, 1].detach()
|
| 412 |
+
z_negs = None
|
| 413 |
+
if hard_negs:
|
| 414 |
+
with torch.no_grad(): # model drafts as negs
|
| 415 |
+
gen = self.llm.generate(seg_seq[:, 0].to(dev), n_new=L)
|
| 416 |
+
gen_new = gen[:, seg_seq.size(2):].cpu()
|
| 417 |
+
z_negs = self.jepa.encode(gen_new).unsqueeze(1) # [B,1,D]
|
| 418 |
+
ebm = self.energy.contrastive_loss(s, z_pos, z_negs)
|
| 419 |
+
|
| 420 |
+
# (5) Bridge: align LLM hidden(seg0) with JEPA latent(seg0)
|
| 421 |
+
bridge = self.bridge.info_nce(
|
| 422 |
+
self.bridge(hidden.cpu() if hidden.device.type != "cpu" else hidden),
|
| 423 |
+
self.jepa.enc(seg_seq[:, 0]).detach())
|
| 424 |
+
|
| 425 |
+
loss = (w["lm"] * lm.cpu() + w["jepa"] * jepa + w["world"] * world
|
| 426 |
+
+ w["ebm"] * ebm + w["bridge"] * bridge)
|
| 427 |
+
opt.zero_grad(); loss.backward(); opt.step()
|
| 428 |
+
self.jepa.ema()
|
| 429 |
+
return dict(lm=lm.item(), jepa=jepa.item(), world=world.item(),
|
| 430 |
+
ebm=ebm.item(), bridge=bridge.item())
|
| 431 |
+
|
| 432 |
+
def make_optimizer(self, lr_lora=2e-4, lr_aux=1e-3):
|
| 433 |
+
return torch.optim.AdamW([
|
| 434 |
+
{"params": list(self.llm.trainable_parameters()), "lr": lr_lora},
|
| 435 |
+
{"params": list(self.jepa.enc.parameters())
|
| 436 |
+
+ list(self.jepa.pred.parameters()), "lr": lr_aux},
|
| 437 |
+
{"params": list(self.world.parameters()), "lr": lr_aux},
|
| 438 |
+
{"params": list(self.energy.parameters()), "lr": lr_aux},
|
| 439 |
+
{"params": list(self.bridge.parameters())
|
| 440 |
+
+ list(self.emb.parameters()), "lr": lr_aux}])
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
# ============================================================================
|
| 444 |
+
# 8. Data helpers + demo
|
| 445 |
+
# ============================================================================
|
| 446 |
+
def toy_segment_sequences(B=8, T=4, L=24, vocab=1000):
|
| 447 |
+
"""Random docs split into T consecutive segments. Replace with real
|
| 448 |
+
corpus: tokenize each document, reshape into [T, L] windows."""
|
| 449 |
+
return torch.randint(0, vocab, (B, T, L))
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
def hf_segment_sequences(llm: HFLLM, texts, T=4, L=64):
|
| 453 |
+
seqs = []
|
| 454 |
+
for t in texts:
|
| 455 |
+
ids = llm.tokenizer(t, return_tensors="pt").input_ids[0]
|
| 456 |
+
n = (len(ids) // (T * L)) * T * L
|
| 457 |
+
if n:
|
| 458 |
+
seqs.append(ids[:n].view(-1, T, L))
|
| 459 |
+
if not seqs:
|
| 460 |
+
raise ValueError("corpus too short for T*L window")
|
| 461 |
+
return torch.cat(seqs, 0)
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
def demo(spec="tiny", steps=60):
|
| 465 |
+
ens = WorldEnsemble(spec)
|
| 466 |
+
opt = ens.make_optimizer()
|
| 467 |
+
|
| 468 |
+
if spec == "tiny":
|
| 469 |
+
get_batch = lambda: toy_segment_sequences(vocab=ens.llm.vocab_size)
|
| 470 |
+
else:
|
| 471 |
+
corpus = ["Replace with your real documents. " * 200]
|
| 472 |
+
data = hf_segment_sequences(ens.llm, corpus, T=4, L=32)
|
| 473 |
+
get_batch = lambda: data[torch.randperm(len(data))[:4]]
|
| 474 |
+
steps = min(steps, 10)
|
| 475 |
+
|
| 476 |
+
t0 = time.time()
|
| 477 |
+
for s in range(steps):
|
| 478 |
+
logs = ens.train_step(get_batch(), opt,
|
| 479 |
+
hard_negs=(s > steps // 2)) # warmup w/o negs
|
| 480 |
+
if s % 10 == 0:
|
| 481 |
+
print(f"step {s:3d} | " +
|
| 482 |
+
" | ".join(f"{k} {v:.3f}" for k, v in logs.items()))
|
| 483 |
+
print(f"trained {steps} steps in {time.time()-t0:.1f}s")
|
| 484 |
+
|
| 485 |
+
# memory + inference
|
| 486 |
+
for _ in range(4):
|
| 487 |
+
if spec == "tiny":
|
| 488 |
+
ens.memorize(torch.randint(0, ens.llm.vocab_size, (1, 24)))
|
| 489 |
+
q = (torch.randint(0, ens.llm.vocab_size, (1, 12)) if spec == "tiny"
|
| 490 |
+
else ens.llm.tokenizer("What is this document about?",
|
| 491 |
+
return_tensors="pt").input_ids)
|
| 492 |
+
res = ens.answer(q, n_drafts=6, horizon=3)
|
| 493 |
+
print(f"\nselected draft energy={res['energy']:.3f} "
|
| 494 |
+
f"(all: {[f'{e:.2f}' for e in res['all_energies']]})")
|
| 495 |
+
print(f"planβoutput alignment: {res['plan_alignment']:.3f}")
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
if __name__ == "__main__":
|
| 499 |
+
demo(sys.argv[1] if len(sys.argv) > 1 else "tiny")
|
research/eval_harness.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Deprecated shim β use `ensemble.eval.jepa_harness` instead."""
|
| 2 |
+
|
| 3 |
+
from ensemble.eval.jepa_harness import run, parse_args
|
| 4 |
+
|
| 5 |
+
if __name__ == "__main__":
|
| 6 |
+
run(parse_args())
|
research/evals/README.md
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SLM Agentic Benchmark Suite
|
| 2 |
+
|
| 3 |
+
A uv workspace package to evaluate **local HuggingFace models** against agentic and academic benchmarks.
|
| 4 |
+
|
| 5 |
+
**Docs:** [USAGE.md](USAGE.md) (commands and workflows) Β· [docs/benchmarks.md](docs/benchmarks.md) (per-benchmark reference) Β· [../USAGE.md](../USAGE.md) (full research tree)
|
| 6 |
+
|
| 7 |
+
| Suite | CLI | What it measures |
|
| 8 |
+
|---|---|---|
|
| 9 |
+
| **Agentic** | `slm-benchmark` | BFCL, Ο-bench, GAIA, SWE-bench |
|
| 10 |
+
| **Academic** | `slm-lm-eval` | ARC, HellaSwag, GSM8K, β¦ (lm-evaluation-harness) |
|
| 11 |
+
|
| 12 |
+
## Install
|
| 13 |
+
|
| 14 |
+
From the repo root:
|
| 15 |
+
|
| 16 |
+
```bash
|
| 17 |
+
uv sync --group evals
|
| 18 |
+
uv sync --group lm-eval # optional: slm-lm-eval academic benchmarks
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
## Quickstart
|
| 22 |
+
|
| 23 |
+
```bash
|
| 24 |
+
# From repo root (recommended)
|
| 25 |
+
uv run --package slm-evals slm-benchmark \
|
| 26 |
+
--model openbmb/MiniCPM5-1B \
|
| 27 |
+
--benchmarks bfcl tau_bench \
|
| 28 |
+
--max-samples 20
|
| 29 |
+
|
| 30 |
+
# Or as a module
|
| 31 |
+
uv run --package slm-evals python -m slm_evals.run_benchmark \
|
| 32 |
+
--model openbmb/MiniCPM5-1B \
|
| 33 |
+
--benchmarks bfcl tau_bench \
|
| 34 |
+
--max-samples 20
|
| 35 |
+
|
| 36 |
+
# YAML config
|
| 37 |
+
uv run --package slm-evals slm-benchmark \
|
| 38 |
+
--config research/evals/configs/experiment_001.yaml
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
## Project structure
|
| 42 |
+
|
| 43 |
+
```
|
| 44 |
+
research/evals/
|
| 45 |
+
βββ pyproject.toml
|
| 46 |
+
βββ configs/
|
| 47 |
+
β βββ experiment_001.yaml
|
| 48 |
+
βββ src/slm_evals/
|
| 49 |
+
β βββ run_benchmark.py
|
| 50 |
+
β βββ benchmarks/
|
| 51 |
+
β β βββ base.py
|
| 52 |
+
β β βββ bfcl.py
|
| 53 |
+
β β βββ tau_bench.py
|
| 54 |
+
β β βββ gaia.py
|
| 55 |
+
β β βββ swe_bench.py
|
| 56 |
+
β βββ utils/
|
| 57 |
+
β βββ model_loader.py
|
| 58 |
+
β βββ reporter.py
|
| 59 |
+
β βββ config_loader.py
|
| 60 |
+
βββ results/ # created at runtime (relative to cwd)
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
## CLI reference
|
| 64 |
+
|
| 65 |
+
```
|
| 66 |
+
--model Path to local HF model dir (or Hub ID)
|
| 67 |
+
--benchmarks Space-separated: bfcl tau_bench gaia swe_bench all
|
| 68 |
+
--config YAML config file (overrides CLI flags)
|
| 69 |
+
--max-samples Cap samples per benchmark
|
| 70 |
+
--output-dir Results directory (default: ./results)
|
| 71 |
+
--experiment-name Tag for this run
|
| 72 |
+
--device auto | cpu | cuda | cuda:0
|
| 73 |
+
--dtype float32 | float16 | bfloat16 | int8 | int4
|
| 74 |
+
--max-new-tokens Max tokens per generation (default: 512)
|
| 75 |
+
--temperature Sampling temp (default: 0.0 = greedy)
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
## Adding a custom benchmark
|
| 79 |
+
|
| 80 |
+
1. Create `src/slm_evals/benchmarks/my_bench.py` and subclass `BaseBenchmark`.
|
| 81 |
+
2. Register it in `src/slm_evals/run_benchmark.py` β `BENCHMARK_REGISTRY`.
|
| 82 |
+
3. Run: `uv run --package slm-evals slm-benchmark --model ./my-model --benchmarks my_bench`
|
| 83 |
+
|
| 84 |
+
## Output formats
|
| 85 |
+
|
| 86 |
+
Results are written under `<output-dir>/<experiment_name>/`:
|
| 87 |
+
|
| 88 |
+
- `results.json` β full structured dump
|
| 89 |
+
- `results.csv` β one row per benchmark
|
| 90 |
+
- `report.md` β human-readable summary
|
| 91 |
+
|
| 92 |
+
## Notes
|
| 93 |
+
|
| 94 |
+
**Ο-bench user simulator**: Default is a lightweight rule-based simulator. Set `use_llm_user: true` in config for the GPT-4o user agent (API cost).
|
| 95 |
+
|
| 96 |
+
**SWE-bench full eval**: Set `full_eval: true` to run the official Docker harness (`pip install swebench docker`).
|
| 97 |
+
|
| 98 |
+
**GAIA tools**: Offline by default (`tool_mode: describe`). Wire real tools in `gaia.py` for live eval.
|
research/evals/USAGE.md
ADDED
|
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Evals usage
|
| 2 |
+
|
| 3 |
+
Run the **SLM Agentic Benchmark Suite** (`slm-evals`) against a local HuggingFace model directory or Hub id.
|
| 4 |
+
|
| 5 |
+
Benchmark details: [docs/benchmarks.md](docs/benchmarks.md). Package overview: [README.md](README.md).
|
| 6 |
+
|
| 7 |
+
## Install
|
| 8 |
+
|
| 9 |
+
From the repo root:
|
| 10 |
+
|
| 11 |
+
```bash
|
| 12 |
+
uv sync --group evals
|
| 13 |
+
```
|
| 14 |
+
|
| 15 |
+
For academic benchmarks (lm-evaluation-harness):
|
| 16 |
+
|
| 17 |
+
```bash
|
| 18 |
+
uv sync --group lm-eval
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
This installs the `slm-evals` workspace package and registers the `slm-benchmark` and `slm-lm-eval` console scripts.
|
| 22 |
+
|
| 23 |
+
## Quick start
|
| 24 |
+
|
| 25 |
+
```bash
|
| 26 |
+
# Two benchmarks, capped samples (good first run)
|
| 27 |
+
uv run --package slm-evals slm-benchmark \
|
| 28 |
+
--model openbmb/MiniCPM5-1B \
|
| 29 |
+
--benchmarks bfcl tau_bench \
|
| 30 |
+
--max-samples 20
|
| 31 |
+
|
| 32 |
+
# All four benchmarks
|
| 33 |
+
uv run --package slm-evals slm-benchmark \
|
| 34 |
+
--model ./models/finetuned/minicpm5-1b-lora \
|
| 35 |
+
--benchmarks all \
|
| 36 |
+
--max-samples 50
|
| 37 |
+
|
| 38 |
+
# Equivalent module invocation
|
| 39 |
+
uv run --package slm-evals python -m slm_evals.run_benchmark \
|
| 40 |
+
--model openbmb/MiniCPM5-1B \
|
| 41 |
+
--benchmarks bfcl \
|
| 42 |
+
--max-samples 10
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
## Config-driven runs
|
| 46 |
+
|
| 47 |
+
Copy and edit the template, then pass `--config`:
|
| 48 |
+
|
| 49 |
+
```bash
|
| 50 |
+
cp research/evals/configs/experiment_001.yaml research/evals/configs/my_run.yaml
|
| 51 |
+
# edit model_path, benchmarks, max_samples, overrides
|
| 52 |
+
|
| 53 |
+
uv run --package slm-evals slm-benchmark \
|
| 54 |
+
--config research/evals/configs/my_run.yaml
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
When `--config` is set, **YAML values override CLI flags**. Use configs for reproducible experiment names and per-benchmark settings.
|
| 58 |
+
|
| 59 |
+
### Template fields
|
| 60 |
+
|
| 61 |
+
| Key | Description |
|
| 62 |
+
| --- | ----------- |
|
| 63 |
+
| `model_path` | Local directory or HF Hub id |
|
| 64 |
+
| `device` | `auto`, `cpu`, `cuda`, `cuda:0`, β¦ |
|
| 65 |
+
| `dtype` | `float32`, `float16`, `bfloat16`, `int8`, `int4` |
|
| 66 |
+
| `max_new_tokens` | Cap per generation (default 512) |
|
| 67 |
+
| `temperature` | `0.0` = greedy (recommended for evals) |
|
| 68 |
+
| `experiment_name` | Folder name under `output_dir` |
|
| 69 |
+
| `output_dir` | Root for results (default `results`) |
|
| 70 |
+
| `benchmarks` | List: `bfcl`, `tau_bench`, `gaia`, `swe_bench` |
|
| 71 |
+
| `max_samples` | Cap per benchmark; omit or `null` for full split |
|
| 72 |
+
| `benchmark_overrides` | Per-benchmark dict (see [docs/benchmarks.md](docs/benchmarks.md)) |
|
| 73 |
+
|
| 74 |
+
---
|
| 75 |
+
|
| 76 |
+
## CLI reference
|
| 77 |
+
|
| 78 |
+
```
|
| 79 |
+
slm-benchmark [OPTIONS]
|
| 80 |
+
|
| 81 |
+
--list-benchmarks Show agentic benchmark keys and preset suites
|
| 82 |
+
--model PATH Local HF dir or Hub id (required unless --config)
|
| 83 |
+
--benchmarks NAMES bfcl tau_bench gaia swe_bench all (default: all)
|
| 84 |
+
--config PATH YAML config (overrides other flags)
|
| 85 |
+
--max-samples N Cap samples per benchmark
|
| 86 |
+
--output-dir DIR Results root (default: ./results)
|
| 87 |
+
--experiment-name TAG Run folder name (auto timestamp if omitted)
|
| 88 |
+
--device MAP auto | cpu | cuda | cuda:0
|
| 89 |
+
--dtype TYPE float32 | float16 | bfloat16 | int8 | int4
|
| 90 |
+
--max-new-tokens N Default 512
|
| 91 |
+
--temperature T Default 0.0
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
---
|
| 95 |
+
|
| 96 |
+
## Results
|
| 97 |
+
|
| 98 |
+
Each run writes to `<output_dir>/<experiment_name>/`:
|
| 99 |
+
|
| 100 |
+
| File | Contents |
|
| 101 |
+
| ---- | -------- |
|
| 102 |
+
| `results.json` | Full structured payload (per-sample + aggregates) |
|
| 103 |
+
| `results.csv` | One row per benchmark |
|
| 104 |
+
| `report.md` | Human-readable summary |
|
| 105 |
+
|
| 106 |
+
Example layout:
|
| 107 |
+
|
| 108 |
+
```text
|
| 109 |
+
results/
|
| 110 |
+
βββ minicpm5-1b__bfcl-tau__v1/
|
| 111 |
+
βββ results.json
|
| 112 |
+
βββ results.csv
|
| 113 |
+
βββ report.md
|
| 114 |
+
```
|
| 115 |
+
|
| 116 |
+
`output_dir` is relative to **current working directory**. Run from repo root so paths stay predictable, or set an absolute `output_dir` in YAML.
|
| 117 |
+
|
| 118 |
+
---
|
| 119 |
+
|
| 120 |
+
## Per-benchmark tips
|
| 121 |
+
|
| 122 |
+
### BFCL (function calling)
|
| 123 |
+
|
| 124 |
+
- Default: downloads from `gorilla-llm/Berkeley-Function-Calling-Leaderboard`
|
| 125 |
+
- `strict: false` in YAML β fuzzy argument matching (better for small models)
|
| 126 |
+
- Local JSONL: set `benchmark_overrides.bfcl.data_path`
|
| 127 |
+
|
| 128 |
+
### Ο-bench (multi-turn tools)
|
| 129 |
+
|
| 130 |
+
- Domains: `retail`, `airline`, or `both`
|
| 131 |
+
- `use_llm_user: false` β free rule-based user simulator (default)
|
| 132 |
+
- `use_llm_user: true` β GPT-4o user agent (**API cost**)
|
| 133 |
+
|
| 134 |
+
### GAIA
|
| 135 |
+
|
| 136 |
+
- Default split: `validation` (public)
|
| 137 |
+
- `tool_mode: describe` β offline tool descriptions (no live web)
|
| 138 |
+
- Level filter: `levels: [1, 2]` or `[1, 2, 3]`
|
| 139 |
+
|
| 140 |
+
### SWE-bench Verified
|
| 141 |
+
|
| 142 |
+
- Default: lightweight patch-generation scoring (no Docker)
|
| 143 |
+
- `full_eval: true` β official harness (`pip install swebench docker`)
|
| 144 |
+
|
| 145 |
+
See [docs/benchmarks.md](docs/benchmarks.md) for scoring semantics.
|
| 146 |
+
|
| 147 |
+
---
|
| 148 |
+
|
| 149 |
+
## lm-evaluation-harness (`slm-lm-eval`)
|
| 150 |
+
|
| 151 |
+
Run standard academic benchmarks (ARC, HellaSwag, PIQA, BoolQ, GSM8K) via [EleutherAI lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness).
|
| 152 |
+
|
| 153 |
+
Install: `uv sync --group lm-eval`
|
| 154 |
+
|
| 155 |
+
Full profile guide: [docs/eval_profiles.md](docs/eval_profiles.md)
|
| 156 |
+
|
| 157 |
+
### Discover profiles and tasks
|
| 158 |
+
|
| 159 |
+
```bash
|
| 160 |
+
# Claim-matched lm-eval profiles (reasoning, code, smoke, β¦)
|
| 161 |
+
uv run --package slm-evals slm-lm-eval --list-profiles
|
| 162 |
+
|
| 163 |
+
# Also show agentic suites + external benchmark notes
|
| 164 |
+
uv run --package slm-evals slm-lm-eval --list-profiles-all
|
| 165 |
+
|
| 166 |
+
# lm-eval task names
|
| 167 |
+
uv run --package slm-evals slm-lm-eval --list-tasks
|
| 168 |
+
|
| 169 |
+
# Agentic benchmarks (BFCL, Ο-bench, GAIA, SWE)
|
| 170 |
+
uv run --package slm-evals slm-benchmark --list-benchmarks
|
| 171 |
+
```
|
| 172 |
+
|
| 173 |
+
### Quick start
|
| 174 |
+
|
| 175 |
+
```bash
|
| 176 |
+
# By profile name (recommended)
|
| 177 |
+
uv run --package slm-evals slm-lm-eval \
|
| 178 |
+
--profile reasoning \
|
| 179 |
+
--preset minicpm5-1b \
|
| 180 |
+
--experiment-name minicpm5-1b__reasoning-baseline
|
| 181 |
+
|
| 182 |
+
# Smoke profile (25 samples)
|
| 183 |
+
uv run --package slm-evals slm-lm-eval \
|
| 184 |
+
--profile smoke \
|
| 185 |
+
--preset minicpm5-1b \
|
| 186 |
+
--experiment-name minicpm5-1b__smoke
|
| 187 |
+
|
| 188 |
+
# LoRA adapter via preset (base + peft resolved automatically)
|
| 189 |
+
uv run --package slm-evals slm-lm-eval \
|
| 190 |
+
--config research/evals/configs/lm_eval_minicpm5.yaml \
|
| 191 |
+
--preset minicpm5-1b-lesson-lora \
|
| 192 |
+
--experiment-name minicpm5-1b-lora__v1
|
| 193 |
+
|
| 194 |
+
# Explicit base + adapter
|
| 195 |
+
uv run --package slm-evals slm-lm-eval \
|
| 196 |
+
--config research/evals/configs/lm_eval_smoke.yaml \
|
| 197 |
+
--model openbmb/MiniCPM5-1B \
|
| 198 |
+
--adapter ./models/finetuned/minicpm5-1b-lora \
|
| 199 |
+
--experiment-name minicpm5-1b-lora__manual
|
| 200 |
+
|
| 201 |
+
# Ensemble checkpoint (manifest.json auto-detected)
|
| 202 |
+
uv run --package slm-evals slm-lm-eval \
|
| 203 |
+
--config research/evals/configs/lm_eval_smoke.yaml \
|
| 204 |
+
--model ./models/ensemble/jepa-lesson-pretrain \
|
| 205 |
+
--experiment-name ensemble-jepa__lm-eval
|
| 206 |
+
```
|
| 207 |
+
|
| 208 |
+
### Compare baseline vs candidate
|
| 209 |
+
|
| 210 |
+
Use the **same config** for both runs; only change `--preset` / `--experiment-name`:
|
| 211 |
+
|
| 212 |
+
```bash
|
| 213 |
+
uv run --package slm-evals slm-lm-eval \
|
| 214 |
+
--config research/evals/configs/lm_eval_compare_study.yaml \
|
| 215 |
+
--preset minicpm5-1b \
|
| 216 |
+
--experiment-name minicpm5-1b__baseline
|
| 217 |
+
|
| 218 |
+
uv run --package slm-evals slm-lm-eval \
|
| 219 |
+
--config research/evals/configs/lm_eval_compare_study.yaml \
|
| 220 |
+
--preset minicpm5-1b-lesson-lora \
|
| 221 |
+
--experiment-name minicpm5-1b-lora__v1 \
|
| 222 |
+
--compare-to results/lm_eval/minicpm5-1b__baseline/results.json
|
| 223 |
+
```
|
| 224 |
+
|
| 225 |
+
### Config templates
|
| 226 |
+
|
| 227 |
+
Catalog: `configs/eval_profiles.yaml` β maps **claim β profile β tasks**.
|
| 228 |
+
|
| 229 |
+
| Profile (`--profile`) | Config file | Purpose |
|
| 230 |
+
| --------------------- | ----------- | ------- |
|
| 231 |
+
| `smoke` | `lm_eval_smoke.yaml` | Fast validation (`limit: 25`, 2 tasks) |
|
| 232 |
+
| `reasoning` | `lm_eval_reasoning.yaml` | Math + commonsense (GSM8K, ARC, HellaSwag) |
|
| 233 |
+
| `understanding` | `lm_eval_understanding.yaml` | NLU (BoolQ, PIQA, COPA, RTE) |
|
| 234 |
+
| `code` | `lm_eval_code.yaml` | HumanEval + MBPP |
|
| 235 |
+
| `instructions` | `lm_eval_instructions.yaml` | IFEval instruction following |
|
| 236 |
+
| `general_slm` | `lm_eval_minicpm5.yaml` | Full ~1B SLM profile (6 tasks) |
|
| 237 |
+
| `compare_study` | `lm_eval_compare_study.yaml` | Baseline vs finetune comparison defaults |
|
| 238 |
+
|
| 239 |
+
| Key | Description |
|
| 240 |
+
| --- | ----------- |
|
| 241 |
+
| `tasks` | lm-eval task names (e.g. `arc_easy`, `gsm8k`) |
|
| 242 |
+
| `num_fewshot` | Few-shot count (gsm8k may use task default 8) |
|
| 243 |
+
| `limit` | Max samples per task; `null` = full split |
|
| 244 |
+
| `seed` | Random seed (applied to all lm-eval RNGs) |
|
| 245 |
+
| `batch_size` | `auto` or integer |
|
| 246 |
+
| `device` | `auto`, `cpu`, `cuda`, β¦ |
|
| 247 |
+
| `dtype` | `bfloat16`, `float16`, `int4`, β¦ |
|
| 248 |
+
| `trust_remote_code` | Required for MiniCPM / Gemma presets |
|
| 249 |
+
| `output_dir` | Root for runs (default `results/lm_eval`) |
|
| 250 |
+
|
| 251 |
+
### CLI reference
|
| 252 |
+
|
| 253 |
+
```
|
| 254 |
+
slm-lm-eval [OPTIONS]
|
| 255 |
+
|
| 256 |
+
--list-profiles Show claim-matched profiles and example commands
|
| 257 |
+
--list-profiles-all Include agentic suites and external benchmark notes
|
| 258 |
+
--list-tasks List lm-eval task names (catalog fallback if not installed)
|
| 259 |
+
--list-tasks-all Full lm-eval task list
|
| 260 |
+
--profile NAME Shorthand for --config (reasoning, code, smoke, β¦)
|
| 261 |
+
--config PATH YAML config (tasks, seed, limit, β¦)
|
| 262 |
+
--preset KEY models.yaml preset (base, LoRA, merged, ensemble)
|
| 263 |
+
--model PATH HF Hub id, merged dir, or ensemble checkpoint
|
| 264 |
+
--adapter PATH LoRA adapter (alternative to preset adapter_path)
|
| 265 |
+
--tasks NAMES Override task list
|
| 266 |
+
--num-fewshot N
|
| 267 |
+
--limit N Cap samples per task
|
| 268 |
+
--seed N
|
| 269 |
+
--batch-size VALUE
|
| 270 |
+
--device MAP
|
| 271 |
+
--dtype TYPE
|
| 272 |
+
--output-dir DIR Default: results/lm_eval
|
| 273 |
+
--experiment-name TAG Run folder name
|
| 274 |
+
--compare-to PATH Baseline results.json for delta table
|
| 275 |
+
```
|
| 276 |
+
|
| 277 |
+
### Results
|
| 278 |
+
|
| 279 |
+
Each run writes to `<output_dir>/<experiment_name>/`:
|
| 280 |
+
|
| 281 |
+
| File | Contents |
|
| 282 |
+
| ---- | -------- |
|
| 283 |
+
| `results.json` | lm-eval native payload + `run_meta` |
|
| 284 |
+
| `summary.md` | Task β metric table |
|
| 285 |
+
| `run_meta.json` | Preset, base model, adapter, tasks, seed |
|
| 286 |
+
| `comparison.md` | Delta table (when `--compare-to` set) |
|
| 287 |
+
|
| 288 |
+
### Ensemble backend notes
|
| 289 |
+
|
| 290 |
+
- **`ensemble-lm`** loads JEPA checkpoints via `manifest.json`.
|
| 291 |
+
- **`generate_until`** tasks (e.g. `gsm8k`) use the full ensemble stack (`generate_text`).
|
| 292 |
+
- **`loglikelihood`** tasks (e.g. `arc_easy`, `hellaswag`) score the underlying HF LLM head (adapter 0), not the JEPA selector. Use [`jepa_harness`](../ensemble/README.md) to measure selector value on domain QA.
|
| 293 |
+
|
| 294 |
+
### PEFT / LoRA
|
| 295 |
+
|
| 296 |
+
lm-eval expects `pretrained=<base>,peft=<adapter>`. The preset resolver handles this for keys like `minicpm5-1b-lesson-lora`. Merged checkpoints use `--preset minicpm5-1b-lesson-merged` or `--model ./models/finetuned/...-merged`.
|
| 297 |
+
|
| 298 |
+
---
|
| 299 |
+
|
| 300 |
+
## Adding a custom benchmark
|
| 301 |
+
|
| 302 |
+
1. Create `src/slm_evals/benchmarks/my_bench.py` subclassing `BaseBenchmark`:
|
| 303 |
+
- `load_dataset()` β list of sample dicts
|
| 304 |
+
- `build_prompt(sample)` β prompt string
|
| 305 |
+
- `evaluate_sample(sample, prediction)` β `{passed, score, note}`
|
| 306 |
+
|
| 307 |
+
2. Register in `src/slm_evals/run_benchmark.py` β `BENCHMARK_REGISTRY`.
|
| 308 |
+
|
| 309 |
+
3. Run:
|
| 310 |
+
```bash
|
| 311 |
+
uv run --package slm-evals slm-benchmark \
|
| 312 |
+
--model ./my-model --benchmarks my_bench --max-samples 10
|
| 313 |
+
```
|
| 314 |
+
|
| 315 |
+
---
|
| 316 |
+
|
| 317 |
+
## Suggested workflows
|
| 318 |
+
|
| 319 |
+
### Smoke (CPU/GPU, ~5 min)
|
| 320 |
+
|
| 321 |
+
```bash
|
| 322 |
+
uv run --package slm-evals slm-benchmark \
|
| 323 |
+
--model openbmb/MiniCPM5-1B \
|
| 324 |
+
--benchmarks bfcl \
|
| 325 |
+
--max-samples 5 \
|
| 326 |
+
--device cpu
|
| 327 |
+
```
|
| 328 |
+
|
| 329 |
+
### Before / after fine-tune
|
| 330 |
+
|
| 331 |
+
```bash
|
| 332 |
+
BASE=openbmb/MiniCPM5-1B
|
| 333 |
+
ADAPTER=./models/finetuned/minicpm5-1b-lora
|
| 334 |
+
|
| 335 |
+
for M in "$BASE" "$ADAPTER"; do
|
| 336 |
+
uv run --package slm-evals slm-benchmark \
|
| 337 |
+
--model "$M" \
|
| 338 |
+
--benchmarks bfcl tau_bench \
|
| 339 |
+
--max-samples 100 \
|
| 340 |
+
--experiment-name "$(basename "$M")__bfcl-tau"
|
| 341 |
+
done
|
| 342 |
+
```
|
| 343 |
+
|
| 344 |
+
### Full experiment (YAML)
|
| 345 |
+
|
| 346 |
+
Edit `configs/experiment_001.yaml` with your `model_path` and `experiment_name`, then:
|
| 347 |
+
|
| 348 |
+
```bash
|
| 349 |
+
uv run --package slm-evals slm-benchmark \
|
| 350 |
+
--config research/evals/configs/experiment_001.yaml
|
| 351 |
+
```
|
| 352 |
+
|
| 353 |
+
---
|
| 354 |
+
|
| 355 |
+
## Troubleshooting
|
| 356 |
+
|
| 357 |
+
| Symptom | Likely cause | Fix |
|
| 358 |
+
| ------- | ------------ | --- |
|
| 359 |
+
| `error: --model is required` | No `--config` and no `--model` | Pass one of them |
|
| 360 |
+
| CUDA OOM | Model too large for VRAM | `--dtype int4` or `--device cpu` |
|
| 361 |
+
| HF dataset 401 on GAIA test | Gated split | Use `split: validation` |
|
| 362 |
+
| Ο-bench hangs / costs | LLM user enabled | Set `use_llm_user: false` |
|
| 363 |
+
| Empty `results/` | Wrong cwd | Run from repo root or use absolute `output_dir` |
|
| 364 |
+
| Import errors | Evals group not synced | `uv sync --group evals` |
|
| 365 |
+
|
| 366 |
+
---
|
| 367 |
+
|
| 368 |
+
## Entry points
|
| 369 |
+
|
| 370 |
+
| Path | Role |
|
| 371 |
+
| ---- | ---- |
|
| 372 |
+
| `slm-benchmark` | Agentic benchmarks (BFCL, Ο-bench, GAIA, SWE) |
|
| 373 |
+
| `slm-lm-eval` | Academic benchmarks via lm-evaluation-harness |
|
| 374 |
+
| `python -m slm_evals.run_benchmark` | Same as `slm-benchmark` |
|
| 375 |
+
| `python -m slm_evals.run_lm_eval` | Same as `slm-lm-eval` |
|
| 376 |
+
| `research/evals/run_benchmark.py` | Thin shim for backward compatibility |
|
research/evals/configs/ensemble_jepa_lesson.yaml
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# JEPA ensemble checkpoint (models/ensemble/jepa-lesson-pretrain)
|
| 2 |
+
# Pretrain: uv run --package ensemble ensemble-pretrain --llm Qwen/Qwen2.5-0.5B-Instruct
|
| 3 |
+
# Compare baseline: copy this file, set model_path to the base Hub id and model_type: hf
|
| 4 |
+
|
| 5 |
+
model_path: "./models/ensemble/jepa-lesson-pretrain"
|
| 6 |
+
model_type: "ensemble"
|
| 7 |
+
device: "auto"
|
| 8 |
+
dtype: "bfloat16"
|
| 9 |
+
|
| 10 |
+
max_new_tokens: 512
|
| 11 |
+
temperature: 0.0
|
| 12 |
+
|
| 13 |
+
experiment_name: "jepa-ensemble-lesson__bfcl-tau__v1"
|
| 14 |
+
output_dir: "results"
|
| 15 |
+
|
| 16 |
+
benchmarks:
|
| 17 |
+
- bfcl
|
| 18 |
+
- tau_bench
|
| 19 |
+
|
| 20 |
+
max_samples: 20
|
| 21 |
+
|
| 22 |
+
benchmark_overrides:
|
| 23 |
+
tau_bench:
|
| 24 |
+
use_llm_user: false
|
research/evals/configs/eval_profiles.yaml
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Eval profile catalog β match your model claim to a profile, then run with:
|
| 2 |
+
# slm-lm-eval --list-profiles
|
| 3 |
+
# slm-lm-eval --profile reasoning --preset minicpm5-1b --experiment-name my-run
|
| 4 |
+
#
|
| 5 |
+
# See research/evals/docs/eval_profiles.md for full guide.
|
| 6 |
+
|
| 7 |
+
profiles:
|
| 8 |
+
smoke:
|
| 9 |
+
tool: slm-lm-eval
|
| 10 |
+
claim: Quick sanity check
|
| 11 |
+
description: Fast validation before a long run (2 tasks, capped samples).
|
| 12 |
+
config: lm_eval_smoke.yaml
|
| 13 |
+
tasks:
|
| 14 |
+
- arc_easy
|
| 15 |
+
- hellaswag
|
| 16 |
+
|
| 17 |
+
reasoning:
|
| 18 |
+
tool: slm-lm-eval
|
| 19 |
+
claim: Better reasoning
|
| 20 |
+
description: Math + commonsense + science QA (generation + multiple-choice).
|
| 21 |
+
config: lm_eval_reasoning.yaml
|
| 22 |
+
tasks:
|
| 23 |
+
- gsm8k
|
| 24 |
+
- arc_easy
|
| 25 |
+
- arc_challenge
|
| 26 |
+
- hellaswag
|
| 27 |
+
|
| 28 |
+
understanding:
|
| 29 |
+
tool: slm-lm-eval
|
| 30 |
+
claim: Better language understanding
|
| 31 |
+
description: NLU / reading comprehension (SuperGLUE-style multiple-choice).
|
| 32 |
+
config: lm_eval_understanding.yaml
|
| 33 |
+
tasks:
|
| 34 |
+
- boolq
|
| 35 |
+
- piqa
|
| 36 |
+
- copa
|
| 37 |
+
- rte
|
| 38 |
+
|
| 39 |
+
code:
|
| 40 |
+
tool: slm-lm-eval
|
| 41 |
+
claim: Better code generation
|
| 42 |
+
description: Python function synthesis (HumanEval + MBPP via lm-eval).
|
| 43 |
+
config: lm_eval_code.yaml
|
| 44 |
+
tasks:
|
| 45 |
+
- humaneval
|
| 46 |
+
- mbpp
|
| 47 |
+
|
| 48 |
+
instructions:
|
| 49 |
+
tool: slm-lm-eval
|
| 50 |
+
claim: Better instruction following
|
| 51 |
+
description: Verifiable instruction constraints (IFEval).
|
| 52 |
+
config: lm_eval_instructions.yaml
|
| 53 |
+
tasks:
|
| 54 |
+
- ifeval
|
| 55 |
+
|
| 56 |
+
general_slm:
|
| 57 |
+
tool: slm-lm-eval
|
| 58 |
+
claim: General ~1B SLM baseline
|
| 59 |
+
description: Balanced academic mix for before/after fine-tune on chat data.
|
| 60 |
+
config: lm_eval_minicpm5.yaml
|
| 61 |
+
tasks:
|
| 62 |
+
- arc_easy
|
| 63 |
+
- arc_challenge
|
| 64 |
+
- hellaswag
|
| 65 |
+
- piqa
|
| 66 |
+
- boolq
|
| 67 |
+
- gsm8k
|
| 68 |
+
|
| 69 |
+
compare_study:
|
| 70 |
+
tool: slm-lm-eval
|
| 71 |
+
claim: Baseline vs finetune comparison
|
| 72 |
+
description: Same tasks as general_slm with limit 100 for paired studies.
|
| 73 |
+
config: lm_eval_compare_study.yaml
|
| 74 |
+
tasks:
|
| 75 |
+
- arc_easy
|
| 76 |
+
- arc_challenge
|
| 77 |
+
- hellaswag
|
| 78 |
+
- piqa
|
| 79 |
+
- boolq
|
| 80 |
+
- gsm8k
|
| 81 |
+
|
| 82 |
+
suites:
|
| 83 |
+
agentic_tool_use:
|
| 84 |
+
tool: slm-benchmark
|
| 85 |
+
claim: Tool use and function calling
|
| 86 |
+
description: BFCL + tau-bench for agents that call tools.
|
| 87 |
+
command: >-
|
| 88 |
+
uv run --package slm-evals slm-benchmark
|
| 89 |
+
--model <model> --benchmarks bfcl tau_bench --max-samples 50
|
| 90 |
+
benchmarks:
|
| 91 |
+
- bfcl
|
| 92 |
+
- tau_bench
|
| 93 |
+
|
| 94 |
+
agentic_gaia:
|
| 95 |
+
tool: slm-benchmark
|
| 96 |
+
claim: End-to-end assistant tasks
|
| 97 |
+
description: GAIA β multi-step reasoning with optional tools.
|
| 98 |
+
command: >-
|
| 99 |
+
uv run --package slm-evals slm-benchmark
|
| 100 |
+
--model <model> --benchmarks gaia --max-samples 20
|
| 101 |
+
benchmarks:
|
| 102 |
+
- gaia
|
| 103 |
+
|
| 104 |
+
agentic_code:
|
| 105 |
+
tool: slm-benchmark
|
| 106 |
+
claim: Real-world code repair
|
| 107 |
+
description: SWE-bench Verified patch generation (lightweight mode by default).
|
| 108 |
+
command: >-
|
| 109 |
+
uv run --package slm-evals slm-benchmark
|
| 110 |
+
--model <model> --benchmarks swe_bench --max-samples 10
|
| 111 |
+
benchmarks:
|
| 112 |
+
- swe_bench
|
| 113 |
+
|
| 114 |
+
agentic_all:
|
| 115 |
+
tool: slm-benchmark
|
| 116 |
+
claim: Full agentic suite
|
| 117 |
+
description: All four slm-benchmark benchmarks.
|
| 118 |
+
command: >-
|
| 119 |
+
uv run --package slm-evals slm-benchmark
|
| 120 |
+
--model <model> --benchmarks all --max-samples 50
|
| 121 |
+
benchmarks:
|
| 122 |
+
- bfcl
|
| 123 |
+
- tau_bench
|
| 124 |
+
- gaia
|
| 125 |
+
- swe_bench
|
| 126 |
+
|
| 127 |
+
jepa_selector:
|
| 128 |
+
tool: jepa_harness
|
| 129 |
+
claim: JEPA / draft selector quality
|
| 130 |
+
description: >-
|
| 131 |
+
Domain QA on benchmark-qa.jsonl; compares first vs random vs JEPA vs oracle
|
| 132 |
+
draft selection. Use for ensemble checkpoints, not base LLM alone.
|
| 133 |
+
command: >-
|
| 134 |
+
uv run --package ensemble python -m ensemble.eval.jepa_harness
|
| 135 |
+
--checkpoint ./models/ensemble/<name> --data research/data/benchmark-qa.jsonl
|
| 136 |
+
|
| 137 |
+
world_model:
|
| 138 |
+
tool: world_harness
|
| 139 |
+
claim: World model / energy-based planning
|
| 140 |
+
description: Energy-ranked draft selection on domain QA (world-model track).
|
| 141 |
+
command: >-
|
| 142 |
+
uv run --package ensemble python -m ensemble.eval.world_harness
|
| 143 |
+
--checkpoint ./models/ensemble/<name> --data research/data/benchmark-qa.jsonl
|
| 144 |
+
|
| 145 |
+
external:
|
| 146 |
+
embeddings_mteb:
|
| 147 |
+
tool: mteb
|
| 148 |
+
claim: Better embeddings
|
| 149 |
+
description: MTEB, BEIR, STS β not wired in this repo; use embeddings-benchmark/mteb.
|
| 150 |
+
tasks_note: 56+ embedding tasks; requires a dedicated encoder model.
|
| 151 |
+
|
| 152 |
+
chat_judge:
|
| 153 |
+
tool: mt_bench_alpacaeval
|
| 154 |
+
claim: Chat / instruction quality (judge-based)
|
| 155 |
+
description: MT-Bench, AlpacaEval β require LLM-as-judge APIs; not in slm-lm-eval.
|
| 156 |
+
|
| 157 |
+
rl_envs:
|
| 158 |
+
tool: babyai_minigrid
|
| 159 |
+
claim: Embodied planning
|
| 160 |
+
description: BabyAI, MiniGrid β separate RL env evals; not integrated here.
|
research/evals/configs/experiment_001.yaml
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# configs/experiment_001.yaml
|
| 2 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 3 |
+
# Template config for a local SLM agentic benchmark run.
|
| 4 |
+
# Copy and edit this file for each new experiment.
|
| 5 |
+
# Run with: uv run --package slm-evals slm-benchmark --config research/evals/configs/experiment_001.yaml
|
| 6 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 7 |
+
|
| 8 |
+
# ββ Model βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 9 |
+
model_path: "./models/my-small-model-7b" # local directory or HF Hub ID
|
| 10 |
+
device: "auto" # auto | cpu | cuda | cuda:0
|
| 11 |
+
dtype: "bfloat16" # float32 | float16 | bfloat16 | int8 | int4
|
| 12 |
+
|
| 13 |
+
# ββ Inference βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 14 |
+
max_new_tokens: 512
|
| 15 |
+
temperature: 0.0 # 0 = greedy (recommended for deterministic evals)
|
| 16 |
+
|
| 17 |
+
# ββ Experiment ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 18 |
+
experiment_name: "my-model-7b__bfcl-tau-gaia-swe__v1"
|
| 19 |
+
output_dir: "results"
|
| 20 |
+
|
| 21 |
+
# ββ Benchmarks ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 22 |
+
benchmarks:
|
| 23 |
+
- bfcl
|
| 24 |
+
- tau_bench
|
| 25 |
+
- gaia
|
| 26 |
+
- swe_bench
|
| 27 |
+
|
| 28 |
+
max_samples: 100 # null = run all; set a number for quick smoke tests
|
| 29 |
+
|
| 30 |
+
# ββ Per-benchmark overrides βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 31 |
+
benchmark_overrides:
|
| 32 |
+
|
| 33 |
+
bfcl:
|
| 34 |
+
# data_path: "./data/bfcl_v4_local.jsonl" # uncomment to use local copy
|
| 35 |
+
strict: false # false = fuzzy arg matching (recommended for SLMs)
|
| 36 |
+
categories: [] # [] = all categories
|
| 37 |
+
|
| 38 |
+
tau_bench:
|
| 39 |
+
# data_path: "./data/tau_bench_retail.jsonl"
|
| 40 |
+
domain: "retail" # retail | airline | both
|
| 41 |
+
max_turns: 15
|
| 42 |
+
use_llm_user: false # true = use GPT-4o as user simulator (costs $)
|
| 43 |
+
|
| 44 |
+
gaia:
|
| 45 |
+
# data_path: "./data/gaia_validation.jsonl"
|
| 46 |
+
split: "validation" # validation (public) | test (needs HF auth)
|
| 47 |
+
levels: [1, 2] # [1] easy only, [1,2,3] all levels
|
| 48 |
+
tool_mode: "describe" # describe (offline) | none
|
| 49 |
+
|
| 50 |
+
swe_bench:
|
| 51 |
+
# data_path: "./data/swe_bench_verified.jsonl"
|
| 52 |
+
full_eval: false # true = run Docker harness (requires swebench pkg)
|
| 53 |
+
context_lines: 80
|
research/evals/configs/lm_eval_code.yaml
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code generation profile β HumanEval + MBPP
|
| 2 |
+
# Run: slm-lm-eval --profile code --preset minicpm5-1b-lesson-lora
|
| 3 |
+
# Note: small models often score low; use --limit 25 for smoke tests.
|
| 4 |
+
|
| 5 |
+
profile: code
|
| 6 |
+
claim: Better code generation
|
| 7 |
+
|
| 8 |
+
tasks:
|
| 9 |
+
- humaneval
|
| 10 |
+
- mbpp
|
| 11 |
+
|
| 12 |
+
num_fewshot: 0
|
| 13 |
+
limit: 50
|
| 14 |
+
seed: 42
|
| 15 |
+
batch_size: auto
|
| 16 |
+
device: auto
|
| 17 |
+
dtype: bfloat16
|
| 18 |
+
trust_remote_code: true
|
| 19 |
+
output_dir: results/lm_eval
|
research/evals/configs/lm_eval_compare_study.yaml
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# configs/lm_eval_compare_study.yaml
|
| 2 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 3 |
+
# Shared settings for baseline vs finetuned comparisons.
|
| 4 |
+
# Use the SAME config for both runs; only change --preset / --experiment-name.
|
| 5 |
+
#
|
| 6 |
+
# Baseline:
|
| 7 |
+
# uv run --package slm-evals slm-lm-eval \
|
| 8 |
+
# --config research/evals/configs/lm_eval_compare_study.yaml \
|
| 9 |
+
# --preset minicpm5-1b \
|
| 10 |
+
# --experiment-name minicpm5-1b__baseline
|
| 11 |
+
#
|
| 12 |
+
# Candidate (after finetune):
|
| 13 |
+
# uv run --package slm-evals slm-lm-eval \
|
| 14 |
+
# --config research/evals/configs/lm_eval_compare_study.yaml \
|
| 15 |
+
# --preset minicpm5-1b-lesson-lora \
|
| 16 |
+
# --experiment-name minicpm5-1b-lora__v1 \
|
| 17 |
+
# --compare-to results/lm_eval/minicpm5-1b__baseline/results.json
|
| 18 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 19 |
+
|
| 20 |
+
study:
|
| 21 |
+
baseline_preset: minicpm5-1b
|
| 22 |
+
candidate_preset: minicpm5-1b-lesson-lora
|
| 23 |
+
notes: >
|
| 24 |
+
Keep tasks, num_fewshot, limit, and seed identical across runs.
|
| 25 |
+
Do not compare training_results.json result_score to lm-eval accuracy.
|
| 26 |
+
|
| 27 |
+
tasks:
|
| 28 |
+
- arc_easy
|
| 29 |
+
- arc_challenge
|
| 30 |
+
- hellaswag
|
| 31 |
+
- piqa
|
| 32 |
+
- boolq
|
| 33 |
+
- gsm8k
|
| 34 |
+
|
| 35 |
+
num_fewshot: 5
|
| 36 |
+
limit: 100
|
| 37 |
+
seed: 42
|
| 38 |
+
batch_size: auto
|
| 39 |
+
device: auto
|
| 40 |
+
dtype: bfloat16
|
| 41 |
+
trust_remote_code: true
|
| 42 |
+
output_dir: results/lm_eval
|
research/evals/configs/lm_eval_instructions.yaml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Instruction following profile β IFEval (verifiable constraints)
|
| 2 |
+
# Run: slm-lm-eval --profile instructions --preset minicpm5-1b
|
| 3 |
+
# Requires lm-eval extras; install with: uv sync --group lm-eval
|
| 4 |
+
|
| 5 |
+
profile: instructions
|
| 6 |
+
claim: Better instruction following
|
| 7 |
+
|
| 8 |
+
tasks:
|
| 9 |
+
- ifeval
|
| 10 |
+
|
| 11 |
+
num_fewshot: 0
|
| 12 |
+
limit: 100
|
| 13 |
+
seed: 42
|
| 14 |
+
batch_size: auto
|
| 15 |
+
device: auto
|
| 16 |
+
dtype: bfloat16
|
| 17 |
+
trust_remote_code: true
|
| 18 |
+
output_dir: results/lm_eval
|
research/evals/configs/lm_eval_minicpm5.yaml
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# configs/lm_eval_minicpm5.yaml
|
| 2 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 3 |
+
# Full academic benchmark profile for ~1B SLMs (MiniCPM5, Gemma E2B, etc.).
|
| 4 |
+
# Run with:
|
| 5 |
+
# uv run --package slm-evals slm-lm-eval \
|
| 6 |
+
# --config research/evals/configs/lm_eval_minicpm5.yaml \
|
| 7 |
+
# --preset minicpm5-1b \
|
| 8 |
+
# --experiment-name minicpm5-1b__baseline
|
| 9 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 10 |
+
|
| 11 |
+
tasks:
|
| 12 |
+
- arc_easy
|
| 13 |
+
- arc_challenge
|
| 14 |
+
- hellaswag
|
| 15 |
+
- piqa
|
| 16 |
+
- boolq
|
| 17 |
+
- gsm8k
|
| 18 |
+
|
| 19 |
+
num_fewshot: 5
|
| 20 |
+
limit: null
|
| 21 |
+
seed: 42
|
| 22 |
+
batch_size: auto
|
| 23 |
+
device: auto
|
| 24 |
+
dtype: bfloat16
|
| 25 |
+
trust_remote_code: true
|
| 26 |
+
output_dir: results/lm_eval
|
research/evals/configs/lm_eval_reasoning.yaml
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Reasoning profile β math + commonsense + science QA
|
| 2 |
+
# Run: slm-lm-eval --profile reasoning --preset minicpm5-1b --experiment-name reasoning-baseline
|
| 3 |
+
|
| 4 |
+
profile: reasoning
|
| 5 |
+
claim: Better reasoning
|
| 6 |
+
|
| 7 |
+
tasks:
|
| 8 |
+
- gsm8k
|
| 9 |
+
- arc_easy
|
| 10 |
+
- arc_challenge
|
| 11 |
+
- hellaswag
|
| 12 |
+
|
| 13 |
+
num_fewshot: 5
|
| 14 |
+
limit: 100
|
| 15 |
+
seed: 42
|
| 16 |
+
batch_size: auto
|
| 17 |
+
device: auto
|
| 18 |
+
dtype: bfloat16
|
| 19 |
+
trust_remote_code: true
|
| 20 |
+
output_dir: results/lm_eval
|