File size: 5,879 Bytes
d066167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import torch
import torch.nn.functional as F

ATTN_PRECISION = torch.float16

try:
    import flash_attn_interface
    FLASH_ATTN_3_AVAILABLE = True
    FLASH_ATTN_AVAILABLE = False

except ModuleNotFoundError:
    FLASH_ATTN_3_AVAILABLE = False
    try:
        import flash_attn
        FLASH_ATTN_AVAILABLE = True
    except ModuleNotFoundError:
        FLASH_ATTN_AVAILABLE = False

try:
    import xformers.ops
    XFORMERS_IS_AVAILBLE = True
except:
    XFORMERS_IS_AVAILBLE = False


def half(x):
    if x.dtype not in [torch.float16, torch.bfloat16]:
        x = x.to(ATTN_PRECISION)
    return x

def attn_processor(q, k, v, attn_mask = None, *args, **kwargs):
    if attn_mask is not None:
        if XFORMERS_IS_AVAILBLE:
            out = xformers.ops.memory_efficient_attention(
                q, k, v, attn_bias=attn_mask, *args, **kwargs
            )
        else:
            q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v))
            out = F.scaled_dot_product_attention(
                q, k, v, attn_mask=attn_mask, *args, **kwargs
            ).transpose(1, 2)
    else:
        if FLASH_ATTN_3_AVAILABLE:
            dtype = v.dtype
            q, k, v = map(lambda t: half(t), (q, k, v))
            out = flash_attn_interface.flash_attn_func(q, k, v, *args, **kwargs)[0].to(dtype)
        elif FLASH_ATTN_AVAILABLE:
            dtype = v.dtype
            q, k, v = map(lambda t: half(t), (q, k, v))
            out = flash_attn.flash_attn_func(q, k, v, *args, **kwargs).to(dtype)
        elif XFORMERS_IS_AVAILBLE:
            out = xformers.ops.memory_efficient_attention(q, k, v, *args, **kwargs)
        else:
            q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v))
            out = F.scaled_dot_product_attention(q, k, v, *args, **kwargs).transpose(1, 2)
    return out


def flash_attn_varlen_func(q, k, v, **kwargs):
    if FLASH_ATTN_3_AVAILABLE:
        return flash_attn_interface.flash_attn_varlen_func(q, k, v, **kwargs)[0]
    else:
        return flash_attn.flash_attn_varlen_func(q, k, v, **kwargs)


def split_tensor_by_mask(tensor: torch.Tensor, mask: torch.Tensor, threshold: float = 0.5):
    """
    Split input tensor into foreground and background based on mask, then concatenate them.
    
    Args:
        tensor: Input tensor of shape (batch, seq_len, dim)
        mask: Binary mask of shape (batch, seq_len, 1) or (batch, seq_len)
        threshold: Threshold for mask binarization
        
    Returns:
        split_tensor: Concatenated tensor with foreground first, then background
        fg_indices: Indices of foreground elements for restoration
        bg_indices: Indices of background elements for restoration
        original_shape: Original tensor shape for restoration
    """
    batch_size, seq_len, *dims = tensor.shape
    device, dtype = tensor.device, tensor.dtype
    
    # Ensure mask has correct shape and binarize
    if mask.dim() == 2:
        mask = mask.unsqueeze(-1)
    binary_mask = (mask > threshold).squeeze(-1)  # Shape: (batch, seq_len)
    
    # Store indices for restoration (keep minimal loop for complex indexing)
    fg_indices = [torch.where(binary_mask[b])[0] for b in range(batch_size)]
    bg_indices = [torch.where(~binary_mask[b])[0] for b in range(batch_size)]
    
    # Count elements efficiently
    fg_counts = binary_mask.sum(dim=1)
    bg_counts = (~binary_mask).sum(dim=1) 
    max_fg_len = fg_counts.max().item()
    max_bg_len = bg_counts.max().item()
    
    # Early exit if no elements
    if max_fg_len == 0 and max_bg_len == 0:
        return torch.zeros(batch_size, 0, *dims, device=device, dtype=dtype), fg_indices, bg_indices, tensor.shape
    
    # Create output tensor
    split_tensor = torch.zeros(batch_size, max_fg_len + max_bg_len, *dims, device=device, dtype=dtype)
    
    # Vectorized approach using gather for better efficiency
    for b in range(batch_size):
        if len(fg_indices[b]) > 0:
            split_tensor[b, :len(fg_indices[b])] = tensor[b][fg_indices[b]]
        if len(bg_indices[b]) > 0:
            split_tensor[b, max_fg_len:max_fg_len + len(bg_indices[b])] = tensor[b][bg_indices[b]]
    
    return split_tensor, fg_indices, bg_indices, tensor.shape


def restore_tensor_from_split(split_tensor: torch.Tensor, fg_indices: list, bg_indices: list, 
                            original_shape: torch.Size):
    """
    Restore original tensor from split tensor using stored indices.
    
    Args:
        split_tensor: Split tensor from split_tensor_by_mask
        fg_indices: List of foreground indices for each batch
        bg_indices: List of background indices for each batch  
        original_shape: Original tensor shape
        
    Returns:
        restored_tensor: Restored tensor with original shape and ordering
    """
    batch_size, seq_len = original_shape[:2]
    dims = original_shape[2:]
    device, dtype = split_tensor.device, split_tensor.dtype
    
    # Calculate split point efficiently
    max_fg_len = max((len(fg) for fg in fg_indices), default=0)
    
    # Initialize restored tensor
    restored_tensor = torch.zeros(batch_size, seq_len, *dims, device=device, dtype=dtype)
    
    # Early exit if no elements to restore
    if split_tensor.shape[1] == 0:
        return restored_tensor
    
    # Split tensor parts
    fg_part = split_tensor[:, :max_fg_len] if max_fg_len > 0 else None
    bg_part = split_tensor[:, max_fg_len:] if split_tensor.shape[1] > max_fg_len else None
    
    # Restore in single loop with efficient indexing
    for b in range(batch_size):
        if fg_part is not None and len(fg_indices[b]) > 0:
            restored_tensor[b, fg_indices[b]] = fg_part[b, :len(fg_indices[b])]
        if bg_part is not None and len(bg_indices[b]) > 0:
            restored_tensor[b, bg_indices[b]] = bg_part[b, :len(bg_indices[b])]
    
    return restored_tensor