Spaces:
Running on Zero
Running on Zero
HadiZayer commited on
Commit ·
674697c
1
Parent(s): 8ff325d
remove torch.hub DINOv2 download: stub out encoder (output was always zeros), strict=False to skip DINO checkpoint keys
Browse files- ldm/modules/encoders/modules.py +13 -12
- run_magicfu.py +1 -1
ldm/modules/encoders/modules.py
CHANGED
|
@@ -263,12 +263,13 @@ class DINOEmbedder(AbstractEncoder):
|
|
| 263 |
# 'huge': 1536
|
| 264 |
# }
|
| 265 |
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
|
|
|
| 272 |
self.freeze()
|
| 273 |
|
| 274 |
def freeze(self):
|
|
@@ -276,12 +277,12 @@ class DINOEmbedder(AbstractEncoder):
|
|
| 276 |
param.requires_grad = False
|
| 277 |
|
| 278 |
def forward(self, image):
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
return torch.
|
| 285 |
|
| 286 |
def encode(self, image):
|
| 287 |
return self(image)
|
|
|
|
| 263 |
# 'huge': 1536
|
| 264 |
# }
|
| 265 |
|
| 266 |
+
embedding_sizes = {
|
| 267 |
+
'small': 384,
|
| 268 |
+
'big': 768,
|
| 269 |
+
'large': 1024,
|
| 270 |
+
'huge': 1536
|
| 271 |
+
}
|
| 272 |
+
self.embedding_dim = embedding_sizes[dino_version]
|
| 273 |
self.freeze()
|
| 274 |
|
| 275 |
def freeze(self):
|
|
|
|
| 277 |
param.requires_grad = False
|
| 278 |
|
| 279 |
def forward(self, image):
|
| 280 |
+
# DINO output is unused (returns zeros); compute shape from input without loading the model
|
| 281 |
+
B = image.shape[0]
|
| 282 |
+
patch_size = 14
|
| 283 |
+
h, w = image.shape[-2], image.shape[-1]
|
| 284 |
+
num_patches = (h // patch_size) * (w // patch_size)
|
| 285 |
+
return torch.zeros(B, num_patches + 1, self.embedding_dim, device=image.device, dtype=image.dtype)
|
| 286 |
|
| 287 |
def encode(self, image):
|
| 288 |
return self(image)
|
run_magicfu.py
CHANGED
|
@@ -115,7 +115,7 @@ def get_model(config_path, ckpt_path):
|
|
| 115 |
model = load_model_from_config(config, None)
|
| 116 |
pl_sd = torch.load(ckpt_path, map_location="cpu")
|
| 117 |
|
| 118 |
-
m, u = model.load_state_dict(pl_sd, strict=
|
| 119 |
if len(m) > 0:
|
| 120 |
print("WARNING: missing keys:")
|
| 121 |
print(m)
|
|
|
|
| 115 |
model = load_model_from_config(config, None)
|
| 116 |
pl_sd = torch.load(ckpt_path, map_location="cpu")
|
| 117 |
|
| 118 |
+
m, u = model.load_state_dict(pl_sd, strict=False)
|
| 119 |
if len(m) > 0:
|
| 120 |
print("WARNING: missing keys:")
|
| 121 |
print(m)
|