Stefano Demartis commited on
Commit
4bf24bd
·
1 Parent(s): 89f99fe

fix: use d_att instead of hidden_dim in enable_corr=False fallback for LocalGatedPropagation

Browse files
Files changed (1) hide show
  1. model/attention.py +2 -1
model/attention.py CHANGED
@@ -829,8 +829,9 @@ class LocalGatedPropagation(nn.Module):
829
  qk = self.correlation_sampler(q, k).view(
830
  n, self.num_head, self.window_size * self.window_size, h * w)
831
  else:
 
832
  unfolded_k = self.pad_and_unfold(k).view(
833
- n * self.num_head, hidden_dim,
834
  self.window_size * self.window_size, h, w)
835
  qk = (q.unsqueeze(2) * unfolded_k).sum(dim=1).view(
836
  n, self.num_head, self.window_size * self.window_size, h * w)
 
829
  qk = self.correlation_sampler(q, k).view(
830
  n, self.num_head, self.window_size * self.window_size, h * w)
831
  else:
832
+ # Use self.d_att (key dim) not hidden_dim (value dim) for unfolded keys
833
  unfolded_k = self.pad_and_unfold(k).view(
834
+ n * self.num_head, self.d_att,
835
  self.window_size * self.window_size, h, w)
836
  qk = (q.unsqueeze(2) * unfolded_k).sum(dim=1).view(
837
  n, self.num_head, self.window_size * self.window_size, h * w)