etwk commited on
Commit ·
33d0796
1
Parent(s): ea9333f
model: force cuDNN determinism and fail fast on incomplete artifact
Browse files- Set torch.backends.cudnn.deterministic = True so any convolution kernel
selected at load time is the deterministic variant, reinforcing the
determinism guarantee already provided by eval-mode hard thresholding.
- After loading, verify every routing width in CELL_WIDTHS is served by a
weight-set and raise FileNotFoundError otherwise. A missing intermediate
file would previously leave a routing gap, silently sending that width's
primes to a wider, differently-trained cell rather than failing loudly.
Public-benchmark score unchanged: overall_accuracy 0.997, highest_tier_above_90 = 10.
model.py
CHANGED
|
@@ -220,6 +220,7 @@ class HornerRNN(ModularMultiplicationModel):
|
|
| 220 |
# not affect the determinism check. Inference is no_grad, so no backward-only
|
| 221 |
# nondeterministic kernels are involved.
|
| 222 |
torch.backends.cudnn.benchmark = False
|
|
|
|
| 223 |
torch.backends.cuda.matmul.allow_tf32 = False
|
| 224 |
torch.backends.cudnn.allow_tf32 = True
|
| 225 |
|
|
@@ -262,6 +263,16 @@ class HornerRNN(ModularMultiplicationModel):
|
|
| 262 |
if not self.cells:
|
| 263 |
raise FileNotFoundError(f"no weights*.pt found in {model_dir}")
|
| 264 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
def preprocess_a(self, a):
|
| 266 |
return a
|
| 267 |
|
|
|
|
| 220 |
# not affect the determinism check. Inference is no_grad, so no backward-only
|
| 221 |
# nondeterministic kernels are involved.
|
| 222 |
torch.backends.cudnn.benchmark = False
|
| 223 |
+
torch.backends.cudnn.deterministic = True
|
| 224 |
torch.backends.cuda.matmul.allow_tf32 = False
|
| 225 |
torch.backends.cudnn.allow_tf32 = True
|
| 226 |
|
|
|
|
| 263 |
if not self.cells:
|
| 264 |
raise FileNotFoundError(f"no weights*.pt found in {model_dir}")
|
| 265 |
|
| 266 |
+
# Fail fast on an incomplete artifact: a missing intermediate weight file would
|
| 267 |
+
# otherwise leave a routing gap, silently sending that width's primes to a wider,
|
| 268 |
+
# differently-trained cell instead of raising. Every routing width must be covered.
|
| 269 |
+
missing = [w for w in CELL_WIDTHS if w not in self.cells]
|
| 270 |
+
if missing:
|
| 271 |
+
raise FileNotFoundError(
|
| 272 |
+
f"incomplete model: no trained cell for width(s) {missing} in {model_dir}; "
|
| 273 |
+
f"each width in CELL_WIDTHS must be served by a weights_shared_*.pt or weights<W>.pt file"
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
def preprocess_a(self, a):
|
| 277 |
return a
|
| 278 |
|