diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..13f0b18ad88d9714f6ddf79a06560a0a4a0b8def 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,21 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+assets/cartoon_boy.png filter=lfs diff=lfs merge=lfs -text
+assets/clock.jpg filter=lfs diff=lfs merge=lfs -text
+assets/monalisa.jpg filter=lfs diff=lfs merge=lfs -text
+assets/rc_car.jpg filter=lfs diff=lfs merge=lfs -text
+assets/room_corner.jpg filter=lfs diff=lfs merge=lfs -text
+assets/tshirt.jpg filter=lfs diff=lfs merge=lfs -text
+assets/vase_hq.jpg filter=lfs diff=lfs merge=lfs -text
+assets/demo/art1.png filter=lfs diff=lfs merge=lfs -text
+assets/demo/art2.png filter=lfs diff=lfs merge=lfs -text
+assets/demo/demo_this_is_omini_control.jpg filter=lfs diff=lfs merge=lfs -text
+assets/demo/dreambooth_res.jpg filter=lfs diff=lfs merge=lfs -text
+assets/demo/monalisa_omini.jpg filter=lfs diff=lfs merge=lfs -text
+assets/demo/scene_variation.jpg filter=lfs diff=lfs merge=lfs -text
+assets/demo/try_on.jpg filter=lfs diff=lfs merge=lfs -text
+assets/ominicontrol_art/DistractedBoyfriend.webp filter=lfs diff=lfs merge=lfs -text
+assets/ominicontrol_art/PulpFiction.jpg filter=lfs diff=lfs merge=lfs -text
+assets/ominicontrol_art/breakingbad.jpg filter=lfs diff=lfs merge=lfs -text
+assets/ominicontrol_art/oiiai.png filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..2fdb2d7733471841f7ba4b7787da9a3da03e5916
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,229 @@
+wandb/*
+runs/*
+
+
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[codz]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py.cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+# Pipfile.lock
+
+# UV
+# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# uv.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+# poetry.lock
+# poetry.toml
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
+# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
+# pdm.lock
+# pdm.toml
+.pdm-python
+.pdm-build/
+
+# pixi
+# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
+# pixi.lock
+# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
+# in the .venv directory. It is recommended not to include this directory in version control.
+.pixi
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# Redis
+*.rdb
+*.aof
+*.pid
+
+# RabbitMQ
+mnesia/
+rabbitmq/
+rabbitmq-data/
+
+# ActiveMQ
+activemq-data/
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.envrc
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+# and can be added to the global gitignore or merged into this file. For a more nuclear
+# option (not recommended) you can uncomment the following to ignore the entire idea folder.
+# .idea/
+
+# Abstra
+# Abstra is an AI-powered process automation framework.
+# Ignore directories containing user credentials, local state, and settings.
+# Learn more at https://abstra.io/docs
+.abstra/
+
+# Visual Studio Code
+# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
+# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
+# and can be added to the global gitignore or merged into this file. However, if you prefer,
+# you could uncomment the following to ignore the entire vscode folder
+# .vscode/
+
+# Ruff stuff:
+.ruff_cache/
+
+# PyPI configuration file
+.pypirc
+
+# Marimo
+marimo/_static/
+marimo/_lsp/
+__marimo__/
+
+# Streamlit
+.streamlit/secrets.toml
+
+
+# exps/
+# wandb/
+# *.ipynb
+# glue_exp/
+# logs_hyper/
+# # grid/
+# glue22_ex
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..7b2de994465d5d0c1f37fc6aecd065f8849ed2d8
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [2024] [Zhenxiong Tan]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..0c9cd9b0159b7fb66d114665a2c3b25030a4e3b0
--- /dev/null
+++ b/README.md
@@ -0,0 +1,198 @@
+# OminiControl
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+> **OminiControl: Minimal and Universal Control for Diffusion Transformer**
+>
+> Zhenxiong Tan,
+> [Songhua Liu](http://121.37.94.87/),
+> [Xingyi Yang](https://adamdad.github.io/),
+> Qiaochu Xue,
+> and
+> [Xinchao Wang](https://sites.google.com/site/sitexinchaowang/)
+>
+> [xML Lab](https://sites.google.com/view/xml-nus), National University of Singapore
+>
+
+> **OminiControl2: Efficient Conditioning for Diffusion Transformers**
+>
+> Zhenxiong Tan,
+> Qiaochu Xue,
+> [Xingyi Yang](https://adamdad.github.io/),
+> [Songhua Liu](http://121.37.94.87/),
+> and
+> [Xinchao Wang](https://sites.google.com/site/sitexinchaowang/)
+>
+> [xML Lab](https://sites.google.com/view/xml-nus), National University of Singapore
+>
+
+
+
+## Features
+
+OminiControl is a minimal yet powerful universal control framework for Diffusion Transformer models like [FLUX](https://github.com/black-forest-labs/flux).
+
+* **Universal Control π**: A unified control framework that supports both subject-driven control and spatial control (such as edge-guided and in-painting generation).
+
+* **Minimal Design π**: Injects control signals while preserving original model structure. Only introduces 0.1% additional parameters to the base model.
+
+## News
+- **2025-05-12**: βοΈ The code of [OminiControl2](https://arxiv.org/abs/2503.08280) is released. It introduces a new efficient conditioning method for diffusion transformers. (Check out the training code [here](./train)).
+- **2025-05-12**: Support custom style LoRA. (Check out the [example](./examples/combine_with_style_lora.ipynb)).
+- **2025-04-09**: βοΈ [OminiControl Art](https://huggingface.co/spaces/Yuanshi/OminiControl_Art) is released. It can stylize any image with a artistic style. (Check out the [demo](https://huggingface.co/spaces/Yuanshi/OminiControl_Art) and [inference examples](./examples/ominicontrol_art.ipynb)).
+- **2024-12-26**: Training code are released. Now you can create your own OminiControl model by customizing any control tasks (3D, multi-view, pose-guided, try-on, etc.) with the FLUX model. Check the [training folder](./train) for more details.
+
+## Quick Start
+### Setup (Optional)
+1. **Environment setup**
+```bash
+conda create -n omini python=3.12
+conda activate omini
+```
+2. **Requirements installation**
+```bash
+pip install -r requirements.txt
+```
+### Usage example
+1. Subject-driven generation: `examples/subject.ipynb`
+2. In-painting: `examples/inpainting.ipynb`
+3. Canny edge to image, depth to image, colorization, deblurring: `examples/spatial.ipynb`
+
+
+### Guidelines for subject-driven generation
+1. Input images are automatically center-cropped and resized to 512x512 resolution.
+2. When writing prompts, refer to the subject using phrases like `this item`, `the object`, or `it`. e.g.
+ 1. *A close up view of this item. It is placed on a wooden table.*
+ 2. *A young lady is wearing this shirt.*
+3. The model primarily works with objects rather than human subjects currently, due to the absence of human data in training.
+
+## Generated samples
+### Subject-driven generation
+
+
+**Demos** (Left: condition image; Right: generated image)
+
+
+
+
+Text Prompts
+
+- Prompt1: *A close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show. With text on the screen that reads 'Omini Control!.'*
+- Prompt2: *A film style shot. On the moon, this item drives across the moon surface. A flag on it reads 'Omini'. The background is that Earth looms large in the foreground.*
+- Prompt3: *In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.*
+- Prompt4: *"On the beach, a lady sits under a beach umbrella with 'Omini' written on it. She's wearing this shirt and has a big smile on her face, with her surfboard hehind her. The sun is setting in the background. The sky is a beautiful shade of orange and purple."*
+
+
+More results
+
+* Try on:
+
+* Scene variations:
+
+* Dreambooth dataset:
+
+* Oye-cartoon finetune:
+
+

+

+
+
+
+### Spatially aligned control
+1. **Image Inpainting** (Left: original image; Center: masked image; Right: filled image)
+ - Prompt: *The Mona Lisa is wearing a white VR headset with 'Omini' written on it.*
+
+
+ - Prompt: *A yellow book with the word 'OMINI' in large font on the cover. The text 'for FLUX' appears at the bottom.*
+
+
+2. **Other spatially aligned tasks** (Canny edge to image, depth to image, colorization, deblurring)
+
+
+ Click to show
+
+
+ Prompt: *A light gray sofa stands against a white wall, featuring a black and white geometric patterned pillow. A white side table sits next to the sofa, topped with a white adjustable desk lamp and some books. Dark hardwood flooring contrasts with the pale walls and furniture.*
+
+
+### Stylize images
+
+
+
+
+
+
+
+
+## Models
+
+**Subject-driven control:**
+| Model | Base model | Description | Resolution |
+| ------------------------------------------------------------------------------------------------ | -------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------ |
+| [`experimental`](https://huggingface.co/Yuanshi/OminiControl/tree/main/experimental) / `subject` | FLUX.1-schnell | The model used in the paper. | (512, 512) |
+| [`omini`](https://huggingface.co/Yuanshi/OminiControl/tree/main/omini) / `subject_512` | FLUX.1-schnell | The model has been fine-tuned on a larger dataset. | (512, 512) |
+| [`omini`](https://huggingface.co/Yuanshi/OminiControl/tree/main/omini) / `subject_1024` | FLUX.1-schnell | The model has been fine-tuned on a larger dataset and accommodates higher resolution. | (1024, 1024) |
+| [`oye-cartoon`](https://huggingface.co/saquiboye/oye-cartoon) | FLUX.1-dev | The model has been fine-tuned on [oye-cartoon](https://huggingface.co/datasets/saquiboye/oye-cartoon) dataset by [@saquib764](https://github.com/Saquib764) | (512, 512) |
+
+**Spatial aligned control:**
+| Model | Base model | Description | Resolution |
+| --------------------------------------------------------------------------------------------------------- | ---------- | -------------------------------------------------------------------------- | ------------ |
+| [`experimental`](https://huggingface.co/Yuanshi/OminiControl/tree/main/experimental) / `` | FLUX.1 | Canny edge to image, depth to image, colorization, deblurring, in-painting | (512, 512) |=
+
+## Community Extensions
+- [ComfyUI-Diffusers-OminiControl](https://github.com/Macoron/ComfyUI-Diffusers-OminiControl) - ComfyUI integration by [@Macoron](https://github.com/Macoron)
+- [ComfyUI_RH_OminiControl](https://github.com/HM-RunningHub/ComfyUI_RH_OminiControl) - ComfyUI integration by [@HM-RunningHub](https://github.com/HM-RunningHub)
+
+## Limitations
+1. The model's subject-driven generation primarily works with objects rather than human subjects due to the absence of human data in training.
+2. The subject-driven generation model may not work well with `FLUX.1-dev`.
+3. The released model only supports the resolution of 512x512.
+
+## Training
+Training instructions can be found in this [folder](./train).
+
+
+## To-do
+- [x] Release the training code.
+- [x] Release the model for higher resolution (1024x1024).
+
+## Acknowledgment
+We would like to acknowledge that the computational work involved in this research work is partially supported by NUS ITβs Research Computing group using grant numbers NUSREC-HPC-00001.
+
+## Citation
+```
+@article{tan2025ominicontrol,
+ title={OminiControl: Minimal and Universal Control for Diffusion Transformer},
+ author={Tan, Zhenxiong and Liu, Songhua and Yang, Xingyi and Xue, Qiaochu and Wang, Xinchao},
+ booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
+ year={2025}
+}
+
+@article{tan2025ominicontrol2,
+ title={OminiControl2: Efficient Conditioning for Diffusion Transformers},
+ author={Tan, Zhenxiong and Xue, Qiaochu and Yang, Xingyi and Liu, Songhua and Wang, Xinchao},
+ journal={arXiv preprint arXiv:2503.08280},
+ year={2025}
+}
+```
diff --git a/ablation_qkv.py b/ablation_qkv.py
new file mode 100644
index 0000000000000000000000000000000000000000..6aaea3421b52a153f8657130bd76d201c71f4c0f
--- /dev/null
+++ b/ablation_qkv.py
@@ -0,0 +1,170 @@
+import torch
+
+from diffusers.pipelines import FluxPipeline
+from omini.pipeline.flux_omini_ablate_qkv import Condition, generate, seed_everything, convert_to_condition
+from omini.rotation import RotationConfig, RotationTuner
+from PIL import Image
+
+
+def load_rotation(transformer, path: str, adapter_name: str = "default", strict: bool = False):
+ """
+ Load rotation adapter weights.
+
+ Args:
+ path: Directory containing the saved adapter weights
+ adapter_name: Name of the adapter to load
+ strict: Whether to strictly match all keys
+ """
+ from safetensors.torch import load_file
+ import os
+ import yaml
+
+ device = transformer.device
+ print(f"device for loading: {device}")
+
+ # Try to load safetensors first, then fallback to .pth
+ safetensors_path = os.path.join(path, f"{adapter_name}.safetensors")
+ pth_path = os.path.join(path, f"{adapter_name}.pth")
+
+ if os.path.exists(safetensors_path):
+ state_dict = load_file(safetensors_path)
+ print(f"Loaded rotation adapter from {safetensors_path}")
+ elif os.path.exists(pth_path):
+ state_dict = torch.load(pth_path, map_location=device)
+ print(f"Loaded rotation adapter from {pth_path}")
+ else:
+ raise FileNotFoundError(
+ f"No adapter weights found for '{adapter_name}' in {path}\n"
+ f"Looking for: {safetensors_path} or {pth_path}"
+ )
+
+ # # Get the device and dtype of the transformer
+ transformer_device = next(transformer.parameters()).device
+ transformer_dtype = next(transformer.parameters()).dtype
+
+
+
+ state_dict_with_adapter = {}
+ for k, v in state_dict.items():
+ # Reconstruct the full key with adapter name
+ new_key = k.replace(".rotation.", f".rotation.{adapter_name}.")
+ if "_adapter_config" in new_key:
+ print(f"adapter_config key: {new_key}")
+
+
+ # Move to target device and dtype
+ # Check if this parameter should keep its original dtype (e.g., indices, masks)
+ if v.dtype in [torch.long, torch.int, torch.int32, torch.int64, torch.bool]:
+ # Keep integer/boolean dtypes, only move device
+ state_dict_with_adapter[new_key] = v.to(device=transformer_device)
+ else:
+ # Convert floating point tensors to target dtype and device
+ state_dict_with_adapter[new_key] = v.to(device=transformer_device, dtype=transformer_dtype)
+
+ # Add adapter name back to keys (reverse of what we did in save)
+ state_dict_with_adapter = {
+ k.replace(".rotation.", f".rotation.{adapter_name}."): v
+ for k, v in state_dict.items()
+ }
+
+
+ # Load into the model
+ missing, unexpected = transformer.load_state_dict(
+ state_dict_with_adapter,
+ strict=strict
+ )
+
+ if missing:
+ print(f"Missing keys: {missing[:5]}{'...' if len(missing) > 5 else ''}")
+ if unexpected:
+ print(f"Unexpected keys: {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}")
+
+ # Load config if available
+ config_path = os.path.join(path, f"{adapter_name}_config.yaml")
+ if os.path.exists(config_path):
+ with open(config_path, 'r') as f:
+ config = yaml.safe_load(f)
+ print(f"Loaded config: {config}")
+
+ total_params = sum(p.numel() for p in state_dict.values())
+ print(f"Loaded {len(state_dict)} tensors ({total_params:,} parameters)")
+
+ return state_dict
+
+
+# prepare input image and prompt
+image = Image.open("assets/coffee.png").convert("RGB")
+
+w, h, min_dim = image.size + (min(image.size),)
+image = image.crop(
+ ((w - min_dim) // 2, (h - min_dim) // 2, (w + min_dim) // 2, (h + min_dim) // 2)
+).resize((512, 512))
+
+prompt = "In a bright room. A cup of a coffee with some beans on the side. They are placed on a dark wooden table."
+
+canny_image = convert_to_condition("canny", image)
+condition = Condition(canny_image, "canny")
+
+seed_everything()
+
+
+
+pipe = FluxPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
+)
+
+
+# add adapter to the transformer
+transformer = pipe.transformer
+
+adapter_name = "default"
+transformer._hf_peft_config_loaded = True
+
+rotation_adapter_config = {
+ "r": 4,
+ "num_rotations": 4,
+ "target_modules": "(.*x_embedder|.*(? 5 else ''}")
+ if unexpected:
+ print(f"Unexpected keys: {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}")
+
+ # Load config if available
+ config_path = os.path.join(path, f"{adapter_name}_config.yaml")
+ if os.path.exists(config_path):
+ with open(config_path, 'r') as f:
+ config = yaml.safe_load(f)
+ print(f"Loaded config: {config}")
+
+ total_params = sum(p.numel() for p in state_dict.values())
+ print(f"Loaded {len(state_dict)} tensors ({total_params:,} parameters)")
+
+ return state_dict
+
+
+# prepare input image and prompt
+image = Image.open("assets/coffee.png").convert("RGB")
+
+w, h, min_dim = image.size + (min(image.size),)
+image = image.crop(
+ ((w - min_dim) // 2, (h - min_dim) // 2, (w + min_dim) // 2, (h + min_dim) // 2)
+).resize((512, 512))
+
+prompt = "In a bright room. A cup of a coffee with some beans on the side. They are placed on a dark wooden table."
+
+canny_image = convert_to_condition("canny", image)
+condition = Condition(canny_image, "canny")
+
+seed_everything()
+
+
+
+for i in range(40, 60):
+ pipe = FluxPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
+ )
+
+
+ # add adapter to the transformer
+ transformer = pipe.transformer
+
+ adapter_name = "default"
+ transformer._hf_peft_config_loaded = True
+
+ rotation_adapter_config = {
+ "r": 4,
+ "num_rotations": 4,
+ "target_modules": "(.*x_embedder|.*(? 5 else ''}")
+ if unexpected:
+ print(f"Unexpected keys: {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}")
+
+ # Load config if available
+ config_path = os.path.join(path, f"{adapter_name}_config.yaml")
+ if os.path.exists(config_path):
+ with open(config_path, 'r') as f:
+ config = yaml.safe_load(f)
+ print(f"Loaded config: {config}")
+
+ total_params = sum(p.numel() for p in state_dict.values())
+ print(f"Loaded {len(state_dict)} tensors ({total_params:,} parameters)")
+
+ return state_dict
+
+
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Evaluate OminiControl on COCO dataset")
+ parser.add_argument("--start_index", type=int, default=0, help="Starting index for evaluation")
+ parser.add_argument("--num_images", type=int, default=500, help="Number of images to evaluate")
+ parser.add_argument("--condition_type", type=str, default="deblurring", help="Type of condition (e.g., 'deblurring', 'canny', 'depth')")
+ parser.add_argument("--adapter_path", type=str, default="runs/20251111-212406-deblurring/ckpt/25000", help="Path to the adapter checkpoint")
+ args = parser.parse_args()
+
+ START_INDEX = args.start_index
+ NUM_IMAGES = args.num_images
+
+ # Path to your captions file (change if needed)
+ CAPTION_FILE = "/home/work/koopman/oft/data/coco/annotations/captions_val2017.json"
+ IMAGE_DIR = "/home/work/koopman/oft/data/coco/images/val2017/"
+ CONDITION_TYPE = args.condition_type
+ SAVE_ROOT_DIR = f"./coco/results_{CONDITION_TYPE}_1000/"
+ ADAPTER_PATH = args.adapter_path
+
+ # Load your Flux pipeline
+ pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.float32) # Replace with your model path
+
+ # add adapter to the transformer
+ transformer = pipe.transformer
+
+ adapter_name = "default"
+ transformer._hf_peft_config_loaded = True
+
+ # make sure this is the same with your config.yaml used in training
+ rotation_adapter_config = {
+ "r": 1,
+ "num_rotations": 8,
+ "target_modules": "(.*x_embedder|.*(? 5 else ''}")
+ if unexpected:
+ print(f"Unexpected keys: {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}")
+
+ # Load config if available
+ config_path = os.path.join(path, f"{adapter_name}_config.yaml")
+ if os.path.exists(config_path):
+ with open(config_path, 'r') as f:
+ config = yaml.safe_load(f)
+ print(f"Loaded config: {config}")
+
+ total_params = sum(p.numel() for p in state_dict.values())
+ print(f"Loaded {len(state_dict)} tensors ({total_params:,} parameters)")
+
+ return state_dict
+
+
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Evaluate OminiControl on COCO dataset")
+ parser.add_argument("--start_index", type=int, default=0, help="Starting index for evaluation")
+ parser.add_argument("--num_images", type=int, default=500, help="Number of images to evaluate")
+ parser.add_argument("--condition_type", type=str, default="deblurring", help="Type of condition (e.g., 'deblurring', 'canny', 'depth')")
+ args = parser.parse_args()
+
+ START_INDEX = args.start_index
+ NUM_IMAGES = args.num_images
+
+ # Path to your captions file (change if needed)
+ CAPTION_FILE = "/home/work/koopman/oft/data/coco/annotations/captions_val2017.json"
+ IMAGE_DIR = "/home/work/koopman/oft/data/coco/images/val2017/"
+ CONDITION_TYPE = args.condition_type
+ SAVE_ROOT_DIR = f"./coco_baseline/results_{CONDITION_TYPE}_1000/"
+
+ # Load your Flux pipeline
+ pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.float16) # Replace with your model path
+
+ ### FOR OMINI
+
+ pipe.load_lora_weights(
+ "Yuanshi/OminiControl",
+ weight_name=f"experimental/{CONDITION_TYPE}.safetensors",
+ adapter_name=CONDITION_TYPE,
+ )
+ pipe.fuse_lora(lora_scale=1.0)
+ pipe.unload_lora_weights()
+
+ # pipe.set_adapters([CONDITION_TYPE])
+ pipe = pipe.to("cuda")
+
+
+ # Evaluate on COCO
+ evaluate(
+ pipe,
+ condition_type=CONDITION_TYPE,
+ caption_file=CAPTION_FILE,
+ image_dir=IMAGE_DIR,
+ save_root_dir=SAVE_ROOT_DIR,
+ num_images=NUM_IMAGES,
+ start_index=START_INDEX,
+ )
\ No newline at end of file
diff --git a/evaluation_subject_driven.py b/evaluation_subject_driven.py
new file mode 100644
index 0000000000000000000000000000000000000000..3eb245996f9b87188123218777221d0f9978aa60
--- /dev/null
+++ b/evaluation_subject_driven.py
@@ -0,0 +1,362 @@
+import openai
+import base64
+from pathlib import Path
+import random
+import os
+
+
+
+evaluation_prompts = {
+ "identity": """
+ Compare the original subject image with the generated image.
+ Rate on a scale of 1-5 how well the essential identifying features
+ are preserved (logos, brand marks, distinctive patterns).
+ Score: [1-5]
+ Reasoning: [explanation]
+ """,
+
+ "material": """
+ Evaluate the material quality and surface characteristics.
+ Rate on a scale of 1-5 how accurately materials are represented
+ (textures, reflections, surface properties).
+ Score: [1-5]
+ Reasoning: [explanation]
+ """,
+
+ "color": """
+ Assess color fidelity in regions NOT specified for modification.
+ Rate on a scale of 1-5 how consistent colors remain.
+ Score: [1-5]
+ Reasoning: [explanation]
+ """,
+
+ "appearance": """
+ Evaluate the overall realism and coherence of the generated image.
+ Rate on a scale of 1-5 how realistic and natural it appears.
+ Score: [1-5]
+ Reasoning: [explanation]
+ """,
+
+ "modification": """
+ Given the text prompt: "{prompt}"
+ Rate on a scale of 1-5 how well the specified changes are executed.
+ Score: [1-5]
+ Reasoning: [explanation]
+ """
+}
+
+
+def encode_image(image_path):
+ with open(image_path, "rb") as image_file:
+ return base64.b64encode(image_file.read()).decode('utf-8')
+
+def evaluate_subject_driven_generation(
+ original_image_path,
+ generated_image_path,
+ text_prompt,
+ client
+):
+ """
+ Evaluate a subject-driven generation using GPT-4o vision
+ """
+
+ # Encode images
+ original_img = encode_image(original_image_path)
+ generated_img = encode_image(generated_image_path)
+
+ results = {}
+
+ # 1. Identity Preservation
+ response = client.chat.completions.create(
+ model="gpt-4o",
+ messages=[{
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "Original subject image:"},
+ {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{original_img}"}},
+ {"type": "text", "text": "Generated image:"},
+ {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{generated_img}"}},
+ {"type": "text", "text": evaluation_prompts["identity"]}
+ ]
+ }],
+ max_tokens=300
+ )
+ results['identity'] = parse_score(response.choices[0].message.content)
+
+ # 2. Material Quality
+ response = client.chat.completions.create(
+ model="gpt-4o",
+ messages=[{
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "Evaluate this generated image:"},
+ {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{generated_img}"}},
+ {"type": "text", "text": evaluation_prompts["material"]}
+ ]
+ }],
+ max_tokens=300
+ )
+ results['material'] = parse_score(response.choices[0].message.content)
+
+ # 3. Color Fidelity
+ response = client.chat.completions.create(
+ model="gpt-4o",
+ messages=[{
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "Original:"},
+ {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{original_img}"}},
+ {"type": "text", "text": "Generated:"},
+ {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{generated_img}"}},
+ {"type": "text", "text": evaluation_prompts["color"]}
+ ]
+ }],
+ max_tokens=300
+ )
+ results['color'] = parse_score(response.choices[0].message.content)
+
+ # 4. Natural Appearance
+ response = client.chat.completions.create(
+ model="gpt-4o",
+ messages=[{
+ "role": "user",
+ "content": [
+ {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{generated_img}"}},
+ {"type": "text", "text": evaluation_prompts["appearance"]}
+ ]
+ }],
+ max_tokens=300
+ )
+ results['appearance'] = parse_score(response.choices[0].message.content)
+
+ # 5. Modification Accuracy
+ response = client.chat.completions.create(
+ model="gpt-4o",
+ messages=[{
+ "role": "user",
+ "content": [
+ {"type": "text", "text": f"Text prompt: {text_prompt}"},
+ {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{generated_img}"}},
+ {"type": "text", "text": evaluation_prompts["modification"].format(prompt=text_prompt)}
+ ]
+ }],
+ max_tokens=300
+ )
+ results['modification'] = parse_score(response.choices[0].message.content)
+
+ return results
+
+def parse_score(response_text):
+ """Extract score from GPT-4o response"""
+ # Look for "Score: X" pattern
+ import re
+ match = re.search(r'Score:\s*(\d+)', response_text)
+ if match:
+ return int(match.group(1))
+ return None
+
+subject_names = [
+ "backpack", "backpack_dog", "bear_plushie", "berry_bowl", "can",
+ "candle", "cat", "cat2", "clock", "colorful_sneaker",
+ "dog", "dog2", "dog3", "dog5", "dog6",
+ "dog7", "dog8", "duck_toy", "fancy_boot", "grey_sloth_plushie",
+ "monster_toy", "pink_sunglasses", "poop_emoji", "rc_car", "red_cartoon",
+ "robot_toy", "shiny_sneaker", "teapot", "vase", "wolf_plushie"
+]
+
+
+def get_prompt(subject_id, prompt_id):
+
+ # subject in 0|1|2|3|4|5|8|9|17|18|19|20|21|22|23|24|25|26|27|28|29
+ if subject_id in [0,1,2,3,4,5,8,9,17,18,19,20,21,22,23,24,25,26,27,28,29]:
+ subject_name = subject_names[subject_id]
+ prompts = [
+ f"a {subject_name} in the jungle",
+ f"a {subject_name} in the snow",
+ f"a {subject_name} on the beach",
+ f"a {subject_name} on a cobblestone street",
+ f"a {subject_name} on top of pink fabric",
+ f"a {subject_name} on top of a wooden floor",
+ f"a {subject_name} with a city in the background",
+ f"a {subject_name} with a mountain in the background",
+ f"a {subject_name} with a blue house in the background",
+ f"a {subject_name} on top of a purple rug in a forest",
+ f"a {subject_name} with a wheat field in the background",
+ f"a {subject_name} with a tree and autumn leaves in the background",
+ f"a {subject_name} with the Eiffel Tower in the background",
+ f"a {subject_name} floating on top of water",
+ f"a {subject_name} floating in an ocean of milk",
+ f"a {subject_name} on top of green grass with sunflowers around it",
+ f"a {subject_name} on top of a mirror",
+ f"a {subject_name} on top of the sidewalk in a crowded street",
+ f"a {subject_name} on top of a dirt road",
+ f"a {subject_name} on top of a white rug",
+ f"a red {subject_name}",
+ f"a purple {subject_name}",
+ f"a shiny {subject_name}",
+ f"a wet {subject_name}",
+ f"a cube shaped {subject_name}"
+ ]
+
+ else:
+ prompts = [
+ f"a {subject_name} in the jungle",
+ f"a {subject_name} in the snow",
+ f"a {subject_name} on the beach",
+ f"a {subject_name} on a cobblestone street",
+ f"a {subject_name} on top of pink fabric",
+ f"a {subject_name} on top of a wooden floor",
+ f"a {subject_name} with a city in the background",
+ f"a {subject_name} with a mountain in the background",
+ f"a {subject_name} with a blue house in the background",
+ f"a {subject_name} on top of a purple rug in a forest",
+ f"a {subject_name} wearing a red hat",
+ f"a {subject_name} wearing a santa hat",
+ f"a {subject_name} wearing a rainbow scarf",
+ f"a {subject_name} wearing a black top hat and a monocle",
+ f"a {subject_name} in a chef outfit",
+ f"a {subject_name} in a firefighter outfit",
+ f"a {subject_name} in a police outfit",
+ f"a {subject_name} wearing pink glasses",
+ f"a {subject_name} wearing a yellow shirt",
+ f"a {subject_name} in a purple wizard outfit",
+ f"a red {subject_name}",
+ f"a purple {subject_name}",
+ f"a shiny {subject_name}",
+ f"a wet {subject_name}",
+ f"a cube shaped {subject_name}"
+ ]
+
+ return prompts[prompt_id]
+
+
+
+
+
+def batch_evaluate_dreambooth(client, generate_fn, dataset_path, output_csv):
+ """
+ Evaluate 750 image pairs with 5 seeds each
+ """
+ import pandas as pd
+
+ results_list = []
+
+ # Iterate through DreamBooth dataset
+ for subject_id in range(30): # 30 subjects
+ subject_name = subject_names[subject_id]
+ for prompt_id in range(25): # 25 prompts per subject
+ original = f"{dataset_path}/{subject_name}"
+ # get a random file in this folder
+ original_files = list(Path(original).glob("*.png"))
+ if len(original_files) == 0:
+ raise ValueError(f"No original images found in {original}")
+
+ original = str(original_files[0])
+
+
+ for seed in range(5): # 5 different seeds
+ # take random file in the folder
+ prompt = get_prompt(subject_id, prompt_id)
+
+ # generated image path
+ generated_folder = f"{dataset_path}/{subject_name}/generated/"
+ os.makedirs(generated_folder, exist_ok=True)
+ generated = f"{generated_folder}/gen_seed{seed}_prompt{prompt_id}.png"
+
+ generate_fn(
+ prompt=prompt,
+ subject_image_path=original,
+ output_image_path=generated,
+ seed=seed
+ )
+
+ scores = evaluate_subject_driven_generation(
+ original, generated, prompt, client
+ )
+
+ results_list.append({
+ 'subject_id': subject_id,
+ 'subject_name': subject_name,
+ 'prompt_id': prompt_id,
+ 'seed': seed,
+ 'prompt': prompt,
+
+ **scores
+ })
+
+ # Save results
+ df = pd.DataFrame(results_list)
+ df.to_csv(output_csv, index=False)
+
+ # Calculate statistics
+ print(df.groupby('subject_id').mean())
+ print(f"\nOverall averages:")
+ print(df[['identity', 'material', 'color', 'appearance', 'modification']].mean())
+
+
+def evaluate_omini_control():
+
+ import torch
+ from diffusers.pipelines import FluxPipeline
+ from PIL import Image
+
+ from omini.pipeline.flux_omini import Condition, generate, seed_everything
+
+ pipe = FluxPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
+ )
+
+ pipe = pipe.to("cuda")
+ pipe.load_lora_weights(
+ "Yuanshi/OminiControl",
+ weight_name=f"omini/subject_512.safetensors",
+ adapter_name="subject",
+ )
+
+ def generate_fn(image_path, prompt, seed, output_path):
+ seed_everything(seed)
+
+ image = Image.open(image_path).convert("RGB").resize((512, 512))
+ condition = Condition.from_image(
+ image,
+ "subject", position_delta=(0, 32)
+ )
+
+ result_img = generate(
+ pipe,
+ prompt=prompt,
+ conditions=[condition],
+ ).images[0]
+
+ result_img.save(output_path)
+
+ return generate_fn
+
+
+if __name__ == "__main__":
+
+
+
+ openai.api_key = os.getenv("OPENAI_API_KEY")
+ # client = openai.Client()
+
+ # generate_fn = evaluate_omini_control()
+
+ # dataset_path = "data/dreambooth"
+ # output_csv = "evaluation_subject_driven_omini_control.csv"
+
+ # batch_evaluate_dreambooth(
+ # client,
+ # generate_fn,
+ # dataset_path,
+ # output_csv
+ # )
+
+ result = evaluate_subject_driven_generation(
+ "data/dreambooth/backpack/00.jpg",
+ "data/dreambooth/backpack/01.jpg",
+ "a backpack in the jungle",
+ openai.Client()
+ )
+
+ print(result)
\ No newline at end of file
diff --git a/examples/combine_with_style_lora.ipynb b/examples/combine_with_style_lora.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..c212ca61014e42c6de001a70ae80d19ab5ce230b
--- /dev/null
+++ b/examples/combine_with_style_lora.ipynb
@@ -0,0 +1,235 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "\n",
+ "os.chdir(\"..\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "from diffusers.pipelines import FluxPipeline\n",
+ "from PIL import Image\n",
+ "\n",
+ "from omini.pipeline.flux_omini import Condition, generate, seed_everything"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\n",
+ "pipe = FluxPipeline.from_pretrained(\n",
+ " \"black-forest-labs/FLUX.1-schnell\", torch_dtype=torch.bfloat16\n",
+ ")\n",
+ "pipe = pipe.to(\"cuda\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pipe.unload_lora_weights()\n",
+ "\n",
+ "pipe.load_lora_weights(\n",
+ " \"Yuanshi/OminiControl\",\n",
+ " weight_name=f\"omini/subject_512.safetensors\",\n",
+ " adapter_name=\"subject\",\n",
+ ")\n",
+ "pipe.load_lora_weights(\"XLabs-AI/flux-RealismLora\", adapter_name=\"realism\")\n",
+ "\n",
+ "pipe.set_adapters([\"subject\", \"realism\"])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "image = Image.open(\"assets/penguin.jpg\").convert(\"RGB\").resize((512, 512))\n",
+ "\n",
+ "# For this model, the position_delta is (0, 32).\n",
+ "# For more details of position_delta, please refer to:\n",
+ "# https://github.com/Yuanshi9815/OminiControl/issues/89#issuecomment-2827080344\n",
+ "condition = Condition(image, \"subject\", position_delta=(0, 32))\n",
+ "\n",
+ "prompt = \"On Christmas evening, on a crowded sidewalk, this item sits on the road, covered in snow and wearing a Christmas hat.\"\n",
+ "\n",
+ "\n",
+ "seed_everything(0)\n",
+ "\n",
+ "result_img = generate(\n",
+ " pipe,\n",
+ " prompt=prompt,\n",
+ " conditions=[condition],\n",
+ " num_inference_steps=8,\n",
+ " height=512,\n",
+ " width=512,\n",
+ " main_adapter=\"realism\"\n",
+ ").images[0]\n",
+ "\n",
+ "concat_image = Image.new(\"RGB\", (1024, 512))\n",
+ "concat_image.paste(image, (0, 0))\n",
+ "concat_image.paste(result_img, (512, 0))\n",
+ "concat_image"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "image = Image.open(\"assets/tshirt.jpg\").convert(\"RGB\").resize((512, 512))\n",
+ "\n",
+ "condition = Condition(image, \"subject\", position_delta=(0, 32))\n",
+ "\n",
+ "prompt = \"On the beach, a lady sits under a beach umbrella. She's wearing this shirt and has a big smile on her face, with her surfboard hehind her. The sun is setting in the background. The sky is a beautiful shade of orange and purple.\"\n",
+ "\n",
+ "\n",
+ "seed_everything()\n",
+ "\n",
+ "result_img = generate(\n",
+ " pipe,\n",
+ " prompt=prompt,\n",
+ " conditions=[condition],\n",
+ " num_inference_steps=8,\n",
+ " height=512,\n",
+ " width=512,\n",
+ " main_adapter=\"realism\"\n",
+ ").images[0]\n",
+ "\n",
+ "concat_image = Image.new(\"RGB\", (1024, 512))\n",
+ "concat_image.paste(condition.condition, (0, 0))\n",
+ "concat_image.paste(result_img, (512, 0))\n",
+ "concat_image"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "image = Image.open(\"assets/rc_car.jpg\").convert(\"RGB\").resize((512, 512))\n",
+ "\n",
+ "condition = Condition(image, \"subject\", position_delta=(0, 32))\n",
+ "\n",
+ "prompt = \"A film style shot. On the moon, this item drives across the moon surface. The background is that Earth looms large in the foreground.\"\n",
+ "\n",
+ "seed_everything()\n",
+ "\n",
+ "result_img = generate(\n",
+ " pipe,\n",
+ " prompt=prompt,\n",
+ " conditions=[condition],\n",
+ " num_inference_steps=8,\n",
+ " height=512,\n",
+ " width=512,\n",
+ " main_adapter=\"realism\"\n",
+ ").images[0]\n",
+ "\n",
+ "concat_image = Image.new(\"RGB\", (1024, 512))\n",
+ "concat_image.paste(condition.condition, (0, 0))\n",
+ "concat_image.paste(result_img, (512, 0))\n",
+ "concat_image"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "image = Image.open(\"assets/clock.jpg\").convert(\"RGB\").resize((512, 512))\n",
+ "\n",
+ "condition = Condition(image, \"subject\", position_delta=(0, 32))\n",
+ "\n",
+ "prompt = \"In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.\"\n",
+ "\n",
+ "seed_everything()\n",
+ "\n",
+ "result_img = generate(\n",
+ " pipe,\n",
+ " prompt=prompt,\n",
+ " conditions=[condition],\n",
+ " num_inference_steps=8,\n",
+ " height=512,\n",
+ " width=512,\n",
+ " main_adapter=\"realism\"\n",
+ ").images[0]\n",
+ "\n",
+ "concat_image = Image.new(\"RGB\", (1024, 512))\n",
+ "concat_image.paste(condition.condition, (0, 0))\n",
+ "concat_image.paste(result_img, (512, 0))\n",
+ "concat_image"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "image = Image.open(\"assets/oranges.jpg\").convert(\"RGB\").resize((512, 512))\n",
+ "\n",
+ "condition = Condition(image, \"subject\", position_delta=(0, 32))\n",
+ "\n",
+ "prompt = \"A very close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show.\"\n",
+ "\n",
+ "seed_everything()\n",
+ "\n",
+ "result_img = generate(\n",
+ " pipe,\n",
+ " prompt=prompt,\n",
+ " conditions=[condition],\n",
+ " num_inference_steps=8,\n",
+ " height=512,\n",
+ " width=512,\n",
+ " main_adapter=\"realism\"\n",
+ ").images[0]\n",
+ "\n",
+ "concat_image = Image.new(\"RGB\", (1024, 512))\n",
+ "concat_image.paste(condition.condition, (0, 0))\n",
+ "concat_image.paste(result_img, (512, 0))\n",
+ "concat_image"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.21"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/examples/inpainting.ipynb b/examples/inpainting.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..21ac662cc6187188bcf799cb3919208198c41235
--- /dev/null
+++ b/examples/inpainting.ipynb
@@ -0,0 +1,135 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "\n",
+ "os.chdir(\"..\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "from diffusers.pipelines import FluxPipeline\n",
+ "from PIL import Image\n",
+ "\n",
+ "from omini.pipeline.flux_omini import Condition, generate, seed_everything"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pipe = FluxPipeline.from_pretrained(\n",
+ " \"black-forest-labs/FLUX.1-dev\", torch_dtype=torch.bfloat16\n",
+ ")\n",
+ "pipe = pipe.to(\"cuda\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pipe.load_lora_weights(\n",
+ " \"Yuanshi/OminiControl\",\n",
+ " weight_name=f\"experimental/fill.safetensors\",\n",
+ " adapter_name=\"fill\",\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "image = Image.open(\"assets/monalisa.jpg\").convert(\"RGB\").resize((512, 512))\n",
+ "\n",
+ "masked_image = image.copy()\n",
+ "masked_image.paste((0, 0, 0), (128, 100, 384, 220))\n",
+ "\n",
+ "condition = Condition(masked_image, \"fill\")\n",
+ "\n",
+ "seed_everything()\n",
+ "result_img = generate(\n",
+ " pipe,\n",
+ " prompt=\"The Mona Lisa is wearing a white VR headset with 'Omini' written on it.\",\n",
+ " conditions=[condition],\n",
+ ").images[0]\n",
+ "\n",
+ "concat_image = Image.new(\"RGB\", (1536, 512))\n",
+ "concat_image.paste(image, (0, 0))\n",
+ "concat_image.paste(condition.condition, (512, 0))\n",
+ "concat_image.paste(result_img, (1024, 0))\n",
+ "concat_image"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "image = Image.open(\"assets/book.jpg\").convert(\"RGB\").resize((512, 512))\n",
+ "\n",
+ "w, h, min_dim = image.size + (min(image.size),)\n",
+ "image = image.crop(\n",
+ " ((w - min_dim) // 2, (h - min_dim) // 2, (w + min_dim) // 2, (h + min_dim) // 2)\n",
+ ").resize((512, 512))\n",
+ "\n",
+ "\n",
+ "masked_image = image.copy()\n",
+ "masked_image.paste((0, 0, 0), (150, 150, 350, 250))\n",
+ "masked_image.paste((0, 0, 0), (200, 380, 320, 420))\n",
+ "\n",
+ "condition = Condition(masked_image, \"fill\")\n",
+ "\n",
+ "seed_everything()\n",
+ "result_img = generate(\n",
+ " pipe,\n",
+ " prompt=\"A yellow book with the word 'OMINI' in large font on the cover. The text 'for FLUX' appears at the bottom.\",\n",
+ " conditions=[condition],\n",
+ ").images[0]\n",
+ "\n",
+ "concat_image = Image.new(\"RGB\", (1536, 512))\n",
+ "concat_image.paste(image, (0, 0))\n",
+ "concat_image.paste(condition.condition, (512, 0))\n",
+ "concat_image.paste(result_img, (1024, 0))\n",
+ "concat_image"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "base",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/examples/ominicontrol_art.ipynb b/examples/ominicontrol_art.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..70cfad066c997c38e30a6ab2dad5615ce5c3ded8
--- /dev/null
+++ b/examples/ominicontrol_art.ipynb
@@ -0,0 +1,218 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "\n",
+ "os.chdir(\"..\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "from diffusers.pipelines import FluxPipeline\n",
+ "from PIL import Image\n",
+ "\n",
+ "from omini.pipeline.flux_omini import Condition, generate, seed_everything"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pipe = FluxPipeline.from_pretrained(\n",
+ " \"black-forest-labs/FLUX.1-dev\", torch_dtype=torch.bfloat16\n",
+ ")\n",
+ "pipe = pipe.to(\"cuda\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pipe.unload_lora_weights()\n",
+ "\n",
+ "for style_type in [\"ghibli\", \"irasutoya\", \"simpsons\", \"snoopy\"]:\n",
+ " pipe.load_lora_weights(\n",
+ " \"Yuanshi/OminiControlArt\",\n",
+ " weight_name=f\"v0/{style_type}.safetensors\",\n",
+ " adapter_name=style_type,\n",
+ " )\n",
+ "\n",
+ "pipe.set_adapters([\"ghibli\", \"irasutoya\", \"simpsons\", \"snoopy\"])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def resize(img, factor=16):\n",
+ " # Resize the image to be divisible by the factor\n",
+ " w, h = img.size\n",
+ " new_w, new_h = w // factor * factor, h // factor * factor\n",
+ " padding_w, padding_h = (w - new_w) // 2, (h - new_h) // 2\n",
+ " img = img.crop((padding_w, padding_h, new_w + padding_w, new_h + padding_h))\n",
+ " return img\n",
+ "\n",
+ "\n",
+ "def bound_image(image):\n",
+ " factor = 512 / max(image.size)\n",
+ " image = resize(\n",
+ " image.resize(\n",
+ " (int(image.size[0] * factor), int(image.size[1] * factor)),\n",
+ " Image.LANCZOS,\n",
+ " )\n",
+ " )\n",
+ " delta = (0, -image.size[0] // 16)\n",
+ " return image, delta\n",
+ "\n",
+ "sizes = {\n",
+ " \"2:3\": (640, 960),\n",
+ " \"1:1\": (640, 640),\n",
+ " \"3:2\": (960, 640),\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "image = Image.open(\"assets/ominicontrol_art/DistractedBoyfriend.webp\").convert(\"RGB\")\n",
+ "image, delta = bound_image(image)\n",
+ "condition = Condition(image, \"ghibli\", position_delta=delta)\n",
+ "\n",
+ "seed_everything()\n",
+ "\n",
+ "size = sizes[\"3:2\"]\n",
+ "\n",
+ "result_img = generate(\n",
+ " pipe,\n",
+ " prompt=\"\",\n",
+ " conditions=[condition],\n",
+ " max_sequence_length=32,\n",
+ " width=size[0],\n",
+ " height=size[1],\n",
+ " image_guidance_scale=1.5,\n",
+ ").images[0]\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "image = Image.open(\"assets/ominicontrol_art/oiiai.png\").convert(\"RGB\")\n",
+ "image, delta = bound_image(image)\n",
+ "condition = Condition(image, \"irasutoya\", position_delta=delta)\n",
+ "\n",
+ "seed_everything()\n",
+ "\n",
+ "size = sizes[\"1:1\"]\n",
+ "\n",
+ "result_img = generate(\n",
+ " pipe,\n",
+ " prompt=\"\",\n",
+ " conditions=[condition],\n",
+ " max_sequence_length=32,\n",
+ " width=size[0],\n",
+ " height=size[1],\n",
+ " image_guidance_scale=1.5,\n",
+ ").images[0]\n",
+ "\n",
+ "result_img"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "image = Image.open(\"assets/ominicontrol_art/breakingbad.jpg\").convert(\"RGB\")\n",
+ "image, delta = bound_image(image)\n",
+ "condition = Condition(image, \"simpsons\", position_delta=delta)\n",
+ "\n",
+ "seed_everything()\n",
+ "\n",
+ "size = sizes[\"3:2\"]\n",
+ "\n",
+ "result_img = generate(\n",
+ " pipe,\n",
+ " prompt=\"\",\n",
+ " conditions=[condition],\n",
+ " max_sequence_length=32,\n",
+ " width=size[0],\n",
+ " height=size[1],\n",
+ " image_guidance_scale=1.5,\n",
+ ").images[0]\n",
+ "\n",
+ "result_img"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "image = Image.open(\"assets/ominicontrol_art/PulpFiction.jpg\").convert(\"RGB\")\n",
+ "image, delta = bound_image(image)\n",
+ "condition = Condition(image, \"snoopy\", position_delta=delta)\n",
+ "\n",
+ "seed_everything()\n",
+ "\n",
+ "size = sizes[\"3:2\"]\n",
+ "\n",
+ "result_img = generate(\n",
+ " pipe,\n",
+ " prompt=\"\",\n",
+ " conditions=[condition],\n",
+ " max_sequence_length=32,\n",
+ " width=size[0],\n",
+ " height=size[1],\n",
+ " image_guidance_scale=1.5,\n",
+ ").images[0]\n",
+ "\n",
+ "result_img"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.21"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/examples/spatial.ipynb b/examples/spatial.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..e22169b428b9ea771aaed7506e8298170f9e55a3
--- /dev/null
+++ b/examples/spatial.ipynb
@@ -0,0 +1,191 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "\n",
+ "os.chdir(\"..\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "from diffusers.pipelines import FluxPipeline\n",
+ "from PIL import Image\n",
+ "\n",
+ "from omini.pipeline.flux_omini import Condition, generate, seed_everything, convert_to_condition"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pipe = FluxPipeline.from_pretrained(\n",
+ " \"black-forest-labs/FLUX.1-dev\", torch_dtype=torch.bfloat16\n",
+ ")\n",
+ "pipe = pipe.to(\"cuda\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pipe.unload_lora_weights()\n",
+ "\n",
+ "for condition_type in [\"canny\", \"depth\", \"coloring\", \"deblurring\"]:\n",
+ " pipe.load_lora_weights(\n",
+ " \"Yuanshi/OminiControl\",\n",
+ " weight_name=f\"experimental/{condition_type}.safetensors\",\n",
+ " adapter_name=condition_type,\n",
+ " )\n",
+ "\n",
+ "pipe.set_adapters([\"canny\", \"depth\", \"coloring\", \"deblurring\"])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "image = Image.open(\"assets/coffee.png\").convert(\"RGB\")\n",
+ "\n",
+ "w, h, min_dim = image.size + (min(image.size),)\n",
+ "image = image.crop(\n",
+ " ((w - min_dim) // 2, (h - min_dim) // 2, (w + min_dim) // 2, (h + min_dim) // 2)\n",
+ ").resize((512, 512))\n",
+ "\n",
+ "prompt = \"In a bright room. A cup of a coffee with some beans on the side. They are placed on a dark wooden table.\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "canny_image = convert_to_condition(\"canny\", image)\n",
+ "condition = Condition(canny_image, \"canny\")\n",
+ "\n",
+ "seed_everything()\n",
+ "\n",
+ "result_img = generate(\n",
+ " pipe,\n",
+ " prompt=prompt,\n",
+ " conditions=[condition],\n",
+ ").images[0]\n",
+ "\n",
+ "concat_image = Image.new(\"RGB\", (1536, 512))\n",
+ "concat_image.paste(image, (0, 0))\n",
+ "concat_image.paste(condition.condition, (512, 0))\n",
+ "concat_image.paste(result_img, (1024, 0))\n",
+ "concat_image"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "depth_image = convert_to_condition(\"depth\", image)\n",
+ "condition = Condition(depth_image, \"depth\")\n",
+ "\n",
+ "seed_everything()\n",
+ "\n",
+ "result_img = generate(\n",
+ " pipe,\n",
+ " prompt=prompt,\n",
+ " conditions=[condition],\n",
+ ").images[0]\n",
+ "\n",
+ "concat_image = Image.new(\"RGB\", (1536, 512))\n",
+ "concat_image.paste(image, (0, 0))\n",
+ "concat_image.paste(condition.condition, (512, 0))\n",
+ "concat_image.paste(result_img, (1024, 0))\n",
+ "concat_image"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "blur_image = convert_to_condition(\"deblurring\", image)\n",
+ "condition = Condition(blur_image, \"deblurring\")\n",
+ "\n",
+ "seed_everything()\n",
+ "\n",
+ "result_img = generate(\n",
+ " pipe,\n",
+ " prompt=prompt,\n",
+ " conditions=[condition],\n",
+ ").images[0]\n",
+ "\n",
+ "concat_image = Image.new(\"RGB\", (1536, 512))\n",
+ "concat_image.paste(image, (0, 0))\n",
+ "concat_image.paste(condition.condition, (512, 0))\n",
+ "concat_image.paste(result_img, (1024, 0))\n",
+ "concat_image"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "condition_image = convert_to_condition(\"coloring\", image)\n",
+ "condition = Condition(condition_image, \"coloring\")\n",
+ "\n",
+ "seed_everything()\n",
+ "\n",
+ "result_img = generate(\n",
+ " pipe,\n",
+ " prompt=prompt,\n",
+ " conditions=[condition],\n",
+ ").images[0]\n",
+ "\n",
+ "concat_image = Image.new(\"RGB\", (1536, 512))\n",
+ "concat_image.paste(image, (0, 0))\n",
+ "concat_image.paste(condition.condition, (512, 0))\n",
+ "concat_image.paste(result_img, (1024, 0))\n",
+ "concat_image"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "base",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..6d0bc83f27098b2cf17e27de67797092ff6bbee1
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,9 @@
+transformers
+diffusers
+peft
+opencv-python
+protobuf
+sentencepiece
+gradio
+jupyter
+torchao
\ No newline at end of file