jboth commited on
Commit
8484f0c
·
verified ·
1 Parent(s): 7046e4a

Upload flash_attn_stub/flash_attn/__init__.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. flash_attn_stub/flash_attn/__init__.py +191 -13
flash_attn_stub/flash_attn/__init__.py CHANGED
@@ -1,13 +1,191 @@
1
- """flash_attn stub – not available on ZeroGPU."""
2
- def flash_attn_varlen_qkvpacked_func(*a, **kw):
3
- raise NotImplementedError("flash_attn stub")
4
- def flash_attn_varlen_kvpacked_func(*a, **kw):
5
- raise NotImplementedError("flash_attn stub")
6
- def flash_attn_varlen_func(*a, **kw):
7
- raise NotImplementedError("flash_attn stub")
8
- def flash_attn_qkvpacked_func(*a, **kw):
9
- raise NotImplementedError("flash_attn stub")
10
- def flash_attn_kvpacked_func(*a, **kw):
11
- raise NotImplementedError("flash_attn stub")
12
- def flash_attn_func(*a, **kw):
13
- raise NotImplementedError("flash_attn stub")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """flash_attn stub – implements flash attention API using torch SDPA.
2
+
3
+ This replaces the real flash_attn package on systems where it cannot be compiled
4
+ (e.g. ZeroGPU with PyTorch 2.10+cu128 and no matching wheel).
5
+ All functions accept the same signatures as flash_attn 2.x and delegate to
6
+ torch.nn.functional.scaled_dot_product_attention.
7
+ """
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
+
12
+ def _sdpa(q, k, v, causal=False, softmax_scale=None):
13
+ """Apply SDPA. q/k/v are (B, H, L, D)."""
14
+ return F.scaled_dot_product_attention(
15
+ q, k, v,
16
+ is_causal=causal,
17
+ scale=softmax_scale,
18
+ )
19
+
20
+
21
+ # ---------- non-varlen ----------
22
+
23
+ def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False,
24
+ window_size=(-1, -1), softcap=0.0, alibi_slopes=None,
25
+ deterministic=False, return_attn_probs=False):
26
+ """q/k/v: (B, L, H, D) -> out: (B, L, H, D)"""
27
+ # Permute to (B, H, L, D) for SDPA
28
+ q2 = q.transpose(1, 2)
29
+ k2 = k.transpose(1, 2)
30
+ v2 = v.transpose(1, 2)
31
+ out = _sdpa(q2, k2, v2, causal=causal, softmax_scale=softmax_scale)
32
+ out = out.transpose(1, 2) # back to (B, L, H, D)
33
+ return out
34
+
35
+
36
+ def flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False,
37
+ window_size=(-1, -1), softcap=0.0, alibi_slopes=None,
38
+ deterministic=False, return_attn_probs=False):
39
+ """qkv: (B, L, 3, H, D) -> out: (B, L, H, D)"""
40
+ q, k, v = qkv.unbind(dim=2)
41
+ return flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=softmax_scale,
42
+ causal=causal)
43
+
44
+
45
+ def flash_attn_kvpacked_func(q, kv, dropout_p=0.0, softmax_scale=None, causal=False,
46
+ window_size=(-1, -1), softcap=0.0, alibi_slopes=None,
47
+ deterministic=False, return_attn_probs=False):
48
+ """q: (B, Lq, H, D), kv: (B, Lk, 2, H, D) -> out: (B, Lq, H, D)"""
49
+ k, v = kv.unbind(dim=2)
50
+ return flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=softmax_scale,
51
+ causal=causal)
52
+
53
+
54
+ # ---------- varlen ----------
55
+
56
+ def _varlen_sdpa(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
57
+ causal=False, softmax_scale=None):
58
+ """
59
+ q: (total_q, H, D), k: (total_k, H, D), v: (total_k, H, D)
60
+ cu_seqlens_q/k: (batch+1,) int32
61
+ Returns: (total_q, H, D)
62
+ """
63
+ batch = cu_seqlens_q.shape[0] - 1
64
+ H = q.shape[1]
65
+ D = q.shape[2]
66
+
67
+ # Fast path: all seqlens are equal (common case)
68
+ cu_q = cu_seqlens_q.tolist()
69
+ cu_k = cu_seqlens_k.tolist()
70
+
71
+ all_equal = True
72
+ sq0 = cu_q[1] - cu_q[0]
73
+ sk0 = cu_k[1] - cu_k[0]
74
+ for i in range(1, batch):
75
+ if cu_q[i + 1] - cu_q[i] != sq0 or cu_k[i + 1] - cu_k[i] != sk0:
76
+ all_equal = False
77
+ break
78
+
79
+ if all_equal and sq0 == max_seqlen_q and sk0 == max_seqlen_k:
80
+ # Reshape directly – no padding needed
81
+ q2 = q.reshape(batch, sq0, H, D).transpose(1, 2) # (B, H, Lq, D)
82
+ k2 = k.reshape(batch, sk0, H, D).transpose(1, 2)
83
+ v2 = v.reshape(batch, sk0, H, D).transpose(1, 2)
84
+ out = _sdpa(q2, k2, v2, causal=causal, softmax_scale=softmax_scale)
85
+ return out.transpose(1, 2).reshape(-1, H, D)
86
+
87
+ # Slow path: unequal lengths – pad, compute, then gather
88
+ q_padded = q.new_zeros(batch, max_seqlen_q, H, D)
89
+ k_padded = k.new_zeros(batch, max_seqlen_k, H, D)
90
+ v_padded = v.new_zeros(batch, max_seqlen_k, H, D)
91
+
92
+ for i in range(batch):
93
+ sq = cu_q[i + 1] - cu_q[i]
94
+ sk = cu_k[i + 1] - cu_k[i]
95
+ q_padded[i, :sq] = q[cu_q[i]:cu_q[i + 1]]
96
+ k_padded[i, :sk] = k[cu_k[i]:cu_k[i + 1]]
97
+ v_padded[i, :sk] = v[cu_k[i]:cu_k[i + 1]]
98
+
99
+ # Create attention mask for padding
100
+ q_mask = torch.arange(max_seqlen_q, device=q.device).unsqueeze(0) # (1, Lq)
101
+ k_mask = torch.arange(max_seqlen_k, device=k.device).unsqueeze(0) # (1, Lk)
102
+ q_lens = torch.tensor([cu_q[i + 1] - cu_q[i] for i in range(batch)],
103
+ device=q.device).unsqueeze(1) # (B, 1)
104
+ k_lens = torch.tensor([cu_k[i + 1] - cu_k[i] for i in range(batch)],
105
+ device=k.device).unsqueeze(1) # (B, 1)
106
+ # (B, 1, 1, Lk) – True where valid
107
+ attn_mask = (k_mask < k_lens).unsqueeze(1).unsqueeze(2)
108
+ # Also mask out query positions that are padding (their output is ignored anyway)
109
+ # Use float mask: -inf for invalid positions
110
+ attn_bias = torch.zeros(batch, 1, max_seqlen_q, max_seqlen_k,
111
+ device=q.device, dtype=q.dtype)
112
+ attn_bias.masked_fill_(~attn_mask, float('-inf'))
113
+
114
+ if causal:
115
+ causal_mask = torch.triu(
116
+ torch.ones(max_seqlen_q, max_seqlen_k, device=q.device, dtype=torch.bool),
117
+ diagonal=1
118
+ )
119
+ attn_bias.masked_fill_(causal_mask.unsqueeze(0).unsqueeze(0), float('-inf'))
120
+
121
+ q2 = q_padded.transpose(1, 2) # (B, H, Lq, D)
122
+ k2 = k_padded.transpose(1, 2)
123
+ v2 = v_padded.transpose(1, 2)
124
+
125
+ out = F.scaled_dot_product_attention(q2, k2, v2, attn_mask=attn_bias,
126
+ scale=softmax_scale)
127
+ out = out.transpose(1, 2) # (B, Lq, H, D)
128
+
129
+ # Gather results back to packed format
130
+ parts = []
131
+ for i in range(batch):
132
+ sq = cu_q[i + 1] - cu_q[i]
133
+ parts.append(out[i, :sq]) # (sq, H, D)
134
+ return torch.cat(parts, dim=0)
135
+
136
+
137
+ def flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k,
138
+ max_seqlen_q, max_seqlen_k,
139
+ dropout_p=0.0, softmax_scale=None, causal=False,
140
+ window_size=(-1, -1), softcap=0.0, alibi_slopes=None,
141
+ deterministic=False, return_attn_probs=False,
142
+ block_table=None):
143
+ """q/k/v: (total, H, D) -> out: (total_q, H, D)"""
144
+ return _varlen_sdpa(q, k, v, cu_seqlens_q, cu_seqlens_k,
145
+ max_seqlen_q, max_seqlen_k,
146
+ causal=causal, softmax_scale=softmax_scale)
147
+
148
+
149
+ def flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens, max_seqlen,
150
+ dropout_p=0.0, softmax_scale=None, causal=False,
151
+ window_size=(-1, -1), softcap=0.0,
152
+ alibi_slopes=None, deterministic=False,
153
+ return_attn_probs=False):
154
+ """qkv: (total, 3, H, D) -> out: (total, H, D)"""
155
+ q, k, v = qkv.unbind(dim=1)
156
+ return _varlen_sdpa(q, k, v, cu_seqlens, cu_seqlens,
157
+ max_seqlen, max_seqlen,
158
+ causal=causal, softmax_scale=softmax_scale)
159
+
160
+
161
+ def flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k,
162
+ max_seqlen_q, max_seqlen_k,
163
+ dropout_p=0.0, softmax_scale=None, causal=False,
164
+ window_size=(-1, -1), softcap=0.0,
165
+ alibi_slopes=None, deterministic=False,
166
+ return_attn_probs=False):
167
+ """q: (total_q, H, D), kv: (total_k, 2, H, D) -> out: (total_q, H, D)"""
168
+ k, v = kv.unbind(dim=1)
169
+ return _varlen_sdpa(q, k, v, cu_seqlens_q, cu_seqlens_k,
170
+ max_seqlen_q, max_seqlen_k,
171
+ causal=causal, softmax_scale=softmax_scale)
172
+
173
+
174
+ # ---------- with_kvcache (used by some SAM2 code paths) ----------
175
+
176
+ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None,
177
+ rotary_cos=None, rotary_sin=None,
178
+ cache_seqlens=None, cache_batch_idx=None,
179
+ block_table=None, softmax_scale=None, causal=False,
180
+ window_size=(-1, -1), softcap=0.0,
181
+ rotary_interleaved=True, alibi_slopes=None,
182
+ num_splits=0, return_softmax_lse=False):
183
+ """Simplified kv-cache attention fallback."""
184
+ # Combine current k/v with cache if provided
185
+ if k is not None:
186
+ k_full = torch.cat([k_cache, k], dim=1)
187
+ v_full = torch.cat([v_cache, v], dim=1)
188
+ else:
189
+ k_full = k_cache
190
+ v_full = v_cache
191
+ return flash_attn_func(q, k_full, v_full, softmax_scale=softmax_scale, causal=causal)