| |
|
| | import math |
| | import torch |
| | import triton |
| | import triton.language as tl |
| |
|
| | @triton.heuristics( |
| | { |
| | "EVEN_N": lambda args: args["seqlen"] % args["BLOCK_N"] == 0, |
| | "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], |
| | } |
| | ) |
| | @triton.jit |
| | def _fwd_eva_prep_kv_kernel( |
| | K, |
| | V, |
| | PARAM_MU, |
| | PARAM_PHI, |
| | Mask, |
| | Out_RFA_K, |
| | Out_RFA_V, |
| | softmax_scale, |
| | stride_kb, stride_kh, stride_kn, |
| | stride_vb, stride_vh, stride_vn, |
| | stride_mu_h, |
| | stride_phi_h, |
| | stride_mb, stride_mn, |
| | stride_ok_b, stride_ok_h, stride_ok_c, |
| | stride_ov_b, stride_ov_h, stride_ov_c, |
| | nheads, |
| | seqlen, |
| | nchunks, |
| | headdim, |
| | CHUNKS_PER_BLOCK: tl.constexpr, |
| | CHUNK_SIZE: tl.constexpr, |
| | MASK_TYPE: tl.constexpr, |
| | BLOCK_HEADDIM: tl.constexpr, |
| | EVEN_N: tl.constexpr, |
| | EVEN_HEADDIM: tl.constexpr, |
| | BLOCK_N: tl.constexpr, |
| | ): |
| | start_n = tl.program_id(0) |
| | offs_bh = tl.program_id(1) |
| | offs_h = offs_bh % nheads |
| | offs_b = offs_bh // nheads |
| | |
| | |
| | |
| | offs_c = tl.arange(0, CHUNKS_PER_BLOCK) |
| | offs_m = tl.arange(0, CHUNK_SIZE) |
| | offs_d = tl.arange(0, BLOCK_HEADDIM) |
| |
|
| | k_ptrs = ( |
| | K + |
| | offs_b * stride_kb + |
| | offs_h * stride_kh + |
| | ( |
| | ( |
| | start_n * BLOCK_N + |
| | offs_c[:, None, None] * CHUNK_SIZE + |
| | offs_m[None, :, None] |
| | ) * stride_kn + |
| | offs_d[None, None, :] |
| | ) |
| | ) |
| | v_ptrs = ( |
| | V + |
| | offs_b * stride_vb + |
| | offs_h * stride_vh + |
| | ( |
| | ( |
| | start_n * BLOCK_N + |
| | offs_c[:, None, None] * CHUNK_SIZE + |
| | offs_m[None, :, None] |
| | ) * stride_vn + |
| | offs_d[None, None, :] |
| | ) |
| | ) |
| | param_mu_ptrs = ( |
| | PARAM_MU + |
| | offs_h * stride_mu_h + |
| | offs_d[None, None, :] |
| | ) |
| | param_phi_ptrs = ( |
| | PARAM_PHI + |
| | offs_h * stride_phi_h + |
| | offs_d[None, None, :] |
| | ) |
| | log2e = 1.4426950408889634 |
| | if MASK_TYPE == 1: |
| | m_ptrs = ( |
| | Mask + |
| | offs_b * stride_mb + |
| | ( |
| | ( |
| | start_n * BLOCK_N + |
| | offs_c[:, None] * CHUNK_SIZE + |
| | offs_m[None, :] |
| | ) * stride_mn |
| | ) |
| | ) |
| | if EVEN_N: |
| | if EVEN_HEADDIM: |
| | k = tl.load( |
| | k_ptrs |
| | ) |
| | else: |
| | k = tl.load( |
| | k_ptrs, |
| | mask=offs_d[None, None, :] < headdim, |
| | other=0.0 |
| | ) |
| | else: |
| | if EVEN_HEADDIM: |
| | k = tl.load( |
| | k_ptrs, |
| | mask=( |
| | start_n * BLOCK_N + |
| | offs_c[:, None, None] * CHUNK_SIZE + |
| | offs_m[None, :, None] |
| | ) < seqlen, |
| | other=0.0 |
| | ) |
| | else: |
| | k = tl.load( |
| | k_ptrs, |
| | mask=( |
| | ( |
| | start_n * BLOCK_N + |
| | offs_c[:, None, None] * CHUNK_SIZE + |
| | offs_m[None, :, None] |
| | ) < seqlen |
| | ) & (offs_d[None, None, :] < headdim), |
| | other=0.0 |
| | ) |
| | |
| | param_mu = tl.load(param_mu_ptrs).to(k.dtype) |
| | rfa_k_c_w = tl.zeros([CHUNKS_PER_BLOCK, CHUNK_SIZE], dtype=tl.float32) |
| | rfa_k_c_w += tl.sum(k * param_mu, axis=-1) |
| | rfa_k_c_w *= log2e |
| | if MASK_TYPE == 1: |
| | if EVEN_N: |
| | mask = tl.load( |
| | m_ptrs |
| | ) |
| | else: |
| | mask = tl.load( |
| | m_ptrs, |
| | mask=( |
| | start_n * BLOCK_N + |
| | offs_c[:, None] * CHUNK_SIZE + |
| | offs_m[None, :] |
| | ) < seqlen, |
| | other=1, |
| | ) |
| | rfa_k_c_w = tl.where(mask, float("-inf"), rfa_k_c_w) |
| | |
| | m_rfa_k_c_w = tl.max(rfa_k_c_w, axis=-1) |
| | masked_out_rows_rfa_k = (m_rfa_k_c_w == float("-inf")) |
| | m_rfa_k_c_w_masked = tl.where(masked_out_rows_rfa_k, 0, m_rfa_k_c_w) |
| | rfa_k_c_w = tl.exp2(rfa_k_c_w - m_rfa_k_c_w_masked[:, None]) |
| | denom_k = tl.sum(rfa_k_c_w, axis=-1) |
| | denom_k = tl.where(denom_k == 0.0, 1.0, denom_k) |
| | rfa_k_c_w = rfa_k_c_w / denom_k[:, None] |
| | rfa_k_c = tl.sum(k * rfa_k_c_w[:, :, None].to(k.dtype), axis=-2) |
| | |
| | offs_out_c = start_n * CHUNKS_PER_BLOCK + tl.arange(0, CHUNKS_PER_BLOCK) |
| | out_rfa_k_ptrs = ( |
| | Out_RFA_K + |
| | offs_b * stride_ok_b + |
| | offs_h * stride_ok_h + |
| | (offs_out_c[:, None] * stride_ok_c + offs_d[None, :]) |
| | ) |
| |
|
| | if EVEN_N: |
| | if EVEN_HEADDIM: |
| | tl.store( |
| | out_rfa_k_ptrs, rfa_k_c |
| | ) |
| | else: |
| | tl.store( |
| | out_rfa_k_ptrs, rfa_k_c, |
| | mask=offs_d[None, :] < headdim |
| | ) |
| | else: |
| | if EVEN_HEADDIM: |
| | tl.store( |
| | out_rfa_k_ptrs, rfa_k_c, |
| | mask=offs_out_c[:, None] < nchunks |
| | ) |
| | else: |
| | tl.store( |
| | out_rfa_k_ptrs, rfa_k_c, |
| | mask=(offs_out_c[:, None] < nchunks) & (offs_d[None, :] < headdim) |
| | ) |
| |
|
| |
|
| | param_phi = tl.load(param_phi_ptrs).to(k.dtype) |
| | rfa_v_c_w = tl.zeros([CHUNKS_PER_BLOCK, CHUNK_SIZE], dtype=tl.float32) |
| | rfa_v_c_w += tl.sum(k * param_phi, axis=-1) |
| | rfa_v_c_w -= (0.5 * tl.sum(k * k, axis=-1)) |
| | rfa_v_c_w *= log2e * softmax_scale |
| | if not EVEN_N: |
| | rfa_v_c_w += tl.where( |
| | ( |
| | start_n * BLOCK_N + |
| | offs_c[:, None] * CHUNK_SIZE + |
| | offs_m[None, :] |
| | ) < seqlen, |
| | 0, |
| | float("-inf") |
| | ) |
| |
|
| | if MASK_TYPE == 1: |
| | rfa_v_c_w = tl.where(mask, float("-inf"), rfa_v_c_w) |
| |
|
| | if EVEN_N: |
| | if EVEN_HEADDIM: |
| | v = tl.load( |
| | v_ptrs |
| | ) |
| | else: |
| | v = tl.load( |
| | v_ptrs, |
| | mask=offs_d[None, None, :] < headdim, |
| | other=0.0 |
| | ) |
| | else: |
| | if EVEN_HEADDIM: |
| | v = tl.load( |
| | v_ptrs, |
| | mask=( |
| | start_n * BLOCK_N + |
| | offs_c[:, None, None] * CHUNK_SIZE + |
| | offs_m[None, :, None] |
| | ) < seqlen, |
| | other=0.0 |
| | ) |
| | else: |
| | v = tl.load( |
| | v_ptrs, |
| | mask=( |
| | ( |
| | start_n * BLOCK_N + |
| | offs_c[:, None, None] * CHUNK_SIZE + |
| | offs_m[None, :, None] |
| | ) < seqlen |
| | ) & (offs_d[None, None, :] < headdim), |
| | other=0.0 |
| | ) |
| | |
| |
|
| | m_rfa_v_c_w = tl.max(rfa_v_c_w, axis=-1) |
| | masked_out_rows_rfa_v = (m_rfa_v_c_w == float("-inf")) |
| | m_rfa_v_c_w_masked = tl.where(masked_out_rows_rfa_v, 0, m_rfa_v_c_w) |
| | rfa_v_c_w = tl.exp2(rfa_v_c_w - m_rfa_v_c_w_masked[:, None]) |
| | denom_v = tl.sum(rfa_v_c_w, axis=-1) |
| | denom_v = tl.where(denom_v == 0.0, 1.0, denom_v) |
| | rfa_v_c_w = rfa_v_c_w / denom_v[:, None] |
| | rfa_v_c = tl.sum(v * rfa_v_c_w[:, :, None].to(v.dtype), axis=-2) |
| |
|
| | offs_out_c = start_n * CHUNKS_PER_BLOCK + tl.arange(0, CHUNKS_PER_BLOCK) |
| | out_rfa_v_ptrs = ( |
| | Out_RFA_V + |
| | offs_b * stride_ov_b + |
| | offs_h * stride_ov_h + |
| | (offs_out_c[:, None] * stride_ov_c + offs_d[None, :]) |
| | ) |
| | if EVEN_N: |
| | if EVEN_HEADDIM: |
| | tl.store( |
| | out_rfa_v_ptrs, rfa_v_c |
| | ) |
| | else: |
| | tl.store( |
| | out_rfa_v_ptrs, rfa_v_c, |
| | mask=offs_d[None, :] < headdim |
| | ) |
| | else: |
| | if EVEN_HEADDIM: |
| | tl.store( |
| | out_rfa_v_ptrs, rfa_v_c, |
| | mask=offs_out_c[:, None] < nchunks |
| | ) |
| | else: |
| | tl.store( |
| | out_rfa_v_ptrs, rfa_v_c, |
| | mask=(offs_out_c[:, None] < nchunks) & (offs_d[None, :] < headdim) |
| | ) |
| |
|
| |
|
| |
|
| | @triton.heuristics( |
| | { |
| | "EVEN_N": lambda args: args["seqlen"] % args["BLOCK_N"] == 0, |
| | "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], |
| | } |
| | ) |
| | @triton.jit |
| | def _bwd_eva_prep_kv_kernel( |
| | RFA_K, |
| | RFA_V, |
| | K, |
| | V, |
| | PARAM_MU, |
| | PARAM_PHI, |
| | Mask, |
| | D_RFA_K, |
| | D_RFA_V, |
| | D_K, |
| | D_V, |
| | D_PARAM_MU_PARTIAL, |
| | D_PARAM_PHI_PARTIAL, |
| | softmax_scale, |
| | stride_rfa_k_b, stride_rfa_k_h, stride_rfa_k_c, |
| | stride_rfa_v_b, stride_rfa_v_h, stride_rfa_v_c, |
| | stride_kb, stride_kh, stride_kn, |
| | stride_vb, stride_vh, stride_vn, |
| | stride_mu_h, |
| | stride_phi_h, |
| | stride_mb, stride_mn, |
| | stride_d_rfa_k_b, stride_d_rfa_k_h, stride_d_rfa_k_c, |
| | stride_d_rfa_v_b, stride_d_rfa_v_h, stride_d_rfa_v_c, |
| | stride_d_k_b, stride_d_k_h, stride_d_k_n, |
| | stride_d_v_b, stride_d_v_h, stride_d_v_n, |
| | stride_d_mu_b, stride_d_mu_h, stride_d_mu_g, |
| | stride_d_phi_b, stride_d_phi_h, stride_d_phi_g, |
| | nheads, |
| | seqlen, |
| | nchunks, |
| | headdim, |
| | CHUNKS_PER_BLOCK: tl.constexpr, |
| | CHUNK_SIZE: tl.constexpr, |
| | MASK_TYPE: tl.constexpr, |
| | BLOCK_HEADDIM: tl.constexpr, |
| | EVEN_N: tl.constexpr, |
| | EVEN_HEADDIM: tl.constexpr, |
| | BLOCK_N: tl.constexpr, |
| | ): |
| | start_n = tl.program_id(0) |
| | offs_bh = tl.program_id(1) |
| | offs_h = offs_bh % nheads |
| | offs_b = offs_bh // nheads |
| | |
| | |
| | |
| | offs_c = tl.arange(0, CHUNKS_PER_BLOCK) |
| | offs_m = tl.arange(0, CHUNK_SIZE) |
| | offs_d = tl.arange(0, BLOCK_HEADDIM) |
| |
|
| | offs_rfa_c = start_n * CHUNKS_PER_BLOCK + offs_c |
| |
|
| | k_ptrs = ( |
| | K + |
| | offs_b * stride_kb + |
| | offs_h * stride_kh + |
| | ( |
| | ( |
| | start_n * BLOCK_N + |
| | offs_c[:, None, None] * CHUNK_SIZE + |
| | offs_m[None, :, None] |
| | ) * stride_kn + |
| | offs_d[None, None, :] |
| | ) |
| | ) |
| | rfa_k_ptrs = ( |
| | RFA_K + |
| | offs_b * stride_rfa_k_b + |
| | offs_h * stride_rfa_k_h + |
| | (offs_rfa_c[:, None] * stride_rfa_k_c + offs_d[None, :]) |
| | ) |
| | rfa_v_ptrs = ( |
| | RFA_V + |
| | offs_b * stride_rfa_v_b + |
| | offs_h * stride_rfa_v_h + |
| | (offs_rfa_c[:, None] * stride_rfa_v_c + offs_d[None, :]) |
| | ) |
| |
|
| | d_rfa_k_ptrs = ( |
| | D_RFA_K + |
| | offs_b * stride_d_rfa_k_b + |
| | offs_h * stride_d_rfa_k_h + |
| | (offs_rfa_c[:, None] * stride_d_rfa_k_c + offs_d[None, :]) |
| | ) |
| | d_rfa_v_ptrs = ( |
| | D_RFA_V + |
| | offs_b * stride_d_rfa_v_b + |
| | offs_h * stride_d_rfa_v_h + |
| | (offs_rfa_c[:, None] * stride_d_rfa_v_c + offs_d[None, :]) |
| | ) |
| |
|
| | param_mu_ptrs = ( |
| | PARAM_MU + |
| | offs_h * stride_mu_h + |
| | offs_d[None, None, :] |
| | ) |
| | param_phi_ptrs = ( |
| | PARAM_PHI + |
| | offs_h * stride_phi_h + |
| | offs_d[None, None, :] |
| | ) |
| | |
| | log2e = 1.4426950408889634 |
| | if MASK_TYPE == 1: |
| | m_ptrs = ( |
| | Mask + |
| | offs_b * stride_mb + |
| | ( |
| | ( |
| | start_n * BLOCK_N + |
| | offs_c[:, None] * CHUNK_SIZE + |
| | offs_m[None, :] |
| | ) * stride_mn |
| | ) |
| | ) |
| | if EVEN_N: |
| | if EVEN_HEADDIM: |
| | k = tl.load( |
| | k_ptrs |
| | ) |
| | else: |
| | k = tl.load( |
| | k_ptrs, |
| | mask=offs_d[None, None, :] < headdim, |
| | other=0.0 |
| | ) |
| | else: |
| | if EVEN_HEADDIM: |
| | k = tl.load( |
| | k_ptrs, |
| | mask=( |
| | start_n * BLOCK_N + |
| | offs_c[:, None, None] * CHUNK_SIZE + |
| | offs_m[None, :, None] |
| | ) < seqlen, |
| | other=0.0 |
| | ) |
| | else: |
| | k = tl.load( |
| | k_ptrs, |
| | mask=( |
| | ( |
| | start_n * BLOCK_N + |
| | offs_c[:, None, None] * CHUNK_SIZE + |
| | offs_m[None, :, None] |
| | ) < seqlen |
| | ) & (offs_d[None, None, :] < headdim), |
| | other=0.0 |
| | ) |
| |
|
| | if EVEN_N: |
| | if EVEN_HEADDIM: |
| | rfa_k = tl.load( |
| | rfa_k_ptrs |
| | ) |
| | else: |
| | rfa_k = tl.load( |
| | rfa_k_ptrs, |
| | mask=offs_d[None, :] < headdim, |
| | other=0.0 |
| | ) |
| | else: |
| | if EVEN_HEADDIM: |
| | rfa_k = tl.load( |
| | rfa_k_ptrs, |
| | mask=offs_rfa_c[:, None] < nchunks, |
| | other=0.0 |
| | ) |
| | else: |
| | rfa_k = tl.load( |
| | rfa_k_ptrs, |
| | mask=(offs_rfa_c[:, None] < nchunks) & (offs_d[None, :] < headdim), |
| | other=0.0 |
| | ) |
| | |
| | if EVEN_N: |
| | if EVEN_HEADDIM: |
| | d_rfa_k = tl.load( |
| | d_rfa_k_ptrs |
| | ) |
| | else: |
| | d_rfa_k = tl.load( |
| | d_rfa_k_ptrs, |
| | mask=offs_d[None, :] < headdim, |
| | other=0.0 |
| | ) |
| | else: |
| | if EVEN_HEADDIM: |
| | d_rfa_k = tl.load( |
| | d_rfa_k_ptrs, |
| | mask=offs_rfa_c[:, None] < nchunks, |
| | other=0.0 |
| | ) |
| | else: |
| | d_rfa_k = tl.load( |
| | d_rfa_k_ptrs, |
| | mask=(offs_rfa_c[:, None] < nchunks) & (offs_d[None, :] < headdim), |
| | other=0.0 |
| | ) |
| | |
| | param_mu = tl.load(param_mu_ptrs).to(k.dtype) |
| | mu_c_w = tl.zeros([CHUNKS_PER_BLOCK, CHUNK_SIZE], dtype=tl.float32) |
| | mu_c_w += tl.sum(k * param_mu, axis=-1) |
| | mu_c_w *= log2e |
| |
|
| | if not EVEN_N: |
| | mu_c_w += tl.where( |
| | ( |
| | start_n * BLOCK_N + |
| | offs_c[:, None] * CHUNK_SIZE + |
| | offs_m[None, :] |
| | ) < seqlen, |
| | 0, |
| | float("-inf") |
| | ) |
| |
|
| | if MASK_TYPE == 1: |
| | if EVEN_N: |
| | mask = tl.load( |
| | m_ptrs |
| | ) |
| | else: |
| | mask = tl.load( |
| | m_ptrs, |
| | mask=( |
| | start_n * BLOCK_N + |
| | offs_c[:, None] * CHUNK_SIZE + |
| | offs_m[None, :] |
| | ) < seqlen, |
| | other=1, |
| | ) |
| | mu_c_w = tl.where(mask, float("-inf"), mu_c_w) |
| |
|
| | |
| | m_mu_c_w = tl.max(mu_c_w, axis=-1) |
| | masked_out_rows_mu = (m_mu_c_w == float("-inf")) |
| | m_mu_c_w_masked = tl.where(masked_out_rows_mu, 0, m_mu_c_w) |
| | mu_c_w = tl.exp2(mu_c_w - m_mu_c_w_masked[:, None]) |
| | denom_mu = tl.sum(mu_c_w, axis=-1) |
| | denom_mu = tl.where(denom_mu == 0.0, 1.0, denom_mu) |
| | mu_tilde_c_w = mu_c_w / denom_mu[:, None] |
| | mu_tilde_c_w = mu_tilde_c_w.to(k.dtype) |
| | |
| | d_mu_tilde_c_w = tl.sum(d_rfa_k[:, None, :] * k, axis=-1) |
| | |
| | d_out_rfa_k_t_rfa_k = tl.sum(d_rfa_k * rfa_k, axis=-1)[:, None] |
| | d_mu_c_w = (d_mu_tilde_c_w - d_out_rfa_k_t_rfa_k) * mu_tilde_c_w |
| |
|
| | |
| | d_param_mu = tl.sum(tl.sum(d_mu_c_w[:, :, None] * k, axis=0), axis=0) |
| | |
| | d_k = mu_tilde_c_w[:, :, None] * d_rfa_k[:, None, :] + d_mu_c_w[:, :, None] * param_mu |
| |
|
| | d_param_mu_partial_ptrs = ( |
| | D_PARAM_MU_PARTIAL + |
| | offs_b * stride_d_mu_b + |
| | offs_h * stride_d_mu_h + |
| | start_n * stride_d_mu_g + |
| | offs_d |
| | ) |
| | if EVEN_HEADDIM: |
| | tl.store( |
| | d_param_mu_partial_ptrs, d_param_mu |
| | ) |
| | else: |
| | tl.store( |
| | d_param_mu_partial_ptrs, d_param_mu, |
| | mask=offs_d < headdim |
| | ) |
| |
|
| |
|
| | v_ptrs = ( |
| | V + |
| | offs_b * stride_vb + |
| | offs_h * stride_vh + |
| | ( |
| | ( |
| | start_n * BLOCK_N + |
| | offs_c[:, None, None] * CHUNK_SIZE + |
| | offs_m[None, :, None] |
| | ) * stride_vn + |
| | offs_d[None, None, :] |
| | ) |
| | ) |
| | if EVEN_N: |
| | if EVEN_HEADDIM: |
| | v = tl.load( |
| | v_ptrs |
| | ) |
| | else: |
| | v = tl.load( |
| | v_ptrs, |
| | mask=offs_d[None, None, :] < headdim, |
| | other=0.0 |
| | ) |
| | else: |
| | if EVEN_HEADDIM: |
| | v = tl.load( |
| | v_ptrs, |
| | mask=( |
| | start_n * BLOCK_N + |
| | offs_c[:, None, None] * CHUNK_SIZE + |
| | offs_m[None, :, None] |
| | ) < seqlen, |
| | other=0.0 |
| | ) |
| | else: |
| | v = tl.load( |
| | v_ptrs, |
| | mask=( |
| | ( |
| | start_n * BLOCK_N + |
| | offs_c[:, None, None] * CHUNK_SIZE + |
| | offs_m[None, :, None] |
| | ) < seqlen |
| | ) & (offs_d[None, None, :] < headdim), |
| | other=0.0 |
| | ) |
| |
|
| |
|
| | if EVEN_N: |
| | if EVEN_HEADDIM: |
| | rfa_v = tl.load( |
| | rfa_v_ptrs |
| | ) |
| | else: |
| | rfa_v = tl.load( |
| | rfa_v_ptrs, |
| | mask=offs_d[None, :] < headdim, |
| | other=0.0 |
| | ) |
| | else: |
| | if EVEN_HEADDIM: |
| | rfa_v = tl.load( |
| | rfa_v_ptrs, |
| | mask=offs_rfa_c[:, None] < nchunks, |
| | other=0.0 |
| | ) |
| | else: |
| | rfa_v = tl.load( |
| | rfa_v_ptrs, |
| | mask=(offs_rfa_c[:, None] < nchunks) & (offs_d[None, :] < headdim), |
| | other=0.0 |
| | ) |
| | |
| | if EVEN_N: |
| | if EVEN_HEADDIM: |
| | d_rfa_v = tl.load( |
| | d_rfa_v_ptrs |
| | ) |
| | else: |
| | d_rfa_v = tl.load( |
| | d_rfa_v_ptrs, |
| | mask=offs_d[None, :] < headdim, |
| | other=0.0 |
| | ) |
| | else: |
| | if EVEN_HEADDIM: |
| | d_rfa_v = tl.load( |
| | d_rfa_v_ptrs, |
| | mask=offs_rfa_c[:, None] < nchunks, |
| | other=0.0 |
| | ) |
| | else: |
| | d_rfa_v = tl.load( |
| | d_rfa_v_ptrs, |
| | mask=(offs_rfa_c[:, None] < nchunks) & (offs_d[None, :] < headdim), |
| | other=0.0 |
| | ) |
| | |
| | param_phi = tl.load(param_phi_ptrs).to(k.dtype) |
| | phi_c_w = tl.zeros([CHUNKS_PER_BLOCK, CHUNK_SIZE], dtype=tl.float32) |
| | phi_c_w += tl.sum(k * param_phi, axis=-1) |
| | phi_c_w -= (0.5 * tl.sum(k * k, axis=-1)) |
| | phi_c_w *= log2e * softmax_scale |
| | if not EVEN_N: |
| | phi_c_w += tl.where( |
| | ( |
| | start_n * BLOCK_N + |
| | offs_c[:, None] * CHUNK_SIZE + |
| | offs_m[None, :] |
| | ) < seqlen, |
| | 0, |
| | float("-inf") |
| | ) |
| |
|
| | if MASK_TYPE == 1: |
| | phi_c_w = tl.where(mask, float("-inf"), phi_c_w) |
| |
|
| |
|
| | m_phi_c_w = tl.max(phi_c_w, axis=-1) |
| | masked_out_rows_phi = (m_phi_c_w == float("-inf")) |
| | m_phi_c_w_masked = tl.where(masked_out_rows_phi, 0, m_phi_c_w) |
| | phi_c_w = tl.exp2(phi_c_w - m_phi_c_w_masked[:, None]) |
| | denom_phi = tl.sum(phi_c_w, axis=-1) |
| | denom_phi = tl.where(denom_phi == 0.0, 1.0, denom_phi) |
| | phi_tilde_c_w = phi_c_w / denom_phi[:, None] |
| | |
| | |
| | phi_tilde_c_w = phi_tilde_c_w.to(k.dtype) |
| | d_phi_tilde_c_w = tl.sum(d_rfa_v[:, None, :] * v, axis=-1) |
| | d_out_rfa_v_t_rfa_v = tl.sum(d_rfa_v * rfa_v, axis=-1)[:, None] |
| | d_phi_c_w = (d_phi_tilde_c_w.to(tl.float32) - d_out_rfa_v_t_rfa_v.to(tl.float32)) * phi_tilde_c_w |
| |
|
| | d_param_phi = tl.sum(tl.sum(d_phi_c_w[:, :, None] * k * softmax_scale, axis=0), axis=0) |
| | d_v = phi_tilde_c_w[:, :, None] * d_rfa_v[:, None, :] |
| | |
| | d_k = d_k + softmax_scale * d_phi_c_w[:, :, None] * (param_phi - k) |
| |
|
| | d_k_ptrs = ( |
| | D_K + |
| | offs_b * stride_d_k_b + |
| | offs_h * stride_d_k_h + |
| | ( |
| | ( |
| | start_n * BLOCK_N + |
| | offs_c[:, None, None] * CHUNK_SIZE + |
| | offs_m[None, :, None] |
| | ) * stride_d_k_n + |
| | offs_d[None, None, :] |
| | ) |
| | ) |
| | d_v_ptrs = ( |
| | D_V + |
| | offs_b * stride_d_v_b + |
| | offs_h * stride_d_v_h + |
| | ( |
| | ( |
| | start_n * BLOCK_N + |
| | offs_c[:, None, None] * CHUNK_SIZE + |
| | offs_m[None, :, None] |
| | ) * stride_d_v_n + |
| | offs_d[None, None, :] |
| | ) |
| | ) |
| | if EVEN_N: |
| | if EVEN_HEADDIM: |
| | tl.store( |
| | d_k_ptrs, d_k |
| | ) |
| | tl.store( |
| | d_v_ptrs, d_v |
| | ) |
| | else: |
| | tl.store( |
| | d_k_ptrs, d_k, |
| | mask=offs_d[None, None, :] < headdim |
| | ) |
| | tl.store( |
| | d_v_ptrs, d_v, |
| | mask=offs_d[None, None, :] < headdim |
| | ) |
| | else: |
| | if EVEN_HEADDIM: |
| | tl.store( |
| | d_k_ptrs, d_k, |
| | mask=( |
| | ( |
| | start_n * BLOCK_N + |
| | offs_c[:, None, None] * CHUNK_SIZE + |
| | offs_m[None, :, None] |
| | ) < seqlen |
| | ), |
| | ) |
| | tl.store( |
| | d_v_ptrs, d_v, |
| | mask=( |
| | ( |
| | start_n * BLOCK_N + |
| | offs_c[:, None, None] * CHUNK_SIZE + |
| | offs_m[None, :, None] |
| | ) < seqlen |
| | ), |
| | ) |
| | else: |
| | tl.store( |
| | d_k_ptrs, d_k, |
| | mask=( |
| | ( |
| | start_n * BLOCK_N + |
| | offs_c[:, None, None] * CHUNK_SIZE + |
| | offs_m[None, :, None] |
| | ) < seqlen |
| | ) & (offs_d[None, None, :] < headdim), |
| | ) |
| | tl.store( |
| | d_v_ptrs, d_v, |
| | mask=( |
| | ( |
| | start_n * BLOCK_N + |
| | offs_c[:, None, None] * CHUNK_SIZE + |
| | offs_m[None, :, None] |
| | ) < seqlen |
| | ) & (offs_d[None, None, :] < headdim), |
| | ) |
| | d_param_phi_partial_ptrs = ( |
| | D_PARAM_PHI_PARTIAL + |
| | offs_b * stride_d_phi_b + |
| | offs_h * stride_d_phi_h + |
| | start_n * stride_d_phi_g + |
| | offs_d |
| | ) |
| | if EVEN_HEADDIM: |
| | tl.store( |
| | d_param_phi_partial_ptrs, d_param_phi |
| | ) |
| | else: |
| | tl.store( |
| | d_param_phi_partial_ptrs, d_param_phi, |
| | mask=offs_d < headdim |
| | ) |
| |
|
| | def triton_eva_prep_kv_fwd(k, v, param_mu, param_phi, mask, softmax_scale, chunksize): |
| | k, v, param_mu, param_phi = [ |
| | x if x.stride(-1) == 1 else x.contiguous() |
| | for x in [k, v, param_mu, param_phi] |
| | ] |
| |
|
| | |
| | batch, nheads, seqlen, head_dim = k.shape |
| | assert seqlen % chunksize == 0, "seqlen must be divisible by chunksize" |
| | nchunks = seqlen // chunksize |
| | assert k.shape == (batch, nheads, seqlen, head_dim) |
| | assert v.shape == (batch, nheads, seqlen, head_dim) |
| | assert param_mu.shape == (1, nheads, 1, 1, head_dim) |
| | assert param_phi.shape == (1, nheads, 1, 1, head_dim) |
| | assert head_dim <= 128, "We only test head dimensions up to 128" |
| | assert k.dtype == v.dtype == param_mu.dtype == param_phi.dtype, "All tensors must have the same type" |
| | assert k.dtype in [torch.bfloat16, torch.float], "Only support bf16 and fp32 for now" |
| | assert k.is_cuda and v.is_cuda |
| | softmax_scale = softmax_scale or 1.0 / math.sqrt(head_dim) |
| |
|
| | mask_type = 0 |
| | if mask is not None: |
| | mask_type = 1 |
| | assert mask.dtype == torch.bool |
| | assert mask.is_cuda |
| | assert mask.dim() == 4 |
| | assert mask.shape == (batch, 1, seqlen, 1) |
| | if mask.stride(-1) != 1: |
| | mask = mask.contiguous() |
| | mask_strides = ( |
| | (mask.stride(0), mask.stride(2)) |
| | if mask_type == 1 else |
| | (0, 0) |
| | ) |
| | out_rfa_k = torch.empty((batch, nheads, nchunks, head_dim), dtype=k.dtype, device=k.device) |
| | out_rfa_v = torch.empty((batch, nheads, nchunks, head_dim), dtype=v.dtype, device=v.device) |
| |
|
| | BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16) |
| | BLOCK = 128 |
| | num_warps = 4 if head_dim <= 64 else 8 |
| | |
| | assert (BLOCK > chunksize) & (BLOCK % chunksize) == 0, "BLOCK must be divisible by chunksize" |
| | chunks_per_block = BLOCK // chunksize |
| |
|
| | grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_N"]), batch * nheads) |
| | _fwd_eva_prep_kv_kernel[grid]( |
| | k, |
| | v, |
| | param_mu, |
| | param_phi, |
| | mask, |
| | out_rfa_k, |
| | out_rfa_v, |
| | softmax_scale, |
| | k.stride(0), k.stride(1), k.stride(2), |
| | v.stride(0), v.stride(1), v.stride(2), |
| | param_mu.stride(1), |
| | param_phi.stride(1), |
| | mask_strides[0], mask_strides[1], |
| | out_rfa_k.stride(0), out_rfa_k.stride(1), out_rfa_k.stride(2), |
| | out_rfa_v.stride(0), out_rfa_v.stride(1), out_rfa_v.stride(2), |
| | nheads, |
| | seqlen, |
| | nchunks, |
| | head_dim, |
| | chunks_per_block, |
| | chunksize, |
| | mask_type, |
| | BLOCK_HEADDIM, |
| | BLOCK_N=BLOCK, |
| | num_warps=num_warps, |
| | num_stages=1, |
| | ) |
| | return out_rfa_k, out_rfa_v |
| |
|
| | def triton_eva_prep_kv_bwd( |
| | d_rfa_k, d_rfa_v, |
| | k, v, param_mu, param_phi, |
| | mask, |
| | rfa_k, rfa_v, |
| | d_k, d_v, d_param_mu, d_param_phi, |
| | softmax_scale, |
| | mask_type, |
| | chunksize |
| | ): |
| | d_rfa_k, d_rfa_v = [ |
| | x if x.stride(-1) == 1 else x.contiguous() |
| | for x in [d_rfa_k, d_rfa_v] |
| | ] |
| |
|
| | |
| | batch, nheads, seqlen, head_dim = k.shape |
| | assert seqlen % chunksize == 0, "seqlen must be divisible by chunksize" |
| | nchunks = seqlen // chunksize |
| | softmax_scale = softmax_scale or 1.0 / math.sqrt(head_dim) |
| |
|
| | mask_strides = ( |
| | (mask.stride(0), mask.stride(2)) |
| | if mask_type == 1 else |
| | (0, 0) |
| | ) |
| |
|
| | BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16) |
| | BLOCK = 128 |
| | num_warps = 4 if head_dim <= 64 else 8 |
| | |
| | assert (BLOCK > chunksize) & (BLOCK % chunksize) == 0, "BLOCK must be divisible by chunksize" |
| | chunks_per_block = BLOCK // chunksize |
| |
|
| | partial_groups = triton.cdiv(seqlen, BLOCK) |
| | d_param_mu_partial = torch.zeros((batch, nheads, partial_groups, head_dim), dtype=torch.float32, device=d_rfa_k.device) |
| | d_param_phi_partial = torch.zeros((batch, nheads, partial_groups, head_dim), dtype=torch.float32, device=d_rfa_k.device) |
| | grid = lambda META: (partial_groups, batch * nheads) |
| | _bwd_eva_prep_kv_kernel[grid]( |
| | rfa_k, |
| | rfa_v, |
| | k, |
| | v, |
| | param_mu, |
| | param_phi, |
| | mask, |
| | d_rfa_k, |
| | d_rfa_v, |
| | d_k, |
| | d_v, |
| | d_param_mu_partial, |
| | d_param_phi_partial, |
| | softmax_scale, |
| | rfa_k.stride(0), rfa_k.stride(1), rfa_k.stride(2), |
| | rfa_v.stride(0), rfa_v.stride(1), rfa_v.stride(2), |
| | k.stride(0), k.stride(1), k.stride(2), |
| | v.stride(0), v.stride(1), v.stride(2), |
| | param_mu.stride(1), |
| | param_phi.stride(1), |
| | mask_strides[0], mask_strides[1], |
| | d_rfa_k.stride(0), d_rfa_k.stride(1), d_rfa_k.stride(2), |
| | d_rfa_v.stride(0), d_rfa_v.stride(1), d_rfa_v.stride(2), |
| | d_k.stride(0), d_k.stride(1), d_k.stride(2), |
| | d_v.stride(0), d_v.stride(1), d_v.stride(2), |
| | d_param_mu_partial.stride(0), d_param_mu_partial.stride(1), d_param_mu_partial.stride(2), |
| | d_param_phi_partial.stride(0), d_param_phi_partial.stride(1), d_param_phi_partial.stride(2), |
| | nheads, |
| | seqlen, |
| | nchunks, |
| | head_dim, |
| | chunks_per_block, |
| | chunksize, |
| | mask_type, |
| | BLOCK_HEADDIM, |
| | BLOCK_N=BLOCK, |
| | num_warps=num_warps, |
| | num_stages=1, |
| | ) |
| | d_param_mu.copy_(d_param_mu_partial.sum(dim=(0, -2), keepdim=True).unsqueeze(-2).to(d_param_mu.dtype)) |
| | d_param_phi.copy_(d_param_phi_partial.sum(dim=(0, -2), keepdim=True).unsqueeze(-2).to(d_param_phi.dtype)) |
| |
|
| |
|
| |
|
| | class EvaPrepKVFunc(torch.autograd.Function): |
| | @staticmethod |
| | def forward(ctx, k, v, param_mu, param_phi, mask, softmax_scale=None, chunksize=None): |
| | if mask is not None: |
| | mask_type = 1 |
| | else: |
| | mask_type = 0 |
| | rfa_k, rfa_v = triton_eva_prep_kv_fwd( |
| | k, v, param_mu, param_phi, mask, softmax_scale, chunksize |
| | ) |
| | ctx.save_for_backward(k, v, param_mu, param_phi, mask, rfa_k, rfa_v) |
| | ctx.softmax_scale = softmax_scale |
| | ctx.chunksize = chunksize |
| | ctx.mask_type = mask_type |
| | return rfa_k, rfa_v |
| |
|
| | @staticmethod |
| | def backward(ctx, d_rfa_k, d_rfa_v): |
| | k, v, param_mu, param_phi, mask, rfa_k, rfa_v = ctx.saved_tensors |
| | d_k = torch.empty_like(k) |
| | d_v = torch.empty_like(v) |
| | d_param_mu = torch.empty_like(param_mu) |
| | d_param_phi = torch.empty_like(param_phi) |
| | triton_eva_prep_kv_bwd( |
| | d_rfa_k, d_rfa_v, |
| | k, v, param_mu, param_phi, |
| | mask, |
| | rfa_k, rfa_v, |
| | d_k, d_v, d_param_mu, d_param_phi, |
| | ctx.softmax_scale, |
| | ctx.mask_type, |
| | ctx.chunksize |
| | ) |
| | return d_k, d_v, d_param_mu, d_param_phi, None, None, None |
| |
|
| | def eva_prep_kv_func_triton( |
| | k, v, |
| | param_mu, param_phi, |
| | mask, |
| | softmax_scale=None, chunksize=None |
| | ): |
| | return EvaPrepKVFunc.apply( |
| | k, v, |
| | param_mu, param_phi, |
| | mask, |
| | softmax_scale, chunksize |
| | ) |
| |
|