BiliSakura commited on
Commit
97e78d6
·
verified ·
1 Parent(s): a5b028f

Update all files for BitDance-ImageNet-diffusers

Browse files
BitDance_B_1x/pipeline_bitdance_imagenet.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Sequence, Union
4
+
5
+ import torch
6
+
7
+ from diffusers import DiffusionPipeline
8
+ from diffusers.pipelines.pipeline_utils import ImagePipelineOutput
9
+
10
+
11
+ class BitDanceImageNetPipeline(DiffusionPipeline):
12
+ model_cpu_offload_seq = "transformer"
13
+
14
+ def __init__(self, transformer, autoencoder=None):
15
+ super().__init__()
16
+ self.register_modules(transformer=transformer, autoencoder=autoencoder)
17
+
18
+ @torch.no_grad()
19
+ def __call__(
20
+ self,
21
+ class_labels: Union[int, Sequence[int]] = 0,
22
+ num_images_per_label: int = 1,
23
+ sample_steps: int = 100,
24
+ cfg_scale: float = 4.6,
25
+ chunk_size: int = 0,
26
+ output_type: str = "pil",
27
+ return_dict: bool = True,
28
+ ):
29
+ device = self._execution_device
30
+
31
+ if isinstance(class_labels, int):
32
+ labels = [class_labels]
33
+ else:
34
+ labels = list(class_labels)
35
+
36
+ class_ids = torch.tensor(labels, device=device, dtype=torch.long)
37
+ if num_images_per_label > 1:
38
+ class_ids = class_ids.repeat_interleave(num_images_per_label)
39
+
40
+ images = self.transformer.sample(
41
+ class_ids=class_ids,
42
+ sample_steps=sample_steps,
43
+ cfg_scale=cfg_scale,
44
+ chunk_size=chunk_size,
45
+ )
46
+
47
+ images = (images / 2 + 0.5).clamp(0, 1)
48
+ images = images.cpu().permute(0, 2, 3, 1).float().numpy()
49
+
50
+ if output_type == "pil":
51
+ images = self.numpy_to_pil(images)
52
+
53
+ if not return_dict:
54
+ return (images,)
55
+
56
+ return ImagePipelineOutput(images=images)