File size: 5,855 Bytes
4189926
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import logging
from PIL import Image
from typing import Any
from diffusers import (
    ControlNetModel,
    AutoPipelineForText2Image,
    LCMScheduler,
    AutoencoderKL,
    StableDiffusionPipeline,
    StableDiffusionXLPipeline,
    StableDiffusionControlNetPipeline,
    StableDiffusionXLControlNetPipeline,
    StableDiffusionControlNetImg2ImgPipeline,
    StableDiffusionXLControlNetImg2ImgPipeline,
)
from backend.models.lcmdiffusion_setting import (
    DiffusionTask,
    ControlNetSetting,
)


# Prepares ControlNet adapters for use with FastSD CPU
#
# This function loads the ControlNet adapters defined by the
# _lcm_diffusion_setting.controlnet_ object and returns a dictionary
# with the pipeline arguments required to use the loaded adapters
def load_controlnet_adapters(lcm_diffusion_setting) -> dict:
    controlnet_args = {}
    if (
        lcm_diffusion_setting.controlnet is None
        or not lcm_diffusion_setting.controlnet.enabled
    ):
        return controlnet_args

    logging.info("Loading ControlNet adapter")
    controlnet_adapter = ControlNetModel.from_single_file(
        lcm_diffusion_setting.controlnet.adapter_path,
        # local_files_only=True,
        use_safetensors=True,
    )
    controlnet_args["controlnet"] = controlnet_adapter
    return controlnet_args


# Updates the ControlNet pipeline arguments to use for image generation
#
# This function uses the contents of the _lcm_diffusion_setting.controlnet_
# object to generate a dictionary with the corresponding pipeline arguments
# to be used for image generation; in particular, it sets the ControlNet control
# image and conditioning scale
def update_controlnet_arguments(lcm_diffusion_setting) -> dict:
    controlnet_args = {}
    if (
        lcm_diffusion_setting.controlnet is None
        or not lcm_diffusion_setting.controlnet.enabled
    ):
        return controlnet_args

    controlnet_args["controlnet_conditioning_scale"] = (
        lcm_diffusion_setting.controlnet.conditioning_scale
    )
    if lcm_diffusion_setting.diffusion_task == DiffusionTask.text_to_image.value:
        controlnet_args["image"] = lcm_diffusion_setting.controlnet._control_image
    elif lcm_diffusion_setting.diffusion_task == DiffusionTask.image_to_image.value:
        controlnet_args["control_image"] = (
            lcm_diffusion_setting.controlnet._control_image
        )
    return controlnet_args


# Helper function to adjust ControlNet settings from a dictionary
def controlnet_settings_from_dict(
    lcm_diffusion_setting,
    dictionary,
) -> None:
    if lcm_diffusion_setting is None or dictionary is None:
        logging.error("Invalid arguments!")
        return
    if (
        "controlnet" not in dictionary
        or dictionary["controlnet"] is None
        or len(dictionary["controlnet"]) == 0
    ):
        logging.warning("ControlNet settings not found, ControlNet will be disabled")
        lcm_diffusion_setting.controlnet = None
        return

    controlnet = ControlNetSetting()
    controlnet.enabled = dictionary["controlnet"][0]["enabled"]
    controlnet.conditioning_scale = dictionary["controlnet"][0]["conditioning_scale"]
    controlnet.adapter_path = dictionary["controlnet"][0]["adapter_path"]
    controlnet._control_image = None
    image_path = dictionary["controlnet"][0]["control_image"]
    if controlnet.enabled:
        try:
            controlnet._control_image = Image.open(image_path)
        except (AttributeError, FileNotFoundError) as err:
            print(err)
        if controlnet._control_image is None:
            logging.error("Wrong ControlNet control image! Disabling ControlNet")
            controlnet.enabled = False
    lcm_diffusion_setting.controlnet = controlnet


def get_controlnet_pipeline(
    pipeline: Any, lcm_diffusion_setting, diffusion_task: DiffusionTask
) -> Any:
    """Creates a ControlNet pipeline from the base txt2img _pipeline_"""
    if (
        lcm_diffusion_setting.controlnet is None
        or not lcm_diffusion_setting.controlnet.enabled
    ):
        return None
    components = pipeline.components
    pipeline_class = pipeline.__class__.__name__
    controlnet_args = load_controlnet_adapters(lcm_diffusion_setting)
    if diffusion_task == DiffusionTask.text_to_image.value:
        if (
            pipeline_class == "LatentConsistencyModelPipeline"
            or pipeline_class == "StableDiffusionPipeline"
        ):
            controlnet_pipeline = StableDiffusionControlNetPipeline.from_pipe(
                pipeline,
                vae=None,
                **controlnet_args,
            )
            controlnet_pipeline.vae = pipeline.vae
            return controlnet_pipeline
        elif pipeline_class == "StableDiffusionXLPipeline":
            controlnet_pipeline = StableDiffusionXLControlNetPipeline.from_pipe(
                pipeline,
                vae=None,
                **controlnet_args,
            )
            controlnet_pipeline.vae = pipeline.vae
            return controlnet_pipeline
    elif diffusion_task == DiffusionTask.image_to_image.value:
        if (
            pipeline_class == "LatentConsistencyModelPipeline"
            or pipeline_class == "StableDiffusionPipeline"
        ):
            controlnet_pipeline = StableDiffusionControlNetImg2ImgPipeline.from_pipe(
                pipeline,
                vae=None,
                **controlnet_args,
            )
            controlnet_pipeline.vae = pipeline.vae
            return controlnet_pipeline
        elif pipeline_class == "StableDiffusionXLPipeline":
            controlnet_pipeline = StableDiffusionXLControlNetImg2ImgPipeline.from_pipe(
                pipeline,
                vae=None,
                **controlnet_args,
            )
            controlnet_pipeline.vae = pipeline.vae
            return controlnet_pipeline