@triton.autotune(
configs=[
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
num_stages=5,
num_warps=2,
),
triton.Config(
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=5,
num_warps=2,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=2,
),
],
key=["chunk_size", "K", "IS_CAUSAL"],
)
@triton.jit
def _bmm_chunk_fwd_kernel(
# Pointers to matrices
a_ptr,
b_ptr,
out_ptr,
cu_chunk_seqlens_ptr,
# Matrix dimensions
seqlen,
chunk_size: tl.constexpr,
K: tl.constexpr,
ngroups: tl.constexpr,
stride_a_seqlen: tl.int64,
stride_a_head: tl.int64,
stride_ak: tl.constexpr,
stride_b_seqlen: tl.int64,
stride_b_head: tl.int64,
stride_bk: tl.constexpr,
stride_out_chunk: tl.int64,
stride_out_head: tl.int64,
stride_outm: tl.int64,
stride_outn: tl.constexpr,
# Meta-parameters
IS_CAUSAL: tl.constexpr,
dot_dtype: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
):
pid_ch = tl.program_id(axis=1).to(tl.int64)
pid_c = pid_ch // ngroups
pid_h = pid_ch - pid_c * ngroups
num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N)
pid_m = tl.program_id(axis=0) // num_pid_n
pid_n = tl.program_id(axis=0) % num_pid_n
if IS_CAUSAL:
if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M:
return
chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c)
chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1)
a_ptr += chunk_seqlen_start * stride_a_seqlen + pid_h * stride_a_head
b_ptr += chunk_seqlen_start * stride_b_seqlen + pid_h * stride_b_head
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen)
chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
# compute a * b.T
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(
a_ptrs,
mask=(offs_m[:, None] < chunk_size_limit)
& (offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0,
).to(dot_dtype)
b = tl.load(
b_ptrs,
mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K)
& (offs_n[None, :] < chunk_size_limit),
other=0.0,
).to(dot_dtype)
acc += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
out = acc.to(out_ptr.dtype.element_ty)
out_ptr += pid_c * stride_out_chunk + pid_h * stride_out_head
out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn)
tl.store(
out_ptrs,
out,
mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size),
)