Skip to content

vllm.v1.attention.backends.gdn_attn ΒΆ

Backend for GatedDeltaNet attention.

GDNAttentionBackend ΒΆ

Bases: AttentionBackend

Source code in vllm/v1/attention/backends/gdn_attn.py
class GDNAttentionBackend(AttentionBackend):
    @staticmethod
    def get_builder_cls() -> type["GDNAttentionMetadataBuilder"]:
        return GDNAttentionMetadataBuilder

get_builder_cls staticmethod ΒΆ

get_builder_cls() -> type[GDNAttentionMetadataBuilder]
Source code in vllm/v1/attention/backends/gdn_attn.py
@staticmethod
def get_builder_cls() -> type["GDNAttentionMetadataBuilder"]:
    return GDNAttentionMetadataBuilder

GDNAttentionMetadata dataclass ΒΆ

Source code in vllm/v1/attention/backends/gdn_attn.py
@dataclass
class GDNAttentionMetadata:
    num_prefills: int
    num_prefill_tokens: int
    num_decodes: int
    num_decode_tokens: int
    num_spec_decodes: int
    num_spec_decode_tokens: int
    num_actual_tokens: int

    has_initial_state: Optional[torch.Tensor] = None

    spec_query_start_loc: Optional[torch.Tensor] = (
        None  # shape: [num_spec_decodes + 1,]
    )
    non_spec_query_start_loc: Optional[torch.Tensor] = (
        None  # shape: [batch - num_spec_decodes + 1,]
    )

    spec_state_indices_tensor: Optional[torch.Tensor] = None  # shape: [batch, num_spec]
    non_spec_state_indices_tensor: Optional[torch.Tensor] = (
        None  # shape: [batch - num_spec_decodes,]
    )
    spec_sequence_masks: Optional[torch.Tensor] = None  # shape: [batch,]
    spec_token_masks: Optional[torch.Tensor] = (
        None  # shape: [num_prefill_tokens + num_decode_tokens,]
    )
    num_accepted_tokens: Optional[torch.Tensor] = None  # shape: [batch,]

    # The following attributes are for triton implementation of causal_conv1d
    nums_dict: Optional[dict] = None
    batch_ptr: Optional[torch.Tensor] = None
    token_chunk_offset_ptr: Optional[torch.Tensor] = None

batch_ptr class-attribute instance-attribute ΒΆ

batch_ptr: Optional[Tensor] = None

has_initial_state class-attribute instance-attribute ΒΆ

has_initial_state: Optional[Tensor] = None

non_spec_query_start_loc class-attribute instance-attribute ΒΆ

non_spec_query_start_loc: Optional[Tensor] = None

non_spec_state_indices_tensor class-attribute instance-attribute ΒΆ

non_spec_state_indices_tensor: Optional[Tensor] = None

num_accepted_tokens class-attribute instance-attribute ΒΆ

num_accepted_tokens: Optional[Tensor] = None

num_actual_tokens instance-attribute ΒΆ

num_actual_tokens: int

num_decode_tokens instance-attribute ΒΆ

num_decode_tokens: int

num_decodes instance-attribute ΒΆ

num_decodes: int

num_prefill_tokens instance-attribute ΒΆ

num_prefill_tokens: int

num_prefills instance-attribute ΒΆ

num_prefills: int

num_spec_decode_tokens instance-attribute ΒΆ

num_spec_decode_tokens: int

num_spec_decodes instance-attribute ΒΆ

num_spec_decodes: int

nums_dict class-attribute instance-attribute ΒΆ

nums_dict: Optional[dict] = None

spec_query_start_loc class-attribute instance-attribute ΒΆ

spec_query_start_loc: Optional[Tensor] = None

spec_sequence_masks class-attribute instance-attribute ΒΆ

spec_sequence_masks: Optional[Tensor] = None

spec_state_indices_tensor class-attribute instance-attribute ΒΆ

spec_state_indices_tensor: Optional[Tensor] = None

spec_token_masks class-attribute instance-attribute ΒΆ

spec_token_masks: Optional[Tensor] = None

token_chunk_offset_ptr class-attribute instance-attribute ΒΆ

