Stefano Demartis commited on
Commit ·
4bf24bd
1
Parent(s): 89f99fe
fix: use d_att instead of hidden_dim in enable_corr=False fallback for LocalGatedPropagation
Browse files- model/attention.py +2 -1
model/attention.py
CHANGED
|
@@ -829,8 +829,9 @@ class LocalGatedPropagation(nn.Module):
|
|
| 829 |
qk = self.correlation_sampler(q, k).view(
|
| 830 |
n, self.num_head, self.window_size * self.window_size, h * w)
|
| 831 |
else:
|
|
|
|
| 832 |
unfolded_k = self.pad_and_unfold(k).view(
|
| 833 |
-
n * self.num_head,
|
| 834 |
self.window_size * self.window_size, h, w)
|
| 835 |
qk = (q.unsqueeze(2) * unfolded_k).sum(dim=1).view(
|
| 836 |
n, self.num_head, self.window_size * self.window_size, h * w)
|
|
|
|
| 829 |
qk = self.correlation_sampler(q, k).view(
|
| 830 |
n, self.num_head, self.window_size * self.window_size, h * w)
|
| 831 |
else:
|
| 832 |
+
# Use self.d_att (key dim) not hidden_dim (value dim) for unfolded keys
|
| 833 |
unfolded_k = self.pad_and_unfold(k).view(
|
| 834 |
+
n * self.num_head, self.d_att,
|
| 835 |
self.window_size * self.window_size, h, w)
|
| 836 |
qk = (q.unsqueeze(2) * unfolded_k).sum(dim=1).view(
|
| 837 |
n, self.num_head, self.window_size * self.window_size, h * w)
|