| | import torch |
| | import os |
| | import random |
| |
|
| | from PIL import Image, ImageDraw |
| |
|
| | from datasets import load_dataset |
| |
|
| | from .trainer import OminiModel, get_config, train |
| | from ..pipeline.flux_omini import Condition, convert_to_condition, generate |
| | from .train_spatial_alignment import ImageConditionDataset |
| |
|
| |
|
| | class ImageMultiConditionDataset(ImageConditionDataset): |
| | def __getitem__(self, idx): |
| | image = self.base_dataset[idx]["jpg"] |
| | image = image.resize(self.target_size).convert("RGB") |
| | description = self.base_dataset[idx]["json"]["prompt"] |
| |
|
| | condition_size = self.condition_size |
| | position_scale = self.position_scale |
| |
|
| | condition_imgs, position_deltas = [], [] |
| | for c_type in self.condition_type: |
| | condition_img, position_delta = self.__get_condition__(image, c_type) |
| | condition_imgs.append(condition_img.convert("RGB")) |
| | position_deltas.append(position_delta) |
| |
|
| | |
| | drop_text = random.random() < self.drop_text_prob |
| | drop_image = random.random() < self.drop_image_prob |
| |
|
| | if drop_text: |
| | description = "" |
| | if drop_image: |
| | condition_imgs = [ |
| | Image.new("RGB", condition_size) |
| | for _ in range(len(self.condition_type)) |
| | ] |
| |
|
| | return_dict = { |
| | "image": self.to_tensor(image), |
| | "description": description, |
| | **({"pil_image": [image, condition_img]} if self.return_pil_image else {}), |
| | } |
| |
|
| | for i, c_type in enumerate(self.condition_type): |
| | return_dict[f"condition_{i}"] = self.to_tensor(condition_imgs[i]) |
| | return_dict[f"condition_type_{i}"] = self.condition_type[i] |
| | return_dict[f"position_delta_{i}"] = position_deltas[i] |
| | return_dict[f"position_scale_{i}"] = position_scale |
| |
|
| | return return_dict |
| |
|
| |
|
| | @torch.no_grad() |
| | def test_function(model, save_path, file_name): |
| | condition_size = model.training_config["dataset"]["condition_size"] |
| | target_size = model.training_config["dataset"]["target_size"] |
| |
|
| | position_delta = model.training_config["dataset"].get("position_delta", [0, 0]) |
| | position_scale = model.training_config["dataset"].get("position_scale", 1.0) |
| |
|
| | condition_type = model.training_config["condition_type"] |
| | test_list = [] |
| |
|
| | condition_list = [] |
| | for i, c_type in enumerate(condition_type): |
| | if c_type in ["canny", "coloring", "deblurring", "depth"]: |
| | image = Image.open("assets/vase_hq.jpg") |
| | image = image.resize(condition_size) |
| | condition_img = convert_to_condition(c_type, image, 5) |
| | elif c_type == "fill": |
| | condition_img = image.resize(condition_size).convert("RGB") |
| | w, h = image.size |
| | x1, x2 = sorted([random.randint(0, w), random.randint(0, w)]) |
| | y1, y2 = sorted([random.randint(0, h), random.randint(0, h)]) |
| | mask = Image.new("L", image.size, 0) |
| | draw = ImageDraw.Draw(mask) |
| | draw.rectangle([x1, y1, x2, y2], fill=255) |
| | if random.random() > 0.5: |
| | mask = Image.eval(mask, lambda a: 255 - a) |
| | condition_img = Image.composite( |
| | image, Image.new("RGB", image.size, (0, 0, 0)), mask |
| | ) |
| | else: |
| | raise NotImplementedError |
| | condition = Condition( |
| | condition_img, |
| | model.adapter_names[i + 2], |
| | position_delta, |
| | position_scale, |
| | ) |
| | condition_list.append(condition) |
| | test_list.append((condition_list, "A beautiful vase on a table.")) |
| | os.makedirs(save_path, exist_ok=True) |
| | for i, (condition, prompt) in enumerate(test_list): |
| | generator = torch.Generator(device=model.device) |
| | generator.manual_seed(42) |
| | |
| | res = generate( |
| | model.flux_pipe, |
| | prompt=prompt, |
| | conditions=condition_list, |
| | height=target_size[0], |
| | width=target_size[1], |
| | generator=generator, |
| | model_config=model.model_config, |
| | kv_cache=model.model_config.get("independent_condition", False), |
| | ) |
| | file_path = os.path.join( |
| | save_path, f"{file_name}_{'|'.join(condition_type)}_{i}.jpg" |
| | ) |
| | res.images[0].save(file_path) |
| |
|
| |
|
| | def main(): |
| | |
| | config = get_config() |
| | training_config = config["train"] |
| | torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0))) |
| |
|
| | |
| | dataset = load_dataset( |
| | "webdataset", |
| | data_files={"train": training_config["dataset"]["urls"]}, |
| | split="train", |
| | cache_dir="cache/t2i2m", |
| | num_proc=32, |
| | ) |
| | dataset = ImageMultiConditionDataset( |
| | dataset, |
| | condition_size=training_config["dataset"]["condition_size"], |
| | target_size=training_config["dataset"]["target_size"], |
| | condition_type=training_config["condition_type"], |
| | drop_text_prob=training_config["dataset"]["drop_text_prob"], |
| | drop_image_prob=training_config["dataset"]["drop_image_prob"], |
| | position_scale=training_config["dataset"].get("position_scale", 1.0), |
| | ) |
| |
|
| | cond_n = len(training_config["condition_type"]) |
| |
|
| | |
| | trainable_model = OminiModel( |
| | flux_pipe_id=config["flux_path"], |
| | lora_config=training_config["lora_config"], |
| | device=f"cuda", |
| | dtype=getattr(torch, config["dtype"]), |
| | optimizer_config=training_config["optimizer"], |
| | model_config=config.get("model", {}), |
| | gradient_checkpointing=training_config.get("gradient_checkpointing", False), |
| | adapter_names=[None, None, *["default"] * cond_n], |
| | |
| | ) |
| |
|
| | train(dataset, trainable_model, config, test_function) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|