Skip to content

vllm.v1.attention.backends.mamba2_attn ΒΆ

Mamba2AttentionBackend ΒΆ

Bases: AttentionBackend

Source code in vllm/v1/attention/backends/mamba2_attn.py
class Mamba2AttentionBackend(AttentionBackend):
    @staticmethod
    def get_builder_cls() -> type["Mamba2AttentionMetadataBuilder"]:
        return Mamba2AttentionMetadataBuilder

get_builder_cls staticmethod ΒΆ

get_builder_cls() -> type[Mamba2AttentionMetadataBuilder]
Source code in vllm/v1/attention/backends/mamba2_attn.py
@staticmethod
def get_builder_cls() -> type["Mamba2AttentionMetadataBuilder"]:
    return Mamba2AttentionMetadataBuilder

Mamba2AttentionMetadata dataclass ΒΆ

Source code in vllm/v1/attention/backends/mamba2_attn.py
@dataclass
class Mamba2AttentionMetadata:
    num_prefills: int
    num_prefill_tokens: int
    num_decodes: int
    num_decode_tokens: int
    query_start_loc_p: torch.Tensor
    seq_lens: torch.Tensor

    prep_initial_states: bool
    chunk_size: int

    # The following tensors only contain prefill requests and will be None if
    # the batch has no prefill request.
    has_initial_states_p: Optional[torch.Tensor]
    seq_idx_p: Optional[torch.Tensor]

    # cu_chunk_seqlen_p is a tensor of shape (nchunks+1,) that contains, for
    # each chunk, its offests into the varlen sequence dimension. It is defined
    # such that the i-th chunk contains tokens from cu_chunk_seqlen_p[i] to
    # cu_chunk_seqlen_p[i+1].
    cu_chunk_seqlen_p: Optional[torch.Tensor]

    # last_chunk_indices_p is a tensor of shape (batch,) that contains the
    # index of the last chunk for every sequence in the (prefill) batch.
    last_chunk_indices_p: Optional[torch.Tensor]

    state_indices_tensor: torch.Tensor  # shape: [batch,]
    block_idx_last_scheduled_token: torch.Tensor  # shape: [batch,]
    block_idx_first_scheduled_token_p: torch.Tensor  # shape: [batch,]
    block_idx_last_computed_token: torch.Tensor  # shape: [batch,]
    num_computed_tokens_p: torch.Tensor  # shape: [batch,]

    # The following attributes are for triton implementation of causal_conv1d
    nums_dict: Optional[dict] = None
    batch_ptr: Optional[torch.Tensor] = None
    token_chunk_offset_ptr: Optional[torch.Tensor] = None

batch_ptr class-attribute instance-attribute ΒΆ

batch_ptr: Optional[Tensor] = None

block_idx_first_scheduled_token_p instance-attribute ΒΆ

block_idx_first_scheduled_token_p: Tensor

block_idx_last_computed_token instance-attribute ΒΆ

block_idx_last_computed_token: Tensor

block_idx_last_scheduled_token instance-attribute ΒΆ

block_idx_last_scheduled_token: Tensor

chunk_size instance-attribute ΒΆ

chunk_size: int

cu_chunk_seqlen_p instance-attribute ΒΆ

cu_chunk_seqlen_p: Optional[Tensor]

has_initial_states_p instance-attribute ΒΆ

has_initial_states_p: Optional[Tensor]

last_chunk_indices_p instance-attribute ΒΆ

last_chunk_indices_p: Optional[Tensor]

num_computed_tokens_p instance-attribute ΒΆ

num_computed_tokens_p: Tensor

num_decode_tokens instance-attribute ΒΆ

num_decode_tokens: int

num_decodes instance-attribute ΒΆ

num_decodes: int

num_prefill_tokens instance-attribute ΒΆ

num_prefill_tokens: int

num_prefills instance-attribute ΒΆ

num_prefills: int

nums_dict class-attribute instance-attribute ΒΆ

nums_dict: Optional[dict] = None

prep_initial_states instance-attribute ΒΆ

prep_initial_states: bool

query_start_loc_p instance-attribute ΒΆ

query_start_loc_p: Tensor

seq_idx_p instance-attribute ΒΆ

