err805 commited on
Commit
2876fc8
·
verified ·
1 Parent(s): fcf7123

Fix runtime buffers after load

Browse files
Files changed (1) hide show
  1. moondream.py +12 -6
moondream.py CHANGED
@@ -132,7 +132,15 @@ class MoondreamModel(nn.Module):
132
  torch.empty(config.region.size_feat_dim // 2, 2, dtype=dtype).T
133
  )
134
 
135
- self.register_buffer("attn_mask", self._build_attn_mask(), persistent=False)
 
 
 
 
 
 
 
 
136
 
137
  self.use_flex_decoding = True
138
  self._causal_block_mask = None
@@ -164,7 +172,7 @@ class MoondreamModel(nn.Module):
164
  )
165
  return self._point_gen_indices
166
 
167
- def _build_attn_mask(self):
168
  attn_mask = torch.tril(
169
  torch.ones(
170
  1,
@@ -172,15 +180,13 @@ class MoondreamModel(nn.Module):
172
  self.config.text.max_context,
173
  self.config.text.max_context,
174
  dtype=torch.bool,
 
175
  )
176
  )
177
  patch_w = self.config.vision.crop_size // self.config.vision.enc_patch_size
178
  prefix_attn_len = 1 + patch_w**2
179
  attn_mask[..., :prefix_attn_len, :prefix_attn_len] = 1
180
- return attn_mask
181
-
182
- def _refresh_runtime_buffers(self):
183
- self.attn_mask = self._build_attn_mask().to(device=self.device)
184
  self.text.freqs_cis = precompute_freqs_cis(
185
  self.config.text.dim // (2 * self.config.text.n_heads),
186
  self.config.text.max_context,
 
132
  torch.empty(config.region.size_feat_dim // 2, 2, dtype=dtype).T
133
  )
134
 
135
+ attn_mask = torch.tril(
136
+ torch.ones(
137
+ 1, 1, config.text.max_context, config.text.max_context, dtype=torch.bool
138
+ )
139
+ )
140
+ patch_w = config.vision.crop_size // config.vision.enc_patch_size
141
+ prefix_attn_len = 1 + patch_w**2
142
+ attn_mask[..., :prefix_attn_len, :prefix_attn_len] = 1
143
+ self.register_buffer("attn_mask", attn_mask, persistent=False)
144
 
145
  self.use_flex_decoding = True
146
  self._causal_block_mask = None
 
172
  )
173
  return self._point_gen_indices
174
 
175
+ def _refresh_runtime_buffers(self):
176
  attn_mask = torch.tril(
177
  torch.ones(
178
  1,
 
180
  self.config.text.max_context,
181
  self.config.text.max_context,
182
  dtype=torch.bool,
183
+ device=self.device,
184
  )
185
  )
186
  patch_w = self.config.vision.crop_size // self.config.vision.enc_patch_size
187
  prefix_attn_len = 1 + patch_w**2
188
  attn_mask[..., :prefix_attn_len, :prefix_attn_len] = 1
189
+ self.attn_mask = attn_mask
 
 
 
190
  self.text.freqs_cis = precompute_freqs_cis(
191
  self.config.text.dim // (2 * self.config.text.n_heads),
192
  self.config.text.max_context,