File size: 15,418 Bytes
feca35a
96ba62d
feca35a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96ba62d
 
 
 
 
 
 
feca35a
 
 
 
 
 
 
 
 
 
96ba62d
 
feca35a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96ba62d
 
feca35a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96ba62d
feca35a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96ba62d
 
 
 
 
feca35a
 
 
 
96ba62d
feca35a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96ba62d
feca35a
 
 
 
 
 
 
 
 
 
 
 
 
96ba62d
 
 
 
 
feca35a
 
 
 
96ba62d
feca35a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
from contextlib import nullcontext
from typing import List, Optional, Tuple, Union

import torch
from einops import rearrange
from PIL import Image
from tqdm.auto import tqdm

from diffusers import DiffusionPipeline
from diffusers.pipelines.pipeline_utils import ImagePipelineOutput

from .constants import SUPPORTED_IMAGE_SIZES


PromptType = Union[str, List[str]]


def _get_pkv_seq_len(past_key_values) -> int:
    """Get cached sequence length from past_key_values (supports tuple and DynamicCache)."""
    if hasattr(past_key_values, "get_seq_length"):
        return past_key_values.get_seq_length()
    return past_key_values[0][0].shape[2]


class BitDanceDiffusionPipeline(DiffusionPipeline):
    model_cpu_offload_seq = "text_encoder->projector->diffusion_head->autoencoder"

    def __init__(
        self,
        tokenizer,
        text_encoder,
        autoencoder,
        diffusion_head,
        projector,
        supported_image_sizes: Optional[List[List[int]]] = None,
        dtype: Optional[torch.dtype] = None,
    ) -> None:
        super().__init__()
        self.register_modules(
            tokenizer=tokenizer,
            text_encoder=text_encoder,
            autoencoder=autoencoder,
            diffusion_head=diffusion_head,
            projector=projector,
        )

        image_sizes = supported_image_sizes or SUPPORTED_IMAGE_SIZES
        self.register_to_config(supported_image_sizes=[list(size) for size in image_sizes])

        self.hidden_size = self.text_encoder.config.hidden_size
        self.vae_patch_size = self.autoencoder.patch_size
        self.parallel_num = int(self.diffusion_head.config.parallel_num)
        self.ps = int(self.parallel_num**0.5)
        if self.ps * self.ps != self.parallel_num:
            raise ValueError(
                f"parallel_num must be a perfect square (got {self.parallel_num})."
            )

        self._build_pos_embed()

    @property
    def supported_image_sizes(self) -> List[List[int]]:
        return [list(size) for size in self.config.supported_image_sizes]

    def _execution_device_fallback(self) -> torch.device:
        if getattr(self, "_execution_device", None) is not None:
            return self._execution_device
        return next(self.text_encoder.parameters()).device

    def _build_pos_embed(self) -> None:
        max_resolution = max(max(size) for size in self.supported_image_sizes)
        max_len = max_resolution // self.vae_patch_size
        pos_embed_1d = self._get_1d_sincos_pos_embed(self.hidden_size // 2, max_len)
        self.pos_embed_1d = pos_embed_1d

    @staticmethod
    def _get_1d_sincos_pos_embed(dim: int, max_len: int, pe_interpolation: float = 1.0) -> torch.Tensor:
        if dim % 2 != 0:
            raise ValueError(f"dim must be even, got {dim}")
        omega = torch.arange(dim // 2, dtype=torch.float32)
        omega /= dim / 2.0
        omega = 1.0 / 10000**omega
        pos = torch.arange(max_len, dtype=torch.float32) / pe_interpolation
        out = torch.einsum("m,d->md", pos, omega)
        emb_sin = torch.sin(out)
        emb_cos = torch.cos(out)
        return torch.cat([emb_sin, emb_cos], dim=1)

    def _get_2d_embed(self, h: int, w: int, ps: int = 1) -> torch.Tensor:
        emb_v = self.pos_embed_1d[:h]
        emb_h = self.pos_embed_1d[:w]
        grid_v = emb_v.view(h, 1, self.hidden_size // 2).repeat(1, w, 1)
        grid_h = emb_h.view(1, w, self.hidden_size // 2).repeat(h, 1, 1)
        pos_embed = torch.cat([grid_h, grid_v], dim=-1)
        return rearrange(pos_embed, "(h p1) (w p2) c -> (h w p1 p2) c", p1=ps, p2=ps)

    def _encode_prompt_to_embeds(
        self,
        prompt: str,
        image_size: Tuple[int, int],
        num_images_per_prompt: int,
        guidance_scale: float,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]:
        device = self._execution_device_fallback()
        model = self.text_encoder.model
        tokenizer = self.tokenizer

        cond_prompt = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
        uncond_prompt = "<|im_start|>assistant\n"

        cond_ids = torch.tensor(tokenizer.encode(cond_prompt), device=device, dtype=torch.long)
        cond_emb = model.embed_tokens(cond_ids)
        uncond_emb = None
        if guidance_scale > 1.0:
            uncond_ids = torch.tensor(tokenizer.encode(uncond_prompt), device=device, dtype=torch.long)
            uncond_emb = model.embed_tokens(uncond_ids)

        image_h, image_w = image_size
        img_start_id = tokenizer.convert_tokens_to_ids("<|vision_start|>")
        res_h_token_id = tokenizer.convert_tokens_to_ids(f"<|res_{image_h // self.vae_patch_size}|>")
        res_w_token_id = tokenizer.convert_tokens_to_ids(f"<|res_{image_w // self.vae_patch_size}|>")
        img_start_emb = model.embed_tokens(torch.tensor([img_start_id, res_h_token_id, res_w_token_id], device=device))

        for i in range(1, self.parallel_num):
            query_token_id = tokenizer.convert_tokens_to_ids(f"<|query_{i}|>")
            query_token = torch.tensor([query_token_id], device=device, dtype=torch.long)
            query_embed = model.embed_tokens(query_token)
            img_start_emb = torch.cat([img_start_emb, query_embed], dim=0)

        input_embeds_cond = torch.cat([cond_emb, img_start_emb], dim=0).unsqueeze(0).repeat(num_images_per_prompt, 1, 1)
        input_embeds_uncond = None
        if guidance_scale > 1.0 and uncond_emb is not None:
            input_embeds_uncond = torch.cat([uncond_emb, img_start_emb], dim=0).unsqueeze(0).repeat(num_images_per_prompt, 1, 1)
        return input_embeds_cond, input_embeds_uncond, img_start_emb

    def _decode_tokens_to_image(self, image_latents: torch.Tensor, image_size: Tuple[int, int], ps: int = 1) -> torch.Tensor:
        h, w = image_size
        image_latents = rearrange(image_latents, "b (h w p1 p2) c -> b c (h p1) (w p2)", h=h // ps, w=w // ps, p1=ps, p2=ps)
        ae_dtype = next(self.autoencoder.parameters()).dtype
        image_latents = image_latents.to(dtype=ae_dtype)
        return self.autoencoder.decode(image_latents)

    @torch.no_grad()
    def _generate_single_prompt(
        self,
        prompt: str,
        height: int,
        width: int,
        num_inference_steps: int,
        guidance_scale: float,
        num_images_per_prompt: int,
        generator: Optional[torch.Generator],
        show_progress_bar: bool,
    ) -> torch.Tensor:
        image_size = (height, width)
        if list(image_size) not in self.supported_image_sizes:
            raise ValueError(
                f"image_size {list(image_size)} is not supported. "
                f"Please choose from {self.supported_image_sizes}"
            )

        h, w = height // self.vae_patch_size, width // self.vae_patch_size
        max_length = h * w
        step_width = self.parallel_num
        if max_length % step_width != 0:
            raise ValueError(
                f"max_length ({max_length}) must be divisible by parallel_num ({step_width})."
            )
        num_steps = max_length // step_width

        device = self._execution_device_fallback()
        model = self.text_encoder.model
        dtype = next(self.text_encoder.parameters()).dtype

        input_embeds_cond, input_embeds_uncond, _ = self._encode_prompt_to_embeds(
            prompt=prompt,
            image_size=image_size,
            num_images_per_prompt=num_images_per_prompt,
            guidance_scale=guidance_scale,
        )
        pos_embed_for_diff = self._get_2d_embed(h, w, ps=self.ps).unsqueeze(0).to(device=device, dtype=dtype)

        autocast_ctx = (
            torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16)
            if device.type == "cuda"
            else nullcontext()
        )

        with autocast_ctx:
            outputs_c = model(inputs_embeds=input_embeds_cond[:, :-step_width, :], use_cache=True)
            pkv_c = outputs_c.past_key_values

            bi_attn_mask = torch.ones(
                (input_embeds_cond.shape[0], 1, step_width, step_width + _get_pkv_seq_len(pkv_c)),
                dtype=torch.bool,
                device=device,
            )
            outputs_c = model(
                inputs_embeds=input_embeds_cond[:, -step_width:, :],
                past_key_values=pkv_c,
                use_cache=True,
                attention_mask=bi_attn_mask,
            )
            pkv_c = outputs_c.past_key_values
            hidden_c = outputs_c.last_hidden_state[:, -step_width:]

            hidden_u = None
            pkv_u = None
            if guidance_scale > 1.0 and input_embeds_uncond is not None:
                outputs_u = model(inputs_embeds=input_embeds_uncond[:, :-step_width, :], use_cache=True)
                pkv_u = outputs_u.past_key_values
                bi_attn_mask_u = torch.ones(
                    (input_embeds_uncond.shape[0], 1, step_width, step_width + _get_pkv_seq_len(pkv_u)),
                    dtype=torch.bool,
                    device=device,
                )
                outputs_u = model(
                    inputs_embeds=input_embeds_uncond[:, -step_width:, :],
                    past_key_values=pkv_u,
                    use_cache=True,
                    attention_mask=bi_attn_mask_u,
                )
                pkv_u = outputs_u.past_key_values
                hidden_u = outputs_u.last_hidden_state[:, -step_width:]

            out_tokens = []
            step_iter = range(num_steps)
            if show_progress_bar:
                step_iter = tqdm(step_iter, total=num_steps, desc="Decoding steps")

            for step in step_iter:
                if guidance_scale > 1.0 and hidden_u is not None:
                    h_fused = torch.cat([hidden_c, hidden_u], dim=0)
                else:
                    h_fused = hidden_c

                pos_slice = pos_embed_for_diff[:, step * step_width : (step + 1) * step_width, :]
                h_fused = h_fused + pos_slice
                pred_latents = self.diffusion_head.sample(
                    h_fused,
                    num_sampling_steps=num_inference_steps,
                    cfg=guidance_scale,
                    generator=generator,
                )
                curr_tokens = torch.sign(pred_latents)
                curr_embeds = self.projector(curr_tokens)
                out_tokens.append(curr_tokens[:num_images_per_prompt])

                model_input = curr_embeds + pos_slice
                bi_attn_mask = torch.ones(
                    (model_input.shape[0], 1, model_input.shape[1], model_input.shape[1] + _get_pkv_seq_len(pkv_c)),
                    dtype=torch.bool,
                    device=device,
                )
                outputs_c = model(
                    inputs_embeds=model_input[:num_images_per_prompt],
                    past_key_values=pkv_c,
                    use_cache=True,
                    attention_mask=bi_attn_mask[:num_images_per_prompt],
                )
                pkv_c = outputs_c.past_key_values
                hidden_c = outputs_c.last_hidden_state[:, -step_width:]

                if guidance_scale > 1.0 and hidden_u is not None and pkv_u is not None:
                    bi_attn_mask_u = torch.ones(
                        (model_input.shape[0], 1, model_input.shape[1], model_input.shape[1] + _get_pkv_seq_len(pkv_u)),
                        dtype=torch.bool,
                        device=device,
                    )
                    outputs_u = model(
                        inputs_embeds=model_input[num_images_per_prompt:],
                        past_key_values=pkv_u,
                        use_cache=True,
                        attention_mask=bi_attn_mask_u[num_images_per_prompt:],
                    )
                    pkv_u = outputs_u.past_key_values
                    hidden_u = outputs_u.last_hidden_state[:, -step_width:]

        full_output = torch.cat(out_tokens, dim=1)
        return self._decode_tokens_to_image(full_output, image_size=(h, w), ps=self.ps)

    @torch.no_grad()
    def __call__(
        self,
        prompt: PromptType,
        height: int = 1024,
        width: int = 1024,
        num_inference_steps: int = 50,
        guidance_scale: float = 7.5,
        num_images_per_prompt: int = 1,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        output_type: str = "pil",
        return_dict: bool = True,
        show_progress_bar: bool = False,
    ) -> Union[ImagePipelineOutput, Tuple]:
        prompts = [prompt] if isinstance(prompt, str) else list(prompt)
        if len(prompts) == 0:
            raise ValueError("prompt must be a non-empty string or list of strings.")

        if isinstance(generator, list) and len(generator) != len(prompts):
            raise ValueError("When passing a list of generators, its length must equal len(prompt).")

        image_tensors = []
        for i, prompt_text in enumerate(prompts):
            prompt_generator = generator[i] if isinstance(generator, list) else generator
            images = self._generate_single_prompt(
                prompt=prompt_text,
                height=height,
                width=width,
                num_inference_steps=num_inference_steps,
                guidance_scale=guidance_scale,
                num_images_per_prompt=num_images_per_prompt,
                generator=prompt_generator,
                show_progress_bar=show_progress_bar,
            )
            image_tensors.append(images)

        images_pt = torch.cat(image_tensors, dim=0)
        images_pt_01 = torch.clamp((images_pt + 1.0) / 2.0, 0.0, 1.0)

        if output_type == "pt":
            output_images = images_pt_01
        elif output_type == "np":
            output_images = images_pt_01.permute(0, 2, 3, 1).float().cpu().numpy()
        elif output_type == "pil":
            images_uint8 = (
                torch.clamp(127.5 * images_pt + 128.0, 0, 255)
                .permute(0, 2, 3, 1)
                .to("cpu", dtype=torch.uint8)
                .numpy()
            )
            output_images = [Image.fromarray(image) for image in images_uint8]
        else:
            raise ValueError(f"Unsupported output_type={output_type}. Expected 'pil', 'np', or 'pt'.")

        if not return_dict:
            return (output_images,)
        return ImagePipelineOutput(images=output_images)

    @torch.no_grad()
    def generate(
        self,
        prompt: str,
        height: int = 1024,
        width: int = 1024,
        num_sampling_steps: int = 50,
        guidance_scale: float = 7.5,
        num_images: int = 1,
        seed: Optional[int] = None,
    ) -> List[Image.Image]:
        generator = None
        if seed is not None:
            device = self._execution_device_fallback()
            generator_device = "cuda" if device.type == "cuda" else "cpu"
            generator = torch.Generator(device=generator_device).manual_seed(seed)
        output = self(
            prompt=prompt,
            height=height,
            width=width,
            num_inference_steps=num_sampling_steps,
            guidance_scale=guidance_scale,
            num_images_per_prompt=num_images,
            generator=generator,
            output_type="pil",
            return_dict=True,
            show_progress_bar=True,
        )
        return output.images