Nekochu commited on
Commit
2aca4a6
·
1 Parent(s): 09dfa8a

fix: .weight.dtype crashes on INT8 quantized Linear, use .float()

Browse files
Files changed (1) hide show
  1. modules/layers.py +1 -1
modules/layers.py CHANGED
@@ -339,7 +339,7 @@ class TimestepEmbedder(nn.Module):
339
  def forward(self, t):
340
  t_freq = self.timestep_embedding(
341
  t, self.frequency_embedding_size, self.max_period
342
- ).type(self.mlp[0].weight.dtype) # type: ignore
343
  t_emb = self.mlp(t_freq)
344
  return t_emb
345
 
 
339
  def forward(self, t):
340
  t_freq = self.timestep_embedding(
341
  t, self.frequency_embedding_size, self.max_period
342
+ ).float() # INT8 quantized Linear has no .weight.dtype
343
  t_emb = self.mlp(t_freq)
344
  return t_emb
345