etwk commited on
Commit
33d0796
·
1 Parent(s): ea9333f

model: force cuDNN determinism and fail fast on incomplete artifact

Browse files

- Set torch.backends.cudnn.deterministic = True so any convolution kernel
selected at load time is the deterministic variant, reinforcing the
determinism guarantee already provided by eval-mode hard thresholding.
- After loading, verify every routing width in CELL_WIDTHS is served by a
weight-set and raise FileNotFoundError otherwise. A missing intermediate
file would previously leave a routing gap, silently sending that width's
primes to a wider, differently-trained cell rather than failing loudly.

Public-benchmark score unchanged: overall_accuracy 0.997, highest_tier_above_90 = 10.

Files changed (1) hide show
  1. model.py +11 -0
model.py CHANGED
@@ -220,6 +220,7 @@ class HornerRNN(ModularMultiplicationModel):
220
  # not affect the determinism check. Inference is no_grad, so no backward-only
221
  # nondeterministic kernels are involved.
222
  torch.backends.cudnn.benchmark = False
 
223
  torch.backends.cuda.matmul.allow_tf32 = False
224
  torch.backends.cudnn.allow_tf32 = True
225
 
@@ -262,6 +263,16 @@ class HornerRNN(ModularMultiplicationModel):
262
  if not self.cells:
263
  raise FileNotFoundError(f"no weights*.pt found in {model_dir}")
264
 
 
 
 
 
 
 
 
 
 
 
265
  def preprocess_a(self, a):
266
  return a
267
 
 
220
  # not affect the determinism check. Inference is no_grad, so no backward-only
221
  # nondeterministic kernels are involved.
222
  torch.backends.cudnn.benchmark = False
223
+ torch.backends.cudnn.deterministic = True
224
  torch.backends.cuda.matmul.allow_tf32 = False
225
  torch.backends.cudnn.allow_tf32 = True
226
 
 
263
  if not self.cells:
264
  raise FileNotFoundError(f"no weights*.pt found in {model_dir}")
265
 
266
+ # Fail fast on an incomplete artifact: a missing intermediate weight file would
267
+ # otherwise leave a routing gap, silently sending that width's primes to a wider,
268
+ # differently-trained cell instead of raising. Every routing width must be covered.
269
+ missing = [w for w in CELL_WIDTHS if w not in self.cells]
270
+ if missing:
271
+ raise FileNotFoundError(
272
+ f"incomplete model: no trained cell for width(s) {missing} in {model_dir}; "
273
+ f"each width in CELL_WIDTHS must be served by a weights_shared_*.pt or weights<W>.pt file"
274
+ )
275
+
276
  def preprocess_a(self, a):
277
  return a
278