| | |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | from dataclasses import dataclass |
| | from typing import List, Optional |
| |
|
| |
|
| | from PIL import Image |
| | import numpy as np |
| | import torch |
| | from torchvision import transforms as TF |
| | from tqdm import tqdm |
| | import pdb |
| |
|
| |
|
| | from diffusers import DiffusionPipeline |
| | from diffusers.utils import BaseOutput |
| |
|
| |
|
| | from diffusers import UNet2DConditionModel, EulerDiscreteScheduler, AutoencoderKL |
| | from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker |
| | from transformers import CLIPImageProcessor |
| | from transformers import AutoModelForCausalLM, AutoTokenizer |
| |
|
| |
|
| | EVA_IMAGE_SIZE = 448 |
| | OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) |
| | OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) |
| | DEFAULT_IMG_PLACEHOLDER = "<image>" |
| |
|
| |
|
| | @dataclass |
| | class EmuVisualGenerationPipelineOutput(BaseOutput): |
| | image: Image.Image |
| | nsfw_content_detected: Optional[bool] |
| |
|
| |
|
| |
|
| |
|
| | class EmuVisualGenerationPipeline(DiffusionPipeline): |
| |
|
| |
|
| | def __init__( |
| | self, |
| | tokenizer: AutoTokenizer, |
| | multimodal_encoder: AutoModelForCausalLM, |
| | scheduler: EulerDiscreteScheduler, |
| | unet: UNet2DConditionModel, |
| | vae: AutoencoderKL, |
| | feature_extractor: CLIPImageProcessor, |
| | safety_checker: StableDiffusionSafetyChecker, |
| | eva_size=EVA_IMAGE_SIZE, |
| | eva_mean=OPENAI_DATASET_MEAN, |
| | eva_std=OPENAI_DATASET_STD, |
| | ): |
| | super().__init__() |
| | self.register_modules( |
| | tokenizer=tokenizer, |
| | multimodal_encoder=multimodal_encoder, |
| | scheduler=scheduler, |
| | unet=unet, |
| | vae=vae, |
| | feature_extractor=feature_extractor, |
| | safety_checker=safety_checker, |
| | ) |
| |
|
| |
|
| | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) |
| |
|
| |
|
| | self.transform = TF.Compose([ |
| | TF.Resize((eva_size, eva_size), interpolation=TF.InterpolationMode.BICUBIC), |
| | TF.ToTensor(), |
| | TF.Normalize(mean=eva_mean, std=eva_std), |
| | ]) |
| |
|
| |
|
| | self.negative_prompt = {} |
| |
|
| |
|
| | def device(self, module): |
| | return next(module.parameters()).device |
| |
|
| |
|
| | def dtype(self, module): |
| | return next(module.parameters()).dtype |
| |
|
| |
|
| | @torch.no_grad() |
| | def __call__( |
| | self, |
| | inputs: List[Image.Image | str] | str | Image.Image, |
| | height: int = 1024, |
| | width: int = 1024, |
| | num_inference_steps: int = 50, |
| | guidance_scale: float = 3., |
| | crop_info: List[int] = [0, 0], |
| | original_size: List[int] = [1024, 1024], |
| | ): |
| | if not isinstance(inputs, list): |
| | inputs = [inputs] |
| |
|
| |
|
| | |
| | height = height or self.unet.config.sample_size * self.vae_scale_factor |
| | width = width or self.unet.config.sample_size * self.vae_scale_factor |
| |
|
| |
|
| | device = self.device(self.unet) |
| | dtype = self.dtype(self.unet) |
| |
|
| |
|
| | do_classifier_free_guidance = guidance_scale > 1.0 |
| |
|
| |
|
| | |
| | prompt_embeds = self._prepare_and_encode_inputs( |
| | inputs, |
| | do_classifier_free_guidance, |
| | ).to(dtype).to(device) |
| | batch_size = prompt_embeds.shape[0] // 2 if do_classifier_free_guidance else prompt_embeds.shape[0] |
| |
|
| |
|
| | unet_added_conditions = {} |
| | time_ids = torch.LongTensor(original_size + crop_info + [height, width]).to(device) |
| | if do_classifier_free_guidance: |
| | unet_added_conditions["time_ids"] = torch.cat([time_ids, time_ids], dim=0) |
| | else: |
| | unet_added_conditions["time_ids"] = time_ids |
| | unet_added_conditions["text_embeds"] = torch.mean(prompt_embeds, dim=1) |
| |
|
| |
|
| | |
| | self.scheduler.set_timesteps(num_inference_steps, device=device) |
| | timesteps = self.scheduler.timesteps |
| |
|
| |
|
| | |
| | shape = ( |
| | batch_size, |
| | self.unet.config.in_channels, |
| | height // self.vae_scale_factor, |
| | width // self.vae_scale_factor, |
| | ) |
| | latents = torch.randn(shape, device=device, dtype=dtype) |
| | latents = latents * self.scheduler.init_noise_sigma |
| |
|
| |
|
| | |
| | for t in tqdm(timesteps): |
| | |
| | |
| | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
| | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
| |
|
| |
|
| | noise_pred = self.unet( |
| | latent_model_input, |
| | t, |
| | encoder_hidden_states=prompt_embeds, |
| | added_cond_kwargs=unet_added_conditions, |
| | ).sample |
| |
|
| |
|
| | |
| | if do_classifier_free_guidance: |
| | noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2) |
| | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) |
| |
|
| |
|
| | |
| | latents = self.scheduler.step(noise_pred, t, latents).prev_sample |
| |
|
| |
|
| | |
| | images = self.decode_latents(latents) |
| |
|
| |
|
| | |
| | images, has_nsfw_concept = self.run_safety_checker(images) |
| |
|
| |
|
| | |
| | images = self.numpy_to_pil(images) |
| | return EmuVisualGenerationPipelineOutput( |
| | image=images[0], |
| | nsfw_content_detected=None if has_nsfw_concept is None else has_nsfw_concept[0], |
| | ) |
| |
|
| |
|
| | def _prepare_and_encode_inputs( |
| | self, |
| | inputs: List[str | Image.Image], |
| | do_classifier_free_guidance: bool = False, |
| | placeholder: str = DEFAULT_IMG_PLACEHOLDER, |
| | ): |
| | |
| | device = self.device(self.multimodal_encoder.model) |
| | dtype = self.dtype(self.multimodal_encoder.model) |
| |
|
| |
|
| | has_image, has_text = False, False |
| | text_prompt, image_prompt = "", [] |
| | for x in inputs: |
| | if isinstance(x, str): |
| | has_text = True |
| | text_prompt += x |
| | else: |
| | has_image = True |
| | text_prompt += placeholder |
| | image_prompt.append(self.transform(x)) |
| |
|
| |
|
| | if len(image_prompt) == 0: |
| | image_prompt = None |
| | else: |
| | image_prompt = torch.stack(image_prompt) |
| | image_prompt = image_prompt.type(dtype).to(device) |
| |
|
| |
|
| | if has_image and not has_text: |
| | prompt = self.multimodal_encoder.model.encode_image(image=image_prompt) |
| | if do_classifier_free_guidance: |
| | key = "[NULL_IMAGE]" |
| | if key not in self.negative_prompt: |
| | negative_image = torch.zeros_like(image_prompt) |
| | self.negative_prompt[key] = self.multimodal_encoder.model.encode_image(image=negative_image) |
| | prompt = torch.cat([prompt, self.negative_prompt[key]], dim=0) |
| | else: |
| | prompt = self.multimodal_encoder.generate_image(text=[text_prompt], image=image_prompt, tokenizer=self.tokenizer) |
| | if do_classifier_free_guidance: |
| | key = "" |
| | if key not in self.negative_prompt: |
| | self.negative_prompt[key] = self.multimodal_encoder.generate_image(text=[""], tokenizer=self.tokenizer) |
| | prompt = torch.cat([prompt, self.negative_prompt[key]], dim=0) |
| |
|
| |
|
| | return prompt |
| |
|
| |
|
| | def decode_latents(self, latents: torch.Tensor) -> np.ndarray: |
| | latents = 1 / self.vae.config.scaling_factor * latents |
| | image = self.vae.decode(latents).sample |
| | image = (image / 2 + 0.5).clamp(0, 1) |
| | image = image.cpu().permute(0, 2, 3, 1).float().numpy() |
| | return image |
| |
|
| |
|
| | def numpy_to_pil(self, images: np.ndarray) -> List[Image.Image]: |
| | """ |
| | Convert a numpy image or a batch of images to a PIL image. |
| | """ |
| | if images.ndim == 3: |
| | images = images[None, ...] |
| | images = (images * 255).round().astype("uint8") |
| | if images.shape[-1] == 1: |
| | |
| | pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] |
| | else: |
| | pil_images = [Image.fromarray(image) for image in images] |
| |
|
| |
|
| | return pil_images |
| |
|
| |
|
| | def run_safety_checker(self, images: np.ndarray): |
| | if self.safety_checker is not None: |
| | device = self.device(self.safety_checker) |
| | dtype = self.dtype(self.safety_checker) |
| | safety_checker_input = self.feature_extractor(self.numpy_to_pil(images), return_tensors="pt").to(device) |
| | images, has_nsfw_concept = self.safety_checker( |
| | images=images, clip_input=safety_checker_input.pixel_values.to(dtype) |
| | ) |
| | else: |
| | has_nsfw_concept = None |
| | return images, has_nsfw_concept |
| |
|
| |
|