| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| from typing import Optional, Union |
|
|
| import paddle |
| import paddle.nn as nn |
| import paddle.nn.functional as F |
|
|
| from ..initializer import normal_, zeros_ |
|
|
|
|
| class CrossAttention(nn.Layer): |
| r""" |
| A cross attention layer. |
| |
| Parameters: |
| query_dim (`int`): The number of channels in the query. |
| cross_attention_dim (`int`, *optional*): |
| The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. |
| heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. |
| dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. |
| dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. |
| bias (`bool`, *optional*, defaults to False): |
| Set to `True` for the query, key, and value linear layers to contain a bias parameter. |
| """ |
|
|
| def __init__( |
| self, |
| query_dim: int, |
| cross_attention_dim: Optional[int] = None, |
| heads: int = 8, |
| dim_head: int = 64, |
| dropout: float = 0.0, |
| bias=False, |
| upcast_attention: bool = False, |
| upcast_softmax: bool = False, |
| added_kv_proj_dim: Optional[int] = None, |
| norm_num_groups: Optional[int] = None, |
| processor: Optional["AttnProcessor"] = None, |
| ): |
| super().__init__() |
| inner_dim = dim_head * heads |
| cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim |
| self.upcast_attention = upcast_attention |
| self.upcast_softmax = upcast_softmax |
|
|
| self.scale = dim_head**-0.5 |
| self.num_heads = heads |
| self.head_dim = inner_dim // heads |
| |
| |
| |
| self.sliceable_head_dim = heads |
|
|
| self.added_kv_proj_dim = added_kv_proj_dim |
|
|
| if norm_num_groups is not None: |
| self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, epsilon=1e-5) |
| else: |
| self.group_norm = None |
|
|
| self.to_q = nn.Linear(query_dim, inner_dim, bias_attr=bias) |
| self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias_attr=bias) |
| self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias_attr=bias) |
|
|
| if self.added_kv_proj_dim is not None: |
| self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) |
| self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) |
|
|
| self.to_out = nn.LayerList([]) |
| self.to_out.append(nn.Linear(inner_dim, query_dim)) |
| self.to_out.append(nn.Dropout(dropout)) |
|
|
| |
| processor = processor if processor is not None else CrossAttnProcessor() |
| self.set_processor(processor) |
|
|
| def set_attention_slice(self, slice_size): |
| if slice_size is not None and slice_size > self.sliceable_head_dim: |
| raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") |
|
|
| if slice_size is not None and self.added_kv_proj_dim is not None: |
| processor = SlicedAttnAddedKVProcessor(slice_size) |
| elif slice_size is not None: |
| processor = SlicedAttnProcessor(slice_size) |
| elif self.added_kv_proj_dim is not None: |
| processor = CrossAttnAddedKVProcessor() |
| else: |
| processor = CrossAttnProcessor() |
|
|
| self.set_processor(processor) |
|
|
| def set_processor(self, processor: "AttnProcessor"): |
| self.processor = processor |
|
|
| def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs): |
| |
| |
| |
| return self.processor( |
| self, |
| hidden_states, |
| encoder_hidden_states=encoder_hidden_states, |
| attention_mask=attention_mask, |
| **cross_attention_kwargs, |
| ) |
|
|
| def batch_to_head_dim(self, tensor): |
| tensor = tensor.transpose([0, 2, 1, 3]) |
| tensor = tensor.reshape([0, 0, tensor.shape[2] * tensor.shape[3]]) |
| return tensor |
|
|
| def head_to_batch_dim(self, tensor): |
| tensor = tensor.reshape([0, 0, self.num_heads, self.head_dim]) |
| tensor = tensor.transpose([0, 2, 1, 3]) |
| return tensor |
|
|
| def get_attention_scores(self, query, key, attention_mask=None): |
| if self.upcast_attention: |
| query = query.cast("float32") |
| key = key.cast("float32") |
|
|
| attention_scores = paddle.matmul(query, key, transpose_y=True) * self.scale |
|
|
| if attention_mask is not None: |
| attention_scores = attention_scores + attention_mask |
|
|
| if self.upcast_softmax: |
| attention_scores = attention_scores.cast("float32") |
|
|
| attention_probs = F.softmax(attention_scores, axis=-1) |
| if self.upcast_softmax: |
| attention_probs = attention_probs.cast(query.dtype) |
|
|
| return attention_probs |
|
|
| def prepare_attention_mask(self, attention_mask, target_length): |
| if attention_mask is None: |
| return attention_mask |
|
|
| if attention_mask.shape[-1] != target_length: |
| attention_mask = F.pad(attention_mask, (0, target_length), value=0.0, data_format="NCL") |
| attention_mask = attention_mask.repeat_interleave(self.num_heads, axis=0) |
| return attention_mask |
|
|
|
|
| class CrossAttnProcessor: |
| def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): |
| batch_size, sequence_length, _ = hidden_states.shape |
| attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) |
| attention_mask = ( |
| attention_mask.reshape([batch_size, attn.num_heads, -1, attention_mask.shape[-1]]) |
| if attention_mask is not None |
| else None |
| ) |
|
|
| query = attn.to_q(hidden_states) |
| query = attn.head_to_batch_dim(query) |
|
|
| encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states |
| key = attn.to_k(encoder_hidden_states) |
| value = attn.to_v(encoder_hidden_states) |
| key = attn.head_to_batch_dim(key) |
| value = attn.head_to_batch_dim(value) |
|
|
| attention_probs = attn.get_attention_scores(query, key, attention_mask) |
| hidden_states = paddle.matmul(attention_probs, value) |
| hidden_states = attn.batch_to_head_dim(hidden_states) |
|
|
| |
| hidden_states = attn.to_out[0](hidden_states) |
| |
| hidden_states = attn.to_out[1](hidden_states) |
|
|
| return hidden_states |
|
|
|
|
| class LoRALinearLayer(nn.Layer): |
| def __init__(self, in_features, out_features, rank=4): |
| super().__init__() |
|
|
| if rank > min(in_features, out_features): |
| raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}") |
|
|
| self.down = nn.Linear(in_features, rank, bias_attr=False) |
| self.up = nn.Linear(rank, out_features, bias_attr=False) |
| self.scale = 1.0 |
|
|
| normal_(self.down.weight, std=1 / rank) |
| zeros_(self.up.weight) |
|
|
| def forward(self, hidden_states): |
| orig_dtype = hidden_states.dtype |
| dtype = self.down.weight.dtype |
|
|
| down_hidden_states = self.down(hidden_states.cast(dtype)) |
| up_hidden_states = self.up(down_hidden_states) |
|
|
| return up_hidden_states.cast(orig_dtype) |
|
|
|
|
| class LoRACrossAttnProcessor(nn.Layer): |
| def __init__(self, hidden_size, cross_attention_dim=None, rank=4): |
| super().__init__() |
|
|
| self.hidden_size = hidden_size |
| self.cross_attention_dim = cross_attention_dim |
| self.rank = rank |
|
|
| self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank) |
| self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) |
| self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) |
| self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) |
|
|
| def __call__( |
| self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0 |
| ): |
| batch_size, sequence_length, _ = hidden_states.shape |
| attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) |
| attention_mask = ( |
| attention_mask.reshape([batch_size, attn.num_heads, -1, attention_mask.shape[-1]]) |
| if attention_mask is not None |
| else None |
| ) |
|
|
| query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) |
| query = attn.head_to_batch_dim(query) |
|
|
| encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states |
|
|
| key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) |
| value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) |
|
|
| key = attn.head_to_batch_dim(key) |
| value = attn.head_to_batch_dim(value) |
|
|
| attention_probs = attn.get_attention_scores(query, key, attention_mask) |
| hidden_states = paddle.matmul(attention_probs, value) |
| hidden_states = attn.batch_to_head_dim(hidden_states) |
|
|
| |
| hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) |
| |
| hidden_states = attn.to_out[1](hidden_states) |
|
|
| return hidden_states |
|
|
|
|
| class CrossAttnAddedKVProcessor: |
| def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): |
| residual = hidden_states |
| hidden_states = hidden_states.reshape([hidden_states.shape[0], hidden_states.shape[1], -1]).transpose( |
| [0, 2, 1] |
| ) |
| batch_size, sequence_length, _ = hidden_states.shape |
| encoder_hidden_states = encoder_hidden_states.transpose([0, 2, 1]) |
|
|
| attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) |
| attention_mask = ( |
| attention_mask.reshape([batch_size, attn.num_heads, -1, attention_mask.shape[-1]]) |
| if attention_mask is not None |
| else None |
| ) |
|
|
| hidden_states = attn.group_norm(hidden_states.transpose([0, 2, 1])).transpose([0, 2, 1]) |
|
|
| query = attn.to_q(hidden_states) |
| query = attn.head_to_batch_dim(query) |
|
|
| key = attn.to_k(hidden_states) |
| value = attn.to_v(hidden_states) |
| key = attn.head_to_batch_dim(key) |
| value = attn.head_to_batch_dim(value) |
|
|
| encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) |
| encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) |
| encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) |
| encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) |
|
|
| key = paddle.concat([encoder_hidden_states_key_proj, key], axis=2) |
| value = paddle.concat([encoder_hidden_states_value_proj, value], axis=2) |
|
|
| attention_probs = attn.get_attention_scores(query, key, attention_mask) |
| hidden_states = paddle.matmul(attention_probs, value) |
| hidden_states = attn.batch_to_head_dim(hidden_states) |
|
|
| |
| hidden_states = attn.to_out[0](hidden_states) |
| |
| hidden_states = attn.to_out[1](hidden_states) |
|
|
| hidden_states = hidden_states.transpose([0, 2, 1]).reshape(residual.shape) |
| hidden_states = hidden_states + residual |
|
|
| return hidden_states |
|
|
|
|
| class SlicedAttnProcessor: |
| def __init__(self, slice_size): |
| self.slice_size = slice_size |
|
|
| def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): |
| batch_size, sequence_length, _ = hidden_states.shape |
|
|
| attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) |
|
|
| query = attn.to_q(hidden_states) |
| query = attn.head_to_batch_dim(query) |
|
|
| encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states |
| key = attn.to_k(encoder_hidden_states) |
| value = attn.to_v(encoder_hidden_states) |
| key = attn.head_to_batch_dim(key) |
| value = attn.head_to_batch_dim(value) |
|
|
| query = query.flatten(0, 1) |
| key = key.flatten(0, 1) |
| value = value.flatten(0, 1) |
|
|
| batch_size_attention = query.shape[0] |
| hidden_states = paddle.zeros((batch_size_attention, sequence_length, attn.head_dim), dtype=query.dtype) |
|
|
| for i in range(hidden_states.shape[0] // self.slice_size): |
| start_idx = i * self.slice_size |
| end_idx = (i + 1) * self.slice_size |
|
|
| query_slice = query[start_idx:end_idx] |
| key_slice = key[start_idx:end_idx] |
| attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None |
|
|
| attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) |
|
|
| attn_slice = paddle.matmul(attn_slice, value[start_idx:end_idx]) |
|
|
| hidden_states[start_idx:end_idx] = attn_slice |
|
|
| |
| hidden_states = hidden_states.reshape([-1, attn.num_heads, sequence_length, attn.head_dim]) |
| |
| hidden_states = attn.batch_to_head_dim(hidden_states) |
|
|
| |
| hidden_states = attn.to_out[0](hidden_states) |
| |
| hidden_states = attn.to_out[1](hidden_states) |
|
|
| return hidden_states |
|
|
|
|
| class SlicedAttnAddedKVProcessor: |
| def __init__(self, slice_size): |
| self.slice_size = slice_size |
|
|
| def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states=None, attention_mask=None): |
| residual = hidden_states |
| hidden_states = hidden_states.reshape([hidden_states.shape[0], hidden_states.shape[1], -1]).transpose( |
| [0, 2, 1] |
| ) |
| encoder_hidden_states = encoder_hidden_states.transpose([0, 2, 1]) |
|
|
| batch_size, sequence_length, _ = hidden_states.shape |
|
|
| attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) |
|
|
| hidden_states = attn.group_norm(hidden_states.transpose([0, 2, 1])).transpose([0, 2, 1]) |
|
|
| query = attn.to_q(hidden_states) |
| query = attn.head_to_batch_dim(query) |
|
|
| key = attn.to_k(hidden_states) |
| value = attn.to_v(hidden_states) |
| encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) |
| encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) |
|
|
| key = attn.head_to_batch_dim(key) |
| value = attn.head_to_batch_dim(value) |
| encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) |
| encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) |
|
|
| key = paddle.concat([encoder_hidden_states_key_proj, key], axis=2) |
| value = paddle.concat([encoder_hidden_states_value_proj, value], axis=2) |
|
|
| query = query.flatten(0, 1) |
| key = key.flatten(0, 1) |
| value = value.flatten(0, 1) |
|
|
| batch_size_attention = query.shape[0] |
| hidden_states = paddle.zeros((batch_size_attention, sequence_length, attn.head_dim), dtype=query.dtype) |
| for i in range(hidden_states.shape[0] // self.slice_size): |
| start_idx = i * self.slice_size |
| end_idx = (i + 1) * self.slice_size |
|
|
| query_slice = query[start_idx:end_idx] |
| key_slice = key[start_idx:end_idx] |
| attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None |
|
|
| attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) |
|
|
| attn_slice = paddle.matmul(attn_slice, value[start_idx:end_idx]) |
|
|
| hidden_states[start_idx:end_idx] = attn_slice |
|
|
| |
| hidden_states = hidden_states.reshape([-1, attn.num_heads, sequence_length, attn.head_dim]) |
| |
| hidden_states = attn.batch_to_head_dim(hidden_states) |
|
|
| |
| hidden_states = attn.to_out[0](hidden_states) |
| |
| hidden_states = attn.to_out[1](hidden_states) |
|
|
| hidden_states = hidden_states.transpose([0, 2, 1]).reshape(residual.shape) |
| hidden_states = hidden_states + residual |
|
|
| return hidden_states |
|
|
|
|
| AttnProcessor = Union[ |
| CrossAttnProcessor, |
| SlicedAttnProcessor, |
| CrossAttnAddedKVProcessor, |
| SlicedAttnAddedKVProcessor, |
| ] |
|
|