| | |
| | |
| | |
| | |
| | |
| |
|
| | import math |
| | from typing import List, Optional, Union |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from mmengine.model import BaseModule |
| | from mmengine.runner import CheckpointLoader, load_checkpoint |
| |
|
| | from mmseg.registry import MODELS |
| | from mmseg.utils import ConfigType, OptConfigType |
| |
|
| | try: |
| | from ldm.modules.diffusionmodules.util import timestep_embedding |
| | from ldm.util import instantiate_from_config |
| | has_ldm = True |
| | except ImportError: |
| | has_ldm = False |
| |
|
| |
|
| | def register_attention_control(model, controller): |
| | """Registers a control function to manage attention within a model. |
| | |
| | Args: |
| | model: The model to which attention is to be registered. |
| | controller: The control function responsible for managing attention. |
| | """ |
| |
|
| | def ca_forward(self, place_in_unet): |
| | """Custom forward method for attention. |
| | |
| | Args: |
| | self: Reference to the current object. |
| | place_in_unet: The location in UNet (down/mid/up). |
| | |
| | Returns: |
| | The modified forward method. |
| | """ |
| |
|
| | def forward(x, context=None, mask=None): |
| | h = self.heads |
| | is_cross = context is not None |
| | context = context or x |
| |
|
| | q, k, v = self.to_q(x), self.to_k(context), self.to_v(context) |
| | q, k, v = ( |
| | tensor.view(tensor.shape[0] * h, tensor.shape[1], |
| | tensor.shape[2] // h) for tensor in [q, k, v]) |
| |
|
| | sim = torch.matmul(q, k.transpose(-2, -1)) * self.scale |
| |
|
| | if mask is not None: |
| | mask = mask.flatten(1).unsqueeze(1).repeat(h, 1, 1) |
| | max_neg_value = -torch.finfo(sim.dtype).max |
| | sim.masked_fill_(~mask, max_neg_value) |
| |
|
| | attn = sim.softmax(dim=-1) |
| | attn_mean = attn.view(h, attn.shape[0] // h, |
| | *attn.shape[1:]).mean(0) |
| | controller(attn_mean, is_cross, place_in_unet) |
| |
|
| | out = torch.matmul(attn, v) |
| | out = out.view(out.shape[0] // h, out.shape[1], out.shape[2] * h) |
| | return self.to_out(out) |
| |
|
| | return forward |
| |
|
| | def register_recr(net_, count, place_in_unet): |
| | """Recursive function to register the custom forward method to all |
| | CrossAttention layers. |
| | |
| | Args: |
| | net_: The network layer currently being processed. |
| | count: The current count of layers processed. |
| | place_in_unet: The location in UNet (down/mid/up). |
| | |
| | Returns: |
| | The updated count of layers processed. |
| | """ |
| | if net_.__class__.__name__ == 'CrossAttention': |
| | net_.forward = ca_forward(net_, place_in_unet) |
| | return count + 1 |
| | if hasattr(net_, 'children'): |
| | return sum( |
| | register_recr(child, 0, place_in_unet) |
| | for child in net_.children()) |
| | return count |
| |
|
| | cross_att_count = sum( |
| | register_recr(net[1], 0, place) for net, place in [ |
| | (child, 'down') if 'input_blocks' in name else ( |
| | child, 'up') if 'output_blocks' in name else |
| | (child, |
| | 'mid') if 'middle_block' in name else (None, None) |
| | for name, child in model.diffusion_model.named_children() |
| | ] if net is not None) |
| |
|
| | controller.num_att_layers = cross_att_count |
| |
|
| |
|
| | class AttentionStore: |
| | """A class for storing attention information in the UNet model. |
| | |
| | Attributes: |
| | base_size (int): Base size for storing attention information. |
| | max_size (int): Maximum size for storing attention information. |
| | """ |
| |
|
| | def __init__(self, base_size=64, max_size=None): |
| | """Initialize AttentionStore with default or custom sizes.""" |
| | self.reset() |
| | self.base_size = base_size |
| | self.max_size = max_size or (base_size // 2) |
| | self.num_att_layers = -1 |
| |
|
| | @staticmethod |
| | def get_empty_store(): |
| | """Returns an empty store for holding attention values.""" |
| | return { |
| | key: [] |
| | for key in [ |
| | 'down_cross', 'mid_cross', 'up_cross', 'down_self', 'mid_self', |
| | 'up_self' |
| | ] |
| | } |
| |
|
| | def reset(self): |
| | """Resets the step and attention stores to their initial states.""" |
| | self.cur_step = 0 |
| | self.cur_att_layer = 0 |
| | self.step_store = self.get_empty_store() |
| | self.attention_store = {} |
| |
|
| | def forward(self, attn, is_cross: bool, place_in_unet: str): |
| | """Processes a single forward step, storing the attention. |
| | |
| | Args: |
| | attn: The attention tensor. |
| | is_cross (bool): Whether it's cross attention. |
| | place_in_unet (str): The location in UNet (down/mid/up). |
| | |
| | Returns: |
| | The unmodified attention tensor. |
| | """ |
| | key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" |
| | if attn.shape[1] <= (self.max_size)**2: |
| | self.step_store[key].append(attn) |
| | return attn |
| |
|
| | def between_steps(self): |
| | """Processes and stores attention information between steps.""" |
| | if not self.attention_store: |
| | self.attention_store = self.step_store |
| | else: |
| | for key in self.attention_store: |
| | self.attention_store[key] = [ |
| | stored + step for stored, step in zip( |
| | self.attention_store[key], self.step_store[key]) |
| | ] |
| | self.step_store = self.get_empty_store() |
| |
|
| | def get_average_attention(self): |
| | """Calculates and returns the average attention across all steps.""" |
| | return { |
| | key: [item for item in self.step_store[key]] |
| | for key in self.step_store |
| | } |
| |
|
| | def __call__(self, attn, is_cross: bool, place_in_unet: str): |
| | """Allows the class instance to be callable.""" |
| | return self.forward(attn, is_cross, place_in_unet) |
| |
|
| | @property |
| | def num_uncond_att_layers(self): |
| | """Returns the number of unconditional attention layers (default is |
| | 0).""" |
| | return 0 |
| |
|
| | def step_callback(self, x_t): |
| | """A placeholder for a step callback. |
| | |
| | Returns the input unchanged. |
| | """ |
| | return x_t |
| |
|
| |
|
| | class UNetWrapper(nn.Module): |
| | """A wrapper for UNet with optional attention mechanisms. |
| | |
| | Args: |
| | unet (nn.Module): The UNet model to wrap |
| | use_attn (bool): Whether to use attention. Defaults to True |
| | base_size (int): Base size for the attention store. Defaults to 512 |
| | max_attn_size (int, optional): Maximum size for the attention store. |
| | Defaults to None |
| | attn_selector (str): The types of attention to use. |
| | Defaults to 'up_cross+down_cross' |
| | """ |
| |
|
| | def __init__(self, |
| | unet, |
| | use_attn=True, |
| | base_size=512, |
| | max_attn_size=None, |
| | attn_selector='up_cross+down_cross'): |
| | super().__init__() |
| |
|
| | assert has_ldm, 'To use UNetWrapper, please install required ' \ |
| | 'packages via `pip install -r requirements/optional.txt`.' |
| |
|
| | self.unet = unet |
| | self.attention_store = AttentionStore( |
| | base_size=base_size // 8, max_size=max_attn_size) |
| | self.attn_selector = attn_selector.split('+') |
| | self.use_attn = use_attn |
| | self.init_sizes(base_size) |
| | if self.use_attn: |
| | register_attention_control(unet, self.attention_store) |
| |
|
| | def init_sizes(self, base_size): |
| | """Initialize sizes based on the base size.""" |
| | self.size16 = base_size // 32 |
| | self.size32 = base_size // 16 |
| | self.size64 = base_size // 8 |
| |
|
| | def forward(self, x, timesteps=None, context=None, y=None, **kwargs): |
| | """Forward pass through the model.""" |
| | diffusion_model = self.unet.diffusion_model |
| | if self.use_attn: |
| | self.attention_store.reset() |
| | hs, emb, out_list = self._unet_forward(x, timesteps, context, y, |
| | diffusion_model) |
| | if self.use_attn: |
| | self._append_attn_to_output(out_list) |
| | return out_list[::-1] |
| |
|
| | def _unet_forward(self, x, timesteps, context, y, diffusion_model): |
| | hs = [] |
| | t_emb = timestep_embedding( |
| | timesteps, diffusion_model.model_channels, repeat_only=False) |
| | emb = diffusion_model.time_embed(t_emb) |
| | h = x.type(diffusion_model.dtype) |
| | for module in diffusion_model.input_blocks: |
| | h = module(h, emb, context) |
| | hs.append(h) |
| | h = diffusion_model.middle_block(h, emb, context) |
| | out_list = [] |
| | for i_out, module in enumerate(diffusion_model.output_blocks): |
| | h = torch.cat([h, hs.pop()], dim=1) |
| | h = module(h, emb, context) |
| | if i_out in [1, 4, 7]: |
| | out_list.append(h) |
| | h = h.type(x.dtype) |
| | out_list.append(h) |
| | return hs, emb, out_list |
| |
|
| | def _append_attn_to_output(self, out_list): |
| | avg_attn = self.attention_store.get_average_attention() |
| | attns = {self.size16: [], self.size32: [], self.size64: []} |
| | for k in self.attn_selector: |
| | for up_attn in avg_attn[k]: |
| | size = int(math.sqrt(up_attn.shape[1])) |
| | up_attn = up_attn.transpose(-1, -2).reshape( |
| | *up_attn.shape[:2], size, -1) |
| | attns[size].append(up_attn) |
| | attn16 = torch.stack(attns[self.size16]).mean(0) |
| | attn32 = torch.stack(attns[self.size32]).mean(0) |
| | attn64 = torch.stack(attns[self.size64]).mean(0) if len( |
| | attns[self.size64]) > 0 else None |
| | out_list[1] = torch.cat([out_list[1], attn16], dim=1) |
| | out_list[2] = torch.cat([out_list[2], attn32], dim=1) |
| | if attn64 is not None: |
| | out_list[3] = torch.cat([out_list[3], attn64], dim=1) |
| |
|
| |
|
| | class TextAdapter(nn.Module): |
| | """A PyTorch Module that serves as a text adapter. |
| | |
| | This module takes text embeddings and adjusts them based on a scaling |
| | factor gamma. |
| | """ |
| |
|
| | def __init__(self, text_dim=768): |
| | super().__init__() |
| | self.fc = nn.Sequential( |
| | nn.Linear(text_dim, text_dim), nn.GELU(), |
| | nn.Linear(text_dim, text_dim)) |
| |
|
| | def forward(self, texts, gamma): |
| | texts_after = self.fc(texts) |
| | texts = texts + gamma * texts_after |
| | return texts |
| |
|
| |
|
| | @MODELS.register_module() |
| | class VPD(BaseModule): |
| | """VPD (Visual Perception Diffusion) model. |
| | |
| | .. _`VPD`: https://arxiv.org/abs/2303.02153 |
| | |
| | Args: |
| | diffusion_cfg (dict): Configuration for diffusion model. |
| | class_embed_path (str): Path for class embeddings. |
| | unet_cfg (dict, optional): Configuration for U-Net. |
| | gamma (float, optional): Gamma for text adaptation. Defaults to 1e-4. |
| | class_embed_select (bool, optional): If True, enables class embedding |
| | selection. Defaults to False. |
| | pad_shape (Optional[Union[int, List[int]]], optional): Padding shape. |
| | Defaults to None. |
| | pad_val (Union[int, List[int]], optional): Padding value. |
| | Defaults to 0. |
| | init_cfg (dict, optional): Configuration for network initialization. |
| | """ |
| |
|
| | def __init__(self, |
| | diffusion_cfg: ConfigType, |
| | class_embed_path: str, |
| | unet_cfg: OptConfigType = dict(), |
| | gamma: float = 1e-4, |
| | class_embed_select=False, |
| | pad_shape: Optional[Union[int, List[int]]] = None, |
| | pad_val: Union[int, List[int]] = 0, |
| | init_cfg: OptConfigType = None): |
| |
|
| | super().__init__(init_cfg=init_cfg) |
| |
|
| | assert has_ldm, 'To use VPD model, please install required packages' \ |
| | ' via `pip install -r requirements/optional.txt`.' |
| |
|
| | if pad_shape is not None: |
| | if not isinstance(pad_shape, (list, tuple)): |
| | pad_shape = (pad_shape, pad_shape) |
| |
|
| | self.pad_shape = pad_shape |
| | self.pad_val = pad_val |
| |
|
| | |
| | diffusion_checkpoint = diffusion_cfg.pop('checkpoint', None) |
| | sd_model = instantiate_from_config(diffusion_cfg) |
| | if diffusion_checkpoint is not None: |
| | load_checkpoint(sd_model, diffusion_checkpoint, strict=False) |
| |
|
| | self.encoder_vq = sd_model.first_stage_model |
| | self.unet = UNetWrapper(sd_model.model, **unet_cfg) |
| |
|
| | |
| | class_embeddings = CheckpointLoader.load_checkpoint(class_embed_path) |
| | text_dim = class_embeddings.size(-1) |
| | self.text_adapter = TextAdapter(text_dim=text_dim) |
| | self.class_embed_select = class_embed_select |
| | if class_embed_select: |
| | class_embeddings = torch.cat( |
| | (class_embeddings, class_embeddings.mean(dim=0, |
| | keepdims=True)), |
| | dim=0) |
| | self.register_buffer('class_embeddings', class_embeddings) |
| | self.gamma = nn.Parameter(torch.ones(text_dim) * gamma) |
| |
|
| | def forward(self, x): |
| | """Extract features from images.""" |
| |
|
| | |
| | if self.class_embed_select: |
| | if isinstance(x, (tuple, list)): |
| | x, class_ids = x[:2] |
| | class_ids = class_ids.tolist() |
| | else: |
| | class_ids = [-1] * x.size(0) |
| | class_embeddings = self.class_embeddings[class_ids] |
| | c_crossattn = self.text_adapter(class_embeddings, self.gamma) |
| | c_crossattn = c_crossattn.unsqueeze(1) |
| | else: |
| | class_embeddings = self.class_embeddings |
| | c_crossattn = self.text_adapter(class_embeddings, self.gamma) |
| | c_crossattn = c_crossattn.unsqueeze(0).repeat(x.size(0), 1, 1) |
| |
|
| | |
| | if self.pad_shape is not None: |
| | pad_width = max(0, self.pad_shape[1] - x.shape[-1]) |
| | pad_height = max(0, self.pad_shape[0] - x.shape[-2]) |
| | x = F.pad(x, (0, pad_width, 0, pad_height), value=self.pad_val) |
| |
|
| | |
| | with torch.no_grad(): |
| | latents = self.encoder_vq.encode(x).mode().detach() |
| | t = torch.ones((x.shape[0], ), device=x.device).long() |
| | outs = self.unet(latents, t, context=c_crossattn) |
| |
|
| | return outs |
| |
|