| | from typing import List |
| |
|
| | import cv2 |
| | import numpy as np |
| | import torch |
| | from controlnet_aux import OpenposeDetector |
| | from diffusers import (ControlNetModel, StableDiffusionControlNetPipeline, |
| | UniPCMultistepScheduler) |
| | from PIL import Image |
| | from util.cache import clear_cuda_and_gc |
| | from util.commons import disable_safety_checker, download_image |
| |
|
| |
|
| | class ControlNet: |
| | __current_task_name = "" |
| |
|
| | def load(self, model_dir: str): |
| | |
| | self.load_canny() |
| |
|
| | pipe = StableDiffusionControlNetPipeline.from_pretrained( |
| | model_dir, controlnet=self.controlnet, torch_dtype=torch.float16 |
| | ) |
| | pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) |
| | pipe.enable_model_cpu_offload() |
| | pipe.enable_xformers_memory_efficient_attention() |
| | disable_safety_checker(pipe) |
| | self.pipe = pipe |
| |
|
| | def load_canny(self): |
| | if self.__current_task_name == "canny": |
| | return |
| | canny = ControlNetModel.from_pretrained( |
| | "lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16 |
| | ).to("cuda") |
| | self.__current_task_name = "canny" |
| | self.controlnet = canny |
| | if hasattr(self, "pipe"): |
| | self.pipe.controlnet = canny |
| | clear_cuda_and_gc() |
| |
|
| | def load_pose(self): |
| | if self.__current_task_name == "pose": |
| | return |
| | pose = ControlNetModel.from_pretrained( |
| | "lllyasviel/control_v11p_sd15_openpose", torch_dtype=torch.float16 |
| | ).to("cuda") |
| | self.__current_task_name = "pose" |
| | self.controlnet = pose |
| | if hasattr(self, "pipe"): |
| | self.pipe.controlnet = pose |
| | clear_cuda_and_gc() |
| |
|
| | def cleanup(self): |
| | self.pipe.controlnet = None |
| | self.controlnet = None |
| | self.__current_task_name = "" |
| |
|
| | clear_cuda_and_gc() |
| |
|
| | @torch.inference_mode() |
| | def process_canny( |
| | self, |
| | prompt: List[str], |
| | imageUrl: str, |
| | seed: int, |
| | steps: int, |
| | negative_prompt: List[str], |
| | height: int, |
| | width: int, |
| | ): |
| | if self.__current_task_name != "canny": |
| | raise Exception("ControlNet is not loaded with canny model") |
| |
|
| | torch.manual_seed(seed) |
| |
|
| | init_image = download_image(imageUrl) |
| | init_image = self.__canny_detect_edge(init_image) |
| |
|
| | return self.pipe.__call__( |
| | prompt=prompt, |
| | image=init_image, |
| | guidance_scale=9, |
| | num_images_per_prompt=1, |
| | negative_prompt=negative_prompt, |
| | num_inference_steps=steps, |
| | height=height, |
| | width=width, |
| | ).images |
| |
|
| | @torch.inference_mode() |
| | def process_pose( |
| | self, |
| | prompt: List[str], |
| | image: List[Image.Image], |
| | seed: int, |
| | steps: int, |
| | negative_prompt: List[str], |
| | height: int, |
| | width: int, |
| | ): |
| | if self.__current_task_name != "pose": |
| | raise Exception("ControlNet is not loaded with pose model") |
| |
|
| | torch.manual_seed(seed) |
| |
|
| | return self.pipe.__call__( |
| | prompt=prompt, |
| | image=image, |
| | num_images_per_prompt=1, |
| | num_inference_steps=steps, |
| | negative_prompt=negative_prompt, |
| | height=height, |
| | width=width, |
| | ).images |
| |
|
| | def detect_pose(self, imageUrl: str) -> Image.Image: |
| | detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet") |
| | image = download_image(imageUrl) |
| | image = detector.__call__(image) |
| | return image |
| |
|
| | def __canny_detect_edge(self, image: Image.Image) -> Image.Image: |
| | image_array = np.array(image) |
| |
|
| | low_threshold = 100 |
| | high_threshold = 200 |
| |
|
| | image_array = cv2.Canny(image_array, low_threshold, high_threshold) |
| | image_array = image_array[:, :, None] |
| | image_array = np.concatenate([image_array, image_array, image_array], axis=2) |
| | canny_image = Image.fromarray(image_array) |
| | return canny_image |
| |
|