HadiZayer commited on
Commit
674697c
·
1 Parent(s): 8ff325d

remove torch.hub DINOv2 download: stub out encoder (output was always zeros), strict=False to skip DINO checkpoint keys

Browse files
Files changed (2) hide show
  1. ldm/modules/encoders/modules.py +13 -12
  2. run_magicfu.py +1 -1
ldm/modules/encoders/modules.py CHANGED
@@ -263,12 +263,13 @@ class DINOEmbedder(AbstractEncoder):
263
  # 'huge': 1536
264
  # }
265
 
266
- # embedding_size = embedding_sizes[dino_version]
267
- letter = letter_map[dino_version]
268
- # self.transformer = CLIPVisionModel.from_pretrained(version)
269
- self.dino_model = torch.hub.load('facebookresearch/dinov2', f'dinov2_vit{letter}14_reg', pretrained=False).cuda()
270
-
271
-
 
272
  self.freeze()
273
 
274
  def freeze(self):
@@ -276,12 +277,12 @@ class DINOEmbedder(AbstractEncoder):
276
  param.requires_grad = False
277
 
278
  def forward(self, image):
279
- with torch.no_grad():
280
- outputs = self.dino_model.forward_features(image)
281
- patch_tokens = outputs['x_norm_patchtokens']
282
- global_token = outputs['x_norm_clstoken'].unsqueeze(1)
283
- features = torch.concat([patch_tokens, global_token], dim=1)
284
- return torch.zeros_like(features)
285
 
286
  def encode(self, image):
287
  return self(image)
 
263
  # 'huge': 1536
264
  # }
265
 
266
+ embedding_sizes = {
267
+ 'small': 384,
268
+ 'big': 768,
269
+ 'large': 1024,
270
+ 'huge': 1536
271
+ }
272
+ self.embedding_dim = embedding_sizes[dino_version]
273
  self.freeze()
274
 
275
  def freeze(self):
 
277
  param.requires_grad = False
278
 
279
  def forward(self, image):
280
+ # DINO output is unused (returns zeros); compute shape from input without loading the model
281
+ B = image.shape[0]
282
+ patch_size = 14
283
+ h, w = image.shape[-2], image.shape[-1]
284
+ num_patches = (h // patch_size) * (w // patch_size)
285
+ return torch.zeros(B, num_patches + 1, self.embedding_dim, device=image.device, dtype=image.dtype)
286
 
287
  def encode(self, image):
288
  return self(image)
run_magicfu.py CHANGED
@@ -115,7 +115,7 @@ def get_model(config_path, ckpt_path):
115
  model = load_model_from_config(config, None)
116
  pl_sd = torch.load(ckpt_path, map_location="cpu")
117
 
118
- m, u = model.load_state_dict(pl_sd, strict=True)
119
  if len(m) > 0:
120
  print("WARNING: missing keys:")
121
  print(m)
 
115
  model = load_model_from_config(config, None)
116
  pl_sd = torch.load(ckpt_path, map_location="cpu")
117
 
118
+ m, u = model.load_state_dict(pl_sd, strict=False)
119
  if len(m) > 0:
120
  print("WARNING: missing keys:")
121
  print(m)