Instructions to use lsmpp/kontextrefiner with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use lsmpp/kontextrefiner with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("lsmpp/kontextrefiner", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| import argparse | |
| import itertools | |
| import json | |
| import os | |
| import random | |
| import time | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from accelerate import Accelerator | |
| from accelerate.utils import ProjectConfiguration | |
| from ip_adapter.ip_adapter import ImageProjModel | |
| from ip_adapter.utils import is_torch2_available | |
| from PIL import Image | |
| from torchvision import transforms | |
| from transformers import ( | |
| CLIPImageProcessor, | |
| CLIPTextModel, | |
| CLIPTextModelWithProjection, | |
| CLIPTokenizer, | |
| CLIPVisionModelWithProjection, | |
| ) | |
| from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel | |
| if is_torch2_available(): | |
| from ip_adapter.attention_processor import AttnProcessor2_0 as AttnProcessor | |
| from ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor | |
| else: | |
| from ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor | |
| # Dataset | |
| class MyDataset(torch.utils.data.Dataset): | |
| def __init__( | |
| self, | |
| json_file, | |
| tokenizer, | |
| tokenizer_2, | |
| size=1024, | |
| center_crop=True, | |
| t_drop_rate=0.05, | |
| i_drop_rate=0.05, | |
| ti_drop_rate=0.05, | |
| image_root_path="", | |
| ): | |
| super().__init__() | |
| self.tokenizer = tokenizer | |
| self.tokenizer_2 = tokenizer_2 | |
| self.size = size | |
| self.center_crop = center_crop | |
| self.i_drop_rate = i_drop_rate | |
| self.t_drop_rate = t_drop_rate | |
| self.ti_drop_rate = ti_drop_rate | |
| self.image_root_path = image_root_path | |
| self.data = json.load(open(json_file)) # list of dict: [{"image_file": "1.png", "text": "A dog"}] | |
| self.transform = transforms.Compose( | |
| [ | |
| transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5], [0.5]), | |
| ] | |
| ) | |
| self.clip_image_processor = CLIPImageProcessor() | |
| def __getitem__(self, idx): | |
| item = self.data[idx] | |
| text = item["text"] | |
| image_file = item["image_file"] | |
| # read image | |
| raw_image = Image.open(os.path.join(self.image_root_path, image_file)) | |
| # original size | |
| original_width, original_height = raw_image.size | |
| original_size = torch.tensor([original_height, original_width]) | |
| image_tensor = self.transform(raw_image.convert("RGB")) | |
| # random crop | |
| delta_h = image_tensor.shape[1] - self.size | |
| delta_w = image_tensor.shape[2] - self.size | |
| assert not all([delta_h, delta_w]) | |
| if self.center_crop: | |
| top = delta_h // 2 | |
| left = delta_w // 2 | |
| else: | |
| top = np.random.randint(0, delta_h + 1) | |
| left = np.random.randint(0, delta_w + 1) | |
| image = transforms.functional.crop(image_tensor, top=top, left=left, height=self.size, width=self.size) | |
| crop_coords_top_left = torch.tensor([top, left]) | |
| clip_image = self.clip_image_processor(images=raw_image, return_tensors="pt").pixel_values | |
| # drop | |
| drop_image_embed = 0 | |
| rand_num = random.random() | |
| if rand_num < self.i_drop_rate: | |
| drop_image_embed = 1 | |
| elif rand_num < (self.i_drop_rate + self.t_drop_rate): | |
| text = "" | |
| elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate): | |
| text = "" | |
| drop_image_embed = 1 | |
| # get text and tokenize | |
| text_input_ids = self.tokenizer( | |
| text, | |
| max_length=self.tokenizer.model_max_length, | |
| padding="max_length", | |
| truncation=True, | |
| return_tensors="pt", | |
| ).input_ids | |
| text_input_ids_2 = self.tokenizer_2( | |
| text, | |
| max_length=self.tokenizer_2.model_max_length, | |
| padding="max_length", | |
| truncation=True, | |
| return_tensors="pt", | |
| ).input_ids | |
| return { | |
| "image": image, | |
| "text_input_ids": text_input_ids, | |
| "text_input_ids_2": text_input_ids_2, | |
| "clip_image": clip_image, | |
| "drop_image_embed": drop_image_embed, | |
| "original_size": original_size, | |
| "crop_coords_top_left": crop_coords_top_left, | |
| "target_size": torch.tensor([self.size, self.size]), | |
| } | |
| def __len__(self): | |
| return len(self.data) | |
| def collate_fn(data): | |
| images = torch.stack([example["image"] for example in data]) | |
| text_input_ids = torch.cat([example["text_input_ids"] for example in data], dim=0) | |
| text_input_ids_2 = torch.cat([example["text_input_ids_2"] for example in data], dim=0) | |
| clip_images = torch.cat([example["clip_image"] for example in data], dim=0) | |
| drop_image_embeds = [example["drop_image_embed"] for example in data] | |
| original_size = torch.stack([example["original_size"] for example in data]) | |
| crop_coords_top_left = torch.stack([example["crop_coords_top_left"] for example in data]) | |
| target_size = torch.stack([example["target_size"] for example in data]) | |
| return { | |
| "images": images, | |
| "text_input_ids": text_input_ids, | |
| "text_input_ids_2": text_input_ids_2, | |
| "clip_images": clip_images, | |
| "drop_image_embeds": drop_image_embeds, | |
| "original_size": original_size, | |
| "crop_coords_top_left": crop_coords_top_left, | |
| "target_size": target_size, | |
| } | |
| class IPAdapter(torch.nn.Module): | |
| """IP-Adapter""" | |
| def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None): | |
| super().__init__() | |
| self.unet = unet | |
| self.image_proj_model = image_proj_model | |
| self.adapter_modules = adapter_modules | |
| if ckpt_path is not None: | |
| self.load_from_checkpoint(ckpt_path) | |
| def forward(self, noisy_latents, timesteps, encoder_hidden_states, unet_added_cond_kwargs, image_embeds): | |
| ip_tokens = self.image_proj_model(image_embeds) | |
| encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1) | |
| # Predict the noise residual | |
| noise_pred = self.unet( | |
| noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=unet_added_cond_kwargs | |
| ).sample | |
| return noise_pred | |
| def load_from_checkpoint(self, ckpt_path: str): | |
| # Calculate original checksums | |
| orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()])) | |
| orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()])) | |
| state_dict = torch.load(ckpt_path, map_location="cpu") | |
| # Load state dict for image_proj_model and adapter_modules | |
| self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True) | |
| self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True) | |
| # Calculate new checksums | |
| new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()])) | |
| new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()])) | |
| # Verify if the weights have changed | |
| assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!" | |
| assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!" | |
| print(f"Successfully loaded weights from checkpoint {ckpt_path}") | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Simple example of a training script.") | |
| parser.add_argument( | |
| "--pretrained_model_name_or_path", | |
| type=str, | |
| default=None, | |
| required=True, | |
| help="Path to pretrained model or model identifier from huggingface.co/models.", | |
| ) | |
| parser.add_argument( | |
| "--pretrained_ip_adapter_path", | |
| type=str, | |
| default=None, | |
| help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.", | |
| ) | |
| parser.add_argument( | |
| "--data_json_file", | |
| type=str, | |
| default=None, | |
| required=True, | |
| help="Training data", | |
| ) | |
| parser.add_argument( | |
| "--data_root_path", | |
| type=str, | |
| default="", | |
| required=True, | |
| help="Training data root path", | |
| ) | |
| parser.add_argument( | |
| "--image_encoder_path", | |
| type=str, | |
| default=None, | |
| required=True, | |
| help="Path to CLIP image encoder", | |
| ) | |
| parser.add_argument( | |
| "--output_dir", | |
| type=str, | |
| default="sd-ip_adapter", | |
| help="The output directory where the model predictions and checkpoints will be written.", | |
| ) | |
| parser.add_argument( | |
| "--logging_dir", | |
| type=str, | |
| default="logs", | |
| help=( | |
| "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" | |
| " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--resolution", | |
| type=int, | |
| default=512, | |
| help=("The resolution for input images"), | |
| ) | |
| parser.add_argument( | |
| "--learning_rate", | |
| type=float, | |
| default=1e-4, | |
| help="Learning rate to use.", | |
| ) | |
| parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay to use.") | |
| parser.add_argument("--num_train_epochs", type=int, default=100) | |
| parser.add_argument( | |
| "--train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader." | |
| ) | |
| parser.add_argument("--noise_offset", type=float, default=None, help="noise offset") | |
| parser.add_argument( | |
| "--dataloader_num_workers", | |
| type=int, | |
| default=0, | |
| help=( | |
| "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--save_steps", | |
| type=int, | |
| default=2000, | |
| help=("Save a checkpoint of the training state every X updates"), | |
| ) | |
| parser.add_argument( | |
| "--mixed_precision", | |
| type=str, | |
| default=None, | |
| choices=["no", "fp16", "bf16"], | |
| help=( | |
| "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" | |
| " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" | |
| " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--report_to", | |
| type=str, | |
| default="tensorboard", | |
| help=( | |
| 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' | |
| ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' | |
| ), | |
| ) | |
| parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") | |
| args = parser.parse_args() | |
| env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) | |
| if env_local_rank != -1 and env_local_rank != args.local_rank: | |
| args.local_rank = env_local_rank | |
| return args | |
| def main(): | |
| args = parse_args() | |
| logging_dir = Path(args.output_dir, args.logging_dir) | |
| accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) | |
| accelerator = Accelerator( | |
| mixed_precision=args.mixed_precision, | |
| log_with=args.report_to, | |
| project_config=accelerator_project_config, | |
| ) | |
| if accelerator.is_main_process: | |
| if args.output_dir is not None: | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| # Load scheduler, tokenizer and models. | |
| noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") | |
| tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") | |
| text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") | |
| tokenizer_2 = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer_2") | |
| text_encoder_2 = CLIPTextModelWithProjection.from_pretrained( | |
| args.pretrained_model_name_or_path, subfolder="text_encoder_2" | |
| ) | |
| vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") | |
| unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") | |
| image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path) | |
| # freeze parameters of models to save more memory | |
| unet.requires_grad_(False) | |
| vae.requires_grad_(False) | |
| text_encoder.requires_grad_(False) | |
| text_encoder_2.requires_grad_(False) | |
| image_encoder.requires_grad_(False) | |
| # ip-adapter | |
| num_tokens = 4 | |
| image_proj_model = ImageProjModel( | |
| cross_attention_dim=unet.config.cross_attention_dim, | |
| clip_embeddings_dim=image_encoder.config.projection_dim, | |
| clip_extra_context_tokens=num_tokens, | |
| ) | |
| # init adapter modules | |
| attn_procs = {} | |
| unet_sd = unet.state_dict() | |
| for name in unet.attn_processors.keys(): | |
| cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim | |
| if name.startswith("mid_block"): | |
| hidden_size = unet.config.block_out_channels[-1] | |
| elif name.startswith("up_blocks"): | |
| block_id = int(name[len("up_blocks.")]) | |
| hidden_size = list(reversed(unet.config.block_out_channels))[block_id] | |
| elif name.startswith("down_blocks"): | |
| block_id = int(name[len("down_blocks.")]) | |
| hidden_size = unet.config.block_out_channels[block_id] | |
| if cross_attention_dim is None: | |
| attn_procs[name] = AttnProcessor() | |
| else: | |
| layer_name = name.split(".processor")[0] | |
| weights = { | |
| "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"], | |
| "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"], | |
| } | |
| attn_procs[name] = IPAttnProcessor( | |
| hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=num_tokens | |
| ) | |
| attn_procs[name].load_state_dict(weights) | |
| unet.set_attn_processor(attn_procs) | |
| adapter_modules = torch.nn.ModuleList(unet.attn_processors.values()) | |
| ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path) | |
| weight_dtype = torch.float32 | |
| if accelerator.mixed_precision == "fp16": | |
| weight_dtype = torch.float16 | |
| elif accelerator.mixed_precision == "bf16": | |
| weight_dtype = torch.bfloat16 | |
| # unet.to(accelerator.device, dtype=weight_dtype) | |
| vae.to(accelerator.device) # use fp32 | |
| text_encoder.to(accelerator.device, dtype=weight_dtype) | |
| text_encoder_2.to(accelerator.device, dtype=weight_dtype) | |
| image_encoder.to(accelerator.device, dtype=weight_dtype) | |
| # optimizer | |
| params_to_opt = itertools.chain(ip_adapter.image_proj_model.parameters(), ip_adapter.adapter_modules.parameters()) | |
| optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay) | |
| # dataloader | |
| train_dataset = MyDataset( | |
| args.data_json_file, | |
| tokenizer=tokenizer, | |
| tokenizer_2=tokenizer_2, | |
| size=args.resolution, | |
| image_root_path=args.data_root_path, | |
| ) | |
| train_dataloader = torch.utils.data.DataLoader( | |
| train_dataset, | |
| shuffle=True, | |
| collate_fn=collate_fn, | |
| batch_size=args.train_batch_size, | |
| num_workers=args.dataloader_num_workers, | |
| ) | |
| # Prepare everything with our `accelerator`. | |
| ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader) | |
| global_step = 0 | |
| for epoch in range(0, args.num_train_epochs): | |
| begin = time.perf_counter() | |
| for step, batch in enumerate(train_dataloader): | |
| load_data_time = time.perf_counter() - begin | |
| with accelerator.accumulate(ip_adapter): | |
| # Convert images to latent space | |
| with torch.no_grad(): | |
| # vae of sdxl should use fp32 | |
| latents = vae.encode(batch["images"].to(accelerator.device, dtype=vae.dtype)).latent_dist.sample() | |
| latents = latents * vae.config.scaling_factor | |
| latents = latents.to(accelerator.device, dtype=weight_dtype) | |
| # Sample noise that we'll add to the latents | |
| noise = torch.randn_like(latents) | |
| if args.noise_offset: | |
| # https://www.crosslabs.org//blog/diffusion-with-offset-noise | |
| noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1)).to( | |
| accelerator.device, dtype=weight_dtype | |
| ) | |
| bsz = latents.shape[0] | |
| # Sample a random timestep for each image | |
| timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) | |
| timesteps = timesteps.long() | |
| # Add noise to the latents according to the noise magnitude at each timestep | |
| # (this is the forward diffusion process) | |
| noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | |
| with torch.no_grad(): | |
| image_embeds = image_encoder( | |
| batch["clip_images"].to(accelerator.device, dtype=weight_dtype) | |
| ).image_embeds | |
| image_embeds_ = [] | |
| for image_embed, drop_image_embed in zip(image_embeds, batch["drop_image_embeds"]): | |
| if drop_image_embed == 1: | |
| image_embeds_.append(torch.zeros_like(image_embed)) | |
| else: | |
| image_embeds_.append(image_embed) | |
| image_embeds = torch.stack(image_embeds_) | |
| with torch.no_grad(): | |
| encoder_output = text_encoder( | |
| batch["text_input_ids"].to(accelerator.device), output_hidden_states=True | |
| ) | |
| text_embeds = encoder_output.hidden_states[-2] | |
| encoder_output_2 = text_encoder_2( | |
| batch["text_input_ids_2"].to(accelerator.device), output_hidden_states=True | |
| ) | |
| pooled_text_embeds = encoder_output_2[0] | |
| text_embeds_2 = encoder_output_2.hidden_states[-2] | |
| text_embeds = torch.concat([text_embeds, text_embeds_2], dim=-1) # concat | |
| # add cond | |
| add_time_ids = [ | |
| batch["original_size"].to(accelerator.device), | |
| batch["crop_coords_top_left"].to(accelerator.device), | |
| batch["target_size"].to(accelerator.device), | |
| ] | |
| add_time_ids = torch.cat(add_time_ids, dim=1).to(accelerator.device, dtype=weight_dtype) | |
| unet_added_cond_kwargs = {"text_embeds": pooled_text_embeds, "time_ids": add_time_ids} | |
| noise_pred = ip_adapter(noisy_latents, timesteps, text_embeds, unet_added_cond_kwargs, image_embeds) | |
| loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") | |
| # Gather the losses across all processes for logging (if we use distributed training). | |
| avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean().item() | |
| # Backpropagate | |
| accelerator.backward(loss) | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| if accelerator.is_main_process: | |
| print( | |
| "Epoch {}, step {}, data_time: {}, time: {}, step_loss: {}".format( | |
| epoch, step, load_data_time, time.perf_counter() - begin, avg_loss | |
| ) | |
| ) | |
| global_step += 1 | |
| if global_step % args.save_steps == 0: | |
| save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") | |
| accelerator.save_state(save_path) | |
| begin = time.perf_counter() | |
| if __name__ == "__main__": | |
| main() | |