diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..13f0b18ad88d9714f6ddf79a06560a0a4a0b8def 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,21 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +assets/cartoon_boy.png filter=lfs diff=lfs merge=lfs -text +assets/clock.jpg filter=lfs diff=lfs merge=lfs -text +assets/monalisa.jpg filter=lfs diff=lfs merge=lfs -text +assets/rc_car.jpg filter=lfs diff=lfs merge=lfs -text +assets/room_corner.jpg filter=lfs diff=lfs merge=lfs -text +assets/tshirt.jpg filter=lfs diff=lfs merge=lfs -text +assets/vase_hq.jpg filter=lfs diff=lfs merge=lfs -text +assets/demo/art1.png filter=lfs diff=lfs merge=lfs -text +assets/demo/art2.png filter=lfs diff=lfs merge=lfs -text +assets/demo/demo_this_is_omini_control.jpg filter=lfs diff=lfs merge=lfs -text +assets/demo/dreambooth_res.jpg filter=lfs diff=lfs merge=lfs -text +assets/demo/monalisa_omini.jpg filter=lfs diff=lfs merge=lfs -text +assets/demo/scene_variation.jpg filter=lfs diff=lfs merge=lfs -text +assets/demo/try_on.jpg filter=lfs diff=lfs merge=lfs -text +assets/ominicontrol_art/DistractedBoyfriend.webp filter=lfs diff=lfs merge=lfs -text +assets/ominicontrol_art/PulpFiction.jpg filter=lfs diff=lfs merge=lfs -text +assets/ominicontrol_art/breakingbad.jpg filter=lfs diff=lfs merge=lfs -text +assets/ominicontrol_art/oiiai.png filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..2fdb2d7733471841f7ba4b7787da9a3da03e5916 --- /dev/null +++ b/.gitignore @@ -0,0 +1,229 @@ +wandb/* +runs/* + + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[codz] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py.cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +# Pipfile.lock + +# UV +# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# uv.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +# poetry.lock +# poetry.toml + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python. +# https://pdm-project.org/en/latest/usage/project/#working-with-version-control +# pdm.lock +# pdm.toml +.pdm-python +.pdm-build/ + +# pixi +# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control. +# pixi.lock +# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one +# in the .venv directory. It is recommended not to include this directory in version control. +.pixi + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# Redis +*.rdb +*.aof +*.pid + +# RabbitMQ +mnesia/ +rabbitmq/ +rabbitmq-data/ + +# ActiveMQ +activemq-data/ + +# SageMath parsed files +*.sage.py + +# Environments +.env +.envrc +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +# .idea/ + +# Abstra +# Abstra is an AI-powered process automation framework. +# Ignore directories containing user credentials, local state, and settings. +# Learn more at https://abstra.io/docs +.abstra/ + +# Visual Studio Code +# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore +# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore +# and can be added to the global gitignore or merged into this file. However, if you prefer, +# you could uncomment the following to ignore the entire vscode folder +# .vscode/ + +# Ruff stuff: +.ruff_cache/ + +# PyPI configuration file +.pypirc + +# Marimo +marimo/_static/ +marimo/_lsp/ +__marimo__/ + +# Streamlit +.streamlit/secrets.toml + + +# exps/ +# wandb/ +# *.ipynb +# glue_exp/ +# logs_hyper/ +# # grid/ +# glue22_ex \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..7b2de994465d5d0c1f37fc6aecd065f8849ed2d8 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [2024] [Zhenxiong Tan] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..0c9cd9b0159b7fb66d114665a2c3b25030a4e3b0 --- /dev/null +++ b/README.md @@ -0,0 +1,198 @@ +# OminiControl + + + +
+ +HuggingFace +HuggingFace +HuggingFace +GitHub +HuggingFace +
+arXiv +arXiv + +> **OminiControl: Minimal and Universal Control for Diffusion Transformer** +>
+> Zhenxiong Tan, +> [Songhua Liu](http://121.37.94.87/), +> [Xingyi Yang](https://adamdad.github.io/), +> Qiaochu Xue, +> and +> [Xinchao Wang](https://sites.google.com/site/sitexinchaowang/) +>
+> [xML Lab](https://sites.google.com/view/xml-nus), National University of Singapore +>
+ +> **OminiControl2: Efficient Conditioning for Diffusion Transformers** +>
+> Zhenxiong Tan, +> Qiaochu Xue, +> [Xingyi Yang](https://adamdad.github.io/), +> [Songhua Liu](http://121.37.94.87/), +> and +> [Xinchao Wang](https://sites.google.com/site/sitexinchaowang/) +>
+> [xML Lab](https://sites.google.com/view/xml-nus), National University of Singapore +>
+ + + +## Features + +OminiControl is a minimal yet powerful universal control framework for Diffusion Transformer models like [FLUX](https://github.com/black-forest-labs/flux). + +* **Universal Control 🌐**: A unified control framework that supports both subject-driven control and spatial control (such as edge-guided and in-painting generation). + +* **Minimal Design πŸš€**: Injects control signals while preserving original model structure. Only introduces 0.1% additional parameters to the base model. + +## News +- **2025-05-12**: ⭐️ The code of [OminiControl2](https://arxiv.org/abs/2503.08280) is released. It introduces a new efficient conditioning method for diffusion transformers. (Check out the training code [here](./train)). +- **2025-05-12**: Support custom style LoRA. (Check out the [example](./examples/combine_with_style_lora.ipynb)). +- **2025-04-09**: ⭐️ [OminiControl Art](https://huggingface.co/spaces/Yuanshi/OminiControl_Art) is released. It can stylize any image with a artistic style. (Check out the [demo](https://huggingface.co/spaces/Yuanshi/OminiControl_Art) and [inference examples](./examples/ominicontrol_art.ipynb)). +- **2024-12-26**: Training code are released. Now you can create your own OminiControl model by customizing any control tasks (3D, multi-view, pose-guided, try-on, etc.) with the FLUX model. Check the [training folder](./train) for more details. + +## Quick Start +### Setup (Optional) +1. **Environment setup** +```bash +conda create -n omini python=3.12 +conda activate omini +``` +2. **Requirements installation** +```bash +pip install -r requirements.txt +``` +### Usage example +1. Subject-driven generation: `examples/subject.ipynb` +2. In-painting: `examples/inpainting.ipynb` +3. Canny edge to image, depth to image, colorization, deblurring: `examples/spatial.ipynb` + + +### Guidelines for subject-driven generation +1. Input images are automatically center-cropped and resized to 512x512 resolution. +2. When writing prompts, refer to the subject using phrases like `this item`, `the object`, or `it`. e.g. + 1. *A close up view of this item. It is placed on a wooden table.* + 2. *A young lady is wearing this shirt.* +3. The model primarily works with objects rather than human subjects currently, due to the absence of human data in training. + +## Generated samples +### Subject-driven generation +HuggingFace + +**Demos** (Left: condition image; Right: generated image) + +
+ + + + +
+ +
+Text Prompts + +- Prompt1: *A close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show. With text on the screen that reads 'Omini Control!.'* +- Prompt2: *A film style shot. On the moon, this item drives across the moon surface. A flag on it reads 'Omini'. The background is that Earth looms large in the foreground.* +- Prompt3: *In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.* +- Prompt4: *"On the beach, a lady sits under a beach umbrella with 'Omini' written on it. She's wearing this shirt and has a big smile on her face, with her surfboard hehind her. The sun is setting in the background. The sky is a beautiful shade of orange and purple."* +
+
+More results + +* Try on: + +* Scene variations: + +* Dreambooth dataset: + +* Oye-cartoon finetune: +
+ + +
+
+ +### Spatially aligned control +1. **Image Inpainting** (Left: original image; Center: masked image; Right: filled image) + - Prompt: *The Mona Lisa is wearing a white VR headset with 'Omini' written on it.* +
+ + - Prompt: *A yellow book with the word 'OMINI' in large font on the cover. The text 'for FLUX' appears at the bottom.* +
+ +2. **Other spatially aligned tasks** (Canny edge to image, depth to image, colorization, deblurring) +
+
+ Click to show +
+ + + + +
+ + Prompt: *A light gray sofa stands against a white wall, featuring a black and white geometric patterned pillow. A white side table sits next to the sofa, topped with a white adjustable desk lamp and some books. Dark hardwood flooring contrasts with the pale walls and furniture.* +
+ +### Stylize images +HuggingFace +
+ + +
+ + + +## Models + +**Subject-driven control:** +| Model | Base model | Description | Resolution | +| ------------------------------------------------------------------------------------------------ | -------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------ | +| [`experimental`](https://huggingface.co/Yuanshi/OminiControl/tree/main/experimental) / `subject` | FLUX.1-schnell | The model used in the paper. | (512, 512) | +| [`omini`](https://huggingface.co/Yuanshi/OminiControl/tree/main/omini) / `subject_512` | FLUX.1-schnell | The model has been fine-tuned on a larger dataset. | (512, 512) | +| [`omini`](https://huggingface.co/Yuanshi/OminiControl/tree/main/omini) / `subject_1024` | FLUX.1-schnell | The model has been fine-tuned on a larger dataset and accommodates higher resolution. | (1024, 1024) | +| [`oye-cartoon`](https://huggingface.co/saquiboye/oye-cartoon) | FLUX.1-dev | The model has been fine-tuned on [oye-cartoon](https://huggingface.co/datasets/saquiboye/oye-cartoon) dataset by [@saquib764](https://github.com/Saquib764) | (512, 512) | + +**Spatial aligned control:** +| Model | Base model | Description | Resolution | +| --------------------------------------------------------------------------------------------------------- | ---------- | -------------------------------------------------------------------------- | ------------ | +| [`experimental`](https://huggingface.co/Yuanshi/OminiControl/tree/main/experimental) / `` | FLUX.1 | Canny edge to image, depth to image, colorization, deblurring, in-painting | (512, 512) |= + +## Community Extensions +- [ComfyUI-Diffusers-OminiControl](https://github.com/Macoron/ComfyUI-Diffusers-OminiControl) - ComfyUI integration by [@Macoron](https://github.com/Macoron) +- [ComfyUI_RH_OminiControl](https://github.com/HM-RunningHub/ComfyUI_RH_OminiControl) - ComfyUI integration by [@HM-RunningHub](https://github.com/HM-RunningHub) + +## Limitations +1. The model's subject-driven generation primarily works with objects rather than human subjects due to the absence of human data in training. +2. The subject-driven generation model may not work well with `FLUX.1-dev`. +3. The released model only supports the resolution of 512x512. + +## Training +Training instructions can be found in this [folder](./train). + + +## To-do +- [x] Release the training code. +- [x] Release the model for higher resolution (1024x1024). + +## Acknowledgment +We would like to acknowledge that the computational work involved in this research work is partially supported by NUS IT’s Research Computing group using grant numbers NUSREC-HPC-00001. + +## Citation +``` +@article{tan2025ominicontrol, + title={OminiControl: Minimal and Universal Control for Diffusion Transformer}, + author={Tan, Zhenxiong and Liu, Songhua and Yang, Xingyi and Xue, Qiaochu and Wang, Xinchao}, + booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, + year={2025} +} + +@article{tan2025ominicontrol2, + title={OminiControl2: Efficient Conditioning for Diffusion Transformers}, + author={Tan, Zhenxiong and Xue, Qiaochu and Yang, Xingyi and Liu, Songhua and Wang, Xinchao}, + journal={arXiv preprint arXiv:2503.08280}, + year={2025} +} +``` diff --git a/ablation_qkv.py b/ablation_qkv.py new file mode 100644 index 0000000000000000000000000000000000000000..6aaea3421b52a153f8657130bd76d201c71f4c0f --- /dev/null +++ b/ablation_qkv.py @@ -0,0 +1,170 @@ +import torch + +from diffusers.pipelines import FluxPipeline +from omini.pipeline.flux_omini_ablate_qkv import Condition, generate, seed_everything, convert_to_condition +from omini.rotation import RotationConfig, RotationTuner +from PIL import Image + + +def load_rotation(transformer, path: str, adapter_name: str = "default", strict: bool = False): + """ + Load rotation adapter weights. + + Args: + path: Directory containing the saved adapter weights + adapter_name: Name of the adapter to load + strict: Whether to strictly match all keys + """ + from safetensors.torch import load_file + import os + import yaml + + device = transformer.device + print(f"device for loading: {device}") + + # Try to load safetensors first, then fallback to .pth + safetensors_path = os.path.join(path, f"{adapter_name}.safetensors") + pth_path = os.path.join(path, f"{adapter_name}.pth") + + if os.path.exists(safetensors_path): + state_dict = load_file(safetensors_path) + print(f"Loaded rotation adapter from {safetensors_path}") + elif os.path.exists(pth_path): + state_dict = torch.load(pth_path, map_location=device) + print(f"Loaded rotation adapter from {pth_path}") + else: + raise FileNotFoundError( + f"No adapter weights found for '{adapter_name}' in {path}\n" + f"Looking for: {safetensors_path} or {pth_path}" + ) + + # # Get the device and dtype of the transformer + transformer_device = next(transformer.parameters()).device + transformer_dtype = next(transformer.parameters()).dtype + + + + state_dict_with_adapter = {} + for k, v in state_dict.items(): + # Reconstruct the full key with adapter name + new_key = k.replace(".rotation.", f".rotation.{adapter_name}.") + if "_adapter_config" in new_key: + print(f"adapter_config key: {new_key}") + + + # Move to target device and dtype + # Check if this parameter should keep its original dtype (e.g., indices, masks) + if v.dtype in [torch.long, torch.int, torch.int32, torch.int64, torch.bool]: + # Keep integer/boolean dtypes, only move device + state_dict_with_adapter[new_key] = v.to(device=transformer_device) + else: + # Convert floating point tensors to target dtype and device + state_dict_with_adapter[new_key] = v.to(device=transformer_device, dtype=transformer_dtype) + + # Add adapter name back to keys (reverse of what we did in save) + state_dict_with_adapter = { + k.replace(".rotation.", f".rotation.{adapter_name}."): v + for k, v in state_dict.items() + } + + + # Load into the model + missing, unexpected = transformer.load_state_dict( + state_dict_with_adapter, + strict=strict + ) + + if missing: + print(f"Missing keys: {missing[:5]}{'...' if len(missing) > 5 else ''}") + if unexpected: + print(f"Unexpected keys: {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}") + + # Load config if available + config_path = os.path.join(path, f"{adapter_name}_config.yaml") + if os.path.exists(config_path): + with open(config_path, 'r') as f: + config = yaml.safe_load(f) + print(f"Loaded config: {config}") + + total_params = sum(p.numel() for p in state_dict.values()) + print(f"Loaded {len(state_dict)} tensors ({total_params:,} parameters)") + + return state_dict + + +# prepare input image and prompt +image = Image.open("assets/coffee.png").convert("RGB") + +w, h, min_dim = image.size + (min(image.size),) +image = image.crop( + ((w - min_dim) // 2, (h - min_dim) // 2, (w + min_dim) // 2, (h + min_dim) // 2) +).resize((512, 512)) + +prompt = "In a bright room. A cup of a coffee with some beans on the side. They are placed on a dark wooden table." + +canny_image = convert_to_condition("canny", image) +condition = Condition(canny_image, "canny") + +seed_everything() + + + +pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16 +) + + +# add adapter to the transformer +transformer = pipe.transformer + +adapter_name = "default" +transformer._hf_peft_config_loaded = True + +rotation_adapter_config = { + "r": 4, + "num_rotations": 4, + "target_modules": "(.*x_embedder|.*(? 5 else ''}") + if unexpected: + print(f"Unexpected keys: {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}") + + # Load config if available + config_path = os.path.join(path, f"{adapter_name}_config.yaml") + if os.path.exists(config_path): + with open(config_path, 'r') as f: + config = yaml.safe_load(f) + print(f"Loaded config: {config}") + + total_params = sum(p.numel() for p in state_dict.values()) + print(f"Loaded {len(state_dict)} tensors ({total_params:,} parameters)") + + return state_dict + + +# prepare input image and prompt +image = Image.open("assets/coffee.png").convert("RGB") + +w, h, min_dim = image.size + (min(image.size),) +image = image.crop( + ((w - min_dim) // 2, (h - min_dim) // 2, (w + min_dim) // 2, (h + min_dim) // 2) +).resize((512, 512)) + +prompt = "In a bright room. A cup of a coffee with some beans on the side. They are placed on a dark wooden table." + +canny_image = convert_to_condition("canny", image) +condition = Condition(canny_image, "canny") + +seed_everything() + + + +for i in range(40, 60): + pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16 + ) + + + # add adapter to the transformer + transformer = pipe.transformer + + adapter_name = "default" + transformer._hf_peft_config_loaded = True + + rotation_adapter_config = { + "r": 4, + "num_rotations": 4, + "target_modules": "(.*x_embedder|.*(? 5 else ''}") + if unexpected: + print(f"Unexpected keys: {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}") + + # Load config if available + config_path = os.path.join(path, f"{adapter_name}_config.yaml") + if os.path.exists(config_path): + with open(config_path, 'r') as f: + config = yaml.safe_load(f) + print(f"Loaded config: {config}") + + total_params = sum(p.numel() for p in state_dict.values()) + print(f"Loaded {len(state_dict)} tensors ({total_params:,} parameters)") + + return state_dict + + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Evaluate OminiControl on COCO dataset") + parser.add_argument("--start_index", type=int, default=0, help="Starting index for evaluation") + parser.add_argument("--num_images", type=int, default=500, help="Number of images to evaluate") + parser.add_argument("--condition_type", type=str, default="deblurring", help="Type of condition (e.g., 'deblurring', 'canny', 'depth')") + parser.add_argument("--adapter_path", type=str, default="runs/20251111-212406-deblurring/ckpt/25000", help="Path to the adapter checkpoint") + args = parser.parse_args() + + START_INDEX = args.start_index + NUM_IMAGES = args.num_images + + # Path to your captions file (change if needed) + CAPTION_FILE = "/home/work/koopman/oft/data/coco/annotations/captions_val2017.json" + IMAGE_DIR = "/home/work/koopman/oft/data/coco/images/val2017/" + CONDITION_TYPE = args.condition_type + SAVE_ROOT_DIR = f"./coco/results_{CONDITION_TYPE}_1000/" + ADAPTER_PATH = args.adapter_path + + # Load your Flux pipeline + pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.float32) # Replace with your model path + + # add adapter to the transformer + transformer = pipe.transformer + + adapter_name = "default" + transformer._hf_peft_config_loaded = True + + # make sure this is the same with your config.yaml used in training + rotation_adapter_config = { + "r": 1, + "num_rotations": 8, + "target_modules": "(.*x_embedder|.*(? 5 else ''}") + if unexpected: + print(f"Unexpected keys: {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}") + + # Load config if available + config_path = os.path.join(path, f"{adapter_name}_config.yaml") + if os.path.exists(config_path): + with open(config_path, 'r') as f: + config = yaml.safe_load(f) + print(f"Loaded config: {config}") + + total_params = sum(p.numel() for p in state_dict.values()) + print(f"Loaded {len(state_dict)} tensors ({total_params:,} parameters)") + + return state_dict + + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Evaluate OminiControl on COCO dataset") + parser.add_argument("--start_index", type=int, default=0, help="Starting index for evaluation") + parser.add_argument("--num_images", type=int, default=500, help="Number of images to evaluate") + parser.add_argument("--condition_type", type=str, default="deblurring", help="Type of condition (e.g., 'deblurring', 'canny', 'depth')") + args = parser.parse_args() + + START_INDEX = args.start_index + NUM_IMAGES = args.num_images + + # Path to your captions file (change if needed) + CAPTION_FILE = "/home/work/koopman/oft/data/coco/annotations/captions_val2017.json" + IMAGE_DIR = "/home/work/koopman/oft/data/coco/images/val2017/" + CONDITION_TYPE = args.condition_type + SAVE_ROOT_DIR = f"./coco_baseline/results_{CONDITION_TYPE}_1000/" + + # Load your Flux pipeline + pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.float16) # Replace with your model path + + ### FOR OMINI + + pipe.load_lora_weights( + "Yuanshi/OminiControl", + weight_name=f"experimental/{CONDITION_TYPE}.safetensors", + adapter_name=CONDITION_TYPE, + ) + pipe.fuse_lora(lora_scale=1.0) + pipe.unload_lora_weights() + + # pipe.set_adapters([CONDITION_TYPE]) + pipe = pipe.to("cuda") + + + # Evaluate on COCO + evaluate( + pipe, + condition_type=CONDITION_TYPE, + caption_file=CAPTION_FILE, + image_dir=IMAGE_DIR, + save_root_dir=SAVE_ROOT_DIR, + num_images=NUM_IMAGES, + start_index=START_INDEX, + ) \ No newline at end of file diff --git a/evaluation_subject_driven.py b/evaluation_subject_driven.py new file mode 100644 index 0000000000000000000000000000000000000000..3eb245996f9b87188123218777221d0f9978aa60 --- /dev/null +++ b/evaluation_subject_driven.py @@ -0,0 +1,362 @@ +import openai +import base64 +from pathlib import Path +import random +import os + + + +evaluation_prompts = { + "identity": """ + Compare the original subject image with the generated image. + Rate on a scale of 1-5 how well the essential identifying features + are preserved (logos, brand marks, distinctive patterns). + Score: [1-5] + Reasoning: [explanation] + """, + + "material": """ + Evaluate the material quality and surface characteristics. + Rate on a scale of 1-5 how accurately materials are represented + (textures, reflections, surface properties). + Score: [1-5] + Reasoning: [explanation] + """, + + "color": """ + Assess color fidelity in regions NOT specified for modification. + Rate on a scale of 1-5 how consistent colors remain. + Score: [1-5] + Reasoning: [explanation] + """, + + "appearance": """ + Evaluate the overall realism and coherence of the generated image. + Rate on a scale of 1-5 how realistic and natural it appears. + Score: [1-5] + Reasoning: [explanation] + """, + + "modification": """ + Given the text prompt: "{prompt}" + Rate on a scale of 1-5 how well the specified changes are executed. + Score: [1-5] + Reasoning: [explanation] + """ +} + + +def encode_image(image_path): + with open(image_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode('utf-8') + +def evaluate_subject_driven_generation( + original_image_path, + generated_image_path, + text_prompt, + client +): + """ + Evaluate a subject-driven generation using GPT-4o vision + """ + + # Encode images + original_img = encode_image(original_image_path) + generated_img = encode_image(generated_image_path) + + results = {} + + # 1. Identity Preservation + response = client.chat.completions.create( + model="gpt-4o", + messages=[{ + "role": "user", + "content": [ + {"type": "text", "text": "Original subject image:"}, + {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{original_img}"}}, + {"type": "text", "text": "Generated image:"}, + {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{generated_img}"}}, + {"type": "text", "text": evaluation_prompts["identity"]} + ] + }], + max_tokens=300 + ) + results['identity'] = parse_score(response.choices[0].message.content) + + # 2. Material Quality + response = client.chat.completions.create( + model="gpt-4o", + messages=[{ + "role": "user", + "content": [ + {"type": "text", "text": "Evaluate this generated image:"}, + {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{generated_img}"}}, + {"type": "text", "text": evaluation_prompts["material"]} + ] + }], + max_tokens=300 + ) + results['material'] = parse_score(response.choices[0].message.content) + + # 3. Color Fidelity + response = client.chat.completions.create( + model="gpt-4o", + messages=[{ + "role": "user", + "content": [ + {"type": "text", "text": "Original:"}, + {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{original_img}"}}, + {"type": "text", "text": "Generated:"}, + {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{generated_img}"}}, + {"type": "text", "text": evaluation_prompts["color"]} + ] + }], + max_tokens=300 + ) + results['color'] = parse_score(response.choices[0].message.content) + + # 4. Natural Appearance + response = client.chat.completions.create( + model="gpt-4o", + messages=[{ + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{generated_img}"}}, + {"type": "text", "text": evaluation_prompts["appearance"]} + ] + }], + max_tokens=300 + ) + results['appearance'] = parse_score(response.choices[0].message.content) + + # 5. Modification Accuracy + response = client.chat.completions.create( + model="gpt-4o", + messages=[{ + "role": "user", + "content": [ + {"type": "text", "text": f"Text prompt: {text_prompt}"}, + {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{generated_img}"}}, + {"type": "text", "text": evaluation_prompts["modification"].format(prompt=text_prompt)} + ] + }], + max_tokens=300 + ) + results['modification'] = parse_score(response.choices[0].message.content) + + return results + +def parse_score(response_text): + """Extract score from GPT-4o response""" + # Look for "Score: X" pattern + import re + match = re.search(r'Score:\s*(\d+)', response_text) + if match: + return int(match.group(1)) + return None + +subject_names = [ + "backpack", "backpack_dog", "bear_plushie", "berry_bowl", "can", + "candle", "cat", "cat2", "clock", "colorful_sneaker", + "dog", "dog2", "dog3", "dog5", "dog6", + "dog7", "dog8", "duck_toy", "fancy_boot", "grey_sloth_plushie", + "monster_toy", "pink_sunglasses", "poop_emoji", "rc_car", "red_cartoon", + "robot_toy", "shiny_sneaker", "teapot", "vase", "wolf_plushie" +] + + +def get_prompt(subject_id, prompt_id): + + # subject in 0|1|2|3|4|5|8|9|17|18|19|20|21|22|23|24|25|26|27|28|29 + if subject_id in [0,1,2,3,4,5,8,9,17,18,19,20,21,22,23,24,25,26,27,28,29]: + subject_name = subject_names[subject_id] + prompts = [ + f"a {subject_name} in the jungle", + f"a {subject_name} in the snow", + f"a {subject_name} on the beach", + f"a {subject_name} on a cobblestone street", + f"a {subject_name} on top of pink fabric", + f"a {subject_name} on top of a wooden floor", + f"a {subject_name} with a city in the background", + f"a {subject_name} with a mountain in the background", + f"a {subject_name} with a blue house in the background", + f"a {subject_name} on top of a purple rug in a forest", + f"a {subject_name} with a wheat field in the background", + f"a {subject_name} with a tree and autumn leaves in the background", + f"a {subject_name} with the Eiffel Tower in the background", + f"a {subject_name} floating on top of water", + f"a {subject_name} floating in an ocean of milk", + f"a {subject_name} on top of green grass with sunflowers around it", + f"a {subject_name} on top of a mirror", + f"a {subject_name} on top of the sidewalk in a crowded street", + f"a {subject_name} on top of a dirt road", + f"a {subject_name} on top of a white rug", + f"a red {subject_name}", + f"a purple {subject_name}", + f"a shiny {subject_name}", + f"a wet {subject_name}", + f"a cube shaped {subject_name}" + ] + + else: + prompts = [ + f"a {subject_name} in the jungle", + f"a {subject_name} in the snow", + f"a {subject_name} on the beach", + f"a {subject_name} on a cobblestone street", + f"a {subject_name} on top of pink fabric", + f"a {subject_name} on top of a wooden floor", + f"a {subject_name} with a city in the background", + f"a {subject_name} with a mountain in the background", + f"a {subject_name} with a blue house in the background", + f"a {subject_name} on top of a purple rug in a forest", + f"a {subject_name} wearing a red hat", + f"a {subject_name} wearing a santa hat", + f"a {subject_name} wearing a rainbow scarf", + f"a {subject_name} wearing a black top hat and a monocle", + f"a {subject_name} in a chef outfit", + f"a {subject_name} in a firefighter outfit", + f"a {subject_name} in a police outfit", + f"a {subject_name} wearing pink glasses", + f"a {subject_name} wearing a yellow shirt", + f"a {subject_name} in a purple wizard outfit", + f"a red {subject_name}", + f"a purple {subject_name}", + f"a shiny {subject_name}", + f"a wet {subject_name}", + f"a cube shaped {subject_name}" + ] + + return prompts[prompt_id] + + + + + +def batch_evaluate_dreambooth(client, generate_fn, dataset_path, output_csv): + """ + Evaluate 750 image pairs with 5 seeds each + """ + import pandas as pd + + results_list = [] + + # Iterate through DreamBooth dataset + for subject_id in range(30): # 30 subjects + subject_name = subject_names[subject_id] + for prompt_id in range(25): # 25 prompts per subject + original = f"{dataset_path}/{subject_name}" + # get a random file in this folder + original_files = list(Path(original).glob("*.png")) + if len(original_files) == 0: + raise ValueError(f"No original images found in {original}") + + original = str(original_files[0]) + + + for seed in range(5): # 5 different seeds + # take random file in the folder + prompt = get_prompt(subject_id, prompt_id) + + # generated image path + generated_folder = f"{dataset_path}/{subject_name}/generated/" + os.makedirs(generated_folder, exist_ok=True) + generated = f"{generated_folder}/gen_seed{seed}_prompt{prompt_id}.png" + + generate_fn( + prompt=prompt, + subject_image_path=original, + output_image_path=generated, + seed=seed + ) + + scores = evaluate_subject_driven_generation( + original, generated, prompt, client + ) + + results_list.append({ + 'subject_id': subject_id, + 'subject_name': subject_name, + 'prompt_id': prompt_id, + 'seed': seed, + 'prompt': prompt, + + **scores + }) + + # Save results + df = pd.DataFrame(results_list) + df.to_csv(output_csv, index=False) + + # Calculate statistics + print(df.groupby('subject_id').mean()) + print(f"\nOverall averages:") + print(df[['identity', 'material', 'color', 'appearance', 'modification']].mean()) + + +def evaluate_omini_control(): + + import torch + from diffusers.pipelines import FluxPipeline + from PIL import Image + + from omini.pipeline.flux_omini import Condition, generate, seed_everything + + pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16 + ) + + pipe = pipe.to("cuda") + pipe.load_lora_weights( + "Yuanshi/OminiControl", + weight_name=f"omini/subject_512.safetensors", + adapter_name="subject", + ) + + def generate_fn(image_path, prompt, seed, output_path): + seed_everything(seed) + + image = Image.open(image_path).convert("RGB").resize((512, 512)) + condition = Condition.from_image( + image, + "subject", position_delta=(0, 32) + ) + + result_img = generate( + pipe, + prompt=prompt, + conditions=[condition], + ).images[0] + + result_img.save(output_path) + + return generate_fn + + +if __name__ == "__main__": + + + + openai.api_key = os.getenv("OPENAI_API_KEY") + # client = openai.Client() + + # generate_fn = evaluate_omini_control() + + # dataset_path = "data/dreambooth" + # output_csv = "evaluation_subject_driven_omini_control.csv" + + # batch_evaluate_dreambooth( + # client, + # generate_fn, + # dataset_path, + # output_csv + # ) + + result = evaluate_subject_driven_generation( + "data/dreambooth/backpack/00.jpg", + "data/dreambooth/backpack/01.jpg", + "a backpack in the jungle", + openai.Client() + ) + + print(result) \ No newline at end of file diff --git a/examples/combine_with_style_lora.ipynb b/examples/combine_with_style_lora.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..c212ca61014e42c6de001a70ae80d19ab5ce230b --- /dev/null +++ b/examples/combine_with_style_lora.ipynb @@ -0,0 +1,235 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.chdir(\"..\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from diffusers.pipelines import FluxPipeline\n", + "from PIL import Image\n", + "\n", + "from omini.pipeline.flux_omini import Condition, generate, seed_everything" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "pipe = FluxPipeline.from_pretrained(\n", + " \"black-forest-labs/FLUX.1-schnell\", torch_dtype=torch.bfloat16\n", + ")\n", + "pipe = pipe.to(\"cuda\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pipe.unload_lora_weights()\n", + "\n", + "pipe.load_lora_weights(\n", + " \"Yuanshi/OminiControl\",\n", + " weight_name=f\"omini/subject_512.safetensors\",\n", + " adapter_name=\"subject\",\n", + ")\n", + "pipe.load_lora_weights(\"XLabs-AI/flux-RealismLora\", adapter_name=\"realism\")\n", + "\n", + "pipe.set_adapters([\"subject\", \"realism\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "image = Image.open(\"assets/penguin.jpg\").convert(\"RGB\").resize((512, 512))\n", + "\n", + "# For this model, the position_delta is (0, 32).\n", + "# For more details of position_delta, please refer to:\n", + "# https://github.com/Yuanshi9815/OminiControl/issues/89#issuecomment-2827080344\n", + "condition = Condition(image, \"subject\", position_delta=(0, 32))\n", + "\n", + "prompt = \"On Christmas evening, on a crowded sidewalk, this item sits on the road, covered in snow and wearing a Christmas hat.\"\n", + "\n", + "\n", + "seed_everything(0)\n", + "\n", + "result_img = generate(\n", + " pipe,\n", + " prompt=prompt,\n", + " conditions=[condition],\n", + " num_inference_steps=8,\n", + " height=512,\n", + " width=512,\n", + " main_adapter=\"realism\"\n", + ").images[0]\n", + "\n", + "concat_image = Image.new(\"RGB\", (1024, 512))\n", + "concat_image.paste(image, (0, 0))\n", + "concat_image.paste(result_img, (512, 0))\n", + "concat_image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "image = Image.open(\"assets/tshirt.jpg\").convert(\"RGB\").resize((512, 512))\n", + "\n", + "condition = Condition(image, \"subject\", position_delta=(0, 32))\n", + "\n", + "prompt = \"On the beach, a lady sits under a beach umbrella. She's wearing this shirt and has a big smile on her face, with her surfboard hehind her. The sun is setting in the background. The sky is a beautiful shade of orange and purple.\"\n", + "\n", + "\n", + "seed_everything()\n", + "\n", + "result_img = generate(\n", + " pipe,\n", + " prompt=prompt,\n", + " conditions=[condition],\n", + " num_inference_steps=8,\n", + " height=512,\n", + " width=512,\n", + " main_adapter=\"realism\"\n", + ").images[0]\n", + "\n", + "concat_image = Image.new(\"RGB\", (1024, 512))\n", + "concat_image.paste(condition.condition, (0, 0))\n", + "concat_image.paste(result_img, (512, 0))\n", + "concat_image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "image = Image.open(\"assets/rc_car.jpg\").convert(\"RGB\").resize((512, 512))\n", + "\n", + "condition = Condition(image, \"subject\", position_delta=(0, 32))\n", + "\n", + "prompt = \"A film style shot. On the moon, this item drives across the moon surface. The background is that Earth looms large in the foreground.\"\n", + "\n", + "seed_everything()\n", + "\n", + "result_img = generate(\n", + " pipe,\n", + " prompt=prompt,\n", + " conditions=[condition],\n", + " num_inference_steps=8,\n", + " height=512,\n", + " width=512,\n", + " main_adapter=\"realism\"\n", + ").images[0]\n", + "\n", + "concat_image = Image.new(\"RGB\", (1024, 512))\n", + "concat_image.paste(condition.condition, (0, 0))\n", + "concat_image.paste(result_img, (512, 0))\n", + "concat_image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "image = Image.open(\"assets/clock.jpg\").convert(\"RGB\").resize((512, 512))\n", + "\n", + "condition = Condition(image, \"subject\", position_delta=(0, 32))\n", + "\n", + "prompt = \"In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.\"\n", + "\n", + "seed_everything()\n", + "\n", + "result_img = generate(\n", + " pipe,\n", + " prompt=prompt,\n", + " conditions=[condition],\n", + " num_inference_steps=8,\n", + " height=512,\n", + " width=512,\n", + " main_adapter=\"realism\"\n", + ").images[0]\n", + "\n", + "concat_image = Image.new(\"RGB\", (1024, 512))\n", + "concat_image.paste(condition.condition, (0, 0))\n", + "concat_image.paste(result_img, (512, 0))\n", + "concat_image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "image = Image.open(\"assets/oranges.jpg\").convert(\"RGB\").resize((512, 512))\n", + "\n", + "condition = Condition(image, \"subject\", position_delta=(0, 32))\n", + "\n", + "prompt = \"A very close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show.\"\n", + "\n", + "seed_everything()\n", + "\n", + "result_img = generate(\n", + " pipe,\n", + " prompt=prompt,\n", + " conditions=[condition],\n", + " num_inference_steps=8,\n", + " height=512,\n", + " width=512,\n", + " main_adapter=\"realism\"\n", + ").images[0]\n", + "\n", + "concat_image = Image.new(\"RGB\", (1024, 512))\n", + "concat_image.paste(condition.condition, (0, 0))\n", + "concat_image.paste(result_img, (512, 0))\n", + "concat_image" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.21" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/inpainting.ipynb b/examples/inpainting.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..21ac662cc6187188bcf799cb3919208198c41235 --- /dev/null +++ b/examples/inpainting.ipynb @@ -0,0 +1,135 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.chdir(\"..\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from diffusers.pipelines import FluxPipeline\n", + "from PIL import Image\n", + "\n", + "from omini.pipeline.flux_omini import Condition, generate, seed_everything" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pipe = FluxPipeline.from_pretrained(\n", + " \"black-forest-labs/FLUX.1-dev\", torch_dtype=torch.bfloat16\n", + ")\n", + "pipe = pipe.to(\"cuda\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pipe.load_lora_weights(\n", + " \"Yuanshi/OminiControl\",\n", + " weight_name=f\"experimental/fill.safetensors\",\n", + " adapter_name=\"fill\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "image = Image.open(\"assets/monalisa.jpg\").convert(\"RGB\").resize((512, 512))\n", + "\n", + "masked_image = image.copy()\n", + "masked_image.paste((0, 0, 0), (128, 100, 384, 220))\n", + "\n", + "condition = Condition(masked_image, \"fill\")\n", + "\n", + "seed_everything()\n", + "result_img = generate(\n", + " pipe,\n", + " prompt=\"The Mona Lisa is wearing a white VR headset with 'Omini' written on it.\",\n", + " conditions=[condition],\n", + ").images[0]\n", + "\n", + "concat_image = Image.new(\"RGB\", (1536, 512))\n", + "concat_image.paste(image, (0, 0))\n", + "concat_image.paste(condition.condition, (512, 0))\n", + "concat_image.paste(result_img, (1024, 0))\n", + "concat_image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "image = Image.open(\"assets/book.jpg\").convert(\"RGB\").resize((512, 512))\n", + "\n", + "w, h, min_dim = image.size + (min(image.size),)\n", + "image = image.crop(\n", + " ((w - min_dim) // 2, (h - min_dim) // 2, (w + min_dim) // 2, (h + min_dim) // 2)\n", + ").resize((512, 512))\n", + "\n", + "\n", + "masked_image = image.copy()\n", + "masked_image.paste((0, 0, 0), (150, 150, 350, 250))\n", + "masked_image.paste((0, 0, 0), (200, 380, 320, 420))\n", + "\n", + "condition = Condition(masked_image, \"fill\")\n", + "\n", + "seed_everything()\n", + "result_img = generate(\n", + " pipe,\n", + " prompt=\"A yellow book with the word 'OMINI' in large font on the cover. The text 'for FLUX' appears at the bottom.\",\n", + " conditions=[condition],\n", + ").images[0]\n", + "\n", + "concat_image = Image.new(\"RGB\", (1536, 512))\n", + "concat_image.paste(image, (0, 0))\n", + "concat_image.paste(condition.condition, (512, 0))\n", + "concat_image.paste(result_img, (1024, 0))\n", + "concat_image" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/ominicontrol_art.ipynb b/examples/ominicontrol_art.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..70cfad066c997c38e30a6ab2dad5615ce5c3ded8 --- /dev/null +++ b/examples/ominicontrol_art.ipynb @@ -0,0 +1,218 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.chdir(\"..\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from diffusers.pipelines import FluxPipeline\n", + "from PIL import Image\n", + "\n", + "from omini.pipeline.flux_omini import Condition, generate, seed_everything" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pipe = FluxPipeline.from_pretrained(\n", + " \"black-forest-labs/FLUX.1-dev\", torch_dtype=torch.bfloat16\n", + ")\n", + "pipe = pipe.to(\"cuda\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pipe.unload_lora_weights()\n", + "\n", + "for style_type in [\"ghibli\", \"irasutoya\", \"simpsons\", \"snoopy\"]:\n", + " pipe.load_lora_weights(\n", + " \"Yuanshi/OminiControlArt\",\n", + " weight_name=f\"v0/{style_type}.safetensors\",\n", + " adapter_name=style_type,\n", + " )\n", + "\n", + "pipe.set_adapters([\"ghibli\", \"irasutoya\", \"simpsons\", \"snoopy\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def resize(img, factor=16):\n", + " # Resize the image to be divisible by the factor\n", + " w, h = img.size\n", + " new_w, new_h = w // factor * factor, h // factor * factor\n", + " padding_w, padding_h = (w - new_w) // 2, (h - new_h) // 2\n", + " img = img.crop((padding_w, padding_h, new_w + padding_w, new_h + padding_h))\n", + " return img\n", + "\n", + "\n", + "def bound_image(image):\n", + " factor = 512 / max(image.size)\n", + " image = resize(\n", + " image.resize(\n", + " (int(image.size[0] * factor), int(image.size[1] * factor)),\n", + " Image.LANCZOS,\n", + " )\n", + " )\n", + " delta = (0, -image.size[0] // 16)\n", + " return image, delta\n", + "\n", + "sizes = {\n", + " \"2:3\": (640, 960),\n", + " \"1:1\": (640, 640),\n", + " \"3:2\": (960, 640),\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "image = Image.open(\"assets/ominicontrol_art/DistractedBoyfriend.webp\").convert(\"RGB\")\n", + "image, delta = bound_image(image)\n", + "condition = Condition(image, \"ghibli\", position_delta=delta)\n", + "\n", + "seed_everything()\n", + "\n", + "size = sizes[\"3:2\"]\n", + "\n", + "result_img = generate(\n", + " pipe,\n", + " prompt=\"\",\n", + " conditions=[condition],\n", + " max_sequence_length=32,\n", + " width=size[0],\n", + " height=size[1],\n", + " image_guidance_scale=1.5,\n", + ").images[0]\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "image = Image.open(\"assets/ominicontrol_art/oiiai.png\").convert(\"RGB\")\n", + "image, delta = bound_image(image)\n", + "condition = Condition(image, \"irasutoya\", position_delta=delta)\n", + "\n", + "seed_everything()\n", + "\n", + "size = sizes[\"1:1\"]\n", + "\n", + "result_img = generate(\n", + " pipe,\n", + " prompt=\"\",\n", + " conditions=[condition],\n", + " max_sequence_length=32,\n", + " width=size[0],\n", + " height=size[1],\n", + " image_guidance_scale=1.5,\n", + ").images[0]\n", + "\n", + "result_img" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "image = Image.open(\"assets/ominicontrol_art/breakingbad.jpg\").convert(\"RGB\")\n", + "image, delta = bound_image(image)\n", + "condition = Condition(image, \"simpsons\", position_delta=delta)\n", + "\n", + "seed_everything()\n", + "\n", + "size = sizes[\"3:2\"]\n", + "\n", + "result_img = generate(\n", + " pipe,\n", + " prompt=\"\",\n", + " conditions=[condition],\n", + " max_sequence_length=32,\n", + " width=size[0],\n", + " height=size[1],\n", + " image_guidance_scale=1.5,\n", + ").images[0]\n", + "\n", + "result_img" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "image = Image.open(\"assets/ominicontrol_art/PulpFiction.jpg\").convert(\"RGB\")\n", + "image, delta = bound_image(image)\n", + "condition = Condition(image, \"snoopy\", position_delta=delta)\n", + "\n", + "seed_everything()\n", + "\n", + "size = sizes[\"3:2\"]\n", + "\n", + "result_img = generate(\n", + " pipe,\n", + " prompt=\"\",\n", + " conditions=[condition],\n", + " max_sequence_length=32,\n", + " width=size[0],\n", + " height=size[1],\n", + " image_guidance_scale=1.5,\n", + ").images[0]\n", + "\n", + "result_img" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.21" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/spatial.ipynb b/examples/spatial.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..e22169b428b9ea771aaed7506e8298170f9e55a3 --- /dev/null +++ b/examples/spatial.ipynb @@ -0,0 +1,191 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.chdir(\"..\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from diffusers.pipelines import FluxPipeline\n", + "from PIL import Image\n", + "\n", + "from omini.pipeline.flux_omini import Condition, generate, seed_everything, convert_to_condition" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pipe = FluxPipeline.from_pretrained(\n", + " \"black-forest-labs/FLUX.1-dev\", torch_dtype=torch.bfloat16\n", + ")\n", + "pipe = pipe.to(\"cuda\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pipe.unload_lora_weights()\n", + "\n", + "for condition_type in [\"canny\", \"depth\", \"coloring\", \"deblurring\"]:\n", + " pipe.load_lora_weights(\n", + " \"Yuanshi/OminiControl\",\n", + " weight_name=f\"experimental/{condition_type}.safetensors\",\n", + " adapter_name=condition_type,\n", + " )\n", + "\n", + "pipe.set_adapters([\"canny\", \"depth\", \"coloring\", \"deblurring\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "image = Image.open(\"assets/coffee.png\").convert(\"RGB\")\n", + "\n", + "w, h, min_dim = image.size + (min(image.size),)\n", + "image = image.crop(\n", + " ((w - min_dim) // 2, (h - min_dim) // 2, (w + min_dim) // 2, (h + min_dim) // 2)\n", + ").resize((512, 512))\n", + "\n", + "prompt = \"In a bright room. A cup of a coffee with some beans on the side. They are placed on a dark wooden table.\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "canny_image = convert_to_condition(\"canny\", image)\n", + "condition = Condition(canny_image, \"canny\")\n", + "\n", + "seed_everything()\n", + "\n", + "result_img = generate(\n", + " pipe,\n", + " prompt=prompt,\n", + " conditions=[condition],\n", + ").images[0]\n", + "\n", + "concat_image = Image.new(\"RGB\", (1536, 512))\n", + "concat_image.paste(image, (0, 0))\n", + "concat_image.paste(condition.condition, (512, 0))\n", + "concat_image.paste(result_img, (1024, 0))\n", + "concat_image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "depth_image = convert_to_condition(\"depth\", image)\n", + "condition = Condition(depth_image, \"depth\")\n", + "\n", + "seed_everything()\n", + "\n", + "result_img = generate(\n", + " pipe,\n", + " prompt=prompt,\n", + " conditions=[condition],\n", + ").images[0]\n", + "\n", + "concat_image = Image.new(\"RGB\", (1536, 512))\n", + "concat_image.paste(image, (0, 0))\n", + "concat_image.paste(condition.condition, (512, 0))\n", + "concat_image.paste(result_img, (1024, 0))\n", + "concat_image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "blur_image = convert_to_condition(\"deblurring\", image)\n", + "condition = Condition(blur_image, \"deblurring\")\n", + "\n", + "seed_everything()\n", + "\n", + "result_img = generate(\n", + " pipe,\n", + " prompt=prompt,\n", + " conditions=[condition],\n", + ").images[0]\n", + "\n", + "concat_image = Image.new(\"RGB\", (1536, 512))\n", + "concat_image.paste(image, (0, 0))\n", + "concat_image.paste(condition.condition, (512, 0))\n", + "concat_image.paste(result_img, (1024, 0))\n", + "concat_image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "condition_image = convert_to_condition(\"coloring\", image)\n", + "condition = Condition(condition_image, \"coloring\")\n", + "\n", + "seed_everything()\n", + "\n", + "result_img = generate(\n", + " pipe,\n", + " prompt=prompt,\n", + " conditions=[condition],\n", + ").images[0]\n", + "\n", + "concat_image = Image.new(\"RGB\", (1536, 512))\n", + "concat_image.paste(image, (0, 0))\n", + "concat_image.paste(condition.condition, (512, 0))\n", + "concat_image.paste(result_img, (1024, 0))\n", + "concat_image" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..6d0bc83f27098b2cf17e27de67797092ff6bbee1 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,9 @@ +transformers +diffusers +peft +opencv-python +protobuf +sentencepiece +gradio +jupyter +torchao \ No newline at end of file