tier5 route [2^32,2^64) to base-16 modmul decoder (htop90=5)
Browse files- manifest.json +2 -2
- model.py +157 -11
- weights.pt +2 -2
manifest.json
CHANGED
|
@@ -2,6 +2,6 @@
|
|
| 2 |
"entry_class": "model.EBMModMul",
|
| 3 |
"output_base": 10,
|
| 4 |
"framework": "pytorch",
|
| 5 |
-
"model_description": "Two trained
|
| 6 |
-
"training_description": "Trained from random init on synthetic examples with x,y in [0,p). Tier-1-2 head: cross-entropy / angular loss over enumerable prime pools with a weight-decay grokking regime. Tier
|
| 7 |
}
|
|
|
|
| 2 |
"entry_class": "model.EBMModMul",
|
| 3 |
"output_base": 10,
|
| 4 |
"framework": "pytorch",
|
| 5 |
+
"model_description": "Two trained network families behind one interface, routed by prime size. Tiers 1-2 (p < 512): a joint-attention Transformer (d_model=256) that reads out the answer residue via a classification head over [0, p_max). Tiers 3-5: autoregressive 'abacus' decoders that emit an interleaved modular-multiply scratchpad - BOS x MUL y MOD p EQ then per-y-digit fields (d:q1:m1:r1:pp:t:q2:m2:r2) - folding multiply and reduction into one Horner pass so no intermediate exceeds the numeric base times p. Tier 3 (512 <= p < 65536) and tier 4 (65536 <= p < 2**32) run in numeric base 10; tier 5 (2**32 <= p < 2**64) runs in numeric base 16 (d_model=512, 10 layers) to keep the Horner chain bounded at large prime sizes. Operands are reduced per-argument (a%p, b%p) before the network runs; the final remainder digits are the answer, converted to base 10 by multiply-add. Emits [0] for p >= 2**64 (tiers 6+, out of the trained range).",
|
| 6 |
+
"training_description": "Trained from random init on synthetic examples with x,y in [0,p). Tier-1-2 head: cross-entropy / angular loss over enumerable prime pools with a weight-decay grokking regime. Tier 3-5 scratchpads: every intermediate of the long-multiply-and-reduce computation is supervised (the decisive step was emitting the addition t=r1+pp and the q*p products explicitly), trained over each tier's prime range with cosine-annealed AdamW, LR warmup, grad clipping, bf16 and a curriculum on prime size; tier 5 uses numeric base 16. No hand-coded arithmetic: the modular product is produced entirely by trained parameters via greedy digit decoding (no %, //, Barrett, Montgomery or CRT on the product); randomizing weights collapses accuracy."
|
| 7 |
}
|
model.py
CHANGED
|
@@ -9,11 +9,12 @@ Compliance contract (see rules/evaluation.md):
|
|
| 9 |
``[0, p)``) materially determines the answer.
|
| 10 |
- We emit the residue as base-10 digits (``output_base = 10``); the harness decodes.
|
| 11 |
|
| 12 |
-
Routing by prime size: tiers 1-2 (p < 512) use the classification head;
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
|
|
|
| 17 |
|
| 18 |
The architecture (encoder + classification/angular head) is loaded from the
|
| 19 |
checkpoint's ``arch`` field, so the same wrapper serves either trained head.
|
|
@@ -230,11 +231,15 @@ def _digits_msb(n: int) -> list[int]:
|
|
| 230 |
|
| 231 |
class AbacusDecoder(nn.Module):
|
| 232 |
"""Decoder-only transformer with abacus (place-within-number) embeddings.
|
| 233 |
-
Architecture identical to training/modmul_probe.py
|
|
|
|
|
|
|
|
|
|
| 234 |
|
| 235 |
-
def __init__(self, max_len, abacus_max, d_model=384, nhead=8, num_layers=8,
|
|
|
|
| 236 |
super().__init__()
|
| 237 |
-
self.tok_emb = nn.Embedding(
|
| 238 |
self.pos_emb = nn.Embedding(max_len, d_model)
|
| 239 |
self.abacus_emb = nn.Embedding(abacus_max, d_model)
|
| 240 |
layer = nn.TransformerEncoderLayer(
|
|
@@ -243,7 +248,7 @@ class AbacusDecoder(nn.Module):
|
|
| 243 |
)
|
| 244 |
self.transformer = nn.TransformerEncoder(layer, num_layers=num_layers)
|
| 245 |
self.ln = nn.LayerNorm(d_model)
|
| 246 |
-
self.head = nn.Linear(d_model,
|
| 247 |
self.max_len = max_len
|
| 248 |
self.register_buffer("pos_ids", torch.arange(max_len), persistent=False)
|
| 249 |
|
|
@@ -326,6 +331,111 @@ def _modmul_decode(model, cfg, xyp, device, chunk=128):
|
|
| 326 |
return [o if o is not None else [0] for o in out]
|
| 327 |
|
| 328 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 329 |
# ---------------------------------------------------------------------------
|
| 330 |
# Submission entry class
|
| 331 |
# ---------------------------------------------------------------------------
|
|
@@ -339,6 +449,8 @@ class EBMModMul(ModularMultiplicationModel):
|
|
| 339 |
self.mm_cfg = None
|
| 340 |
self.mm4 = None # tier-4 modmul scratchpad
|
| 341 |
self.mm4_cfg = None
|
|
|
|
|
|
|
| 342 |
|
| 343 |
def load(self, model_dir: str) -> None:
|
| 344 |
if torch.cuda.is_available():
|
|
@@ -375,6 +487,17 @@ class EBMModMul(ModularMultiplicationModel):
|
|
| 375 |
).to(self.device)
|
| 376 |
self.mm4.load_state_dict(ckpt["tier4"]["state_dict"])
|
| 377 |
self.mm4.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 378 |
|
| 379 |
# Per-argument identity preprocessing (each hook sees only its own argument).
|
| 380 |
def preprocess_a(self, a): return a
|
|
@@ -392,6 +515,7 @@ class EBMModMul(ModularMultiplicationModel):
|
|
| 392 |
TIER3_LO = 512
|
| 393 |
TIER3_HI = 65536
|
| 394 |
TIER4_HI = 2 ** 32
|
|
|
|
| 395 |
|
| 396 |
@torch.no_grad()
|
| 397 |
def predict_digits_batch(self, inputs):
|
|
@@ -399,16 +523,22 @@ class EBMModMul(ModularMultiplicationModel):
|
|
| 399 |
x_rows, y_rows, p_rows, p_ints, idx = [], [], [], [], [] # tiers 1-2
|
| 400 |
mm_items, mm_idx = [], [] # tier 3
|
| 401 |
mm4_items, mm4_idx = [], [] # tier 4
|
|
|
|
| 402 |
|
| 403 |
for i, (a_enc, b_enc, p_enc) in enumerate(inputs):
|
| 404 |
p = int(p_enc)
|
| 405 |
# Out of regime (residues don't fit the trained range): honest 0.
|
| 406 |
-
if p >= self.
|
| 407 |
out[i] = [0]
|
| 408 |
continue
|
| 409 |
a_red = int(a_enc) % p # per-operand reduction (allowed)
|
| 410 |
b_red = int(b_enc) % p
|
| 411 |
-
if p >= self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 412 |
if self.mm4 is not None:
|
| 413 |
mm4_items.append((a_red, b_red, p)); mm4_idx.append(i)
|
| 414 |
else:
|
|
@@ -445,6 +575,22 @@ class EBMModMul(ModularMultiplicationModel):
|
|
| 445 |
for j, i in enumerate(mm4_idx):
|
| 446 |
out[i] = res[j]
|
| 447 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 448 |
return [o if o is not None else [0] for o in out]
|
| 449 |
|
| 450 |
def max_batch_size(self) -> int:
|
|
|
|
| 9 |
``[0, p)``) materially determines the answer.
|
| 10 |
- We emit the residue as base-10 digits (``output_base = 10``); the harness decodes.
|
| 11 |
|
| 12 |
+
Routing by prime size: tiers 1-2 (p < 512) use the classification head; tiers 3-5
|
| 13 |
+
use the interleaved modular-multiply scratchpad decoder (same architecture,
|
| 14 |
+
separately trained weights) — tier 3 (512 <= p < 65536) and tier 4
|
| 15 |
+
(65536 <= p < 2**32) in numeric base 10, tier 5 (2**32 <= p < 2**64) in numeric
|
| 16 |
+
base 16 (shorter Horner chain at large prime sizes). p >= 2**64 (tiers 6+) is out
|
| 17 |
+
of regime, so we emit ``[0]`` — an honest fallback, not a guess.
|
| 18 |
|
| 19 |
The architecture (encoder + classification/angular head) is loaded from the
|
| 20 |
checkpoint's ``arch`` field, so the same wrapper serves either trained head.
|
|
|
|
| 231 |
|
| 232 |
class AbacusDecoder(nn.Module):
|
| 233 |
"""Decoder-only transformer with abacus (place-within-number) embeddings.
|
| 234 |
+
Architecture identical to training/modmul_probe.py / tier5_modmul.py for
|
| 235 |
+
state_dict match. ``vocab`` defaults to the base-10 scratchpad vocab (18) used
|
| 236 |
+
by tiers 3-4; the tier-5 base-16 scratchpad passes vocab=24 (16 digits + 8
|
| 237 |
+
specials)."""
|
| 238 |
|
| 239 |
+
def __init__(self, max_len, abacus_max, d_model=384, nhead=8, num_layers=8,
|
| 240 |
+
dim_ff=1536, vocab=MM_VOCAB):
|
| 241 |
super().__init__()
|
| 242 |
+
self.tok_emb = nn.Embedding(vocab, d_model)
|
| 243 |
self.pos_emb = nn.Embedding(max_len, d_model)
|
| 244 |
self.abacus_emb = nn.Embedding(abacus_max, d_model)
|
| 245 |
layer = nn.TransformerEncoderLayer(
|
|
|
|
| 248 |
)
|
| 249 |
self.transformer = nn.TransformerEncoder(layer, num_layers=num_layers)
|
| 250 |
self.ln = nn.LayerNorm(d_model)
|
| 251 |
+
self.head = nn.Linear(d_model, vocab, bias=False)
|
| 252 |
self.max_len = max_len
|
| 253 |
self.register_buffer("pos_ids", torch.arange(max_len), persistent=False)
|
| 254 |
|
|
|
|
| 331 |
return [o if o is not None else [0] for o in out]
|
| 332 |
|
| 333 |
|
| 334 |
+
# ---------------------------------------------------------------------------
|
| 335 |
+
# Tier-5 base-16 modular-multiply scratchpad (autoregressive).
|
| 336 |
+
#
|
| 337 |
+
# Same AbacusDecoder architecture and 9-field Horner scratchpad as tiers 3-4, but
|
| 338 |
+
# trained in numeric BASE 16 (so the per-step partial products / quotient digits
|
| 339 |
+
# stay easy while the chain length is bounded). tier-5 primes are 33-64 bit, so the
|
| 340 |
+
# chain is ~16 base-16 Horner blocks (~1853 tokens). Vocab: digits 0..15 then
|
| 341 |
+
# PAD,BOS,MUL,MOD,EQ,COLON,STEP,EOS = base..base+7 (see tier5_modmul.make_vocab).
|
| 342 |
+
# The decoded answer is a BASE-16 residue; we convert it to base-10 digits with
|
| 343 |
+
# multiply-add only (no %, //, Barrett/Montgomery/CRT on the product) so it matches
|
| 344 |
+
# the global output_base=10. Compliance is unchanged: the only modular reduction in
|
| 345 |
+
# shipped code is the per-operand int(a)%p / int(b)%p done before the network runs.
|
| 346 |
+
# ---------------------------------------------------------------------------
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
def _make_vocab_base(base: int) -> dict:
|
| 350 |
+
"""Base-B scratchpad vocab, matching training/tier5_modmul.make_vocab."""
|
| 351 |
+
PAD, BOS, MUL, MOD, EQ, COLON, STEP, EOS = (
|
| 352 |
+
base, base + 1, base + 2, base + 3, base + 4, base + 5, base + 6, base + 7)
|
| 353 |
+
return dict(PAD=PAD, BOS=BOS, MUL=MUL, MOD=MOD, EQ=EQ, COLON=COLON, STEP=STEP,
|
| 354 |
+
EOS=EOS, VOCAB=base + 8,
|
| 355 |
+
SPECIALS={PAD, BOS, MUL, MOD, EQ, COLON, STEP, EOS})
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def _digits_base_msb(n: int, base: int) -> list[int]:
|
| 359 |
+
if n == 0:
|
| 360 |
+
return [0]
|
| 361 |
+
s = []
|
| 362 |
+
while n > 0:
|
| 363 |
+
s.append(n % base)
|
| 364 |
+
n //= base
|
| 365 |
+
return s[::-1]
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
def _base_to_int(ds: list[int], base: int) -> int:
|
| 369 |
+
v = 0
|
| 370 |
+
for d in ds:
|
| 371 |
+
v = v * base + d # multiply-add only; no %/// on the product
|
| 372 |
+
return v
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
@torch.no_grad()
|
| 376 |
+
def _modmul_decode_base(model, cfg, xyp, device, base, chunk=64):
|
| 377 |
+
"""Greedy-decode (x*y) mod p in numeric base ``base`` for each (x, y, p) with
|
| 378 |
+
x, y already in [0, p). Returns base-10 digit-lists (MSB-first), or [0] if
|
| 379 |
+
unparseable. Mirrors _modmul_decode but base-parametrized; the final base-``base``
|
| 380 |
+
residue is re-expressed in base 10 via multiply-add (compliant)."""
|
| 381 |
+
V = _make_vocab_base(base)
|
| 382 |
+
PAD, EOS, COLON = V["PAD"], V["EOS"], V["COLON"]
|
| 383 |
+
max_len, abmax = cfg["max_len"], cfg["abacus_max"]
|
| 384 |
+
specials = torch.tensor(sorted(V["SPECIALS"]), device=device)
|
| 385 |
+
out: list[list[int] | None] = [None] * len(xyp)
|
| 386 |
+
|
| 387 |
+
groups = defaultdict(list)
|
| 388 |
+
prompts = []
|
| 389 |
+
for i, (x, y, p) in enumerate(xyp):
|
| 390 |
+
xd, yd, pd = (_digits_base_msb(x, base), _digits_base_msb(y, base),
|
| 391 |
+
_digits_base_msb(p, base))
|
| 392 |
+
toks = [V["BOS"]] + xd + [V["MUL"]] + yd + [V["MOD"]] + pd + [V["EQ"]]
|
| 393 |
+
abac = ([0] + list(range(len(xd))) + [0] + list(range(len(yd)))
|
| 394 |
+
+ [0] + list(range(len(pd))) + [0])
|
| 395 |
+
groups[len(toks)].append(i)
|
| 396 |
+
prompts.append((toks, abac))
|
| 397 |
+
|
| 398 |
+
for L, idxs in groups.items():
|
| 399 |
+
for s in range(0, len(idxs), chunk):
|
| 400 |
+
sub = idxs[s:s + chunk]
|
| 401 |
+
g = len(sub)
|
| 402 |
+
toks = torch.tensor([prompts[i][0] for i in sub], dtype=torch.long, device=device)
|
| 403 |
+
abac = torch.tensor([prompts[i][1] for i in sub], dtype=torch.long, device=device)
|
| 404 |
+
seg = torch.zeros(g, dtype=torch.long, device=device)
|
| 405 |
+
done = torch.zeros(g, dtype=torch.bool, device=device)
|
| 406 |
+
gen = [[] for _ in range(g)]
|
| 407 |
+
while toks.shape[1] < max_len and not bool(done.all()):
|
| 408 |
+
nxt = model(toks, abac)[:, -1].argmax(-1)
|
| 409 |
+
nxt = torch.where(done, torch.full_like(nxt, PAD), nxt)
|
| 410 |
+
is_sp = (nxt.unsqueeze(1) == specials).any(1)
|
| 411 |
+
new_abac = torch.where(is_sp, torch.zeros_like(seg),
|
| 412 |
+
torch.clamp(seg, max=abmax - 1))
|
| 413 |
+
seg = torch.where(is_sp, torch.zeros_like(seg), seg + 1)
|
| 414 |
+
nc, dc = nxt.tolist(), done.tolist()
|
| 415 |
+
for j in range(g):
|
| 416 |
+
if not dc[j] and nc[j] != EOS and nc[j] != PAD:
|
| 417 |
+
gen[j].append(nc[j])
|
| 418 |
+
toks = torch.cat([toks, nxt.unsqueeze(1)], dim=1)
|
| 419 |
+
abac = torch.cat([abac, new_abac.unsqueeze(1)], dim=1)
|
| 420 |
+
done = done | (nxt == EOS)
|
| 421 |
+
for j, i in enumerate(sub):
|
| 422 |
+
gj = gen[j]
|
| 423 |
+
if COLON in gj:
|
| 424 |
+
k = len(gj) - 1 - gj[::-1].index(COLON)
|
| 425 |
+
ans = [d for d in gj[k + 1:] if d < base]
|
| 426 |
+
out[i] = int_to_decimal_digits(_base_to_int(ans, base)) if ans else [0]
|
| 427 |
+
else:
|
| 428 |
+
out[i] = [0]
|
| 429 |
+
del toks, abac, seg, done, gen
|
| 430 |
+
# Only MPS needs the cache drop (it never frees mid-run and OOMs on
|
| 431 |
+
# long chains). On CUDA, empty_cache() synchronizes + forces the
|
| 432 |
+
# allocator to re-acquire buffers every chunk -- a ~5x slowdown over
|
| 433 |
+
# the ~1853-step base-16 decode -- and tier-5 peak is <3 GB anyway.
|
| 434 |
+
if device.type == "mps":
|
| 435 |
+
torch.mps.empty_cache()
|
| 436 |
+
return [o if o is not None else [0] for o in out]
|
| 437 |
+
|
| 438 |
+
|
| 439 |
# ---------------------------------------------------------------------------
|
| 440 |
# Submission entry class
|
| 441 |
# ---------------------------------------------------------------------------
|
|
|
|
| 449 |
self.mm_cfg = None
|
| 450 |
self.mm4 = None # tier-4 modmul scratchpad
|
| 451 |
self.mm4_cfg = None
|
| 452 |
+
self.mm5 = None # tier-5 base-16 modmul scratchpad
|
| 453 |
+
self.mm5_cfg = None
|
| 454 |
|
| 455 |
def load(self, model_dir: str) -> None:
|
| 456 |
if torch.cuda.is_available():
|
|
|
|
| 487 |
).to(self.device)
|
| 488 |
self.mm4.load_state_dict(ckpt["tier4"]["state_dict"])
|
| 489 |
self.mm4.eval()
|
| 490 |
+
# Tier 5: base-16 scratchpad, trained on primes in [2**33, 2**64).
|
| 491 |
+
if "tier5" in ckpt:
|
| 492 |
+
c5 = ckpt["tier5"]["config"]
|
| 493 |
+
self.mm5_cfg = c5
|
| 494 |
+
self.mm5 = AbacusDecoder(
|
| 495 |
+
max_len=c5["max_len"], abacus_max=c5["abacus_max"], d_model=c5["d_model"],
|
| 496 |
+
nhead=c5["nhead"], num_layers=c5["layers"], dim_ff=c5["dim_ff"],
|
| 497 |
+
vocab=c5["base"] + 8,
|
| 498 |
+
).to(self.device)
|
| 499 |
+
self.mm5.load_state_dict(ckpt["tier5"]["state_dict"])
|
| 500 |
+
self.mm5.eval()
|
| 501 |
|
| 502 |
# Per-argument identity preprocessing (each hook sees only its own argument).
|
| 503 |
def preprocess_a(self, a): return a
|
|
|
|
| 515 |
TIER3_LO = 512
|
| 516 |
TIER3_HI = 65536
|
| 517 |
TIER4_HI = 2 ** 32
|
| 518 |
+
TIER5_HI = 2 ** 64
|
| 519 |
|
| 520 |
@torch.no_grad()
|
| 521 |
def predict_digits_batch(self, inputs):
|
|
|
|
| 523 |
x_rows, y_rows, p_rows, p_ints, idx = [], [], [], [], [] # tiers 1-2
|
| 524 |
mm_items, mm_idx = [], [] # tier 3
|
| 525 |
mm4_items, mm4_idx = [], [] # tier 4
|
| 526 |
+
mm5_items, mm5_idx = [], [] # tier 5
|
| 527 |
|
| 528 |
for i, (a_enc, b_enc, p_enc) in enumerate(inputs):
|
| 529 |
p = int(p_enc)
|
| 530 |
# Out of regime (residues don't fit the trained range): honest 0.
|
| 531 |
+
if p >= self.TIER5_HI:
|
| 532 |
out[i] = [0]
|
| 533 |
continue
|
| 534 |
a_red = int(a_enc) % p # per-operand reduction (allowed)
|
| 535 |
b_red = int(b_enc) % p
|
| 536 |
+
if p >= self.TIER4_HI:
|
| 537 |
+
if self.mm5 is not None:
|
| 538 |
+
mm5_items.append((a_red, b_red, p)); mm5_idx.append(i)
|
| 539 |
+
else:
|
| 540 |
+
out[i] = [0]
|
| 541 |
+
elif p >= self.TIER3_HI:
|
| 542 |
if self.mm4 is not None:
|
| 543 |
mm4_items.append((a_red, b_red, p)); mm4_idx.append(i)
|
| 544 |
else:
|
|
|
|
| 575 |
for j, i in enumerate(mm4_idx):
|
| 576 |
out[i] = res[j]
|
| 577 |
|
| 578 |
+
if mm5_items:
|
| 579 |
+
# Tier-5 base-16 chains are ~1853 tokens. The model was trained under
|
| 580 |
+
# bf16 autocast, and decoding in bf16 (not fp32) is what keeps the long
|
| 581 |
+
# attention both fast (~125s/100 vs ~470-660s in fp32) and within memory
|
| 582 |
+
# (<3 GB vs a fp32 OOM at 1853-length attention). Match training precision.
|
| 583 |
+
if self.device.type == "cuda":
|
| 584 |
+
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 585 |
+
res = _modmul_decode_base(self.mm5, self.mm5_cfg, mm5_items,
|
| 586 |
+
self.device, base=self.mm5_cfg["base"],
|
| 587 |
+
chunk=64)
|
| 588 |
+
else:
|
| 589 |
+
res = _modmul_decode_base(self.mm5, self.mm5_cfg, mm5_items,
|
| 590 |
+
self.device, base=self.mm5_cfg["base"], chunk=64)
|
| 591 |
+
for j, i in enumerate(mm5_idx):
|
| 592 |
+
out[i] = res[j]
|
| 593 |
+
|
| 594 |
return [o if o is not None else [0] for o in out]
|
| 595 |
|
| 596 |
def max_batch_size(self) -> int:
|
weights.pt
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:165d028f8af0a9efb45b755a20e855701b3e6594bcbb69b63cdc356544db8303
|
| 3 |
+
size 271364937
|