Upload modeling_esm_plusplus.py with huggingface_hub
Browse files- modeling_esm_plusplus.py +13 -4
modeling_esm_plusplus.py
CHANGED
|
@@ -471,7 +471,17 @@ def _try_get_kernels_flash():
|
|
| 471 |
return flash_kernel, flash_kernel_variant
|
| 472 |
|
| 473 |
|
| 474 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 475 |
|
| 476 |
|
| 477 |
def _kernels_flash_forward(
|
|
@@ -646,6 +656,8 @@ def resolve_attention_backend(requested_backend: str) -> AttentionBackend:
|
|
| 646 |
assert requested_backend in VALID_ATTENTION_BACKENDS, (
|
| 647 |
f"Unsupported attention backend: {requested_backend}. Expected one of {VALID_ATTENTION_BACKENDS}."
|
| 648 |
)
|
|
|
|
|
|
|
| 649 |
if requested_backend == AttentionBackend.AUTO.value:
|
| 650 |
if FLASH_KERNEL is not None:
|
| 651 |
resolved = AttentionBackend.KERNELS_FLASH
|
|
@@ -1098,9 +1110,6 @@ class MultiHeadAttention(nn.Module):
|
|
| 1098 |
flex_block_mask: "BlockMask | None" = None,
|
| 1099 |
) -> tuple[torch.Tensor, None]:
|
| 1100 |
assert flex_attention is not None, "Flex attention is not available in this environment."
|
| 1101 |
-
assert query_BHLD.dtype in (torch.float16, torch.bfloat16), (
|
| 1102 |
-
f"Flex attention requires float16 or bfloat16, got {query_BHLD.dtype}."
|
| 1103 |
-
)
|
| 1104 |
fn = _get_flex_attention_fn()
|
| 1105 |
context_BHLD = fn(query_BHLD, key_BHLD, value_BHLD, block_mask=flex_block_mask, scale=self.scale)
|
| 1106 |
return rearrange(context_BHLD, "b h s d -> b s (h d)"), None
|
|
|
|
| 471 |
return flash_kernel, flash_kernel_variant
|
| 472 |
|
| 473 |
|
| 474 |
+
_FLASH_KERNELS_LOADED = False
|
| 475 |
+
FLASH_KERNEL = None
|
| 476 |
+
FLASH_KERNEL_VARIANT = None
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
def _ensure_flash_kernels_loaded():
|
| 480 |
+
global _FLASH_KERNELS_LOADED, FLASH_KERNEL, FLASH_KERNEL_VARIANT
|
| 481 |
+
if _FLASH_KERNELS_LOADED:
|
| 482 |
+
return
|
| 483 |
+
_FLASH_KERNELS_LOADED = True
|
| 484 |
+
FLASH_KERNEL, FLASH_KERNEL_VARIANT = _try_get_kernels_flash()
|
| 485 |
|
| 486 |
|
| 487 |
def _kernels_flash_forward(
|
|
|
|
| 656 |
assert requested_backend in VALID_ATTENTION_BACKENDS, (
|
| 657 |
f"Unsupported attention backend: {requested_backend}. Expected one of {VALID_ATTENTION_BACKENDS}."
|
| 658 |
)
|
| 659 |
+
if requested_backend in (AttentionBackend.AUTO.value, AttentionBackend.KERNELS_FLASH.value):
|
| 660 |
+
_ensure_flash_kernels_loaded()
|
| 661 |
if requested_backend == AttentionBackend.AUTO.value:
|
| 662 |
if FLASH_KERNEL is not None:
|
| 663 |
resolved = AttentionBackend.KERNELS_FLASH
|
|
|
|
| 1110 |
flex_block_mask: "BlockMask | None" = None,
|
| 1111 |
) -> tuple[torch.Tensor, None]:
|
| 1112 |
assert flex_attention is not None, "Flex attention is not available in this environment."
|
|
|
|
|
|
|
|
|
|
| 1113 |
fn = _get_flex_attention_fn()
|
| 1114 |
context_BHLD = fn(query_BHLD, key_BHLD, value_BHLD, block_mask=flex_block_mask, scale=self.scale)
|
| 1115 |
return rearrange(context_BHLD, "b h s d -> b s (h d)"), None
|