| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| import inspect |
| from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| import torchvision.transforms as T |
| from diffusers import AutoencoderKL, DiffusionPipeline, FlowMatchEulerDiscreteScheduler |
| from diffusers.image_processor import PipelineImageInput, VaeImageProcessor |
| from diffusers.loaders import FromSingleFileMixin, ZImageLoraLoaderMixin |
| from diffusers.pipelines.z_image.pipeline_output import ZImagePipelineOutput |
| from diffusers.utils import logging |
| from diffusers.utils.torch_utils import randn_tensor |
| from PIL import Image, ImageFilter |
| from transformers import AutoTokenizer, PreTrainedModel |
|
|
| from diffusers_local.z_image_control_transformer_2d import ZImageControlTransformer2DModel |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| def calculate_shift( |
| image_seq_len, |
| base_seq_len: int = 256, |
| max_seq_len: int = 4096, |
| base_shift: float = 0.5, |
| max_shift: float = 1.15, |
| ): |
| """ |
| Calculates the shift value `mu` for the scheduler based on the image sequence length. |
| |
| This function implements a linear interpolation to determine the shift value based on the input |
| image's sequence length, scaling between a base and a maximum shift value. |
| |
| Args: |
| image_seq_len (`int`): |
| The sequence length of the image latents (height * width). |
| base_seq_len (`int`, *optional*, defaults to 256): |
| The base sequence length for the shift calculation. |
| max_seq_len (`int`, *optional*, defaults to 4096): |
| The maximum sequence length for the shift calculation. |
| base_shift (`float`, *optional*, defaults to 0.5): |
| The shift value corresponding to `base_seq_len`. |
| max_shift (`float`, *optional*, defaults to 1.15): |
| The shift value corresponding to `max_seq_len`. |
| |
| Returns: |
| `float`: The calculated shift value `mu`. |
| """ |
| m = (max_shift - base_shift) / (max_seq_len - base_seq_len) |
| b = base_shift - m * base_seq_len |
| mu = image_seq_len * m + b |
| return mu |
|
|
|
|
| def retrieve_latents(encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"): |
| """ |
| Retrieves latents from a VAE encoder output. |
| |
| Args: |
| encoder_output (`torch.Tensor`): |
| The output of a VAE encoder. |
| generator (`torch.Generator`, *optional*): |
| A random number generator for sampling from the latent distribution. |
| sample_mode (`str`, *optional*, defaults to "sample"): |
| The method to retrieve latents. Can be "sample" to sample from the distribution or |
| "argmax" to take the mode. |
| |
| Returns: |
| `torch.Tensor`: The retrieved latents. |
| """ |
| if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": |
| return encoder_output.latent_dist.sample(generator) |
| elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": |
| return encoder_output.latent_dist.mode() |
| elif hasattr(encoder_output, "latents"): |
| return encoder_output.latents |
| else: |
| raise AttributeError("Could not access latents of provided encoder_output") |
|
|
|
|
| def retrieve_timesteps( |
| scheduler, |
| num_inference_steps: Optional[int] = None, |
| device: Optional[Union[str, torch.device]] = None, |
| timesteps: Optional[List[int]] = None, |
| sigmas: Optional[List[float]] = None, |
| **kwargs, |
| ): |
| """ |
| Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles |
| custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. |
| |
| Args: |
| scheduler (`SchedulerMixin`): |
| The scheduler to get timesteps from. |
| num_inference_steps (`int`): |
| The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` |
| must be `None`. |
| device (`str` or `torch.device`, *optional*): |
| The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. |
| timesteps (`List[int]`, *optional*): |
| Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, |
| `num_inference_steps` and `sigmas` must be `None`. |
| sigmas (`List[float]`, *optional*): |
| Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, |
| `num_inference_steps` and `timesteps` must be `None`. |
| |
| Returns: |
| `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the |
| second element is the number of inference steps. |
| """ |
| if timesteps is not None and sigmas is not None: |
| raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") |
| if timesteps is not None: |
| accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) |
| if not accepts_timesteps: |
| raise ValueError( |
| f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" |
| f" timestep schedules. Please check whether you are using the correct scheduler." |
| ) |
| scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) |
| timesteps = scheduler.timesteps |
| num_inference_steps = len(timesteps) |
| elif sigmas is not None: |
| accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) |
| if not accept_sigmas: |
| raise ValueError( |
| f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" |
| f" sigmas schedules. Please check whether you are using the correct scheduler." |
| ) |
| scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) |
| timesteps = scheduler.timesteps |
| num_inference_steps = len(timesteps) |
| else: |
| scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) |
| timesteps = scheduler.timesteps |
| return timesteps, num_inference_steps |
|
|
|
|
| class ZImageControlUnifiedPipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMixin): |
| model_cpu_offload_seq = "text_encoder->vae->transformer" |
| _optional_components = [] |
| _callback_tensor_inputs = ["latents", "prompt_embeds"] |
|
|
| def __init__( |
| self, |
| scheduler: FlowMatchEulerDiscreteScheduler, |
| vae: AutoencoderKL, |
| text_encoder: PreTrainedModel, |
| tokenizer: AutoTokenizer, |
| transformer: ZImageControlTransformer2DModel, |
| ): |
| """ |
| Initializes the ZImageControlUnifiedPipeline. |
| |
| Args: |
| scheduler (`FlowMatchEulerDiscreteScheduler`): |
| A scheduler to be used in combination with `transformer` to denoise the latents. |
| vae (`AutoencoderKL`): |
| Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. |
| text_encoder (`PreTrainedModel`): |
| A pretrained text encoder model. |
| tokenizer (`AutoTokenizer`): |
| A tokenizer to prepare text prompts for the `text_encoder`. |
| transformer (`ZImageControlTransformer2DModel`): |
| The main transformer model for the diffusion process. |
| """ |
| super().__init__() |
| self.register_modules( |
| vae=vae, |
| text_encoder=text_encoder, |
| tokenizer=tokenizer, |
| scheduler=scheduler, |
| transformer=transformer, |
| ) |
| self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 |
| self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) |
| self.mask_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True) |
|
|
| def encode_prompt( |
| self, |
| prompt: Union[str, List[str]], |
| device: Optional[torch.device] = None, |
| num_images_per_prompt: int = 1, |
| do_classifier_free_guidance: bool = True, |
| negative_prompt: Optional[Union[str, List[str]]] = None, |
| prompt_embeds: Optional[List[torch.FloatTensor]] = None, |
| negative_prompt_embeds: Optional[torch.FloatTensor] = None, |
| max_sequence_length: int = 512, |
| ): |
| """ |
| Encodes the prompt into text embeddings. |
| |
| Args: |
| prompt (`Union[str, List[str]]`): |
| The prompt or prompts to guide the image generation. |
| device (`Optional[torch.device]`): |
| The device to move the embeddings to. |
| num_images_per_prompt (`int`): |
| The number of images to generate per prompt. |
| do_classifier_free_guidance (`bool`): |
| Whether to generate embeddings for classifier-free guidance. |
| negative_prompt (`Optional[Union[str, List[str]]]`): |
| The negative prompt or prompts. |
| prompt_embeds (`Optional[List[torch.FloatTensor]]`): |
| Pre-generated positive prompt embeddings. |
| negative_prompt_embeds (`Optional[torch.FloatTensor]`): |
| Pre-generated negative prompt embeddings. |
| max_sequence_length (`int`): |
| The maximum sequence length for tokenization. |
| |
| Returns: |
| `Tuple[List[torch.Tensor], List[torch.Tensor]]`: A tuple containing the positive and negative prompt embeddings. |
| """ |
| device = device or self._execution_device |
| prompt = [prompt] if isinstance(prompt, str) else prompt |
|
|
| if prompt_embeds is not None: |
| pass |
| else: |
| prompt_embeds = self._encode_prompt( |
| prompt=prompt, |
| device=device, |
| max_sequence_length=max_sequence_length, |
| ) |
| if num_images_per_prompt > 1: |
| prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)] |
|
|
| if do_classifier_free_guidance: |
| if negative_prompt_embeds is not None: |
| pass |
| else: |
| if negative_prompt is None: |
| negative_prompt = [""] * len(prompt) |
| else: |
| negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt |
| assert len(prompt) == len(negative_prompt) |
| negative_prompt_embeds = self._encode_prompt( |
| prompt=negative_prompt, |
| device=device, |
| max_sequence_length=max_sequence_length, |
| ) |
|
|
| if num_images_per_prompt > 1: |
| negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)] |
|
|
| return prompt_embeds, negative_prompt_embeds |
|
|
| def _encode_prompt(self, prompt: Union[str, List[str]], device: torch.device, max_sequence_length: int) -> List[torch.Tensor]: |
| """ |
| Internal helper to encode a list of prompts into embeddings, applying chat templates if available. |
| |
| Args: |
| prompt (`Union[str, List[str]]`): |
| A list of strings to be encoded. |
| device (`torch.device`): |
| The target device for the embeddings. |
| max_sequence_length (`int`): |
| The maximum length for tokenization. |
| |
| Returns: |
| `List[torch.Tensor]`: A list of embedding tensors, one for each prompt. |
| """ |
| formatted_prompts = [] |
| for p in prompt: |
| messages = [{"role": "user", "content": p}] |
| if hasattr(self.tokenizer, "apply_chat_template"): |
| formatted_prompts.append(self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=True)) |
| else: |
| formatted_prompts.append(p) |
|
|
| text_inputs = self.tokenizer( |
| formatted_prompts, |
| padding="max_length", |
| max_length=max_sequence_length, |
| truncation=True, |
| return_tensors="pt", |
| ).to(device) |
|
|
| prompt_masks = text_inputs.attention_mask.bool() |
|
|
| with torch.no_grad(): |
| prompt_embeds_batch = self.text_encoder(input_ids=text_inputs.input_ids, attention_mask=prompt_masks, output_hidden_states=True).hidden_states[-2] |
|
|
| embeddings_list = [] |
| for i in range(prompt_embeds_batch.shape[0]): |
| embeddings_list.append(prompt_embeds_batch[i][prompt_masks[i]]) |
|
|
| return embeddings_list |
|
|
| def get_timesteps(self, num_inference_steps, strength, device): |
| """ |
| Calculates the timesteps for the scheduler based on the number of inference steps and strength. |
| This is primarily used for image-to-image pipelines. |
| |
| Args: |
| num_inference_steps (`int`): The total number of diffusion steps. |
| strength (`float`): The strength of the denoising process. A value of 1.0 means full denoising. |
| device (`torch.device`): The device to place the timesteps on. |
| |
| Returns: |
| `Tuple[torch.Tensor, int]`: A tuple containing the timesteps and the number of steps to run. |
| """ |
| init_timestep = min(num_inference_steps * strength, num_inference_steps) |
|
|
| t_start = int(max(num_inference_steps - init_timestep, 0)) |
| timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] |
| if hasattr(self.scheduler, "set_begin_index"): |
| self.scheduler.set_begin_index(t_start * self.scheduler.order) |
|
|
| return timesteps, num_inference_steps - t_start |
|
|
| def prepare_latents( |
| self, |
| batch_size: int, |
| num_channels_latents: int, |
| height: int, |
| width: int, |
| dtype: torch.dtype, |
| device: torch.device, |
| generator: torch.Generator, |
| image: Optional[PipelineImageInput] = None, |
| timestep: Optional[torch.Tensor] = None, |
| latents: Optional[torch.Tensor] = None, |
| ): |
| """ |
| Prepares the initial latents for the diffusion process. |
| |
| This function handles three cases: |
| 1. `latents` are provided: They are returned directly. |
| 2. `image` is None (Text-to-Image): Random noise is generated. |
| 3. `image` is provided (Image-to-Image): The image is encoded, and noise is added according to the timestep. |
| |
| Args: |
| batch_size (`int`): The number of latents to generate. |
| num_channels_latents (`int`): The number of channels in the latents. |
| height (`int`): The height of the output image in pixels. |
| width (`int`): The width of the output image in pixels. |
| dtype (`torch.dtype`): The data type for the latents. |
| device (`torch.device`): The device to create the latents on. |
| generator (`torch.Generator`): A random generator for creating the initial noise. |
| image (`Optional[PipelineImageInput]`): An initial image for img2img mode. |
| timestep (`Optional[torch.Tensor]`): The starting timestep for adding noise in img2img mode. |
| latents (`Optional[torch.Tensor]`): Pre-generated latents. |
| |
| Returns: |
| `torch.Tensor`: The prepared latents. |
| """ |
| latent_height = 2 * (int(height) // (self.vae_scale_factor * 2)) |
| latent_width = 2 * (int(width) // (self.vae_scale_factor * 2)) |
| shape = (batch_size, num_channels_latents, latent_height, latent_width) |
|
|
| if latents is not None: |
| return latents.to(device=device, dtype=dtype) |
|
|
| if image is None: |
| latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
| return latents |
|
|
| image_tensor = self.image_processor.preprocess(image, height=height, width=width).to(device=device, dtype=self.vae.dtype) |
| with torch.no_grad(): |
| if image_tensor.shape[1] != num_channels_latents: |
| if isinstance(generator, list): |
| image_latents = [retrieve_latents(self.vae.encode(image_tensor[i : i + 1]), generator=generator[i]) for i in range(image_tensor.shape[0])] |
| image_latents = torch.cat(image_latents, dim=0) |
| else: |
| image_latents = retrieve_latents(self.vae.encode(image_tensor), generator=generator) |
|
|
| image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor |
| image_latents = image_latents.to(dtype) |
|
|
| if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: |
| additional_image_per_prompt = batch_size // image_latents.shape[0] |
| image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) |
| elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: |
| raise ValueError(f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts.") |
|
|
| noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
| latents = self.scheduler.scale_noise(image_latents, timestep, noise) |
|
|
| return latents |
|
|
| def _prepare_image_latents( |
| self, |
| image: PipelineImageInput, |
| mask_image: PipelineImageInput, |
| width: int, |
| height: int, |
| batch_size: int, |
| num_images_per_prompt: int, |
| device: torch.device, |
| dtype: torch.dtype, |
| do_preprocess: bool = True, |
| ) -> torch.Tensor: |
| """ |
| Generic function to encode an image into 5D latents for inpainting context. |
| |
| If `do_preprocess` is True, it processes the image (PIL/np). |
| If `do_preprocess` is False, it assumes 'image' is already a ready-to-use tensor. |
| |
| Args: |
| image (`PipelineImageInput`): The input image. Can be None to return zeros. |
| width (`int`): The target width. |
| height (`int`): The target height. |
| batch_size (`int`): The prompt batch size. |
| num_images_per_prompt (`int`): The number of images per prompt. |
| device (`torch.device`): The target device. |
| dtype (`torch.dtype`): The target data type. |
| do_preprocess (`bool`): Whether to preprocess the image. |
| |
| Returns: |
| `torch.Tensor`: A 5D tensor of the encoded image latents. |
| """ |
| if image is None: |
| latent_h = height // self.vae_scale_factor |
| latent_w = width // self.vae_scale_factor |
| shape = (batch_size * num_images_per_prompt, self.transformer.in_channels, 1, latent_h, latent_w) |
| return torch.zeros(shape, device=device, dtype=dtype) |
|
|
| if do_preprocess: |
| image_tensor = self.image_processor.preprocess(image, height=height, width=width).to(device=device, dtype=self.vae.dtype) |
| else: |
| image_tensor = image.to(device=device, dtype=self.vae.dtype) |
|
|
| if mask_image is not None: |
| mask_condition = self.mask_processor.preprocess(mask_image, height=height, width=width).to(device=device, dtype=self.vae.dtype) |
| |
| mask_condition = torch.tile(mask_condition, [1, 3, 1, 1]) |
| |
| image_tensor = image_tensor * (mask_condition < 0.5) |
|
|
| with torch.no_grad(): |
| latents = retrieve_latents(self.vae.encode(image_tensor), sample_mode="argmax") |
| latents = (latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor |
|
|
| effective_batch_size = batch_size * num_images_per_prompt |
| if latents.shape[0] != effective_batch_size: |
| repeat_by = effective_batch_size // latents.shape[0] |
| latents = latents.repeat_interleave(repeat_by, dim=0) |
|
|
| return latents.to(dtype=dtype).unsqueeze(2) |
|
|
| def _prepare_mask_latents( |
| self, |
| mask_image: PipelineImageInput, |
| width: int, |
| height: int, |
| batch_size: int, |
| num_images_per_prompt: int, |
| reference_latents_shape: Tuple, |
| device: torch.device, |
| dtype: torch.dtype, |
| invert_mask: bool = False, |
| do_unsqueeze: bool = True, |
| ) -> torch.Tensor: |
| """ |
| Processes a MASK using the mask_processor, inverts it, resizes it, and formats it for the control_context. |
| |
| Args: |
| mask_image (`PipelineImageInput`): The mask image. Can be None to return zeros. |
| width (`int`): The target width. |
| height (`int`): The target height. |
| batch_size (`int`): The prompt batch size. |
| num_images_per_prompt (`int`): The number of images per prompt. |
| reference_latents_shape (`Tuple`): The shape of the inpainting latents for resizing. |
| device (`torch.device`): The target device. |
| dtype (`torch.dtype`): The target data type. |
| |
| Returns: |
| `torch.Tensor`: A 5D tensor of the processed mask latents. |
| """ |
| if mask_image is None: |
| placeholder_shape = ( |
| batch_size * num_images_per_prompt, |
| 1, |
| 1, |
| reference_latents_shape[-2], |
| reference_latents_shape[-1], |
| ) |
| return torch.zeros(placeholder_shape, device=device, dtype=dtype) |
|
|
| mask_tensor = self.mask_processor.preprocess(mask_image, height=height, width=width) |
| mask_tensor = mask_tensor.to(device=device, dtype=dtype) |
| |
| if invert_mask: |
| mask_tensor = 1.0 - mask_tensor |
|
|
| mask_latents = F.interpolate(mask_tensor, size=reference_latents_shape[-2:], mode="nearest") |
| |
| if do_unsqueeze: |
| mask_latents = mask_latents.unsqueeze(2) |
| |
| return mask_latents |
|
|
| def prepare_control_latents( |
| self, image: PipelineImageInput, width: int, height: int, batch_size: int, num_images_per_prompt: int, device: torch.device, dtype: torch.dtype |
| ) -> torch.Tensor: |
| """ |
| Preprocesses a control image, ENCODES it with the VAE to latent space, |
| and returns a 5D tensor ready for the transformer model. |
| |
| Args: |
| image (`PipelineImageInput`): The control image. Can be None to return zeros. |
| width (`int`): The target width. |
| height (`int`): The target height. |
| batch_size (`int`): The prompt batch size. |
| num_images_per_prompt (`int`): The number of images per prompt. |
| device (`torch.device`): The target device. |
| dtype (`torch.dtype`): The target data type. |
| |
| Returns: |
| `torch.Tensor`: A 5D tensor of the control image latents. |
| """ |
| if image is None: |
| latent_h = 2 * (int(height) // (self.vae_scale_factor * 2)) |
| latent_w = 2 * (int(width) // (self.vae_scale_factor * 2)) |
| return torch.zeros( |
| (batch_size * num_images_per_prompt, self.transformer.in_channels, 1, latent_h, latent_w), |
| device=device, |
| dtype=dtype, |
| ) |
|
|
| image_tensor = self.image_processor.preprocess(image, height=height, width=width).to(device=device, dtype=self.vae.dtype) |
| with torch.no_grad(): |
| latents = retrieve_latents(self.vae.encode(image_tensor), sample_mode="argmax") |
| latents = (latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor |
|
|
| effective_batch_size = batch_size * num_images_per_prompt |
| if latents.shape[0] < effective_batch_size: |
| latents = latents.repeat_interleave(effective_batch_size // latents.shape[0], dim=0) |
|
|
| return latents.to(dtype=dtype).unsqueeze(2) |
|
|
| def _expand_and_feather_mask(self, mask_image, expand_pixels=10, feather_radius=8, is_inpaint_mode=True): |
| """ |
| Expands the white area of a mask using PyTorch for performance and then smooths its edges with Pillow. |
| |
| Args: |
| mask_image (PIL.Image.Image | np.ndarray | torch.Tensor): The input mask. |
| expand_pixels (int): How many pixels to expand the white area. |
| feather_radius (int): The radius of the Gaussian blur for the gradient. |
| is_inpaint_mode (bool): Flag to enable/disable the operation. |
| |
| Returns: |
| PIL.Image.Image | np.ndarray | torch.Tensor: The processed mask, in the same format as the input. |
| """ |
| if not is_inpaint_mode or (expand_pixels <= 0 and feather_radius <= 0): |
| return mask_image |
|
|
| |
| input_type = type(mask_image) |
| |
| if isinstance(mask_image, Image.Image): |
| |
| mask_tensor = T.ToTensor()(mask_image.convert("L")) |
| elif isinstance(mask_image, np.ndarray): |
| |
| mask_tensor = torch.from_numpy(mask_image).permute(2, 0, 1) if mask_image.ndim == 3 else torch.from_numpy(mask_image).unsqueeze(0) |
| elif isinstance(mask_image, torch.Tensor): |
| mask_tensor = mask_image |
| else: |
| raise TypeError(f"Unsupported mask type: {input_type}") |
|
|
| |
| mask_tensor = mask_tensor.to(device=self.device, dtype=torch.float32) |
| if mask_tensor.ndim == 3: |
| mask_tensor = mask_tensor.unsqueeze(0) |
|
|
| |
| if expand_pixels > 0: |
| kernel_size = expand_pixels * 2 + 1 |
| padding = expand_pixels |
|
|
| |
| mask_tensor = F.max_pool2d( |
| mask_tensor, |
| kernel_size=kernel_size, |
| stride=1, |
| padding=padding |
| ) |
|
|
| |
| |
| to_pil = T.ToPILImage() |
| mask_pil = to_pil(mask_tensor.squeeze(0).cpu()) |
|
|
| |
| if feather_radius > 0: |
| mask_pil = mask_pil.filter(ImageFilter.GaussianBlur(radius=feather_radius)) |
|
|
| |
| if input_type is torch.Tensor: |
| |
| return T.ToTensor()(mask_pil).to(device=self.device, dtype=mask_image.dtype) |
| elif input_type is np.ndarray: |
| |
| return np.array(mask_pil) |
| else: |
| return mask_pil |
|
|
| def _apply_mask_blur(self, mask_image, mask_blur_radius, is_inpaint_mode): |
| """ |
| Apply Gaussian blur to a mask image for inpainting operations. |
| Args: |
| mask_image (Image.Image | np.ndarray | torch.Tensor): The mask image to be blurred. |
| Can be provided as a PIL Image, NumPy array, or PyTorch tensor. |
| mask_blur_radius (float): The radius of the Gaussian blur filter in pixels. |
| Only applied if is_inpaint_mode is True and mask_blur_radius > 0. |
| is_inpaint_mode (bool): Flag indicating whether the pipeline is in inpainting mode. |
| Blur is only applied when this is True. |
| Returns: |
| Image.Image | np.ndarray | torch.Tensor: The mask image with Gaussian blur applied |
| if is_inpaint_mode is True and mask_blur_radius > 0. Otherwise, returns the |
| original mask_image unchanged. The return type matches the input type. |
| """ |
| mask_to_use = mask_image |
| if is_inpaint_mode and mask_blur_radius > 0: |
| if isinstance(mask_image, Image.Image): |
| mask_pil = mask_image |
| elif isinstance(mask_image, np.ndarray): |
| mask_pil = Image.fromarray(mask_image) |
| elif isinstance(mask_image, torch.Tensor): |
| mask_pil = Image.fromarray(mask_image.cpu().numpy().astype(np.uint8)) |
| else: |
| mask_pil = mask_image |
|
|
| mask_to_use = mask_pil.filter(ImageFilter.GaussianBlur(radius=mask_blur_radius)) |
| return mask_to_use |
|
|
| @property |
| def guidance_scale(self): |
| return self._guidance_scale |
|
|
| @property |
| def do_classifier_free_guidance(self): |
| return self._guidance_scale > 1 |
|
|
| @property |
| def joint_attention_kwargs(self): |
| return self._joint_attention_kwargs |
|
|
| @property |
| def num_timesteps(self): |
| return self._num_timesteps |
|
|
| @property |
| def interrupt(self): |
| return self._interrupt |
|
|
| def __call__( |
| self, |
| prompt: Union[str, List[str]], |
| image: Optional[PipelineImageInput] = None, |
| mask_image: Optional[PipelineImageInput] = None, |
| inpaint_mode: Literal["default", "diff", "diff+inpaint"] = "default", |
| mask_blur_radius: float=8.0, |
| control_image: Optional[PipelineImageInput] = None, |
| height: Optional[int] = None, |
| width: Optional[int] = None, |
| num_inference_steps: int = 20, |
| sigmas: Optional[List[float]] = None, |
| strength: float = 1.0, |
| guidance_scale: float = 4.0, |
| cfg_normalization: bool = False, |
| cfg_truncation: float = 1.0, |
| negative_prompt: Optional[Union[str, List[str]]] = None, |
| num_images_per_prompt: int = 1, |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
| latents: Optional[torch.Tensor] = None, |
| prompt_embeds: Optional[List[torch.FloatTensor]] = None, |
| negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None, |
| controlnet_conditioning_scale: float = 1.0, |
| controlnet_refiner_conditioning_scale: float = 1.0, |
| output_type: str = "pil", |
| return_dict: bool = True, |
| joint_attention_kwargs: Optional[Dict[str, Any]] = None, |
| callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, |
| callback_on_step_end_tensor_inputs: List[str] = ["latents"], |
| max_sequence_length: int = 512, |
| ): |
| r""" |
| The main entry point for the Z-Image unified pipeline for generation. |
| |
| Args: |
| prompt (`str` or `List[str]`, *optional*): |
| The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. |
| image (`PipelineImageInput`, *optional*): |
| The initial image for image-to-image or inpainting modes. |
| mask_image (`PipelineImageInput`, *optional*): |
| The mask image for inpainting. White areas are preserved, black areas are inpainted. |
| inpaint_mode (`str`, *optional*, defaults to `"default"`): |
| The inpainting mode. Can be "default", "diff", or "diff+inpaint". Determines how the inpainting |
| process is handled. |
| mask_blur_radius (`float`, *optional*, defaults to 8.0): |
| The radius for blurring the edges of the inpainting mask to create a smoother transition. |
| control_image (`PipelineImageInput`, *optional*): |
| The conditioning image for control modes (e.g., Canny, depth). |
| height (`int`, *optional*, defaults to 1024): |
| The height in pixels of the generated image. |
| width (`int`, *optional*, defaults to 1024): |
| The width in pixels of the generated image. |
| num_inference_steps (`int`, *optional*, defaults to 20): |
| The number of denoising steps. More denoising steps usually lead to a higher quality image at the |
| expense of slower inference. |
| sigmas (`List[float]`, *optional*): |
| Custom sigmas to use for the denoising process. If not defined, the scheduler's default behavior |
| will be used. |
| strength (`float`, *optional*, defaults to 1.0): |
| Denoising strength for image-to-image. A value of 1.0 means the initial image is fully replaced, |
| while a lower value preserves more of the original image structure. Only used in img2img mode. |
| guidance_scale (`float`, *optional*, defaults to 4.0): |
| The scale for classifier-free guidance. A value > 1 enables it. Higher values encourage images |
| closer to the prompt, potentially at the cost of quality. |
| cfg_normalization (`bool`, *optional*, defaults to False): |
| Whether to apply normalization to the guidance, which can prevent oversaturation. |
| cfg_truncation (`float`, *optional*, defaults to 1.0): |
| A value between 0.0 and 1.0 that disables CFG for the final portion of the denoising steps, |
| specified as a fraction of total steps. For example, 0.8 disables CFG for the last 20% of steps. |
| negative_prompt (`str` or `List[str]`, *optional*): |
| The prompt or prompts not to guide the image generation. |
| num_images_per_prompt (`int`, *optional*, defaults to 1): |
| The number of images to generate per prompt. |
| generator (`torch.Generator` or `List[torch.Generator]`, *optional*): |
| A torch generator to make generation deterministic. |
| latents (`torch.FloatTensor`, *optional*): |
| Pre-generated noisy latents to be used as inputs for image generation. |
| prompt_embeds (`List[torch.FloatTensor]`, *optional*): |
| Pre-generated positive text embeddings. |
| negative_prompt_embeds (`List[torch.FloatTensor]`, *optional*): |
| Pre-generated negative text embeddings. |
| controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): |
| The scale of the control conditioning influence. |
| controlnet_refiner_conditioning_scale (`float`, *optional*, defaults to 1.0): |
| The scale of the control refiner conditioning influence. |
| output_type (`str`, *optional*, defaults to `"pil"`): |
| The output format of the generated image. Choose between "pil" (`PIL.Image.Image`), "np.array", or "latent". |
| return_dict (`bool`, *optional*, defaults to `True`): |
| Whether to return a `ZImagePipelineOutput` instead of a plain tuple. |
| joint_attention_kwargs (`dict`, *optional*): |
| A kwargs dictionary for the `AttentionProcessor`. |
| callback_on_step_end (`Callable`, *optional*): |
| A function that is called at the end of each denoising step. |
| callback_on_step_end_tensor_inputs (`List`, *optional*): |
| The list of tensor inputs for the `callback_on_step_end` function. |
| max_sequence_length (`int`, *optional*, defaults to 512): |
| Maximum sequence length to use with the `prompt`. |
| |
| Examples: |
| |
| Returns: |
| [`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: |
| If `return_dict` is True, a `ZImagePipelineOutput` is returned, otherwise a `tuple` with the generated images. |
| """ |
| self._guidance_scale = guidance_scale |
| self._joint_attention_kwargs = joint_attention_kwargs |
| self._interrupt = False |
| self._cfg_normalization = cfg_normalization |
| self._cfg_truncation = cfg_truncation |
| is_two_stage_control_model = self.transformer.control_in_dim > self.transformer.in_channels if hasattr(self.transformer, "control_in_dim") else False |
| device = self._execution_device |
| dtype = self.transformer.dtype |
| vae_scale = self.vae_scale_factor * 2 |
| has_inpaint_inputs = image is not None and mask_image is not None |
| is_inpaint_control_mode = has_inpaint_inputs and inpaint_mode in ["default", "diff+inpaint"] |
| is_diff_mode = has_inpaint_inputs and inpaint_mode in ["diff", "diff+inpaint"] |
| is_img2img_mode = image is not None and not has_inpaint_inputs |
| |
| ref_image = control_image or image |
| image_height = None |
| image_width = None |
| if ref_image is not None: |
| if isinstance(ref_image, Image.Image): |
| image_height, image_width = ref_image.height, ref_image.width |
| else: |
| image_height, image_width = ref_image.shape[-2], ref_image.shape[-1] |
|
|
| height = height or image_height or 1024 |
| width = width or image_width or 1024 |
|
|
| if height % vae_scale != 0 or width % vae_scale != 0: |
| raise ValueError(f"Height/width must be divisible by {vae_scale}.") |
|
|
| batch_size = len(prompt) if isinstance(prompt, list) else 1 if prompt else len(prompt_embeds) |
| effective_batch_size = batch_size * num_images_per_prompt |
|
|
| if prompt_embeds is not None and prompt is None: |
| if self.do_classifier_free_guidance and negative_prompt_embeds is None: |
| raise ValueError( |
| "When `prompt_embeds` is provided without `prompt`, `negative_prompt_embeds` must also be provided for classifier-free guidance." |
| ) |
| else: |
| ( |
| prompt_embeds, |
| negative_prompt_embeds, |
| ) = self.encode_prompt( |
| prompt=prompt, |
| num_images_per_prompt=num_images_per_prompt, |
| negative_prompt=negative_prompt, |
| do_classifier_free_guidance=self.do_classifier_free_guidance, |
| prompt_embeds=prompt_embeds, |
| negative_prompt_embeds=negative_prompt_embeds, |
| device=device, |
| max_sequence_length=max_sequence_length, |
| ) |
|
|
| if self.do_classifier_free_guidance: |
| prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds |
| else: |
| prompt_embeds_model_input = prompt_embeds |
| |
| if control_image is not None or is_inpaint_control_mode: |
| control_latents = self.prepare_control_latents(control_image, width, height, batch_size, num_images_per_prompt, device, dtype) |
|
|
| if is_two_stage_control_model: |
| image_for_inpaint = None if is_diff_mode and not is_inpaint_control_mode else image |
| mask_for_inpaint = None if is_diff_mode and not is_inpaint_control_mode else mask_image |
| |
| if is_inpaint_control_mode: |
| mask_for_inpaint = self._apply_mask_blur(mask_for_inpaint, mask_blur_radius, True) |
|
|
| inpaint_latents = self._prepare_image_latents( |
| image_for_inpaint, mask_for_inpaint, width, height, batch_size, num_images_per_prompt, device, dtype |
| ) |
| |
| mask_latents = self._prepare_mask_latents( |
| mask_for_inpaint, |
| width, |
| height, |
| batch_size, |
| num_images_per_prompt, |
| inpaint_latents.shape, |
| device, |
| dtype, |
| invert_mask=is_inpaint_control_mode, |
| do_unsqueeze=True, |
| ) |
| control_context = torch.cat([control_latents, mask_latents, inpaint_latents], dim=1) |
| else: |
| control_context = control_latents |
| else: |
| control_context = None |
|
|
| if self.do_classifier_free_guidance: |
| control_context_model_input = control_context.repeat(2, 1, 1, 1, 1) |
| else: |
| control_context_model_input = control_context |
|
|
| image_seq_len = (height // (self.vae_scale_factor * 2)) * (width // (self.vae_scale_factor * 2)) |
| mu = calculate_shift(image_seq_len) |
| self.scheduler.sigma_min = 0.0 |
| timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas, mu=mu) |
| self._num_timesteps = len(timesteps) |
|
|
| if is_img2img_mode: |
| strength = min(strength, 1.0) |
| else: |
| strength = 1.0 |
|
|
| if strength < 1.0: |
| init_timestep = min(int(num_inference_steps * strength), num_inference_steps) |
| t_start = max(num_inference_steps - init_timestep, 0) |
| timesteps = timesteps[t_start * self.scheduler.order :] |
| num_steps_to_run = len(timesteps) // self.scheduler.order |
| else: |
| num_steps_to_run = num_inference_steps |
|
|
| latent_timestep = timesteps[:1].repeat(effective_batch_size) if strength < 1.0 else None |
|
|
| use_image_for_latents = is_img2img_mode |
| |
| latents = self.prepare_latents( |
| effective_batch_size, |
| self.transformer.in_channels, |
| height, |
| width, |
| torch.float32, |
| device, |
| generator, |
| image=image if use_image_for_latents else None, |
| timestep=latent_timestep if use_image_for_latents else None, |
| latents=latents, |
| ) |
| |
| if is_diff_mode: |
| original_image_tensor = self.image_processor.preprocess(image, height=height, width=width).to(device=device, dtype=self.vae.dtype) |
| with torch.no_grad(): |
| original_clean_latents = retrieve_latents(self.vae.encode(original_image_tensor), sample_mode="argmax") |
| original_clean_latents = (original_clean_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor |
| original_clean_latents = original_clean_latents.to(dtype) |
| |
| noise = randn_tensor(original_clean_latents.shape, generator=generator, device=device, dtype=dtype) |
| latents_list = [] |
| step_indices = [(self.scheduler.timesteps == t).nonzero().item() for t in timesteps] |
| for i in step_indices: |
| sigma = self.scheduler.sigmas[i] |
| noisy_latent = (1.0 - sigma) * original_clean_latents + sigma * noise |
| latents_list.append(noisy_latent) |
| |
| original_latents_trajectory = torch.cat(latents_list, dim=0) |
| blurred_mask_image = self._apply_mask_blur(mask_image, mask_blur_radius, True) |
| map_processed = self._prepare_mask_latents( |
| blurred_mask_image, |
| width, |
| height, |
| batch_size, |
| num_images_per_prompt, |
| latents.shape, |
| device, |
| dtype, |
| invert_mask=True, |
| do_unsqueeze=False, |
| ) |
| |
| thresholds = torch.arange(len(timesteps), device=device, dtype=dtype) / len(timesteps) |
| thresholds = thresholds.view(-1, 1, 1, 1) |
| time_masks = map_processed > thresholds |
| |
| num_warmup_steps = len(timesteps) - num_steps_to_run * self.scheduler.order |
| with torch.inference_mode(): |
| with self.progress_bar(total=num_steps_to_run) as progress_bar: |
| for i, t in enumerate(timesteps): |
| if self.interrupt: |
| continue |
| |
| if is_diff_mode: |
| if i == 0: |
| latents = original_latents_trajectory[:1] |
| else: |
| current_mask = time_masks[i].to(latents.dtype) |
| current_original_latent = original_latents_trajectory[i:i+1] |
| |
| if current_mask.ndim == 3: |
| current_mask = current_mask.unsqueeze(1) |
| |
| latents = current_original_latent * current_mask + latents * (1 - current_mask) |
| |
| timestep = t.expand(latents.shape[0]) |
| timestep = (1000 - timestep) / 1000 |
| |
| t_norm = timestep[0].item() |
| current_guidance_scale = self.guidance_scale |
| if self.do_classifier_free_guidance and self._cfg_truncation is not None and float(self._cfg_truncation) <= 1: |
| if t_norm > self._cfg_truncation: |
| current_guidance_scale = 0.0 |
| apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0 |
|
|
| if apply_cfg: |
| latent_model_input = latents.repeat(2, 1, 1, 1) |
| timestep_model_input = timestep.repeat(2) |
| else: |
| latent_model_input = latents |
| timestep_model_input = timestep |
|
|
| latent_model_input = latent_model_input.to(self.transformer.dtype) |
| latent_model_input = latent_model_input.unsqueeze(2) |
| latent_model_input_list = list(latent_model_input.unbind(dim=0)) |
|
|
| model_out_list = self.transformer( |
| x=latent_model_input_list, |
| t=timestep_model_input, |
| cap_feats=prompt_embeds_model_input, |
| control_context=control_context_model_input, |
| conditioning_scale=controlnet_conditioning_scale, |
| refiner_conditioning_scale=controlnet_refiner_conditioning_scale, |
| )[0] |
|
|
| if apply_cfg: |
| pos_out = model_out_list[:effective_batch_size] |
| neg_out = model_out_list[effective_batch_size:] |
|
|
| noise_pred = [] |
| for j in range(effective_batch_size): |
| pos = pos_out[j].float() |
| neg = neg_out[j].float() |
|
|
| pred = pos + current_guidance_scale * (pos - neg) |
|
|
| if self._cfg_normalization and float(self._cfg_normalization) > 0.0: |
| ori_pos_norm = torch.linalg.vector_norm(pos) |
| new_pos_norm = torch.linalg.vector_norm(pred) |
| max_new_norm = ori_pos_norm * float(self._cfg_normalization) |
| if new_pos_norm > max_new_norm: |
| pred = pred * (max_new_norm / new_pos_norm) |
|
|
| noise_pred.append(pred) |
|
|
| noise_pred = torch.stack(noise_pred, dim=0) |
| else: |
| noise_pred = torch.stack([t.float() for t in model_out_list], dim=0) |
|
|
| noise_pred = noise_pred.squeeze(2) |
| noise_pred = -noise_pred |
|
|
| latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents).prev_sample |
|
|
| if callback_on_step_end is not None: |
| callback_kwargs = {} |
| for k in callback_on_step_end_tensor_inputs: |
| callback_kwargs[k] = locals()[k] |
| callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) |
|
|
| if isinstance(callback_outputs, dict): |
| latents = callback_outputs.pop("latents", latents) |
| prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) |
| negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) |
|
|
| if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): |
| progress_bar.update() |
|
|
| if output_type != "latent": |
| latents = latents.to(self.vae.dtype) |
| latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor |
| with torch.no_grad(): |
| image = self.vae.decode(latents, return_dict=False)[0] |
| image = self.image_processor.postprocess(image, output_type=output_type) |
| else: |
| image = latents |
|
|
| self.maybe_free_model_hooks() |
|
|
| if not return_dict: |
| return (image,) |
|
|
| return ZImagePipelineOutput(images=image) |
|
|