| |
| """ |
| Module utils |
| """ |
|
|
| import copy |
| import math |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.nn.init import uniform_ |
|
|
| __all__ = 'multi_scale_deformable_attn_pytorch', 'inverse_sigmoid' |
|
|
|
|
| def _get_clones(module, n): |
| return nn.ModuleList([copy.deepcopy(module) for _ in range(n)]) |
|
|
|
|
| def bias_init_with_prob(prior_prob=0.01): |
| """initialize conv/fc bias value according to a given probability value.""" |
| return float(-np.log((1 - prior_prob) / prior_prob)) |
|
|
|
|
| def linear_init_(module): |
| bound = 1 / math.sqrt(module.weight.shape[0]) |
| uniform_(module.weight, -bound, bound) |
| if hasattr(module, 'bias') and module.bias is not None: |
| uniform_(module.bias, -bound, bound) |
|
|
|
|
| def inverse_sigmoid(x, eps=1e-5): |
| x = x.clamp(min=0, max=1) |
| x1 = x.clamp(min=eps) |
| x2 = (1 - x).clamp(min=eps) |
| return torch.log(x1 / x2) |
|
|
|
|
| def multi_scale_deformable_attn_pytorch(value: torch.Tensor, value_spatial_shapes: torch.Tensor, |
| sampling_locations: torch.Tensor, |
| attention_weights: torch.Tensor) -> torch.Tensor: |
| """ |
| Multi-scale deformable attention. |
| https://github.com/IDEA-Research/detrex/blob/main/detrex/layers/multi_scale_deform_attn.py |
| """ |
|
|
| bs, _, num_heads, embed_dims = value.shape |
| _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape |
| value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) |
| sampling_grids = 2 * sampling_locations - 1 |
| sampling_value_list = [] |
| for level, (H_, W_) in enumerate(value_spatial_shapes): |
| |
| |
| |
| |
| value_l_ = (value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_)) |
| |
| |
| |
| sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1) |
| |
| sampling_value_l_ = F.grid_sample(value_l_, |
| sampling_grid_l_, |
| mode='bilinear', |
| padding_mode='zeros', |
| align_corners=False) |
| sampling_value_list.append(sampling_value_l_) |
| |
| |
| |
| attention_weights = attention_weights.transpose(1, 2).reshape(bs * num_heads, 1, num_queries, |
| num_levels * num_points) |
| output = ((torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view( |
| bs, num_heads * embed_dims, num_queries)) |
| return output.transpose(1, 2).contiguous() |
|
|