MOTHER_CORE_V2 / inference.py
MediaStreamAI's picture
Upload inference.py (chunk 450 W2.7)
8c2d15a verified
#!/usr/bin/env python3
"""
MOTHER CORE V2 — chunk 450 (W2.7) — Reference Inference
========================================================
Sovereign UK AI by MediaStream AI Limited (MSAI).
This script loads chunk 450 from HuggingFace and runs the LOCKED inference
rules used during training. Deviation from these rules produces incorrect
or degenerate output.
Usage:
python inference.py "What is the capital of Scotland?"
python inference.py # enters interactive mode
Requirements:
pip install torch safetensors sentencepiece huggingface_hub
"""
from __future__ import annotations
import sys
import json
import torch
from pathlib import Path
from safetensors.torch import load_file
import sentencepiece as spm
# ════════════════════════════════════════════════════════════════════
# LOCKED INFERENCE RULES (DO NOT CHANGE)
# ════════════════════════════════════════════════════════════════════
BOS_ID = 1
EOS_ID = 2
PAD_ID = 0
PROMPT_FORMAT = "Question:\n\n{q}\n\nAnswer:"
REP_PEN = 1.3
NO_REPEAT_NGRAM = 4
MAX_NEW = 200
# Greedy argmax — no temperature, no sampling
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DTYPE = torch.bfloat16
def load_model_and_tokenizer(repo_dir: str):
"""Load MOTHER CORE from a local directory (downloaded HF snapshot)."""
repo = Path(repo_dir)
# Load config
with open(repo / "config.json") as f:
cfg = json.load(f)
print(f"Loaded config: {cfg['n_layers']} layers, dim={cfg['dim']}, "
f"params~{cfg.get('_msai_total_params_b', '?')}B")
# Load tokenizer (SentencePiece)
tokenizer = spm.SentencePieceProcessor()
tokenizer.Load(str(repo / "tokenizer.model"))
print(f"Loaded tokenizer: vocab_size={tokenizer.vocab_size()}")
# Build model — requires mother_core package available
try:
sys.path.insert(0, str(Path.home() / "mother-core-reasoning"))
from mother_core.config import ModelConfig
from mother_core.model import MotherCoreModel
except ImportError:
print("ERROR: mother_core package not found.")
print("This script requires the mother_core source code to be available.")
print("Either clone the MSAI sovereign training repo, or copy "
"mother_core/ into your PYTHONPATH.")
sys.exit(1)
config = ModelConfig(
vocab_size=cfg["vocab_size"],
dim=cfg["dim"],
n_layers=cfg["n_layers"],
n_heads=cfg["n_heads"],
n_kv_heads=cfg["n_kv_heads"],
ff_mult=cfg["ff_mult"],
max_seq_len=cfg["max_seq_len"],
rope_theta=cfg["rope_theta"],
rms_norm_eps=cfg["rms_norm_eps"],
)
model = MotherCoreModel(config)
# Load sharded safetensors
index_path = repo / "model.safetensors.index.json"
if index_path.exists():
with open(index_path) as f:
index = json.load(f)
shard_files = sorted(set(index["weight_map"].values()))
print(f"Loading {len(shard_files)} shards...")
full_sd = {}
for sf in shard_files:
print(f" - {sf}")
full_sd.update(load_file(str(repo / sf)))
model.load_state_dict(full_sd, strict=False)
else:
# Single-file fallback
sd = load_file(str(repo / "model.safetensors"))
model.load_state_dict(sd, strict=False)
model = model.to(DTYPE).to(DEVICE).eval()
print(f"Model on {DEVICE} in {DTYPE}")
return model, tokenizer
@torch.no_grad()
def generate_greedy(model, tokenizer, question: str,
max_new: int = MAX_NEW,
rep_pen: float = REP_PEN,
no_repeat_ngram: int = NO_REPEAT_NGRAM) -> str:
"""
LOCKED inference path. Greedy argmax with n-gram blocking and
frequency-scaled repetition penalty.
"""
prompt = PROMPT_FORMAT.format(q=question)
ids = [BOS_ID] + tokenizer.EncodeAsIds(prompt)
inp = torch.tensor([ids], device=DEVICE)
gen_out = []
for i in range(max_new):
x = inp if i == 0 else torch.tensor([[gen_out[-1]]], device=DEVICE)
out = model(x)
logits = out["logits"][:, -1, :].float()
# Block BOS in generated output, allow EOS only after at least 1 token
if len(gen_out) < 1:
logits[0, EOS_ID] = -1e9
logits[0, BOS_ID] = -1e9
# Frequency-scaled repetition penalty (only tokens seen ≥ 2 times)
if len(gen_out) >= 3:
from collections import Counter
counts = Counter(gen_out)
for t, c in counts.items():
if c >= 2 and 0 <= t < logits.shape[-1]:
logits[0, t] /= (rep_pen ** (c - 1))
# n-gram blocking
if no_repeat_ngram > 0 and len(gen_out) >= no_repeat_ngram:
ngram = tuple(gen_out[-(no_repeat_ngram - 1):]) if no_repeat_ngram > 1 else ()
banned = set()
for j in range(len(gen_out) - no_repeat_ngram + 1):
if tuple(gen_out[j:j + no_repeat_ngram - 1]) == ngram:
banned.add(gen_out[j + no_repeat_ngram - 1])
for t in banned:
if 0 <= t < logits.shape[-1]:
logits[0, t] = -1e9
# Greedy argmax (no temperature, no sampling)
nxt = logits.argmax(-1).item()
if nxt == EOS_ID:
break
gen_out.append(nxt)
# Cycle-break: 4 identical tokens in a row
if len(gen_out) >= 4 and len(set(gen_out[-4:])) == 1:
break
return tokenizer.DecodeIds(gen_out).strip()
def main():
# Download from HF if needed
try:
from huggingface_hub import snapshot_download
except ImportError:
print("ERROR: pip install huggingface_hub")
sys.exit(1)
print("Downloading MediaStreamAI/MOTHER_CORE_V2 ...")
repo_dir = snapshot_download(repo_id="MediaStreamAI/MOTHER_CORE_V2")
print(f"Local snapshot: {repo_dir}")
model, tokenizer = load_model_and_tokenizer(repo_dir)
if len(sys.argv) > 1:
question = " ".join(sys.argv[1:])
print(f"\nQ: {question}")
ans = generate_greedy(model, tokenizer, question)
print(f"A: {ans}")
return
print("\nInteractive mode. Type 'quit' to exit.\n")
while True:
try:
q = input("Q: ").strip()
except (EOFError, KeyboardInterrupt):
print()
break
if q.lower() in ("quit", "exit"):
break
if not q:
continue
ans = generate_greedy(model, tokenizer, q)
print(f"A: {ans}\n")
if __name__ == "__main__":
main()