Build uploaded using `kernels`.
Browse files
build/torch-cuda/_ops.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
import torch
|
| 2 |
-
ops = torch.ops.
|
| 3 |
|
| 4 |
def add_op_namespace_prefix(op_name: str):
|
| 5 |
"""
|
| 6 |
Prefix op by namespace.
|
| 7 |
"""
|
| 8 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
ops = torch.ops._flash_attn4_474fc55
|
| 3 |
|
| 4 |
def add_op_namespace_prefix(op_name: str):
|
| 5 |
"""
|
| 6 |
Prefix op by namespace.
|
| 7 |
"""
|
| 8 |
+
return f"_flash_attn4_474fc55::{op_name}"
|
build/torch-cuda/flash_bwd_sm100.py
CHANGED
|
@@ -1544,6 +1544,7 @@ class FlashAttentionBackwardSm100:
|
|
| 1544 |
)
|
| 1545 |
# Dealloc the tensor memory buffer
|
| 1546 |
tmem.relinquish_alloc_permit()
|
|
|
|
| 1547 |
tmem.free(tmem_ptr)
|
| 1548 |
|
| 1549 |
# Compute
|
|
@@ -1595,6 +1596,7 @@ class FlashAttentionBackwardSm100:
|
|
| 1595 |
fastdiv_mods,
|
| 1596 |
blocksparse_tensors,
|
| 1597 |
)
|
|
|
|
| 1598 |
|
| 1599 |
# Reduce
|
| 1600 |
# (0, 1, 2, 3) - dQ
|
|
@@ -1615,6 +1617,7 @@ class FlashAttentionBackwardSm100:
|
|
| 1615 |
mdQ_semaphore,
|
| 1616 |
blocksparse_tensors,
|
| 1617 |
)
|
|
|
|
| 1618 |
|
| 1619 |
return
|
| 1620 |
|
|
|
|
| 1544 |
)
|
| 1545 |
# Dealloc the tensor memory buffer
|
| 1546 |
tmem.relinquish_alloc_permit()
|
| 1547 |
+
tmem_alloc_barrier.arrive_and_wait()
|
| 1548 |
tmem.free(tmem_ptr)
|
| 1549 |
|
| 1550 |
# Compute
|
|
|
|
| 1596 |
fastdiv_mods,
|
| 1597 |
blocksparse_tensors,
|
| 1598 |
)
|
| 1599 |
+
tmem_alloc_barrier.arrive()
|
| 1600 |
|
| 1601 |
# Reduce
|
| 1602 |
# (0, 1, 2, 3) - dQ
|
|
|
|
| 1617 |
mdQ_semaphore,
|
| 1618 |
blocksparse_tensors,
|
| 1619 |
)
|
| 1620 |
+
tmem_alloc_barrier.arrive()
|
| 1621 |
|
| 1622 |
return
|
| 1623 |
|
build/torch-cuda/flash_fwd_sm100.py
CHANGED
|
@@ -1090,6 +1090,7 @@ class FlashAttentionForwardSm100:
|
|
| 1090 |
)
|
| 1091 |
# Dealloc the tensor memory buffer
|
| 1092 |
tmem.relinquish_alloc_permit()
|
|
|
|
| 1093 |
tmem.free(tmem_ptr)
|
| 1094 |
|
| 1095 |
# ///////////////////////////////////////////////////////////////////////////////
|
|
@@ -1157,6 +1158,8 @@ class FlashAttentionForwardSm100:
|
|
| 1157 |
if warp_idx < self.correction_warp_ids[0] and warp_idx >= self.softmax1_warp_ids[0]:
|
| 1158 |
softmax_loop(stage=1, tStS=tStS)
|
| 1159 |
|
|
|
|
|
|
|
| 1160 |
# ///////////////////////////////////////////////////////////////////////////////
|
| 1161 |
# Correction
|
| 1162 |
# ///////////////////////////////////////////////////////////////////////////////
|
|
@@ -1189,6 +1192,7 @@ class FlashAttentionForwardSm100:
|
|
| 1189 |
TileSchedulerCls,
|
| 1190 |
blocksparse_tensors,
|
| 1191 |
)
|
|
|
|
| 1192 |
|
| 1193 |
return
|
| 1194 |
|
|
|
|
| 1090 |
)
|
| 1091 |
# Dealloc the tensor memory buffer
|
| 1092 |
tmem.relinquish_alloc_permit()
|
| 1093 |
+
tmem_alloc_barrier.arrive_and_wait()
|
| 1094 |
tmem.free(tmem_ptr)
|
| 1095 |
|
| 1096 |
# ///////////////////////////////////////////////////////////////////////////////
|
|
|
|
| 1158 |
if warp_idx < self.correction_warp_ids[0] and warp_idx >= self.softmax1_warp_ids[0]:
|
| 1159 |
softmax_loop(stage=1, tStS=tStS)
|
| 1160 |
|
| 1161 |
+
tmem_alloc_barrier.arrive()
|
| 1162 |
+
|
| 1163 |
# ///////////////////////////////////////////////////////////////////////////////
|
| 1164 |
# Correction
|
| 1165 |
# ///////////////////////////////////////////////////////////////////////////////
|
|
|
|
| 1192 |
TileSchedulerCls,
|
| 1193 |
blocksparse_tensors,
|
| 1194 |
)
|
| 1195 |
+
tmem_alloc_barrier.arrive()
|
| 1196 |
|
| 1197 |
return
|
| 1198 |
|