| from typing import Dict, List, Any |
| import torch |
| import base64 |
| from PIL import Image |
| from io import BytesIO |
| from diffusers import T2IAdapter, StableDiffusionXLAdapterPipeline, StableDiffusionXLImg2ImgPipeline, AutoencoderKL, DPMSolverMultistepScheduler |
| from controlnet_aux.pidi import PidiNetDetector |
|
|
| |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| if device.type != 'cuda': |
| raise ValueError("need to run on GPU") |
|
|
| class EndpointHandler(): |
| |
| def __init__(self, path=""): |
|
|
| |
| adapter = T2IAdapter.from_pretrained( |
| "Adapter/t2iadapter", |
| subfolder="sketch_sdxl_1.0", |
| torch_dtype=torch.float16, |
| adapter_type="full_adapter_xl", |
| use_safetensors=True, |
| ) |
|
|
| |
| vae = AutoencoderKL.from_pretrained( |
| "madebyollin/sdxl-vae-fp16-fix", |
| torch_dtype=torch.float16, |
| use_safetensors=True, |
| ) |
|
|
| |
| scheduler = DPMSolverMultistepScheduler.from_pretrained( |
| "stabilityai/stable-diffusion-xl-base-1.0", |
| subfolder="scheduler", |
| use_lu_lambdas=True, |
| euler_at_final=True, |
| ) |
|
|
| |
| self.pipeline = StableDiffusionXLAdapterPipeline.from_pretrained( |
| "stabilityai/stable-diffusion-xl-base-1.0", |
| adapter=adapter, |
| vae=vae, |
| scheduler=scheduler, |
| torch_dtype=torch.float16, |
| variant="fp16", |
| use_safetensors=True, |
| ).to("cuda") |
|
|
| |
| self.refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained( |
| "stabilityai/stable-diffusion-xl-refiner-1.0", |
| text_encoder_2=self.pipeline.text_encoder_2, |
| vae=vae, |
| torch_dtype=torch.float16, |
| variant="fp16", |
| use_safetensors=True, |
| ).to("cuda") |
|
|
| self.pipeline.enable_model_cpu_offload() |
| self.refiner.enable_model_cpu_offload() |
|
|
| self.pidinet = PidiNetDetector.from_pretrained("lllyasviel/Annotators").to("cuda") |
|
|
| def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| """ |
| data args: |
| inputs (:obj: `str` | `PIL.Image` | `np.array`) |
| kwargs |
| Return: |
| A :obj:`list` | `dict`: will be serialized and returned |
| """ |
|
|
| |
| |
|
|
| |
| inputs = data.pop("inputs", "") |
| encoded_image = data.pop("image", None) |
| adapter_conditioning_scale = data.pop("adapter_conditioning_scale", 1.0) |
| adapter_conditioning_factor = data.pop("adapter_conditioning_factor", 1.0) |
|
|
|
|
| |
| decoded_image = self.decode_base64_image(encoded_image).convert('RGB') |
| sketch_image = self.pidinet( |
| decoded_image, |
| detect_resolution=1024, |
| image_resolution=1024, |
| apply_filter=True |
| ).convert('L') |
|
|
| |
|
|
| num_inference_steps = 25 |
| high_noise_frac = 0.7 |
| base_image = self.pipeline( |
| prompt=inputs, |
| negative_prompt="extra digit, fewer digits, cropped, worst quality, low quality", |
| image=sketch_image, |
| num_inference_steps=num_inference_steps, |
| denoising_end=high_noise_frac, |
| guidance_scale=7.5, |
| adapter_conditioning_scale=adapter_conditioning_scale, |
| adapter_conditioning_factor=adapter_conditioning_factor, |
| output_type="latent", |
| ).images |
|
|
| output_image = self.refiner( |
| prompt=inputs, |
| negative_prompt="extra digit, fewer digits, cropped, worst quality, low quality", |
| image=base_image, |
| num_inference_steps=num_inference_steps, |
| denoising_start=high_noise_frac, |
| guidance_scale=7.5, |
| adapter_conditioning_scale=adapter_conditioning_scale, |
| adapter_conditioning_factor=adapter_conditioning_factor, |
| ).images[0] |
|
|
| |
| return output_image |
|
|
| |
| def decode_base64_image(self, image_string): |
| base64_image = base64.b64decode(image_string) |
| buffer = BytesIO(base64_image) |
| image = Image.open(buffer) |
| return image |