| import inspect |
| from typing import Any, Callable, Dict, List, Optional, Union |
|
|
| import numpy as np |
| import torch |
| from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast |
|
|
| from diffusers.image_processor import (VaeImageProcessor) |
| from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin |
| from diffusers.models.autoencoders import AutoencoderKL |
| from diffusers.schedulers import FlowMatchEulerDiscreteScheduler |
| from diffusers.utils import ( |
| USE_PEFT_BACKEND, |
| is_torch_xla_available, |
| logging, |
| scale_lora_layers, |
| unscale_lora_layers, |
| ) |
| from diffusers.utils.torch_utils import randn_tensor |
| from diffusers.pipelines.pipeline_utils import DiffusionPipeline |
| from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput |
| from torchvision.transforms.functional import pad |
| from diffusers import FluxTransformer2DModel |
|
|
| if is_torch_xla_available(): |
| import torch_xla.core.xla_model as xm |
|
|
| XLA_AVAILABLE = True |
| else: |
| XLA_AVAILABLE = False |
|
|
| logger = logging.get_logger(__name__) |
|
|
| def calculate_shift( |
| image_seq_len, |
| base_seq_len: int = 256, |
| max_seq_len: int = 4096, |
| base_shift: float = 0.5, |
| max_shift: float = 1.16, |
| ): |
| m = (max_shift - base_shift) / (max_seq_len - base_seq_len) |
| b = base_shift - m * base_seq_len |
| mu = image_seq_len * m + b |
| return mu |
|
|
| def prepare_latent_image_ids_(height, width, device, dtype): |
| latent_image_ids = torch.zeros(height//2, width//2, 3, device=device, dtype=dtype) |
| latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height//2, device=device)[:, None] |
| latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width//2, device=device)[None, :] |
| return latent_image_ids |
|
|
| def prepare_latent_subject_ids(height, width, device, dtype): |
| latent_image_ids = torch.zeros(height // 2, width // 2, 3, device=device, dtype=dtype) |
| latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2, device=device)[:, None] |
| latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2, device=device)[None, :] |
| latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape |
| latent_image_ids = latent_image_ids.reshape( |
| latent_image_id_height * latent_image_id_width, latent_image_id_channels |
| ) |
| return latent_image_ids.to(device=device, dtype=dtype) |
|
|
| def resize_position_encoding(batch_size, original_height, original_width, target_height, target_width, device, dtype): |
| latent_image_ids = prepare_latent_image_ids_(original_height, original_width, device, dtype) |
| latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape |
| latent_image_ids = latent_image_ids.reshape( |
| latent_image_id_height * latent_image_id_width, latent_image_id_channels |
| ) |
| |
| scale_h = original_height / target_height |
| scale_w = original_width / target_width |
| latent_image_ids_resized = torch.zeros(target_height//2, target_width//2, 3, device=device, dtype=dtype) |
| latent_image_ids_resized[..., 1] = latent_image_ids_resized[..., 1] + torch.arange(target_height//2, device=device)[:, None] * scale_h |
| latent_image_ids_resized[..., 2] = latent_image_ids_resized[..., 2] + torch.arange(target_width//2, device=device)[None, :] * scale_w |
| |
| cond_latent_image_id_height, cond_latent_image_id_width, cond_latent_image_id_channels = latent_image_ids_resized.shape |
| cond_latent_image_ids = latent_image_ids_resized.reshape( |
| cond_latent_image_id_height * cond_latent_image_id_width, cond_latent_image_id_channels |
| ) |
| return latent_image_ids, cond_latent_image_ids |
| |
| |
| def retrieve_latents( |
| encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" |
| ): |
| if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": |
| return encoder_output.latent_dist.sample(generator) |
| elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": |
| return encoder_output.latent_dist.mode() |
| elif hasattr(encoder_output, "latents"): |
| return encoder_output.latents |
| else: |
| raise AttributeError("Could not access latents of provided encoder_output") |
|
|
|
|
| |
| def retrieve_timesteps( |
| scheduler, |
| num_inference_steps: Optional[int] = None, |
| device: Optional[Union[str, torch.device]] = None, |
| timesteps: Optional[List[int]] = None, |
| sigmas: Optional[List[float]] = None, |
| **kwargs, |
| ): |
| if timesteps is not None and sigmas is not None: |
| raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") |
| if timesteps is not None: |
| accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) |
| if not accepts_timesteps: |
| raise ValueError( |
| f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" |
| f" timestep schedules. Please check whether you are using the correct scheduler." |
| ) |
| scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) |
| timesteps = scheduler.timesteps |
| num_inference_steps = len(timesteps) |
| elif sigmas is not None: |
| accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) |
| if not accept_sigmas: |
| raise ValueError( |
| f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" |
| f" sigmas schedules. Please check whether you are using the correct scheduler." |
| ) |
| scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) |
| timesteps = scheduler.timesteps |
| num_inference_steps = len(timesteps) |
| else: |
| scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) |
| timesteps = scheduler.timesteps |
| return timesteps, num_inference_steps |
|
|
|
|
| class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin): |
| def __init__( |
| self, |
| scheduler: FlowMatchEulerDiscreteScheduler, |
| vae: AutoencoderKL, |
| text_encoder: CLIPTextModel, |
| tokenizer: CLIPTokenizer, |
| text_encoder_2: T5EncoderModel, |
| tokenizer_2: T5TokenizerFast, |
| transformer: FluxTransformer2DModel, |
| ): |
| super().__init__() |
|
|
| self.register_modules( |
| vae=vae, |
| text_encoder=text_encoder, |
| text_encoder_2=text_encoder_2, |
| tokenizer=tokenizer, |
| tokenizer_2=tokenizer_2, |
| transformer=transformer, |
| scheduler=scheduler, |
| ) |
| self.vae_scale_factor = ( |
| 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 |
| ) |
| self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) |
| self.tokenizer_max_length = ( |
| self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 |
| ) |
| self.default_sample_size = 64 |
|
|
| def _get_t5_prompt_embeds( |
| self, |
| prompt: Union[str, List[str]] = None, |
| num_images_per_prompt: int = 1, |
| max_sequence_length: int = 512, |
| device: Optional[torch.device] = None, |
| dtype: Optional[torch.dtype] = None, |
| ): |
| device = device or self._execution_device |
| dtype = dtype or self.text_encoder.dtype |
|
|
| prompt = [prompt] if isinstance(prompt, str) else prompt |
| batch_size = len(prompt) |
|
|
| text_inputs = self.tokenizer_2( |
| prompt, |
| padding="max_length", |
| max_length=max_sequence_length, |
| truncation=True, |
| return_length=False, |
| return_overflowing_tokens=False, |
| return_tensors="pt", |
| ) |
| text_input_ids = text_inputs.input_ids |
| untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids |
|
|
| if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): |
| removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1: -1]) |
| logger.warning( |
| "The following part of your input was truncated because `max_sequence_length` is set to " |
| f" {max_sequence_length} tokens: {removed_text}" |
| ) |
|
|
| prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] |
|
|
| dtype = self.text_encoder_2.dtype |
| prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) |
|
|
| _, seq_len, _ = prompt_embeds.shape |
|
|
| |
| prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) |
| prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) |
|
|
| return prompt_embeds |
|
|
| def _get_clip_prompt_embeds( |
| self, |
| prompt: Union[str, List[str]], |
| num_images_per_prompt: int = 1, |
| device: Optional[torch.device] = None, |
| ): |
| device = device or self._execution_device |
|
|
| prompt = [prompt] if isinstance(prompt, str) else prompt |
| batch_size = len(prompt) |
|
|
| text_inputs = self.tokenizer( |
| prompt, |
| padding="max_length", |
| max_length=self.tokenizer_max_length, |
| truncation=True, |
| return_overflowing_tokens=False, |
| return_length=False, |
| return_tensors="pt", |
| ) |
|
|
| text_input_ids = text_inputs.input_ids |
| untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids |
| if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): |
| removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1: -1]) |
| logger.warning( |
| "The following part of your input was truncated because CLIP can only handle sequences up to" |
| f" {self.tokenizer_max_length} tokens: {removed_text}" |
| ) |
| prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) |
|
|
| |
| prompt_embeds = prompt_embeds.pooler_output |
| prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) |
|
|
| |
| prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) |
| prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) |
|
|
| return prompt_embeds |
|
|
| def encode_prompt( |
| self, |
| prompt: Union[str, List[str]], |
| prompt_2: Union[str, List[str]], |
| device: Optional[torch.device] = None, |
| num_images_per_prompt: int = 1, |
| prompt_embeds: Optional[torch.FloatTensor] = None, |
| pooled_prompt_embeds: Optional[torch.FloatTensor] = None, |
| max_sequence_length: int = 512, |
| lora_scale: Optional[float] = None, |
| ): |
| device = device or self._execution_device |
|
|
| |
| |
| if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): |
| self._lora_scale = lora_scale |
|
|
| |
| if self.text_encoder is not None and USE_PEFT_BACKEND: |
| scale_lora_layers(self.text_encoder, lora_scale) |
| if self.text_encoder_2 is not None and USE_PEFT_BACKEND: |
| scale_lora_layers(self.text_encoder_2, lora_scale) |
|
|
| prompt = [prompt] if isinstance(prompt, str) else prompt |
|
|
| if prompt_embeds is None: |
| prompt_2 = prompt_2 or prompt |
| prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 |
|
|
| |
| pooled_prompt_embeds = self._get_clip_prompt_embeds( |
| prompt=prompt, |
| device=device, |
| num_images_per_prompt=num_images_per_prompt, |
| ) |
| prompt_embeds = self._get_t5_prompt_embeds( |
| prompt=prompt_2, |
| num_images_per_prompt=num_images_per_prompt, |
| max_sequence_length=max_sequence_length, |
| device=device, |
| ) |
|
|
| if self.text_encoder is not None: |
| if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: |
| |
| unscale_lora_layers(self.text_encoder, lora_scale) |
|
|
| if self.text_encoder_2 is not None: |
| if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: |
| |
| unscale_lora_layers(self.text_encoder_2, lora_scale) |
|
|
| dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype |
| text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) |
|
|
| return prompt_embeds, pooled_prompt_embeds, text_ids |
|
|
| |
| def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): |
| if isinstance(generator, list): |
| image_latents = [ |
| retrieve_latents(self.vae.encode(image[i: i + 1]), generator=generator[i]) |
| for i in range(image.shape[0]) |
| ] |
| image_latents = torch.cat(image_latents, dim=0) |
| else: |
| image_latents = retrieve_latents(self.vae.encode(image), generator=generator) |
|
|
| image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor |
|
|
| return image_latents |
|
|
| def check_inputs( |
| self, |
| prompt, |
| prompt_2, |
| height, |
| width, |
| prompt_embeds=None, |
| pooled_prompt_embeds=None, |
| callback_on_step_end_tensor_inputs=None, |
| max_sequence_length=None, |
| ): |
| if height % 8 != 0 or width % 8 != 0: |
| raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") |
|
|
| if prompt is not None and prompt_embeds is not None: |
| raise ValueError( |
| f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" |
| " only forward one of the two." |
| ) |
| elif prompt_2 is not None and prompt_embeds is not None: |
| raise ValueError( |
| f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" |
| " only forward one of the two." |
| ) |
| elif prompt is None and prompt_embeds is None: |
| raise ValueError( |
| "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." |
| ) |
| elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): |
| raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") |
| elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): |
| raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") |
|
|
| if prompt_embeds is not None and pooled_prompt_embeds is None: |
| raise ValueError( |
| "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." |
| ) |
|
|
| if max_sequence_length is not None and max_sequence_length > 512: |
| raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") |
|
|
| @staticmethod |
| def _prepare_latent_image_ids(batch_size, height, width, device, dtype): |
| latent_image_ids = torch.zeros(height // 2, width // 2, 3) |
| latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] |
| latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] |
| latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape |
| latent_image_ids = latent_image_ids.reshape( |
| latent_image_id_height * latent_image_id_width, latent_image_id_channels |
| ) |
| return latent_image_ids.to(device=device, dtype=dtype) |
|
|
| @staticmethod |
| def _pack_latents(latents, batch_size, num_channels_latents, height, width): |
| latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) |
| latents = latents.permute(0, 2, 4, 1, 3, 5) |
| latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) |
| return latents |
|
|
| @staticmethod |
| def _unpack_latents(latents, height, width, vae_scale_factor): |
| batch_size, num_patches, channels = latents.shape |
|
|
| height = height // vae_scale_factor |
| width = width // vae_scale_factor |
|
|
| latents = latents.view(batch_size, height, width, channels // 4, 2, 2) |
| latents = latents.permute(0, 3, 1, 4, 2, 5) |
|
|
| latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) |
|
|
| return latents |
|
|
| def enable_vae_slicing(self): |
| r""" |
| Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to |
| compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. |
| """ |
| self.vae.enable_slicing() |
|
|
| def disable_vae_slicing(self): |
| r""" |
| Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to |
| computing decoding in one step. |
| """ |
| self.vae.disable_slicing() |
|
|
| def enable_vae_tiling(self): |
| r""" |
| Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to |
| compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow |
| processing larger images. |
| """ |
| self.vae.enable_tiling() |
|
|
| def disable_vae_tiling(self): |
| r""" |
| Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to |
| computing decoding in one step. |
| """ |
| self.vae.disable_tiling() |
|
|
| def prepare_latents( |
| self, |
| batch_size, |
| num_channels_latents, |
| height, |
| width, |
| dtype, |
| device, |
| generator, |
| subject_image, |
| condition_image, |
| latents=None, |
| cond_number=1, |
| sub_number=1 |
| ): |
| height_cond = 2 * (self.cond_size // self.vae_scale_factor) |
| width_cond = 2 * (self.cond_size // self.vae_scale_factor) |
| height = 2 * (int(height) // self.vae_scale_factor) |
| width = 2 * (int(width) // self.vae_scale_factor) |
|
|
| shape = (batch_size, num_channels_latents, height, width) |
| noise_latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
| noise_latents = self._pack_latents(noise_latents, batch_size, num_channels_latents, height, width) |
| noise_latent_image_ids, cond_latent_image_ids = resize_position_encoding( |
| batch_size, |
| height, |
| width, |
| height_cond, |
| width_cond, |
| device, |
| dtype, |
| ) |
| |
| latents_to_concat = [] |
| latents_ids_to_concat = [noise_latent_image_ids] |
| |
| |
| if subject_image is not None: |
| shape_subject = (batch_size, num_channels_latents, height_cond*sub_number, width_cond) |
| subject_image = subject_image.to(device=device, dtype=dtype) |
| subject_image_latents = self._encode_vae_image(image=subject_image, generator=generator) |
| subject_latents = self._pack_latents(subject_image_latents, batch_size, num_channels_latents, height_cond*sub_number, width_cond) |
| mask2 = torch.zeros(shape_subject, device=device, dtype=dtype) |
| mask2 = self._pack_latents(mask2, batch_size, num_channels_latents, height_cond*sub_number, width_cond) |
| latent_subject_ids = prepare_latent_subject_ids(height_cond, width_cond, device, dtype) |
| latent_subject_ids[:, 1] += 64 |
| subject_latent_image_ids = torch.concat([latent_subject_ids for _ in range(sub_number)], dim=-2) |
| latents_to_concat.append(subject_latents) |
| latents_ids_to_concat.append(subject_latent_image_ids) |
| |
| |
| if condition_image is not None: |
| shape_cond = (batch_size, num_channels_latents, height_cond*cond_number, width_cond) |
| condition_image = condition_image.to(device=device, dtype=dtype) |
| image_latents = self._encode_vae_image(image=condition_image, generator=generator) |
| cond_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height_cond*cond_number, width_cond) |
| mask3 = torch.zeros(shape_cond, device=device, dtype=dtype) |
| mask3 = self._pack_latents(mask3, batch_size, num_channels_latents, height_cond*cond_number, width_cond) |
| cond_latent_image_ids = cond_latent_image_ids |
| cond_latent_image_ids = torch.concat([cond_latent_image_ids for _ in range(cond_number)], dim=-2) |
| latents_ids_to_concat.append(cond_latent_image_ids) |
| latents_to_concat.append(cond_latents) |
|
|
| cond_latents = torch.concat(latents_to_concat, dim=-2) |
| latent_image_ids = torch.concat(latents_ids_to_concat, dim=-2) |
| return cond_latents, latent_image_ids, noise_latents |
|
|
| @property |
| def guidance_scale(self): |
| return self._guidance_scale |
|
|
| @property |
| def joint_attention_kwargs(self): |
| return self._joint_attention_kwargs |
|
|
| @property |
| def num_timesteps(self): |
| return self._num_timesteps |
|
|
| @property |
| def interrupt(self): |
| return self._interrupt |
|
|
| @torch.no_grad() |
| def __call__( |
| self, |
| prompt: Union[str, List[str]] = None, |
| prompt_2: Optional[Union[str, List[str]]] = None, |
| height: Optional[int] = None, |
| width: Optional[int] = None, |
| num_inference_steps: int = 28, |
| timesteps: List[int] = None, |
| guidance_scale: float = 3.5, |
| num_images_per_prompt: Optional[int] = 1, |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
| latents: Optional[torch.FloatTensor] = None, |
| prompt_embeds: Optional[torch.FloatTensor] = None, |
| pooled_prompt_embeds: Optional[torch.FloatTensor] = None, |
| output_type: Optional[str] = "pil", |
| return_dict: bool = True, |
| joint_attention_kwargs: Optional[Dict[str, Any]] = None, |
| callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, |
| callback_on_step_end_tensor_inputs: List[str] = ["latents"], |
| max_sequence_length: int = 512, |
| spatial_images=[], |
| subject_images=[], |
| cond_size=512, |
| ): |
|
|
| height = height or self.default_sample_size * self.vae_scale_factor |
| width = width or self.default_sample_size * self.vae_scale_factor |
| self.cond_size = cond_size |
| |
| |
| self.check_inputs( |
| prompt, |
| prompt_2, |
| height, |
| width, |
| prompt_embeds=prompt_embeds, |
| pooled_prompt_embeds=pooled_prompt_embeds, |
| callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, |
| max_sequence_length=max_sequence_length, |
| ) |
|
|
| self._guidance_scale = guidance_scale |
| self._joint_attention_kwargs = joint_attention_kwargs |
| self._interrupt = False |
| |
| cond_number = len(spatial_images) |
| sub_number = len(subject_images) |
| |
| if sub_number > 0: |
| subject_image_ls = [] |
| for subject_image in subject_images: |
| w, h = subject_image.size[:2] |
| scale = self.cond_size / max(h, w) |
| new_h, new_w = int(h * scale), int(w * scale) |
| subject_image = self.image_processor.preprocess(subject_image, height=new_h, width=new_w) |
| subject_image = subject_image.to(dtype=torch.float32) |
| pad_h = cond_size - subject_image.shape[-2] |
| pad_w = cond_size - subject_image.shape[-1] |
| subject_image = pad( |
| subject_image, |
| padding=(int(pad_w / 2), int(pad_h / 2), int(pad_w / 2), int(pad_h / 2)), |
| fill=0 |
| ) |
| subject_image_ls.append(subject_image) |
| subject_image = torch.concat(subject_image_ls, dim=-2) |
| else: |
| subject_image = None |
| |
| if cond_number > 0: |
| condition_image_ls = [] |
| for img in spatial_images: |
| print(img) |
| condition_image = self.image_processor.preprocess(img, height=self.cond_size, width=self.cond_size) |
| condition_image = condition_image.to(dtype=torch.float32) |
| condition_image_ls.append(condition_image) |
| condition_image = torch.concat(condition_image_ls, dim=-2) |
| else: |
| condition_image = None |
| |
| |
| if prompt is not None and isinstance(prompt, str): |
| batch_size = 1 |
| elif prompt is not None and isinstance(prompt, list): |
| batch_size = len(prompt) |
| else: |
| batch_size = prompt_embeds.shape[0] |
|
|
| device = self._execution_device |
|
|
| lora_scale = ( |
| self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None |
| ) |
| ( |
| prompt_embeds, |
| pooled_prompt_embeds, |
| text_ids, |
| ) = self.encode_prompt( |
| prompt=prompt, |
| prompt_2=prompt_2, |
| prompt_embeds=prompt_embeds, |
| pooled_prompt_embeds=pooled_prompt_embeds, |
| device=device, |
| num_images_per_prompt=num_images_per_prompt, |
| max_sequence_length=max_sequence_length, |
| lora_scale=lora_scale, |
| ) |
|
|
| |
| num_channels_latents = self.transformer.config.in_channels // 4 |
| cond_latents, latent_image_ids, noise_latents = self.prepare_latents( |
| batch_size * num_images_per_prompt, |
| num_channels_latents, |
| height, |
| width, |
| prompt_embeds.dtype, |
| device, |
| generator, |
| subject_image, |
| condition_image, |
| latents, |
| cond_number, |
| sub_number |
| ) |
| latents = noise_latents |
| |
| sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) |
| image_seq_len = latents.shape[1] |
| mu = calculate_shift( |
| image_seq_len, |
| self.scheduler.config.base_image_seq_len, |
| self.scheduler.config.max_image_seq_len, |
| self.scheduler.config.base_shift, |
| self.scheduler.config.max_shift, |
| ) |
| timesteps, num_inference_steps = retrieve_timesteps( |
| self.scheduler, |
| num_inference_steps, |
| device, |
| timesteps, |
| sigmas, |
| mu=mu, |
| ) |
| num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) |
| self._num_timesteps = len(timesteps) |
|
|
| |
| if self.transformer.config.guidance_embeds: |
| guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) |
| guidance = guidance.expand(latents.shape[0]) |
| else: |
| guidance = None |
|
|
| |
| |
| try: |
| for name, attn_processor in self.transformer.attn_processors.items(): |
| attn_processor.bank_kv.clear() |
| except: |
| pass |
| |
| t = torch.tensor([timesteps[0]], device=device) |
| timestep = t.expand(cond_latents.shape[0]).to(latents.dtype) |
| warmup_image_ids = latent_image_ids[latents.shape[1]:, :] |
| _ = self.transformer( |
| hidden_states=cond_latents, |
| timestep=torch.ones_like(timestep) * 0, |
| guidance=guidance, |
| pooled_projections=pooled_prompt_embeds, |
| encoder_hidden_states=prompt_embeds, |
| txt_ids=text_ids, |
| img_ids=warmup_image_ids, |
| joint_attention_kwargs=self.joint_attention_kwargs, |
| return_dict=False, |
| )[0] |
| |
| del cond_latents, spatial_images, condition_image, condition_image_ls, img, _ |
| torch.cuda.empty_cache() |
|
|
| |
| with self.progress_bar(total=num_inference_steps) as progress_bar: |
| for i, t in enumerate(timesteps): |
| if self.interrupt: |
| continue |
|
|
| |
| timestep = t.expand(latents.shape[0]).to(latents.dtype) |
| noise_pred = self.transformer( |
| hidden_states=latents, |
| timestep=timestep / 1000, |
| guidance=guidance, |
| pooled_projections=pooled_prompt_embeds, |
| encoder_hidden_states=prompt_embeds, |
| txt_ids=text_ids, |
| img_ids=latent_image_ids, |
| joint_attention_kwargs=self.joint_attention_kwargs, |
| return_dict=False, |
| )[0] |
|
|
| |
| latents_dtype = latents.dtype |
| latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] |
|
|
| if latents.dtype != latents_dtype: |
| if torch.backends.mps.is_available(): |
| |
| latents = latents.to(latents_dtype) |
|
|
| if callback_on_step_end is not None: |
| callback_kwargs = {} |
| for k in callback_on_step_end_tensor_inputs: |
| callback_kwargs[k] = locals()[k] |
| callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) |
|
|
| latents = callback_outputs.pop("latents", latents) |
| prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) |
|
|
| |
| if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): |
| progress_bar.update() |
|
|
| if XLA_AVAILABLE: |
| xm.mark_step() |
|
|
| if output_type == "latent": |
| image = latents |
| else: |
| latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) |
| latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor |
| image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0] |
| image = self.image_processor.postprocess(image, output_type=output_type) |
|
|
| |
| self.maybe_free_model_hooks() |
|
|
| if not return_dict: |
| return (image,) |
|
|
| return FluxPipelineOutput(images=image) |
|
|