| | import torch |
| | from einops import rearrange |
| | from torch import Tensor |
| | import math |
| | from torchvision.utils import save_image |
| | from torchvision.io import read_image |
| | from PIL import Image |
| | import torchvision.transforms as transforms |
| |
|
| |
|
| | def adaptive_attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, txt_shape: int, img_shape: int, cur_step:int, cur_block:int, info) -> Tensor: |
| | q, k = apply_rope(q, k, pe) |
| |
|
| | |
| | x = scaled_dot_product_attention(q, k, v, txt_shape, img_shape, cur_step, cur_block, info) |
| | x = rearrange(x, "B H L D -> B L (H D)") |
| |
|
| | return x |
| |
|
| |
|
| | def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: |
| | q, k = apply_rope(q, k, pe) |
| |
|
| | x = torch.nn.functional.scaled_dot_product_attention(q, k, v) |
| | x = rearrange(x, "B H L D -> B L (H D)") |
| |
|
| | return x |
| |
|
| |
|
| | def rope(pos: Tensor, dim: int, theta: int) -> Tensor: |
| | assert dim % 2 == 0 |
| | scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim |
| | omega = 1.0 / (theta**scale) |
| | out = torch.einsum("...n,d->...nd", pos, omega) |
| | out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) |
| | out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) |
| | return out.float() |
| |
|
| |
|
| | def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: |
| | xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) |
| | xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) |
| | xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] |
| | xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] |
| | return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) |
| |
|
| |
|
| | def auto_mask(load_list, mask_accumulator, thre, info, mask_num = 4): |
| |
|
| | mask_list = [] |
| | for img_path in load_list: |
| | load_mask_img = Image.open(img_path).convert('L') |
| | |
| | transform = transforms.PILToTensor() |
| | mask_tensor = transform(load_mask_img) |
| | mask_tensor = mask_tensor.to(device=mask_accumulator.device, dtype=mask_accumulator.dtype) |
| | mask_tensor /= 255.0 |
| | mask_list.append(mask_tensor) |
| |
|
| | |
| | mask_list.sort(key=lambda x: x.sum().item(), reverse=True) |
| | |
| | num_masks = len(mask_list) |
| | if num_masks > mask_num: |
| | |
| | start_block = info['attn_guidance'] |
| | end_block = info['attn_guidance'] + mask_num |
| | if end_block > num_masks - 1: |
| | selected_masks = mask_list[-mask_num: ] |
| | else: |
| | selected_masks = mask_list[start_block: end_block] |
| | else: |
| | selected_masks = mask_list |
| |
|
| | |
| | for mask in selected_masks: |
| | mask_accumulator += mask |
| |
|
| | mask_tensor = (mask_accumulator / len(selected_masks)).to(dtype=mask_accumulator.dtype) |
| | mask_tensor[mask_tensor >= thre] = 1 |
| | mask_tensor[mask_tensor < thre] = 0 |
| |
|
| | return mask_tensor |
| |
|
| |
|
| | |
| | def scaled_dot_product_attention(query, key, value, txt_shape, img_shape, cur_step, cur_block, info, |
| | token_index=2, layer=range(19), attn_mask=None, dropout_p=0.0, coefficient=10, tau=0.5, thre=0.3, |
| | is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor: |
| | L, S = query.size(-2), key.size(-2) |
| | scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale |
| | attn_bias = torch.zeros(L, S, dtype=query.dtype).cuda() |
| | if is_causal: |
| | assert attn_mask is None |
| | temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) |
| | attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) |
| | attn_bias.to(query.dtype) |
| |
|
| | if attn_mask is not None: |
| | if attn_mask.dtype == torch.bool: |
| | attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) |
| | else: |
| | attn_bias += attn_mask |
| |
|
| | if enable_gqa: |
| | key = key.repeat_interleave(query.size(-3)//key.size(-3), -3) |
| | value = value.repeat_interleave(query.size(-3)//value.size(-3), -3) |
| |
|
| | attn_weight = query @ key.transpose(-2, -1) * scale_factor |
| | attn_weight += attn_bias |
| | attn_weight = torch.softmax(attn_weight, dim=-1) |
| | attn_weight = torch.dropout(attn_weight, dropout_p, train=True) |
| |
|
| | if not info['inverse']: |
| | |
| | txt_img_cross = attn_weight[:, :, -img_shape:, :txt_shape] |
| | |
| | token_heatmap = txt_img_cross[:, :, :, token_index] |
| | token_heatmap = token_heatmap.mean(dim=1)[0] |
| | min_val, max_val = token_heatmap.min(), token_heatmap.max() |
| | norm_heatmap = (token_heatmap - min_val) / (max_val - min_val) |
| |
|
| | mask_img = torch.sigmoid(coefficient*(norm_heatmap - 0.5)) |
| | |
| | H = W = int(math.sqrt(mask_img.size(0))) |
| | mask_img = mask_img.reshape(H, W) |
| |
|
| | save_path = f'heatmap/step_{cur_step}_layer_{cur_block}_token{token_index}.png' |
| | load_path = [f'heatmap/step_{cur_step-1}_layer_{i}_token{token_index}.png' for i in layer] |
| | save_image(mask_img.unsqueeze(0), save_path) |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | mask_img[mask_img >= thre] = 1 |
| | mask_img[mask_img < thre] = 0 |
| | |
| | |
| | |
| |
|
| | mask_tensor = torch.zeros_like(mask_img) |
| | if cur_step > 3: |
| | mask_accumulator = torch.zeros_like(mask_tensor.unsqueeze(0), dtype=mask_img.dtype) |
| | mask_tensor = auto_mask(load_path, mask_accumulator, thre, info, mask_num=4) |
| | if cur_block == 1: |
| | save_image(mask_tensor, f'heatmap/average_heatmaps/step_{cur_step}_layer_{cur_block}_token{token_index}.png') |
| |
|
| |
|
| | if not torch.all(mask_tensor == 0): |
| | highlight_factor = 2.0 |
| | reduce_factor = 0.8 |
| |
|
| | mask_tensor = mask_tensor.reshape(1, H * W) |
| | mask_tensor = mask_tensor.unsqueeze(1).unsqueeze(-1) |
| | |
| | multiplier = torch.where(mask_tensor.bool(), torch.tensor(highlight_factor), torch.tensor(reduce_factor)) |
| | attn_weight[:, :, -img_shape:, :15] *= multiplier |
| |
|
| | return attn_weight @ value |
| |
|
| | ''' |
| | if cur_step == 14 and (cur_block == 2 or cur_block == 7 or cur_block == 12): |
| | mask_img = torch.zeros_like(mask_img) |
| | for j in range(5): |
| | token_heatmap = txt_img_cross[:, :, :, j] |
| | token_heatmap = token_heatmap.mean(dim=1)[0] |
| | min_val, max_val = token_heatmap.min(), token_heatmap.max() |
| | norm_heatmap = (token_heatmap - min_val) / (max_val - min_val) |
| | |
| | mask_img = torch.sigmoid(coefficient*(norm_heatmap - 0.5)) |
| | |
| | H = W = int(math.sqrt(mask_img.size(0))) |
| | mask_img = mask_img.reshape(H, W) |
| | save_path = f'/home/hfle/personalization/FireFlow-Fast-Inversion-of-Rectified-Flow-for-Image-Semantic-Editing/heatmap/step_{cur_step}_layer_{cur_block}_token{j}.png' |
| | save_image(mask_img.unsqueeze(0), save_path)''' |
| |
|