kernels-bot commited on
Commit
eca38cd
·
verified ·
1 Parent(s): 6337869

Uploaded using `kernel-builder`.

Browse files
build/torch-cuda/_ops.py CHANGED
@@ -1,8 +1,38 @@
1
  import torch
2
- ops = torch.ops._flash_attn4_efe2479
3
 
4
- def add_op_namespace_prefix(op_name: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  """
6
  Prefix op by namespace.
7
  """
8
- return f"_flash_attn4_efe2479::{op_name}"
 
1
  import torch
 
2
 
3
+ def get_backend() -> str:
4
+ """Detect the backend by inspecting torch."""
5
+ import torch
6
+
7
+ if hasattr(torch, "neuron"):
8
+ # Needs to be sorted before specific Torch builds, since Neuron
9
+ # extension can be loaded into e.g. CUDA Torch builds.
10
+ return "neuron"
11
+ elif torch.version.cuda is not None:
12
+ return "cuda"
13
+ elif torch.version.hip is not None:
14
+ return "rocm"
15
+ elif torch.backends.mps.is_available():
16
+ return "metal"
17
+ elif hasattr(torch.version, "xpu") and torch.version.xpu is not None:
18
+ return "xpu"
19
+ else:
20
+ return "cpu"
21
+
22
+
23
+ def _find_ops_name() -> str:
24
+ kernel_name = "flash_attn4"
25
+ unique_id = "86f75d9"
26
+ backend = get_backend()
27
+ return f"_{kernel_name}_{backend}_{unique_id}"
28
+
29
+
30
+ _OPS_NAME = _find_ops_name()
31
+
32
+ ops = getattr(torch.ops, _OPS_NAME)
33
+
34
+ def add_op_namespace_prefix(op_name: str) -> str:
35
  """
36
  Prefix op by namespace.
37
  """
38
+ return f"{_OPS_NAME}::{op_name}"
build/torch-cuda/metadata.json CHANGED
@@ -1,4 +1,6 @@
1
  {
 
 
2
  "version": 0,
3
  "license": "BSD-3-Clause",
4
  "python-depends": [
 
1
  {
2
+ "name": "flash-attn4",
3
+ "id": "_flash_attn4_cuda_86f75d9",
4
  "version": 0,
5
  "license": "BSD-3-Clause",
6
  "python-depends": [