seq_idx_p: Optional[Tensor]

seq_lens instance-attribute ΒΆ

seq_lens: Tensor

state_indices_tensor instance-attribute ΒΆ

state_indices_tensor: Tensor

token_chunk_offset_ptr class-attribute instance-attribute ΒΆ

token_chunk_offset_ptr: Optional[Tensor] = None

__init__ ΒΆ

__init__(
    num_prefills: int,
    num_prefill_tokens: int,
    num_decodes: int,
    num_decode_tokens: int,
    query_start_loc_p: Tensor,
    seq_lens: Tensor,
    prep_initial_states: bool,
    chunk_size: int,
    has_initial_states_p: Optional[Tensor],
    seq_idx_p: Optional[Tensor],
    cu_chunk_seqlen_p: Optional[Tensor],
    last_chunk_indices_p: Optional[Tensor],
    state_indices_tensor: Tensor,
    block_idx_last_scheduled_token: Tensor,
    block_idx_first_scheduled_token_p: Tensor,
    block_idx_last_computed_token: Tensor,
    num_computed_tokens_p: Tensor,
    nums_dict: Optional[dict] = None,
    batch_ptr: Optional[Tensor] = None,
    token_chunk_offset_ptr: Optional[Tensor] = None,
) -> None

Mamba2AttentionMetadataBuilder ΒΆ

Bases: BaseMambaAttentionMetadataBuilder[Mamba2AttentionMetadata]

