fix: typo in lmm leads infer cannot work
Browse files- deepgen_pipeline.py +6 -6
deepgen_pipeline.py
CHANGED
|
@@ -1185,15 +1185,15 @@ class DeepGenPipeline(DiffusionPipeline):
|
|
| 1185 |
else:
|
| 1186 |
input_ids = input_ids[:, :-l]
|
| 1187 |
if image_embeds is None:
|
| 1188 |
-
inputs_embeds = self.
|
| 1189 |
else:
|
| 1190 |
inputs_embeds = torch.zeros(
|
| 1191 |
-
*input_ids.shape, self.
|
| 1192 |
device=self._gpu_device, dtype=self.transformer.dtype)
|
| 1193 |
inputs_embeds[input_ids == self.image_token_id] = \
|
| 1194 |
-
image_embeds.contiguous().view(-1, self.
|
| 1195 |
inputs_embeds[input_ids != self.image_token_id] = \
|
| 1196 |
-
self.
|
| 1197 |
inputs_embeds = torch.cat([inputs_embeds, query_embeds], dim=1)
|
| 1198 |
|
| 1199 |
return dict(inputs_embeds=inputs_embeds, attention_mask=attention_mask,
|
|
@@ -1334,7 +1334,7 @@ class DeepGenPipeline(DiffusionPipeline):
|
|
| 1334 |
hidden_states = self.connector_module.meta_queries[None].expand(
|
| 1335 |
2 * b, self.num_queries, -1)
|
| 1336 |
inputs = self.prepare_forward_input(query_embeds=hidden_states, **text_inputs)
|
| 1337 |
-
output = self.
|
| 1338 |
|
| 1339 |
# SCB: extract multi-layer hidden states
|
| 1340 |
hidden_states = output.hidden_states
|
|
@@ -1391,4 +1391,4 @@ class DeepGenPipeline(DiffusionPipeline):
|
|
| 1391 |
img = torch.clamp(127.5 * img + 128.0, 0, 255).to("cpu", dtype=torch.uint8).numpy()
|
| 1392 |
images.append(Image.fromarray(img))
|
| 1393 |
|
| 1394 |
-
return SimpleNamespace(images=images)
|
|
|
|
| 1185 |
else:
|
| 1186 |
input_ids = input_ids[:, :-l]
|
| 1187 |
if image_embeds is None:
|
| 1188 |
+
inputs_embeds = self.lmm.get_input_embeddings()(input_ids)
|
| 1189 |
else:
|
| 1190 |
inputs_embeds = torch.zeros(
|
| 1191 |
+
*input_ids.shape, self.lmm.config.hidden_size,
|
| 1192 |
device=self._gpu_device, dtype=self.transformer.dtype)
|
| 1193 |
inputs_embeds[input_ids == self.image_token_id] = \
|
| 1194 |
+
image_embeds.contiguous().view(-1, self.lmm.config.hidden_size)
|
| 1195 |
inputs_embeds[input_ids != self.image_token_id] = \
|
| 1196 |
+
self.lmm.get_input_embeddings()(input_ids[input_ids != self.image_token_id])
|
| 1197 |
inputs_embeds = torch.cat([inputs_embeds, query_embeds], dim=1)
|
| 1198 |
|
| 1199 |
return dict(inputs_embeds=inputs_embeds, attention_mask=attention_mask,
|
|
|
|
| 1334 |
hidden_states = self.connector_module.meta_queries[None].expand(
|
| 1335 |
2 * b, self.num_queries, -1)
|
| 1336 |
inputs = self.prepare_forward_input(query_embeds=hidden_states, **text_inputs)
|
| 1337 |
+
output = self.lmm(**inputs, return_dict=True, output_hidden_states=True)
|
| 1338 |
|
| 1339 |
# SCB: extract multi-layer hidden states
|
| 1340 |
hidden_states = output.hidden_states
|
|
|
|
| 1391 |
img = torch.clamp(127.5 * img + 128.0, 0, 255).to("cpu", dtype=torch.uint8).numpy()
|
| 1392 |
images.append(Image.fromarray(img))
|
| 1393 |
|
| 1394 |
+
return SimpleNamespace(images=images)
|