@triton.autotune(
configs=[
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
num_stages=5,
num_warps=2,
),
triton.Config(
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=5,
num_warps=2,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=2,
),
],
key=["chunk_size", "hdim", "dstate", "IS_CAUSAL"],
)
@triton.jit
def _chunk_scan_fwd_kernel(
# Pointers to matrices
cb_ptr,
x_ptr,
z_ptr,
out_ptr,
dt_ptr,
dA_cumsum_ptr,
seq_idx_ptr,
C_ptr,
states_ptr,
D_ptr,
initstates_ptr,
cu_chunk_seqlens_ptr,
# Matrix dimensions
chunk_size: tl.constexpr,
hdim: tl.constexpr,
dstate: tl.constexpr,
seqlen,
nheads_ngroups_ratio: tl.constexpr,
# Strides
stride_cb_chunk: tl.int64,
stride_cb_head: tl.int64,
stride_cb_csize_m: tl.int64,
stride_cb_csize_k: tl.constexpr,
stride_x_seqlen: tl.int64,
stride_x_head: tl.int64,
stride_x_hdim: tl.constexpr,
stride_z_seqlen: tl.int64,
stride_z_head: tl.int64,
stride_z_hdim: tl.constexpr,
stride_out_seqlen: tl.int64,
stride_out_head: tl.int64,
stride_out_hdim: tl.constexpr,
stride_dt_chunk: tl.int64,
stride_dt_head: tl.int64,
stride_dt_csize: tl.constexpr,
stride_dA_cs_chunk: tl.int64,
stride_dA_cs_head: tl.int64,
stride_dA_cs_csize: tl.constexpr,
stride_seq_idx_chunk: tl.constexpr,
stride_C_seqlen: tl.int64,
stride_C_head: tl.int64,
stride_C_dstate: tl.constexpr,
stride_states_chunk: tl.int64,
stride_states_head: tl.int64,
stride_states_hdim: tl.int64,
stride_states_dstate: tl.constexpr,
stride_init_states_batch: tl.int64,
stride_init_states_head: tl.int64,
stride_init_states_hdim: tl.int64,
stride_init_states_dstate: tl.constexpr,
stride_D_head: tl.constexpr,
# Meta-parameters
IS_CAUSAL: tl.constexpr,
HAS_D: tl.constexpr,
D_HAS_HDIM: tl.constexpr,
HAS_Z: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
BLOCK_SIZE_DSTATE: tl.constexpr,
IS_TRITON_22: tl.constexpr,
HAS_INITSTATES: tl.constexpr,
):
pid_c = tl.program_id(axis=1).to(tl.int64)
pid_h = tl.program_id(axis=2)
num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
pid_m = tl.program_id(axis=0) // num_pid_n
pid_n = tl.program_id(axis=0) % num_pid_n
cb_ptr += pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head
chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c)
chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1)
x_ptr += chunk_seqlen_start * stride_x_seqlen + pid_h * stride_x_head
dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
C_ptr += (
chunk_seqlen_start * stride_C_seqlen
+ (pid_h // nheads_ngroups_ratio) * stride_C_head
)
# M-block offsets and prev states
# - logic in next block may override these if there is an active offset
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
seq_idx_ptr += pid_c * stride_seq_idx_chunk
seq_idx = tl.load(seq_idx_ptr)
seq_idx_prev = tl.load(
seq_idx_ptr - stride_seq_idx_chunk, mask=pid_c >= 1, other=-1
)
if HAS_INITSTATES and (seq_idx != seq_idx_prev):
prev_states_ptr = (
initstates_ptr
+ seq_idx * stride_init_states_batch
+ pid_h * stride_init_states_head
)
prev_states_hdim = stride_init_states_hdim
prev_states_dstate = stride_init_states_dstate
else:
prev_states_ptr = (
states_ptr + (pid_c - 1) * stride_states_chunk + pid_h * stride_states_head
)
prev_states_hdim = stride_states_hdim
prev_states_dstate = stride_states_dstate
chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
dA_cs_m = tl.load(
dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0
).to(tl.float32)
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
offs_out_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
# Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
offs_k_dstate = tl.arange(
0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K
)
C_ptrs = C_ptr + (
offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate
)
scale_m = tl.exp(dA_cs_m)
if BLOCK_SIZE_DSTATE <= 128:
C = tl.load(
C_ptrs,
mask=(offs_m[:, None] < chunk_size_limit)
& (offs_k_dstate[None, :] < dstate),
other=0.0,
)
if not HAS_INITSTATES and (seq_idx != seq_idx_prev):
# if no init states AND starting a new sequence, we need zeros
prev_states = tl.zeros(
(BLOCK_SIZE_DSTATE, BLOCK_SIZE_N), dtype=C_ptr.dtype.element_ty
)
else:
# otherwise read the previous state
prev_states_ptrs = (
prev_states_ptr
+ offs_n[None, :] * prev_states_hdim
+ offs_k_dstate[:, None] * prev_states_dstate
)
prev_states = tl.load(
prev_states_ptrs,
mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim),
other=0.0,
)
prev_states = prev_states.to(C_ptr.dtype.element_ty)
acc = tl.dot(C, prev_states) * scale_m[:, None]
else:
prev_states_ptrs = (
prev_states_ptr
+ offs_n[None, :] * prev_states_hdim
+ offs_k_dstate[:, None] * prev_states_dstate
)
for k in range(0, dstate, BLOCK_SIZE_K):
C = tl.load(
C_ptrs,
mask=(offs_m[:, None] < chunk_size_limit)
& (offs_k_dstate[None, :] < dstate - k),
other=0.0,
)
if not HAS_INITSTATES and (seq_idx != seq_idx_prev):
prev_states = tl.zeros(
(BLOCK_SIZE_DSTATE, BLOCK_SIZE_K), dtype=C_ptr.dtype.element_ty
)
else:
prev_states = tl.load(
prev_states_ptrs,
mask=(offs_k_dstate[:, None] < dstate - k)
& (offs_n[None, :] < hdim),
other=0.0,
)
prev_states = prev_states.to(C_ptr.dtype.element_ty)
acc += tl.dot(C, prev_states)
C_ptrs += BLOCK_SIZE_K
prev_states_ptrs += BLOCK_SIZE_K
acc *= scale_m[:, None]
offs_k = tl.arange(0, BLOCK_SIZE_K)
cb_ptrs = cb_ptr + (
offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k
)
x_ptrs = x_ptr + (
offs_k[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim
)
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
K_MAX = (
chunk_size_limit
if not IS_CAUSAL
else min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit)
)
for k in range(0, K_MAX, BLOCK_SIZE_K):
cb = tl.load(
cb_ptrs,
mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < chunk_size - k),
other=0.0,
).to(tl.float32)
dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(
tl.float32
)
# If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j].
# So we don't need masking wrt seq_idx here.
cb *= tl.exp(dA_cs_m[:, None] - dA_cs_k[None, :])
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32)
cb *= dt_k
if IS_CAUSAL:
mask = offs_m[:, None] >= k + offs_k[None, :]
cb = tl.where(mask, cb, 0.0)
cb = cb.to(x_ptr.dtype.element_ty)
x = tl.load(
x_ptrs,
mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < hdim),
other=0.0,
)
acc += tl.dot(cb, x)
cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
offs_out_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
if HAS_D:
if D_HAS_HDIM:
D = tl.load(
D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0
).to(tl.float32)
else:
D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
x_residual = tl.load(
x_ptr
+ (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim),
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
other=0.0,
).to(tl.float32)
acc += x_residual * D
if HAS_Z:
z_ptr += chunk_seqlen_start * stride_z_seqlen + pid_h * stride_z_head
z_ptrs = z_ptr + (
stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :]
)
z = tl.load(
z_ptrs,
mask=(offs_out_m[:, None] < chunk_size_limit)
& (offs_out_n[None, :] < hdim),
other=0.0,
).to(tl.float32)
acc *= z * tl.sigmoid(z)
out_ptr += chunk_seqlen_start * stride_out_seqlen + pid_h * stride_out_head
out_ptrs = out_ptr + (
stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] * stride_out_hdim
)
tl.store(
out_ptrs,
acc,
mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim),
)