Source code in vllm/v1/attention/backends/mamba2_attn.py
class Mamba2AttentionMetadataBuilder(
    BaseMambaAttentionMetadataBuilder[Mamba2AttentionMetadata]
):
    def __init__(
        self,
        kv_cache_spec: AttentionSpec,
        layer_names: list[str],
        vllm_config: VllmConfig,
        device: torch.device,
    ):
        super().__init__(kv_cache_spec, layer_names, vllm_config, device)
        self.chunk_size = vllm_config.model_config.get_mamba_chunk_size()
        assert self.chunk_size is not None, (
            "chunk_size needs to be set in the model config for Mamba2 models"
        )
        if self.vllm_config.cache_config.enable_prefix_caching:
            self.state_indices_tensor = torch.empty(
                (
                    self.decode_cudagraph_max_bs,
                    cdiv(
                        vllm_config.model_config.max_model_len, kv_cache_spec.block_size
                    ),
                ),
                dtype=torch.int32,
                device=device,
            )
            self.block_idx_last_scheduled_token = torch.empty(
                (self.decode_cudagraph_max_bs,),
                dtype=torch.int32,
                device=device,
            )
            self.block_idx_last_computed_token = torch.empty(
                (self.decode_cudagraph_max_bs,),
                dtype=torch.int32,
                device=device,
            )

    def build(
        self,
        common_prefix_len: int,
        common_attn_metadata: CommonAttentionMetadata,
        fast_build: bool = False,
    ) -> Mamba2AttentionMetadata:
        num_reqs = common_attn_metadata.num_reqs
        seq_lens = common_attn_metadata.seq_lens

        query_start_loc_p = None
        seq_idx_p = None
        cu_chunk_seqlen_p = None
        last_chunk_indices_p = None

        # Need flags to indicate if there are initial states
        has_initial_states_p = None
        prep_initial_states = False

        # for causal_conv1d
        nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None

        num_computed_tokens, num_computed_tokens_p = None, None
        block_idx_first_scheduled_token = None
        block_idx_first_scheduled_token_p = None

        if self.vllm_config.cache_config.enable_prefix_caching:
            # Return a tensor of shape (#requests, #max blocks)
            state_indices_tensor = common_attn_metadata.block_table_tensor
            # Additional cache-related varaiables:
            mamba_block_size = self.kv_cache_spec.block_size
            num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(
                self.device
            )
            # Block index of the last computed token
            block_idx_last_computed_token = (
                cdiv(num_computed_tokens, mamba_block_size) - 1
            )
            # which is <= block index for the first scheduled token
            block_idx_first_scheduled_token = (
                cdiv(num_computed_tokens + 1, mamba_block_size) - 1
            )
            # which is <= block index of the last scheduled token
            block_idx_last_scheduled_token = (
                cdiv(common_attn_metadata.seq_lens, mamba_block_size) - 1
            )
            # -1 in case it's non-computed and causes later issues with indexing
            block_idx_last_computed_token = block_idx_last_computed_token.clamp(min=0)
        else:
            # Always return just a single block per each request:
            state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
            # Additional cache-related varaiables:
            block_idx_last_scheduled_token = None
            block_idx_last_computed_token = None

        num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
            split_decodes_and_prefills(
                common_attn_metadata, decode_threshold=self.reorder_batch_threshold
            )
        )

        # Compute seq_idx for prefill only
        if num_prefills > 0:
            # [batch,]
            has_initial_states_cpu = (
                common_attn_metadata.num_computed_tokens_cpu[
                    num_reqs - num_prefills : num_reqs
                ]
                > 0
            )
            prep_initial_states = torch.any(has_initial_states_cpu).item()
            has_initial_states_p = has_initial_states_cpu.to(
                common_attn_metadata.query_start_loc.device
            )

            query_start_loc_p = (
                common_attn_metadata.query_start_loc[-num_prefills - 1 :]
                - num_decode_tokens
            )

            if self.vllm_config.cache_config.enable_prefix_caching:
                assert num_computed_tokens is not None
                num_computed_tokens_p = num_computed_tokens[
                    num_reqs - num_prefills : num_reqs
                ]
                assert block_idx_first_scheduled_token is not None
                block_idx_first_scheduled_token_p = block_idx_first_scheduled_token[
                    num_reqs - num_prefills : num_reqs
                ]
            num_computed_tokens_p_cpu = common_attn_metadata.num_computed_tokens_cpu[
                num_reqs - num_prefills : num_reqs
            ]
            query_start_loc_p_cpu = (
                common_attn_metadata.query_start_loc_cpu[-num_prefills - 1 :]
                - num_decode_tokens
            )

            # The code below carefully constructs the chunks such that:
            # 1. Chunks contain tokens from a *single* sequence only.
            # 2. For every sequence, we are guaranteed that we can
            #    retrieve the mamba state *every* chunk_size tokens.
            # Constraint (1) dramatically simplifies the mamba2 kernels.
            # Constraint (2) dramatically simplifies the implementation
            # of prefix caching for mamba2 (wip). We need to take care
            # of the interaction with chunked prefill in order to
            # satisfy constraint (2).
            # TODO (tdoublep): This code could probably be optimized.
            cu_chunk_seqlen = []
            seq_idx = []
            last_chunk_indices = []
            seqlen_pos = 0
            for req_idx in range(num_prefills):
                this_num_computed = num_computed_tokens_p_cpu[req_idx].item()
                this_new_tokens = (
                    query_start_loc_p_cpu[req_idx + 1].item()
                    - query_start_loc_p_cpu[req_idx].item()
                )

                # if computed tokens are not chunk-aligned, use the first
                # chunk to finish it off
                if this_num_computed % self.chunk_size != 0:
                    seq_idx.append(req_idx)
                    cu_chunk_seqlen.append(seqlen_pos)
                    # how many tokens to finish the chunk?
                    chunk_len = (
                        cdiv(this_num_computed, self.chunk_size) * self.chunk_size
                        - this_num_computed
                    )
                    # we can only use at most this_new_tokens
                    chunk_len = min(chunk_len, this_new_tokens)
                    seqlen_pos += chunk_len
                    this_new_tokens -= chunk_len

                n_chunks = cdiv(this_new_tokens, self.chunk_size)
                for chunk in range(n_chunks):
                    seq_idx.append(req_idx)
                    cu_chunk_seqlen.append(seqlen_pos)
                    chunk_len = min(self.chunk_size, this_new_tokens)
                    seqlen_pos += chunk_len
                    this_new_tokens -= chunk_len

                assert this_new_tokens == 0
                last_chunk_indices.append(len(cu_chunk_seqlen) - 1)

            cu_chunk_seqlen.append(seqlen_pos)

            seq_idx_p = torch.as_tensor(
                seq_idx, device=query_start_loc_p.device, dtype=torch.int32
            )
            cu_chunk_seqlen_p = torch.as_tensor(
                cu_chunk_seqlen, device=query_start_loc_p.device, dtype=torch.int32
            )
            last_chunk_indices_p = torch.as_tensor(
                last_chunk_indices, device=query_start_loc_p.device, dtype=torch.int32
            )

            nums_dict, batch_ptr, token_chunk_offset_ptr = (
                compute_causal_conv1d_metadata(query_start_loc_p)
            )

        elif (
            num_decodes <= self.decode_cudagraph_max_bs
            and self.compilation_config.full_cuda_graph
        ):
            # Pad state tensor for CUDA graph
            num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes)
            self.state_indices_tensor[:num_decodes].copy_(
                state_indices_tensor, non_blocking=True
            )
            state_indices_tensor = self.state_indices_tensor[:num_input_tokens]
            state_indices_tensor[num_decodes:] = PAD_SLOT_ID

            if self.vllm_config.cache_config.enable_prefix_caching:
                self.block_idx_last_scheduled_token[:num_decodes].copy_(
                    block_idx_last_scheduled_token, non_blocking=True
                )
                block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[
                    :num_input_tokens
                ]
                block_idx_last_scheduled_token[num_decodes:] = 0

                self.block_idx_last_computed_token[:num_decodes].copy_(
                    block_idx_last_computed_token, non_blocking=True
                )
                block_idx_last_computed_token = self.block_idx_last_computed_token[
                    :num_input_tokens
                ]
                block_idx_last_computed_token[num_decodes:] = 0

        attn_metadata = Mamba2AttentionMetadata(
            num_prefills=num_prefills,
            num_prefill_tokens=num_prefill_tokens,
            num_decodes=num_decodes,
            num_decode_tokens=num_decode_tokens,
            query_start_loc_p=query_start_loc_p,
            seq_lens=seq_lens,
            prep_initial_states=prep_initial_states,
            chunk_size=self.chunk_size,
            has_initial_states_p=has_initial_states_p,
            seq_idx_p=seq_idx_p,
            state_indices_tensor=state_indices_tensor,
            cu_chunk_seqlen_p=cu_chunk_seqlen_p,
            last_chunk_indices_p=last_chunk_indices_p,
            nums_dict=nums_dict,
            batch_ptr=batch_ptr,
            token_chunk_offset_ptr=token_chunk_offset_ptr,
            block_idx_last_scheduled_token=block_idx_last_scheduled_token,
            block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p,
            block_idx_last_computed_token=block_idx_last_computed_token,
            num_computed_tokens_p=num_computed_tokens_p,
        )
        return attn_metadata