token_chunk_offset_ptr: Optional[Tensor] = None

__init__ ΒΆ

__init__(
    num_prefills: int,
    num_prefill_tokens: int,
    num_decodes: int,
    num_decode_tokens: int,
    num_spec_decodes: int,
    num_spec_decode_tokens: int,
    num_actual_tokens: int,
    has_initial_state: Optional[Tensor] = None,
    spec_query_start_loc: Optional[Tensor] = None,
    non_spec_query_start_loc: Optional[Tensor] = None,
    spec_state_indices_tensor: Optional[Tensor] = None,
    non_spec_state_indices_tensor: Optional[Tensor] = None,
    spec_sequence_masks: Optional[Tensor] = None,
    spec_token_masks: Optional[Tensor] = None,
    num_accepted_tokens: Optional[Tensor] = None,
    nums_dict: Optional[dict] = None,
    batch_ptr: Optional[Tensor] = None,
    token_chunk_offset_ptr: Optional[Tensor] = None,
) -> None

GDNAttentionMetadataBuilder ΒΆ

Bases: AttentionMetadataBuilder[GDNAttentionMetadata]

Source code in vllm/v1/attention/backends/gdn_attn.py
class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]):
    cudagraph_support = AttentionCGSupport.UNIFORM_BATCH

    reorder_batch_threshold: int = 1

    def __init__(
        self,
        kv_cache_spec: AttentionSpec,
        layer_names: list[str],
        vllm_config: VllmConfig,
        device: torch.device,
    ):
        assert isinstance(kv_cache_spec, MambaSpec)
        self.vllm_config = vllm_config
        self.compilation_config = vllm_config.compilation_config
        self.speculative_config = vllm_config.speculative_config
        self.kv_cache_spec = kv_cache_spec
        if self.speculative_config:
            self.num_spec = self.speculative_config.num_speculative_tokens  # noqa: E501
        else:
            self.num_spec = 0
        self.use_spec_decode = self.num_spec > 0
        self._init_reorder_batch_threshold(1, self.use_spec_decode)

        self.use_full_cuda_graph = (
            self.compilation_config.cudagraph_mode.has_full_cudagraphs()
        )
        self.decode_cudagraph_max_bs = min(
            self.vllm_config.scheduler_config.max_num_seqs * (self.num_spec + 1),
            self.compilation_config.max_capture_size,
        )

        self.spec_state_indices_tensor = torch.empty(
            (self.decode_cudagraph_max_bs, self.num_spec + 1),
            dtype=torch.int32,
            device=device,
        )
        self.non_spec_state_indices_tensor = torch.empty(
            (self.decode_cudagraph_max_bs,),
            dtype=torch.int32,
            device=device,
        )
        self.spec_sequence_masks = torch.empty(
            (self.decode_cudagraph_max_bs,),
            dtype=torch.bool,
            device=device,
        )
        self.spec_token_masks = torch.empty(
            (self.decode_cudagraph_max_bs * (self.num_spec + 1),),
            dtype=torch.bool,
            device=device,
        )
        self.spec_query_start_loc = torch.empty(
            (self.decode_cudagraph_max_bs + 1,),
            dtype=torch.int32,
            device=device,
        )
        self.non_spec_query_start_loc = torch.empty(
            (self.decode_cudagraph_max_bs + 1,),
            dtype=torch.int32,
            device=device,
        )
        self.num_accepted_tokens = torch.empty(
            (self.decode_cudagraph_max_bs,),
            dtype=torch.int32,
            device=device,
        )

    def build(  # type: ignore[override]
        self,
        common_prefix_len: int,
        common_attn_metadata: CommonAttentionMetadata,
        num_accepted_tokens: Optional[torch.Tensor] = None,
        num_decode_draft_tokens_cpu: Optional[torch.Tensor] = None,
        fast_build: bool = False,
    ) -> GDNAttentionMetadata:
        m = common_attn_metadata

        query_start_loc = m.query_start_loc
        context_lens = m.num_computed_tokens_cpu
        context_lens_tensor = context_lens.to(query_start_loc.device)
        nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None

        if (
            not self.use_spec_decode
            or num_decode_draft_tokens_cpu is None
            or num_decode_draft_tokens_cpu[num_decode_draft_tokens_cpu >= 0]
            .sum()
            .item()
            == 0
        ):
            spec_sequence_masks = None
            num_spec_decodes = 0
        else:
            spec_sequence_masks = num_decode_draft_tokens_cpu >= 0
            num_spec_decodes = spec_sequence_masks.sum().item()
            if num_spec_decodes == 0:
                spec_sequence_masks = None
            else:
                spec_sequence_masks = spec_sequence_masks.to(
                    query_start_loc.device, non_blocking=True
                )

        if spec_sequence_masks is None:
            num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
                split_decodes_and_prefills(m, decode_threshold=1)
            )
            num_spec_decode_tokens = 0
            spec_token_masks = None
            spec_state_indices_tensor = None
            non_spec_state_indices_tensor = m.block_table_tensor[:, 0]
            spec_query_start_loc = None
            non_spec_query_start_loc = query_start_loc
            num_accepted_tokens = None
        else:
            query_lens = query_start_loc[1:] - query_start_loc[:-1]

            non_spec_query_lens = query_lens[~spec_sequence_masks]
            num_decodes = (non_spec_query_lens == 1).sum().item()
            num_prefills = non_spec_query_lens.size(0) - num_decodes
            num_decode_tokens = num_decodes
            num_prefill_tokens = non_spec_query_lens.sum().item() - num_decode_tokens

            if num_prefills == 0 and num_decodes == 0:
                spec_token_masks = torch.ones(
                    (
                        min(
                            num_spec_decodes * (self.num_spec + 1),
                            query_start_loc[-1].item(),
                        )
                    ),
                    dtype=torch.bool,
                    device=query_start_loc.device,
                )
                spec_state_indices_tensor = m.block_table_tensor[:, : self.num_spec + 1]
                non_spec_state_indices_tensor = None
                spec_query_start_loc = query_start_loc
                non_spec_query_start_loc = None
            else:
                spec_token_masks = torch.repeat_interleave(
                    spec_sequence_masks, query_lens
                )
                spec_state_indices_tensor = m.block_table_tensor[
                    spec_sequence_masks, : self.num_spec + 1
                ]
                non_spec_state_indices_tensor = m.block_table_tensor[
                    ~spec_sequence_masks, 0
                ]

                spec_query_start_loc = torch.zeros(
                    num_spec_decodes + 1,
                    dtype=torch.int32,
                    device=query_start_loc.device,
                )
                torch.cumsum(
                    query_lens[spec_sequence_masks], dim=0, out=spec_query_start_loc[1:]
                )
                non_spec_query_start_loc = torch.zeros(
                    query_lens.size(0) - num_spec_decodes + 1,
                    dtype=torch.int32,
                    device=query_start_loc.device,
                )
                torch.cumsum(
                    query_lens[~spec_sequence_masks],
                    dim=0,
                    out=non_spec_query_start_loc[1:],
                )

            num_spec_decode_tokens = (
                query_lens.sum().item() - num_prefill_tokens - num_decode_tokens
            )
            assert num_accepted_tokens is not None
            num_accepted_tokens = num_accepted_tokens[spec_sequence_masks]

        if num_prefills > 0:
            has_initial_state = context_lens_tensor > 0
            if spec_sequence_masks is not None:
                has_initial_state = has_initial_state[~spec_sequence_masks]
            nums_dict, batch_ptr, token_chunk_offset_ptr = (
                compute_causal_conv1d_metadata(non_spec_query_start_loc)
            )
        else:
            has_initial_state = None
        num_actual_tokens = (
            num_prefill_tokens + num_decode_tokens + num_spec_decode_tokens
        )

        # prepare tensors for cudagraph
        #
        # With speculative decoding, the xgrammar backend may rollback tokens
        # and causing some sequences has less draft tokens than self.num_spec.
        #
        # In above cases, the max possible batch size for n tokens, can be
        # min(n, cudagraph_max_bs).
        if (
            self.use_full_cuda_graph
            and num_prefills == 0
            and num_decodes == 0
            and num_spec_decodes <= self.decode_cudagraph_max_bs
            and num_spec_decode_tokens <= self.decode_cudagraph_max_bs
        ):
            num_actual_tokens = self.vllm_config.pad_for_cudagraph(m.num_actual_tokens)
            batch_size = min(self.decode_cudagraph_max_bs, num_actual_tokens)

            self.spec_state_indices_tensor[:num_spec_decodes].copy_(
                spec_state_indices_tensor, non_blocking=True
            )
            spec_state_indices_tensor = self.spec_state_indices_tensor[:batch_size]
            spec_state_indices_tensor[num_spec_decodes:].fill_(PAD_SLOT_ID)

            self.spec_sequence_masks[:num_spec_decodes].copy_(
                spec_sequence_masks, non_blocking=True
            )
            spec_sequence_masks = self.spec_sequence_masks[:batch_size]
            spec_sequence_masks[num_spec_decodes:].fill_(False)

            assert spec_token_masks is not None
            self.spec_token_masks[: spec_token_masks.size(0)].copy_(
                spec_token_masks, non_blocking=True
            )
            spec_token_masks = self.spec_token_masks[:num_actual_tokens]
            spec_token_masks[spec_token_masks.size(0) :].fill_(False)

            self.spec_query_start_loc[: num_spec_decodes + 1].copy_(
                spec_query_start_loc, non_blocking=True
            )
            spec_num_query_tokens = spec_query_start_loc[-1]  # type: ignore[index]
            spec_query_start_loc = self.spec_query_start_loc[: batch_size + 1]
            spec_query_start_loc[num_spec_decodes + 1 :].fill_(spec_num_query_tokens)

            self.num_accepted_tokens[:num_spec_decodes].copy_(
                num_accepted_tokens, non_blocking=True
            )
            num_accepted_tokens = self.num_accepted_tokens[:batch_size]
            num_accepted_tokens[num_spec_decodes:].fill_(1)

        if (
            self.use_full_cuda_graph
            and num_prefills == 0
            and num_spec_decodes == 0
            and num_decodes <= self.decode_cudagraph_max_bs
        ):
            num_actual_tokens = self.vllm_config.pad_for_cudagraph(m.num_actual_tokens)
            batch_size = num_actual_tokens

            self.non_spec_state_indices_tensor[:num_decodes].copy_(
                non_spec_state_indices_tensor, non_blocking=True
            )
            non_spec_state_indices_tensor = self.non_spec_state_indices_tensor[
                :batch_size
            ]
            non_spec_state_indices_tensor[num_decodes:].fill_(PAD_SLOT_ID)

            self.non_spec_query_start_loc[: num_decodes + 1].copy_(
                non_spec_query_start_loc, non_blocking=True
            )
            non_spec_num_query_tokens = non_spec_query_start_loc[-1]  # type: ignore[index]
            non_spec_query_start_loc = self.non_spec_query_start_loc[: batch_size + 1]
            non_spec_query_start_loc[num_decodes + 1 :].fill_(non_spec_num_query_tokens)

        attn_metadata = GDNAttentionMetadata(
            num_prefills=num_prefills,
            num_prefill_tokens=num_prefill_tokens,
            num_decodes=num_decodes,
            num_decode_tokens=num_decode_tokens,
            num_spec_decodes=num_spec_decodes,
            num_spec_decode_tokens=num_spec_decode_tokens,
            num_actual_tokens=num_actual_tokens,
            has_initial_state=has_initial_state,
            spec_query_start_loc=spec_query_start_loc,
            non_spec_query_start_loc=non_spec_query_start_loc,
            spec_state_indices_tensor=spec_state_indices_tensor,
            non_spec_state_indices_tensor=non_spec_state_indices_tensor,
            spec_sequence_masks=spec_sequence_masks,
            spec_token_masks=spec_token_masks,
            num_accepted_tokens=num_accepted_tokens,
            nums_dict=nums_dict,
            batch_ptr=batch_ptr,
            token_chunk_offset_ptr=token_chunk_offset_ptr,
        )
        return attn_metadata

    def build_for_cudagraph_capture(
        self, common_attn_metadata: CommonAttentionMetadata
    ):
        """
        This method builds the metadata for full cudagraph capture.
        Currently, only decode is supported for full cudagraphs with Mamba.
        """
        m = common_attn_metadata

        assert (
            m.num_reqs <= self.decode_cudagraph_max_bs
            and m.num_actual_tokens <= self.decode_cudagraph_max_bs
        ), (
            f"GDN only supports decode-only full CUDAGraph capture. "
            f"Make sure batch size ({m.num_reqs}) <= "
            f"cudagraph capture sizes ({self.decode_cudagraph_max_bs}), "
            f"and number of tokens ({m.num_actual_tokens}) <= "
            f"cudagraph capture sizes ({self.decode_cudagraph_max_bs})."
        )

        num_accepted_tokens = torch.diff(m.query_start_loc)
        num_decode_draft_tokens_cpu = (num_accepted_tokens - 1).cpu()
        m.num_computed_tokens_cpu = m.seq_lens_cpu - num_accepted_tokens.cpu()

        return self.build(0, m, num_accepted_tokens, num_decode_draft_tokens_cpu)

