ModuleMind / online.py
Quazim0t0's picture
Add files using upload-large-folder tool
45e7dfb verified
Raw
History Blame Contribute Delete
4.77 kB
"""
online.py -- persistent finetuning of the boss from real player fights.
Flow: the browser logs every boss decision (state, action, HP) during a fight and,
on fight end, POSTs the trajectory + who-died to /learn. We buffer trajectories and,
every MM_UPDATE_EVERY fights, run one REINFORCE step (mm_grad.OnlineLearner) that
nudges the HARD brain toward what actually worked against humans. The adapted
weights feed straight back into the live boss.
ONLY HARD-tier fights are used, so the data stays on-policy (Easy/Normal are
deliberately handicapped checkpoints; learning from them would be off-policy).
Persistence (optional, set as Space secrets):
HF_TOKEN - a write token
MM_DATASET_REPO - e.g. "your-name/boss-fight-online"
If set, adapted weights are pushed to / pulled from that dataset so learning
survives Space restarts. Without them it still adapts live, just in-memory.
Safety: a frozen copy of the sim-trained weights is kept as an anchor (the learner
pulls gently back toward it), so a weird run can't brick the boss.
"""
from __future__ import annotations
import os
import threading
import numpy as np
from mm_grad import OnlineLearner
HERE = os.path.dirname(os.path.abspath(__file__))
BASE_WEIGHTS = os.path.join(HERE, "mm_weights.npz") # sim-trained HARD brain
LIVE_WEIGHTS = os.path.join(HERE, "mm_weights_live.npz") # player-adapted snapshot
ENABLED = os.environ.get("MM_ONLINE", "1") == "1"
UPDATE_EVERY = int(os.environ.get("MM_UPDATE_EVERY", "3")) # fights per update
ADAPT_TIER = "hard"
DATASET_REPO = os.environ.get("MM_DATASET_REPO")
HF_TOKEN = os.environ.get("HF_TOKEN")
WEIGHTS_IN_REPO = "mm_weights_live.npz"
_LOCK = threading.Lock()
_LEARNER = None
_BUFFER = []
_FIGHTS = 0
def _pull_from_dataset():
"""Try to download the latest adapted weights from the HF dataset."""
if not (DATASET_REPO and HF_TOKEN):
return None
try:
from huggingface_hub import hf_hub_download
p = hf_hub_download(repo_id=DATASET_REPO, filename=WEIGHTS_IN_REPO,
repo_type="dataset", token=HF_TOKEN)
return {k: v for k, v in np.load(p).items()}
except Exception as e:
print(f"[online] no persisted weights pulled ({e}); starting from sim weights")
return None
def _push_to_dataset():
"""Upload the current adapted weights to the HF dataset (best-effort)."""
if not (DATASET_REPO and HF_TOKEN):
return
try:
from huggingface_hub import upload_file
upload_file(path_or_fileobj=LIVE_WEIGHTS, path_in_repo=WEIGHTS_IN_REPO,
repo_id=DATASET_REPO, repo_type="dataset", token=HF_TOKEN)
except Exception as e:
print(f"[online] dataset push failed ({e})")
def get_learner():
global _LEARNER
if _LEARNER is None:
base = _pull_from_dataset()
if base is None and os.path.exists(LIVE_WEIGHTS):
base = {k: v for k, v in np.load(LIVE_WEIGHTS).items()}
if base is None:
base = {k: v for k, v in np.load(BASE_WEIGHTS).items()}
# anchor is ALWAYS the pristine sim-trained weights
_LEARNER = OnlineLearner(base)
_LEARNER.anchor = {k: v.astype(np.float64).copy()
for k, v in np.load(BASE_WEIGHTS).items()}
return _LEARNER
def live_weights():
"""The HARD brain's current (possibly player-adapted) weights, shared live."""
return get_learner().W
def _save_live():
np.savez(LIVE_WEIGHTS, **{k: v.astype(np.float32) for k, v in get_learner().W.items()})
_push_to_dataset()
def record_fight(trajectory: dict) -> dict:
"""Called by /learn. trajectory = {difficulty, steps:[{state,action,bossHP,playerHP}],
result:{bossDied,playerDied}}. Buffers it and updates every UPDATE_EVERY fights."""
global _FIGHTS
if not ENABLED:
return {"enabled": False}
if trajectory.get("difficulty") != ADAPT_TIER:
return {"enabled": True, "skipped": "only HARD-tier fights train the brain"}
if len(trajectory.get("steps", [])) < 2:
return {"enabled": True, "skipped": "too short"}
with _LOCK:
_BUFFER.append(trajectory)
_FIGHTS += 1
if len(_BUFFER) >= UPDATE_EVERY:
stats = get_learner().update(list(_BUFFER))
_BUFFER.clear()
if stats.get("updated"):
_save_live()
return {"enabled": True, "fights": _FIGHTS, **stats}
return {"enabled": True, "fights": _FIGHTS, "buffered": len(_BUFFER),
"update_in": UPDATE_EVERY - len(_BUFFER)}
def status() -> dict:
return {"enabled": ENABLED, "fights_seen": _FIGHTS, "buffered": len(_BUFFER),
"persistent": bool(DATASET_REPO and HF_TOKEN), "adapt_tier": ADAPT_TIER}