| |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| from dataclasses import dataclass |
| from typing import List, Optional |
|
|
|
|
| from PIL import Image |
| import numpy as np |
| import torch |
| from torchvision import transforms as TF |
| from tqdm import tqdm |
| import pdb |
|
|
|
|
| from diffusers import DiffusionPipeline |
| from diffusers.utils import BaseOutput |
|
|
|
|
| from diffusers import UNet2DConditionModel, EulerDiscreteScheduler, AutoencoderKL |
| from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker |
| from transformers import CLIPImageProcessor |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
| EVA_IMAGE_SIZE = 448 |
| OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) |
| OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) |
| DEFAULT_IMG_PLACEHOLDER = "<image>" |
|
|
|
|
| @dataclass |
| class EmuVisualGenerationPipelineOutput(BaseOutput): |
| image: Image.Image |
| nsfw_content_detected: Optional[bool] |
|
|
|
|
|
|
|
|
| class EmuVisualGenerationPipeline(DiffusionPipeline): |
|
|
|
|
| def __init__( |
| self, |
| tokenizer: AutoTokenizer, |
| multimodal_encoder: AutoModelForCausalLM, |
| scheduler: EulerDiscreteScheduler, |
| unet: UNet2DConditionModel, |
| vae: AutoencoderKL, |
| feature_extractor: CLIPImageProcessor, |
| safety_checker: StableDiffusionSafetyChecker, |
| eva_size=EVA_IMAGE_SIZE, |
| eva_mean=OPENAI_DATASET_MEAN, |
| eva_std=OPENAI_DATASET_STD, |
| ): |
| super().__init__() |
| self.register_modules( |
| tokenizer=tokenizer, |
| multimodal_encoder=multimodal_encoder, |
| scheduler=scheduler, |
| unet=unet, |
| vae=vae, |
| feature_extractor=feature_extractor, |
| safety_checker=safety_checker, |
| ) |
|
|
|
|
| self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) |
|
|
|
|
| self.transform = TF.Compose([ |
| TF.Resize((eva_size, eva_size), interpolation=TF.InterpolationMode.BICUBIC), |
| TF.ToTensor(), |
| TF.Normalize(mean=eva_mean, std=eva_std), |
| ]) |
|
|
|
|
| self.negative_prompt = {} |
|
|
|
|
| def device(self, module): |
| return next(module.parameters()).device |
|
|
|
|
| def dtype(self, module): |
| return next(module.parameters()).dtype |
|
|
|
|
| @torch.no_grad() |
| def __call__( |
| self, |
| inputs: List[Image.Image | str] | str | Image.Image, |
| height: int = 1024, |
| width: int = 1024, |
| num_inference_steps: int = 50, |
| guidance_scale: float = 3., |
| crop_info: List[int] = [0, 0], |
| original_size: List[int] = [1024, 1024], |
| ): |
| if not isinstance(inputs, list): |
| inputs = [inputs] |
|
|
|
|
| |
| height = height or self.unet.config.sample_size * self.vae_scale_factor |
| width = width or self.unet.config.sample_size * self.vae_scale_factor |
|
|
|
|
| device = self.device(self.unet) |
| dtype = self.dtype(self.unet) |
|
|
|
|
| do_classifier_free_guidance = guidance_scale > 1.0 |
|
|
|
|
| |
| prompt_embeds = self._prepare_and_encode_inputs( |
| inputs, |
| do_classifier_free_guidance, |
| ).to(dtype).to(device) |
| batch_size = prompt_embeds.shape[0] // 2 if do_classifier_free_guidance else prompt_embeds.shape[0] |
|
|
|
|
| unet_added_conditions = {} |
| time_ids = torch.LongTensor(original_size + crop_info + [height, width]).to(device) |
| if do_classifier_free_guidance: |
| unet_added_conditions["time_ids"] = torch.cat([time_ids, time_ids], dim=0) |
| else: |
| unet_added_conditions["time_ids"] = time_ids |
| unet_added_conditions["text_embeds"] = torch.mean(prompt_embeds, dim=1) |
|
|
|
|
| |
| self.scheduler.set_timesteps(num_inference_steps, device=device) |
| timesteps = self.scheduler.timesteps |
|
|
|
|
| |
| shape = ( |
| batch_size, |
| self.unet.config.in_channels, |
| height // self.vae_scale_factor, |
| width // self.vae_scale_factor, |
| ) |
| latents = torch.randn(shape, device=device, dtype=dtype) |
| latents = latents * self.scheduler.init_noise_sigma |
|
|
|
|
| |
| for t in tqdm(timesteps): |
| |
| |
| latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
| latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
|
|
|
|
| noise_pred = self.unet( |
| latent_model_input, |
| t, |
| encoder_hidden_states=prompt_embeds, |
| added_cond_kwargs=unet_added_conditions, |
| ).sample |
|
|
|
|
| |
| if do_classifier_free_guidance: |
| noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2) |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) |
|
|
|
|
| |
| latents = self.scheduler.step(noise_pred, t, latents).prev_sample |
|
|
|
|
| |
| images = self.decode_latents(latents) |
|
|
|
|
| |
| images, has_nsfw_concept = self.run_safety_checker(images) |
|
|
|
|
| |
| images = self.numpy_to_pil(images) |
| return EmuVisualGenerationPipelineOutput( |
| image=images[0], |
| nsfw_content_detected=None if has_nsfw_concept is None else has_nsfw_concept[0], |
| ) |
|
|
|
|
| def _prepare_and_encode_inputs( |
| self, |
| inputs: List[str | Image.Image], |
| do_classifier_free_guidance: bool = False, |
| placeholder: str = DEFAULT_IMG_PLACEHOLDER, |
| ): |
| |
| device = self.device(self.multimodal_encoder.model) |
| dtype = self.dtype(self.multimodal_encoder.model) |
|
|
|
|
| has_image, has_text = False, False |
| text_prompt, image_prompt = "", [] |
| for x in inputs: |
| if isinstance(x, str): |
| has_text = True |
| text_prompt += x |
| else: |
| has_image = True |
| text_prompt += placeholder |
| image_prompt.append(self.transform(x)) |
|
|
|
|
| if len(image_prompt) == 0: |
| image_prompt = None |
| else: |
| image_prompt = torch.stack(image_prompt) |
| image_prompt = image_prompt.type(dtype).to(device) |
|
|
|
|
| if has_image and not has_text: |
| prompt = self.multimodal_encoder.model.encode_image(image=image_prompt) |
| if do_classifier_free_guidance: |
| key = "[NULL_IMAGE]" |
| if key not in self.negative_prompt: |
| negative_image = torch.zeros_like(image_prompt) |
| self.negative_prompt[key] = self.multimodal_encoder.model.encode_image(image=negative_image) |
| prompt = torch.cat([prompt, self.negative_prompt[key]], dim=0) |
| else: |
| prompt = self.multimodal_encoder.generate_image(text=[text_prompt], image=image_prompt, tokenizer=self.tokenizer) |
| if do_classifier_free_guidance: |
| key = "" |
| if key not in self.negative_prompt: |
| self.negative_prompt[key] = self.multimodal_encoder.generate_image(text=[""], tokenizer=self.tokenizer) |
| prompt = torch.cat([prompt, self.negative_prompt[key]], dim=0) |
|
|
|
|
| return prompt |
|
|
|
|
| def decode_latents(self, latents: torch.Tensor) -> np.ndarray: |
| latents = 1 / self.vae.config.scaling_factor * latents |
| image = self.vae.decode(latents).sample |
| image = (image / 2 + 0.5).clamp(0, 1) |
| image = image.cpu().permute(0, 2, 3, 1).float().numpy() |
| return image |
|
|
|
|
| def numpy_to_pil(self, images: np.ndarray) -> List[Image.Image]: |
| """ |
| Convert a numpy image or a batch of images to a PIL image. |
| """ |
| if images.ndim == 3: |
| images = images[None, ...] |
| images = (images * 255).round().astype("uint8") |
| if images.shape[-1] == 1: |
| |
| pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] |
| else: |
| pil_images = [Image.fromarray(image) for image in images] |
|
|
|
|
| return pil_images |
|
|
|
|
| def run_safety_checker(self, images: np.ndarray): |
| if self.safety_checker is not None: |
| device = self.device(self.safety_checker) |
| dtype = self.dtype(self.safety_checker) |
| safety_checker_input = self.feature_extractor(self.numpy_to_pil(images), return_tensors="pt").to(device) |
| images, has_nsfw_concept = self.safety_checker( |
| images=images, clip_input=safety_checker_input.pixel_values.to(dtype) |
| ) |
| else: |
| has_nsfw_concept = None |
| return images, has_nsfw_concept |
|
|
|
|