compilation_config instance-attribute ΒΆ

compilation_config = compilation_config

cudagraph_support class-attribute instance-attribute ΒΆ

cudagraph_support = UNIFORM_BATCH

decode_cudagraph_max_bs instance-attribute ΒΆ

decode_cudagraph_max_bs = min(
    max_num_seqs * (num_spec + 1), max_capture_size
)

kv_cache_spec instance-attribute ΒΆ

kv_cache_spec = kv_cache_spec

non_spec_query_start_loc instance-attribute ΒΆ

non_spec_query_start_loc = empty(
    (decode_cudagraph_max_bs + 1,),
    dtype=int32,
    device=device,
)

non_spec_state_indices_tensor instance-attribute ΒΆ

non_spec_state_indices_tensor = empty(
    (decode_cudagraph_max_bs,), dtype=int32, device=device
)

num_accepted_tokens instance-attribute ΒΆ

num_accepted_tokens = empty(
    (decode_cudagraph_max_bs,), dtype=int32, device=device
)

num_spec instance-attribute ΒΆ

num_spec = num_speculative_tokens

reorder_batch_threshold class-attribute instance-attribute ΒΆ

reorder_batch_threshold: int = 1

spec_query_start_loc instance-attribute ΒΆ

spec_query_start_loc = empty(
    (decode_cudagraph_max_bs + 1,),
    dtype=int32,
    device=device,
)