block_idx_last_computed_token instance-attribute ΒΆ

block_idx_last_computed_token = empty(
    (decode_cudagraph_max_bs,), dtype=int32, device=device
)

block_idx_last_scheduled_token instance-attribute ΒΆ

block_idx_last_scheduled_token = empty(
    (decode_cudagraph_max_bs,), dtype=int32, device=device
)

chunk_size instance-attribute ΒΆ

chunk_size = get_mamba_chunk_size()

state_indices_tensor instance-attribute ΒΆ

state_indices_tensor = empty(
    (
        decode_cudagraph_max_bs,
        cdiv(max_model_len, block_size),
    ),
    dtype=int32,
    device=device,
)

__init__ ΒΆ

__init__(
    kv_cache_spec: AttentionSpec,
    layer_names: list[str],
    vllm_config: VllmConfig,
    device: device,
)
Source code in vllm/v1/attention/backends/mamba2_attn.py
def __init__(
    self,
    kv_cache_spec: AttentionSpec,
    layer_names: list[str],
    vllm_config: VllmConfig,
    device: torch.device,
):
    super().__init__(kv_cache_spec, layer_names, vllm_config, device)
    self.chunk_size = vllm_config.model_config.get_mamba_chunk_size()
    assert self.chunk_size is not None, (
        "chunk_size needs to be set in the model config for Mamba2 models"
    )
    if self.vllm_config.cache_config.enable_prefix_caching:
        self.state_indices_tensor = torch.empty(
            (
                self.decode_cudagraph_max_bs,
                cdiv(
                    vllm_config.model_config.max_model_len, kv_cache_spec.block_size
                ),
            ),
            dtype=torch.int32,
            device=device,
        )
        self.block_idx_last_scheduled_token = torch.empty(
            (self.decode_cudagraph_max_bs,),
            dtype=torch.int32,
            device=device,
        )
        self.block_idx_last_computed_token = torch.empty(
            (self.decode_cudagraph_max_bs,),
            dtype=torch.int32,
            device=device,
        )

