| | import os |
| | import json |
| | import torch |
| | from model.attn_processor import AttnProcessor2_0, SkipAttnProcessor |
| |
|
| |
|
| | def init_adapter(unet, |
| | cross_attn_cls=SkipAttnProcessor, |
| | self_attn_cls=None, |
| | cross_attn_dim=None, |
| | **kwargs): |
| | if cross_attn_dim is None: |
| | cross_attn_dim = unet.config.cross_attention_dim |
| | attn_procs = {} |
| | for name in unet.attn_processors.keys(): |
| | cross_attention_dim = None if name.endswith("attn1.processor") else cross_attn_dim |
| | if name.startswith("mid_block"): |
| | hidden_size = unet.config.block_out_channels[-1] |
| | elif name.startswith("up_blocks"): |
| | block_id = int(name[len("up_blocks.")]) |
| | hidden_size = list(reversed(unet.config.block_out_channels))[block_id] |
| | elif name.startswith("down_blocks"): |
| | block_id = int(name[len("down_blocks.")]) |
| | hidden_size = unet.config.block_out_channels[block_id] |
| | if cross_attention_dim is None: |
| | if self_attn_cls is not None: |
| | attn_procs[name] = self_attn_cls(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, **kwargs) |
| | else: |
| | |
| | attn_procs[name] = AttnProcessor2_0(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, **kwargs) |
| | else: |
| | attn_procs[name] = cross_attn_cls(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, **kwargs) |
| | |
| | unet.set_attn_processor(attn_procs) |
| | adapter_modules = torch.nn.ModuleList(unet.attn_processors.values()) |
| | return adapter_modules |
| |
|
| | def init_diffusion_model(diffusion_model_name_or_path, unet_class=None): |
| | from diffusers import AutoencoderKL |
| | from transformers import CLIPTextModel, CLIPTokenizer |
| |
|
| | text_encoder = CLIPTextModel.from_pretrained(diffusion_model_name_or_path, subfolder="text_encoder") |
| | vae = AutoencoderKL.from_pretrained(diffusion_model_name_or_path, subfolder="vae") |
| | tokenizer = CLIPTokenizer.from_pretrained(diffusion_model_name_or_path, subfolder="tokenizer") |
| | try: |
| | unet_folder = os.path.join(diffusion_model_name_or_path, "unet") |
| | unet_configs = json.load(open(os.path.join(unet_folder, "config.json"), "r")) |
| | unet = unet_class(**unet_configs) |
| | unet.load_state_dict(torch.load(os.path.join(unet_folder, "diffusion_pytorch_model.bin"), map_location="cpu"), strict=True) |
| | except: |
| | unet = None |
| | return text_encoder, vae, tokenizer, unet |
| |
|
| | def attn_of_unet(unet): |
| | attn_blocks = torch.nn.ModuleList() |
| | for name, param in unet.named_modules(): |
| | if "attn1" in name: |
| | attn_blocks.append(param) |
| | return attn_blocks |
| |
|
| | def get_trainable_module(unet, trainable_module_name): |
| | if trainable_module_name == "unet": |
| | return unet |
| | elif trainable_module_name == "transformer": |
| | trainable_modules = torch.nn.ModuleList() |
| | for blocks in [unet.down_blocks, unet.mid_block, unet.up_blocks]: |
| | if hasattr(blocks, "attentions"): |
| | trainable_modules.append(blocks.attentions) |
| | else: |
| | for block in blocks: |
| | if hasattr(block, "attentions"): |
| | trainable_modules.append(block.attentions) |
| | return trainable_modules |
| | elif trainable_module_name == "attention": |
| | attn_blocks = torch.nn.ModuleList() |
| | for name, param in unet.named_modules(): |
| | if "attn1" in name: |
| | attn_blocks.append(param) |
| | return attn_blocks |
| | else: |
| | raise ValueError(f"Unknown trainable_module_name: {trainable_module_name}") |
| |
|
| | |
| | |
| |
|