spec_sequence_masks instance-attribute ΒΆ

spec_sequence_masks = empty(
    (decode_cudagraph_max_bs,), dtype=bool, device=device
)

spec_state_indices_tensor instance-attribute ΒΆ

spec_state_indices_tensor = empty(
    (decode_cudagraph_max_bs, num_spec + 1),
    dtype=int32,
    device=device,
)

spec_token_masks instance-attribute ΒΆ

spec_token_masks = empty(
    (decode_cudagraph_max_bs * (num_spec + 1),),
    dtype=bool,
    device=device,
)

speculative_config instance-attribute ΒΆ

speculative_config = speculative_config

use_full_cuda_graph instance-attribute ΒΆ

use_full_cuda_graph = has_full_cudagraphs()

use_spec_decode instance-attribute ΒΆ

use_spec_decode = num_spec > 0

vllm_config instance-attribute ΒΆ

vllm_config = vllm_config

__init__ ΒΆ

__init__(
    kv_cache_spec: AttentionSpec,
    layer_names: list[str],
    vllm_config: VllmConfig,
    device: device,
)
Source code in vllm/v1/attention/backends/gdn_attn.py
def __init__(
    self,
    kv_cache_spec: AttentionSpec,
    layer_names: list[str],
    vllm_config: VllmConfig,
    device: torch.device,
):
    assert isinstance(kv_cache_spec, MambaSpec)
    self.vllm_config = vllm_config
    self.compilation_config = vllm_config.compilation_config
    self.speculative_config = vllm_config.speculative_config
    self.kv_cache_spec = kv_cache_spec
    if self.speculative_config:
        self.num_spec = self.speculative_config.num_speculative_tokens  # noqa: E501
    else:
        self.num_spec = 0
    self.use_spec_decode = self.num_spec > 0
    self._init_reorder_batch_threshold(1, self.use_spec_decode)

    self.use_full_cuda_graph = (
        self.compilation_config.cudagraph_mode.has_full_cudagraphs()
    )
    self.decode_cudagraph_max_bs = min(
        self.vllm_config.scheduler_config.max_num_seqs * (self.num_spec + 1),
        self.compilation_config.max_capture_size,
    )

    self.spec_state_indices_tensor = torch.empty(
        (self.decode_cudagraph_max_bs, self.num_spec + 1),
        dtype=torch.int32,
        device=device,
    )
    self.non_spec_state_indices_tensor = torch.empty(
        (self.decode_cudagraph_max_bs,),
        dtype=torch.int32,
        device=device,
    )
    self.spec_sequence_masks = torch.empty(
        (self.decode_cudagraph_max_bs,),
        dtype=torch.bool,
        device=device,
    )
    self.spec_token_masks = torch.empty(
        (self.decode_cudagraph_max_bs * (self.num_spec + 1),),
        dtype=torch.bool,
        device=device,
    )
    self.spec_query_start_loc = torch.empty(
        (self.decode_cudagraph_max_bs + 1,),
        dtype=torch.int32,
        device=device,
    )
    self.non_spec_query_start_loc = torch.empty(
        (self.decode_cudagraph_max_bs + 1,),
        dtype=torch.int32,
        device=device,
    )
    self.num_accepted_tokens = torch.empty(
        (self.decode_cudagraph_max_bs,),
        dtype=torch.int32,
        device=device,
    )

