Spaces:
Running on Zero
Running on Zero
| from calendar import c | |
| import torch.nn as nn | |
| from einops import rearrange | |
| from refnet.util import exists, default, checkpoint_wrapper | |
| from .layers import RMSNorm | |
| from .attn_utils import * | |
| def create_masked_attention_bias( | |
| mask: torch.Tensor, | |
| threshold: float, | |
| num_heads: int, | |
| context_len: int | |
| ): | |
| b, seq_len, _ = mask.shape | |
| half_len = context_len // 2 | |
| if context_len % 8 != 0: | |
| padded_context_len = ((context_len + 7) // 8) * 8 | |
| else: | |
| padded_context_len = context_len | |
| fg_bias = torch.zeros(b, seq_len, padded_context_len, device=mask.device, dtype=mask.dtype) | |
| bg_bias = torch.zeros(b, seq_len, padded_context_len, device=mask.device, dtype=mask.dtype) | |
| fg_bias[:, :, half_len:] = -float('inf') | |
| bg_bias[:, :, :half_len] = -float('inf') | |
| attn_bias = torch.where(mask > threshold, fg_bias, bg_bias) | |
| return attn_bias.unsqueeze(1).repeat_interleave(num_heads, dim=1) | |
| class Identity(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| def forward(self, x, *args, **kwargs): | |
| return x | |
| # Rotary Positional Embeddings implementation | |
| class RotaryPositionalEmbeddings(nn.Module): | |
| def __init__(self, dim, max_seq_len=1024, theta=10000.0): | |
| super().__init__() | |
| assert dim % 2 == 0, "Dimension must be divisible by 2" | |
| dim = dim // 2 | |
| self.max_seq_len = max_seq_len | |
| freqs = torch.outer( | |
| torch.arange(self.max_seq_len), | |
| 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim)) | |
| ) | |
| freqs = torch.polar(torch.ones_like(freqs), freqs) | |
| self.register_buffer("freq_h", freqs, persistent=False) | |
| self.register_buffer("freq_w", freqs, persistent=False) | |
| def forward(self, x, grid_size): | |
| bs, seq_len, heads = x.shape[:3] | |
| h, w = grid_size | |
| x_complex = torch.view_as_complex( | |
| x.float().reshape(bs, seq_len, heads, -1, 2) | |
| ) | |
| freqs = torch.cat([ | |
| self.freq_h[:h].view(1, h, 1, -1).expand(bs, h, w, -1), | |
| self.freq_w[:w].view(1, 1, w, -1).expand(bs, h, w, -1) | |
| ], dim=-1).reshape(bs, seq_len, 1, -1) | |
| x_out = x_complex * freqs | |
| x_out = torch.view_as_real(x_out).flatten(3) | |
| return x_out.type_as(x) | |
| class MemoryEfficientAttention(nn.Module): | |
| # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 | |
| def __init__( | |
| self, | |
| query_dim, | |
| context_dim = None, | |
| heads = None, | |
| dim_head = 64, | |
| dropout = 0.0, | |
| log = False, | |
| causal = False, | |
| rope = False, | |
| max_seq_len = 1024, | |
| qk_norm = False, | |
| **kwargs | |
| ): | |
| super().__init__() | |
| if log: | |
| print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " | |
| f"{heads} heads.") | |
| heads = heads or query_dim // dim_head | |
| inner_dim = dim_head * heads | |
| context_dim = default(context_dim, query_dim) | |
| self.heads = heads | |
| self.dim_head = dim_head | |
| self.to_q = nn.Linear(query_dim, inner_dim, bias=False) | |
| self.to_k = nn.Linear(context_dim, inner_dim, bias=False) | |
| self.to_v = nn.Linear(context_dim, inner_dim, bias=False) | |
| self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) | |
| self.q_norm = RMSNorm(inner_dim) if qk_norm else Identity() | |
| self.k_norm = RMSNorm(inner_dim) if qk_norm else Identity() | |
| self.rope = RotaryPositionalEmbeddings(dim_head, max_seq_len=max_seq_len) if rope else Identity() | |
| self.attn_ops = causal_ops if causal else {} | |
| # default setting for split cross-attention | |
| self.bg_scale = 1. | |
| self.fg_scale = 1. | |
| self.merge_scale = 0. | |
| self.mask_threshold = 0.05 | |
| def forward( | |
| self, | |
| x, | |
| context=None, | |
| mask=None, | |
| scale=1., | |
| scale_factor=None, | |
| grid_size=None, | |
| **kwargs, | |
| ): | |
| context = default(context, x) | |
| if exists(mask): | |
| out = self.masked_forward(x, context, mask, scale, scale_factor) | |
| else: | |
| q = self.to_q(x) | |
| k = self.to_k(context) | |
| v = self.to_v(context) | |
| out = self.attn_forward(q, k, v, scale, grid_size) | |
| return self.to_out(out) | |
| def attn_forward(self, q, k, v, scale=1., grid_size=None, mask=None): | |
| q, k = map( | |
| lambda t: | |
| self.rope(rearrange(t, "b n (h c) -> b n h c", h=self.heads), grid_size), | |
| (self.q_norm(q), self.k_norm(k)) | |
| ) | |
| v = rearrange(v, "b n (h c) -> b n h c", h=self.heads) | |
| out = attn_processor(q, k, v, attn_mask=mask, **self.attn_ops) * scale | |
| out = rearrange(out, "b n h c -> b n (h c)") | |
| return out | |
| def masked_forward(self, x, context, mask, scale=1., scale_factor=None): | |
| # split cross-attention function | |
| def qkv_forward(x, context): | |
| q = self.to_q(x) | |
| k = self.to_k(context) | |
| v = self.to_v(context) | |
| return q, k, v | |
| assert exists(scale_factor), "Scale factor must be assigned before masked attention" | |
| mask = rearrange( | |
| F.interpolate(mask, scale_factor=scale_factor, mode="bicubic"), | |
| "b c h w -> b (h w) c" | |
| ).contiguous() | |
| if self.merge_scale > 0: | |
| # split cross-attention with merging scale, need two times forward | |
| c1, c2 = context.chunk(2, dim=1) | |
| # Background region cross-attention | |
| q2, k2, v2 = qkv_forward(x, c2) | |
| bg_out = self.attn_forward(q2, k2, v2, scale) * self.bg_scale | |
| # Foreground region cross-attention | |
| q1, k1, v1 = qkv_forward(x, c1) | |
| fg_out = self.attn_forward(q1, k1, v1, scale) * self.fg_scale | |
| fg_out = fg_out * (1 - self.merge_scale) + bg_out * self.merge_scale | |
| return torch.where(mask < self.mask_threshold, bg_out, fg_out) | |
| else: | |
| attn_mask = create_masked_attention_bias( | |
| mask, self.mask_threshold, self.heads, context.size(1) | |
| ) | |
| q, k, v = qkv_forward(x, context) | |
| return self.attn_forward(q, k, v, mask=attn_mask) * scale | |
| class MultiModalAttention(MemoryEfficientAttention): | |
| def __init__(self, query_dim, context_dim_2, heads=8, dim_head=64, qk_norm=False, *args, **kwargs): | |
| super().__init__(query_dim, heads=heads, dim_head=dim_head, qk_norm=qk_norm, *args, **kwargs) | |
| inner_dim = dim_head * heads | |
| self.to_k_2 = nn.Linear(context_dim_2, inner_dim, bias=False) | |
| self.to_v_2 = nn.Linear(context_dim_2, inner_dim, bias=False) | |
| self.k2_norm = RMSNorm(inner_dim) if qk_norm else Identity() | |
| def forward(self, x, context=None, mask=None, scale=1., grid_size=None): | |
| if not isinstance(scale, list) and not isinstance(scale, tuple): | |
| scale = (scale, scale) | |
| assert len(context.shape) == 4, "Multi-modal attention requires different context inputs to be (b, m, n c)" | |
| context, context2 = context.chunk(2, dim=1) | |
| q = self.to_q(x) | |
| k = self.to_k(context) | |
| v = self.to_v(context) | |
| k2 = self.to_k_2(context2) | |
| v2 = self.to_k_2(context2) | |
| b, _, _ = q.shape | |
| q, k, k2 = map( | |
| lambda t: self.rope(rearrange(t, "b n (h c) -> b n h c", h=self.heads), grid_size), | |
| (self.q_norm(q), self.k_norm(k), self.k2_norm(k2)) | |
| ) | |
| v, v2 = map(lambda t: rearrange(t, "b n (h c) -> b n h c", h=self.heads), (v, v2)) | |
| out = (attn_processor(q, k, v, **self.attn_ops) * scale[0] + | |
| attn_processor(q, k2, v2, **self.attn_ops) * scale[1]) | |
| if exists(mask): | |
| raise NotImplementedError | |
| out = rearrange(out, "b n h c -> b n (h c)") | |
| return self.to_out(out) | |
| class MultiScaleCausalAttention(MemoryEfficientAttention): | |
| def forward( | |
| self, | |
| x, | |
| context=None, | |
| mask=None, | |
| scale=1., | |
| scale_factor=None, | |
| grid_size=None, | |
| token_lens=None | |
| ): | |
| context = default(context, x) | |
| q = self.to_q(x) | |
| k = self.to_k(context) | |
| v = self.to_v(context) | |
| out = self.attn_forward(q, k, v, scale, grid_size=grid_size, token_lens=token_lens) | |
| return self.to_out(out) | |
| def attn_forward(self, q, k, v, scale = 1., grid_size = None, token_lens = None): | |
| q, k, v = map( | |
| lambda t: rearrange(t, "b n (h c) -> b n h c", h=self.heads), | |
| (self.q_norm(q), self.k_norm(k), v) | |
| ) | |
| attn_output = [] | |
| prev_idx = 0 | |
| for idx, (grid, length) in enumerate(zip(grid_size, token_lens)): | |
| end_idx = prev_idx + length + (idx == 0) | |
| rope_prev_idx = prev_idx + (idx == 0) | |
| rope_slice = slice(rope_prev_idx, end_idx) | |
| q[:, rope_slice] = self.rope(q[:, rope_slice], grid) | |
| k[:, rope_slice] = self.rope(k[:, rope_slice], grid) | |
| qs = q[:, prev_idx: end_idx] | |
| ks, vs = map(lambda t: t[:, :end_idx], (k, v)) | |
| attn_output.append(attn_processor(qs.clone(), ks.clone(), vs.clone()) * scale) | |
| prev_idx = end_idx | |
| attn_output = rearrange(torch.cat(attn_output, 1), "b n h c -> b n (h c)") | |
| return attn_output | |
| # if FLASH_ATTN_3_AVAILABLE or FLASH_ATTN_AVAILABLE: | |
| # k_chunks = [] | |
| # v_chunks = [] | |
| # kv_token_lens = [] | |
| # prev_idx = 0 | |
| # for idx, (grid, length) in enumerate(zip(grid_size, token_lens)): | |
| # end_idx = prev_idx + length + (idx == 0) | |
| # rope_prev_idx = prev_idx + (idx == 0) | |
| # rope_slice = slice(rope_prev_idx, end_idx) | |
| # q[:, rope_slice], k[:, rope_slice], v[:, rope_slice] = map( | |
| # lambda t: self.rope(t[:, rope_slice], grid), | |
| # (q, k, v) | |
| # ) | |
| # kv_token_lens.append(end_idx+1) | |
| # k_chunks.append(k[:, :end_idx]) | |
| # v_chunks.append(v[:, :end_idx]) | |
| # prev_idx = end_idx | |
| # k = torch.cat(k_chunks, 1) | |
| # v = torch.cat(v_chunks, 1) | |
| # B, N, H, C = q.shape | |
| # token_lens = torch.tensor(token_lens, device=q.device, dtype=torch.int32) | |
| # kv_token_lens = torch.tensor(kv_token_lens, device=q.device, dtype=torch.int32) | |
| # token_lens[0] = token_lens[0] + 1 | |
| # | |
| # cu_seqlens_q, cu_seqlens_kv = map(lambda t: | |
| # torch.cat([t.new_zeros([1]), t]).cumsum(0, dtype=torch.int32), | |
| # (token_lens, kv_token_lens) | |
| # ) | |
| # max_seqlen_q, max_seqlen_kv = map(lambda t: int(t.max()), (token_lens, kv_token_lens)) | |
| # | |
| # q_flat = q.reshape(-1, H, C).contiguous() | |
| # k_flat = k.reshape(-1, H, C).contiguous() | |
| # v_flat = v.reshape(-1, H, C).contiguous() | |
| # out_flat = flash_attn_varlen_func( | |
| # q=q_flat, k=k_flat, v=v_flat, | |
| # cu_seqlens_q=cu_seqlens_q, | |
| # cu_seqlens_k=cu_seqlens_kv, | |
| # max_seqlen_q=max_seqlen_q, | |
| # max_seqlen_k=max_seqlen_kv, | |
| # causal=True, | |
| # ) | |
| # | |
| # out = rearrange(out_flat, "(b n) h c -> b n (h c)", b=B, n=N) | |
| # return out * scale | |