| | import torch |
| | import torch.nn.functional as F |
| | import numpy as np |
| | from PIL import Image |
| |
|
| | attn_maps = {} |
| | def hook_fn(name): |
| | def forward_hook(module, input, output): |
| | if hasattr(module.processor, "attn_map"): |
| | attn_maps[name] = module.processor.attn_map |
| | del module.processor.attn_map |
| |
|
| | return forward_hook |
| |
|
| | def register_cross_attention_hook(unet): |
| | for name, module in unet.named_modules(): |
| | if name.split('.')[-1].startswith('attn2'): |
| | module.register_forward_hook(hook_fn(name)) |
| |
|
| | return unet |
| |
|
| | def upscale(attn_map, target_size): |
| | attn_map = torch.mean(attn_map, dim=0) |
| | attn_map = attn_map.permute(1,0) |
| | temp_size = None |
| |
|
| | for i in range(0,5): |
| | scale = 2 ** i |
| | if ( target_size[0] // scale ) * ( target_size[1] // scale) == attn_map.shape[1]*64: |
| | temp_size = (target_size[0]//(scale*8), target_size[1]//(scale*8)) |
| | break |
| |
|
| | assert temp_size is not None, "temp_size cannot is None" |
| |
|
| | attn_map = attn_map.view(attn_map.shape[0], *temp_size) |
| |
|
| | attn_map = F.interpolate( |
| | attn_map.unsqueeze(0).to(dtype=torch.float32), |
| | size=target_size, |
| | mode='bilinear', |
| | align_corners=False |
| | )[0] |
| |
|
| | attn_map = torch.softmax(attn_map, dim=0) |
| | return attn_map |
| | def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True): |
| |
|
| | idx = 0 if instance_or_negative else 1 |
| | net_attn_maps = [] |
| |
|
| | for name, attn_map in attn_maps.items(): |
| | attn_map = attn_map.cpu() if detach else attn_map |
| | attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze() |
| | attn_map = upscale(attn_map, image_size) |
| | net_attn_maps.append(attn_map) |
| |
|
| | net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0) |
| |
|
| | return net_attn_maps |
| |
|
| | def attnmaps2images(net_attn_maps): |
| |
|
| | |
| | images = [] |
| |
|
| | for attn_map in net_attn_maps: |
| | attn_map = attn_map.cpu().numpy() |
| | |
| |
|
| | normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255 |
| | normalized_attn_map = normalized_attn_map.astype(np.uint8) |
| | |
| | image = Image.fromarray(normalized_attn_map) |
| |
|
| | |
| | images.append(image) |
| |
|
| | |
| | return images |
| | def is_torch2_available(): |
| | return hasattr(F, "scaled_dot_product_attention") |
| |
|
| | def get_generator(seed, device): |
| |
|
| | if seed is not None: |
| | if isinstance(seed, list): |
| | generator = [torch.Generator(device).manual_seed(seed_item) for seed_item in seed] |
| | else: |
| | generator = torch.Generator(device).manual_seed(seed) |
| | else: |
| | generator = None |
| |
|
| | return generator |