build ΒΆ

build(
    common_prefix_len: int,
    common_attn_metadata: CommonAttentionMetadata,
    num_accepted_tokens: Optional[Tensor] = None,
    num_decode_draft_tokens_cpu: Optional[Tensor] = None,
    fast_build: bool = False,
) -> GDNAttentionMetadata
Source code in vllm/v1/attention/backends/gdn_attn.py
def build(  # type: ignore[override]
    self,
    common_prefix_len: int,
    common_attn_metadata: CommonAttentionMetadata,
    num_accepted_tokens: Optional[torch.Tensor] = None,
    num_decode_draft_tokens_cpu: Optional[torch.Tensor] = None,
    fast_build: bool = False,
) -> GDNAttentionMetadata:
    m = common_attn_metadata

    query_start_loc = m.query_start_loc
    context_lens = m.num_computed_tokens_cpu
    context_lens_tensor = context_lens.to(query_start_loc.device)
    nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None

    if (
        not self.use_spec_decode
        or num_decode_draft_tokens_cpu is None
        or num_decode_draft_tokens_cpu[num_decode_draft_tokens_cpu >= 0]
        .sum()
        .item()
        == 0
    ):
        spec_sequence_masks = None
        num_spec_decodes = 0
    else:
        spec_sequence_masks = num_decode_draft_tokens_cpu >= 0
        num_spec_decodes = spec_sequence_masks.sum().item()
        if num_spec_decodes == 0:
            spec_sequence_masks = None
        else:
            spec_sequence_masks = spec_sequence_masks.to(
                query_start_loc.device, non_blocking=True
            )

    if spec_sequence_masks is None:
        num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
            split_decodes_and_prefills(m, decode_threshold=1)
        )
        num_spec_decode_tokens = 0
        spec_token_masks = None
        spec_state_indices_tensor = None
        non_spec_state_indices_tensor = m.block_table_tensor[:, 0]
        spec_query_start_loc = None
        non_spec_query_start_loc = query_start_loc
        num_accepted_tokens = None
    else:
        query_lens = query_start_loc[1:] - query_start_loc[:-1]

        non_spec_query_lens = query_lens[~spec_sequence_masks]
        num_decodes = (non_spec_query_lens == 1).sum().item()
        num_prefills = non_spec_query_lens.size(0) - num_decodes
        num_decode_tokens = num_decodes
        num_prefill_tokens = non_spec_query_lens.sum().item() - num_decode_tokens

        if num_prefills == 0 and num_decodes == 0:
            spec_token_masks = torch.ones(
                (
                    min(
                        num_spec_decodes * (self.num_spec + 1),
                        query_start_loc[-1].item(),
                    )
                ),
                dtype=torch.bool,
                device=query_start_loc.device,
            )
            spec_state_indices_tensor = m.block_table_tensor[:, : self.num_spec + 1]
            non_spec_state_indices_tensor = None
            spec_query_start_loc = query_start_loc
            non_spec_query_start_loc = None
        else:
            spec_token_masks = torch.repeat_interleave(
                spec_sequence_masks, query_lens
            )
            spec_state_indices_tensor = m.block_table_tensor[
                spec_sequence_masks, : self.num_spec + 1
            ]
            non_spec_state_indices_tensor = m.block_table_tensor[
                ~spec_sequence_masks, 0
            ]

            spec_query_start_loc = torch.zeros(
                num_spec_decodes + 1,
                dtype=torch.int32,
                device=query_start_loc.device,
            )
            torch.cumsum(
                query_lens[spec_sequence_masks], dim=0, out=spec_query_start_loc[1:]
            )
            non_spec_query_start_loc = torch.zeros(
                query_lens.size(0) - num_spec_decodes + 1,
                dtype=torch.int32,
                device=query_start_loc.device,
            )
            torch.cumsum(
                query_lens[~spec_sequence_masks],
                dim=0,
                out=non_spec_query_start_loc[1:],
            )

        num_spec_decode_tokens = (
            query_lens.sum().item() - num_prefill_tokens - num_decode_tokens
        )
        assert num_accepted_tokens is not None
        num_accepted_tokens = num_accepted_tokens[spec_sequence_masks]

    if num_prefills > 0:
        has_initial_state = context_lens_tensor > 0
        if spec_sequence_masks is not None:
            has_initial_state = has_initial_state[~spec_sequence_masks]
        nums_dict, batch_ptr, token_chunk_offset_ptr = (
            compute_causal_conv1d_metadata(non_spec_query_start_loc)
        )
    else:
        has_initial_state = None
    num_actual_tokens = (
        num_prefill_tokens + num_decode_tokens + num_spec_decode_tokens
    )

    # prepare tensors for cudagraph
    #
    # With speculative decoding, the xgrammar backend may rollback tokens
    # and causing some sequences has less draft tokens than self.num_spec.
    #
    # In above cases, the max possible batch size for n tokens, can be
    # min(n, cudagraph_max_bs).
    if (
        self.use_full_cuda_graph
        and num_prefills == 0
        and num_decodes == 0
        and num_spec_decodes <= self.decode_cudagraph_max_bs
        and num_spec_decode_tokens <= self.decode_cudagraph_max_bs
    ):
        num_actual_tokens = self.vllm_config.pad_for_cudagraph(m.num_actual_tokens)
        batch_size = min(self.decode_cudagraph_max_bs, num_actual_tokens)

        self.spec_state_indices_tensor[:num_spec_decodes].copy_(
            spec_state_indices_tensor, non_blocking=True
        )
        spec_state_indices_tensor = self.spec_state_indices_tensor[:batch_size]
        spec_state_indices_tensor[num_spec_decodes:].fill_(PAD_SLOT_ID)

        self.spec_sequence_masks[:num_spec_decodes].copy_(
            spec_sequence_masks, non_blocking=True
        )
        spec_sequence_masks = self.spec_sequence_masks[:batch_size]
        spec_sequence_masks[num_spec_decodes:].fill_(False)

        assert spec_token_masks is not None
        self.spec_token_masks[: spec_token_masks.size(0)].copy_(
            spec_token_masks, non_blocking=True
        )
        spec_token_masks = self.spec_token_masks[:num_actual_tokens]
        spec_token_masks[spec_token_masks.size(0) :].fill_(False)

        self.spec_query_start_loc[: num_spec_decodes + 1].copy_(
            spec_query_start_loc, non_blocking=True
        )
        spec_num_query_tokens = spec_query_start_loc[-1]  # type: ignore[index]
        spec_query_start_loc = self.spec_query_start_loc[: batch_size + 1]
        spec_query_start_loc[num_spec_decodes + 1 :].fill_(spec_num_query_tokens)

        self.num_accepted_tokens[:num_spec_decodes].copy_(
            num_accepted_tokens, non_blocking=True
        )
        num_accepted_tokens = self.num_accepted_tokens[:batch_size]
        num_accepted_tokens[num_spec_decodes:].fill_(1)

    if (
        self.use_full_cuda_graph
        and num_prefills == 0
        and num_spec_decodes == 0
        and num_decodes <= self.decode_cudagraph_max_bs
    ):
        num_actual_tokens = self.vllm_config.pad_for_cudagraph(m.num_actual_tokens)
        batch_size = num_actual_tokens

        self.non_spec_state_indices_tensor[:num_decodes].copy_(
            non_spec_state_indices_tensor, non_blocking=True
        )
        non_spec_state_indices_tensor = self.non_spec_state_indices_tensor[
            :batch_size
        ]
        non_spec_state_indices_tensor[num_decodes:].fill_(PAD_SLOT_ID)

        self.non_spec_query_start_loc[: num_decodes + 1].copy_(
            non_spec_query_start_loc, non_blocking=True
        )
        non_spec_num_query_tokens = non_spec_query_start_loc[-1]  # type: ignore[index]
        non_spec_query_start_loc = self.non_spec_query_start_loc[: batch_size + 1]
        non_spec_query_start_loc[num_decodes + 1 :].fill_(non_spec_num_query_tokens)

    attn_metadata = GDNAttentionMetadata(
        num_prefills=num_prefills,
        num_prefill_tokens=num_prefill_tokens,
        num_decodes=num_decodes,
        num_decode_tokens=num_decode_tokens,
        num_spec_decodes=num_spec_decodes,
        num_spec_decode_tokens=num_spec_decode_tokens,
        num_actual_tokens=num_actual_tokens,
        has_initial_state=has_initial_state,
        spec_query_start_loc=spec_query_start_loc,
        non_spec_query_start_loc=non_spec_query_start_loc,
        spec_state_indices_tensor=spec_state_indices_tensor,
        non_spec_state_indices_tensor=non_spec_state_indices_tensor,
        spec_sequence_masks=spec_sequence_masks,
        spec_token_masks=spec_token_masks,
        num_accepted_tokens=num_accepted_tokens,
        nums_dict=nums_dict,
        batch_ptr=batch_ptr,
        token_chunk_offset_ptr=token_chunk_offset_ptr,
    )
    return attn_metadata