build ΒΆ

build(
    common_prefix_len: int,
    common_attn_metadata: CommonAttentionMetadata,
    fast_build: bool = False,
) -> Mamba2AttentionMetadata
Source code in vllm/v1/attention/backends/mamba2_attn.py
def build(
    self,
    common_prefix_len: int,
    common_attn_metadata: CommonAttentionMetadata,
    fast_build: bool = False,
) -> Mamba2AttentionMetadata:
    num_reqs = common_attn_metadata.num_reqs
    seq_lens = common_attn_metadata.seq_lens

    query_start_loc_p = None
    seq_idx_p = None
    cu_chunk_seqlen_p = None
    last_chunk_indices_p = None

    # Need flags to indicate if there are initial states
    has_initial_states_p = None
    prep_initial_states = False

    # for causal_conv1d
    nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None

    num_computed_tokens, num_computed_tokens_p = None, None
    block_idx_first_scheduled_token = None
    block_idx_first_scheduled_token_p = None

    if self.vllm_config.cache_config.enable_prefix_caching:
        # Return a tensor of shape (#requests, #max blocks)
        state_indices_tensor = common_attn_metadata.block_table_tensor
        # Additional cache-related varaiables:
        mamba_block_size = self.kv_cache_spec.block_size
        num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(
            self.device
        )
        # Block index of the last computed token
        block_idx_last_computed_token = (
            cdiv(num_computed_tokens, mamba_block_size) - 1
        )
        # which is <= block index for the first scheduled token
        block_idx_first_scheduled_token = (
            cdiv(num_computed_tokens + 1, mamba_block_size) - 1
        )
        # which is <= block index of the last scheduled token
        block_idx_last_scheduled_token = (
            cdiv(common_attn_metadata.seq_lens, mamba_block_size) - 1
        )
        # -1 in case it's non-computed and causes later issues with indexing
        block_idx_last_computed_token = block_idx_last_computed_token.clamp(min=0)
    else:
        # Always return just a single block per each request:
        state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
        # Additional cache-related varaiables:
        block_idx_last_scheduled_token = None
        block_idx_last_computed_token = None

    num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
        split_decodes_and_prefills(
            common_attn_metadata, decode_threshold=self.reorder_batch_threshold
        )
    )

    # Compute seq_idx for prefill only
    if num_prefills > 0:
        # [batch,]
        has_initial_states_cpu = (
            common_attn_metadata.num_computed_tokens_cpu[
                num_reqs - num_prefills : num_reqs
            ]
            > 0
        )
        prep_initial_states = torch.any(has_initial_states_cpu).item()
        has_initial_states_p = has_initial_states_cpu.to(
            common_attn_metadata.query_start_loc.device
        )

        query_start_loc_p = (
            common_attn_metadata.query_start_loc[-num_prefills - 1 :]
            - num_decode_tokens
        )

        if self.vllm_config.cache_config.enable_prefix_caching:
            assert num_computed_tokens is not None
            num_computed_tokens_p = num_computed_tokens[
                num_reqs - num_prefills : num_reqs
            ]
            assert block_idx_first_scheduled_token is not None
            block_idx_first_scheduled_token_p = block_idx_first_scheduled_token[
                num_reqs - num_prefills : num_reqs
            ]
        num_computed_tokens_p_cpu = common_attn_metadata.num_computed_tokens_cpu[
            num_reqs - num_prefills : num_reqs
        ]
        query_start_loc_p_cpu = (
            common_attn_metadata.query_start_loc_cpu[-num_prefills - 1 :]
            - num_decode_tokens
        )

        # The code below carefully constructs the chunks such that:
        # 1. Chunks contain tokens from a *single* sequence only.
        # 2. For every sequence, we are guaranteed that we can
        #    retrieve the mamba state *every* chunk_size tokens.
        # Constraint (1) dramatically simplifies the mamba2 kernels.
        # Constraint (2) dramatically simplifies the implementation
        # of prefix caching for mamba2 (wip). We need to take care
        # of the interaction with chunked prefill in order to
        # satisfy constraint (2).
        # TODO (tdoublep): This code could probably be optimized.
        cu_chunk_seqlen = []
        seq_idx = []
        last_chunk_indices = []
        seqlen_pos = 0
        for req_idx in range(num_prefills):
            this_num_computed = num_computed_tokens_p_cpu[req_idx].item()
            this_new_tokens = (
                query_start_loc_p_cpu[req_idx + 1].item()
                - query_start_loc_p_cpu[req_idx].item()
            )

            # if computed tokens are not chunk-aligned, use the first
            # chunk to finish it off
            if this_num_computed % self.chunk_size != 0:
                seq_idx.append(req_idx)
                cu_chunk_seqlen.append(seqlen_pos)
                # how many tokens to finish the chunk?
                chunk_len = (
                    cdiv(this_num_computed, self.chunk_size) * self.chunk_size
                    - this_num_computed
                )
                # we can only use at most this_new_tokens
                chunk_len = min(chunk_len, this_new_tokens)
                seqlen_pos += chunk_len
                this_new_tokens -= chunk_len

            n_chunks = cdiv(this_new_tokens, self.chunk_size)
            for chunk in range(n_chunks):
                seq_idx.append(req_idx)
                cu_chunk_seqlen.append(seqlen_pos)
                chunk_len = min(self.chunk_size, this_new_tokens)
                seqlen_pos += chunk_len
                this_new_tokens -= chunk_len

            assert this_new_tokens == 0
            last_chunk_indices.append(len(cu_chunk_seqlen) - 1)

        cu_chunk_seqlen.append(seqlen_pos)

        seq_idx_p = torch.as_tensor(
            seq_idx, device=query_start_loc_p.device, dtype=torch.int32
        )
        cu_chunk_seqlen_p = torch.as_tensor(
            cu_chunk_seqlen, device=query_start_loc_p.device, dtype=torch.int32
        )
        last_chunk_indices_p = torch.as_tensor(
            last_chunk_indices, device=query_start_loc_p.device, dtype=torch.int32
        )

        nums_dict, batch_ptr, token_chunk_offset_ptr = (
            compute_causal_conv1d_metadata(query_start_loc_p)
        )

    elif (
        num_decodes <= self.decode_cudagraph_max_bs
        and self.compilation_config.full_cuda_graph
    ):
        # Pad state tensor for CUDA graph
        num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes)
        self.state_indices_tensor[:num_decodes].copy_(
            state_indices_tensor, non_blocking=True
        )
        state_indices_tensor = self.state_indices_tensor[:num_input_tokens]
        state_indices_tensor[num_decodes:] = PAD_SLOT_ID

        if self.vllm_config.cache_config.enable_prefix_caching:
            self.block_idx_last_scheduled_token[:num_decodes].copy_(
                block_idx_last_scheduled_token, non_blocking=True
            )
            block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[
                :num_input_tokens
            ]
            block_idx_last_scheduled_token[num_decodes:] = 0

            self.block_idx_last_computed_token[:num_decodes].copy_(
                block_idx_last_computed_token, non_blocking=True
            )
            block_idx_last_computed_token = self.block_idx_last_computed_token[
                :num_input_tokens
            ]
            block_idx_last_computed_token[num_decodes:] = 0

    attn_metadata = Mamba2AttentionMetadata(
        num_prefills=num_prefills,
        num_prefill_tokens=num_prefill_tokens,
        num_decodes=num_decodes,
        num_decode_tokens=num_decode_tokens,
        query_start_loc_p=query_start_loc_p,
        seq_lens=seq_lens,
        prep_initial_states=prep_initial_states,
        chunk_size=self.chunk_size,
        has_initial_states_p=has_initial_states_p,
        seq_idx_p=seq_idx_p,
        state_indices_tensor=state_indices_tensor,
        cu_chunk_seqlen_p=cu_chunk_seqlen_p,
        last_chunk_indices_p=last_chunk_indices_p,
        nums_dict=nums_dict,
        batch_ptr=batch_ptr,
        token_chunk_offset_ptr=token_chunk_offset_ptr,
        block_idx_last_scheduled_token=block_idx_last_scheduled_token,
        block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p,
        block_idx_last_computed_token=block_idx_last_computed_token,
        num_computed_tokens_p=num_computed_tokens_p,
    )
    return attn_metadata

