| | from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, EulerAncestralDiscreteScheduler, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionPipeline |
| | from transformers import CLIPVisionModelWithProjection |
| | import torch |
| | from copy import deepcopy |
| |
|
| | ENABLE_CPU_CACHE = False |
| | DEFAULT_BASE_MODEL = "benjamin-paine/stable-diffusion-v1-5" |
| |
|
| | cached_models = {} |
| | def cache_model(func): |
| | def wrapper(*args, **kwargs): |
| | if ENABLE_CPU_CACHE: |
| | model_name = func.__name__ + str(args) + str(kwargs) |
| | if model_name not in cached_models: |
| | cached_models[model_name] = func(*args, **kwargs) |
| | return cached_models[model_name] |
| | else: |
| | return func(*args, **kwargs) |
| | return wrapper |
| |
|
| | def copied_cache_model(func): |
| | def wrapper(*args, **kwargs): |
| | if ENABLE_CPU_CACHE: |
| | model_name = func.__name__ + str(args) + str(kwargs) |
| | if model_name not in cached_models: |
| | cached_models[model_name] = func(*args, **kwargs) |
| | return deepcopy(cached_models[model_name]) |
| | else: |
| | return func(*args, **kwargs) |
| | return wrapper |
| |
|
| | def model_from_ckpt_or_pretrained(ckpt_or_pretrained, model_cls, original_config_file='ckpt/v1-inference.yaml', torch_dtype=torch.float16, **kwargs): |
| | if ckpt_or_pretrained.endswith(".safetensors"): |
| | pipe = model_cls.from_single_file(ckpt_or_pretrained, original_config_file=original_config_file, torch_dtype=torch_dtype, **kwargs) |
| | else: |
| | pipe = model_cls.from_pretrained(ckpt_or_pretrained, torch_dtype=torch_dtype, **kwargs) |
| | return pipe |
| |
|
| | @copied_cache_model |
| | def load_base_model_components(base_model=DEFAULT_BASE_MODEL, torch_dtype=torch.float16): |
| | model_kwargs = dict( |
| | torch_dtype=torch_dtype, |
| | requires_safety_checker=False, |
| | safety_checker=None, |
| | ) |
| | pipe: StableDiffusionPipeline = model_from_ckpt_or_pretrained( |
| | base_model, |
| | StableDiffusionPipeline, |
| | **model_kwargs |
| | ) |
| | pipe.to("cpu") |
| | return pipe.components |
| |
|
| | @cache_model |
| | def load_controlnet(controlnet_path, torch_dtype=torch.float16): |
| | controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch_dtype) |
| | return controlnet |
| |
|
| | @cache_model |
| | def load_image_encoder(): |
| | image_encoder = CLIPVisionModelWithProjection.from_pretrained( |
| | "h94/IP-Adapter", |
| | subfolder="models/image_encoder", |
| | torch_dtype=torch.float16, |
| | ) |
| | return image_encoder |
| |
|
| | def load_common_sd15_pipe(base_model=DEFAULT_BASE_MODEL, device="balanced", controlnet=None, ip_adapter=False, plus_model=True, torch_dtype=torch.float16, model_cpu_offload_seq=None, enable_sequential_cpu_offload=False, vae_slicing=False, pipeline_class=None, **kwargs): |
| | model_kwargs = dict( |
| | torch_dtype=torch_dtype, |
| | |
| | requires_safety_checker=False, |
| | safety_checker=None, |
| | ) |
| | components = load_base_model_components(base_model=base_model, torch_dtype=torch_dtype) |
| | model_kwargs.update(components) |
| | model_kwargs.update(kwargs) |
| | |
| | if controlnet is not None: |
| | if isinstance(controlnet, list): |
| | controlnet = [load_controlnet(controlnet_path, torch_dtype=torch_dtype) for controlnet_path in controlnet] |
| | else: |
| | controlnet = load_controlnet(controlnet, torch_dtype=torch_dtype) |
| | model_kwargs.update(controlnet=controlnet) |
| | |
| | if pipeline_class is None: |
| | if controlnet is not None: |
| | pipeline_class = StableDiffusionControlNetPipeline |
| | else: |
| | pipeline_class = StableDiffusionPipeline |
| | |
| | pipe: StableDiffusionPipeline = model_from_ckpt_or_pretrained( |
| | base_model, |
| | pipeline_class, |
| | **model_kwargs |
| | ) |
| |
|
| | if ip_adapter: |
| | image_encoder = load_image_encoder() |
| | pipe.image_encoder = image_encoder |
| | if plus_model: |
| | pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.safetensors") |
| | else: |
| | pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.safetensors") |
| | pipe.set_ip_adapter_scale(1.0) |
| | else: |
| | pipe.unload_ip_adapter() |
| | |
| | pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) |
| |
|
| | if model_cpu_offload_seq is None: |
| | if isinstance(pipe, StableDiffusionControlNetPipeline): |
| | pipe.model_cpu_offload_seq = "text_encoder->controlnet->unet->vae" |
| | elif isinstance(pipe, StableDiffusionControlNetImg2ImgPipeline): |
| | pipe.model_cpu_offload_seq = "text_encoder->controlnet->vae->unet->vae" |
| | else: |
| | pipe.model_cpu_offload_seq = model_cpu_offload_seq |
| | |
| | if enable_sequential_cpu_offload: |
| | pipe.enable_sequential_cpu_offload() |
| | else: |
| | pass |
| | pipe.enable_model_cpu_offload() |
| | if vae_slicing: |
| | pipe.enable_vae_slicing() |
| | |
| | import gc |
| | gc.collect() |
| | return pipe |
| |
|
| |
|