lhallee commited on
Commit
8b892c5
·
verified ·
1 Parent(s): 097dfab

Upload modeling_esm_plusplus.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +13 -4
modeling_esm_plusplus.py CHANGED
@@ -471,7 +471,17 @@ def _try_get_kernels_flash():
471
  return flash_kernel, flash_kernel_variant
472
 
473
 
474
- FLASH_KERNEL, FLASH_KERNEL_VARIANT = _try_get_kernels_flash()
 
 
 
 
 
 
 
 
 
 
475
 
476
 
477
  def _kernels_flash_forward(
@@ -646,6 +656,8 @@ def resolve_attention_backend(requested_backend: str) -> AttentionBackend:
646
  assert requested_backend in VALID_ATTENTION_BACKENDS, (
647
  f"Unsupported attention backend: {requested_backend}. Expected one of {VALID_ATTENTION_BACKENDS}."
648
  )
 
 
649
  if requested_backend == AttentionBackend.AUTO.value:
650
  if FLASH_KERNEL is not None:
651
  resolved = AttentionBackend.KERNELS_FLASH
@@ -1098,9 +1110,6 @@ class MultiHeadAttention(nn.Module):
1098
  flex_block_mask: "BlockMask | None" = None,
1099
  ) -> tuple[torch.Tensor, None]:
1100
  assert flex_attention is not None, "Flex attention is not available in this environment."
1101
- assert query_BHLD.dtype in (torch.float16, torch.bfloat16), (
1102
- f"Flex attention requires float16 or bfloat16, got {query_BHLD.dtype}."
1103
- )
1104
  fn = _get_flex_attention_fn()
1105
  context_BHLD = fn(query_BHLD, key_BHLD, value_BHLD, block_mask=flex_block_mask, scale=self.scale)
1106
  return rearrange(context_BHLD, "b h s d -> b s (h d)"), None
 
471
  return flash_kernel, flash_kernel_variant
472
 
473
 
474
+ _FLASH_KERNELS_LOADED = False
475
+ FLASH_KERNEL = None
476
+ FLASH_KERNEL_VARIANT = None
477
+
478
+
479
+ def _ensure_flash_kernels_loaded():
480
+ global _FLASH_KERNELS_LOADED, FLASH_KERNEL, FLASH_KERNEL_VARIANT
481
+ if _FLASH_KERNELS_LOADED:
482
+ return
483
+ _FLASH_KERNELS_LOADED = True
484
+ FLASH_KERNEL, FLASH_KERNEL_VARIANT = _try_get_kernels_flash()
485
 
486
 
487
  def _kernels_flash_forward(
 
656
  assert requested_backend in VALID_ATTENTION_BACKENDS, (
657
  f"Unsupported attention backend: {requested_backend}. Expected one of {VALID_ATTENTION_BACKENDS}."
658
  )
659
+ if requested_backend in (AttentionBackend.AUTO.value, AttentionBackend.KERNELS_FLASH.value):
660
+ _ensure_flash_kernels_loaded()
661
  if requested_backend == AttentionBackend.AUTO.value:
662
  if FLASH_KERNEL is not None:
663
  resolved = AttentionBackend.KERNELS_FLASH
 
1110
  flex_block_mask: "BlockMask | None" = None,
1111
  ) -> tuple[torch.Tensor, None]:
1112
  assert flex_attention is not None, "Flex attention is not available in this environment."
 
 
 
1113
  fn = _get_flex_attention_fn()
1114
  context_BHLD = fn(query_BHLD, key_BHLD, value_BHLD, block_mask=flex_block_mask, scale=self.scale)
1115
  return rearrange(context_BHLD, "b h s d -> b s (h d)"), None