File size: 8,871 Bytes
dad14e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
# Training for FLUX

## Table of Contents
- [Training for FLUX](#training-for-flux)
  - [Table of Contents](#table-of-contents)
  - [Environment Setup](#environment-setup)
  - [Dataset Preparation](#dataset-preparation)
  - [Quick Start](#quick-start)
  - [Basic Training](#basic-training)
    - [Tasks from OminiControl](#tasks-from-ominicontrol)
    - [Creating Your Own Task](#creating-your-own-task)
    - [Training Configuration](#training-configuration)
      - [Batch Size](#batch-size)
      - [Optimizer](#optimizer)
      - [LoRA Configuration](#lora-configuration)
      - [Trainable Modules](#trainable-modules)
  - [Advanced Training](#advanced-training)
    - [Multi-condition](#multi-condition)
    - [Efficient Generation (OminiControl2)](#efficient-generation-ominicontrol2)
      - [Feature Reuse (KV-Cache)](#feature-reuse-kv-cache)
      - [Compact Encoding Representation](#compact-encoding-representation)
      - [Token Integration (for Fill task)](#token-integration-for-fill-task)
  - [Citation](#citation)

## Environment Setup

1. Create and activate a new conda environment:
   ```bash
   conda create -n omini python=3.10
   conda activate omini
   ```

2. Install required packages:
   ```bash
   pip install -r requirements.txt
   ```

## Dataset Preparation

1. Download [Subject200K](https://huggingface.co/datasets/Yuanshi/Subjects200K) dataset for subject-driven generation:
   ```bash
   bash train/script/data_download/data_download1.sh
   ```

2. Download [text-to-image-2M](https://huggingface.co/datasets/jackyhate/text-to-image-2M) dataset for spatial alignment control tasks:
   ```bash
   bash train/script/data_download/data_download2.sh
   ```
   
   **Note:** By default, only a few files will be downloaded. You can edit `data_download2.sh` to download more data, and update the config file accordingly.

## Quick Start

Use these scripts to start training immediately:

1. **Subject-driven generation**:
   ```bash
   bash train/script/train_subject.sh
   ```

2. **Spatial control tasks** (Canny-to-image, colorization, depth map, etc.):
   ```bash
   bash train/script/train_spatial_alignment.sh
   ```

3. **Multi-condition training**:
   ```bash
   bash train/script/train_multi_condition.sh
   ```

4. **Feature reuse** (OminiControl2):
   ```bash
   bash train/script/train_feature_reuse.sh
   ```

5. **Compact token representation** (OminiControl2):
   ```bash
   bash train/script/train_compact_token_representation.sh
   ```

6. **Token integration** (OminiControl2):
   ```bash
   bash train/script/train_token_intergration.sh
   ```

## Basic Training

### Tasks from OminiControl
<a href="https://arxiv.org/abs/2411.15098"><img src="https://img.shields.io/badge/ariXv-2411.15098-A42C25.svg" alt="arXiv"></a>

1. Subject-driven generation:
   ```bash
   bash train/script/train_subject.sh
   ```

2. Spatial control tasks (using canny-to-image as example):
   ```bash
   bash train/script/train_spatial_alignment.sh
   ```

   <details>
   <summary>Supported tasks</summary>

   * Canny edge to image (`canny`)
   * Image colorization (`coloring`)
   * Image deblurring (`deblurring`)
   * Depth map to image (`depth`)
   * Image to depth map (`depth_pred`)
   * Image inpainting (`fill`)
   * Super resolution (`sr`)
   
   🌟 Change the `condition_type` parameter in the config file to switch between tasks.
   </details>

**Note**: Check the **script files** (`train/script/`) and **config files** (`train/configs/`) for WanDB and GPU settings.

### Creating Your Own Task

You can create a custom task by building a new dataset and modifying the test code:

1. **Create a custom dataset:**
   Your custom dataset should follow the format of `Subject200KDataset` in `omini/train_flux/train_subject.py`. Each sample should contain:

   - Image: the target image (`image`)
   - Text: description of the image (`description`)
   - Conditions: image conditions for generation
   - Position delta:
     - Use `position_delta = (0, 0)` to align the condition with the generated image
     - Use `position_delta = (0, -a)` to separate them (a = condition width / 16)

   > **Explanation:**  
   > The model places both the condition and generated image in a shared coordinate system. `position_delta` shifts the condition image in this space.
   > 
   > Each unit equals one patch (16 pixels). For a 512px-wide condition image (32 patches), `position_delta = (0, -32)` moves it fully to the left.
   > 
   > This controls whether conditions and generated images share space or appear side-by-side.

2. **Modify the test code:**
   Define `test_function()` in `train_custom.py`. Refer to the function in `train_subject.py` for examples. Make sure to keep the `position_delta` parameter consistent with your dataset.

### Training Configuration

#### Batch Size
We recommend a batch size of 1 for stable training. And you can set `accumulate_grad_batches` to n to simulate a batch size of n. 

#### Optimizer
The default optimizer is `Prodigy`. To use `AdamW` instead, modify the config file:
```yaml
optimizer:
  type: AdamW
  lr: 1e-4
  weight_decay: 0.001
```

#### LoRA Configuration
Default LoRA rank is 4. Increase it for complex tasks (keep `r` and `lora_alpha` parameters the same):
```yaml
lora_config:
  r: 128
  lora_alpha: 128
```

#### Trainable Modules
The `target_modules` parameter uses regex patterns to specify which modules to train. See [PEFT Documentation](https://huggingface.co/docs/peft/package_reference/lora) for details.

Default configuration trains all modules affecting image tokens:
```yaml
target_modules: "(.*x_embedder|.*(?<!single_)transformer_blocks\\.[0-9]+\\.norm1\\.linear|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_k|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_q|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_v|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_out\\.0|.*(?<!single_)transformer_blocks\\.[0-9]+\\.ff\\.net\\.2|.*single_transformer_blocks\\.[0-9]+\\.norm\\.linear|.*single_transformer_blocks\\.[0-9]+\\.proj_mlp|.*single_transformer_blocks\\.[0-9]+\\.proj_out|.*single_transformer_blocks\\.[0-9]+\\.attn.to_k|.*single_transformer_blocks\\.[0-9]+\\.attn.to_q|.*single_transformer_blocks\\.[0-9]+\\.attn.to_v|.*single_transformer_blocks\\.[0-9]+\\.attn.to_out)"
```

To train only attention components (`to_q`, `to_k`, `to_v`), use:
```yaml
target_modules: "(.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_k|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_q|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_v|.*single_transformer_blocks\\.[0-9]+\\.attn.to_k|.*single_transformer_blocks\\.[0-9]+\\.attn.to_q|.*single_transformer_blocks\\.[0-9]+\\.attn.to_v)"
```

## Advanced Training

### Multi-condition
A basic multi-condition implementation is available in `train_multi_condition.py`:
```bash
bash train/script/train_multi_condition.sh
```

### Efficient Generation (OminiControl2)
<a href="https://arxiv.org/abs/2503.08280"><img src="https://img.shields.io/badge/ariXv-2503.08280-A42C25.svg" alt="arXiv"></a>

[OminiControl2](https://arxiv.org/abs/2503.08280) introduces techniques to improve generation efficiency:

#### Feature Reuse (KV-Cache)
1. Enable `independent_condition` in the config file during training:
   ```yaml
   model:
     independent_condition: true
   ```

2. During inference, set `kv_cache = True` in the `generate` function to speed up generation.

*Example:*
```bash
bash train/script/train_feature_reuse.sh
```

**Note:** Feature reuse speeds up generation but may slightly reduce performance and increase training time.

#### Compact Encoding Representation
Reduce the condition image resolution and use `position_scale` to align it with the output image:

```diff
train:
  dataset:
    condition_size: 
-     - 512
-     - 512
+     - 256
+     - 256
+   position_scale: 2
    target_size: 
      - 512
      - 512
```

*Example:*
```bash
bash train/script/train_compact_token_representation.sh
```

#### Token Integration (for Fill task)
Further reduce tokens by merging condition and generation tokens into a unified sequence. (Refer to [the paper](https://arxiv.org/abs/2503.08280) for details.)

*Example:*
```bash
bash train/script/train_token_intergration.sh
```

## Citation

If you find this code useful, please cite our papers:

```
@article{tan2024ominicontrol,
  title={OminiControl: Minimal and Universal Control for Diffusion Transformer},
  author={Tan, Zhenxiong and Liu, Songhua and Yang, Xingyi and Xue, Qiaochu and Wang, Xinchao},
  journal={arXiv preprint arXiv:2411.15098},
  year={2024}
}

@article{tan2025ominicontrol2,
  title={OminiControl2: Efficient Conditioning for Diffusion Transformers},
  author={Tan, Zhenxiong and Xue, Qiaochu and Yang, Xingyi and Liu, Songhua and Wang, Xinchao},
  journal={arXiv preprint arXiv:2503.08280},
  year={2025}
}
```