| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import warnings |
| from functools import partial |
| from typing import Dict, List, Optional, Union |
|
|
| import jax |
| import jax.numpy as jnp |
| import numpy as np |
| from flax.core.frozen_dict import FrozenDict |
| from flax.jax_utils import unreplicate |
| from flax.training.common_utils import shard |
| from packaging import version |
| from PIL import Image |
| from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel |
|
|
| from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel |
| from ...schedulers import ( |
| FlaxDDIMScheduler, |
| FlaxDPMSolverMultistepScheduler, |
| FlaxLMSDiscreteScheduler, |
| FlaxPNDMScheduler, |
| ) |
| from ...utils import PIL_INTERPOLATION, deprecate, logging, replace_example_docstring |
| from ..pipeline_flax_utils import FlaxDiffusionPipeline |
| from .pipeline_output import FlaxStableDiffusionPipelineOutput |
| from .safety_checker_flax import FlaxStableDiffusionSafetyChecker |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
| |
| DEBUG = False |
|
|
| EXAMPLE_DOC_STRING = """ |
| Examples: |
| ```py |
| >>> import jax |
| >>> import numpy as np |
| >>> from flax.jax_utils import replicate |
| >>> from flax.training.common_utils import shard |
| >>> import PIL |
| >>> import requests |
| >>> from io import BytesIO |
| >>> from diffusers import FlaxStableDiffusionInpaintPipeline |
| |
| |
| >>> def download_image(url): |
| ... response = requests.get(url) |
| ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") |
| |
| |
| >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" |
| >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" |
| |
| >>> init_image = download_image(img_url).resize((512, 512)) |
| >>> mask_image = download_image(mask_url).resize((512, 512)) |
| |
| >>> pipeline, params = FlaxStableDiffusionInpaintPipeline.from_pretrained( |
| ... "xvjiarui/stable-diffusion-2-inpainting" |
| ... ) |
| |
| >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench" |
| >>> prng_seed = jax.random.PRNGKey(0) |
| >>> num_inference_steps = 50 |
| |
| >>> num_samples = jax.device_count() |
| >>> prompt = num_samples * [prompt] |
| >>> init_image = num_samples * [init_image] |
| >>> mask_image = num_samples * [mask_image] |
| >>> prompt_ids, processed_masked_images, processed_masks = pipeline.prepare_inputs( |
| ... prompt, init_image, mask_image |
| ... ) |
| # shard inputs and rng |
| |
| >>> params = replicate(params) |
| >>> prng_seed = jax.random.split(prng_seed, jax.device_count()) |
| >>> prompt_ids = shard(prompt_ids) |
| >>> processed_masked_images = shard(processed_masked_images) |
| >>> processed_masks = shard(processed_masks) |
| |
| >>> images = pipeline( |
| ... prompt_ids, processed_masks, processed_masked_images, params, prng_seed, num_inference_steps, jit=True |
| ... ).images |
| >>> images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:]))) |
| ``` |
| """ |
|
|
|
|
| class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline): |
| r""" |
| Flax-based pipeline for text-guided image inpainting using Stable Diffusion. |
| |
| <Tip warning={true}> |
| |
| 🧪 This is an experimental feature! |
| |
| </Tip> |
| |
| This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods |
| implemented for all pipelines (downloading, saving, running on a particular device, etc.). |
| |
| Args: |
| vae ([`FlaxAutoencoderKL`]): |
| Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. |
| text_encoder ([`~transformers.FlaxCLIPTextModel`]): |
| Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). |
| tokenizer ([`~transformers.CLIPTokenizer`]): |
| A `CLIPTokenizer` to tokenize text. |
| unet ([`FlaxUNet2DConditionModel`]): |
| A `FlaxUNet2DConditionModel` to denoise the encoded image latents. |
| scheduler ([`SchedulerMixin`]): |
| A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of |
| [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or |
| [`FlaxDPMSolverMultistepScheduler`]. |
| safety_checker ([`FlaxStableDiffusionSafetyChecker`]): |
| Classification module that estimates whether generated images could be considered offensive or harmful. |
| Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details |
| about a model's potential harms. |
| feature_extractor ([`~transformers.CLIPImageProcessor`]): |
| A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. |
| """ |
|
|
| def __init__( |
| self, |
| vae: FlaxAutoencoderKL, |
| text_encoder: FlaxCLIPTextModel, |
| tokenizer: CLIPTokenizer, |
| unet: FlaxUNet2DConditionModel, |
| scheduler: Union[ |
| FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler |
| ], |
| safety_checker: FlaxStableDiffusionSafetyChecker, |
| feature_extractor: CLIPImageProcessor, |
| dtype: jnp.dtype = jnp.float32, |
| ): |
| super().__init__() |
| self.dtype = dtype |
|
|
| if safety_checker is None: |
| logger.warning( |
| f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" |
| " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" |
| " results in services or applications open to the public. Both the diffusers team and Hugging Face" |
| " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" |
| " it only for use-cases that involve analyzing network behavior or auditing its results. For more" |
| " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." |
| ) |
|
|
| is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( |
| version.parse(unet.config._diffusers_version).base_version |
| ) < version.parse("0.9.0.dev0") |
| is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 |
| if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: |
| deprecation_message = ( |
| "The configuration file of the unet has set the default `sample_size` to smaller than" |
| " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" |
| " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" |
| " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" |
| " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" |
| " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" |
| " in the config might lead to incorrect results in future versions. If you have downloaded this" |
| " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" |
| " the `unet/config.json` file" |
| ) |
| deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) |
| new_config = dict(unet.config) |
| new_config["sample_size"] = 64 |
| unet._internal_dict = FrozenDict(new_config) |
|
|
| self.register_modules( |
| vae=vae, |
| text_encoder=text_encoder, |
| tokenizer=tokenizer, |
| unet=unet, |
| scheduler=scheduler, |
| safety_checker=safety_checker, |
| feature_extractor=feature_extractor, |
| ) |
| self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) |
|
|
| def prepare_inputs( |
| self, |
| prompt: Union[str, List[str]], |
| image: Union[Image.Image, List[Image.Image]], |
| mask: Union[Image.Image, List[Image.Image]], |
| ): |
| if not isinstance(prompt, (str, list)): |
| raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") |
|
|
| if not isinstance(image, (Image.Image, list)): |
| raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}") |
|
|
| if isinstance(image, Image.Image): |
| image = [image] |
|
|
| if not isinstance(mask, (Image.Image, list)): |
| raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}") |
|
|
| if isinstance(mask, Image.Image): |
| mask = [mask] |
|
|
| processed_images = jnp.concatenate([preprocess_image(img, jnp.float32) for img in image]) |
| processed_masks = jnp.concatenate([preprocess_mask(m, jnp.float32) for m in mask]) |
| |
| processed_masks = processed_masks.at[processed_masks < 0.5].set(0) |
| |
| processed_masks = processed_masks.at[processed_masks >= 0.5].set(1) |
|
|
| processed_masked_images = processed_images * (processed_masks < 0.5) |
|
|
| text_input = self.tokenizer( |
| prompt, |
| padding="max_length", |
| max_length=self.tokenizer.model_max_length, |
| truncation=True, |
| return_tensors="np", |
| ) |
| return text_input.input_ids, processed_masked_images, processed_masks |
|
|
| def _get_has_nsfw_concepts(self, features, params): |
| has_nsfw_concepts = self.safety_checker(features, params) |
| return has_nsfw_concepts |
|
|
| def _run_safety_checker(self, images, safety_model_params, jit=False): |
| |
| pil_images = [Image.fromarray(image) for image in images] |
| features = self.feature_extractor(pil_images, return_tensors="np").pixel_values |
|
|
| if jit: |
| features = shard(features) |
| has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params) |
| has_nsfw_concepts = unshard(has_nsfw_concepts) |
| safety_model_params = unreplicate(safety_model_params) |
| else: |
| has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params) |
|
|
| images_was_copied = False |
| for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): |
| if has_nsfw_concept: |
| if not images_was_copied: |
| images_was_copied = True |
| images = images.copy() |
|
|
| images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) |
|
|
| if any(has_nsfw_concepts): |
| warnings.warn( |
| "Potential NSFW content was detected in one or more images. A black image will be returned" |
| " instead. Try again with a different prompt and/or seed." |
| ) |
|
|
| return images, has_nsfw_concepts |
|
|
| def _generate( |
| self, |
| prompt_ids: jnp.ndarray, |
| mask: jnp.ndarray, |
| masked_image: jnp.ndarray, |
| params: Union[Dict, FrozenDict], |
| prng_seed: jax.Array, |
| num_inference_steps: int, |
| height: int, |
| width: int, |
| guidance_scale: float, |
| latents: Optional[jnp.ndarray] = None, |
| neg_prompt_ids: Optional[jnp.ndarray] = None, |
| ): |
| if height % 8 != 0 or width % 8 != 0: |
| raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") |
|
|
| |
| prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] |
|
|
| |
| |
| batch_size = prompt_ids.shape[0] |
|
|
| max_length = prompt_ids.shape[-1] |
|
|
| if neg_prompt_ids is None: |
| uncond_input = self.tokenizer( |
| [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" |
| ).input_ids |
| else: |
| uncond_input = neg_prompt_ids |
| negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0] |
| context = jnp.concatenate([negative_prompt_embeds, prompt_embeds]) |
|
|
| latents_shape = ( |
| batch_size, |
| self.vae.config.latent_channels, |
| height // self.vae_scale_factor, |
| width // self.vae_scale_factor, |
| ) |
| if latents is None: |
| latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=self.dtype) |
| else: |
| if latents.shape != latents_shape: |
| raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") |
|
|
| prng_seed, mask_prng_seed = jax.random.split(prng_seed) |
|
|
| masked_image_latent_dist = self.vae.apply( |
| {"params": params["vae"]}, masked_image, method=self.vae.encode |
| ).latent_dist |
| masked_image_latents = masked_image_latent_dist.sample(key=mask_prng_seed).transpose((0, 3, 1, 2)) |
| masked_image_latents = self.vae.config.scaling_factor * masked_image_latents |
| del mask_prng_seed |
|
|
| mask = jax.image.resize(mask, (*mask.shape[:-2], *masked_image_latents.shape[-2:]), method="nearest") |
|
|
| |
| num_channels_latents = self.vae.config.latent_channels |
| num_channels_mask = mask.shape[1] |
| num_channels_masked_image = masked_image_latents.shape[1] |
| if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: |
| raise ValueError( |
| f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" |
| f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" |
| f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" |
| f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" |
| " `pipeline.unet` or your `mask_image` or `image` input." |
| ) |
|
|
| def loop_body(step, args): |
| latents, mask, masked_image_latents, scheduler_state = args |
| |
| |
| |
| latents_input = jnp.concatenate([latents] * 2) |
| mask_input = jnp.concatenate([mask] * 2) |
| masked_image_latents_input = jnp.concatenate([masked_image_latents] * 2) |
|
|
| t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] |
| timestep = jnp.broadcast_to(t, latents_input.shape[0]) |
|
|
| latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t) |
| |
| latents_input = jnp.concatenate([latents_input, mask_input, masked_image_latents_input], axis=1) |
|
|
| |
| noise_pred = self.unet.apply( |
| {"params": params["unet"]}, |
| jnp.array(latents_input), |
| jnp.array(timestep, dtype=jnp.int32), |
| encoder_hidden_states=context, |
| ).sample |
| |
| noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) |
|
|
| |
| latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() |
| return latents, mask, masked_image_latents, scheduler_state |
|
|
| scheduler_state = self.scheduler.set_timesteps( |
| params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape |
| ) |
|
|
| |
| latents = latents * params["scheduler"].init_noise_sigma |
|
|
| if DEBUG: |
| |
| for i in range(num_inference_steps): |
| latents, mask, masked_image_latents, scheduler_state = loop_body( |
| i, (latents, mask, masked_image_latents, scheduler_state) |
| ) |
| else: |
| latents, _, _, _ = jax.lax.fori_loop( |
| 0, num_inference_steps, loop_body, (latents, mask, masked_image_latents, scheduler_state) |
| ) |
|
|
| |
| latents = 1 / self.vae.config.scaling_factor * latents |
| image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample |
|
|
| image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) |
| return image |
|
|
| @replace_example_docstring(EXAMPLE_DOC_STRING) |
| def __call__( |
| self, |
| prompt_ids: jnp.ndarray, |
| mask: jnp.ndarray, |
| masked_image: jnp.ndarray, |
| params: Union[Dict, FrozenDict], |
| prng_seed: jax.Array, |
| num_inference_steps: int = 50, |
| height: Optional[int] = None, |
| width: Optional[int] = None, |
| guidance_scale: Union[float, jnp.ndarray] = 7.5, |
| latents: jnp.ndarray = None, |
| neg_prompt_ids: jnp.ndarray = None, |
| return_dict: bool = True, |
| jit: bool = False, |
| ): |
| r""" |
| Function invoked when calling the pipeline for generation. |
| |
| Args: |
| prompt (`str` or `List[str]`): |
| The prompt or prompts to guide image generation. |
| height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): |
| The height in pixels of the generated image. |
| width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): |
| The width in pixels of the generated image. |
| num_inference_steps (`int`, *optional*, defaults to 50): |
| The number of denoising steps. More denoising steps usually lead to a higher quality image at the |
| expense of slower inference. This parameter is modulated by `strength`. |
| guidance_scale (`float`, *optional*, defaults to 7.5): |
| A higher guidance scale value encourages the model to generate images closely linked to the text |
| `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. |
| latents (`jnp.ndarray`, *optional*): |
| Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image |
| generation. Can be used to tweak the same generation with different prompts. If not provided, a latents |
| array is generated by sampling using the supplied random `generator`. |
| jit (`bool`, defaults to `False`): |
| Whether to run `pmap` versions of the generation and safety scoring functions. |
| |
| <Tip warning={true}> |
| |
| This argument exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a |
| future release. |
| |
| </Tip> |
| |
| return_dict (`bool`, *optional*, defaults to `True`): |
| Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of |
| a plain tuple. |
| |
| Examples: |
| |
| Returns: |
| [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: |
| If `return_dict` is `True`, [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] is |
| returned, otherwise a `tuple` is returned where the first element is a list with the generated images |
| and the second element is a list of `bool`s indicating whether the corresponding generated image |
| contains "not-safe-for-work" (nsfw) content. |
| """ |
| |
| height = height or self.unet.config.sample_size * self.vae_scale_factor |
| width = width or self.unet.config.sample_size * self.vae_scale_factor |
|
|
| masked_image = jax.image.resize(masked_image, (*masked_image.shape[:-2], height, width), method="bicubic") |
| mask = jax.image.resize(mask, (*mask.shape[:-2], height, width), method="nearest") |
|
|
| if isinstance(guidance_scale, float): |
| |
| |
| guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0]) |
| if len(prompt_ids.shape) > 2: |
| |
| guidance_scale = guidance_scale[:, None] |
|
|
| if jit: |
| images = _p_generate( |
| self, |
| prompt_ids, |
| mask, |
| masked_image, |
| params, |
| prng_seed, |
| num_inference_steps, |
| height, |
| width, |
| guidance_scale, |
| latents, |
| neg_prompt_ids, |
| ) |
| else: |
| images = self._generate( |
| prompt_ids, |
| mask, |
| masked_image, |
| params, |
| prng_seed, |
| num_inference_steps, |
| height, |
| width, |
| guidance_scale, |
| latents, |
| neg_prompt_ids, |
| ) |
|
|
| if self.safety_checker is not None: |
| safety_params = params["safety_checker"] |
| images_uint8_casted = (images * 255).round().astype("uint8") |
| num_devices, batch_size = images.shape[:2] |
|
|
| images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3) |
| images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit) |
| images = np.asarray(images) |
|
|
| |
| if any(has_nsfw_concept): |
| for i, is_nsfw in enumerate(has_nsfw_concept): |
| if is_nsfw: |
| images[i] = np.asarray(images_uint8_casted[i]) |
|
|
| images = images.reshape(num_devices, batch_size, height, width, 3) |
| else: |
| images = np.asarray(images) |
| has_nsfw_concept = False |
|
|
| if not return_dict: |
| return (images, has_nsfw_concept) |
|
|
| return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) |
|
|
|
|
| |
| |
| @partial( |
| jax.pmap, |
| in_axes=(None, 0, 0, 0, 0, 0, None, None, None, 0, 0, 0), |
| static_broadcasted_argnums=(0, 6, 7, 8), |
| ) |
| def _p_generate( |
| pipe, |
| prompt_ids, |
| mask, |
| masked_image, |
| params, |
| prng_seed, |
| num_inference_steps, |
| height, |
| width, |
| guidance_scale, |
| latents, |
| neg_prompt_ids, |
| ): |
| return pipe._generate( |
| prompt_ids, |
| mask, |
| masked_image, |
| params, |
| prng_seed, |
| num_inference_steps, |
| height, |
| width, |
| guidance_scale, |
| latents, |
| neg_prompt_ids, |
| ) |
|
|
|
|
| @partial(jax.pmap, static_broadcasted_argnums=(0,)) |
| def _p_get_has_nsfw_concepts(pipe, features, params): |
| return pipe._get_has_nsfw_concepts(features, params) |
|
|
|
|
| def unshard(x: jnp.ndarray): |
| |
| num_devices, batch_size = x.shape[:2] |
| rest = x.shape[2:] |
| return x.reshape(num_devices * batch_size, *rest) |
|
|
|
|
| def preprocess_image(image, dtype): |
| w, h = image.size |
| w, h = (x - x % 32 for x in (w, h)) |
| image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) |
| image = jnp.array(image).astype(dtype) / 255.0 |
| image = image[None].transpose(0, 3, 1, 2) |
| return 2.0 * image - 1.0 |
|
|
|
|
| def preprocess_mask(mask, dtype): |
| w, h = mask.size |
| w, h = (x - x % 32 for x in (w, h)) |
| mask = mask.resize((w, h)) |
| mask = jnp.array(mask.convert("L")).astype(dtype) / 255.0 |
| mask = jnp.expand_dims(mask, axis=(0, 1)) |
|
|
| return mask |
|
|