Spaces:
Running on Zero
Running on Zero
| """ | |
| online.py -- persistent finetuning of the boss from real player fights. | |
| Flow: the browser logs every boss decision (state, action, HP) during a fight and, | |
| on fight end, POSTs the trajectory + who-died to /learn. We buffer trajectories and, | |
| every MM_UPDATE_EVERY fights, run one REINFORCE step (mm_grad.OnlineLearner) that | |
| nudges the HARD brain toward what actually worked against humans. The adapted | |
| weights feed straight back into the live boss. | |
| ONLY HARD-tier fights are used, so the data stays on-policy (Easy/Normal are | |
| deliberately handicapped checkpoints; learning from them would be off-policy). | |
| Persistence (optional, set as Space secrets): | |
| HF_TOKEN - a write token | |
| MM_DATASET_REPO - e.g. "your-name/boss-fight-online" | |
| If set, adapted weights are pushed to / pulled from that dataset so learning | |
| survives Space restarts. Without them it still adapts live, just in-memory. | |
| Safety: a frozen copy of the sim-trained weights is kept as an anchor (the learner | |
| pulls gently back toward it), so a weird run can't brick the boss. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import threading | |
| import numpy as np | |
| from mm_grad import OnlineLearner | |
| HERE = os.path.dirname(os.path.abspath(__file__)) | |
| BASE_WEIGHTS = os.path.join(HERE, "mm_weights.npz") # sim-trained HARD brain | |
| LIVE_WEIGHTS = os.path.join(HERE, "mm_weights_live.npz") # player-adapted snapshot | |
| ENABLED = os.environ.get("MM_ONLINE", "1") == "1" | |
| UPDATE_EVERY = int(os.environ.get("MM_UPDATE_EVERY", "3")) # fights per update | |
| ADAPT_TIER = "hard" | |
| DATASET_REPO = os.environ.get("MM_DATASET_REPO") | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| WEIGHTS_IN_REPO = "mm_weights_live.npz" | |
| _LOCK = threading.Lock() | |
| _LEARNER = None | |
| _BUFFER = [] | |
| _FIGHTS = 0 | |
| def _pull_from_dataset(): | |
| """Try to download the latest adapted weights from the HF dataset.""" | |
| if not (DATASET_REPO and HF_TOKEN): | |
| return None | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| p = hf_hub_download(repo_id=DATASET_REPO, filename=WEIGHTS_IN_REPO, | |
| repo_type="dataset", token=HF_TOKEN) | |
| return {k: v for k, v in np.load(p).items()} | |
| except Exception as e: | |
| print(f"[online] no persisted weights pulled ({e}); starting from sim weights") | |
| return None | |
| def _push_to_dataset(): | |
| """Upload the current adapted weights to the HF dataset (best-effort).""" | |
| if not (DATASET_REPO and HF_TOKEN): | |
| return | |
| try: | |
| from huggingface_hub import upload_file | |
| upload_file(path_or_fileobj=LIVE_WEIGHTS, path_in_repo=WEIGHTS_IN_REPO, | |
| repo_id=DATASET_REPO, repo_type="dataset", token=HF_TOKEN) | |
| except Exception as e: | |
| print(f"[online] dataset push failed ({e})") | |
| def get_learner(): | |
| global _LEARNER | |
| if _LEARNER is None: | |
| base = _pull_from_dataset() | |
| if base is None and os.path.exists(LIVE_WEIGHTS): | |
| base = {k: v for k, v in np.load(LIVE_WEIGHTS).items()} | |
| if base is None: | |
| base = {k: v for k, v in np.load(BASE_WEIGHTS).items()} | |
| # anchor is ALWAYS the pristine sim-trained weights | |
| _LEARNER = OnlineLearner(base) | |
| _LEARNER.anchor = {k: v.astype(np.float64).copy() | |
| for k, v in np.load(BASE_WEIGHTS).items()} | |
| return _LEARNER | |
| def live_weights(): | |
| """The HARD brain's current (possibly player-adapted) weights, shared live.""" | |
| return get_learner().W | |
| def _save_live(): | |
| np.savez(LIVE_WEIGHTS, **{k: v.astype(np.float32) for k, v in get_learner().W.items()}) | |
| _push_to_dataset() | |
| def record_fight(trajectory: dict) -> dict: | |
| """Called by /learn. trajectory = {difficulty, steps:[{state,action,bossHP,playerHP}], | |
| result:{bossDied,playerDied}}. Buffers it and updates every UPDATE_EVERY fights.""" | |
| global _FIGHTS | |
| if not ENABLED: | |
| return {"enabled": False} | |
| if trajectory.get("difficulty") != ADAPT_TIER: | |
| return {"enabled": True, "skipped": "only HARD-tier fights train the brain"} | |
| if len(trajectory.get("steps", [])) < 2: | |
| return {"enabled": True, "skipped": "too short"} | |
| with _LOCK: | |
| _BUFFER.append(trajectory) | |
| _FIGHTS += 1 | |
| if len(_BUFFER) >= UPDATE_EVERY: | |
| stats = get_learner().update(list(_BUFFER)) | |
| _BUFFER.clear() | |
| if stats.get("updated"): | |
| _save_live() | |
| return {"enabled": True, "fights": _FIGHTS, **stats} | |
| return {"enabled": True, "fights": _FIGHTS, "buffered": len(_BUFFER), | |
| "update_in": UPDATE_EVERY - len(_BUFFER)} | |
| def status() -> dict: | |
| return {"enabled": ENABLED, "fights_seen": _FIGHTS, "buffered": len(_BUFFER), | |
| "persistent": bool(DATASET_REPO and HF_TOKEN), "adapt_tier": ADAPT_TIER} | |