Upload scripts/patch_unirig_cloud_io.py with huggingface_hub
Browse files
scripts/patch_unirig_cloud_io.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Force weights_only=False everywhere in lightning cloud_io.py
|
| 2 |
+
# PyTorch 2.6+ treats weights_only=None as True, breaking UniRig checkpoints
|
| 3 |
+
|
| 4 |
+
import glob, re
|
| 5 |
+
|
| 6 |
+
for pattern in [
|
| 7 |
+
'/root/miniconda/envs/unirig/lib/python*/site-packages/lightning/fabric/utilities/cloud_io.py',
|
| 8 |
+
'/root/miniconda/envs/unirig/lib/python*/site-packages/lightning/fabric/plugins/io/torch_io.py',
|
| 9 |
+
]:
|
| 10 |
+
for path in glob.glob(pattern):
|
| 11 |
+
with open(path) as f:
|
| 12 |
+
src = f.read()
|
| 13 |
+
|
| 14 |
+
# Replace any torch.load call that passes weights_only as a variable
|
| 15 |
+
# with a hardcoded False
|
| 16 |
+
new_src = re.sub(
|
| 17 |
+
r'weights_only\s*=\s*weights_only',
|
| 18 |
+
'weights_only=False',
|
| 19 |
+
src
|
| 20 |
+
)
|
| 21 |
+
# Also catch weights_only=None
|
| 22 |
+
new_src = re.sub(
|
| 23 |
+
r'weights_only\s*=\s*None(?!\s*\))',
|
| 24 |
+
'weights_only=False',
|
| 25 |
+
new_src
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
if new_src != src:
|
| 29 |
+
with open(path, 'w') as f:
|
| 30 |
+
f.write(new_src)
|
| 31 |
+
print(f'Patched {path}')
|
| 32 |
+
else:
|
| 33 |
+
print(f'No change needed: {path}')
|