build_for_cudagraph_capture ΒΆ

build_for_cudagraph_capture(
    common_attn_metadata: CommonAttentionMetadata,
)

This method builds the metadata for full cudagraph capture. Currently, only decode is supported for full cudagraphs with Mamba.

Source code in vllm/v1/attention/backends/gdn_attn.py
def build_for_cudagraph_capture(
    self, common_attn_metadata: CommonAttentionMetadata
):
    """
    This method builds the metadata for full cudagraph capture.
    Currently, only decode is supported for full cudagraphs with Mamba.
    """
    m = common_attn_metadata

    assert (
        m.num_reqs <= self.decode_cudagraph_max_bs
        and m.num_actual_tokens <= self.decode_cudagraph_max_bs
    ), (
        f"GDN only supports decode-only full CUDAGraph capture. "
        f"Make sure batch size ({m.num_reqs}) <= "
        f"cudagraph capture sizes ({self.decode_cudagraph_max_bs}), "
        f"and number of tokens ({m.num_actual_tokens}) <= "
        f"cudagraph capture sizes ({self.decode_cudagraph_max_bs})."
    )

    num_accepted_tokens = torch.diff(m.query_start_loc)
    num_decode_draft_tokens_cpu = (num_accepted_tokens - 1).cpu()
    m.num_computed_tokens_cpu = m.seq_lens_cpu - num_accepted_tokens.cpu()

    return self.build(0, m, num_accepted_tokens, num_decode_draft_tokens_cpu)