kelseye commited on
Commit
9355758
·
verified ·
1 Parent(s): 3ad80be

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/cat_Inpaint_1.jpg filter=lfs diff=lfs merge=lfs -text
37
+ assets/cat_Inpaint_2.jpg filter=lfs diff=lfs merge=lfs -text
38
+ assets/cat_base.jpg filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+ # Templates-Inpainting (FLUX.2-klein-base-4B)
5
+
6
+ This model is one of the open-source Diffusion Templates series models from [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio). Specifically designed for inpainting, it accepts an original image and a mask image, then generates new content within the masked region based on natural language prompts, seamlessly blending with the surrounding unmasked background.
7
+
8
+ ## Results
9
+
10
+ | Reference | Prompt | Mask | Generated |
11
+ |:---:|:---|:---:|:---:|
12
+ | ![](./assets/cat_base.jpg) | An orange cat is sitting on a stone. | ![](./assets/cat_mask_1.jpg) | ![](./assets/cat_Inpaint_1.jpg) |
13
+ | ![](./assets/cat_base.jpg) | A cat wearing sunglasses is sitting on a stone. | ![](./assets/cat_mask_2.jpg) | ![](./assets/cat_Inpaint_2.jpg) |
14
+
15
+ | Reference | Prompt | Mask | Generated |
16
+ |:---:|:---|:---:|:---:|
17
+ | ![](./assets/girl_base.jpg) | A beautiful young woman wearing a woven straw hat with a ribbon standing in a sunflower field. | ![](./assets/girl_mask_1.jpg) | ![](./assets/girl_Inpaint_1.jpg) |
18
+ | ![](./assets/girl_base.jpg) | A beautiful young woman wearing an elegant white dress standing in a glowing sunflower field. | ![](./assets/girl_mask_2.jpg) | ![](./assets/girl_Inpaint_2.jpg) |
19
+
20
+ | Reference | Prompt | Mask | Generated |
21
+ |:---:|:---|:---:|:---:|
22
+ | ![](./assets/room_base.jpg) | A sleek glass vase with a single blooming white lily and an open minimalist art book resting on the circular white marble coffee table. | ![](./assets/room1.jpg) | ![](./assets/room_Inpaint_1.jpg) |
23
+ | ![](./assets/room_base.jpg) | A large, minimalist flower painting hanging on the clean off-white wall above the sofa, soft shadows. | ![](./assets/room2.jpg) | ![](./assets/room_Inpaint_2.jpg) |
24
+
25
+ ## Inference Code
26
+
27
+ * Install [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio)
28
+
29
+ ```
30
+ git clone https://github.com/modelscope/DiffSynth-Studio.git
31
+ cd DiffSynth-Studio
32
+ pip install -e .
33
+ ```
34
+
35
+ * Direct inference (requires 40GB GPU memory)
36
+
37
+ ```python
38
+ from diffsynth.diffusion.template import TemplatePipeline
39
+ from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
40
+ import torch
41
+ from modelscope import dataset_snapshot_download
42
+ from PIL import Image
43
+ ```
44
+
45
+ ```python
46
+ pipe = Flux2ImagePipeline.from_pretrained(
47
+ torch_dtype=torch.bfloat16,
48
+ device="cuda",
49
+ model_configs=[
50
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors"),
51
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"),
52
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
53
+ ],
54
+ tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
55
+ )
56
+ template = TemplatePipeline.from_pretrained(
57
+ torch_dtype=torch.bfloat16,
58
+ device="cuda",
59
+ model_configs=[ModelConfig(model_id="DiffSynth-Studio/Template-KleinBase4B-Inpaint")],
60
+ )
61
+ dataset_snapshot_download(
62
+ "DiffSynth-Studio/examples_in_diffsynth",
63
+ allow_file_pattern=["templates/*"],
64
+ local_dir="data/examples",
65
+ )
66
+ image = template(
67
+ pipe,
68
+ prompt="An orange cat is sitting on a stone.",
69
+ seed=0, cfg_scale=4, num_inference_steps=50,
70
+ template_inputs = [{
71
+ "image": Image.open("data/examples/templates/image_reference.jpg"),
72
+ "mask": Image.open("data/examples/templates/image_mask_1.jpg"),
73
+ "force_inpaint": True,
74
+ }],
75
+ negative_template_inputs = [{
76
+ "image": Image.open("data/examples/templates/image_reference.jpg"),
77
+ "mask": Image.open("data/examples/templates/image_mask_1.jpg"),
78
+ }],
79
+ )
80
+ image.save("image_Inpaint_1.jpg")
81
+ image = template(
82
+ pipe,
83
+ prompt="A cat wearing sunglasses is sitting on a stone.",
84
+ seed=0, cfg_scale=4, num_inference_steps=50,
85
+ template_inputs = [{
86
+ "image": Image.open("data/examples/templates/image_reference.jpg"),
87
+ "mask": Image.open("data/examples/templates/image_mask_2.jpg"),
88
+ }],
89
+ negative_template_inputs = [{
90
+ "image": Image.open("data/examples/templates/image_reference.jpg"),
91
+ "mask": Image.open("data/examples/templates/image_mask_2.jpg"),
92
+ }],
93
+ )
94
+ image.save("image_Inpaint_2.jpg")
95
+ ```
96
+
97
+ * Enable lazy loading and memory management, requires 24G GPU memory
98
+
99
+ ```python
100
+ from diffsynth.diffusion.template import TemplatePipeline
101
+ from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
102
+ import torch
103
+ from modelscope import dataset_snapshot_download
104
+ from PIL import Image
105
+ ```
106
+
107
+ ```python
108
+ vram_config = {
109
+ "offload_dtype": "disk",
110
+ "offload_device": "disk",
111
+ "onload_dtype": torch.float8_e4m3fn,
112
+ "onload_device": "cpu",
113
+ "preparing_dtype": torch.float8_e4m3fn,
114
+ "preparing_device": "cuda",
115
+ "computation_dtype": torch.bfloat16,
116
+ "computation_device": "cuda",
117
+ }
118
+ pipe = Flux2ImagePipeline.from_pretrained(
119
+ torch_dtype=torch.bfloat16,
120
+ device="cuda",
121
+ model_configs=[
122
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors", **vram_config),
123
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
124
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
125
+ ],
126
+ tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
127
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
128
+ )
129
+ template = TemplatePipeline.from_pretrained(
130
+ torch_dtype=torch.bfloat16,
131
+ device="cuda",
132
+ model_configs=[ModelConfig(model_id="DiffSynth-Studio/Template-KleinBase4B-Inpaint")],
133
+ lazy_loading=True,
134
+ )
135
+ dataset_snapshot_download(
136
+ "DiffSynth-Studio/examples_in_diffsynth",
137
+ allow_file_pattern=["templates/*"],
138
+ local_dir="data/examples",
139
+ )
140
+ image = template(
141
+ pipe,
142
+ prompt="An orange cat is sitting on a stone.",
143
+ seed=0, cfg_scale=4, num_inference_steps=50,
144
+ template_inputs = [{
145
+ "image": Image.open("data/examples/templates/image_reference.jpg"),
146
+ "mask": Image.open("data/examples/templates/image_mask_1.jpg"),
147
+ "force_inpaint": True,
148
+ }],
149
+ negative_template_inputs = [{
150
+ "image": Image.open("data/examples/templates/image_reference.jpg"),
151
+ "mask": Image.open("data/examples/templates/image_mask_1.jpg"),
152
+ }],
153
+ )
154
+ image.save("image_Inpaint_1.jpg")
155
+ image = template(
156
+ pipe,
157
+ prompt="A cat wearing sunglasses is sitting on a stone.",
158
+ seed=0, cfg_scale=4, num_inference_steps=50,
159
+ template_inputs = [{
160
+ "image": Image.open("data/examples/templates/image_reference.jpg"),
161
+ "mask": Image.open("data/examples/templates/image_mask_2.jpg"),
162
+ }],
163
+ negative_template_inputs = [{
164
+ "image": Image.open("data/examples/templates/image_reference.jpg"),
165
+ "mask": Image.open("data/examples/templates/image_mask_2.jpg"),
166
+ }],
167
+ )
168
+ image.save("image_Inpaint_2.jpg")
169
+ ```
170
+
171
+ ## Training Code
172
+
173
+ After installing DiffSynth-Studio, use the following script to start training. For more information, please refer to the [DiffSynth-Studio Documentation](https://diffsynth-studio-doc.readthedocs.io/zh-cn/latest/).
174
+
175
+ ```shell
176
+ modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "flux2/Template-KleinBase4B-Inpaint/*" --local_dir ./data/diffsynth_example_dataset
177
+
178
+ accelerate launch examples/flux2/model_training/train.py \
179
+ --dataset_base_path data/diffsynth_example_dataset/flux2/Template-KleinBase4B-Inpaint \
180
+ --dataset_metadata_path data/diffsynth_example_dataset/flux2/Template-KleinBase4B-Inpaint/metadata.jsonl \
181
+ --extra_inputs "template_inputs" \
182
+ --max_pixels 1048576 \
183
+ --dataset_repeat 50 \
184
+ --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
185
+ --template_model_id_or_path "DiffSynth-Studio/Template-KleinBase4B-Inpaint:" \
186
+ --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
187
+ --learning_rate 1e-4 \
188
+ --num_epochs 2 \
189
+ --remove_prefix_in_ckpt "pipe.template_model." \
190
+ --output_path "./models/train/Template-KleinBase4B-Inpaint_full" \
191
+ --trainable_models "template_model" \
192
+ --use_gradient_checkpointing \
193
+ --find_unused_parameters
194
+ ```
README_from_modelscope.md ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ frameworks:
3
+ - Pytorch
4
+ license: Apache License 2.0
5
+ tags: []
6
+ tasks:
7
+ - text-to-image-synthesis
8
+ ---
9
+
10
+ # Templates-局部重绘(FLUX.2-klein-base-4B)
11
+
12
+ 本模型是 [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) 开源的 Diffusion Templates 系列模型之一。该模型专为局部重绘设计,能够接收原图和Mask图,并根据自然语言提示词在遮罩区域内生成全新的内容,同时无缝融合周围未被遮罩的图像背景。
13
+
14
+ ## 效果展示
15
+
16
+ | Reference | Prompt | Mask | Generated |
17
+ |:---:|:---|:---:|:---:|
18
+ | ![](./assets/cat_base.jpg) | An orange cat is sitting on a stone. | ![](./assets/cat_mask_1.jpg) | ![](./assets/cat_Inpaint_1.jpg) |
19
+ | ![](./assets/cat_base.jpg) | A cat wearing sunglasses is sitting on a stone. | ![](./assets/cat_mask_2.jpg) | ![](./assets/cat_Inpaint_2.jpg) |
20
+
21
+ | Reference | Prompt | Mask | Generated |
22
+ |:---:|:---|:---:|:---:|
23
+ | ![](./assets/girl_base.jpg) | A beautiful young woman wearing a woven straw hat with a ribbon standing in a sunflower field. | ![](./assets/girl_mask_1.jpg) | ![](./assets/girl_Inpaint_1.jpg) |
24
+ | ![](./assets/girl_base.jpg) | A beautiful young woman wearing an elegant white dress standing in a glowing sunflower field. | ![](./assets/girl_mask_2.jpg) | ![](./assets/girl_Inpaint_2.jpg) |
25
+
26
+ | Reference | Prompt | Mask | Generated |
27
+ |:---:|:---|:---:|:---:|
28
+ | ![](./assets/room_base.jpg) | A sleek glass vase with a single blooming white lily and an open minimalist art book resting on the circular white marble coffee table. | ![](./assets/room1.jpg) | ![](./assets/room_Inpaint_1.jpg) |
29
+ | ![](./assets/room_base.jpg) | A large, minimalist flower painting hanging on the clean off-white wall above the sofa, soft shadows. | ![](./assets/room2.jpg) | ![](./assets/room_Inpaint_2.jpg) |
30
+
31
+ ## 推理代码
32
+
33
+ * 安装 [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio)
34
+
35
+ ```
36
+ git clone https://github.com/modelscope/DiffSynth-Studio.git
37
+ cd DiffSynth-Studio
38
+ pip install -e .
39
+ ```
40
+
41
+ * 直接推理,需 40G 显存
42
+
43
+ ```python
44
+ from diffsynth.diffusion.template import TemplatePipeline
45
+ from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
46
+ import torch
47
+ from modelscope import dataset_snapshot_download
48
+ from PIL import Image
49
+
50
+ pipe = Flux2ImagePipeline.from_pretrained(
51
+ torch_dtype=torch.bfloat16,
52
+ device="cuda",
53
+ model_configs=[
54
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors"),
55
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"),
56
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
57
+ ],
58
+ tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
59
+ )
60
+ template = TemplatePipeline.from_pretrained(
61
+ torch_dtype=torch.bfloat16,
62
+ device="cuda",
63
+ model_configs=[ModelConfig(model_id="DiffSynth-Studio/Template-KleinBase4B-Inpaint")],
64
+ )
65
+ dataset_snapshot_download(
66
+ "DiffSynth-Studio/examples_in_diffsynth",
67
+ allow_file_pattern=["templates/*"],
68
+ local_dir="data/examples",
69
+ )
70
+ image = template(
71
+ pipe,
72
+ prompt="An orange cat is sitting on a stone.",
73
+ seed=0, cfg_scale=4, num_inference_steps=50,
74
+ template_inputs = [{
75
+ "image": Image.open("data/examples/templates/image_reference.jpg"),
76
+ "mask": Image.open("data/examples/templates/image_mask_1.jpg"),
77
+ "force_inpaint": True,
78
+ }],
79
+ negative_template_inputs = [{
80
+ "image": Image.open("data/examples/templates/image_reference.jpg"),
81
+ "mask": Image.open("data/examples/templates/image_mask_1.jpg"),
82
+ }],
83
+ )
84
+ image.save("image_Inpaint_1.jpg")
85
+ image = template(
86
+ pipe,
87
+ prompt="A cat wearing sunglasses is sitting on a stone.",
88
+ seed=0, cfg_scale=4, num_inference_steps=50,
89
+ template_inputs = [{
90
+ "image": Image.open("data/examples/templates/image_reference.jpg"),
91
+ "mask": Image.open("data/examples/templates/image_mask_2.jpg"),
92
+ }],
93
+ negative_template_inputs = [{
94
+ "image": Image.open("data/examples/templates/image_reference.jpg"),
95
+ "mask": Image.open("data/examples/templates/image_mask_2.jpg"),
96
+ }],
97
+ )
98
+ image.save("image_Inpaint_2.jpg")
99
+ ```
100
+
101
+ * 开启惰性加载和显存管理,需 24G 显存
102
+
103
+ ```python
104
+ from diffsynth.diffusion.template import TemplatePipeline
105
+ from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
106
+ import torch
107
+ from modelscope import dataset_snapshot_download
108
+ from PIL import Image
109
+
110
+ vram_config = {
111
+ "offload_dtype": "disk",
112
+ "offload_device": "disk",
113
+ "onload_dtype": torch.float8_e4m3fn,
114
+ "onload_device": "cpu",
115
+ "preparing_dtype": torch.float8_e4m3fn,
116
+ "preparing_device": "cuda",
117
+ "computation_dtype": torch.bfloat16,
118
+ "computation_device": "cuda",
119
+ }
120
+ pipe = Flux2ImagePipeline.from_pretrained(
121
+ torch_dtype=torch.bfloat16,
122
+ device="cuda",
123
+ model_configs=[
124
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors", **vram_config),
125
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
126
+ ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
127
+ ],
128
+ tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
129
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
130
+ )
131
+ template = TemplatePipeline.from_pretrained(
132
+ torch_dtype=torch.bfloat16,
133
+ device="cuda",
134
+ model_configs=[ModelConfig(model_id="DiffSynth-Studio/Template-KleinBase4B-Inpaint")],
135
+ lazy_loading=True,
136
+ )
137
+ dataset_snapshot_download(
138
+ "DiffSynth-Studio/examples_in_diffsynth",
139
+ allow_file_pattern=["templates/*"],
140
+ local_dir="data/examples",
141
+ )
142
+ image = template(
143
+ pipe,
144
+ prompt="An orange cat is sitting on a stone.",
145
+ seed=0, cfg_scale=4, num_inference_steps=50,
146
+ template_inputs = [{
147
+ "image": Image.open("data/examples/templates/image_reference.jpg"),
148
+ "mask": Image.open("data/examples/templates/image_mask_1.jpg"),
149
+ "force_inpaint": True,
150
+ }],
151
+ negative_template_inputs = [{
152
+ "image": Image.open("data/examples/templates/image_reference.jpg"),
153
+ "mask": Image.open("data/examples/templates/image_mask_1.jpg"),
154
+ }],
155
+ )
156
+ image.save("image_Inpaint_1.jpg")
157
+ image = template(
158
+ pipe,
159
+ prompt="A cat wearing sunglasses is sitting on a stone.",
160
+ seed=0, cfg_scale=4, num_inference_steps=50,
161
+ template_inputs = [{
162
+ "image": Image.open("data/examples/templates/image_reference.jpg"),
163
+ "mask": Image.open("data/examples/templates/image_mask_2.jpg"),
164
+ }],
165
+ negative_template_inputs = [{
166
+ "image": Image.open("data/examples/templates/image_reference.jpg"),
167
+ "mask": Image.open("data/examples/templates/image_mask_2.jpg"),
168
+ }],
169
+ )
170
+ image.save("image_Inpaint_2.jpg")
171
+ ```
172
+
173
+ ## 训练代码
174
+
175
+ 安装 DiffSynth-Studio 后,使用以下脚本可开启训练,更多信息请参考 [DiffSynth-Studio 文档](https://diffsynth-studio-doc.readthedocs.io/zh-cn/latest/)。
176
+
177
+ ```shell
178
+ modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "flux2/Template-KleinBase4B-Inpaint/*" --local_dir ./data/diffsynth_example_dataset
179
+
180
+ accelerate launch examples/flux2/model_training/train.py \
181
+ --dataset_base_path data/diffsynth_example_dataset/flux2/Template-KleinBase4B-Inpaint \
182
+ --dataset_metadata_path data/diffsynth_example_dataset/flux2/Template-KleinBase4B-Inpaint/metadata.jsonl \
183
+ --extra_inputs "template_inputs" \
184
+ --max_pixels 1048576 \
185
+ --dataset_repeat 50 \
186
+ --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
187
+ --template_model_id_or_path "DiffSynth-Studio/Template-KleinBase4B-Inpaint:" \
188
+ --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
189
+ --learning_rate 1e-4 \
190
+ --num_epochs 2 \
191
+ --remove_prefix_in_ckpt "pipe.template_model." \
192
+ --output_path "./models/train/Template-KleinBase4B-Inpaint_full" \
193
+ --trainable_models "template_model" \
194
+ --use_gradient_checkpointing \
195
+ --find_unused_parameters
196
+ ```
assets/cat_Inpaint_1.jpg ADDED

Git LFS Details

  • SHA256: 61f144a198a21d0e552a6c259a2d8376e4edd9129b63c0b8d07772b2a3f9ffb8
  • Pointer size: 131 Bytes
  • Size of remote file: 110 kB
assets/cat_Inpaint_2.jpg ADDED

Git LFS Details

  • SHA256: 8a75e21783a96c9738d6a68f69390ca62fd19711e084e887a105c4604efcf2dc
  • Pointer size: 131 Bytes
  • Size of remote file: 132 kB
assets/cat_base.jpg ADDED

Git LFS Details

  • SHA256: f113000383ad9e079689cd5be415aea94e80ae5ef597ec062e5ad94f4c95a63a
  • Pointer size: 131 Bytes
  • Size of remote file: 129 kB
assets/cat_mask_1.jpg ADDED
assets/cat_mask_2.jpg ADDED
assets/girl_Inpaint_1.jpg ADDED
assets/girl_Inpaint_2.jpg ADDED
assets/girl_base.jpg ADDED
assets/girl_mask_1.jpg ADDED
assets/girl_mask_2.jpg ADDED
assets/room1.jpg ADDED
assets/room2.jpg ADDED
assets/room_Inpaint_1.jpg ADDED
assets/room_Inpaint_2.jpg ADDED
assets/room_base.jpg ADDED
configuration.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"framework":"Pytorch","task":"text-to-image-synthesis"}
model.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional, Tuple, Union
2
+ import torch, math
3
+ import torch.nn as nn
4
+ from einops import rearrange
5
+ from diffsynth.core.attention import attention_forward
6
+ from diffsynth.core.gradient import gradient_checkpoint_forward
7
+ from diffsynth.models.flux2_dit import apply_rotary_emb, Flux2PosEmbed
8
+ from diffsynth.models.general_modules import get_timestep_embedding
9
+ from PIL import Image
10
+ import numpy as np
11
+
12
+
13
+ class AdaLayerNormContinuous(nn.Module):
14
+ def __init__(self, dim_in, dim_out, eps=1e-6):
15
+ super().__init__()
16
+ self.linear = nn.Linear(dim_in, dim_out * 2, bias=False)
17
+ self.norm = nn.LayerNorm(dim_in, eps=eps, elementwise_affine=False, bias=False)
18
+
19
+ def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
20
+ scale, shift = self.linear(torch.nn.functional.silu(conditioning_embedding)).chunk(2, dim=1)
21
+ x = self.norm(x) * (1 + scale) + shift
22
+ return x
23
+
24
+
25
+ class Flux2FeedForward(nn.Module):
26
+ def __init__(self, dim):
27
+ super().__init__()
28
+ self.linear_in = nn.Linear(dim, dim*3*2, bias=False)
29
+ self.linear_out = nn.Linear(dim*3, dim, bias=False)
30
+
31
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
32
+ x1, x2 = self.linear_in(x).chunk(2, dim=-1)
33
+ x = torch.nn.functional.silu(x1) * x2
34
+ x = self.linear_out(x)
35
+ return x
36
+
37
+
38
+ class Flux2TransformerBlock(nn.Module):
39
+ def __init__(self, dim, num_heads, eps=1e-6):
40
+ super().__init__()
41
+ self.img_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
42
+ self.txt_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
43
+
44
+ self.img_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
45
+ self.img_ff = Flux2FeedForward(dim)
46
+ self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
47
+ self.txt_ff = Flux2FeedForward(dim)
48
+
49
+ self.num_heads = num_heads
50
+ self.img_to_qkv = torch.nn.Linear(dim, 3 * dim, bias=False)
51
+ self.img_norm_q = torch.nn.RMSNorm(dim // num_heads, eps=eps)
52
+ self.img_norm_k = torch.nn.RMSNorm(dim // num_heads, eps=eps)
53
+ self.img_to_out = torch.nn.Linear(dim, dim, bias=False)
54
+ self.txt_to_qkv = torch.nn.Linear(dim, 3 * dim, bias=False)
55
+ self.txt_norm_q = torch.nn.RMSNorm(dim // num_heads, eps=eps)
56
+ self.txt_norm_k = torch.nn.RMSNorm(dim // num_heads, eps=eps)
57
+ self.txt_to_out = torch.nn.Linear(dim, dim, bias=False)
58
+
59
+ def attention(self, img: torch.Tensor, txt: torch.Tensor, image_rotary_emb: torch.Tensor, **kwargs) -> torch.Tensor:
60
+ img_q, img_k, img_v = self.img_to_qkv(img).chunk(3, dim=-1)
61
+ txt_q, txt_k, txt_v = self.txt_to_qkv(txt).chunk(3, dim=-1)
62
+ img_q, img_k, img_v, txt_q, txt_k, txt_v = tuple(map(lambda x: x.unflatten(-1, (self.num_heads, -1)), (img_q, img_k, img_v, txt_q, txt_k, txt_v)))
63
+ img_q = self.img_norm_q(img_q)
64
+ img_k = self.img_norm_k(img_k)
65
+ txt_q = self.txt_norm_q(txt_q)
66
+ txt_k = self.txt_norm_k(txt_k)
67
+
68
+ q = torch.cat([txt_q, img_q], dim=1)
69
+ k = torch.cat([txt_k, img_k], dim=1)
70
+ v = torch.cat([txt_v, img_v], dim=1)
71
+ q = apply_rotary_emb(q, image_rotary_emb, sequence_dim=1)
72
+ k = apply_rotary_emb(k, image_rotary_emb, sequence_dim=1)
73
+
74
+ img = attention_forward(q, k, v, q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s (n d)")
75
+ txt, img = img.split_with_sizes([txt.shape[1], img.shape[1] - txt.shape[1]], dim=1)
76
+ txt = self.txt_to_out(txt)
77
+ img = self.img_to_out(img)
78
+ return img, txt, (k, v)
79
+
80
+ def forward(self, img, txt, temb_mod_params_img, temb_mod_params_txt, image_rotary_emb):
81
+ (img_shift_msa, img_scale_msa, img_gate_msa), (img_shift_mlp, img_scale_mlp, img_gate_mlp) = temb_mod_params_img
82
+ (txt_shift_msa, txt_scale_msa, txt_gate_msa), (txt_shift_mlp, txt_scale_mlp, txt_gate_mlp) = temb_mod_params_txt
83
+
84
+ norm_img = (1 + img_scale_msa) * self.img_norm1(img) + img_shift_msa
85
+ norm_txt = (1 + txt_scale_msa) * self.txt_norm1(txt) + txt_shift_msa
86
+ img_attn_out, txt_attn_out, kv_cache = self.attention(norm_img, norm_txt, image_rotary_emb)
87
+
88
+ img = img + img_gate_msa * img_attn_out
89
+ norm_img = self.img_norm2(img) * (1 + img_scale_mlp) + img_shift_mlp
90
+ img = img + img_gate_mlp * self.img_ff(norm_img)
91
+
92
+ txt = txt + txt_gate_msa * txt_attn_out
93
+ norm_txt = self.txt_norm2(txt) * (1 + txt_scale_mlp) + txt_shift_mlp
94
+ txt = txt + txt_gate_mlp * self.txt_ff(norm_txt)
95
+ return txt, img, kv_cache
96
+
97
+
98
+ class Flux2SingleTransformerBlock(nn.Module):
99
+ def __init__(self, dim, num_heads, eps: float = 1e-6):
100
+ super().__init__()
101
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
102
+ self.dim = dim
103
+ self.num_heads = num_heads
104
+ self.norm_q = torch.nn.RMSNorm(dim // num_heads, eps=eps, elementwise_affine=True)
105
+ self.norm_k = torch.nn.RMSNorm(dim // num_heads, eps=eps, elementwise_affine=True)
106
+ self.to_qkv_mlp_proj = torch.nn.Linear(dim, dim * 3 + dim * 3 * 2, bias=False)
107
+ self.to_out = torch.nn.Linear(dim + dim * 3, dim, bias=False)
108
+
109
+ def attention(self, x: torch.Tensor, image_rotary_emb: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
110
+ x = self.to_qkv_mlp_proj(x)
111
+ qkv, mlp_x = torch.split(x, [3 * self.dim, self.dim * 3 * 2], dim=-1)
112
+ q, k, v = tuple(map(lambda x: x.unflatten(-1, (self.num_heads, -1)), qkv.chunk(3, dim=-1)))
113
+
114
+ q = self.norm_q(q)
115
+ k = self.norm_k(k)
116
+ q = apply_rotary_emb(q, image_rotary_emb, sequence_dim=1)
117
+ k = apply_rotary_emb(k, image_rotary_emb, sequence_dim=1)
118
+ x = attention_forward(q, k, v, q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s (n d)")
119
+
120
+ x1, x2 = mlp_x.chunk(2, dim=-1)
121
+ x = torch.cat([x, torch.nn.functional.silu(x1) * x2], dim=-1)
122
+ x = self.to_out(x)
123
+ return x, (k, v)
124
+
125
+ def forward(self, x, temb_mod_params, image_rotary_emb):
126
+ mod_shift, mod_scale, mod_gate = temb_mod_params
127
+ norm_x = (1 + mod_scale) * self.norm(x) + mod_shift
128
+ attn_output, kv_cache = self.attention(x=norm_x, image_rotary_emb=image_rotary_emb,)
129
+ x = x + mod_gate * attn_output
130
+ return x, kv_cache
131
+
132
+
133
+ class Flux2TimestepGuidanceEmbeddings(nn.Module):
134
+ def __init__(self, dim_in, dim_out):
135
+ super().__init__()
136
+ self.dim_in = dim_in
137
+ self.timestep_embedder = torch.nn.Sequential(nn.Linear(dim_in, dim_out, bias=False), nn.SiLU(), nn.Linear(dim_out, dim_out, bias=False))
138
+
139
+ def forward(self, timestep: torch.Tensor) -> torch.Tensor:
140
+ timesteps_proj = get_timestep_embedding(timestep, self.dim_in, flip_sin_to_cos=True, downscale_freq_shift=0)
141
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype))
142
+ return timesteps_emb
143
+
144
+
145
+ class Flux2Modulation(nn.Module):
146
+ def __init__(self, dim: int, mod_param_sets: int = 2, bias: bool = False):
147
+ super().__init__()
148
+ self.mod_param_sets = mod_param_sets
149
+ self.linear = nn.Linear(dim, dim * 3 * self.mod_param_sets, bias=bias)
150
+
151
+ def forward(self, temb: torch.Tensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]:
152
+ mod = torch.nn.functional.silu(temb)
153
+ mod = self.linear(mod)
154
+ mod = mod.unsqueeze(1)
155
+ mod_params = torch.chunk(mod, 3 * self.mod_param_sets, dim=-1)
156
+ return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(self.mod_param_sets))
157
+
158
+
159
+ class Flux2DiTVariantModel(torch.nn.Module):
160
+ def __init__(
161
+ self,
162
+ patch_size: int = 1,
163
+ in_channels: int = 128,
164
+ out_channels: Optional[int] = None,
165
+ num_layers: int = 5,
166
+ num_single_layers: int = 20,
167
+ attention_head_dim: int = 128,
168
+ num_attention_heads: int = 24,
169
+ joint_attention_dim: int = 7680,
170
+ timestep_guidance_channels: int = 256,
171
+ axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32),
172
+ rope_theta: int = 2000,
173
+ ):
174
+ super().__init__()
175
+ self.out_channels = out_channels or in_channels
176
+ self.inner_dim = num_attention_heads * attention_head_dim
177
+
178
+ # 1. Sinusoidal positional embedding for RoPE on image and text tokens
179
+ self.pos_embed = Flux2PosEmbed(theta=rope_theta, axes_dim=axes_dims_rope)
180
+
181
+ # 2. Combined timestep + guidance embedding
182
+ self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings(
183
+ dim_in=timestep_guidance_channels,
184
+ dim_out=self.inner_dim,
185
+ )
186
+
187
+ # 3. Modulation (double stream and single stream blocks share modulation parameters, resp.)
188
+ # Two sets of shift/scale/gate modulation parameters for the double stream attn and FF sub-blocks
189
+ self.double_stream_modulation_img = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False)
190
+ self.double_stream_modulation_txt = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False)
191
+ # Only one set of modulation parameters as the attn and FF sub-blocks are run in parallel for single stream
192
+ self.single_stream_modulation = Flux2Modulation(self.inner_dim, mod_param_sets=1, bias=False)
193
+
194
+ # 4. Input projections
195
+ self.img_embedder = nn.Linear(in_channels, self.inner_dim, bias=False)
196
+ self.txt_embedder = nn.Linear(joint_attention_dim, self.inner_dim, bias=False)
197
+
198
+ # 5. Double Stream Transformer Blocks
199
+ self.transformer_blocks = nn.ModuleList([Flux2TransformerBlock(dim=self.inner_dim, num_heads=num_attention_heads) for _ in range(num_layers)])
200
+
201
+ # 6. Single Stream Transformer Blocks
202
+ self.single_transformer_blocks = nn.ModuleList([Flux2SingleTransformerBlock(dim=self.inner_dim, num_heads=num_attention_heads) for _ in range(num_single_layers)])
203
+
204
+ # 7. Output layers
205
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim)
206
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False)
207
+
208
+ def prepare_static_parameters(self, img, txt):
209
+ timestep = torch.zeros((1,), dtype=txt.dtype, device=txt.device)
210
+ img_ids = []
211
+ for latent_id, latent in enumerate(img):
212
+ _, _, height, width = latent.shape
213
+ x_ids = torch.cartesian_prod(torch.tensor([(latent_id + 1) * 10]), torch.arange(height), torch.arange(width), torch.arange(1))
214
+ img_ids.append(x_ids)
215
+ img_ids = torch.cat(img_ids, dim=0).to(txt.device)
216
+ txt_ids = torch.cartesian_prod(torch.arange(1), torch.arange(1), torch.arange(1), torch.arange(txt.shape[1])).to(txt.device)
217
+ return timestep, img_ids, txt_ids
218
+
219
+ def patchify(self, img):
220
+ img_ = []
221
+ for latent in img:
222
+ latent = rearrange(latent, "B C H W -> B (H W) C")
223
+ img_.append(latent)
224
+ img_ = torch.concat(img_, dim=1)
225
+ return img_
226
+
227
+ def process_image(self, image, mask):
228
+ mask = mask.convert("RGB").resize(image.size)
229
+ mask = np.array(mask).mean(axis=-1)
230
+ image = np.array(image)
231
+ image[mask > 127] = 0
232
+ return Image.fromarray(image), Image.fromarray(mask).convert("RGB")
233
+
234
+ @torch.no_grad()
235
+ def process_inputs(
236
+ self,
237
+ pipe,
238
+ image,
239
+ mask,
240
+ prompt="Complete the content in the annotated region of the image.",
241
+ force_inpaint=False,
242
+ **kwargs
243
+ ):
244
+ masked_image, mask = self.process_image(image, mask)
245
+ images = [masked_image, mask]
246
+ pipe.load_models_to_device(["vae"])
247
+ kv_cache_input_latents = [pipe.vae.encode(pipe.preprocess_image(image)) for image in images]
248
+ prompt_emb_unit = [unit for unit in pipe.units if unit.__class__.__name__ == "Flux2Unit_Qwen3PromptEmbedder"][0]
249
+ kv_cache_prompt_emb = prompt_emb_unit.process(pipe, prompt)["prompt_embeds"]
250
+ pipe.load_models_to_device([])
251
+ return {
252
+ "kv_cache_input_latents": kv_cache_input_latents,
253
+ "kv_cache_prompt_emb": kv_cache_prompt_emb,
254
+ "image": image,
255
+ "mask": mask,
256
+ "force_inpaint": force_inpaint,
257
+ }
258
+
259
+ def forward(
260
+ self,
261
+ kv_cache_input_latents,
262
+ kv_cache_prompt_emb,
263
+ use_gradient_checkpointing=False,
264
+ use_gradient_checkpointing_offload=False,
265
+ image=None,
266
+ mask=None,
267
+ force_inpaint=False,
268
+ **kwargs,
269
+ ):
270
+ img = kv_cache_input_latents
271
+ txt = kv_cache_prompt_emb
272
+ num_txt_tokens = txt.shape[1]
273
+
274
+ # 1. Calculate timestep embedding and modulation parameters
275
+ timestep, img_ids, txt_ids = self.prepare_static_parameters(img, txt)
276
+ img = self.patchify(img)
277
+
278
+ temb = self.time_guidance_embed(timestep)
279
+ double_stream_mod_img = self.double_stream_modulation_img(temb)
280
+ double_stream_mod_txt = self.double_stream_modulation_txt(temb)
281
+ single_stream_mod = self.single_stream_modulation(temb)[0]
282
+
283
+ # 2. Input projection for image (img) and conditioning text (txt)
284
+ img = self.img_embedder(img)
285
+ txt = self.txt_embedder(txt)
286
+
287
+ # 3. Calculate RoPE embeddings from image and text tokens
288
+ image_rotary_emb = self.pos_embed(img_ids)
289
+ text_rotary_emb = self.pos_embed(txt_ids)
290
+ concat_rotary_emb = (
291
+ torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0),
292
+ torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
293
+ )
294
+
295
+ # 4. Double Stream Transformer Blocks
296
+ kv_cache = {}
297
+ for block_id, block in enumerate(self.transformer_blocks):
298
+ txt, img, kv_cache_ = gradient_checkpoint_forward(
299
+ block,
300
+ use_gradient_checkpointing=use_gradient_checkpointing,
301
+ use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
302
+ img=img,
303
+ txt=txt,
304
+ temb_mod_params_img=double_stream_mod_img,
305
+ temb_mod_params_txt=double_stream_mod_txt,
306
+ image_rotary_emb=concat_rotary_emb,
307
+ )
308
+ kv_cache[f"double_{block_id}"] = kv_cache_
309
+ # Concatenate text and image streams for single-block inference
310
+ img = torch.cat([txt, img], dim=1)
311
+
312
+ # 5. Single Stream Transformer Blocks
313
+ for block_id, block in enumerate(self.single_transformer_blocks):
314
+ img, kv_cache_ = gradient_checkpoint_forward(
315
+ block,
316
+ use_gradient_checkpointing=use_gradient_checkpointing,
317
+ use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
318
+ x=img,
319
+ temb_mod_params=single_stream_mod,
320
+ image_rotary_emb=concat_rotary_emb,
321
+ )
322
+ kv_cache[f"single_{block_id}"] = kv_cache_
323
+ # # Remove text tokens from concatenated stream
324
+ # img = img[:, num_txt_tokens:, ...]
325
+
326
+ # # 6. Output layers
327
+ # img = self.norm_out(img, temb)
328
+ # output = self.proj_out(img)
329
+
330
+ results = {"kv_cache": kv_cache}
331
+ if force_inpaint:
332
+ results.update({
333
+ "input_image": image,
334
+ "inpaint_mask": mask,
335
+ "inpaint_blur_size": 1,
336
+ "inpaint_blur_sigma": 1,
337
+ })
338
+ return results
339
+
340
+
341
+ class TrainDataProcessor:
342
+ def __init__(self):
343
+ from diffsynth.core import UnifiedDataset
344
+ self.image_oparator = UnifiedDataset.default_image_operator(
345
+ base_path="", # If your dataset contains relative paths, please specify the root path here.
346
+ max_pixels=1024*1024,
347
+ height_division_factor=16,
348
+ width_division_factor=16,
349
+ )
350
+
351
+ def generate_bbox(self, height, width):
352
+ h = torch.randint(10, height - 10, (1,)).item()
353
+ w = torch.randint(10, width - 10, (1,)).item()
354
+ x = torch.randint(0, height - h, (1,)).item()
355
+ y = torch.randint(0, width - w, (1,)).item()
356
+ return x, x + h, y, y + w
357
+
358
+ def generate_mask(self, image):
359
+ image = np.array(image)
360
+ height, width, _ = image.shape
361
+ x, x_, y, y_ = self.generate_bbox(height, width)
362
+ image[x: x_, y: y_] = 0
363
+
364
+ mask = np.zeros_like(image)
365
+ mask[x: x_, y: y_] = 255
366
+ return Image.fromarray(image), Image.fromarray(mask)
367
+
368
+ def __call__(self, image, **kwargs):
369
+ image = self.image_oparator(image)
370
+ image, mask = self.generate_mask(image)
371
+ return {
372
+ "image": image,
373
+ "mask": mask,
374
+ }
375
+
376
+ TEMPLATE_MODEL = Flux2DiTVariantModel
377
+ TEMPLATE_MODEL_PATH = "model.safetensors"
378
+ TEMPLATE_DATA_PROCESSOR = TrainDataProcessor
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:57ef6ed07ef3a159bb2c5424c12efef054b3250c2d25ba1ff1cf6960b6d4cb94
3
+ size 7751106784