compute_varlen_chunk_metadata ΒΆ

compute_varlen_chunk_metadata(
    query_start_loc: Tensor, chunk_size: int
) -> tuple[Tensor, Tensor, Tensor]

Build chunk-aligned, variable-length metadata used by Mamba2 SSD kernels.

Given per-sequence cumulative token starts query_start_loc of shape [B+1] and a physical chunk_size, returns three tensors on the same device: - cu_chunk_seqlens: (nchunks+1,) int32 exclusive prefix-sum of logical-chunk lengths (each logical chunk never crosses a sequence or physical-chunk boundary). - last_chunk_indices: (B,) int32 index of the last logical chunk for each sequence (=-1 for empty sequences). - seq_idx_chunks: (nchunks,) int32 sequence index for each logical chunk in order.

This is intentionally lightweight and CPU-side; it mirrors the metadata produced by the V1 Mamba2 meta-data builder and is exported so tests (and other callers) can avoid duplicating the logic.

Source code in vllm/v1/attention/backends/mamba2_attn.py
def compute_varlen_chunk_metadata(
    query_start_loc: torch.Tensor,
    chunk_size: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Build chunk-aligned, variable-length metadata used by Mamba2 SSD kernels.

    Given per-sequence cumulative token starts `query_start_loc` of shape [B+1]
    and a physical `chunk_size`, returns three tensors on the same device:
      - cu_chunk_seqlens:  (nchunks+1,) int32   exclusive prefix-sum of
        logical-chunk lengths (each logical chunk never crosses a sequence or
        physical-chunk boundary).
      - last_chunk_indices: (B,)       int32   index of the last logical chunk
        for each sequence (=-1 for empty sequences).
      - seq_idx_chunks:     (nchunks,) int32   sequence index for each logical
        chunk in order.

    This is intentionally lightweight and CPU-side; it mirrors the metadata
    produced by the V1 Mamba2 meta-data builder and is exported so tests
    (and other callers) can avoid duplicating the logic.
    """
    assert query_start_loc.ndim == 1, "query_start_loc must be 1-D [B+1]"
    assert int(query_start_loc[0].item()) == 0, "query_start_loc[0] must be 0"
    device = query_start_loc.device

    qsl64 = query_start_loc.to(torch.int64)
    starts = qsl64[:-1].tolist()
    ends = qsl64[1:].tolist()
    total = int(qsl64[-1].item())

    chunk_lens: list[int] = []
    seq_idx_chunks: list[int] = []
    last_chunk_indices: list[int] = [-1] * len(starts)

    for b, (s, e) in enumerate(zip(starts, ends)):
        if e <= s:
            # empty sequence
            continue
        pos = s
        while pos < e:
            # split at both sequence boundaries and physical chunk boundaries
            room = chunk_size - (pos % chunk_size)
            take = min(room, e - pos)
            chunk_lens.append(int(take))
            seq_idx_chunks.append(b)
            last_chunk_indices[b] = len(chunk_lens) - 1
            pos += take

    # Exclusive prefix sum over logical-chunk lengths
    if chunk_lens:
        cu_chunk_seqlens = torch.tensor(
            [0] + list(itertools.accumulate(chunk_lens)),
            device=device,
            dtype=torch.int32,
        )
        # Final boundary must equal total tokens
        assert int(cu_chunk_seqlens[-1].item()) == total
    else:
        cu_chunk_seqlens = torch.tensor([0], device=device, dtype=torch.int32)

    last_chunk_indices_t = (
        torch.tensor(last_chunk_indices, device=device, dtype=torch.int32)
        if len(starts) > 0
        else torch.empty((0,), device=device, dtype=torch.int32)
    )
    seq_idx_chunks_t = torch.tensor(seq_idx_chunks, device=device, dtype=torch.int32)
    return cu_chunk_seqlens, last_chunk_indices_t, seq_idx_chunks_t