@triton.jit
def _fwd_kernel(
Q,
K,
V,
K_cache,
V_cache,
sink_ptr,
B_Loc,
sm_scale,
k_scale,
v_scale,
out_scale_inv,
B_Start_Loc,
B_Seqlen,
x: tl.constexpr,
Out,
stride_b_loc_b,
stride_b_loc_s,
stride_qbs,
stride_qh,
stride_qd,
stride_kbs,
stride_kh,
stride_kd,
stride_vbs,
stride_vh,
stride_vd,
stride_obs,
stride_oh,
stride_od,
stride_k_cache_bs,
stride_k_cache_h,
stride_k_cache_d,
stride_k_cache_bl: tl.constexpr,
stride_k_cache_x,
stride_v_cache_bs,
stride_v_cache_h,
stride_v_cache_d,
stride_v_cache_bl,
num_queries_per_kv: tl.constexpr,
IN_PRECISION: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_DMODEL_PADDED: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
BLOCK_N: tl.constexpr,
SLIDING_WINDOW: tl.constexpr,
num_unroll_cache: tl.constexpr,
num_unroll_request: tl.constexpr,
SKIP_DECODE: tl.constexpr,
USE_SINKS: tl.constexpr,
USE_FP8: tl.constexpr,
MAX_Q_LEN: tl.constexpr = 0,
MAX_CTX_LEN: tl.constexpr = 0,
FP8_MIN: tl.constexpr = float8_info.min,
FP8_MAX: tl.constexpr = float8_info.max,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_m = tl.program_id(2)
cur_kv_head = cur_head // num_queries_per_kv
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1)
cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index
cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len
if SKIP_DECODE and cur_batch_query_len == 1:
return
# start position inside of the query
# generally, N goes over kv, while M goes over query_len
block_start_loc = BLOCK_M * start_m
# initialize offsets
# [BLOCK_SIZE]; starts at 0
offs_bs_n = tl.arange(0, BLOCK_SIZE)
# [N]; starts at 0
offs_n = tl.arange(0, BLOCK_N)
# [D]; starts at 0
offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
# [M]; starts at current position in query
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# [M,D]
off_q = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs
+ cur_head * stride_qh
+ offs_d[None, :] * stride_qd
)
dim_mask = tl.where(tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(
tl.int1
) # [D]
q = tl.load(
Q + off_q,
mask=dim_mask[None, :] & (offs_m[:, None] < cur_batch_query_len),
other=0.0,
) # [M,D]
# initialize pointer to m and l
if not USE_SINKS:
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
else:
m_i = tl.load(
sink_ptr + tl.full([BLOCK_M], cur_head, dtype=tl.int64),
mask=(offs_m < cur_batch_query_len),
other=float("-inf"),
).to(dtype=tl.float32)
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) # [M,D]
# compute query against context (no causal mask here)
for start_n in tl.range(
0, cur_batch_ctx_len, BLOCK_SIZE, loop_unroll_factor=num_unroll_cache
):
start_n = tl.multiple_of(start_n, BLOCK_SIZE)
# -- compute qk ----
bn = tl.load(
B_Loc
+ cur_batch * stride_b_loc_b
+ (start_n // BLOCK_SIZE) * stride_b_loc_s
).to(tl.int64)
# [D,BLOCK_SIZE]
off_k = (
bn[None, :] * stride_k_cache_bs
+ cur_kv_head * stride_k_cache_h
+ (offs_d[:, None] // x) * stride_k_cache_d
+ ((start_n + offs_bs_n[None, :]) % BLOCK_SIZE) * stride_k_cache_bl
+ (offs_d[:, None] % x) * stride_k_cache_x
)
# [BLOCK_SIZE,D]
off_v = (
bn[:, None] * stride_v_cache_bs
+ cur_kv_head * stride_v_cache_h
+ offs_d[None, :] * stride_v_cache_d
+ offs_bs_n[:, None] * stride_v_cache_bl
)
if (
start_n + BLOCK_SIZE > cur_batch_ctx_len
or BLOCK_DMODEL != BLOCK_DMODEL_PADDED
):
k_load = tl.load(
K_cache + off_k,
mask=dim_mask[:, None]
& ((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len),
other=0.0,
) # [D,N]
else:
k_load = tl.load(K_cache + off_k)
if k_load.dtype.is_fp8():
k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype)
else:
k = k_load
qk = tl.zeros([BLOCK_M, BLOCK_SIZE], dtype=tl.float32) # [M,N]
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
qk = tl.where(
(start_n + offs_bs_n[None, :]) < cur_batch_ctx_len, qk, float("-inf")
)
qk *= sm_scale
if SLIDING_WINDOW > 0:
# (cur_batch_ctx_len + offs_m[:, None]) are the positions of
# Q entries in sequence
# (start_n + offs_bs_n[None, :]) are the positions of
# KV entries in sequence
# So the condition makes sure each entry in Q only attends
# to KV entries not more than SLIDING_WINDOW away.
#
# We can't use -inf here, because the
# sliding window may lead to the entire row being masked.
# This then makes m_ij contain -inf, which causes NaNs in
# exp().
qk = tl.where(
(cur_batch_ctx_len + offs_m[:, None]) - (start_n + offs_bs_n[None, :])
< SLIDING_WINDOW,
qk,
-10000,
)
# compute running maximum
m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, axis=1)
alpha = tl.exp(m_i - m_ij)
acc = acc * alpha[:, None]
# update acc
if (
start_n + BLOCK_SIZE > cur_batch_ctx_len
or BLOCK_DMODEL != BLOCK_DMODEL_PADDED
):
v_load = tl.load(
V_cache + off_v,
mask=dim_mask[None, :]
& ((start_n + offs_bs_n[:, None]) < cur_batch_ctx_len),
other=0.0,
) # [N,D]
else:
v_load = tl.load(V_cache + off_v)
if v_load.dtype.is_fp8():
v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype)
else:
v = v_load
p = p.to(v.dtype)
acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION)
# # update m_i and l_i
l_i = l_i * alpha + l_ij
m_i = m_ij
off_k = (
offs_n[None, :] * stride_kbs
+ cur_kv_head * stride_kh
+ offs_d[:, None] * stride_kd
)
off_v = (
offs_n[:, None] * stride_vbs
+ cur_kv_head * stride_vh
+ offs_d[None, :] * stride_vd
)
k_ptrs = K + off_k
v_ptrs = V + off_v
# block_mask is 0 when we're already past the current query length
block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0)
# compute query against itself (with causal mask)
for start_n in tl.range(
0,
block_mask * (start_m + 1) * BLOCK_M,
BLOCK_N,
loop_unroll_factor=num_unroll_request,
):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=dim_mask[:, None]
& ((start_n + offs_n[None, :]) < cur_batch_query_len),
other=0.0,
)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION)
qk *= sm_scale
# apply causal mask
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
if SLIDING_WINDOW > 0:
qk = tl.where(
offs_m[:, None] - (start_n + offs_n[None, :]) < SLIDING_WINDOW,
qk,
-10000,
)
# compute running maximum
m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, axis=1)
alpha = tl.exp(m_i - m_ij)
acc = acc * alpha[:, None]
# update acc
v = tl.load(
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
mask=dim_mask[None, :]
& ((start_n + offs_n[:, None]) < cur_batch_query_len),
other=0.0,
)
p = p.to(v.dtype)
acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION)
# update m_i and l_i
l_i = l_i * alpha + l_ij
m_i = m_ij
acc = acc / l_i[:, None]
# initialize pointers to output
off_o = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs
+ cur_head * stride_oh
+ offs_d[None, :] * stride_od
)
out_ptrs = Out + off_o
if USE_FP8:
acc = acc * tl.load(out_scale_inv)
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
tl.store(
out_ptrs, acc, mask=dim_mask[None, :] & (offs_m[:, None] < cur_batch_query_len)
)
return